tensorflow - tf.gradients,如何理解 `grad_ys` 并使用它?
问题描述
在tf.gradients
中,有一个关键字参数grad_ys
grad_ys
是一个长度相同的张量列表,ys
其中包含每个y
in的初始梯度ys
。当为无时,我们为每个ingrad_ys
填充一个形状为 '1' 的张量。用户可以提供他们自己的初始值,以使用每个 y 的不同初始梯度来计算导数(例如,如果想要为每个 y 中的每个值不同地加权梯度)。y
y
ys
grad_ys
为什么grads_ys
这里需要?这里的文档是隐含的。你能给出一些具体的目的和代码吗?
我的示例代码tf.gradients
是
In [1]: import numpy as np
In [2]: import tensorflow as tf
In [3]: sess = tf.InteractiveSession()
In [4]: X = tf.placeholder("float", shape=[2, 1])
In [5]: Y = tf.placeholder("float", shape=[2, 1])
In [6]: W = tf.Variable(np.random.randn(), name='weight')
In [7]: b = tf.Variable(np.random.randn(), name='bias')
In [8]: pred = tf.add(tf.multiply(X, W), b)
In [9]: cost = 0.5 * tf.reduce_sum(tf.pow(pred-Y, 2))
In [10]: grads = tf.gradients(cost, [W, b])
In [11]: sess.run(tf.global_variables_initializer())
In [15]: W_, b_, pred_, cost_, grads_ = sess.run([W, b, pred, cost, grads],
feed_dict={X: [[2.0], [3.]], Y: [[3.0], [2.]]})
解决方案
grad_ys
只有高级用例才需要。这是您可以考虑的方法。
tf.gradients
允许您计算tf.gradients(y, x, grad_ys) = grad_ys * dy/dx
. 换句话说,grad_ys
是每个 的乘数y
。在这种表示法中,提供这个论点似乎很愚蠢,因为一个人应该能够只对自己进行复数,即tf.gradients(y, x, grad_ys) = grad_ys * tf.gradients(y, x)
. 不幸的是,这种等式不成立,因为当反向计算梯度时,我们在每一步之后执行归约(通常是求和)以获得“中间损失”。
此功能在许多情况下都很有用。文档字符串中提到了一个。这是另一个。记住链式法则 - dz/dx = dz/dy * dy/dx
。假设我们想要计算dz/dx
但dz/dy
不可微分,我们只能近似它。假设我们以某种方式计算近似值并将其称为approx
。那么,dz/dx = tf.gradients(y, x, grad_ys=approx)
。
另一个用例可能是当您有一个带有“巨大扇入”的模型时。假设您有 100 个输入源,它们经过几层(称为“100 个分支”),在 处合并y
,然后再经过 10 个层,直到到达loss
. 一次计算整个模型的所有梯度(需要记住许多激活)可能不适合内存。一种方法是先计算d(loss)/dy
。然后,计算branch_i
相对于loss
using的变量的梯度tf.gradients(y, branch_i_variables, grad_ys=d(loss)/dy)
。使用这个(以及我跳过的更多细节),您可以减少峰值内存需求。
推荐阅读
- javascript - 将 javascript 变量发送到 django 视图
- c++ - 制作一个 QPushButton 圆形 C++
- ios - 如何初始化所有视图控制器?
- python - TypeError:标签不是 numpy 数组,也不是标量
- angular - 重定向到路由后,它立即转到上一页
- xmpp - 无法使用 IP 地址访问 ejabberd 管理面板
- javascript - 从 dist 加载 index.html 中的资源失败,这与工作 npm run dev cmd 相反
- ruby - The recipient does not receive any push message even though
- thymeleaf - How to stop character encoding in thymeleaf from changing url get parameter '?' character to changing to %3F
- ios - 如何在 Swift 中按日期对数组进行分组?