首页 > 解决方案 > 用于令牌分类的 Tensorflow BERT - 在训练和测试时从准确性中排除 pad-token

问题描述

我正在使用用于 tensorflow 的预训练 BERT 模型进行基于标记的分类,以自动标记句子中的因果关系。

要访问 BERT,我正在使用来自 huggingface 的 TFBertForTokenClassification-Interface:https ://huggingface.co/transformers/model_doc/bert.html#tfbertfortokenclassification

我用来训练的句子都是根据BERT-tokenizer转换成token(基本上是单词到数字的映射),然后在训练前填充到一定长度,所以当一个句子只有50个token,另一个只有30个token时第一个填充了 50 个填充令牌,第二个填充了 70 个填充令牌,以获得 100 的通用输入句子长度。

然后我训练我的模型来预测每个标记该标记所属的标签;无论它是原因的一部分,结果还是它们都不是。

但是,在训练和评估期间,我的模型也会对 PAD 令牌进行预测,并且它们也包含在模型的准确性中。由于 PAD 标记很容易为模型预测(它们总是具有相同的标记,并且它们都具有“无”标签,这意味着它们既不属于句子的原因也不属于句子的结果),它们确实扭曲了我的模型的准确性.

例如,如果您有一个包含 30 个单词 -> 30 个标记的句子,并且您将所有句子填充到 100 的长度,那么即使模型没有正确预测“真实”标记,该句子也会获得 70% 的分数. 这样,尽管模型在真正的 pad-tokens 上表现不佳,但我很快就获得了 90% 以上的训练和验证准确率。

我认为注意力面具可以解决这个问题,但事实并非如此。

输入数据集的创建如下:

def example_to_features(input_ids,attention_masks,token_type_ids,label_ids):
  return {"input_ids": input_ids,
          "attention_mask": attention_masks},label_ids

train_ds = tf.data.Dataset.from_tensor_slices((input_ids_train,attention_masks_train,token_ids_train,label_ids_train)).map(example_to_features).shuffle(buffer_size=1000).batch(32)

模型创建:

from transformers import TFBertForTokenClassification

num_epochs = 30

model = TFBertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=3)

model.layers[-1].activation = tf.keras.activations.softmax

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-6)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

model.summary()

然后我像这样训练它:

history = model.fit(train_ds, epochs=num_epochs, validation_data=validate_ds)

到目前为止,有没有人遇到过这个问题,或者知道如何在训练和评估期间从模型的准确性中排除对 pad-token 的预测?

标签: pythontensorflownamed-entity-recognitionhuggingface-transformersbert-language-model

解决方案


是的,这很正常。

BERT 的输出也[batch_size, max_seq_len = 100, hidden_size]将包括 [PAD] 令牌的值或嵌入。但是,您还提供attention_masks给 BERT 模型,以便它不考虑这些 [PAD] 令牌。

同样,您需要在将 BERT 结果传递到最终的全连接层之前屏蔽这些 [PAD] 标记,在计算损失时屏蔽它们,以及计算精度和召回等指标。


推荐阅读