首页 > 解决方案 > 如何使用训练有素的 BERT 模型检查点进行预测?

问题描述

我用 SQUAD 2.0 训练了 BERT,并使用BERT-mastermodel.ckpt.data在输出目录中得到了, model.ckpt.meta, model.ckpt.index(F1 score : 81) 以及predictions.json, 等等/run_squad.py

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
  --do_train=True \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

我尝试将model.ckpt.meta, model.ckpt.index,复制model.ckpt.data$BERT_LARGE_DIR目录并更改run_squad.py标志如下,以仅预测答案而不使用数据集进行训练:

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
  --do_train=False \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

它抛出 bucket directory/model.ckpt 不存在错误。

如何利用训练后生成的检查点进行预测?

标签: pythontensorflowneural-networkgoogle-cloud-tpubert-language-model

解决方案


通常,训练的检查点是在训练时在--output_dir参数指定的目录中创建的。(gs://some_bucket/squad_large/在你的情况下)。每个检查点都会有一个编号。你必须找出最大的数字;例子:model.ckpt-12345。现在,--init_checkpoint使用输出目录和最后保存的检查点(编号最高的模型)在评估/预测中设置参数。(在你的情况下,它应该是这样的--init_checkpoint=gs://some_bucket/squad_large/model.ckpt-<highest number>


推荐阅读