首页 > 解决方案 > PyTorch 何时自动转换 Tensor dtype?

问题描述

PyTorch 何时自动转换 Tensor dtype?为什么它有时会自动执行,有时会抛出错误?

例如,这会自动转换c为浮点数:

a = torch.tensor(5)    
b = torch.tensor(5.)
c = a*b 

a.dtype
>>> torch.int64

b.dtype
>>> torch.float32

c.dtype
>>> torch.float32

但这会引发错误:

a = torch.ones(2, dtype=torch.float)   
b = torch.ones(2, dtype=torch.long)    
c = torch.matmul(a,b)

Traceback (most recent call last):

  File "<ipython-input-128-fbff7a713ff0>", line 1, in <module>
    torch.matmul(a,b)

RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'tensor'

我很困惑,因为 Numpy 似乎会根据需要自动转换所有数组,例如

a = np.ones(2, dtype=np.long)
b = np.ones(2, dtype=np.float)

np.matmul(a,b)
>>> 2.0

a*b
>>> array([1., 1.])

标签: pythonnumpypytorch

解决方案


看起来 PyTorch 团队正在解决这些类型的问题,请参阅此问题。根据您的示例,似乎已经在 1.0.0 中实现了一些基本的向上转换(可能对于重载的运算符,尝试了其他一些类似 '//' 或加法并且它们工作正常),但没有找到任何证据(比如github 问题或文档中的信息)。如果有人发现它(torch.Tensor各种操作的隐式转换),请发表评论或其他答案。

这个问题是关于类型提升的提案,你可以看到所有这些仍然是开放的。


推荐阅读