python - 使用急切执行优化向量
问题描述
我想使用 TensorFlow 的急切执行功能来优化向量的组件。在所有记录的示例中,每个可训练变量只是一个标量,其集合由这些列表表示。但是,我想到的损失函数涉及对这些组件执行矢量操作,所以这很不方便。
例如,让我们使用 Adam 优化器来归一化一个 3 分量向量:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import numpy as np
tf.enable_eager_execution()
def normalize(din=[2.0,1.0,0.0], lr=0.001,
nsteps=100):
d = tfe.Variable(din)
def loss(dvec):
return tf.sqrt((1.0 - tf.tensordot(dvec, dvec, 1))**2)
def grad(dvec):
with tf.GradientTape() as tape:
loss_val = loss(dvec)
return tape.gradient(loss_val, dvec)
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
for i in range(nsteps):
grads = grad(d)
optimizer.apply_gradients(zip(grads, d)) #Throws error
return d
此代码正确计算所需的梯度。但是,无论我做什么,“optimizer.apply_gradients”行都会引发某种错误,主要是因为 tfe.Variable 不是可迭代的。
在这个具体的例子中,错误是“AttributeError: Tensor.name is senseless when eager execution is enabled”。我们也可以尝试,例如,
zip(grads, [d[i] for i in range(3)])
而不是 d,但随后解释器抱怨 d 不可迭代。
将 grads 与 d 配对的正确方法是什么?
解决方案
Optimizer.apply_gradients
要求它的第一个参数是(梯度,变量)对的列表。
在上面的代码中,既不是列表grads
也不d
是列表(print(type(grads))
例如尝试),因此错误来自对zip
. 我认为你想要的是:
optimizer.apply_gradients(zip([grads], [d]))
或者,更简单地说:
optimizer.apply_gradients([(grads, d)])
此外,仅供参考,因为急切执行正在稳定更多的东西正在从实验性的“contrib”命名空间中移出,所以如果您使用的是最新版本的 TensorFlow tfe
(tf.Variable
1.11、1.12 等)。使您的整个程序看起来像:
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
def normalize(din=[2.0,1.0,0.0], lr=0.001,
nsteps=100):
d = tf.Variable(din)
def loss(dvec):
return tf.sqrt((1.0 - tf.tensordot(dvec, dvec, 1))**2)
def grad(dvec):
with tf.GradientTape() as tape:
loss_val = loss(dvec)
return tape.gradient(loss_val, dvec)
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
for i in range(nsteps):
dd = grad(d)
optimizer.apply_gradients([(dd, d)])
return d
希望有帮助!
推荐阅读
- java - java - 如何在java类路径中强制设置maven依赖项?
- django - Django 文件上传(获取文件数据,而不仅仅是文件路径)
- cassandra - 如何在 Cassandra 中处理分区键上的 BETWEEN 子句
- c# - 在一个 c# azure 函数会话中处理来自 IoT 中心的多条消息
- php - 方法vue js中的for循环并替换caract
- python - Keras ValueError:预期输入的形状为(2,),但数组的形状为(16,)
- python - 迭代 O(n^2 / 2) 一个字典
- jquery - 输入数字掩码
- visual-studio - 在 Visual Studio 中创建无需重定位的可执行文件
- java - 在 moto 设备和三星 s6 中第一次无法打开相机