首页 > 解决方案 > 使用 tensorflow 2.1.0 更新字典

问题描述

我面临一个问题,我似乎无法在其他任何地方找到解决方案,所以我决定在这里发布我的问题(我有 tensorflow 的基本知识,但很新):

我用python写了一个简单的代码来说明我想要做什么。

    import tensorflow as tf

    def generated_dict():
       graph = {'input': tf.Variable(2)}
       graph['layer_1'] = tf.square(graph['input'])
       graph['layer_2'] = tf.add(graph['input'], graph['layer_1'])
       return graph

    graph = generated_dict()
    print("boo = " + str(graph['layer_2']))

    graph['input'].assign(tf.constant(3))
    print("far = " + str(graph['layer_2']))

在这个示例代码中,我希望 tensorflow 在我分配新的输入值时更新整个字典graph['input'].assign(tf.constant(3))。基本上,现在我得到

boo = tf.Tensor(6, shape=(), dtype=int32) # 2²+2
far = tf.Tensor(6, shape=(), dtype=int32) # 2²+2

这是正常的,因为我的代码急切地执行。但是,我希望字典用我的新输入更新其值并获得:

boo = tf.Tensor(6, shape=(), dtype=int32) #2²+2
far = tf.Tensor(12, shape=(), dtype=int32) #3²+3

我觉得我应该使用 tf.function() 但我不确定我应该如何进行。我试过graph = tf.function(generated_graph)()但我没有帮助。任何帮助将不胜感激。

标签: pythontensorflowdictionaryeager-execution

解决方案


首先,当尝试将 TF1 代码改编为 TF2 时,我建议阅读指南: 将您的 TensorFlow 1 代码迁移到 TensorFlow 2

TF2 将 tensorflow 的基本设计从在图中运行操作更改为急切执行。这意味着在 TF2 中编写代码与在普通 python 中编写代码非常接近:所有的抽象、图形创建等都是在后台完成的。

你复杂的设计不需要存在于TF2中,只需要写一个简单的python函数即可。您甚至可以使用普通运算符而不是 tensorflow 函数。会有转换成 tensorflow 的算子。或者,您可以使用tf.function装饰器进行表演

@tf.function
def my_graph(x):
     return x**2 + x

现在,如果您想在该函数中提供数据,您只需要使用 value 调用它:

>>> my_graph(tf.constant(3))
<tf.Tensor: shape=(), dtype=int32, numpy=12>
>>> my_graph(tf.constant(2))
<tf.Tensor: shape=(), dtype=int32, numpy=6>

推荐阅读