python - 如何将 pytorch 整数张量转换为布尔张量?
问题描述
我想将整数张量转换为布尔张量。
具体来说,我希望能够拥有一个转换tensor([0,10,0,16])
为tensor([0,1,0,1])
这在 Tensorflow 中是微不足道的,只需使用tf.cast(x,tf.bool)
.
我希望强制转换将所有大于 0 的整数更改为 1,并将所有等于 0 的整数更改为 0。这!!
在大多数语言中是等价的。
由于 pytorch 似乎没有专用的布尔类型可以转换,所以这里最好的方法是什么?
编辑:我正在寻找一种矢量化解决方案,而不是循环遍历每个元素。
解决方案
您正在寻找的是为给定的整数张量生成一个布尔掩码。为此,您可以使用简单的比较运算符 ( >
) 或 using简单地检查条件:“张量中的值是否大于 0” torch.gt()
,然后会给我们所需的结果。
# input tensor
In [76]: t
Out[76]: tensor([ 0, 10, 0, 16])
# generate the needed boolean mask
In [78]: t > 0
Out[78]: tensor([0, 1, 0, 1], dtype=torch.uint8)
# sanity check
In [93]: mask = t > 0
In [94]: mask.type()
Out[94]: 'torch.ByteTensor'
注意:在 PyTorch 1.4+ 版本中,上述操作将返回'torch.BoolTensor'
In [9]: t > 0
Out[9]: tensor([False, True, False, True])
# alternatively, use `torch.gt()` API
In [11]: torch.gt(t, 0)
Out[11]: tensor([False, True, False, True])
如果您确实想要单个位(0
s 或1
s),请使用:
In [14]: (t > 0).type(torch.uint8)
Out[14]: tensor([0, 1, 0, 1], dtype=torch.uint8)
# alternatively, use `torch.gt()` API
In [15]: torch.gt(t, 0).int()
Out[15]: tensor([0, 1, 0, 1], dtype=torch.int32)
此更改的原因已在此功能请求问题中进行了讨论:问题/4764 - 引入 torch.BoolTensor ...
TL;DR : 简单的一个班轮
t.bool().int()
推荐阅读
- azure-aks - 无法访问 AKS 中的服务
- python - 这是插入排序吗?
- python - 如何使用 re.compile 和 finditer() 函数获取单词的开始和结束索引
- android - Android Studio 4.1.1 不会在 MacOS Big Sur 上启动
- reactjs - React.FC 泛型之一
- ms-access - 在 .NET 中设置 Microsoft Access 数据库中列的验证规则
- flutter - Flutter - 使用 Firestore 实例的 BuildSuggestions(搜索委托)出错
- java - Java:在另一个线程中执行 showOptionDialog 然后退出它会关闭整个应用程序
- javascript - 为什么在安装组件(NextJS)之后触发useEffect?
- python - 如何在 Django-graphene 中使用带有突变的异步