首页 > 解决方案 > 将张量转换为索引的一个热编码张量

问题描述

我有形状为 (1,1,128,128,128) 的标签张量,其中值的范围可能为 0,24。我想使用该nn.fucntional.one_hot函数将其转换为一个热编码张量

n = 24
one_hot = torch.nn.functional.one_hot(indices, n)

但这需要一个指数的张量,老实说,我不知道如何得到这些。我拥有的唯一张量是上述形状的标签张量,它包含范围为 1-24 的值,而不是索引

如何从我的张量中获取索引张量?提前致谢。

标签: pytorchone-hot-encoding

解决方案


如果你得到的错误是这个:

Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
RuntimeError: one_hot is only applicable to index tensor.

也许您只需要转换为int64

import torch

# random Tensor with the shape you said
indices = torch.Tensor(1, 1, 128, 128, 128).random_(1, 24)
# indices.shape => torch.Size([1, 1, 128, 128, 128])
# indices.dtype => torch.float32

n = 24
one_hot = torch.nn.functional.one_hot(indices.to(torch.int64), n)
# one_hot.shape => torch.Size([1, 1, 128, 128, 128, 24])
# one_hot.dtype => torch.int64

你也可以用indices.long()


推荐阅读