python - 如何在不使用 Sequential() 的情况下在 Keras 中堆叠图层?
问题描述
如果我有一个 keras 层 L,并且我想在 keras 模型中堆叠该层的 N 个版本(具有不同的权重),那么最好的方法是什么?请注意,这里的 N 很大并且由超参数控制。如果 N 很小,那么这不是问题(我们可以手动重复一行 N 次)。所以让我们假设 N > 10 例如。
如果该层只有一个输入和一个输出,我可以执行以下操作:
m = Sequential()
for i in range(N):
m.add(L)
但如果我的层实际上需要多个输入,这将不起作用。例如,如果我的层具有 z = L(x, y) 的形式,并且我希望我的模型执行以下操作:
x_1 = L(x_0, y)
x_2 = L(x_1, y)
...
x_N = L(x_N-1, y)
然后 Sequential 就无法完成这项工作。我想我可以子类化一个 keras 模型,但我不知道将 N 层放入类中的最干净的方法是什么。我可以使用一个列表,例如:
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.layers = []
for i in range(N):
self.layers.append(L)
def call(self, inputs):
x = inputs[0]
y = inputs[1]
for i in range(N):
x = self.layers[i](x, y)
return x
但这并不理想,因为 keras 无法识别这些层(似乎没有将层列表视为“可检查点”)。例如, MyModel.variables 将为空,而 MyModel.Save() 不会保存任何内容。
我还尝试使用功能 API 定义模型,但在我的情况下它也不起作用。事实上,如果我们这样做
def MyModel():
input = Input(shape=...)
output = SomeLayer(input)
return Model(inputs=input, outputs=output)
如果 SomeLayer 本身是自定义模型,它将不会运行(它会引发 NotImplementedError)。
有什么建议么?
解决方案
不确定我的问题是否正确,但我想您可以使用 Keras 应用程序中显示的功能 API 和concatenate
/或add
层,例如ResNet50或InceptionV3来构建“非序列”网络。
更新
在我的一个项目中,我正在使用这样的东西。我有一个自定义层(它没有在我的 Keras 版本中实现,所以我只是手动将代码“反向移植”到我的笔记本中)。
class LeakyReLU(Layer):
"""Leaky version of a Rectified Linear Unit backported from newer Keras
version."""
def __init__(self, alpha=0.3, **kwargs):
super(LeakyReLU, self).__init__(**kwargs)
self.supports_masking = True
self.alpha = K.cast_to_floatx(alpha)
def call(self, inputs):
return tf.maximum(self.alpha * inputs, inputs)
def get_config(self):
config = {'alpha': float(self.alpha)}
base_config = super(LeakyReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
然后,模型:
def create_model(input_shape, output_size, alpha=0.05, reg=0.001):
inputs = Input(shape=input_shape)
x = Conv2D(16, (3, 3), padding='valid', strides=(1, 1),
kernel_regularizer=l2(reg), kernel_constraint=maxnorm(3),
activation=None)(inputs)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=alpha)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(32, (3, 3), padding='valid', strides=(1, 1),
kernel_regularizer=l2(reg), kernel_constraint=maxnorm(3),
activation=None)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=alpha)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(64, (3, 3), padding='valid', strides=(1, 1),
kernel_regularizer=l2(reg), kernel_constraint=maxnorm(3),
activation=None)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=alpha)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(128, (3, 3), padding='valid', strides=(1, 1),
kernel_regularizer=l2(reg), kernel_constraint=maxnorm(3),
activation=None)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=alpha)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(256, (3, 3), padding='valid', strides=(1, 1),
kernel_regularizer=l2(reg), kernel_constraint=maxnorm(3),
activation=None)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=alpha)(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Flatten()(x)
x = Dense(500, activation='relu', kernel_regularizer=l2(reg))(x)
x = Dense(500, activation='relu', kernel_regularizer=l2(reg))(x)
x = Dense(500, activation='relu', kernel_regularizer=l2(reg))(x)
x = Dense(500, activation='relu', kernel_regularizer=l2(reg))(x)
x = Dense(500, activation='relu', kernel_regularizer=l2(reg))(x)
x = Dense(500, activation='relu', kernel_regularizer=l2(reg))(x)
x = Dense(output_size, activation='linear', kernel_regularizer=l2(reg))(x)
model = Model(inputs=inputs, outputs=x)
return model
最后,自定义指标:
def root_mean_squared_error(y_true, y_pred):
return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
我正在使用以下代码段来创建和编译模型:
model = create_model(input_shape=X.shape[1:], output_size=y.shape[1])
model.compile(loss=root_mean_squared_error, optimizer='adamax')
像往常一样,我使用检查点回调来保存模型。要加载模型,您需要将自定义层类和指标也传递给load_model
函数:
def load_custom_model(path):
return load_model(path, custom_objects={
'LeakyReLU': LeakyReLU,
'root_mean_squared_error': root_mean_squared_error
})
它有帮助吗?
推荐阅读
- gradle - 带有文件分隔符的依赖项在 gradle 6.0 版本中无法解析
- django - django runserver:关系“django_migrations”已经存在
- powershell - PowerShell - 解压缩大型档案 - 运行空间工厂
- python - 带优先队列的线程
- reference - 在 Latex 中用逗号替换引用中的括号
- angularjs - 安装 angularx-social-login 后出现弃用错误
- node.js - 我如何通过节点 js 在模型中使用函数查找值来获取
- ethereum - 我们可以在不同的区块链上部署相同的 ERC20 代币吗?
- c# - 无法解析符号
- php - PHPUnit Laravel 没有运行下一个测试用例