tensorflow - 在训练有素的网络 keras 中编辑图层
问题描述
我有一个训练有素的图像去马赛克模型,我想通过删除超规格层中的过滤器来使其更小。
例如,我想采用以下模型(摘录):
conv1 = Conv2D(32, self.kernel_size, activation='relu', padding='same')(chnl4_input)
conv2 = Conv2D(32, self.kernel_size, strides=(2, 2), activation='relu', padding='same')(conv1)
conv5 = Conv2D(64, self.kernel_size, activation='relu', padding='same')(conv2)
conv6 = Conv2D(64, self.kernel_size, activation='relu', padding='same')(conv5)
up1 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv1], axis=-1)
conv7 = Conv2D(64, self.kernel_size, activation='relu', padding='same')(up1)
我想将 conv5 和 conv6 层更改为:
conv1 = Conv2D(32, self.kernel_size, activation='relu', padding='same')(chnl4_input)
conv2 = Conv2D(32, self.kernel_size, strides=(2, 2), activation='relu', padding='same')(conv1)
conv5 = Conv2D(32, self.kernel_size, activation='relu', padding='same')(conv2)
conv6 = Conv2D(32, self.kernel_size, activation='relu', padding='same')(conv5)
up1 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv1], axis=-1)
conv7 = Conv2D(64, self.kernel_size, activation='relu', padding='same')(up1)
我环顾四周,但没有看到任何明显的方法来做到这一点。 我发现了这个类似问题的例子,但解决方案特别提到新层必须与旧层具有相同数量的过滤器,这对我没有好处。
如果有人知道我该如何做到这一点,我将不胜感激。
[编辑]:澄清一下,我有一个现有的模型,比如“模型 A”。我想创建一个新模型,“模型 B”。这两个模型将是相同的,除了我上面提到的层。我正在寻找一种方法来初始化新模型,其中所有层的旧模型权重都已更改,但已更改的层除外。然后将像往常一样训练新模型以收敛。
解决方案
建立一个新模型(结构完全相同,只改变过滤器的数量)并正确传递权重:
transferLayers = [0,1,2,3,4,8,9] #indices must be chosen by you properly
for layer in transferLayers:
newModel.layers[layer].set_weights(oldModel.layers[layer].get_weights())
会有一个问题conv7
,它将接收不同的输入蚂蚁,因此它的权重矩阵也有不同的大小。
如果改变模型结构怎么办
然后您可能应该创建两个索引列表,一个用于旧模型,一个用于新模型。
或者您可以重新创建旧模型,为其层添加名称:
- 重新创建完全相同的训练模型,但为每一层添加名称
- 转移重量:
namedTrainingModel.set_weights(unnamedTraininModel.get_weights())
- 然后创建更改后的模型,为未更改的图层添加相同的名称,为更改的图层添加新名称
按名称转移权重:
namedTrainingModel.save_weights('filename')
changedModel.load_weights('filename', by_name=True)
推荐阅读
- php - SwiftMailer - php 在邮件中附加 txt 文件
- javascript - yyyy-mm-dd 的正则表达式
- react-native - react-native-svg中的LinearGradient不起作用
- javascript - 几个 axios 请求后 HTTP 请求冻结几分钟
- scala - 通过来自 Scala 的 Cats 中的类型别名创建 Validation.valid
- python - 如何只打开一次python软件
- spring-boot - Tomcat 上的 Logback 外部配置
- android - VideoView Android上的黑色背景
- pandas - 在带有线图的 seaborn FacetGrid 中使用单位
- r - 根据 DAG 中的顶点名称获取“顶级”边