首页 > 解决方案 > 通过 pytorch 中的张量切片,火炬分配不到位

问题描述

我正在尝试将分配张量的当前代码转换为外部操作。
目前的意思是代码是

self.X[:, nc:] = D

其中 D 的形状与self.X[:, nc:]
但我想将其转换为

sliced_index = ~ somehow create an indexed tensor from self.X[:, nc:]
self.X = self.X.scatter(1,sliced_index,mm(S_, Z[:, :n - nc]))

并且不知道如何创建仅表示切片张量中的条目的索引掩码张量

最小的例子:

a = [[0,1,2],[3,4,5]]
D = [[6],[7]]
Not_in_place = [[0,1,6],[3,4,7]]

标签: pythonpytorch

解决方案


蒙面散点图更容易一些。掩码本身可以计算为就地操作,之后您可以使用masked_scatter

mask = torch.zeros(self.X.shape, device=self.X.device, dtype=torch.bool)
mask[:, nc:] = True
self.X = self.X.masked_scatter(mask, D)

一个依赖广播但应该更高效的更专业的版本是

mask = torch.zeros([1, self.X.size(1)], device=self.X.device, dtype=torch.bool)
mask[0, nc:] = True
self.X = self.X.masked_scatter(mask, D)

推荐阅读