首页 > 解决方案 > BatchNormalization 炸毁了 keras 模型

问题描述

我尝试使用以下代码在 keras 上以 tensorflow 作为后端来训练模型:

CHANNEL_AXIS = 3

img_width, img_height = 513, 128
nb_classes = 10
batch_size = 64

input_shape = (img_width, img_height, 3)
inputs = layers.Input(input_shape)

tempModel = layers.Conv2D(filters = 256, kernel_size=(4, 513), strides=(1, 1), padding='same')(inputs)

shortcut = tempModel

tempModel = layers.BatchNormalization(axis = CHANNEL_AXIS)(tempModel)
tempModel = layers.Activation('relu')(tempModel)
tempModel = layers.Conv2D(filters = 256, kernel_size=(4, 1), strides=(1, 1), padding='same', activation=None)(tempModel)

tempModel = layers.BatchNormalization(axis = CHANNEL_AXIS)(tempModel)
tempModel = layers.Activation('relu')(tempModel)
tempModel = layers.Conv2D(filters = 256, kernel_size=(4, 1), strides=(1, 1), padding='same', activation=None)(tempModel)

tempModel = layers.add([shortcut, tempModel])

max_p_layer = layers.GlobalMaxPooling2D(data_format='channels_last')(tempModel)
avg_p_layer = layers.GlobalAveragePooling2D(data_format='channels_last')(tempModel)

tempModel = layers.concatenate([max_p_layer, avg_p_layer])
tempModel = layers.Dense(300, activation='relu')(tempModel)
tempModel = layers.Dropout(0.2)(tempModel)
tempModel = layers.Dense(150, activation='relu')(tempModel)
tempModel = layers.Dropout(0.2)(tempModel)
tempModel = layers.Dense(nb_classes, activation='softmax')(tempModel)

model = Model(inputs=inputs, outputs=tempModel)

现在,当我尝试训练模型时,训练速度非常慢,尤其是与其他具有更多参数的架构相比。此外,该模型需要更多的内存(总共超过 30 GB 没有权重),我认为这是因为 BatchNormalization 层(当我出于测试目的删除它们时,至少模型使用的内存少了几个 GB)。是我错误地实现了网络还是 BatchNormalization 层非常慢?

标签: tensorflowkerasdeep-learningbatch-normalization

解决方案


推荐阅读