首页 > 解决方案 > BERT 调用函数中的关键字参数

问题描述

在 HuggingFace TensorFlow 2.0 BERT 库中,文档指出:

TF 2.0 模型接受两种格式作为输入:

  • 将所有输入作为关键字参数(如 PyTorch 模型),或

  • 在第一个位置参数中将所有输入作为列表、元组或字典。

我正在尝试使用这两个中的第一个来调用我创建的 BERT 模型:

from transformers import BertTokenizer, TFBertModel
import tensorflow as tf

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = TFBertModel.from_pretrained('bert-base-uncased')

text = ['This is a sentence', 
        'The sky is blue and the grass is green', 
        'More words are here']
labels = [0, 1, 0]
tokenized_text = bert_tokenizer.batch_encode_plus(batch_text_or_text_pairs=text,
                                                  pad_to_max_length=True,
                                                  return_tensors='tf')
dataset = tf.data.Dataset.from_tensor_slices((tokenized_text['input_ids'],
                                              tokenized_text['attention_mask'],
                                              tokenized_text['token_type_ids'],
                                              tf.constant(labels))).batch(3)
sample = next(iter(dataset))

result1 = bert_model(inputs=(sample[0], sample[1], sample[2]))  # works fine
result2 = bert_model(inputs={'input_ids': sample[0], 
                             'attention_mask': sample[1], 
                             'token_type_ids': sample[2]})  # also fine
result3 = bert_model(input_ids=sample[0], 
                     attention_mask=sample[1], 
                     token_type_ids=sample[2])  # raises an error

但是当我执行最后一行时,我得到一个错误:

TypeError: __call__() missing 1 required positional argument: 'inputs'

有人可以解释如何正确使用输入的关键字参数样式吗?

标签: tensorflownlpargumentshuggingface-transformers

解决方案


似乎在内部,他们正在解释inputsas input_ids,如果您不将多个张量作为第一个参数。你可以在里面看到这个,TFBertModel然后寻找TFBertMainLayer's 的call功能。

result1对我来说,result2如果我执行以下操作,我会得到完全相同的结果:

result3 = bert_model(inputs=sample[0], 
                     attention_mask=sample[1], 
                     token_type_ids=sample[2])

或者,您也可以只删除 , 也可以inputs=


推荐阅读