首页 > 解决方案 > 如何使用 tf.multiply 执行自定义渐变?

问题描述

我用 tensorflow 包定义了自定义渐变映射器。

当我将 tf.multiply 与自定义渐变一起使用时,它不起作用。

整个代码在这里

import tensorflow as tf

@tf.RegisterGradient("MyopGrad")
def frop_grad(op, grad):
    x = op.inputs[0] 
    return 1000.0 * x 

input = tf.Variable([4.0], dtype=tf.float32)
x = tf.constant(5.0)
g = tf.get_default_graph()

with g.gradient_override_map({"Multiply": "MyopGrad"}): 
  output1 = tf.multiply(input, x , name = 'multiply')
grad1 = tf.gradients(output1, input)

# output without gradient clipping in the backwards pass for comparison:
output1_ori = tf.multiply(input , x)
grad1_ori = tf.gradients(output1_ori, input)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print("with custom:", sess.run(grad1)[0])
  print("without custom:", sess.run(grad1_ori)[0])

标签: pythontensorflowgradient

解决方案


TensorFlow 操作名称tf.multiply只是Mul,不是Multiply。此外,tf.multiply有两个输入,所以它的梯度应该有两个输出。所以你的代码可能看起来像这样:

import tensorflow as tf

@tf.RegisterGradient("MyopGrad")
def frop_grad(op, grad):
    x = op.inputs[0]
    y = op.inputs[1]
    return 1000.0 * x, 1000.0 * y

input = tf.Variable([4.0], dtype=tf.float32)
x = tf.constant(5.0)
g = tf.get_default_graph()

with g.gradient_override_map({"Mul": "MyopGrad"}): 
  output1 = tf.multiply(input, x , name = 'multiply')
grad1 = tf.gradients(output1, input)

# output without gradient clipping in the backwards pass for comparison:
output1_ori = tf.multiply(input , x)
grad1_ori = tf.gradients(output1_ori, input)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print("with custom:", sess.run(grad1)[0])
  print("without custom:", sess.run(grad1_ori)[0])

输出:

with custom: [4000.]
without custom: [5.]

推荐阅读