pytorch - 将张量转换为索引的一个热编码张量
问题描述
我有形状为 (1,1,128,128,128) 的标签张量,其中值的范围可能为 0,24。我想使用该nn.fucntional.one_hot
函数将其转换为一个热编码张量
n = 24
one_hot = torch.nn.functional.one_hot(indices, n)
但这需要一个指数的张量,老实说,我不知道如何得到这些。我拥有的唯一张量是上述形状的标签张量,它包含范围为 1-24 的值,而不是索引
如何从我的张量中获取索引张量?提前致谢。
解决方案
如果你得到的错误是这个:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: one_hot is only applicable to index tensor.
也许您只需要转换为int64
:
import torch
# random Tensor with the shape you said
indices = torch.Tensor(1, 1, 128, 128, 128).random_(1, 24)
# indices.shape => torch.Size([1, 1, 128, 128, 128])
# indices.dtype => torch.float32
n = 24
one_hot = torch.nn.functional.one_hot(indices.to(torch.int64), n)
# one_hot.shape => torch.Size([1, 1, 128, 128, 128, 24])
# one_hot.dtype => torch.int64
你也可以用indices.long()
。
推荐阅读
- javascript - 如何使标题元素的宽度缓慢向左移动
- swift - UICollectionViewCell 阴影和圆角不起作用
- powerapps - 从 PowerApps 创建页面时的 OneNote 嵌套列表问题
- flutter - 检查表单条目是否为 10 的倍数
- graphics - 未解析的外部符号“类 Fl_Graphics_Driver * fl_graphics_driver”
- javascript - 如何在javascript中将文本输入变量添加到URL
- css - 我应该如何在我的项目中构建 sass 和 css
- python - 如何使用物理按钮在热敏打印机上打印内容?
- python - 通过字符串生成器的字符串时出现新字符串时最清晰和 Pythonic 的跟踪方式
- python - 长时间的文件进度