python - 切片 jax.numpy 数组时性能下降
问题描述
在尝试对大型数组进行 SVD 压缩时,我在 Jax 中遇到了一些我不理解的行为。这是示例代码:
@jit
def jax_compress(L):
U, S, _ = jsc.linalg.svd(L,
full_matrices = False,
lapack_driver = 'gesvd',
check_finite=False,
overwrite_a=True)
maxS=jnp.max(S)
chi = jnp.sum(S/maxS>1E-1)
return chi, jnp.asarray(U)
在考虑这段代码时,Jax/jit 比 SciPy 提供了巨大的性能提升,但最终我想减少 U 的维数,我通过将它包装在函数中来做到这一点:
def jax_process(A):
chi, U = jax_compress(A)
return U[:,0:chi]
这一步在计算时间方面的成本令人难以置信,比 SciPy 的等价物更昂贵,从这个比较中可以看出:
sc_compress
并且sc_process
是上面 jax 代码的 SciPy 等价物。如您所见,在 SciPy 中对数组进行切片几乎不需要任何成本,但在应用于 hit 函数的输出时却非常昂贵。有人对这种行为有一些了解吗?
解决方案
我对 JAX 和 PyTorch 之间的切片速度进行了类似的比较。dynamic_slice
比普通切片快得多,但仍然比火炬中的同等切片要慢得多。由于我是 JAX 新手,我不确定原因是什么,但这可能与复制与引用有关,因为 JAX 数组是不可变的。
JAX(没有@jit)
key = random.PRNGKey(0)
j = random.normal(key, (32, 2, 1024, 1024, 3))
%timeit j[..., 100:600, 100:600, :].block_until_ready()
%timeit dynamic_slice(j, [0, 0, 100, 100, 0], [32, 2, 500, 500, 3]).block_until_ready()
2.78 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
993 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
PyTorch
t = torch.randn((32, 2, 1024, 1024, 3)).cuda()
%%timeit
t[..., 100:600, 100:600, :]
torch.cuda.synchronize()
7.63 µs ± 22.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
推荐阅读
- ios - .compact 样式的 UIDatePicker 不尊重内容拥抱优先级
- c# - 如何使用 Entity Framework Core 3.1 在一个事务中删除不超过 X 行
- mongodb - Bitnami mongodb cluster con't access from mongo3t client kubernetes
- failover - Ceph MDS 会在“up:replay”中停留数小时。MDS 故障转移需要 10-15 小时
- html - 在网格中悬停时显示整个截断的文本 - 有更好的解决方案吗?
- git - Git:如何提出仅包含文件子集(而不是提交)的选择性拉取请求
- mysql - 具有排名变量顺序的分数表并保持排名
- html - 在 Adobe Animate CC HTML5 画布中加载 SWF
- python - 如何将先前计算的总和添加到新输入中?
- go - 在 Go 二进制文件中打包 Vue 前端