python - 主模型中的卷积和分类模型
问题描述
我必须创建神经网络模型,如下所示:
convolution --> classification
\ /
\ /
_\| |/_
third model
with one output
卷积输出数据,用作分类模型的输入。之后,卷积和分类输出被填充(连接)到第三个模型。第三个模型将输出预测 0..1,用于训练整个网络。
- 首先:在这种情况下是否可以正确地反向传播分类模型?或者这需要创建三个独立的模型?
- 我试图连接卷积和分类,但没有好的结果。我收到“图表已断开连接”错误。
完整的错误日志:“图表已断开:无法在“classification_prediction_Input”层获得张量张量(“classification_prediction_Input_2:0”,shape=(1, 512), dtype=float32) 的值。访问以下先前层没有问题:[ ]”。
如果想法是正确的,如何连接“图形”上的模型?
我现在的代码:
# state convolution
state_input = Input(shape=INPUT_SHAPE, name='state_input', batch_shape=(1, 210, 160, 3))
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)
state_convolution_model = Model(state_input, state_outputs, name='state_convolution_model')
state_convolution_model.compile(optimizer='adam', loss='mean_squared_error', metrics=['acc'])
state_convolution_model_input = Input(shape=INPUT_SHAPE, name='state_convolution_model_input', batch_shape=(1, 210, 160, 3))
state_convolution = state_convolution_model(state_convolution_model_input)
# classification output
classficication_Input = Input(shape=(1, LSTM_OUTPUT_DIM), batch_shape=(1, LSTM_OUTPUT_DIM), name='classification_prediction_Input')
classficication_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(classficication_Input)
classficication_output_raw = Dense(ACTIONS, activation='sigmoid', name='classification_output_raw')(classficication_Dense_1)
classficication_output = Reshape((ACTIONS,), name='classification_output')(classficication_output_raw)
classficication_model = Model(classficication_Input, classficication_output, name='classificationPrediction_model')
classficication_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
classficicationPrediction = classficication_model(state_convolution)
i = keras.layers.concatenate([state_outputs, classficication_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o) # <-- graph error is here
plot_model(model, to_file='model.png', show_shapes=True)
解决方案
是的,您可以构建这样的结构并以端到端的方式对其进行训练。但是,您需要创建一个具有多个分支的模型。我可以看到的另一个问题是您在模型完全定义之前编译它。这是工作代码:
# state convolution
state_input = Input(shape=INPUT_SHAPE, name='state_input')
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)
# classification output
classification_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(state_outputs)
classification_output_raw = Dense(ACTIONS,
activation='sigmoid',
name='classification_output_raw')(classification_Dense_1)
classification_output = Reshape((ACTIONS,), name='classification_output')(classification_output_raw)
i = concatenate([state_outputs, classification_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o) # <-- no graph error anymore here
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
model.summary()
输出:
Layer (type) Output Shape Param # Connected to
==================================================================================================
state_input (InputLayer) (None, 210, 160, 3) 0
__________________________________________________________________________________________________
state_Conv2D_1 (Conv2D) (None, 51, 39, 8) 1544 state_input[0][0]
__________________________________________________________________________________________________
state_MaxPooling2D_1 (MaxPoolin (None, 25, 19, 8) 0 state_Conv2D_1[0][0]
__________________________________________________________________________________________________
state_Flatten (Flatten) (None, 3800) 0 state_MaxPooling2D_1[0][0]
__________________________________________________________________________________________________
classification_prediction_Dense (None, 32) 121632 state_Flatten[0][0]
__________________________________________________________________________________________________
classification_output_raw (Dens (None, 4) 132 classification_prediction_Dense_1
__________________________________________________________________________________________________
classification_output (Reshape) (None, 4) 0 classification_output_raw[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 3804) 0 state_Flatten[0][0]
classification_output[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 32) 121760 concatenate[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 1) 33 dense[0][0]
==================================================================================================
有关更多示例,请参阅功能 API 指南。
推荐阅读
- javascript - 哪个性能更好:在每个渲染上添加和删除事件侦听器 VS 运行 useEffect 来更新 ref
- ios - UIStackView 作为 XIB 的根视图
- javascript - 如何在类组件中处理从父事件处理程序向子事件处理程序传递的附加参数
- arangodb-php - 使用 PHP 将大量循环/批量插入到 ArrangoDB
- r - 两侧按行布局
- php - 用一个答案显示多个图像
- javascript - Angular Material表拖放列排序不正常
- java - Deeplearning4j 预测二手车价格
- sql - 计算 sqlite 中的匹配值
- perl - 在 perl 中不推荐使用数组作为引用