首页 > 解决方案 > 使用预训练模型和配置文件时如何根据损失停止训练?

问题描述

我正在使用更快的 RCNN 模型来训练对象检测器,使用管道配置文件。我知道可以通过直接取消(ctrl+c)来停止训练。我希望训练根据损失值自动停止。如何才能做到这一点?我知道在监控时代时可以使用 keras 回调。使用配置文件和预训练模型(监控步骤)时是否有任何此类选项。谢谢。

标签: pythontensorflowkeraspre-trained-modelearly-stopping

解决方案


It might just be a hack, but I found a solution to my question. The Object detector requires tf_slim package to be installed. And within the tf_slim package, there is a module called learning.py. The complete path to this might look something like this: /usr/local/lib/python3.6/site-packages/tf_slim/learning.py Here, in the learning.py, starting Line 764, the code looks something like this:

try:
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

I wrote a small if statement to check the maximum value for the last five values of the total_loss, and if below a certain threshold (in this case 3), make should_stop True. This is shown below:

try:
  total_loss_list = []
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    total_loss_list.append(total_loss)
    if len(total_loss_list) > 5:
      if max(total_loss_list[-5:]) < 3:
        should_stop = True
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
  # OutOfRangeError is thrown when epoch limit per
  # tf.compat.v1.train.limit_epochs is reached.
  logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

If the loss values are continuously below 3 for five steps, then the training stops. The downside to this is that, the package distribution of tf_slim has to be altered. And every time you work on a new object detection problem, this threshold loss value will change. A better way would be to use a configuration file where you supply the threshold loss value. But I'm stopping here for now. If anyone else has a better solution, please share. I hope this helps someone. Thank you!


推荐阅读