python - ValueError:没有为任何变量提供渐变 - Keras Tensorflow 2.0
问题描述
我正在尝试在 TensorFlow 网站上遵循此示例,但它不起作用。
这是我的代码:
import tensorflow as tf
def vectorize(vector_like):
return tf.convert_to_tensor(vector_like)
def batchify(vector):
'''Make a batch out of a single example'''
return vectorize([vector])
data = [(batchify([0]), batchify([0, 0, 0])), (batchify([1]), batchify([0, 0, 0])), (batchify([2]), batchify([0, 0, 0]))]
num_hidden = 5
num_classes = 3
opt = tf.keras.optimizers.SGD(learning_rate=0.1)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
loss_fn = lambda: tf.keras.backend.cast(tf.keras.losses.mse(model(input), output), tf.float32)
var_list_fn = lambda: model.trainable_weights
for input, output in data:
opt.minimize(loss_fn, var_list_fn)
有一段时间,我收到有关损失函数具有错误数据类型(int 而不是 float)的警告,这就是我将强制转换添加到损失函数的原因。
我没有进行网络培训,而是收到了错误:
ValueError:没有为任何变量提供梯度:['sequential/dense/kernel:0', 'sequential/dense/bias:0', 'sequential/dense_1/kernel:0', 'sequential/dense_1/bias:0'] .
为什么渐变没有通过?我究竟做错了什么?
解决方案
GradientTape
如果要在 TF2 中操作渐变,则需要使用。例如,以下作品。
opt = tf.keras.optimizers.SGD(learning_rate=0.1)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
with tf.GradientTape() as tape:
loss = tf.keras.backend.mean(tf.keras.losses.mse(model(input),tf.cast(output, tf.float32)))
gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
编辑:
您实际上可以通过执行以下更改来使您的示例正常工作。
- 仅将 cast 用于输出而不是完整的
loss_fn
(注意我也在做一个mean
,因为我们优化了损失的平均值)
通过“工作”,我的意思是它不会抱怨。但是您需要进一步调查以确保它按预期工作。
loss_fn = lambda: tf.keras.backend.mean(tf.keras.losses.mse(model(input), tf.cast(output, tf.float32)))
var_list_fn = lambda: model.trainable_weights
opt.minimize(loss_fn, var_list_fn)
推荐阅读
- grapesjs - Grapejs将字段添加到设置组件
- sql - 如何在 PL/SQL 中增量集成数据
- python - 如何将新记录添加到多对多字段
- javascript - 未捕获的 TypeError:动态加载 A-Frame 时系统 [e] 不是构造函数
- php - 如何在 Windows 7 的 XAMPP 上运行的 PHP 应用程序中启用 SSL?
- javascript - 对基于另一个数组的数组进行排序,并用 0 或空字符串 javascript 替换不匹配的数据
- dart - 飞镖:什么
在函数名之后 - tomcat - 如何在 Jakarta / Java EE 中连续运行一个进程(没有网站请求)?
- haskell - 如何在haskell中仅导入特定实例
- azure-devops - 从 REST API 传递到构建管道 yml 的变量未正确获取