tensorflow - 为什么将会话参数 is_trainning 设置为 false 到 batchNorm 层时,Tensorflow 分割网络返回空数据?
问题描述
我正在使用神经网络使用 Tensorflow 进行图像分割。如果slim.batch_norm层的is_traning参数设置为 True ,
则训练阶段和推理运行正常。但是,当我使用is_training作为 false 运行会话时,这意味着(据我所知)只是通过网络推断/转发数据,结果分割图像数据为空。我相信它与 batchNorm 层有关,但我已经对此失去了理智,我就是无法让它工作。
我在 TensorFlow 中使用基于 Semantic Segmentation Suite的代码。以下是有效和失败的简化版本。
.....
def ConvBlock(inputs, n_filters, kernel_size=[3, 3],is_training=True):
net = slim.conv2d(inputs, n_filters, kernel_size=[1, 1], activation_fn=None)
net = slim.batch_norm(net, fused=True, is_training=is_training)
net = tf.nn.relu(net)
return net
def DepthwiseSeparableConvBlock(inputs, n_filters, kernel_size=[3, 3],is_training=True):
net = slim.separable_convolution2d(inputs, num_outputs=None, depth_multiplier=1, kernel_size=[3, 3], activation_fn=None)
net = slim.batch_norm(net, fused=True, is_training=is_training)
net = tf.nn.relu(net)
....
return net
def ConvTransposeBlock(inputs, n_filters, kernel_size=[3, 3],is_training=True):
net = slim.conv2d_transpose(inputs, n_filters, kernel_size=[3, 3], stride=[2, 2], activation_fn=None)
net = slim.batch_norm(net,is_training=is_training)
net = tf.nn.relu(net)
return net
def build_mobile_unet(inputs, .... ,is_training=True):
net = ConvBlock(inputs, 64, is_training=is_training)
net = DepthwiseSeparableConvBlock(net, 64, is_training=is_training)
net = slim.pool(net, [2, 2], stride=[2, 2], pooling_type='MAX')
....
net = ConvTransposeBlock(net, 64, is_training=is_training)
net = DepthwiseSeparableConvBlock(net, 64, is_training=is_training)
net = DepthwiseSeparableConvBlock(net, 64, is_training=is_training)
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, scope='logits')
return net
# Define the param placeholders
net_input_image = tf.placeholder(tf.float32,shape=[None,None,None,3], name="input")
net_input_label = tf.placeholder(tf.int32, [None,None,None])
# Training phase placeholder
net_training = tf.placeholder(tf.bool, name='phase_train')
model, _ = build_mobile_unet(
net_input=net_input_image,
....
is_training=net_training)
model = tf.nn.softmax(model, name="softmax_output")
with tf.name_scope('loss'):
cross_entropy =tf.losses.sparse_softmax_cross_entropy(logits=model, labels=net_input_label)
cross_entropy = tf.reduce_mean(cross_entropy)
# use RMSProp to optimize
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.0001,decay=0.995)
train_step = optimizer.minimize(cross_entropy)
# create train OP
total_loss = tf.losses.get_total_loss()
train_op = slim.learning.create_train_op(total_loss,optimizer)
# Do the training here
for epoch in range(args.epoch_start_i, args.num_epochs):
input_image_batch = ...
label_image_batch = ...
# Do the training
train_dict={
net_input_image:input_image_batch,
net_input_label:label_image_batch,
net_training: True
}
train_loss=sess.run(train_op, feed_dict=train_dict)
# Do the validation on a small set of validation images
for ind in val_indices:
input_image = np.expand_dims(np.float32(utils.load_image(val_input_names[ind])[:args.crop_height, :args.crop_width]),axis=0)/255.0
gt = utils.load_image(val_output_names[ind])[:args.crop_height, :args.crop_width]
gt = helpers.reverse_one_hot(helpers.one_hot_it(gt, label_values))
# THIS WORKS : Image segmentation result is OK
output_image = sess.run(
model,
feed_dict={
net_input_image:input_image,
net_training: True
})
# THIS FAILS : Image segmentation result is all Zeros....
output_image = sess.run(
model,
feed_dict={
net_input_image:input_image,
net_training: False
})
训练效果很好,网络收敛了……如果我总是将占位符net_training保持为 True,那么一切都很好。
但是我是否调用了 sess.run(model,...net_training: False),正如您在上面的代码中看到的那样,在测试一些图像时输出结果为空。
我做错了什么伙计们?
任何帮助将不胜感激。感谢您的时间。
解决方案
推荐阅读
- java - 绝对文件路径不正确
- c# - 当 xml 文件包含不正确的数据时处理反序列化错误
- mongodb - MongoDB 使用 between 子句连接两个集合
- angular - 如何在基于 Angular 7 的 Ionic 4 应用程序中使用 sha256 对文件进行哈希处理?
- shopify - 用于租赁服务的 Shopify POS
- angular - Angular 同时显示不同组件加载的正确方法?
- html - 两个 div 并排并清除第三个 div 的 flex
- android - java.lang.RuntimeException:无法启动活动
- hibernate - JPA在不相关表连接的情况下返回表名
- c++ - 如何使用二叉搜索树解决索引生成器的这个问题?