首页 > 解决方案 > 使用 BERT 编码器的二元分类模型准确率达 50%

问题描述

我正在尝试为 Yelp 二进制分类任务训练一个简单的模型。

加载 BERT 编码器:

gs_folder_bert = "gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12"
bert_config_file = os.path.join(gs_folder_bert, "bert_config.json")
config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())
bert_config = bert.configs.BertConfig.from_dict(config_dict)
_, bert_encoder = bert.bert_models.classifier_model(
    bert_config, num_labels=2)
checkpoint = tf.train.Checkpoint(model=bert_encoder)
checkpoint.restore(
    os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()

加载数据:

data, info = tfds.load('yelp_polarity_reviews', with_info=True, batch_size=-1, as_supervised=True)
train_x_orig, train_y_orig = tfds.as_numpy(data['train'])
train_x = encode_examples(train_x_orig)
train_y = train_y_orig 

使用 BERT 嵌入数据:

encoder_output = bert_encoder.predict(train_x)

设置模型:

inputs = keras.Input(shape=(768,))
x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(8, activation='relu')(x)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
sgd = SGD(lr=0.0001)
model.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])

火车:

model.fit(encoder_output[0], train_y, batch_size=64, epochs=3)
# encoder_output[0].shape === (10000, 1, 768)
# y_train.shape === (100000,)

训练结果:

Epoch 1/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6921 - accuracy: 0.5455
Epoch 2/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6918 - accuracy: 0.5455
Epoch 3/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6915 - accuracy: 0.5412
Epoch 4/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6913 - accuracy: 0.5407
Epoch 5/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6911 - accuracy: 0.5358

我尝试了不同的学习率,但主要问题似乎是训练需要 1 秒并且准确度保持在 ~0.5。我没有正确设置输入/模型吗?

标签: pythontensorflowkeras

解决方案


你的 BERT 模型没有训练。它必须放在密集层之前并作为模型的一部分进行训练。输入层必须不采用 BERT 向量,而是将标记序列裁剪为 max_length 并进行填充。这是示例代码:https ://keras.io/examples/nlp/text_extraction_with_bert/ ,请参见函数的开头create_model

或者,您可以使用Trainer变压器。


推荐阅读