python - 使用 vmap 时,Jax 不支持不可散列的静态参数
问题描述
这与这个问题有关。经过一番工作,我设法将其更改为最后一个错误。代码现在看起来像这样。
import jax.numpy as jnp
from jax import grad, jit, value_and_grad
from jax import vmap, pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp
def p_tau(z, tau, alpha=1.5):
return jnp.clip((alpha - 1) * z - tau, 0) ** (1 / (alpha - 1))
def get_tau(tau, tau_max, tau_min, z_value):
return lax.cond(z_value < 1,
lambda _: (tau, tau_min),
lambda _: (tau_max, tau),
operand=None
)
def body(kwargs, x):
tau_min = kwargs['tau_min']
tau_max = kwargs['tau_max']
z = kwargs['z']
alpha = kwargs['alpha']
tau = (tau_min + tau_max) / 2
z_value = p_tau(z, tau, alpha).sum()
taus = get_tau(tau, tau_max, tau_min, z_value)
tau_max, tau_min = taus[0], taus[1]
return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None
@jax.partial(jax.jit, static_argnums=(2,))
def map_row(z_input, alpha, T):
z = (alpha - 1) * z_input
tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
length=T)
tau = (result['tau_max'] + result['tau_min']) / 2
result = p_tau(z, tau, alpha)
return result / result.sum()
@jax.partial(jax.jit, static_argnums=(1,3,))
def _entmax(input, axis=-1, alpha=1.5, T=20):
result = vmap(jax.partial(map_row, alpha, T), axis)(input)
return result
@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
return _entmax(input, axis, alpha, T)
@jax.partial(jax.jit, static_argnums=(0,2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
input = primals[0]
Y = entmax(input, axis, alpha, T)
gppr = Y ** (2 - alpha)
grad_output = tangents[0]
dX = grad_output * gppr
q = dX.sum(axis=axis) / gppr.sum(axis=axis)
q = jnp.expand_dims(q, axis=axis)
dX -= q * gppr
return Y, dX
@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
return _entmax_jvp_impl(axis, alpha, T, primals, tangents)
import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()
def toy(input, weight):
return (weight*entmax(input, 0, 1.5, 20)).sum()
jax.jit(value_and_grad(toy))(input, weight)
这导致(我希望)是最终错误,即
Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.batching.BatchTracer'> for function map_row is non-hashable.
这很奇怪,因为我认为我已经标记了每一个地方axis
似乎都是静态的,但它仍然告诉我它是被追踪的。
解决方案
当您编写partial
带有位置参数的函数时,首先传递这些参数。所以这:
jax.partial(map_row, alpha, T)
本质上等同于:
lambda z_input: map_row(alpha, T, z_input)
请注意参数的错误顺序——这就是导致错误的原因:您正在将z_input
一个不可散列的跟踪器传递给一个预期为静态的参数。
partial
您可以通过将上面的语句替换为以下内容来解决此问题:
lambda z: map_row(z, alpha, T)
然后您的代码将正确运行。
推荐阅读
- python - python3 json响应的Windows CMD输出在黑色背景上打印黑色字体
- r-markdown - 如何使用 Markdown 使图像居中?
- java - 在模块描述符中是否有订购要求的约定?
- javascript - /*#__PURE__*/ 在一些javascript源代码中是什么意思?
- java - 创建、写入和读取同一个文件返回一个 0 数组
- azure - 部署/托管的 asp.net 核心中的 WebAssembly launchSettings.json 文件
- svelte - 如何在 SvelteKit 中以编程方式路由?
- javascript - 基于浏览器的脏克隆(创建一个新的脚本元素,然后根据 function/class.toString() 更新其 .text)是否存在安全风险?
- reactjs - 反应数据表值重复
- rest - 如果用户无权访问某些请求的资源,那么使用标识符作为查询参数的 GET 调用的 REST 响应应该是什么?