首页 > 解决方案 > HuggingFace BartModel 应该定义一个“get_encoder”函数

问题描述

我正在使用 HuggingFace 预训练模型通过和facebook/bart-large-cnn进行文本摘要。模型和标记器都加载正常:AutoModelAutoTokenizer

import os
import torch
from transformers import AutoTokenizer, AutoModel

torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn",
                                          cache_dir=os.getenv("cache_dir", "model"))

model = AutoModel.from_pretrained("facebook/bart-large-cnn",
                                  cache_dir=os.getenv("cache_dir", "model")).to(torch_device)

FRANCE_ARTICLE = ' Marseille...'  # @noqa

dct = tokenizer.batch_encode_plus(
    [FRANCE_ARTICLE],
    max_length=1024,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
)

max_length = 140
min_length = 55

hypotheses_batch = model.generate(
    input_ids=dct["input_ids"].to(torch_device),
    attention_mask=dct["attention_mask"].to(torch_device),
    num_beams=4,
    length_penalty=2.0,
    max_length=max_length + 2,
    min_length=min_length + 1,
    no_repeat_ngram_size=3,
    do_sample=False,
    early_stopping=True,
    decoder_start_token_id=model.config.eos_token_id,
)

decoded = [
    tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
]

print(decoded)

但是当我在标记器上调用解码时出现此错误tokenizer.batch_encode_plus

Traceback (most recent call last):
  File "src/summarization/run.py", line 42, in <module>
    summary_ids = model.generate(article_input_ids,num_beams=4,length_penalty=2.0,max_length=142,min_length=56,no_repeat_ngram_size=3)
  File "/usr/local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.7/site-packages/transformers/generation_utils.py", line 379, in generate
    assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
AssertionError: BartModel(
  (shared): Embedding(50264, 1024, padding_idx=1)
  (encoder): BartEncoder(
    (embed_tokens): Embedding(50264, 1024, padding_idx=1)
    (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
    (layers): ModuleList(
      (0): EncoderLayer(
...
      )
    )
    (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
) should have a 'get_encoder' function defined

标签: pythonhuggingface-transformershuggingface-tokenizers

解决方案


推荐阅读