首页 > 解决方案 > 如何从 Jax 中的联合累积密度函数计算联合概率密度函数?

问题描述

我在 python 中定义了一个联合累积密度函数作为 jax 数组的函数并返回单个值。就像是:

def cumulative(inputs: array) -> float:
    ...

要获得梯度,我知道我可以做到grad(cumulative),但这只是给我累积相对于输入变量的一阶偏导数。相反,我想做的是计算这个,假设 F 是我的函数, f 是联合概率密度函数:

公式

偏导数的顺序无关紧要。

所以,我有几个问题:

标签: pythonprobability-densityprobability-distributionautomatic-differentiationjax

解决方案


JAX 通常将渐变视为相对于单个参数,而不是参数中的元素。在这种情况下,一个与您想要做的类似(但不完全相同)的内置函数是jax.hessian,它计算二阶导数的 hessian 矩阵;例如:

import jax
import jax.numpy as jnp

def f(x):
  return jnp.prod(x ** 2)

x = jnp.arange(1.0, 4.0)
print(jax.hessian(f)(x))
# [[72. 72. 48.]
#  [72. 18. 24.]
#  [48. 24.  8.]]

对于数组的各个元素的高阶导数,我认为您必须手动嵌套渐变。您可以使用如下所示的辅助函数来执行此操作:

def grad_all(f):
  def gradfun(x):
    args = tuple(x)
    f_args = lambda *args: f(jnp.array(args))
    for i in range(len(args)):
      f_args = jax.grad(f_args, argnums=i)
    return f_args(*args)
  return gradfun

print(grad_all(f)(x))
# 48.0

推荐阅读