python - 如何使用 tf.multiply 执行自定义渐变?
问题描述
我用 tensorflow 包定义了自定义渐变映射器。
当我将 tf.multiply 与自定义渐变一起使用时,它不起作用。
整个代码在这里
import tensorflow as tf
@tf.RegisterGradient("MyopGrad")
def frop_grad(op, grad):
x = op.inputs[0]
return 1000.0 * x
input = tf.Variable([4.0], dtype=tf.float32)
x = tf.constant(5.0)
g = tf.get_default_graph()
with g.gradient_override_map({"Multiply": "MyopGrad"}):
output1 = tf.multiply(input, x , name = 'multiply')
grad1 = tf.gradients(output1, input)
# output without gradient clipping in the backwards pass for comparison:
output1_ori = tf.multiply(input , x)
grad1_ori = tf.gradients(output1_ori, input)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("with custom:", sess.run(grad1)[0])
print("without custom:", sess.run(grad1_ori)[0])
解决方案
TensorFlow 操作名称tf.multiply
只是Mul
,不是Multiply
。此外,tf.multiply
有两个输入,所以它的梯度应该有两个输出。所以你的代码可能看起来像这样:
import tensorflow as tf
@tf.RegisterGradient("MyopGrad")
def frop_grad(op, grad):
x = op.inputs[0]
y = op.inputs[1]
return 1000.0 * x, 1000.0 * y
input = tf.Variable([4.0], dtype=tf.float32)
x = tf.constant(5.0)
g = tf.get_default_graph()
with g.gradient_override_map({"Mul": "MyopGrad"}):
output1 = tf.multiply(input, x , name = 'multiply')
grad1 = tf.gradients(output1, input)
# output without gradient clipping in the backwards pass for comparison:
output1_ori = tf.multiply(input , x)
grad1_ori = tf.gradients(output1_ori, input)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("with custom:", sess.run(grad1)[0])
print("without custom:", sess.run(grad1_ori)[0])
输出:
with custom: [4000.]
without custom: [5.]
推荐阅读
- java - Jackson 深度依赖序列化
- java - 我如何在 android 中使用自定义适配器按字母顺序执行搜索
- flutter - 在 Flutter 中导航到 ListView 后,图像未在 Hero 中显示
- opencv - canny 边缘检测算法是否适用于梯度变化较小的图像?
- java - 在 Maven 中指定 Apache Tomcat“lib”文件夹
- android - java.lang.IllegalStateException:应为 BEGIN_OBJECT 但为 BEGIN_ARRAY Kotlin
- java - 如何用最少的代码比较两个字符串列表的java集合?
- cordova - Ionic 3 错误:找不到资源 xml/network_security_config
- python - CSV 到多值字典?
- haskell - 如果一个字母只使用一次,检查一个字符串 Haskell