python - 如何矢量化具有多索引访问的 numpy for 循环
问题描述
unigram
是一个数组形状(N, M, 100)
我想删除for
循环并执行所有计算。
seq
是 size 的一维数组,大小M
可能M
高达 10000。
我想删除 for 循环并将其矢量化以便于计算。
batch_size, seq_len, num_labels = unigram_scores.shape
broadcast = np.broadcast_to(seq, (batch_size, seq_len))
for i in range(0, broadcast.shape[1]):
n_seq[i] = unigram_scores[np.arange(batch_size), i , broadcast[:,i]]
编辑:@hpaulj 的回答完美无缺,并且还具有不必安装任何额外依赖项
的优势,速度比我预期的要低得多
我最终安装了 numba
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def calculate_unigram_probability(unigram_scores,seq):
batch_size, seq_len, num_labels = unigram_scores.shape
broadcast = np.broadcast_to(seq, (batch_size, seq_len))
for i in prange( broadcast.shape[1]):
n_seq[i] = unigram_scores[np.arange(batch_size), i , broadcast[:,i]]
return n_seq
这也需要一点时间,目前我正在尝试将它从 cpu 移动到 cuda,这应该会带来我希望的加速
解决方案
In [129]: N,M = 5,3
In [130]: unigram=np.arange(N*M*4).reshape(N,M,4)
In [131]: seq = np.arange(M)
In [132]: b_seq = np.broadcast_to(seq, (N,M))
对于单个i
:
In [133]: i=0; unigram[np.arange(N),i,b_seq[:,i]]
Out[133]: array([ 0, 12, 24, 36, 48])
对于i
范围内的所有人:
In [136]: i=np.arange(M)[:,None]
In [137]: unigram[np.arange(N),i,b_seq[:,i]]
Out[137]:
array([[[ 0, 12, 24, 36, 48],
[ 5, 17, 29, 41, 53],
[10, 22, 34, 46, 58]],
...
[[ 0, 12, 24, 36, 48],
[ 5, 17, 29, 41, 53],
[10, 22, 34, 46, 58]]])
一个 (5,3,5) 数组。这(5,3)可能会更好)
In [141]: i=np.arange(M); unigram[np.arange(N)[:,None],i,b_seq[:,i]]
Out[141]:
array([[ 0, 5, 10],
[12, 17, 22],
[24, 29, 34],
[36, 41, 46],
[48, 53, 58]])
我们不需要索引b_seq
:unigram[np.arange(N)[:,None],i,b_seq]
甚至使用;让索引broadcast
seq
:
unigram[np.arange(N)[:,None],i,seq]
并在以下帮助下ix_
:
In [145]: I,J=np.ix_(np.arange(N), np.arange(M))
In [146]: unigram[I,J,seq]
要直观地了解此索引的作用,请查看unigram
. 它从连续的块/批次中拉出“对角线”:
In [147]: unigram
Out[147]:
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]],
...
推荐阅读
- python - 分离 \t,但保留 'e+06' 或任何权力
- mysql - 使用 FastAPI 和 MYSQL 在 SQLAlchemy 中自动生成的主列(整数)的默认限制是多少?
- python - Python:在嵌套字典和 .csv 文件之间转换(通用)?
- fullcalendar - 使用私有 Google 日历的 FullCalendar 示例
- python - Prefect 中的循环任务
- angular - 从应用程序导入 .scss,中断库构建,Angular11
- animation - JavaFX图表:启用动画时动态数据会导致视轴故障
- javascript - 如何链接和使用外部 JS 文件并在 ejs 文件中使用它?
- reactive-programming - RxPy:使用可变计数进行缓冲
- r - 如何从值大于 4 的现有变量创建新变量