python - 用于令牌分类的 Tensorflow BERT - 在训练和测试时从准确性中排除 pad-token
问题描述
我正在使用用于 tensorflow 的预训练 BERT 模型进行基于标记的分类,以自动标记句子中的因果关系。
要访问 BERT,我正在使用来自 huggingface 的 TFBertForTokenClassification-Interface:https ://huggingface.co/transformers/model_doc/bert.html#tfbertfortokenclassification
我用来训练的句子都是根据BERT-tokenizer转换成token(基本上是单词到数字的映射),然后在训练前填充到一定长度,所以当一个句子只有50个token,另一个只有30个token时第一个填充了 50 个填充令牌,第二个填充了 70 个填充令牌,以获得 100 的通用输入句子长度。
然后我训练我的模型来预测每个标记该标记所属的标签;无论它是原因的一部分,结果还是它们都不是。
但是,在训练和评估期间,我的模型也会对 PAD 令牌进行预测,并且它们也包含在模型的准确性中。由于 PAD 标记很容易为模型预测(它们总是具有相同的标记,并且它们都具有“无”标签,这意味着它们既不属于句子的原因也不属于句子的结果),它们确实扭曲了我的模型的准确性.
例如,如果您有一个包含 30 个单词 -> 30 个标记的句子,并且您将所有句子填充到 100 的长度,那么即使模型没有正确预测“真实”标记,该句子也会获得 70% 的分数. 这样,尽管模型在真正的 pad-tokens 上表现不佳,但我很快就获得了 90% 以上的训练和验证准确率。
我认为注意力面具可以解决这个问题,但事实并非如此。
输入数据集的创建如下:
def example_to_features(input_ids,attention_masks,token_type_ids,label_ids):
return {"input_ids": input_ids,
"attention_mask": attention_masks},label_ids
train_ds = tf.data.Dataset.from_tensor_slices((input_ids_train,attention_masks_train,token_ids_train,label_ids_train)).map(example_to_features).shuffle(buffer_size=1000).batch(32)
模型创建:
from transformers import TFBertForTokenClassification
num_epochs = 30
model = TFBertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=3)
model.layers[-1].activation = tf.keras.activations.softmax
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-6)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
model.summary()
然后我像这样训练它:
history = model.fit(train_ds, epochs=num_epochs, validation_data=validate_ds)
到目前为止,有没有人遇到过这个问题,或者知道如何在训练和评估期间从模型的准确性中排除对 pad-token 的预测?
解决方案
是的,这很正常。
BERT 的输出也[batch_size, max_seq_len = 100, hidden_size]
将包括 [PAD] 令牌的值或嵌入。但是,您还提供attention_masks
给 BERT 模型,以便它不考虑这些 [PAD] 令牌。
同样,您需要在将 BERT 结果传递到最终的全连接层之前屏蔽这些 [PAD] 标记,在计算损失时屏蔽它们,以及计算精度和召回等指标。
推荐阅读
- swift - 删除数组元素时,SwiftUI 的 ForEach 崩溃
- php - 如何在整个 PHP 站点中更改图像位置
- android - 使用不同的包和应用程序 ID 制作 Android 构建变体
- java - 覆盖具有相同限定符名称的 spring beans
- gitlab - 根据环境设置 CI/CD 变量
- android - 以编程方式合并不同种类的 android 渐变
- java - 如何验证放心的 api 响应
- amazon-ec2 - 如何在 squid.conf 中创建多个 IP?
- sql - listagg() 空间不足(> 4k)后有没有办法继续下一行?
- android - 编写这个简单的 sqlite rawquery 的正确方法是什么?