首页 > 解决方案 > 如何分别运行多输入 keras 模型?

问题描述

我有一个多输入 keras 模型,如下图所示。输入是 (244, 244, 3)。由于内存不足,我无法在计算机上运行它。有没有办法单独运行每个通道,然后连接它们的输出并继续训练?

这里是我用来创建模型的代码:

def create_multi_channel_cnn():
    channel_1 = create_model(input_shape=(244, 244, 3))
    channel_2 = create_model(input_shape=(244, 244, 3))
           
    # combine the output of the two branches
    combined = concatenate([channel_1.output, channel_2.output])

    x = Dense(64, activation="relu")(combined)
    x = Dense(2, activation="softmax")(x)

    model = Model(inputs=[channel_1.input,  channel_2.input], outputs=x)

    return model 

每个通道都有以下层:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv4 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________

标签: pythontensorflowkerasconv-neural-network

解决方案


推荐阅读