首页 > 解决方案 > 在 TensorFlow 中使用 MonitoredTrainingSession 与 Estimator 的原因是什么

问题描述

我看到了许多使用MonitoredTrainingSessiontf.Estimator作为训练框架的例子。但是,尚不清楚为什么我会使用其中一个。两者都可以用SessionRunHooks. 两者都与tf.data.Dataset迭代器集成,并且可以提供训练/验证数据集。我不确定一种设置的好处是什么。

标签: pythontensorflowmachine-learningtensorflow-estimator

解决方案


简短的回答是,MonitoredTrainingSession允许用户访问 Graph 和 Session 对象以及训练循环,同时Estimator向用户隐藏图形和会话的详细信息,并且通常可以更轻松地运行训练,尤其是train_and_evaluate在需要定期评估的情况下。

MonitoredTrainingSession与普通的 tf.Session() 不同之处在于它处理变量初始化、设置文件编写器并且还包含分布式训练的功能。

Estimator API另一方面,它是一个高级构造,就像Keras. 在示例中可能使用较少,因为它是稍后介绍的。它还允许使用 分发训练/评估DistibutedStrategy,并且它有几个允许快速原型制作的罐装估计器。

就模型定义而言,它们非常平等,都允许使用其中之一keras.layers,或者从头开始定义完全自定义的模型。因此,如果出于某种原因,您需要访问图构造或自定义训练循环,请使用MonitoredTrainingSession. 如果您只想定义模型、训练它、运行验证和预测而无需额外的复杂性和样板代码,请使用Estimator


推荐阅读