首页 > 解决方案 > 在 PyTorch 中设置浮点类型时,张量类型和 dtype 有什么区别,什么时候应该设置一个而不是另一个?

问题描述

我使用双打作为我的模型输入和输出,所以我试图将torch设置为使用float64而不是float32。到底有什么区别

torch.set_default_tensor_type(torch.DoubleTensor)

将默认的 torch.Tensor 类型设置为浮点张量类型 t。此类型也将用作 torch.tensor() 中类型推断的默认浮点类型。

torch.set_default_dtype(torch.float64)

将默认浮点 dtype 设置为 d。此类型将用作 torch.tensor() 中类型推断的默认浮点类型。

文档告诉我,设置张量类型也会设置 dtype,但我不确定何时会使用其中一个。

我应该提到,这两个语句都修复了我在从浮点数更改为双精度数后看到的错误:

Traceback (most recent call last):
  File "train.py", line 122, in train_model  
    output = net(action)  
  File "/opt/anaconda3/lib/python3.7/site- packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "models.py", line 25, in forward
    return self.fc2(F.relu(self.fc1(x)))
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 67, in forward
    return F.linear(input, self.weight, self.bias)
  File "/opt/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1352, in linear
    ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'

标签: numpyruntime-errorpytorch

解决方案


推荐阅读