首页 > 解决方案 > 仅优化使用 TensorFlow Estimator API 的模型的某些变量

问题描述

我需要冻结部分模型并只训练某些变量。现在,使用低级 API,我可以传递var_listtf.train.Optimizer.minimize方法。但是,当我使用 TensorFlow Estimator 时,我只能传递优化器本身,然后将其用于最小化 Estimator 内部循环内的损失。

我想到的唯一解决方案是定义一个自定义优化器并覆盖该Optimizer.minimize方法。像这样的东西:

def minimize(self, *args, **kwargs):
    print("Inside...")
    if not kwargs['var_list']:
       kwargs['var_list'] = self.var_list

    return super(MyOptimizer, self).minimize(*args, **kwargs)

现在,我希望在每个训练步骤中都能看到“Inside...”短语打印在屏幕上;尤其是当我看到模型训练得很好时。这有点表明我的minimize功能被完全忽略了,我似乎无法弄清楚为什么。

那么,重写是否正确,minimize或者是否有更好的方法来使用 Estimator 呢?

标签: tensorflowmachine-learningtensorflow-estimator

解决方案


您可以通过指定 model_fn 函数来简单地制作自定义估算器

    def model_fn(features, labels, mode):
      logits = model_architecture(features)
      loss = loss_function(logits, labels)
      if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = optimizer
        train_op = ontimizer.minimize(loss=loss, 
                                      global_step=global_step,
                                      var_list=variables_to_minimize)

      return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

推荐阅读