首页 > 解决方案 > 重塑 Pytorch 张量

问题描述

(24, 2, 224, 224)在 Pytorch 中有一个大小的张量。

这是执行二进制分割的 CNN 的输出。在 2 个矩阵的每个单元格中存储该像素为前景或背景的概率:[n][0][h][w] + [n][1][h][w] = 1对于每个坐标

我想将它重塑为 size 的张量(24, 1, 224, 224)。新层中的值应该是01根据概率较高的矩阵。

我怎样才能做到这一点?我应该使用哪个功能?

标签: pythonarraysnumpymatrixpytorch

解决方案


使用torch.argmax()(对于 PyTorch +0.4):

prediction = torch.argmax(tensor, dim=1) # with 'dim' the considered dimension 
prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)

如果 PyTorch 版本低于 0.4.0,则可以使用tensor.max()它返回最大值和它们的索引(但不能通过索引值区分):

_, prediction = tensor.max(dim=1)
prediction = prediction.unsqueeze(1) # to reshape from (24, 224, 224) to (24, 1, 224, 224)

推荐阅读