tensorflow - tf-slim batch norm:训练/推理模式之间的不同行为
问题描述
我正在尝试基于流行的苗条实现来训练张量流模型,mobilenet_v2
并且正在观察我无法解释(我认为)与批量标准化相关的行为。
问题总结
推理模式下的模型性能最初有所提高,但在很长一段时间后开始产生微不足道的推理(全部接近零)。在训练模式下运行时,即使在评估数据集上也能保持良好的性能。评估性能受批标准化衰减/动量率的影响......不知何故。
下面有更广泛的实现细节,但我可能会因为文字墙而失去大多数人,所以这里有一些图片让你感兴趣。
下面的曲线来自我bn_decay
在训练时调整了参数的模型。
0-370k:(bn_decay=0.997
默认)
370k-670k:bn_decay=0.9
670k+:bn_decay=0.5
(橙色)训练(训练模式)和(蓝色)评估(推理模式)的损失。低是好的。
我试图制作一个最小的例子来演示这个问题 - MNIST 上的分类 - 但失败了(即分类效果很好,我遇到的问题没有表现出来)。我很抱歉不能进一步减少事情。
实施细节
我的问题是 2D 姿态估计,以关节位置为中心的高斯人为目标。它本质上与语义分割相同,除了不使用softmax_cross_entropy_with_logits(labels, logits)
我使用tf.losses.l2_loss(sigmoid(logits) - gaussian(label_2d_points))
的 a(我使用术语“logits”来描述我的学习模型的未激活输出,尽管这可能不是最好的术语)。
推理模型
在预处理我的输入之后,我的 logits 函数是对基本 mobilenet_v2 的范围调用,然后是一个未激活的卷积层,以使过滤器的数量合适。
from slim.nets.mobilenet import mobilenet_v2
def get_logtis(image):
with mobilenet_v2.training_scope(
is_training=is_training, bn_decay=bn_decay):
base, _ = mobilenet_v2.mobilenet(image, base_only=True)
logits = tf.layers.conv2d(base, n_joints, 1, 1)
return logits
训练操作
我已经尝试tf.contrib.slim.learning.create_train_op
过以及自定义培训操作:
def get_train_op(optimizer, loss):
global_step = tf.train.get_or_create_global_step()
opt_op = optimizer.minimize(loss, global_step)
update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
update_ops.add(opt_op)
return tf.group(*update_ops)
我正在tf.train.AdamOptimizer
使用learning rate=1e-3
.
训练循环
我正在使用tf.estimator.Estimator
API 进行培训/评估。
行为
培训最初进展顺利,预期性能会急剧提高。这与我的预期一致,因为最后一层被快速训练以解释预训练基础模型输出的高级特征。
然而,经过很长一段时间(60k 步,batch_size 8,在 GTX-1070 上约 8 小时)我的模型在推理模式下运行时开始输出接近零的值(~1e-11),即is_training=False
。在 *training mode is_training=True` 下运行时,完全相同的模型会继续改进, i.e.
,即使在评估集上也是如此。我已经在视觉上验证了这是。
经过一些实验后,我将bn_decay
(批量归一化衰减/动量速率)从默认值0.997
更改0.9
为 ~370k 步(也尝试过0.99
,但这并没有太大区别),并观察到准确性的即时改进。在推理模式下对推理的目视检查显示,~1e-1
在预期位置的推断顺序值中有明显的峰值,与训练模式中峰值的位置一致(尽管值要低得多)。这就是为什么准确度显着提高,但损失 - 虽然更具波动性 - 并没有太大改善。
这些效果在更多的训练后下降,并恢复到全零推理。
我bn_decay
在步长约 670k 处进一步将 0.5 降至 0.5。这导致损失和准确性的改进。我可能要等到明天才能看到长期效果。
下面给出了损失和评估指标图。请注意,评估指标基于 logits 的 argmax,高为佳。损失以实际值为准,低为好。橙色用于is_training=True
训练集,蓝色用于is_training=False
评估集。大约 8 的损失与所有零输出一致。
其他注意事项
- 我还尝试过关闭 dropout(即始终使用 运行 dropout 层
is_training=False
),并没有观察到差异。 1.7
我已经尝试过从到的所有 tensorflow 版本1.10
。没有不同。- 我
bn_decay=0.99
从一开始就使用预训练的检查点训练模型。与使用 default 的行为相同bn_decay
。 - 批量大小为 16 的其他实验导致质量相同的行为(尽管由于内存限制我无法同时评估和训练,因此定量分析批量大小为 8)。
- 我已经使用相同的损失和
tf.layers
API 训练了不同的模型,并从头开始训练。他们工作得很好。 - 从头开始训练(而不是使用预训练的检查点)会产生类似的行为,但需要更长的时间。
总结/我的想法:
- 我相信这不是过度拟合/数据集问题。
is_training=True
当使用 运行时,该模型在峰值位置和幅度方面对评估集做出了合理的推断。 - 我相信这不是不运行更新操作的问题。我以前没有使用
slim
过,但是除了使用它之外,它与我广泛使用arg_scope
的 API 并没有太大的不同。tf.layers
我还可以检查移动平均值并观察它们随着训练的进行而变化。 - 改变
bn_decay
值暂时显着影响结果。我接受价值0.5
低得离谱的说法,但我的想法已经不多了。 - 我尝试将
slim.layers.conv2d
层换成tf.layers.conv2d
withmomentum=0.997
(即与默认衰减值一致的动量)并且行为是相同的。 - 使用预训练权重和框架的最小示例可用于 MNIST
Estimator
分类,无需修改bn_decay
参数。
我已经查看了有关 tensorflow 和模型 github 存储库的问题,但除此之外没有发现太多问题。我目前正在尝试使用较低的学习率和更简单的优化器MomentumOptimizer
(
可能的解释
- 我最好的解释是我的模型参数以某种方式快速循环,以至于移动统计数据无法跟上批量统计数据。我从来没有听说过这种行为,它并不能解释为什么模型会在更多时间后恢复到不良行为,但这是我所拥有的最好的解释。
- 移动平均代码中可能存在错误,但在其他所有情况下它对我来说都非常有效,包括一个简单的分类任务。在我能提出一个更简单的例子之前,我不想提出问题。
反正我的想法不多了,调试周期很长,我已经花太多时间在这上面了。很高兴提供更多细节或按需运行实验。也很高兴发布更多代码,尽管我担心这会吓跑更多人。
提前致谢。
解决方案
使用 Adam降低学习率1e-4
和使用 Momentum 优化器(使用learning_rate=1e-3
和momentum=0.9
)都解决了这个问题。我还发现这篇文章表明这个问题跨越了多个框架,并且由于优化器和批量标准化之间的相互作用,它是一些网络的未记录病理。我不认为这是优化器由于学习率太高而未能找到合适的最小值的简单案例(否则训练模式下的性能会很差)。
我希望这可以帮助其他遇到同样问题的人,但我离满意还有很长的路要走。我很高兴听到其他解释。
推荐阅读
- python - 在 postgres 和 sqlalchemy 中限制同时执行两个查询
- node.js - 如何导入 discord.js?
- kendo-asp.net-mvc - 剑道:启用浮动标签的全宽
- flutter - 如何打开/查找 path_provider 创建的文件?
- asp.net - 指定 ASP.NET MVC 应用程序根目录的正确语法是什么?
- sql - 加入 Fragment CTE 来构建我们合适的表
- reactjs - 通用反应组件道具
- stripe-payments - 创建付款意图时如何设置默认卡
- c++ - 什么情况下会出现Linux信号函数的SIG_ERR?
- sql - 将标识列重置为从 1 开始,其中 identity_columns.last_value 为 NULL