首页 > 解决方案 > 寻找 TensorFlow 等效的 Pytorch GRU 功能

问题描述

我对如何在 TensorFlow 中重建以下 Pytorch 代码感到困惑。它同时使用输入大小x和隐藏大小h来创建 GRU 层

import torch
torch.nn.GRU(64, 64*2, batch_first=True, return_state=True) 

本能地,我首先尝试了以下方法:

import tensorflow as tf
tf.keras.layers.GRU(64, return_state=True)

但是,我意识到它并没有真正考虑h或隐藏大小。在这种情况下我该怎么办?

标签: tensorflowdeep-learningpytorchrecurrent-neural-networkgated-recurrent-unit

解决方案


在您的 tensorflow 示例中,隐藏大小为 64。要获得等价物,您应该使用

import tensorflow as tf
tf.keras.layers.GRU(64*2, return_state=True)

这是因为 keras 层不需要您指定输入大小(本例中为 64);它是在您第一次构建或运行模型时决定的。


推荐阅读