python - 保存使用 BatchNorm 的 TensorFlow 模型
问题描述
我正在尝试使用 Tensorflow 从 GAN 中保存生成器模型。我使用的模型有几个批处理规范层。当我保存权重时,只有运行全局变量初始化程序才能成功恢复它们,我不应该这样做,因为正在恢复所有变量。如果我在恢复之前运行全局变量初始化程序,当我使用加载的权重运行推理并为批量规范参数设置 is_training=False 时,模型的性能非常差。但是,如果 is_training=True,则模型按预期执行。这种行为应该完全相反。
为了节省重量,我这样做:
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if 'g_' in var.name]
g_saver = tf.train.Saver(g_vars)
... train model ...
g_saver.save(sess, "weights/generator/gen.ckpt")
当我恢复权重时,我使用相同的模型定义并执行以下操作:
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if 'g_' in var.name]
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
g_saver = tf.train.Saver(g_vars)
g_saver.restore(sess, "./weights/generator/gen.ckpt")
您是否需要执行特殊程序来说明批量标准权重?我是否缺少一些变量集合?
编辑:
我使用以下方法定义批处理规范层:
conv1_norm = tf.contrib.layers.batch_norm(conv1, is_training=training
我发现将 variables_collections=["g_batch_norm_non_trainable"] 添加到 batch_norm 函数中,然后执行
g_vars = list(set([var for var in t_vars if 'g_' in var.name] + tf.get_collection("g_batch_norm_non_trainable")))
有效,但这对于应该是一个简单的减肥指令来说似乎相当复杂。
解决方案
当您使用tf.contrib.layers.batch_norm和默认参数定义批处理规范化时,会创建三个变量:beta
、moving_mean
和moving_variance
。第一个是唯一可训练的变量,另外两个包含在tf.GraphKeys.GLOBAL_VARIABLES
集合中。
这就是为什么g_vars
使用以下行中的可训练变量定义并没有同时moving_mean
在moving_variance
列表中获得的原因:
g_vars = [var for var in t_vars if 'g_' in var.name]
由于您似乎只想保存生成器变量,因此我建议使用变量范围来定义您的生成器网络。
对随机张量进行上采样并使用批量归一化的示例:
import tensorflow as tf
import numpy as np
input_layer = tf.placeholder(tf.float32, (2, 7, 7, 64)) # (batch, height, width, in_channels)
with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
# define your generator network here ...
t_conv_layer = tf.layers.conv2d_transpose(input_layer,
filters=32, kernel_size=[3, 3], strides=(2, 2), padding='SAME', name='t_conv_layer')
batch_norm = tf.contrib.layers.batch_norm(t_conv_layer, is_training=True, scope='my_batch_norm')
print(batch_norm) # Tensor("generator/my_batch_norm/FusedBatchNorm:0", shape=(2, 14, 14, 32), dtype=float32)
您可以检查两者的变量列表tf.trainable_variables()
并tf.global_variables()
打印它们。由于可训练变量在此处描述的全局变量列表中,我们可以定义g_vars
为:
g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
如果我们检查这个列表,我们将拥有我们想要的批处理规范的所有变量:
for var in g_vars:
print("variable_name: {:45}, nb_parameters: {}".format(var.name, np.prod(var.get_shape().as_list())))
产生输出:
variable_name: generator/t_conv_layer/kernel:0 , nb_parameters: 18432
variable_name: generator/t_conv_layer/bias:0 , nb_parameters: 32
variable_name: generator/my_batch_norm/beta:0 , nb_parameters: 32
variable_name: generator/my_batch_norm/moving_mean:0 , nb_parameters: 32
variable_name: generator/my_batch_norm/moving_variance:0 , nb_parameters: 32
推荐阅读
- reactjs - 具有自定义数据提供者的自定义管理组件
- html - 如何解决 iframe 的填充问题?
- c# - Web API 空白参数值被转换为 null
- eslint - gulp-eslint 未输出到文件 - 无法正确配置 writableStream
- brain.js - 从json导入LSTM网络的brain.js问题
- sql - 使用 MAX 和 GROUP BY 时如何只得到一个结果
- python - 当我尝试为线性回归运行此代码时,输入包含 NaN、无穷大或对于 dtype('float64') 错误而言太大的值
- amazon-web-services - AWS - 如何使用 Javascript SDK 获取 Cloudfront 指标
- javascript - 终端输出:manpath:无法设置语言环境;确保 $LC_* 和 $LANG 是正确的
- sql - DBATools - 从 Export-DBAScript -PATH 中删除时间戳