首页 > 解决方案 > Tensorflow Estimator 打印损失时使用的是什么数据集

问题描述

使用 Tensorflow Estimator 时,它会在 python 控制台(每 100 步)上打印全局步长和损失(此外,它还会打印学习率、交叉熵和 MAE,这是我的评估指标,并将这 3 个值打印在一条不同的线,我认为这是由于一些包装函数不是原始 Estimator API 的一部分,因为我正在使用谷歌开发人员的 ResNet 实现)。它看起来像这样:

    I0530 19:20:42.748463 10964 tf_logging.py:116] learning_rate = 3.552962e-05, cross_entropy = 2.2080934, MAE = 5.135024 (62.295 sec)   
    I0530 19:20:42.749458 10964 tf_logging.py:116] loss = 2.2080934, step = 76066 (62.295 sec)

我的问题是,正在计算什么损失(或正在计算什么 MAE)?
发生日志记录时,是否仅在特定步骤中丢失了一个示例?
是发生记录时特定步骤的批次损失吗?
或者可能是整个火车的损失?

另外,如果我假设有问题,请纠正我。我在这个领域很新。
谢谢。

标签: loggingtensorflowmachine-learningdeep-learning

解决方案


tf.Estimator自动LoggingTensorHook为损失和全局步骤设置一个。据推测,您运行的代码为其他值(学习率、交叉熵(这只是损失)和 MAE)设置了一个单独的钩子,这就是为什么它们被打印在不同的行上。

至于使用什么数据来产生值:它是“当前”批次的数据,即在完成记录的步骤中使用的批次。所以,在你提出的三个选项中,第二个是正确的。

这可以通过源代码确认,因为钩子会在“运行后”记录,接收最后一次session.run()调用的结果(一次只能得到一批)run_values


推荐阅读