python - 如何使用训练有素的 BERT 模型检查点进行预测?
问题描述
我用 SQUAD 2.0 训练了 BERT,并使用BERT-mastermodel.ckpt.data
在输出目录中得到了, model.ckpt.meta
, model.ckpt.index
(F1 score : 81) 以及predictions.json
, 等等/run_squad.py
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
--do_train=True \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
我尝试将model.ckpt.meta
, model.ckpt.index
,复制model.ckpt.data
到$BERT_LARGE_DIR
目录并更改run_squad.py
标志如下,以仅预测答案而不使用数据集进行训练:
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
--do_train=False \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
它抛出 bucket directory/model.ckpt 不存在错误。
如何利用训练后生成的检查点进行预测?
解决方案
通常,训练的检查点是在训练时在--output_dir
参数指定的目录中创建的。(gs://some_bucket/squad_large/
在你的情况下)。每个检查点都会有一个编号。你必须找出最大的数字;例子:model.ckpt-12345
。现在,--init_checkpoint
使用输出目录和最后保存的检查点(编号最高的模型)在评估/预测中设置参数。(在你的情况下,它应该是这样的--init_checkpoint=gs://some_bucket/squad_large/model.ckpt-<highest number>
)
推荐阅读
- java - 通过电子邮件发送文件andoid(无法附加文件)
- r - 如何创建包含不同颜色评级的箱线图?
- class - 递归函数和类中的 Raku 类型约束
- sql-server - 如何更改现有分区的 FILEGROUP (SQL Server)
- node.js - npm i 导致许多 ERESOLVE 问题
- c++ - 如何输出数组中的第一个、最后一个、第二个、倒数第二个等元素?
- java - Clojure/Java.regex:为什么重新查找带括号的正则表达式(如下)的抛出异常,而不是像“\ d +”这样的简单表达式?
- javascript - 如何在本机反应中从警报按钮传递值以发挥作用
- excel - 根据单元格 ID(例如 A2、D15 等)定位相应的列标题和行 ID
- python - 此代码查找单个字母并检查其是否在输入的单词中。如何将我正在搜索的字母作为参数?