tensorflow - 理解 keras 层中的形状
问题描述
我正在学习 Tensorflow 和 Keras 来实现LSTM
many-to-many
输入序列长度等于输出序列长度的模型。
示例代码:
输入:
voc_size = 10000
embed_dim = 64
lstm_units = 75
size_batch = 30
count_classes = 5
模型:
from tensorflow.keras.layers import ( Bidirectional, LSTM,
Dense, Embedding, TimeDistributed )
from tensorflow.keras import Sequential
def sample_build(embed_dim, voc_size, batch_size, lstm_units, count_classes):
model = Sequential()
model.add(Embedding(input_dim=voc_size,
output_dim=embed_dim,input_length=50))
model.add(Bidirectional(LSTM(units=lstm_units,return_sequences=True),
merge_mode="ave"))
model.add(Dense(200))
model.add(TimeDistributed(Dense(count_classes+1)))
# Compile model
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
model.summary()
return model
sample_model = sample_build(embed_dim,voc_size,
size_batch, rnn_units,
count_classes)
我无法理解每一层的输入和输出的形状。例如,输出的形状Embedding_Layer
是(BATCH_SIZE, time_steps, length_of_input)
,在这种情况下,它是(30, 50, 64)
。
Bidirectional LSTM
类似地, later的输出形状是(30, 50, 75)
。这将是下一个单位的Dense Layer
输入200
。但是权重矩阵的形状Dense Layer
是(units
当前层的个数,前一层的单元数,就是(200,75)
这种情况。那么矩阵的计算是如何在2D
形状Dense Layer
和3D
双向层的形状之间发生的呢?任何关于形状澄清的解释都会有所帮助
解决方案
Dense 可以进行 3D 操作,它将输入展平为形状 (batch_size * time_steps, features),然后应用密集层并将其重新整形回原始 (batch_size, time_steps, units)。在 keras 的密集层文档中,它说:
注意:如果层的输入的秩大于 2,则 Dense 沿输入的最后一个轴和内核的轴 1 计算输入和内核之间的点积(使用 tf.tensordot)。例如,如果输入的维度为 (batch_size, d0, d1),那么我们创建一个形状为 (d1, units) 的内核,并且内核沿输入的轴 2 对形状 (1, 1) 的每个子张量进行操作, d1) (有 batch_size * d0 这样的子张量)。在这种情况下,输出将具有形状 (batch_size, d0, units)。
关于Embedding
图层输出的另一点。正如你所说,它是一个 3D 输出是正确的,但正确的形状对应于 (BATCH_SIZE, input_dim, embeddings_dim)
推荐阅读
- python - 使用 matplotlib 绘制 sin wav 的傅立叶变换
- node.js - 使用 docker compose 时,Discord bot 无法联系不和谐服务器
- performance - 在 chrome 中的 React 应用程序中不断“渲染”屏幕外元素
- azure-aks - 使用自定义子网为 AKS 群集启用应用程序网关入口控制器 (AGIC) 加载项
- javascript - 可以采用递增计数器并使其显得唯一随机的算法或公式
- java - 指定的编译器合规性为 11,但使用 JRE 15 警告 spring boot 套件?
- c# - 在 EF Core 5 中,如何通过仅设置外键 ID 插入具有多对多关系的实体,而无需先查询?
- java - 为什么杰克逊仍然将 long[] 序列化为数字数组?
- html - 调整大小时如何设置背景图像以适应浏览器窗口?
- c# - 在 XNA/Monogame 中旋转然后平移矩阵时抖动