首页 > 解决方案 > 通过连接从元组的多个元素创建火炬张量的简单方法

问题描述

输入

我的清单如下

r1 = [([[[1, 2, 3], [1, 2, 3]], 
        [[4, 5, 6], [4, 5, 6]]],
       [[[7, 8], [7, 8]], 
        [[9, 10], [9, 10]]]),

      ([[[11, 12, 13], [11, 12, 13]], 
        [[14, 15, 16], [14, 15, 16]]],
       [[[17, 18], [17, 18]], 
        [[19, 20], [19, 20]]])]

我将从上面的输入中制作 2 个火炬张量。

我想要的输出如下

输出

output = 
[tensor([[[ 1,  2,  3],
          [ 1,  2,  3]],
 
         [[ 4,  5,  6],
          [ 4,  5,  6]],
 
         [[11, 12, 13],
          [11, 12, 13]],
 
         [[14, 15, 16],
          [14, 15, 16]]]), 

 tensor([[[ 7,  8],
          [ 7,  8]],
 
         [[ 9, 10],
          [ 9, 10]],
 
         [[17, 18],
          [17, 18]],
 
         [[19, 20],
          [19, 20]]])]

我的代码如下。

output = []
for i in range(len(r1[0])):
    templates = []
    for j in range(len(r1)):
        templates.append(torch.tensor(r1[j][i]))
        template = torch.cat(templates)
    output.append(template)

有没有更简单或更容易的方法来获得我想要的结果?

标签: pythonlistpytorchtensor

解决方案


这将做:

output = [torch.Tensor([*a, *b]) for a, b in zip(*r1)]

它首先连接两个列表的相应项目,然后创建张量


推荐阅读