首页 > 解决方案 > Tensorflow:使用 tf.function 可能的张量列表?

问题描述

我正在尝试构建以急切执行的 tensorflow 构建简单的 RNN。

这是我有问题的代码

@tf.function
def rnn(self,X):
    outputs=[]
    state=self.state
    for x in X:
        output,state=self.basic_rnn_cell(x,state)
        outputs.append(output)
    return outputs

如您所见,我想创建张量列表,然后可以将其提供给优化器。此代码无需tf.function修饰即可正常工作。装饰签名很难使用 python 列表。我收到这样的错误

tensorflow.python.framework.errors_impl.InaccessibleTensorError: The tensor 'Tensor("while/StatefulPartitionedCall:0", shape=(2, 1), dtype=float32)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_51, id=2300683156496); accessed from: FuncGraph(name=rnn, id=2300681517808).

我已经尝试过,tf.concat但是循环中的可变长度变量有问题。

有什么办法可以解决这个问题吗?

标签: pythontensorflowtensorflow2.0

解决方案


推荐阅读