tensorflow - 仅优化使用 TensorFlow Estimator API 的模型的某些变量
问题描述
我需要冻结部分模型并只训练某些变量。现在,使用低级 API,我可以传递var_list
给tf.train.Optimizer.minimize
方法。但是,当我使用 TensorFlow Estimator 时,我只能传递优化器本身,然后将其用于最小化 Estimator 内部循环内的损失。
我想到的唯一解决方案是定义一个自定义优化器并覆盖该Optimizer.minimize
方法。像这样的东西:
def minimize(self, *args, **kwargs):
print("Inside...")
if not kwargs['var_list']:
kwargs['var_list'] = self.var_list
return super(MyOptimizer, self).minimize(*args, **kwargs)
现在,我希望在每个训练步骤中都能看到“Inside...”短语打印在屏幕上;尤其是当我看到模型训练得很好时。这有点表明我的minimize
功能被完全忽略了,我似乎无法弄清楚为什么。
那么,重写是否正确,minimize
或者是否有更好的方法来使用 Estimator 呢?
解决方案
您可以通过指定 model_fn 函数来简单地制作自定义估算器
def model_fn(features, labels, mode):
logits = model_architecture(features)
loss = loss_function(logits, labels)
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = optimizer
train_op = ontimizer.minimize(loss=loss,
global_step=global_step,
var_list=variables_to_minimize)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
推荐阅读
- javascript - 如何将选定的索引数据分配给我的工具提示数据?
- swift - Swift 4.2 Setter Getter, All paths through this function will call itself
- sql-server - 在 sql server db 上引入延迟以测试 API 中的超时
- uuid - Longest Version 4 GUID
- sql - MS Access:查询以列出一系列文件夹中的文件数
- laravel - PHPUNIT 测试 - 预期状态 200 但收到 500
- python - 覆盆子主机名更改但不是真的
- javascript - 如果输入字符串有任何参数,Faker.fake() 不起作用
- firebase - Firestore 导入 - 没有错误,但没有更改
- c# - 如何将 XSD 转换为 C#,包括 XSD 中的注释?