tensorflow - 如何结合两个 RNN 模型
问题描述
我对 tensorflow 有疑问,现在我使用 tensorflow 构建了两个 RNN 模型。
RNN1(编码器)是 3 层 LSTM 单元,单元大小为 256
RNN2(解码器)是 3 层 LSTM 单元,单元大小为 512。
现在我想按状态组合两个RNN,这意味着RNN1的最后一个状态是RNN2的第一个状态。我想问如何实现?
我尝试使用相同的单元格大小(512)设置 RNN1 和 RNN2 并使用以下代码:
decoder_initial_state = cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state)` #this code is from Internet.
但我认为这是不对的。
任何帮助表示感谢谢谢
解决方案
从您发布的示例代码中,我很确定它是正确的。
2)我认为这是不对的,因为 RNN1 编码器和 RNN2 解码器假设是按状态连接的,但在这段代码中,它只是由 RNN1 状态初始化所有 RNN2 状态
虽然它被初始化为那个,但如果你查看build_decoder
函数,你会看到decoder_initial_state
解码器中使用了 。所以,encoder_state
-> decoder_initial_state
->build_decoder
高水平
当 Tensorflow 构建计算图时,由于计算图的创建方式,它会自动将结果RNN1
输入RNN2
。当您在 Tensorflow 中定义操作时,您不是“直接”提供数据,而是说“将这些操作链接在一起”
向前进
您应该从一个更简单的示例开始是Seq2Seq 教程,它将引导您完成一个简单的编码器-解码器。这应该让您更好地理解,以便您可以处理您发布的代码:)
一切顺利!希望这可以帮助
推荐阅读
- python - 比较包装函数是否是装饰器的实例
- python - 如何删除熊猫中属于同一组的某些 k 条目,例如余额类?
- python - 如何确定嵌套列表结构是否与另一个相同,但元素交换为新的
- postgresql - 错误:类型时间戳的无效输入语法:“20-MAR-17 08.30.41.453267 AM”
- loops - 如何创建动态 while 循环?
- php - 获取值列表,然后获取具有该值的所有帖子标题
- java - 如何在 BasicInterpreter 的 GETSTATIC 和 PUTSTATIC 命令中唯一标识静态变量
- python - 如何使用 Pandas 保留我的组中使用的列名
- python - 如何通过 tesseract OCR 读取黑色背景图像上的黑色文本?
- javascript - 如何打印 HTML 地理位置?