python - Pytorch 出现 CUDA 错误:训练 conv1d 分类器时触发设备端断言
问题描述
我正在使用 PyTorch 实现 CNN 多标签分类器,但它总是显示此错误:
“CUDA 错误:触发了设备端断言”。在错误消息中它指出了损失函数,但是当我更改损失函数时,它仍然会发生并指向其他部分(真的就像随机挑选的一样)。当我换成 CPU 时,它说“索引超出自身范围”,但是当我调查我的数据加载器时,这并不奇怪。
我有 15 个类,59462 个唯一令牌,每个文档的 len 是 30,000(令牌)。
我的模型和损失函数是这样的:
class model(nn.Module):
def __init__(self, num_classes=15):
super(model, self).__init__()
self.embedding = nn.Sequential(nn.Embedding(59462,400),nn.Dropout(0.15))
self.features = nn.Sequential(
nn.Conv1d(400, 500, kernel_size=3, stride=1, padding=False), nn.ReLU(),
nn.Dropout(0.05), nn.MaxPool1d(kernel_size=2), nn.Dropout(0.15))
self.linear = nn.Linear(500*14999, 15)
def forward(self, x):
x = self.embedding(x)
x = x.permute(0,2,1)
x = self.features(x)
x = x.view(x.size(0), 500*14999)
x = self.linear(x)
return x
model = model()
model = model.to(device)
def loss_fn(outputs, targets):
return torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))(outputs, targets).to(device)
#pos_weight is for my unbalanced data
这是使用 CPU 时的错误消息:
/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py in run_cell_magic(self, magic_name, line, cell)
2115 magic_arg_s = self.var_expand(line, stack_depth)
2116 with self.builtin_trap:
-> 2117 result = fn(magic_arg_s, cell)
2118 return result
2119
<decorator-gen-60> in time(self, line, cell, local_ns)
/usr/local/lib/python3.6/dist-packages/IPython/core/magic.py in <lambda>(f, *a, **k)
186 # but it's overkill for just that one bit of state.
187 def magic_deco(arg):
--> 188 call = lambda f, *a, **k: f(*a, **k)
189
190 if callable(arg):
/usr/local/lib/python3.6/dist-packages/IPython/core/magics/execution.py in time(self, line, cell, local_ns)
1191 else:
1192 st = clock2()
-> 1193 exec(code, glob, local_ns)
1194 end = clock2()
1195 out = None
<timed exec> in <module>()
<ipython-input-67-09829392983d> in train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples)
27 targets = d["targets"].to(device)
28 outputs = model(
---> 29 tokens
30 )
31
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
<ipython-input-59-2c2ec554cb05> in forward(self, x)
16
17 def forward(self, x):
---> 18 x = self.embedding(x)
19 x = x.permute(0,2,1)
20 x = self.features(x)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py in forward(self, input)
115 def forward(self, input):
116 for module in self:
--> 117 input = module(input)
118 return input
119
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/sparse.py in forward(self, input)
124 return F.embedding(
125 input, self.weight, self.padding_idx, self.max_norm,
--> 126 self.norm_type, self.scale_grad_by_freq, self.sparse)
127
128 def extra_repr(self) -> str:
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
1850 # remove once script supports set_grad_enabled
1851 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1852 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
1853
1854
IndexError: index out of range in self
有谁知道这个问题是什么,我该如何解决?提前致谢。
解决方案
推荐阅读
- python-3.x - 我需要用特殊参数制作类
- javascript - a .__ proto __ 和 b() 的比率
- docker - 服务“nginx-proxy”使用未定义的网络“nginx-proxy”
- ios - iOS Swift:Firebase 存储上传错误 - 只有初始文件上传有效
- flutter - 使用 Flutter TDD 基于 json 文件返回有效的名片模型
- css - WordPress中按字母分组的分类术语链接
- javascript - Mongoose - findByIdAndUpdate runValidators 基于其他属性
- sql - 连接表 - 如何找出要添加的 ID?
- python - 如何使箱线图更具可读性和比例数字?
- javascript - 子 componentDidMount() 中未定义道具