首页 > 解决方案 > 矢量化 for 循环 - 需要平均不同大小的切片

问题描述

我正在尝试平均子词嵌入以形成词级表示。每个单词都有一个相应的开始和结束索引,指示哪些子词组成了该词。

sequence_output是 B * 3 * 2 的张量,其中 3 是最大序列长度,2 是特征数。

all_token_mapping是 B * 3 * 2 的张量,其中包含开始和结束索引。

initial_reps是 num_nodes * 2 的张量,num_nodes 是不同样本中所有词(不是子词)数量的总和。

sequence_output = torch.arange(2*3*2).float().reshape(2, 3, 2)
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]]])
all_token_mapping = torch.tensor([[[0,0],[1,2],[-1,-1]], [[0,2],[-1,-1],[-1,-1]]])
tensor([[[ 0,  0],
         [ 1,  2],
         [-1, -1]],

        [[ 0,  2],
         [-1, -1],
         [-1, -1]]])
num_nodes = 0
for sample in all_token_mapping:
  for mapping in sample:
    if mapping[0] != -1:
      num_nodes += 1
3
initial_reps = torch.empty((num_nodes, 2), dtype=torch.float32)
current_idx = 0
for i, feature_tokens_mapping in enumerate(all_token_mapping):
    for j, token_mapping in enumerate(feature_tokens_mapping):
        if token_mapping[0] == -1: # reached the end for this particular sequence
            break
        initial_reps[current_idx] = torch.mean(sequence_output[i][token_mapping[0]:token_mapping[-1] + 1], 0, keepdim=True)                                           
        current_idx += 1
initial_reps
tensor([[0., 1.],
        [3., 4.],
        [8., 9.]])

在上面的例子中,initial_reps[0] 将是 sequence_output[0][0:1] 的均值,initial_reps[1] 将是 sequence_output[0][1:3] 的均值,initial_reps[2] 将是sequence_output[1][0:3] 的平均值。

我当前的代码将创建一个长度为 num_nodes 的空张量,并且 for 循环将通过检查 token_mapping[0] 和 token_mapping[1] 来计算每个索引处的值,以获取要平均的 sequence_output 的正确切片。

有没有办法对这段代码进行矢量化?

此外,我有一个列表,其中包含每个样本的单词数。即列表中所有元素的总和 == num_nodes

标签: pythonnumpymachine-learningpytorchvectorization

解决方案


我会像我在下面的代码中显示的那样做。该案例非常简化,因此我可以向您展示输入和输出的示例。但是这个概念可以扩展到任何数组大小或维度。我将“768”更改为设置为 5 的变量“num_features”,并将源节点的数量从 384 个减少到 4 个。

进口火炬

B = 3
num_nodes0 = 4
num_nodes = 3
num_features = 5

sequence_output = torch.arange(B * num_nodes0 * num_features).float()
sequence_output = sequence_output.reshape(B, num_nodes0, num_features)

all_token_mapping = torch.randint(0, num_nodes0, (B, num_nodes))

idx0 = torch.arange(B).reshape(-1, 1).repeat(1, num_nodes).flatten().long()
idx1 = all_token_mapping.flatten().long()
initial_reps = sequence_output[idx0, idx1, :].reshape(B, num_nodes, num_features)
initial_reps = torch.mean(initial_reps, axis = 1)

输出:

sequence_output = 
tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.]],

        [[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.]],

        [[40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.],
         [50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.]]])
all_token_mapping = 
tensor([[0, 2, 1],
        [2, 3, 0],
        [2, 0, 0]])
initial_reps = 
tensor([[ 5.0000,  6.0000,  7.0000,  8.0000,  9.0000],
        [28.3333, 29.3333, 30.3333, 31.3333, 32.3333],
        [43.3333, 44.3333, 45.3333, 46.3333, 47.3333]])

推荐阅读