python - tf.keras.layers.RNN 与 tf.keras.layers.StackedRNNCells:Tensorflow 2
问题描述
我正在尝试在 Tensorflow 2.0 中实现多层 RNN 模型。尝试两者tf.keras.layers.StackedRNNCells
并tf.keras.layers.RNN
得出相同的结果。谁能帮我理解和之间的tf.keras.layers.RNN
区别tf.keras.layers.StackedRNNCells
?
# driving parameters
sz_batch = 128
sz_latent = 200
sz_sequence = 196
sz_feature = 2
n_units = 120
n_layers = 3
多层 RNN tf.keras.layers.RNN
:
inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(cells, stateful=True, return_sequences=True, return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()
返回:
Model: "model_13"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_88 (InputLayer) [(128, 196, 2)] 0
_________________________________________________________________
rnn_61 (RNN) (128, 196, 120) 218880
_________________________________________________________________
dense_19 (Dense) (128, 196, 1) 121
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0
多层 RNNtf.keras.layers.RNN
和tf.keras.layers.StackedRNNCells
:
inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(tf.keras.layers.StackedRNNCells(cells),
stateful=True,
return_sequences=True,
return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()
返回:
Model: "model_14"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_89 (InputLayer) [(128, 196, 2)] 0
_________________________________________________________________
rnn_62 (RNN) (128, 196, 120) 218880
_________________________________________________________________
dense_20 (Dense) (128, 196, 1) 121
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0
解决方案
tf.keras.layers.RNN 使用 tf.keras.layers.StackedRNNCells 如果你给它一个列表或一个单元格元组。这是在https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/layers/recurrent.py#L390中完成的
推荐阅读
- php - 如何优化 get_option() 的 sql 查询数?
- python - 检查图形路径是否有效时出现 KeyError
- python - 如何打印 mime 有效载荷
- javascript - 使用 FileField 编辑表单会删除文件
- mysql - 带有尺寸和引号的mysql排序字符串
- c# - 在 ASP.NET MVC 5 中,如何从 FormCollection 反序列化复杂对象?
- docker - 在带有 windows/servercore 的 Windows 容器上以无头模式运行 Firefox
- pandas - 使用 isin 选择 Series 的部分元素
- selenium - 使用 xpath 在 selenium 中查找元素
- docker - 如何理解容器状态