python - PyTorch 和 Numpy 中的张量入口选择逻辑发散
问题描述
描述
我正在设置一个torch.Tensor
用于掩蔽目的。当尝试通过索引选择条目时,使用numpy.ndarray
和torch.Tensor
保存索引数据之间的行为是不同的。我希望能够访问解释差异的框架和相关文档中的设计。
复制步骤
环境
官方发布的容器中的 Pytorch 1.3:pytorch/pytorch:1.3-cuda10.1-cudnn7-devel
例子
假设我需要设置mask
为torch.Tensor
具有形状的对象[3,3,3]
并在条目(0,0,1)
&处设置值(1,2,0)
to 1
。下面的代码解释了差异。
mask = torch.zeros([3,3,3])
indices = torch.tensor([[0, 1],
[0, 2],
[1, 0]])
mask[indices.numpy()] = 1 # Works
# mask[indices] = 1 # Incorrect result
我注意到,当使用mask[indices.numpy()]
new torch.Tensor
of shape[2]
时,whilemask[indices]
返回一个 new torch.Tensor
of shape [3, 2, 3, 3]
,这表明张量切片逻辑存在差异。
解决方案
您会得到不同的结果,因为这就是在 Pytorch 中实现索引的方式。如果您将数组作为索引传递,那么它会被“解包”。例如:
indices = torch.tensor([[0, 1], [0, 2], [1, 0]])
mask = torch.arange(1,28).reshape(3,3,3)
# tensor([[[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9]],
# [[10, 11, 12],
# [13, 14, 15],
# [16, 17, 18]],
# [[19, 20, 21],
# [22, 23, 24],
# [25, 26, 27]]])
mask[indices.numpy()]
等价于,即第mask[[0, 1], [0, 2], [1, 0]]
i 行的indices.numpy()
元素用于选择mask
沿第 i 轴的元素。所以它返回tensor([mask[0,0,1], mask[1,2,0]])
,即tensor([2, 16])
。
另一方面,当将张量作为索引传递时(我不知道数组和张量用于索引的这种区别的确切原因),它不像数组那样“解包”,并且第 i 行的元素张量indices
的 用于选择mask
沿轴 0 的元素。也就是说,mask[indices]
相当于mask[[[0, 1], [0, 2], [1, 0]], :, :]
>>> mask[ind]
tensor([[[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]]],
[[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[19, 20, 21],
[22, 23, 24],
[25, 26, 27]]],
[[[10, 11, 12],
[13, 14, 15],
[16, 17, 18]],
[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]]]])
这基本上是tensor(mask[[0,1], :, :], mask[[0,2],: ,:], mask[[1,0], :, :])
有形状indices.shape + mask[0,:,:].shape == (3,2,3,3)
的。因此,整个“工作表”被选择并堆叠成新的维度。请注意,这不是一个新的张量,而是 的一个特殊视图mask
。因此,如果您指定mask[indices] = 1
, 带有这个 special indices
,那么 的所有元素都mask
将变为 1。
推荐阅读
- python - Python、英特尔 Python 和多核处理器
- python-3.x - 仅将可见点写入重叠散点图的磁盘
- python - 在 Pandas 中随着时间的推移绘制带有另一列标签的列
- python - 从 numpy 理解 _r
- javascript - 单个 html 元素的滚动指示器
- swift - 如何使用完成按钮确认我在 UIPickerView 中的选择?
- spring-boot - SpringBoot自定义ConstraintValidator没有被触发
- javascript - 按目录获取 azure 容器 blob 列表
- r - 在 R 中添加列,其值基于其他列的变量“名称”
- node.js - 网络请求失败 - Expo Fetch - React Native