首页 > 解决方案 > 将自定义渐变定义为 Tensorflow 中的类方法

问题描述

我需要将一个方法定义为自定义渐变,如下所示:

class CustGradClass:

    def __init__(self):
        pass

    @tf.custom_gradient
    def f(self,x):
      fx = x
      def grad(dy):
        return dy * 1
      return fx, grad

我收到以下错误:

ValueError:尝试将具有不受支持的类型 ()的值(< main .CustGradClass object at 0x12ed91710>)转换为张量。

原因是自定义梯度接受函数f(*x),其中 x 是张量序列。传递的第一个参数是对象本身,即self

文档中:

f:函数 f(*x),它返回一个元组 (y, grad_fn) 其中:
x 是函数的张量输入序列。y 是将 f 中的 TensorFlow 操作应用于 x 的张量或张量输出序列。grad_fn 是一个签名为 g(*grad_ys) 的函数

我如何使它工作?我需要继承一些 python tensorflow 类吗?

我正在使用 tf 版本 1.12.0 和渴望模式。

标签: pythontensorflowgradientautodiff

解决方案


这是一种可能的简单解决方法:

import tensorflow as tf

class CustGradClass:

    def __init__(self):
        self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))

    @staticmethod
    def _f(self, x):
        fx = x * 1
        def grad(dy):
            return dy * 1
        return fx, grad

with tf.Graph().as_default(), tf.Session() as sess:
    x = tf.constant(1.0)
    c = CustGradClass()
    y = c.f(x)
    print(tf.gradients(y, x))
    # [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]

编辑:

如果你想在不同的类上多次这样做,或者只是想要一个更可重用的解决方案,你可以使用像这样的一些装饰器,例如:

import functools
import tensorflow as tf

def tf_custom_gradient_method(f):
    @functools.wraps(f)
    def wrapped(self, *args, **kwargs):
        if not hasattr(self, '_tf_custom_gradient_wrappers'):
            self._tf_custom_gradient_wrappers = {}
        if f not in self._tf_custom_gradient_wrappers:
            self._tf_custom_gradient_wrappers[f] = tf.custom_gradient(lambda *a, **kw: f(self, *a, **kw))
        return self._tf_custom_gradient_wrappers[f](*args, **kwargs)
    return wrapped

然后你可以这样做:

class CustGradClass:

    def __init__(self):
        pass

    @tf_custom_gradient_method
    def f(self, x):
        fx = x * 1
        def grad(dy):
            return dy * 1
        return fx, grad

    @tf_custom_gradient_method
    def f2(self, x):
        fx = x * 2
        def grad(dy):
            return dy * 2
        return fx, grad

推荐阅读