keras - 我的 ResNet 的 Stage-1 中的 BatchNorm 层连接到所有其他 BatchNorm 层。为什么?
问题描述
这里我给出了我实现的 ResNet 模型的一些截图。使用 TensorBoard 生成的图表。
是 tensorflow 在后端做的某种优化吗?
我已经使用 Keras 实现了代码。
模型中有两个块。身份块和卷积块。添加这些块的代码会导致 StackOverflow 出现问题(您的帖子主要是代码)
在 ResNet 函数 (def ResNet) 中,我使用了 BatchNormalization 并将其命名为“bnl_stg-1”,我只向其传递了一个输入 (X)。但由于某种原因,它连接到标识和卷积块中的所有 BatchNorm 层,如图所示。
这是代码:
def ResNet(input_shape, features):
'''
Implements the ResNet50 Model
[Conv2D -> BatchNorm -> ReLU -> MaxPool2D] --> [ConvBlock -> IdentityBlock * 2] --> [ConvBlock -> IdentityBlock * 3] --> [AveragePool2D -> Flatten -> Dense -> Sigmoid]
'''
X_input = Input(input_shape)
X = ZeroPadding2D((3, 3))(X_input)
# Stage 1
X = Conv2D(filters = 64,
kernel_size = (7, 7),
strides = (2, 2),
name = 'cnl_stg-1',
kernel_initializer = 'glorot_uniform')(X)
X = BatchNormalization(axis = 3,
name = 'bnl_stg-1')(X)
X = Activation('relu')(X)
X = MaxPooling2D(pool_size=(3, 3),
strides=(2, 2))(X)
# Stage 2
X = convolutional_block(X, f = 3, filters = [64, 64, 256], stage = 2, s = 1)
X = identity_block(X, 3, [64, 64, 256], stage=2, block=1)
X = identity_block(X, 3, [64, 64, 256], stage=2, block=2)
# Stage 3
X = convolutional_block(X, f = 3, filters = [128, 128, 512], stage = 3, s = 2)
X = identity_block(X, 3, [128, 128, 512], stage = 3, block = 1)
X = identity_block(X, 3, [128, 128, 512], stage = 3, block = 2)
X = identity_block(X, 3, [128, 128, 512], stage = 3, block = 3)
#Final Stage
X = AveragePooling2D(pool_size = (2, 2),
strides = (2, 2))(X)
X = Flatten()(X)
X = Dense(features, activation='sigmoid', name='fc' + str(features), kernel_initializer = 'glorot_uniform')(X)
# Create model
model = Model(inputs = X_input, outputs = X, name='ResNet')
return model
解决方案
你不应该担心它。Batch Normalization 行为在训练和学习之间发生变化,因此 Keras 添加了一个布尔变量来控制它(如果我没记错的话是 keras_learning_phase)。这就是为什么所有这些层都是连接的。你可以期待 Dropout 层的类似行为。
推荐阅读
- python - 调用随机值以绘制形状和删除形状
- sql - 时态表以一种奇怪的模式检索具有不同顺序的分区行集
- python - 类型错误:不支持的操作数
- haskell - 箭头化 FRP 中流元组与元组流
- scala - 无法解决错误:java.io.NotSerializableException: org.apache.avro.Schema$RecordSchema
- c++ - 内存地址输出而不是值
- reverse-engineering - 关于指针扫描
- python - 如何在 Python 中使用 .upper 函数格式化大写的完整句子?
- node.js - 如何在 ExpressJS 中使用 SendGrid 发送电子邮件
- amazon-ec2 - ec2 实例持续运行