首页 > 解决方案 > 具有不同维度索引数组的索引pytorch张量

问题描述

我有以下功能,它可以做我想要使用的功能,但是由于索引错误而numpy.array在喂食时会中断。torch.Tensor

import torch
import numpy as np


def combination_matrix(arr):
    idxs = np.arange(len(arr))
    idx = np.ix_(idxs, idxs)
    mesh = np.stack(np.meshgrid(idxs, idxs))

    def np_combination_matrix():
        output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
        num_dims = len(output.shape)
        output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
        return output

    def torch_combination_matrix():
        output = torch.zeros(len(arr), len(arr), 2, *arr.shape[1:], dtype=arr.dtype)
        num_dims = len(output.shape)
        print(arr[mesh].shape)  # <-- This is wrong/different to numpy!
        output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
        return output

    if isinstance(arr, np.ndarray):
        return np_combination_matrix()
    elif isinstance(arr, torch.Tensor):
        return torch_combination_matrix()

问题在于,这arr[mesh]会导致不同的维度,具体取决于 numpy 和 torch。显然,pytorch 不支持使用与被索引的数组不同维度的索引数组进行索引。理想情况下,以下应该有效:

features = np.arange(9).reshape(3, 3)
np_combs = combination_matrix(features)
features = torch.from_numpy(features)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())

但尺寸不同:

(2, 3, 3, 3)
torch.Size([3, 3])

这会导致错误(逻辑上):

Traceback (most recent call last):
  File "/home/XXX/util.py", line 226, in <module>
    torch_combs = combination_matrix(features)
  File "/home/XXX/util.py", line 218, in combination_matrix
    return torch_combination_matrix()
  File "/home/XXX/util.py", line 212, in torch_combination_matrix
    output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
RuntimeError: number of dims don't match in permute

如何将火炬行为与 numpy 匹配?我已经阅读了火炬论坛上的各种问题(例如这个只有一维的问题),但可以在这里找到如何应用它。同样,index_select仅适用于一维,但我需要它至少适用于二维。

标签: pythonnumpyindexingpytorch

解决方案


这实际上很容易令人尴尬。您只需要展平索引,然后重塑和排列尺寸。这是完整的工作版本:

import torch
import numpy as np


def combination_matrix(arr):
    idxs = np.arange(len(arr))
    idx = np.ix_(idxs, idxs)
    mesh = np.stack(np.meshgrid(idxs, idxs))

    def np_combination_matrix():
        output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
        num_dims = len(output.shape)
        output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
        return output

    def torch_combination_matrix():
        output_shape = (2, len(arr), len(arr), *arr.shape[1:])  # Note that this is different to numpy!
        return arr[mesh.flatten()].reshape(output_shape).permute(2, 1, 0, *range(3, len(output_shape)))

    if isinstance(arr, np.ndarray):
        return np_combination_matrix()
    elif isinstance(arr, torch.Tensor):
        return torch_combination_matrix()

我使用 pytest 在不同维度的随机数组上运行它,它似乎在所有情况下都有效:

import pytest

@pytest.mark.parametrize('random_dims', range(1, 5))
def test_combination_matrix(random_dims):
    dim_size = np.random.randint(1, 40, size=random_dims)
    elements = np.random.random(size=dim_size)
    np_combs = combination_matrix(elements)
    features = torch.from_numpy(elements)
    torch_combs = combination_matrix(features)

    assert np.array_equal(np_combs, torch_combs.numpy())

if __name__ == '__main__':
    pytest.main(['-x', __file__])

推荐阅读