首页 > 解决方案 > 给定索引ndarray和标志ndarray的任何numpy/torch样式设置值?

问题描述

我正在研究 PyTorch,目前我遇到了一个问题,我不知道如何以 Torch/numpy 风格解决它。例如,假设我有三个 PyTorch 张量

import torch
import numpy as np

indices = torch.from_numpy(np.array([[2, 1, 3, 0], [1, 0, 3, 2]]))
flags = torch.from_numpy(np.array([[False, False, False, True], [False, False, True, True]]))
tensor = torch.from_numpy(np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]]))

这是一个布尔标志张量,用于显示应提取flags的元素。indices给定提取的索引,我想将相应的元素设置tensor为指定的常量(比如 1e-30)。基于上面显示的示例,我想要

>>> sub_indices = indices.op1(flags)
>>> sub_indices
tensor([[0], [3, 2]])
>>> tensor.op2(sub_indices, 1e-30)
>>> tensor
tensor([[1e-30, 0.5, 1.2, 0.9], [3.1, 2.8, 1e-30, 1e-30]])

任何人都可以帮助提供解决方案吗?我正在使用列表理解,但我认为这种方式有点难看。我试过indices[flags]了,但它只返回一个一维数组[0, 3, 2],所以应用它会改变同一列 0、2、3 上的所有行

一些补充说明:

为了方便复制粘贴,下面是示例代码的 numpy 版本。我怀疑这是否可以以纯粹的 numpy 方式完成

import numpy as np

indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])

标签: pythonnumpypytorch

解决方案


您可以flags根据排序indices创建一个mask,然后将mask用作多路复用器。这是一个示例代码:

indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])

indices_sorted = indices.argsort(axis=1)
mask = np.take_along_axis(flags, indices_sorted, axis=1)
result = tensor * (1 - mask) + 1e-30 * mask

我对 pytorch 不太熟悉,但我想收集一个参差不齐的张量并不是一个好主意。虽然,即使在最坏的情况下,您也可以转换为/从 numpy 数组。


推荐阅读