text - 用于文本摘要的微调 Huggingface 编码器-解码器模型得到“RuntimeError:CUDA 错误:设备端断言触发”
问题描述
在尝试微调预训练的编码器-解码器模型时,得到RuntimeError: CUDA error: device-side assertwritten 。这里 BERT 被用作编码器和解码器。也经历了这个,这没有帮助。请通过以下代码。
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for the CUDA related error
class SummaryModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(MODEL_NAME, MODEL_NAME, tie_encoder_decoder=True, return_dict = True)
self.model.config.decoder_start_token_id = tokenizer.bos_token_id
self.model.config.eos_token_id = tokenizer.eos_token_id
self.model.config.early_stopping = True
def forward(self, batch):
output = self.model(decoder_input_ids = batch['labels'], **batch)
return output.loss, output.logits
def training_step(self, batch, batch_idx):
loss, outputs = self(batch)
self.log('train_loss', loss, prog_bar = True, logger = True)
return loss
def validation_step(self, batch, batch_idx):
loss, outputs = self(batch)
self.log('val_loss', loss, prog_bar = True, logger = True)
return loss
def test_step(self, batch, batch_idx):
loss, outputs = self(batch)
self.log('test_loss', loss, prog_bar = True, logger = True)
return loss
def configure_optimizers(self):
return AdamW(self.parameters(), lr=0.001)
model = SummaryModel()
trainer.fit(model, datam)
返回以下
----> 1 trainer.fit(model, datam)
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2041 # remove once script supports set_grad_enabled
2042 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2043 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
2044
2045
RuntimeError: CUDA error: device-side assert triggered
其他相关的事情是,BATCH_SIZE = 20,TOKENIZER = BertTokenizer,MODEL_NAME = 'bert-base-uncased',TEXT_MAX_TOKEN_LEN = 200,SUMMARY_MAX_TOKEN_LEN = 50。
数据模块如下
datam = SummaryDataModule(train_df, val_df, test_df, batch_size = BATCH_SIZE)
datam.setup()
dl = datam.train_dataloader()
print(' batch-len, input_length')
for d in dl:
print(d['input_ids'].shape)
print(d['attention_mask'].shape)
print(d['decoder_attention_mask'].shape)
print(d['labels'].shape)
print(type(d))
break
返回以下
batch-len, input_length
torch.Size([20, 200])
torch.Size([20, 200])
torch.Size([20, 50])
torch.Size([20, 50])
<class 'dict'>
解决方案
推荐阅读
- javascript - TypeError:在 Next JS 中实现时,XM_ProgressBar 不是构造函数
- python - 计算熊猫数据框的余弦距离
- python - 在cython中使用带有数组值的cpp映射
- html - 从容器 div 出来并重叠下一个元素的 Bootstrap 4 卡
- c++ - 具有相同类型值的变量的 std::initializer_list
- mysql - 无法排序查询,SQL数据库表
- ios - SwiftUI:动画期间的文本问题
- android - 无法通过 Android 10 / API 29 中的 Web 服务获取数据
- python - 如何使用 PYPDF 2 为多个页面添加水印更高效
- ios - Xamarin 应用程序在 iOS 12 上的启动屏幕后崩溃