pytorch - 如何在索引处添加到 pytorch 张量?
问题描述
我不得不承认,我对 scatter* 和 index* 操作有点困惑——我不确定它们中的任何一个都完全符合我的要求,这很简单:
给定一些二维张量
z = tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
以及二维索引的列表(或张量?):
inds = tensor([[0, 0],
[1, 1],
[1, 2]])
我想在这些索引处向 z 添加一个标量(并有效地做到这一点):
znew = z.something_add(inds, 3)
->
znew = tensor([[4., 1., 1., 1.],
[1., 4., 4., 1.],
[1., 1., 1., 1.]])
如果必须,我可以使该标量成为任何形状的张量(所有元素 = 3),但我宁愿不...
解决方案
您必须为索引提供两个列表。第一个具有行位置,第二个具有列位置。在您的示例中,它将是:
z[[0, 1, 1], [0, 1, 2]] += 3
torch.Tensor 索引遵循 Numpy。有关更多详细信息,请参阅https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing。
推荐阅读
- python - 网状:AttributeError:模块'importlib'没有属性'util'
- r - 如何使用 BaseR 确定图例的形式?
- python - 数据框字符串列中的字符匹配和替换
- sql - SQL 非重复连接到使用空值作为通配符的表
- linux - 遍历bash中的范围
- wordpress - Wordpress 帖子元查询过滤器
- python - python中的单元测试-如何测试使用`read_sql_query`返回的数据帧中的数据类型?
- jpa - Spring命名+jpa方法
- python - Django:为学生和导师创建模型的最佳方式是什么,并且他们都有不同的配置文件?
- google-tag-manager - 跟踪代码管理器权限中的多个 URI 模式?