python - 有没有办法使用 tensorflow 的 custom_gradient 装饰器定义 3D FFT 的梯度
问题描述
背景与问题
我正在使用tensorflow-probability
模块的 Hamiltonian Monte Carlo (HMC) 方法来探索自写概率函数的最可能状态。我试图拟合的参数包括真实三维场的傅里叶模式。
为了使 HMC 运行,每个计算块都需要实现其梯度。但是,逆实数 FFT 的实现tf.signal.irfft3d
默认没有关联梯度方法。
问题
有没有办法实现函数的梯度irfft3d
?我已经有一个运行的、自我实现irfft3d
的、带有更多基本块的tensorflow
自动微分似乎可以工作,但我想包装tf.signal.irfft3d
使用装饰器的实际优化和稳定的实现@tf.custom_gradient
来使自动微分工作。
我猜
傅里叶变换是线性的,这个问题在理论上是微不足道的。然而,在网格上写出傅里叶变换的雅可比行列在数值上是不可行的(因为它的维度会很大)。幸运的是,tensorflow
只需要一个在输入向量上评估雅可比行列式的函数。我相信这可以通过 FFT 算法有效地完成。不幸的是,在我看来,这tensorflow
需要一个函数来计算应用于“上游梯度”的雅可比行列式的反转,我不明白:
https://www.tensorflow.org/api_docs/python/tf/custom_gradient?version=nightly
函数 f(*x) 返回一个元组 (y, grad_fn) 其中:
- x 是函数的张量输入的(嵌套结构)序列。
- y 是将 f 中的 TensorFlow 操作应用于 x 的 Tensor 输出的(嵌套结构)。
- grad_fn 是一个带有签名 g(*grad_ys) 的函数,它返回一个与(扁平化)x 大小相同的张量列表——y 中的张量相对于 x 中的张量的导数。grad_ys 是与(扁平化)y 大小相同的张量序列,其中包含 y 中每个张量的初始值梯度。
在纯数学意义上,向量自变量向量值函数 f 的导数应该是它的雅可比矩阵 J。这里我们将雅可比 J 表示为函数 grad_fn,它定义了当与向量 grad_ys 左相乘时,J 将如何变换它( grad_ys * J,向量雅可比积或 VJP)。矩阵的这种函数表示便于用于链式规则计算(例如在反向传播算法中)。
遵守文档中给出的尺寸和格式,我无法想象任何其他解决方案:
#!/usr/bin/env python3
# set up
import tensorflow as tf
n = 64
noise = tf.random.normal((n, n, n))
modes = tf.signal.rfft3d(noise)
# the function
@tf.custom_gradient
def irfft3d(x):
def grad_fn(dy):
return (tf.signal.rfft3d(dy))
return (tf.signal.irfft3d(x), grad_fn)
# test
with tf.GradientTape() as gt:
gt.watch(modes)
rec_noise = irfft3d(modes)
dn_dm = gt.gradient(rec_noise, modes)
print(dn_dm)
哪个运行并返回:
tf.Tensor(
[[[262144.+0.j 0.+0.j 0.+0.j ... 0.+0.j 0.+0.j
0.+0.j]
[ 0.+0.j 0.+0.j 0.+0.j ... 0.+0.j 0.+0.j
0.+0.j]
[ 0.+0.j 0.+0.j 0.+0.j ... 0.+0.j 0.+0.j
0.+0.j]
...
[ 0.+0.j 0.+0.j 0.+0.j ... 0.+0.j 0.+0.j
0.+0.j]
[ 0.+0.j 0.+0.j 0.+0.j ... 0.+0.j 0.+0.j
0.+0.j]
[ 0.+0.j 0.+0.j 0.+0.j ... 0.+0.j 0.+0.j
0.+0.j]]], shape=(64, 64, 33), dtype=complex64)
我真的无法完全解决它。首先,这将是一个如此简单的解决方案,我不明白为什么它没有被本地实现。但更重要的是,我完全迷失了对tensorflow
这个自写梯度函数的期望,我无法用对我来说有意义的数学方式来表达它的结果。
有没有人了解tensorflow
处理差异化的方式并可以帮助或纠正我?
解决方案
推荐阅读
- vue.js - Highcharts-vue - 调用我自己的工具提示标签格式化程序函数
- c - 将 rand() 用于加密不安全的随机数是否可以接受?
- security - 去中心化的点对点身份验证
- php - 在 Laravel 中活跃的动态 li 类
- android - 在 Toast 之上的 Android Toast - 当顶部消失时底部的底部仍然存在
- sql - 如何解决 UPDATE 语句的条件?
- c++ - 如何忽略 gcc 5.x 中未使用的静态变量?
- react-redux - Redux - 限制一个动作所以不经常触发
- lua - 在 lua 中获取 lua 状态
- python - 删除 Pandas DataFrame 中的每 n 列