pytorch - 4 个维度上的张量均值
问题描述
我有一个 PyTorch 张量
x = torch.arange(1, 601)
x = x.reshape(20, 5, -1).float()
这看起来像
tensor([[[ 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12.],
[ 13., 14., 15., 16., 17., 18.],
[ 19., 20., 21., 22., 23., 24.],
[ 25., 26., 27., 28., 29., 30.]],
[[ 31., 32., 33., 34., 35., 36.],
[ 37., 38., 39., 40., 41., 42.],
[ 43., 44., 45., 46., 47., 48.],
[ 49., 50., 51., 52., 53., 54.],
[ 55., 56., 57., 58., 59., 60.]],
[[ 61., 62., 63., 64., 65., 66.],
[ 67., 68., 69., 70., 71., 72.],
[ 73., 74., 75., 76., 77., 78.],
[ 79., 80., 81., 82., 83., 84.],
[ 85., 86., 87., 88., 89., 90.]],
我想在每个块中添加相同索引的每一行。意思是我想在每个第一行的轴 0 上求和:
[ 1., 2., 3., 4., 5., 6.]
[ 31., 32., 33., 34., 35., 36.]
[ 61., 62., 63., 64., 65., 66.]
:
:
然后每第二行在轴 0 上求和
[ 7., 8., 9., 10., 11., 12.]
[ 37., 38., 39., 40., 41., 42.]
[ 67., 68., 69., 70., 71., 72.]
:
:
等等
我将如何在 PyTorch 中做到这一点?
解决方案
您可以通过展平最后两个轴、在展平的轴上拆分并将结果堆叠回单个张量来实现此目的。以下是步骤:
>>> x = torch.arange(1, 91, dtype=float).reshape(3, 5, -1)
tensor([[[ 1., 2., 3., 4., 5., 6.],
[ 7., 8., 9., 10., 11., 12.],
[13., 14., 15., 16., 17., 18.],
[19., 20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29., 30.]],
[[31., 32., 33., 34., 35., 36.],
[37., 38., 39., 40., 41., 42.],
[43., 44., 45., 46., 47., 48.],
[49., 50., 51., 52., 53., 54.],
[55., 56., 57., 58., 59., 60.]],
[[61., 62., 63., 64., 65., 66.],
[67., 68., 69., 70., 71., 72.],
[73., 74., 75., 76., 77., 78.],
[79., 80., 81., 82., 83., 84.],
[85., 86., 87., 88., 89., 90.]]], dtype=torch.float64)
首先展axis=1
平和axis=2
:
>>> x.flatten(1)
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30.],
[31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60.],
[61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90.]], dtype=torch.float64)
我们现在正在寻找拆分为(x.size(2),)*x.size(1)
,这里它对应于(6, 6, 6, 6, 6)
。
>>> x.flatten(1).split((x.size(2),)*x.size(1), dim=1)
上面将返回一个包含拆分的元组,以重建张量使用torch.stack
:
>>> torch.stack(x.flatten(1).split((x.size(2),)*x.size(1), dim=1))
tensor([[[ 1., 2., 3., 4., 5., 6.],
[31., 32., 33., 34., 35., 36.],
[61., 62., 63., 64., 65., 66.]],
[[ 7., 8., 9., 10., 11., 12.],
[37., 38., 39., 40., 41., 42.],
[67., 68., 69., 70., 71., 72.]],
[[13., 14., 15., 16., 17., 18.],
[43., 44., 45., 46., 47., 48.],
[73., 74., 75., 76., 77., 78.]],
[[19., 20., 21., 22., 23., 24.],
[49., 50., 51., 52., 53., 54.],
[79., 80., 81., 82., 83., 84.]],
[[25., 26., 27., 28., 29., 30.],
[55., 56., 57., 58., 59., 60.],
[85., 86., 87., 88., 89., 90.]]], dtype=torch.float64)
推荐阅读
- wordpress - 使用联系表 7 更新用户元数据
- javascript - 从日期中找出月份的周数
- unity3d - 统一光子视图
- c# - 从对象(类)c#创建xml文件
- playback - 电子书中嵌入的 Spotify 播放功能
- react-native-windows - 有没有办法测试在 windows 操作系统中使用 react native 开发的 ios 应用程序?
- c++ - 如何在类外使用公共成员变量?
- java - 数据流:发布订阅消息的字符串
- android - 如何通过在新类中移动代码(Kotlin)来避免 Android 中的意大利面条代码
- linux - NginX 反向代理背后的 Apache 开发的 WebDAV