python - 给定索引ndarray和标志ndarray的任何numpy/torch样式设置值?
问题描述
我正在研究 PyTorch,目前我遇到了一个问题,我不知道如何以 Torch/numpy 风格解决它。例如,假设我有三个 PyTorch 张量
import torch
import numpy as np
indices = torch.from_numpy(np.array([[2, 1, 3, 0], [1, 0, 3, 2]]))
flags = torch.from_numpy(np.array([[False, False, False, True], [False, False, True, True]]))
tensor = torch.from_numpy(np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]]))
这是一个布尔标志张量,用于显示应提取flags
的元素。indices
给定提取的索引,我想将相应的元素设置tensor
为指定的常量(比如 1e-30)。基于上面显示的示例,我想要
>>> sub_indices = indices.op1(flags)
>>> sub_indices
tensor([[0], [3, 2]])
>>> tensor.op2(sub_indices, 1e-30)
>>> tensor
tensor([[1e-30, 0.5, 1.2, 0.9], [3.1, 2.8, 1e-30, 1e-30]])
任何人都可以帮助提供解决方案吗?我正在使用列表理解,但我认为这种方式有点难看。我试过indices[flags]
了,但它只返回一个一维数组[0, 3, 2]
,所以应用它会改变同一列 0、2、3 上的所有行
一些补充说明:
flags
无法确定每一行的“真”值的数量- 的每一行都
indices
保证是序列的一个排列0 ... N - 1
为了方便复制粘贴,下面是示例代码的 numpy 版本。我怀疑这是否可以以纯粹的 numpy 方式完成
import numpy as np
indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])
解决方案
您可以flags
根据排序indices
创建一个mask
,然后将mask
用作多路复用器。这是一个示例代码:
indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])
indices_sorted = indices.argsort(axis=1)
mask = np.take_along_axis(flags, indices_sorted, axis=1)
result = tensor * (1 - mask) + 1e-30 * mask
我对 pytorch 不太熟悉,但我想收集一个参差不齐的张量并不是一个好主意。虽然,即使在最坏的情况下,您也可以转换为/从 numpy 数组。
推荐阅读
- java - 如何在 Java Swing 中找到最后按下的按钮
- docker - 由 Dockerfile 或 Docker-compose 文件创建的 Docker 卷
- reactjs - 反应,在选择选项上获取数据属性
- reactjs - React Native Firebase 的侦听器会优化 Auth 和 FireStore 查询,还是建议创建自定义缓存?
- verilog - 连续赋值中的非法赋值表达式。Verilog 八进制 2 到 1 多路复用器
- android - 中国手机上的Android硬重置
- python - 从泛欧交易所网站(欧洲证券交易所)导入“上市股票”
- php - finfo_file():CodeIgniter 4.0.4 中的空文件名或路径
- azure - 如何使用 ansible 在一个循环中挂载多个磁盘
- php - 如何制作(用于 VS 代码的 PHP 工具)不制作(emmet.excludeLanguages)到 PHP 语言