python - 全局步骤不随批量规范和自定义估计器增加
问题描述
我有一个客户估计器,它有几个层,在模型函数中如下所示:
natural_layer = tf.layers.dense(inputs = natural_layer,
units = units,
activation = None,
use_bias = False,
kernel_regularizer = params['regularizer'],
name = 'pre_batch_norm_layer_' + str(i + 1))
natural_layer = tf.layers.batch_normalization(natural_layer,
axis = 1,
center = True,
scale = True,
training = (mode == tf.estimator.ModeKeys.TRAIN),
name = 'batch_norm_layer_' + str(i + 1))
natural_layer = params['natural_layer_activation'](natural_layer, name = 'activation_layer_' + str(i + 1))
因为我使用的是批量规范,所以训练操作设置如下:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optimizer = tf.contrib.opt.MultitaskOptimizerWrapper(params['optimization_algorithm'](params['training_rate']))
train_op = optimizer.minimize(loss, global_step = tf.train.get_global_step())
其中优化器通常是 tf.train.AdamOptimizer。
但是,当我去训练估计器时,全局步骤永远不会增加(所以训练将永远运行),我得到了这个:
WARNING:tensorflow:似乎全局步长(tf.train.get_global_step)没有增加。当前值(可能是稳定的):0 vs 之前的值:0。您可以通过将 tf.train.get_global_step() 传递给 Optimizer.apply_gradients 或 Optimizer.minimize 来增加全局步长。
我正在传递 tf.train.get_global_step() 以最小化,所以我不确定为什么它永远不会更新。我的预感是它与批处理规范化有关,因为当我删除它或将其替换为 dropout 时,一切正常(即使根据文档保留批处理规范化所需的更新操作行)。
有谁知道发生了什么?如果有帮助,很高兴发布更多代码。
解决方案
我无法弄清楚为什么全局步骤没有自动增加,但是通过将全局步骤添加到带有 tf.group 的 train_op 来手动增加全局步骤如下是一个很好的解决方法。
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optimizer = tf.contrib.opt.MultitaskOptimizerWrapper(params['optimization_algorithm'](params['training_rate']))
train_op = optimizer.minimize(loss)
global_step = tf.train.get_global_step()
update_global_step = tf.assign(global_step, global_step + 1, name = 'update_global_step')
return tf.estimator.EstimatorSpec(mode, loss = loss, train_op = tf.group(train_op, update_global_step))
推荐阅读
- azure-logic-apps - 如何区分 PowerAutomate 中的共享邮箱和用户邮箱?
- .net - 在 Azure 应用服务上托管 WCF 不响应客户端应用服务
- kdb - 如何在单个查询中更改命名空间?
- authentication - 尝试以用户身份登录时:目标类 [Laravel\Fortify\Http\Controllers\AuthenticateSessionController] 不存在
- loops - 无法在 Kotlin 中重新分配变量的值
- javascript - 得到网格的孩子根据内容增长和缩小
- python - 结束请求 Python
- wordpress - 从 woocommerce 结帐帐单运输字段中删除默认选定状态,但仅适用于第一次来宾用户填写信息
- javascript - 是否可以在 getServerSideProps() 中获取搜索词参数?
- javascript - 页面开始添加 DOM 元素之前的 Chrome.storage.sync 用法