python - GPT2 on Hugging face(pytorch transformers) RuntimeError: grad can be implicitly created only for scalar outputs
问题描述
我正在尝试使用我的自定义数据集微调 gpt2。我使用来自拥抱脸转换器的文档创建了一个基本示例。我收到提到的错误。我知道这意味着什么:(基本上它是在非标量张量上向后调用)但由于我几乎只使用 API 调用,我不知道如何解决这个问题。有什么建议么?
from pathlib import Path
from absl import flags, app
import IPython
import torch
from transformers import GPT2LMHeadModel, Trainer, TrainingArguments
from data_reader import GetDataAsPython
# this is my custom data, but i get the same error for the basic case below
# data = GetDataAsPython('data.json')
# data = [data_point.GetText2Text() for data_point in data]
# print("Number of data samples is", len(data))
data = ["this is a trial text", "this is another trial text"]
train_texts = data
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
special_tokens_dict = {'pad_token': '<PAD>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
train_encodigs = tokenizer(train_texts, truncation=True, padding=True)
class BugFixDataset(torch.utils.data.Dataset):
def __init__(self, encodings):
self.encodings = encodings
def __getitem__(self, index):
item = {key: torch.tensor(val[index]) for key, val in self.encodings.items()}
return item
def __len__(self):
return len(self.encodings['input_ids'])
train_dataset = BugFixDataset(train_encodigs)
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
)
model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)
model.resize_token_embeddings(len(tokenizer))
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()
解决方案
我终于弄明白了。问题是数据样本不包含目标输出。即使是严格的 gpt 也是自我监督的,这必须明确告知模型。
您必须添加以下行:
item['labels'] = torch.tensor(self.encodings['input_ids'][index])
到 Dataset 类的getitem函数,然后它运行正常!
推荐阅读
- python - Xpath 只给我第一个项目,而我想要所有项目(使用 Scrapy)
- javascript - Chartjs - 工具提示 - 不同的圆角半径
- android - Android Studio中文件名旁边的勾号
- vb6 - 如何在 DTPICKER vb6 中计算分钟
- webpack - 如何在动态导入块名称中添加哈希?
- swift - 有什么办法可以暂时隐藏蓝点,然后让它在 GMSMapView 中再次消失?
- graphql - Appsync graphql:如何根据数组字段中的条目进行过滤
- javascript - Javascript用匹配的数组键替换字符串中的单词
- python - Python 3.7“由于环境错误而无法安装软件包:无法解析:主机:端口
- amazon-web-services - 在 Cloudformation 中将 Cloudfront“基于所选请求标头的缓存”设置为全部