python - 基于旧张量和二维索引的新张量
问题描述
我之前问过: PyTorch 张量:基于旧张量和索引的新张量
我现在有同样的问题,但需要使用 2d 索引张量。
我有一个大小为 [batch_size, k] 的张量 col,其值介于 0 和 k-1 之间:
idx = tensor([[0,1,2,0],
[0,3,2,2],
...])
和以下矩阵:
x = tensor([[[0, 9],
[1, 8],
[2, 3],
[4, 9]],
[[0, 0],
[1, 2],
[3, 4],
[5, 6]]])
我想创建一个新的张量,其中包含索引中指定的行,按该顺序。所以我想要:
tensor([[[0, 9],
[1, 8],
[2, 3],
[0, 9]],
[[0, 0],
[5, 6],
[3, 4],
[3, 4]]])
目前我正在这样做:
for i, batch in enumerate(t):
t[i] = batch[col[i]]
我怎样才能更有效地做到这一点?
解决方案
你应该使用火炬收集来实现这一点。它实际上也适用于您链接的其他问题,但这留给读者作为练习:p
让我们称idx
您的第一个张量和source
第二个张量。它们各自的尺寸是(B,N)
和(B, K, p)
(p=2
在你的例子中),所有的值idx
都在0
和之间K-1
。
所以要使用torch gather,我们首先需要将您的操作表示为嵌套的for循环。就您而言,您真正想要实现的是:
for b in range(B):
for i in range(N):
for j in range(p):
# This kind of nested for loops is what torch.gether actually does
target[b,i,j] = source[b, idx[b,i,j], j]
但这不起作用,因为idx
它是 2D 张量,而不是 3D 张量。好吧,没什么大不了的,让我们把它变成一个 3D 张量。我们希望它具有形状(B, N, p)
并且沿最后一个维度实际上是恒定的。然后我们可以用调用来替换 for 循环gather
:
reshaped_idx = idx.unsqueeze(-1).repeat(1,1,2)
target = source.gather(1, reshaped_idx)
# or : target = torch.gather(source, 1, reshaped_idx)
推荐阅读
- django - django SimpleListFilter中parameter_name的目的是什么?
- sql-server - 如何在从 docker-compose 调用的 .sh 文件中使用 .NET Core 机密
- c# - 如何创建其构造函数需要委托函数参数的泛型类型实例?
- r - R - 过滤数据中的负峰值和正峰值
- node.js - VS Code 中的终端无法执行节点命令行
- javascript - React Native - 来自 JSON 的地图数据可以显示为文本,但不能显示为地图标记坐标
- reactjs - 我想创建反应应用程序,但我不能。任何帮助将不胜感激
- r - 在 Shiny 中创建一个带有增量的 valueBox?
- postgresql - Postgres 持续显示高 CPU 使用率
- reporting-services - 同步 SSRS 中图表的系列组颜色