首页 > 解决方案 > Tensorflow:在 tf.gradients() 期间忽略特定依赖项

问题描述

给定变量 y 和 z,它们都依赖于张量 x。根据产品规则,如果我做 tf.gradients(y z,x),它会给我 y'(x)z(x) + z'(x)y(x)。有没有办法可以将 y 指定为相对于 x 的常数,这样 tf.gradients(y z,x) 只给我 z'(x)y(x)?

我知道 y_=tf.constant(sess.run(y)) 会给我 y 作为常数,但我不能在我的代码中使用该解决方案。

标签: tensorflowtensorflow-datasets

解决方案


您可以使用tf.stop_gradient()来阻止反向传播。在您的示例中阻止渐变:

y = function1(x)
z = function2(x)

blocked_y = tf.stop_gradient(y)

product = blocked_y * z

在你通过反向传播之后product,反向传播将继续z而不是y


推荐阅读