首页 > 解决方案 > unsqueeze 在 pytorch 中有什么作用?

问题描述

lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), 
                         set_2[:, :2].unsqueeze(0))   #(n1, n2, 2)

此代码片段unsqueeze(1)用于一个张量,但unsqeeze(0)用于另一个。它们之间有什么区别?

标签: pythonpytorchtorch

解决方案


unsqueeze通过添加一个零深度的额外维度,将一个 n 维张量变成一个 n+1 维张量。然而,由于新维度应该位于哪个轴上(即它应该“不被挤压”的方向)是不明确的,这需要由dim参数指定。

因此,得到的未压缩张量具有相同的信息,但用于访问它们的索引不同。

这是一个有效 2d 矩阵的可视化表示squeezeunsqueeze它从 2d 张量变为 3d 张量,因此新维度的位置有 3 种选择:

在此处输入图像描述


推荐阅读