python - 如何从 Jax 中的联合累积密度函数计算联合概率密度函数?
问题描述
我在 python 中定义了一个联合累积密度函数作为 jax 数组的函数并返回单个值。就像是:
def cumulative(inputs: array) -> float:
...
要获得梯度,我知道我可以做到grad(cumulative)
,但这只是给我累积相对于输入变量的一阶偏导数。相反,我想做的是计算这个,假设 F 是我的函数, f 是联合概率密度函数:
偏导数的顺序无关紧要。
所以,我有几个问题:
- 如何在 Jax 中有效地计算这个?我想我不能只打电话给 grad n 次
- 一旦计算出结果函数,结果函数是否会比原始函数具有更高的调用复杂度(是否增加了 O(n),或者它是恒定的,还是其他什么)?
- 或者,如何仅针对输入数组的一个变量而不是整个数组计算单个偏导数?(我将重复此 n 次,每个变量一次)
解决方案
JAX 通常将渐变视为相对于单个参数,而不是参数中的元素。在这种情况下,一个与您想要做的类似(但不完全相同)的内置函数是jax.hessian
,它计算二阶导数的 hessian 矩阵;例如:
import jax
import jax.numpy as jnp
def f(x):
return jnp.prod(x ** 2)
x = jnp.arange(1.0, 4.0)
print(jax.hessian(f)(x))
# [[72. 72. 48.]
# [72. 18. 24.]
# [48. 24. 8.]]
对于数组的各个元素的高阶导数,我认为您必须手动嵌套渐变。您可以使用如下所示的辅助函数来执行此操作:
def grad_all(f):
def gradfun(x):
args = tuple(x)
f_args = lambda *args: f(jnp.array(args))
for i in range(len(args)):
f_args = jax.grad(f_args, argnums=i)
return f_args(*args)
return gradfun
print(grad_all(f)(x))
# 48.0
推荐阅读
- python-3.x - 如何在 Windows 子系统(Ubuntu 18.04)上运行 python tkinter?
- arrays - 在Apple Swift中将json转换为数组的字符串数组
- julia - Julia 1.2 下无法安装 Genie 框架
- arrays - 如何在laravel中将数组插入多行
- python - 重置 Python/Matplotlib 中绘图的默认字体/颜色
- python - sum 每次都可以被列表的一个元素整除
- c - C - 将错误的指针类型传递给函数
- c++ - 我不能在 C++ 上使用 fmt 库头文件
- java - 更新集合内的列表 dbref
- python - 如何在 Python 中使用 BeautifulSoup 从 html 中提取特定文本?