python - TensorFlow 中切片输入的梯度为无
问题描述
以下是我的代码
import tensorflow as tf
import numpy as np
def forward(x):
z = tf.Variable(tf.zeros_like(x), trainable=False)
s = tf.shape(x)[0]
for i in range(s):
z[i].assign(x[i]**i)
return z
a = tf.Variable(np.ones([5])*3)
with tf.GradientTape() as tape:
b = forward(a)
grad = tape.gradient(b, a)
我有一个输入,我必须对其进行切片然后计算输出。在此基础上,我需要计算梯度。但是,上述代码的输出为 None。
如何获得渐变?有什么方法可以对输入进行切片以获得渐变。
PS我必须只使用EagerExecution。没有图表模式。
解决方案
使用 gradientTape 时,如果您将其视为一个函数,这将很有帮助。假设您的成本函数是 y = x ** 2。可以计算 y(您的函数)相对于 x(您的变量)的梯度。
在您的代码中,您没有计算梯度的函数。您试图针对变量计算梯度,但这是行不通的。
我做了一个小改动。检查下面代码中的可变成本
import tensorflow as tf
import numpy as np
def forward(x):
cost = []
z = tf.Variable(tf.zeros_like(x), trainable=False)
s = tf.shape(x)[0]
for i in range(s):
z[i].assign(x[i]**i)
cost.append(x[i]**i)
return cost
a = tf.Variable(np.ones([5])*3)
with tf.GradientTape() as tape:
b = forward(a)
grad = tape.gradient(b, a)
print(grad)
推荐阅读
- user-interface - 使用 xvfb 在 GitHub 上运行 Java GUI 测试
- javascript - 如何在 D3 中使用 selection.join 更新标签?
- css - Vuetify 框架引入的 CSS 道具根据浏览器控制台进行复制
- sql-server - SQL中将一个表分成两个表的最干净的方法是什么?
- javascript - 如何在 javascript fetch() 方法上传递带有 url 的变量?
- opencl - 代码永远不会为大于 8000 个条目的数组运行并出现错误
- sql - Postgres 在组中的多行中识别字符串中的模式
- latex - Latex中的2列列表
- .net - 当由不同的 dbContexts 跟踪时,我“应该”如何在 WinForms 中保持对象同步?
- python - 与 x 轴上的日期时间变量一起使用时不显示线图