python - 具有不同维度索引数组的索引pytorch张量
问题描述
我有以下功能,它可以做我想要使用的功能,但是由于索引错误而numpy.array
在喂食时会中断。torch.Tensor
import torch
import numpy as np
def combination_matrix(arr):
idxs = np.arange(len(arr))
idx = np.ix_(idxs, idxs)
mesh = np.stack(np.meshgrid(idxs, idxs))
def np_combination_matrix():
output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
num_dims = len(output.shape)
output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
return output
def torch_combination_matrix():
output = torch.zeros(len(arr), len(arr), 2, *arr.shape[1:], dtype=arr.dtype)
num_dims = len(output.shape)
print(arr[mesh].shape) # <-- This is wrong/different to numpy!
output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
return output
if isinstance(arr, np.ndarray):
return np_combination_matrix()
elif isinstance(arr, torch.Tensor):
return torch_combination_matrix()
问题在于,这arr[mesh]
会导致不同的维度,具体取决于 numpy 和 torch。显然,pytorch 不支持使用与被索引的数组不同维度的索引数组进行索引。理想情况下,以下应该有效:
features = np.arange(9).reshape(3, 3)
np_combs = combination_matrix(features)
features = torch.from_numpy(features)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())
但尺寸不同:
(2, 3, 3, 3)
torch.Size([3, 3])
这会导致错误(逻辑上):
Traceback (most recent call last):
File "/home/XXX/util.py", line 226, in <module>
torch_combs = combination_matrix(features)
File "/home/XXX/util.py", line 218, in combination_matrix
return torch_combination_matrix()
File "/home/XXX/util.py", line 212, in torch_combination_matrix
output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
RuntimeError: number of dims don't match in permute
如何将火炬行为与 numpy 匹配?我已经阅读了火炬论坛上的各种问题(例如这个只有一维的问题),但可以在这里找到如何应用它。同样,index_select仅适用于一维,但我需要它至少适用于二维。
解决方案
这实际上很容易令人尴尬。您只需要展平索引,然后重塑和排列尺寸。这是完整的工作版本:
import torch
import numpy as np
def combination_matrix(arr):
idxs = np.arange(len(arr))
idx = np.ix_(idxs, idxs)
mesh = np.stack(np.meshgrid(idxs, idxs))
def np_combination_matrix():
output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
num_dims = len(output.shape)
output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
return output
def torch_combination_matrix():
output_shape = (2, len(arr), len(arr), *arr.shape[1:]) # Note that this is different to numpy!
return arr[mesh.flatten()].reshape(output_shape).permute(2, 1, 0, *range(3, len(output_shape)))
if isinstance(arr, np.ndarray):
return np_combination_matrix()
elif isinstance(arr, torch.Tensor):
return torch_combination_matrix()
我使用 pytest 在不同维度的随机数组上运行它,它似乎在所有情况下都有效:
import pytest
@pytest.mark.parametrize('random_dims', range(1, 5))
def test_combination_matrix(random_dims):
dim_size = np.random.randint(1, 40, size=random_dims)
elements = np.random.random(size=dim_size)
np_combs = combination_matrix(elements)
features = torch.from_numpy(elements)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())
if __name__ == '__main__':
pytest.main(['-x', __file__])
推荐阅读
- go - 我应该在 Go 中封装切片和地图吗?如果是这样,该怎么做?
- cloud-foundry - 使用 docker 镜像在 Cloud Foundry 中推送 Fortio
- python - 为什么我调用 task_done() 后队列仍然加入?
- ssl - 如何修复 matplotlib 安装错误
- typescript - 如何在 JSX 中扩展 HTMLButtonAttribute 类型
- sql - 在不违反主键和唯一约束的情况下更改表结构
- regex - 哪个更好用:regex.containsMatchIn(String) 或 String.contains(regex),为什么?
- ios - 如何使用 swift 将嵌套字典作为 JSON 正文发布
- java - openjdk docker基础上.deb文件的Java依赖失败
- xml - 如何将 laravel 中的完整 XML 文件从公共文件夹复制到存储文件夹