首页 > 解决方案 > 如何将 pytorch 整数张量转换为布尔张量?

问题描述

我想将整数张量转换为布尔张量。

具体来说,我希望能够拥有一个转换tensor([0,10,0,16])tensor([0,1,0,1])

这在 Tensorflow 中是微不足道的,只需使用tf.cast(x,tf.bool).

我希望强制转换将所有大于 0 的整数更改为 1,并将所有等于 0 的整数更改为 0。这!!在大多数语言中是等价的。

由于 pytorch 似乎没有专用的布尔类型可以转换,所以这里最好的方法是什么?

编辑:我正在寻找一种矢量化解决方案,而不是循环遍历每个元素。

标签: pythoncastingbooleanpytorchtensor

解决方案


您正在寻找的是为给定的整数张量生成一个布尔掩码。为此,您可以使用简单的比较运算符 ( >) 或 using简单地检查条件:“张量中的值是否大于 0” torch.gt(),然后会给我们所需的结果。

# input tensor
In [76]: t   
Out[76]: tensor([ 0, 10,  0, 16])

# generate the needed boolean mask
In [78]: t > 0      
Out[78]: tensor([0, 1, 0, 1], dtype=torch.uint8)

# sanity check
In [93]: mask = t > 0      

In [94]: mask.type()      
Out[94]: 'torch.ByteTensor'

注意:在 PyTorch 1.4+ 版本中,上述操作将返回'torch.BoolTensor'

In [9]: t > 0  
Out[9]: tensor([False,  True, False,  True])

# alternatively, use `torch.gt()` API
In [11]: torch.gt(t, 0)
Out[11]: tensor([False,  True, False,  True])

如果您确实想要单个位(0s 或1s),请使用:

In [14]: (t > 0).type(torch.uint8)   
Out[14]: tensor([0, 1, 0, 1], dtype=torch.uint8)

# alternatively, use `torch.gt()` API
In [15]: torch.gt(t, 0).int()
Out[15]: tensor([0, 1, 0, 1], dtype=torch.int32)

此更改的原因已在此功能请求问题中进行了讨论:问题/4764 - 引入 torch.BoolTensor ...


TL;DR : 简单的一个班轮

t.bool().int()

推荐阅读