tensorflow - BERT 调用函数中的关键字参数
问题描述
在 HuggingFace TensorFlow 2.0 BERT 库中,文档指出:
TF 2.0 模型接受两种格式作为输入:
将所有输入作为关键字参数(如 PyTorch 模型),或
在第一个位置参数中将所有输入作为列表、元组或字典。
我正在尝试使用这两个中的第一个来调用我创建的 BERT 模型:
from transformers import BertTokenizer, TFBertModel
import tensorflow as tf
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = TFBertModel.from_pretrained('bert-base-uncased')
text = ['This is a sentence',
'The sky is blue and the grass is green',
'More words are here']
labels = [0, 1, 0]
tokenized_text = bert_tokenizer.batch_encode_plus(batch_text_or_text_pairs=text,
pad_to_max_length=True,
return_tensors='tf')
dataset = tf.data.Dataset.from_tensor_slices((tokenized_text['input_ids'],
tokenized_text['attention_mask'],
tokenized_text['token_type_ids'],
tf.constant(labels))).batch(3)
sample = next(iter(dataset))
result1 = bert_model(inputs=(sample[0], sample[1], sample[2])) # works fine
result2 = bert_model(inputs={'input_ids': sample[0],
'attention_mask': sample[1],
'token_type_ids': sample[2]}) # also fine
result3 = bert_model(input_ids=sample[0],
attention_mask=sample[1],
token_type_ids=sample[2]) # raises an error
但是当我执行最后一行时,我得到一个错误:
TypeError: __call__() missing 1 required positional argument: 'inputs'
有人可以解释如何正确使用输入的关键字参数样式吗?
解决方案
似乎在内部,他们正在解释inputs
as input_ids
,如果您不将多个张量作为第一个参数。你可以在里面看到这个,TFBertModel
然后寻找TFBertMainLayer
's 的call
功能。
result1
对我来说,result2
如果我执行以下操作,我会得到完全相同的结果:
result3 = bert_model(inputs=sample[0],
attention_mask=sample[1],
token_type_ids=sample[2])
或者,您也可以只删除 , 也可以inputs=
。
推荐阅读
- c++ - 打包到安装包时,我可以删除/排除哪些 cpp 文件类型?
- azure - 从 ASP.NET Core 访问 Azure App Service ConnectionString
- c# - UserControl 的标签隐藏 MouseEnter 事件
- oracle - Oracle 唯一约束冲突
- sql - 获取几个表的每个租户的行数
- hql - 在 hql 脚本中,我们使用 "!sh echo ---new line---" 来表示相同的 . 想知道在 impala 中打印 impala 脚本中任何行的替代方法吗?
- lstm - RNN 与简单神经网络
- database - 如果没有它自己的表中已经存在 FK,我就无法插入新记录
- python - Python Datetime - 查找日期的半个月期间或日期的 2 周期间
- php - 将变量从控制器引用到所有刀片 Laravel 5