首页 > 解决方案 > 如何使用单个循环计算 pytorch 中两组特征之间的乘积?

问题描述

我想计算两组特征矩阵X和每组Y 维度之间的乘积(H,W,12)

我会低效地做:

H = []
for i in range(12):
    for j in range(12):
        h = X[:,:,i]*Y[:,:,j]
        H.append(h)

这将输出H维度(H,W,144)

如何在不迭代两个循环的情况下在 pytorch 中完成?

我尝试过使用 tensordot 解决方案,但无法复制该行为。

标签: pytorchtensordot

解决方案


我不确定这是最有效的,但你可以做这样的事情(警告:前面的代码很丑=]):

import torch

# I choose not to use random -- easier to verify, IMO
a = torch.Tensor([[[1,2],[3,4],[5,6]],[[1,2],[3,4],[5,6]]])
b = torch.Tensor([[[1,2],[3,4],[5,6]],[[1,2],[3,4],[5,6]]])

c = torch.bmm(
    a.view(-1, a.size(-1), 1),
    b.view(-1, 1, b.size(-1))
).view(*(a.shape[:2]), -1)

print(c)

print(a.shape)
print(b.shape)
print(c.shape)

输出:

tensor([[[ 1.,  2.,  2.,  4.],
         [ 9., 12., 12., 16.],
         [25., 30., 30., 36.]],

        [[ 1.,  2.,  2.,  4.],
         [ 9., 12., 12., 16.],
         [25., 30., 30., 36.]]])

torch.Size([2, 3, 2])  # a
torch.Size([2, 3, 2])  # b
torch.Size([2, 3, 4])  # c

基本上,外部产品。如果您需要我解释,请告诉我。


计时

使用torch.bmm32 个内核中的 16 个。我使用 GeForce RTX 2080 Ti 运行 CUDA 版本(执行期间 GPU 使用率约为 97%)。请注意,GPU 时序上使用的尺寸是 x10,否则它太快了。

脚本:

import timeit

setup = '''
import torch
a = torch.randn(({H}, {W}, 12))
b = torch.randn(({H}, {W}, 12))
'''

setup_cuda = setup.replace("))", ")).to(torch.device('cuda'))")

bmm = '''
c = torch.bmm(
    a.view(-1, a.size(-1), 1),
    b.view(-1, 1, b.size(-1))
).view(*(a.shape[:2]), -1)
'''

loop = '''
c = []
for i in range(a.size(-1)):
    for j in range(b.size(-1)):
        c.append(a[:, :, i] * b[:, :, j])
c = torch.stack(c).permute(1, 2, 0)
'''

min_dim = 10
max_dim = 100
num_repeats = 10

print('BMM')
for d in range(min_dim, max_dim+1, 10):
    print(d, min(timeit.Timer(bmm, setup=setup.format(H=d, W=d)).repeat(num_repeats, 1000)))

print('LOOP')
for d in range(min_dim, max_dim+1, 10):
    print(d, min(timeit.Timer(loop, setup=setup.format(H=d, W=d)).repeat(num_repeats, 1000)))

print('BMM - CUDA')
for d in range(min_dim*10, (max_dim*10)+1, 100):
    print(d, min(timeit.Timer(bmm, setup=setup_cuda.format(H=d, W=d)).repeat(num_repeats, 1000)))

输出:

BMM
10 0.022082214010879397
20 0.034024904016405344
30 0.08957623899914324
40 0.1376199919031933
50 0.20248223491944373
60 0.2657837320584804
70 0.3533527449471876
80 0.42361779196653515
90 0.6103016039123759
100 0.7161333339754492

LOOP
10 1.7369094720343128
20 1.8517447559861466
30 1.9145489090587944
40 2.0530637570191175
50 2.2066439649788663
60 2.394576688995585
70 2.6210166650125757
80 2.9242434420157224
90 3.5709626079769805
100 5.413458575960249

BMM - CUDA
100 0.014253990724682808
200 0.015094103291630745
300 0.12792395427823067
400 0.307440347969532
500 0.541196970269084
600 0.8697826713323593
700 1.2538292426615953
800 1.6859236396849155
900 2.2016236428171396
1000 2.764942280948162

推荐阅读