python - RuntimeError: CUDA 错误:在 model.cuda() 之后没有可在设备上执行的内核映像
问题描述
我正在研究这个模型:
class Model(torch.nn.Module):
def __init__(self, sizes, config):
super(Model, self).__init__()
self.lstm = []
for i in range(len(sizes) - 2):
self.lstm.append(LSTM(sizes[i], sizes[i+1], num_layers=8))
self.lstm.append(torch.nn.Linear(sizes[-2], sizes[-1]).cuda())
self.lstm = torch.nn.ModuleList(self.lstm)
self.config_mel = config.mel_features
def forward(self, x):
# convert to log-domain
x = x.clip(min=1e-6).log10()
for layer in self.lstm[:-1]:
x, _ = layer(x)
x = torch.relu(x)
#x = torch_unpack_seq(x)[0]
x = self.lstm[-1](x)
mask = torch.sigmoid(x)
return mask
接着:
model = Model(model_width, config)
model.cuda()
但我收到此错误:
File "main.py", line 29, in <module>
Model.train(args)
File ".../src/model.py", line 57, in train
model.cuda()
File ".../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 637, in cuda
return self._apply(lambda t: t.cuda(device))
File ".../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 530, in _apply
module._apply(fn)
File "/.../.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 530, in _apply
module._apply(fn)
File ".../.local/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 189, in _apply
self.flatten_parameters()
File ".../.local/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 175, in flatten_parameters
torch._cudnn_rnn_flatten_weight(
RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
我不知道为什么会这样。我正在尝试在 cuda 中推送模型和输入,并且我了解错误是否是由于 CPU 中的某些模型和 GPU 中的某些模型引起的。但这里不是这样。我在这里找到了一些 pip 安装解决方案:Pytorch CUDA error: no kernel image is available for execution on the device on RTX 3090 with cuda 11.1
但我无法使用它,因为我试图在无法访问 pip install 的远程仓库中完成工作。
有没有办法解决这个问题?
解决方案
推荐阅读
- paypal - PayPal REST SDK 不支持定期付款
- android - 带有 UUID 过滤器列表混淆的 Android BLE 扫描
- angular - 错误类型错误:无法读取 null 的属性“替换”(无限循环)
- java - 从 url 字符串获取 HandlerMethod
- javascript - Thead 在打印时不在页眉上重复
- r - Graphviz 正在绘制带有 misig 箭头和节点圆形框架的 DAG
- google-cloud-platform - GCP VMWare 引擎 - 与 Vsphere 客户端的连接不起作用
- python - 运行此 python 代码时出现错误。错误是“ElementNotInteractableException”。任何人都可以帮助我吗?
- python - 如何在 Python 中使用 SPACE 键进行图像交换
- jmeter - Jmeter-虽然控制器没有退出循环