首页 > 解决方案 > 4 个维度上的张量均值

问题描述

我有一个 PyTorch 张量

x = torch.arange(1, 601)
x = x.reshape(20, 5, -1).float()

这看起来像

tensor([[[  1.,   2.,   3.,   4.,   5.,   6.],
         [  7.,   8.,   9.,  10.,  11.,  12.],
         [ 13.,  14.,  15.,  16.,  17.,  18.],
         [ 19.,  20.,  21.,  22.,  23.,  24.],
         [ 25.,  26.,  27.,  28.,  29.,  30.]],

        [[ 31.,  32.,  33.,  34.,  35.,  36.],
         [ 37.,  38.,  39.,  40.,  41.,  42.],
         [ 43.,  44.,  45.,  46.,  47.,  48.],
         [ 49.,  50.,  51.,  52.,  53.,  54.],
         [ 55.,  56.,  57.,  58.,  59.,  60.]],

        [[ 61.,  62.,  63.,  64.,  65.,  66.],
         [ 67.,  68.,  69.,  70.,  71.,  72.],
         [ 73.,  74.,  75.,  76.,  77.,  78.],
         [ 79.,  80.,  81.,  82.,  83.,  84.],
         [ 85.,  86.,  87.,  88.,  89.,  90.]],

我想在每个块中添加相同索引的每一行。意思是我想在每个第一行的轴 0 上求和:

[  1.,   2.,   3.,   4.,   5.,   6.]
[ 31.,  32.,  33.,  34.,  35.,  36.]
[ 61.,  62.,  63.,  64.,  65.,  66.]
:
:

然后每第二行在轴 0 上求和

[  7.,   8.,   9.,  10.,  11.,  12.]
[ 37.,  38.,  39.,  40.,  41.,  42.]
[ 67.,  68.,  69.,  70.,  71.,  72.]
:
:

等等

我将如何在 PyTorch 中做到这一点?

标签: pytorchtensor

解决方案


您可以通过展平最后两个轴、在展平的轴上拆分并将结果堆叠回单个张量来实现此目的。以下是步骤:

>>> x = torch.arange(1, 91, dtype=float).reshape(3, 5, -1)
tensor([[[ 1.,  2.,  3.,  4.,  5.,  6.],
         [ 7.,  8.,  9., 10., 11., 12.],
         [13., 14., 15., 16., 17., 18.],
         [19., 20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29., 30.]],

        [[31., 32., 33., 34., 35., 36.],
         [37., 38., 39., 40., 41., 42.],
         [43., 44., 45., 46., 47., 48.],
         [49., 50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59., 60.]],

        [[61., 62., 63., 64., 65., 66.],
         [67., 68., 69., 70., 71., 72.],
         [73., 74., 75., 76., 77., 78.],
         [79., 80., 81., 82., 83., 84.],
         [85., 86., 87., 88., 89., 90.]]], dtype=torch.float64)

首先展axis=1平和axis=2

>>> x.flatten(1)
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30.],
        [31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60.],
        [61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90.]], dtype=torch.float64)

我们现在正在寻找拆分为(x.size(2),)*x.size(1),这里它对应于(6, 6, 6, 6, 6)

>>> x.flatten(1).split((x.size(2),)*x.size(1), dim=1)

上面将返回一个包含拆分的元组,以重建张量使用torch.stack

>>> torch.stack(x.flatten(1).split((x.size(2),)*x.size(1), dim=1))
tensor([[[ 1.,  2.,  3.,  4.,  5.,  6.],
         [31., 32., 33., 34., 35., 36.],
         [61., 62., 63., 64., 65., 66.]],

        [[ 7.,  8.,  9., 10., 11., 12.],
         [37., 38., 39., 40., 41., 42.],
         [67., 68., 69., 70., 71., 72.]],

        [[13., 14., 15., 16., 17., 18.],
         [43., 44., 45., 46., 47., 48.],
         [73., 74., 75., 76., 77., 78.]],

        [[19., 20., 21., 22., 23., 24.],
         [49., 50., 51., 52., 53., 54.],
         [79., 80., 81., 82., 83., 84.]],

        [[25., 26., 27., 28., 29., 30.],
         [55., 56., 57., 58., 59., 60.],
         [85., 86., 87., 88., 89., 90.]]], dtype=torch.float64)

推荐阅读