首页 > 解决方案 > 主模型中的卷积和分类模型

问题描述

我必须创建神经网络模型,如下所示:

convolution --> classification
       \            /
        \          /
        _\|      |/_
         third model
       with one output

卷积输出数据,用作分类模型的输入。之后,卷积和分类输出被填充(连接)到第三个模型。第三个模型将输出预测 0..1,用于训练整个网络。

完整的错误日志:“图表已断开:无法在“classification_prediction_Input”层获得张量张量(“classification_prediction_Input_2:0”,shape=(1, 512), dtype=float32) 的值。访问以下先前层没有问题:[ ]”。

如果想法是正确的,如何连接“图形”上的模型?

我现在的代码:

# state convolution
state_input = Input(shape=INPUT_SHAPE, name='state_input', batch_shape=(1, 210, 160, 3))
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)
state_convolution_model = Model(state_input, state_outputs, name='state_convolution_model')
state_convolution_model.compile(optimizer='adam', loss='mean_squared_error', metrics=['acc'])

state_convolution_model_input = Input(shape=INPUT_SHAPE, name='state_convolution_model_input', batch_shape=(1, 210, 160, 3))
state_convolution = state_convolution_model(state_convolution_model_input)

# classification output
classficication_Input = Input(shape=(1, LSTM_OUTPUT_DIM), batch_shape=(1, LSTM_OUTPUT_DIM), name='classification_prediction_Input')
classficication_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(classficication_Input)
classficication_output_raw = Dense(ACTIONS, activation='sigmoid', name='classification_output_raw')(classficication_Dense_1)
classficication_output = Reshape((ACTIONS,), name='classification_output')(classficication_output_raw)
classficication_model = Model(classficication_Input, classficication_output, name='classificationPrediction_model')
classficication_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

classficicationPrediction = classficication_model(state_convolution)

i = keras.layers.concatenate([state_outputs, classficication_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o)                  # <-- graph error is here
plot_model(model, to_file='model.png', show_shapes=True)

标签: pythonmachine-learningkerasneural-network

解决方案


是的,您可以构建这样的结构并以端到端的方式对其进行训练。但是,您需要创建一个具有多个分支的模型。我可以看到的另一个问题是您在模型完全定义之前编译它。这是工作代码:

# state convolution                                                                                                                                                                                                                                                   
state_input = Input(shape=INPUT_SHAPE, name='state_input')
state_Conv2D_1 = Conv2D(8, kernel_size=(8, 8), strides=(4, 4), activation='relu', name='state_Conv2D_1')(state_input)
state_MaxPooling2D_1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='state_MaxPooling2D_1')(state_Conv2D_1)
state_outputs = Flatten(name='state_Flatten')(state_MaxPooling2D_1)

# classification output                                                                                                                                                                                                                                               
classification_Dense_1 = Dense(32, activation='relu', name='classification_prediction_Dense_1')(state_outputs)
classification_output_raw = Dense(ACTIONS,                                                                                                                                                                                                                            
                                  activation='sigmoid',                                                                                                                                                                                                               
                                  name='classification_output_raw')(classification_Dense_1)
classification_output = Reshape((ACTIONS,), name='classification_output')(classification_output_raw)


i = concatenate([state_outputs, classification_output], name='concatenate')
d = Dense(32, activation='relu')(i)
o = Dense(1, activation='sigmoid')(d)
model = Model(state_input, o)                  # <-- no graph error anymore here                                                                                                                                                                                      
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
model.summary()

输出:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
state_input (InputLayer)        (None, 210, 160, 3)  0                                            
__________________________________________________________________________________________________
state_Conv2D_1 (Conv2D)         (None, 51, 39, 8)    1544        state_input[0][0]                
__________________________________________________________________________________________________
state_MaxPooling2D_1 (MaxPoolin (None, 25, 19, 8)    0           state_Conv2D_1[0][0]             
__________________________________________________________________________________________________
state_Flatten (Flatten)         (None, 3800)         0           state_MaxPooling2D_1[0][0]       
__________________________________________________________________________________________________
classification_prediction_Dense (None, 32)           121632      state_Flatten[0][0]              
__________________________________________________________________________________________________
classification_output_raw (Dens (None, 4)            132         classification_prediction_Dense_1
__________________________________________________________________________________________________
classification_output (Reshape) (None, 4)            0           classification_output_raw[0][0]  
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 3804)         0           state_Flatten[0][0]              
                                                                 classification_output[0][0]      
__________________________________________________________________________________________________
dense (Dense)                   (None, 32)           121760      concatenate[0][0]                
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1)            33          dense[0][0]                      
==================================================================================================

有关更多示例,请参阅功能 API 指南。


推荐阅读