首页 > 解决方案 > 通过有状态对象获取渐变

问题描述

我已经开始通过官方指南学习 TensorFlow:https ://www.tensorflow.org/guide 。

我的理解是在指南中名为“自动微分”的部分,尤其是“通过有状态对象获取渐变”的一部分。

我不明白他们为什么说有状态的对象停止渐变。该指南给出了这段代码

x0 = tf.Variable(3.0)
x1 = tf.Variable(0.0)

with tf.GradientTape() as tape:
  # Update x1 = x1 + x0.
  x1.assign_add(x0)
  # The tape starts recording from x1.
  y = x1**2   # y = (x1 + x0)**2

# This doesn't work.
print(tape.gradient(y, x0))   #dy/dx0 = 2*(x1 + x0)

为什么渐变不记录x0?!是这个函数.assign_add(x0)增加x1了 overshadowx0吗?是因为assign_add会选择x0并窃取其分配的内存吗?这是正确的原因还是我看不到的其他原因?

预先感谢您的回答。

标签: python-3.xtensorflow2.0gradient-descent

解决方案


推荐阅读