python - 用不同的通道替换预训练模型的输入层?
问题描述
我想重新使用 的预训练权重MobiletNetv2
,但图像有12 个通道。我知道这需要增加重量,但这没关系,因为无论如何我都想重新训练。我找不到让它工作的方法。
import tensorflow as tf
class CNN(tf.keras.Model):
def __init__(self):
super(CNN, self).__init__()
self.input_layer = tf.keras.layers.InputLayer(input_shape=(None, 224, 224, 12))
self.base = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
include_top=False,
weights='imagenet')
_ = self.base._layers.pop(0)
self.flat1 = tf.keras.layers.Flatten()
self.dens3 = tf.keras.layers.Dense(10)
def call(self, x, **kwargs):
x = self.input_layer(x)
x = self.base(x)
x = self.flat1(x)
x = self.dens3(x)
return x
model = CNN()
model.build(input_shape=(None, 224, 224, 12))
ValueError: Input 0 is in compatible with layer mobilenetv2_1.00_224: expected shape=(None, 224, 224, 3), found shape=(None, 224, 224, 12)
我尝试像其他答案一样弹出第一层。
解决方案
可以加载两种模型,一种具有 12 通道的输入形状,另一种具有正常的 12 通道。然后,只需将 3 通道模型的权重加载到 12 通道模型中,从第 2 层或第 3 层开始。
这是执行重量转移的地方:
for i in range(3, len(self.base.layers)):
self.base.layers[i].set_weights(base_weights.layers[i].get_weights())
这是整个事情:
import tensorflow as tf
h, w, c = 224, 224, 3
class CNNModel(tf.keras.Model):
def __init__(self):
super(CNNModel, self).__init__()
self.base = tf.keras.applications.MobileNetV2(input_shape=(h, w, 12),
include_top=False,
weights=None)
base_weights = tf.keras.applications.MobileNetV2(input_shape=(h, w, c),
include_top=False,
weights='imagenet')
for i in range(3, len(self.base.layers)):
self.base.layers[i].set_weights(base_weights.layers[i].get_weights())
del base_weights
self.pool = tf.keras.layers.GlobalAveragePooling2D()
self.drop1 = tf.keras.layers.Dropout(0.25)
self.out = tf.keras.layers.Dense(1, activation='sigmoid')
def call(self, x, training=None, **kwargs):
x = self.base(x)
x = self.pool(x)
x = self.drop1(x)
x = self.out(x)
return x
model = CNNModel()
model.build(input_shape=(None, h, w, 12))
推荐阅读
- r - expand.grid 和 rep 如何协同工作以显示所有结果组合?
- javascript - JavaScript 函数未在 FormLoad 页面上加载
- python - 将列表作为元素附加到空的 numpy 数组
- python - (python) (selenium) 消息:没有这样的元素:无法找到元素:
- python - 将 Qtable 中的多条记录保存到 Pickle
- plaid - 我可以在格子中为单个银行指定特定帐户,然后只检索这些帐户的数据吗?
- xml - 如果未提供另一个元素,则 XML 验证需要一个元素 - 没有 XSD 1.1
- reactjs - 如何通过 axios 获取 json 数据以与代理做出反应?
- salt-stack - Saltstack:错误:无法从 minion 中的 saltenv 基础缓存文件 'salt://foo/bar/test.sh'
- neo4j - 无法启用 Neo4j 空间扩展