python-3.x - 具有 2x2 输入的双向 GRU
问题描述
我正在构建一个网络,它将字符串拆分为单词,将单词拆分为字符,嵌入每个字符,然后通过将字符聚合为单词并将单词聚合为字符串来计算该字符串的向量表示。使用双向 gru 层执行聚合并注意。
为了测试这个东西,假设我对这个字符串中的 5 个单词和 5 个字符感兴趣。在这种情况下,我的转变是:
["Some string"] -> ["Some","strin","","",""] ->
["Some_","string","_____","_____","_____"] where _ is the padding symbol ) ->
[[1,2,3,4,0],[1,5,6,7,8],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]] (shape 5x5)
接下来我有一个嵌入层,它将每个字符变成一个长度为 6 的嵌入向量。所以我的特征变成了一个 5x5x6 矩阵。然后我将此输出传递给双向 gru 层并执行一些其他操作,这些操作在这种情况下并不重要,我相信。
问题是当我用迭代器运行它时,比如
for string in strings:
output = model(string)
它似乎工作得很好(字符串是从 5x5 的切片创建的 tf 数据集),所以它是一堆 5 x 5 矩阵。
但是,当我转到训练或使用预测等功能在数据集级别工作时,模型会失败:
model.predict(strings.batch(1))
ValueError: Input 0 of layer bidirectional is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: (None, 5, 5, 6)
据我从文档中了解到,双向层将 3d 张量作为输入:[batch, timesteps, feature],因此在这种情况下,我的输入形状应如下所示:[batch_size,timesteps,(5,5,6)]
所以问题是我应该对输入数据应用哪种转换来获得这种形状?
解决方案
对于双向输入层,如果您使用 GRU,请使用return_sequences=True
, 来获得 3 维输出。由于 GRU 输出是 2D,return_sequences 将为您提供 3D 输出。对于堆叠的双向层输入应该是 3D 形状。
示例代码
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential()
model.add(
layers.Bidirectional(layers.GRU(64, return_sequences=True), input_shape=(5, 10))
)
model.add(layers.Bidirectional(layers.GRU(32)))
model.add(layers.Dense(10))
model.summary()
输出
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
bidirectional_3 (Bidirection (None, 5, 128) 38400
_________________________________________________________________
bidirectional_4 (Bidirection (None, 64) 41216
_________________________________________________________________
dense_2 (Dense) (None, 10) 650
=================================================================
Total params: 80,266
Trainable params: 80,266
Non-trainable params: 0
___________________________
推荐阅读
- javascript - 使用 React 的自定义鼠标光标:未捕获的 TypeError:无法读取 null 的属性“clientWidth”
- firebase - 如何在 Vue 中设置身份验证状态更改?
- sql - 比较来自同一 sqlite 列的两行
- r - R 抱怨缺少 .so 文件,但它的包名
- razor - 如何仅输出选择的文本而不是 2sxc 剃须刀模板中的值?
- json - _TypeError(类型'_InternalLinkedHashMap
' 不是类型 'List 的子类型 ') - kotlin - 无法在 Gradle 项目中使用 pluginManagement 块应用 Kotlin JVM 插件
- linux - 将屏幕会话配置为与 xterm 终端完全相同
- regex - 如何使用自定义分隔符删除模式之间的线条
- qt - Qt Quick Controls 2.14 如何设置ScrollView的样式