首页 > 解决方案 > 在 pytorch 中,是否有内置方法来提取具有给定索引的行?

问题描述

假设我有一个火炬张量

import torch
a = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])

和一份清单

b = [0,2]

是否有内置方法来提取第 0 行和第 2 行并将它们放入新的张量中:

tensor([[1,2,3],
        [7,8,9]])

特别是,是否有一个看起来像这样的函数:

extract_rows(a,b) -> c

其中c包含所需的行。当然,这可以通过 for 循环来完成,但内置方法通常更快。

请注意,示例只是一个示例,列表中可能有几十个索引,张量中可能有数百行。

标签: pythonpytorchtensor

解决方案


看看 torch 内置的index_select()方法。这对你会有帮助。或者您可以使用切片来执行此操作。

tensor = [[1,2,3],
            [4,5,6],
            [7,8,9]]

new_tensor = tensor[0::2]
print(new_tensor)

输出:

[[1, 2, 3], [7, 8, 9]]

推荐阅读