python - 将自定义渐变定义为 Tensorflow 中的类方法
问题描述
我需要将一个方法定义为自定义渐变,如下所示:
class CustGradClass:
def __init__(self):
pass
@tf.custom_gradient
def f(self,x):
fx = x
def grad(dy):
return dy * 1
return fx, grad
我收到以下错误:
ValueError:尝试将具有不受支持的类型 ()的值(< main .CustGradClass object at 0x12ed91710>)转换为张量。
原因是自定义梯度接受函数f(*x),其中 x 是张量序列。传递的第一个参数是对象本身,即self。
从文档中:
f:函数 f(*x),它返回一个元组 (y, grad_fn) 其中:
x 是函数的张量输入序列。y 是将 f 中的 TensorFlow 操作应用于 x 的张量或张量输出序列。grad_fn 是一个签名为 g(*grad_ys) 的函数
我如何使它工作?我需要继承一些 python tensorflow 类吗?
我正在使用 tf 版本 1.12.0 和渴望模式。
解决方案
这是一种可能的简单解决方法:
import tensorflow as tf
class CustGradClass:
def __init__(self):
self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))
@staticmethod
def _f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
with tf.Graph().as_default(), tf.Session() as sess:
x = tf.constant(1.0)
c = CustGradClass()
y = c.f(x)
print(tf.gradients(y, x))
# [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]
编辑:
如果你想在不同的类上多次这样做,或者只是想要一个更可重用的解决方案,你可以使用像这样的一些装饰器,例如:
import functools
import tensorflow as tf
def tf_custom_gradient_method(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if not hasattr(self, '_tf_custom_gradient_wrappers'):
self._tf_custom_gradient_wrappers = {}
if f not in self._tf_custom_gradient_wrappers:
self._tf_custom_gradient_wrappers[f] = tf.custom_gradient(lambda *a, **kw: f(self, *a, **kw))
return self._tf_custom_gradient_wrappers[f](*args, **kwargs)
return wrapped
然后你可以这样做:
class CustGradClass:
def __init__(self):
pass
@tf_custom_gradient_method
def f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
@tf_custom_gradient_method
def f2(self, x):
fx = x * 2
def grad(dy):
return dy * 2
return fx, grad
推荐阅读
- php - 用户名隐藏 html + php
- python - POSTGRES:创建表不保存
- python - 为什么这个初学者 python 程序中途停止?
- python - TypeError: __init__() got multiple values for argument 'kernel_size' 在 tensorflow 2.0 中发生错误
- c++ - 将向量按降序排序并从前面添加有效,但按升序排序并从后面添加则不行吗?
- c - 为什么 gcc 会生成 ARM 程序集,它将全局数据的地址放在标签中并在每个函数中重复
- python-3.x - TypeError:“int”类型的参数在索引中不可迭代
- python - 寻找 SumKSmallest 的算法
- python - 如何为 SVM 机器学习算法转换字符串数据
- snowflake-cloud-data-platform - 雪花数据阅读器显示错误的仓库数量