首页 > 解决方案 > 为什么在从预训练模型进行微调时,Tensorflow 对象检测 API 不会恢复框预测器权重?

问题描述

Tensorflow对象检测 API一个示例项目,用于使用新类微调 SSD 模型。在示例中,作者从预训练的检查点恢复了特征提取器和部分框预测器。然而,在微调模型的主包中,当从检查点恢复 SSD 模型进行微调时,没有恢复任何框预测器,只有特征提取器被恢复。为什么是这样?您何时想要部分恢复框预测器?

我发现了这一点,因为当在自定义数据集上微调这个模型时,模型的总损失减少了约 100 倍,并且部分恢复了框预测器。

示例项目检查点恢复代码

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

主包使用的 SSD 检查点恢复代码

  def restore_from_objects(self, fine_tune_checkpoint_type='detection'):
    """Returns a map of Trackable objects to load from a foreign checkpoint.
    Returns a dictionary of Tensorflow 2 Trackable objects (e.g. tf.Module
    or Checkpoint). This enables the model to initialize based on weights from
    another task. For example, the feature extractor variables from a
    classification model can be used to bootstrap training of an object
    detector. When loading from an object detection model, the checkpoint model
    should have the same parameters as this detection model with exception of
    the num_classes parameter.
    Note that this function is intended to be used to restore Keras-based
    models when running Tensorflow 2, whereas restore_map (above) is intended
    to be used to restore Slim-based models when running Tensorflow 1.x.
    Args:
      fine_tune_checkpoint_type: A string inidicating the subset of variables
        to load. Valid values: `detection`, `classification`, `full`. Default
        `detection`.
        An SSD checkpoint has three parts:
        1) Classification Network (like ResNet)
        2) DeConv layers (for FPN)
        3) Box/Class prediction parameters
        The parameters will be loaded using the following strategy:
          `classification` - will load #1
          `detection` - will load #1, #2
          `full` - will load #1, #2, #3
    Returns:
      A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
    """
    if fine_tune_checkpoint_type == 'classification':
      return {
          'feature_extractor':
              self._feature_extractor.classification_backbone
      }
    elif fine_tune_checkpoint_type == 'detection':
      fake_model = tf.train.Checkpoint(
          _feature_extractor=self._feature_extractor)
      return {'model': fake_model}

    elif fine_tune_checkpoint_type == 'full':
      return {'model': self}

    else:
      raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
          fine_tune_checkpoint_type))

标签: tensorflowcomputer-visionobject-detectionsingle-shot-detector

解决方案


推荐阅读