python - 在 GLUE 任务上微调 BERT 时,如何监控训练和评估损失?
问题描述
我正在运行https://github.com/huggingface/transformers/blob/master/examples/run_glue.py对二进制分类任务 (CoLA) 执行微调。我想监控训练和评估损失以防止过度拟合。
目前该库为 2.8.0,我从源代码进行了安装。
当我运行示例时
python run_glue.py --model_name_or_path bert-base-uncased
--task_name CoLA
--do_train
--do_eval
--data_dir my_dir
--max_seq_length 128
--per_gpu_train_batch_size 8
--per_gpu_eval_batch_size 8
--learning_rate 2e-5
--num_train_epochs 3.0
--output_dir ./outputs
--logging_steps 5
在标准输出日志中,我看到只有一个损失值的行,例如
{“learning_rate”:3.3333333333333333e-06,“loss”:0.47537623047828675,“step”:25}
通过查看https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py,我看到那里计算了训练和评估损失(在我看来,代码最近被重构)。
因此,我已经用
cr_loss = self._training_step(model, inputs, optimizer)
tr_loss += cr_loss
logs["training loss"] = cr_loss
有了这个我得到:
0502 14:12:18.644119 23632 summary.py:47] Summary name training loss is illegal; using training_loss instead.
| 4/10 [00:02<00:04, 1.49it/s]
{"learning_rate": 3.3333333333333333e-06, "loss": 0.47537623047828675, "training loss": 0.5451719760894775, "step": 25}
这可以吗,还是我在这里做错了什么?
在微调期间,在标准输出中监控给定记录间隔的平均训练和评估损失的最佳方法是什么?
解决方案
如果安装更新的版本(我通过 pip 尝试了 2.9.0),代码中可能不需要更改:只需使用附加标志触发微调,--evaluate_during_training
输出就可以了
0506 12:11:30.021593 34540 trainer.py:551] ***** Running Evaluation *****
I0506 12:11:30.022596 34540 trainer.py:552] Num examples = 140
I0506 12:11:30.023634 34540 trainer.py:553] Batch size = 8 Evaluation:
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:19<00:00, 1.10s/it]
{"eval_mcc": 0.0, "eval_loss": 0.6600487811697854, "learning_rate": 3.3333333333333333e-06, "loss": 0.50044886469841, "step": 25}
请注意示例脚本更改非常频繁,因此完成此操作的标志可能会更改名称......另请参见此处https://discuss.huggingface.co/t/how-to-monitor-both-train-and-validation-metrics -在同一步骤/1301
推荐阅读
- arrays - Ruby - 比较数组和交换索引
- python - Python 脚本的任务计划程序未运行
- c++ - How do I fix this warning - "control reaches end of non-void function [-Wreturn-type]"
- javascript - RxJS Observable 中的循环在调用完成或错误后不会停止
- python-3.x - 将字典键移动到特定值的列表
- android - 无法滚动到 NestedScrollView 内的 RecyclerView 中的项目
- r - 在 R 中使用 GTmetrix REST API v2.0
- javascript - 行为类似于内置的自定义 HTML 元素元素
- android - Kotlin addSnapshotListener() 如何求和总量?
- export - 如何将 .dm3 文件(带注释和比例尺)转换为 .jpg/jpeg 图像?