首页 > 解决方案 > keras.layers.concatenate 的输出形状

问题描述

我有一个具有相同输出形状的密集层列表 [batch, 1]。如果我将这些层的输出与 keras.layers.concatenate() 结合起来,形状会是什么?

dense_layers = [Dense(1), Dense(1), Dense(1)] #some dense layers
merged_output = keras.layers.concatenate([dense_layers])

合并输出的形状是 (batch, 3) 还是 (3, 1)?

标签: keras

解决方案


答案是(批次,3)。要看到这一点,您可以构建一个模型并打印 model.summary():

from keras.layers import Input, Dense
from keras.models import Model
from keras.layers import concatenate

 batch = 30

# define three sets of inputs
input1 = Input(shape=(batch,1))
input2 = Input(shape=(batch,1))
input3 = Input(shape=(batch,1))

# define three dense layers
layer1 = Dense(1)(input1)
layer2 = Dense(1)(input2)
layer3 = Dense(1)(input3)

# concatenate layers
dense_layers = [layer1, layer2, layer3]
merged_output = concatenate(dense_layers)

# create a model and check for output shape
model = Model(inputs=[input1, input2, input3], outputs=merged_output)
model.summary()

Layer (type)                    Output Shape         Param #     Connected to                     

=============================================================================
input_1 (InputLayer)            (None, 30, 1)        0                                            
_______________________________________________________________________________
input_2 (InputLayer)            (None, 30, 1)        0                                            

_______________________________________________________________________________
input_3 (InputLayer)            (None, 30, 1)        0                                            

_______________________________________________________________________________
dense_1 (Dense)                 (None, 30, 1)        2           input_1[0][0]                    
_______________________________________________________________________________
dense_2 (Dense)                 (None, 30, 1)        2           input_2[0][0]                    
_______________________________________________________________________________
dense_3 (Dense)                 (None, 30, 1)        2           input_3[0][0]                    

_______________________________________________________________________________
concatenate_1 (Concatenate)     (None, 30, 3)        0           dense_1[0][0]                    
                                                                 dense_2[0][0]                    
                                                                 dense_3[0][0]                    

==============================================================================
Total params: 6
Trainable params: 6
Non-trainable params: 0
______________________________________________________________________________

推荐阅读