首页 > 解决方案 > 如何矢量化具有多索引访问的 numpy for 循环

问题描述

unigram是一个数组形状(N, M, 100)

我想删除for循环并执行所有计算。

seq是 size 的一维数组,大小M可能M高达 10000。

我想删除 for 循环并将其矢量化以便于计算。

batch_size, seq_len, num_labels = unigram_scores.shape
broadcast = np.broadcast_to(seq, (batch_size, seq_len))


for i in range(0, broadcast.shape[1]):
    n_seq[i] = unigram_scores[np.arange(batch_size), i , broadcast[:,i]]

编辑:@hpaulj 的回答完美无缺,并且还具有不必安装任何额外依赖项
的优势,速度比我预期的要低得多

我最终安装了 numba

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def calculate_unigram_probability(unigram_scores,seq):
    
    batch_size, seq_len, num_labels = unigram_scores.shape
    broadcast = np.broadcast_to(seq, (batch_size, seq_len))

    for i in prange( broadcast.shape[1]):
       n_seq[i] = unigram_scores[np.arange(batch_size), i , broadcast[:,i]]


    return n_seq

这也需要一点时间,目前我正在尝试将它从 cpu 移动到 cuda,这应该会带来我希望的加速

标签: pythonnumpy

解决方案


In [129]: N,M = 5,3
In [130]: unigram=np.arange(N*M*4).reshape(N,M,4)
In [131]: seq = np.arange(M)
In [132]: b_seq = np.broadcast_to(seq, (N,M))

对于单个i

In [133]: i=0; unigram[np.arange(N),i,b_seq[:,i]]
Out[133]: array([ 0, 12, 24, 36, 48])

对于i范围内的所有人:

In [136]: i=np.arange(M)[:,None]
In [137]: unigram[np.arange(N),i,b_seq[:,i]]
Out[137]: 
array([[[ 0, 12, 24, 36, 48],
        [ 5, 17, 29, 41, 53],
        [10, 22, 34, 46, 58]],

        ...
       [[ 0, 12, 24, 36, 48],
        [ 5, 17, 29, 41, 53],
        [10, 22, 34, 46, 58]]])

一个 (5,3,5) 数组。这(5,3)可能会更好)

In [141]: i=np.arange(M); unigram[np.arange(N)[:,None],i,b_seq[:,i]]
Out[141]: 
array([[ 0,  5, 10],
       [12, 17, 22],
       [24, 29, 34],
       [36, 41, 46],
       [48, 53, 58]])

我们不需要索引b_sequnigram[np.arange(N)[:,None],i,b_seq]

甚至使用;让索引broadcast seq

unigram[np.arange(N)[:,None],i,seq]

并在以下帮助下ix_

In [145]: I,J=np.ix_(np.arange(N), np.arange(M))
In [146]: unigram[I,J,seq]

要直观地了解此索引的作用,请查看unigram. 它从连续的块/批次中拉出“对角线”:

In [147]: unigram
Out[147]: 
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]],
        ...

推荐阅读