python - 后端 CPU 的预期对象,但获得了参数 #2 'source' 的后端 CUDA
问题描述
我尝试了其他答案,但错误没有被删除。与我得到的另一个问题的不同之处在于,错误使用的最后一个术语是“来源”,我在任何问题中都没有找到。如果可能,还请错误地解释术语“来源”。并且在没有 CPU 的情况下运行代码工作正常。
我正在使用启用了 GPU 的 Google Colab。
import torch
from torch import nn
import syft as sy
hook = sy.TorchHook(torch)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.Sequential(nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim = 1))
model = model.to(device)
输出 :
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-42-136ec343040a> in <module>()
8 nn.LogSoftmax(dim = 1))
9
---> 10 model = model.to(device)
3 frames
/usr/local/lib/python3.6/dist-packages/syft/frameworks/torch/hook/hook.py in data(self, new_data)
368
369 with torch.no_grad():
--> 370 self.set_(new_data)
371 return self
372
RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'source'
解决方案
这个问题与PySyft
. 正如您在此Issue#1893中看到的,当前的解决方法是设置:
import torch
torch.set_default_tensor_type(torch.cuda.FloatTensor)
紧随其后import torch
。
代码:
import torch
from torch import nn
torch.set_default_tensor_type(torch.cuda.FloatTensor) # <-- workaround
import syft as sy
hook = sy.TorchHook(torch)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = nn.Sequential(nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim = 1))
model = model.to(device)
print(model)
输出:
cuda
Sequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=128, bias=True)
(3): ReLU()
(4): Linear(in_features=128, out_features=64, bias=True)
(5): ReLU()
(6): Linear(in_features=64, out_features=10, bias=True)
(7): LogSoftmax()
)
推荐阅读
- php - Is there a particular function to receive the response of a stored procedure with a pivot?
- javascript - Random boolean 70/30 inside function?
- autohotkey - pressing Q should be Q + MB RightClick
- windows - 将 psql 查询输出存储到变量中的批处理文件
- vim - Use AHK to detect the "Vim Mode" for GUI windows and configure conditional mappings of keys
- django - Render Django AuthenticationForm in Jinja2
- spring - IdentityServer4 API Resource & Spring Framework
- delphi - 如何覆盖私有属性设置器?
- azure - Third Party DLLs in Azure Web API
- environment-variables - 在鱼壳中,如何设置具有默认回退的变量?