python - PyTorch - 在 torch.sort 之后取回原始张量顺序的更好方法
问题描述
我想在操作和对排序张量进行一些其他修改后取回原始张量顺序torch.sort
,以便不再对张量进行排序。最好用一个例子来解释这一点:
x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices)
# final must be equal to torch.tanh(x)
我以这种方式实现了该功能:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
for i in range(ordered.size(0)):
z[indices[i]] = ordered[i]
return z
有一个更好的方法吗?特别是,是否可以避免循环并更有效地计算操作?
在我的情况下,我有一个大小的张量,我通过一次调用单独torch.Size([B, N])
对每一行进行排序。所以,我必须用另一个循环来调用时间。B
torch.sort
original_order
B
任何,更多pytorch-ic,想法?
编辑 1 - 摆脱内循环
我通过以这种方式简单地用索引索引 z 解决了部分问题:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
z[indices] = ordered
return z
现在,我只需要了解如何避免B
维度上的外循环。
编辑 2 - 摆脱外循环
def original_order(ordered, indices, batch_size):
# produce a vector to shift indices by lenght of the vector
# times the batch position
add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)
indices = indices + add.long().view(-1,1)
# reduce tensor to single dimension.
# Now the indices take in consideration the new length
long_ordered = ordered.view(-1)
long_indices = indices.view(-1)
# we are in the previous case with one dimensional vector
z = torch.zeros_like(long_ordered).float()
z[long_indices] = long_ordered
# reshape to get back to the correct dimension
return z.view(batch_size, -1)
解决方案
def original_order(ordered, indices):
return ordered.gather(1, indices.argsort(1))
例子
original = torch.tensor([
[20, 22, 24, 21],
[12, 14, 10, 11],
[34, 31, 30, 32]])
sorted, index = original.sort()
unsorted = sorted.gather(1, index.argsort(1))
assert(torch.all(original == unsorted))
为什么有效
为简单起见,想象一下t = [30, 10, 20]
,省略张量符号。
t.sort()
免费为我们提供排序张量s = [10, 20, 30]
以及排序索引i = [1, 2, 0]
。i
实际上是 的输出t.argsort()
。
i
告诉我们如何从t
到s
。t
“要对进行排序s
,请从“中获取元素 1,然后是 2,然后是 0 t
。Argsortingi
为我们提供了另一个排序索引j = [2, 0, 1]
,它告诉我们如何从i
自然数的规范序列转到自然数的规范序列[0, 1, 2]
,实际上是反转排序。另一种看待它的方式是j
告诉我们如何从s
到t
。“要排序s
到t
,请从“中获取元素 2,然后是 0,然后是 1 s
。对排序索引进行 Argsorting 为我们提供了它的“逆索引”,反之亦然。
现在我们有了逆索引,我们将其转储到torch.gather()
正确的dim
中,这会取消对张量的排序。
来源
在研究这个问题时我找不到这个确切的解决方案,所以我认为这是一个原始答案。
推荐阅读
- android - 降低 ble for android 的 Tx 功率
- visual-studio - 项目xamarin中不存在目标“CssG”
- spring-boot - RestTemplate 模拟给出 NullPointerException
- azure - Azure函数队列触发器执行多次
- python - 从数据框的不同列显示堆积条形图的每种颜色的值
- node.js - pg-promise 将整数读取为字符串,即使在被解析之后也是如此
- sql - SQL中的左右连接
- node.js - Angular 5 - ag-grid 18.0.1 - 边缘崩溃
- ios - 向下滚动(或从 ipad 更改为 iphone 时)的 TableViewCell 边框和阴影
- vba - 根据条件查找范围内最常用的文本