python - 从函数中有效地填充数组
问题描述
我想以我可以利用的方式从函数构造一个二维数组jax.jit
。
我通常使用的方法numpy
是创建一个空数组,然后就地填充该数组。
xx = jnp.empty((num_a, num_b))
yy = jnp.empty((num_a, num_b))
zz = jnp.empty((num_a, num_b))
for ii_a in range(num_a):
for ii_b in range(num_b):
a = aa[ii_a, ii_b]
b = bb[ii_a, ii_b]
xyz = self.get_coord(a, b)
xx[ii_a, ii_b] = xyz[0]
yy[ii_a, ii_b] = xyz[1]
zz[ii_a, ii_b] = xyz[2]
为了使这项工作在jax
我尝试使用jax.opt.index_update
.
xx = xx.at[ii_a, ii_b].set(xyz[0])
yy = yy.at[ii_a, ii_b].set(xyz[1])
zz = zz.at[ii_a, ii_b].set(xyz[2])
这运行没有错误,但是当我尝试使用@jax.jit
装饰器时非常慢(至少比纯 python/numpy 版本慢一个数量级)。
从函数中填充多维数组的最佳方法是什么jax
?
解决方案
JAX 具有专门为此类应用程序设计的vmap
转换。
只要您的get_coords
函数与 JAX 兼容(即是一个没有副作用的纯函数),您就可以在一行中完成此操作:
from jax import vmap
xx, yy, zz = vmap(vmap(get_coord))(aa, bb)
推荐阅读
- node.js - mongodb查询事件完成到一个项目并将它们分组到一个数组中
- javascript - Javascript ajax如何只显示数组的值
- mysql - MySQL 中的小数点后两位,但在单词 SELECT 处显示错误
- python - 如何在烧瓶中看到预览图像
- google-sheets - 在数组中查找字符串并返回列的第一个值
- java - Intellij 无法加载 ModelBuilderService$Ex 类
- c++ - 如何从字符串中删除重复的单词,并且只用它们的字数显示一次
- java - 如何按升序对二维数组进行排序?
- playframework - 为大量并发 SSE 连接扩展 Play 2.7.x(使用 Akka HTTP 服务器)
- java - 如何管理 Dataflow Java 管道中的属性和秘密值?