首页 > 解决方案 > 从函数中有效地填充数组

问题描述

我想以我可以利用的方式从函数构造一个二维数组jax.jit

我通常使用的方法numpy是创建一个空数组,然后就地填充该数组。

xx = jnp.empty((num_a, num_b))
yy = jnp.empty((num_a, num_b))
zz = jnp.empty((num_a, num_b))

for ii_a in range(num_a):
    for ii_b in range(num_b):
        a = aa[ii_a, ii_b]
        b = bb[ii_a, ii_b]

        xyz = self.get_coord(a, b)

        xx[ii_a, ii_b] = xyz[0]
        yy[ii_a, ii_b] = xyz[1]
        zz[ii_a, ii_b] = xyz[2]

为了使这项工作在jax我尝试使用jax.opt.index_update.

        xx = xx.at[ii_a, ii_b].set(xyz[0])
        yy = yy.at[ii_a, ii_b].set(xyz[1])
        zz = zz.at[ii_a, ii_b].set(xyz[2])

这运行没有错误,但是当我尝试使用@jax.jit装饰器时非常慢(至少比纯 python/numpy 版本慢一个数量级)。

从函数中填充多维数组的最佳方法是什么jax

标签: pythonnumpyjax

解决方案


JAX 具有专门为此类应用程序设计的vmap转换。

只要您的get_coords函数与 JAX 兼容(即是一个没有副作用的纯函数),您就可以在一行中完成此操作:

from jax import vmap
xx, yy, zz = vmap(vmap(get_coord))(aa, bb)

推荐阅读