keras - keras.layers.concatenate 的输出形状
问题描述
我有一个具有相同输出形状的密集层列表 [batch, 1]。如果我将这些层的输出与 keras.layers.concatenate() 结合起来,形状会是什么?
dense_layers = [Dense(1), Dense(1), Dense(1)] #some dense layers
merged_output = keras.layers.concatenate([dense_layers])
合并输出的形状是 (batch, 3) 还是 (3, 1)?
解决方案
答案是(批次,3)。要看到这一点,您可以构建一个模型并打印 model.summary():
from keras.layers import Input, Dense
from keras.models import Model
from keras.layers import concatenate
batch = 30
# define three sets of inputs
input1 = Input(shape=(batch,1))
input2 = Input(shape=(batch,1))
input3 = Input(shape=(batch,1))
# define three dense layers
layer1 = Dense(1)(input1)
layer2 = Dense(1)(input2)
layer3 = Dense(1)(input3)
# concatenate layers
dense_layers = [layer1, layer2, layer3]
merged_output = concatenate(dense_layers)
# create a model and check for output shape
model = Model(inputs=[input1, input2, input3], outputs=merged_output)
model.summary()
Layer (type) Output Shape Param # Connected to
=============================================================================
input_1 (InputLayer) (None, 30, 1) 0
_______________________________________________________________________________
input_2 (InputLayer) (None, 30, 1) 0
_______________________________________________________________________________
input_3 (InputLayer) (None, 30, 1) 0
_______________________________________________________________________________
dense_1 (Dense) (None, 30, 1) 2 input_1[0][0]
_______________________________________________________________________________
dense_2 (Dense) (None, 30, 1) 2 input_2[0][0]
_______________________________________________________________________________
dense_3 (Dense) (None, 30, 1) 2 input_3[0][0]
_______________________________________________________________________________
concatenate_1 (Concatenate) (None, 30, 3) 0 dense_1[0][0]
dense_2[0][0]
dense_3[0][0]
==============================================================================
Total params: 6
Trainable params: 6
Non-trainable params: 0
______________________________________________________________________________
推荐阅读
- node.js - NodeJS TypeError:sessionsMap[userId].push 不可迭代(无法读取属性 Symbol(Symbol.iterator))
- reactjs - 具有本地依赖的 Dockrize React App
- azure-ad-b2c - B2C 租户中的 AD 用户
- numpy - 如何过滤给定特定字母的ndarray?
- python - 字典到另一个字典的索引
- c# - Microsoft.ML 训练视频数据模型
- html - 从 django 模型表单中删除字段会破坏 HTML 中的格式
- r - 如何在 ggplot2 中创建具有多个计数变量的线图?
- r - 如何使用带有数据框的 dplyr 在 R 中创建百分位数?
- amazon-web-services - 如何从jenkins管道将命令行参数传递给aws beanstalk中的jar?