tensorflow - 如何为训练和验证创建两个图?
问题描述
当我阅读有关图形和会话(图形和会话)的tensorflow 指南时,我发现他们建议为训练和验证创建两个图形。
我认为这是合理的,我想使用它,因为我的训练和验证模型是不同的(对于编码器-解码器模式或 dropout)。但是,我不知道如何在不使用 tf.saver() 的情况下使训练图中的变量可用于测试图。
当我创建两个图表并在每个图表中创建变量时,我发现这两个变量完全不同,因为它们属于不同的图表。我用谷歌搜索了很多,我知道有关于这个问题的问题,例如question1。但仍然没有有用的答案。如果有任何代码示例或任何人知道如何分别为训练和验证创建两个图,例如:
def train_model():
g_train = tf.graph()
with g_train.as_default():
train_models
def validation_model():
g_test = tf.graph()
with g_test.as_default():
test_models
解决方案
一种简单的方法是创建一个“前向函数”来定义模型并根据额外参数改变行为。
这是一个例子:
def forward_pass(x, is_training, reuse=tf.AUTO_REUSE, name='model_forward_pass'):
# Note the reuse attribute as it tells the getter to either create the graph or get the weights
with tf.variable_scope(name=name, reuse=reuse):
x = tf.layers.conv(x, ...)
...
x = tf.layers.dense(x, ...)
x = tf.layers.dropout(x, rate, training=is_training) # Note the is_training attribute
...
return x
现在您可以在代码中的任何位置调用“forward_pass”函数。例如,您只需提供 is_training 属性即可使用正确的退出模式。只要 'variable_scope' 的 'name' 相同,'reuse' 参数就会自动获得正确的权重值。
例如:
train_logits_model1 = forward_pass(x_train, is_training=True, name='model1')
# Graph is defined and dropout is used in training mode
test_logits_model1 = forward_pass(x_test, is_training=False, name='model1')
# Graph is reused but the dropout behaviour change to inference mode
train_logits_model2 = forward_pass(x_train2, is_training=True, name='model2')
# Name changed, model2 is added to the graph and dropout is used in training mode
如您所说,要添加到此答案中,您希望有 2 个单独的图表,您可以使用分配函数来添加:
train_graph = forward_pass(x, is_training=True, reuse=False, name='train_graph')
...
test_graph = forward_pass(x, is_training=False, reuse=False, name='test_graph')
...
train_vars = tf.get_collection('variables', 'train_graph/.*')
test_vars = tf.get_collection('variables','test_graph/.*')
test_assign_ops = []
for test, train in zip(test_vars, train_vars):
test_assign_ops += [tf.assign(test, train)]
assign_op = tf.group(*test_assign_ops)
sess.run(assign_op) # Replace vars in the test_graph by the one in train_graph
我是方法 1 的大力倡导者,因为它更干净并减少了内存使用量。
推荐阅读
- python - 新样本的因果推理
- pandas - 删除熊猫列中的字符编码
- scala - Spark中rtrim函数的意外结果
- json - 如何将字符串形式的 JSON 数组转换为 Lua 表或将其解析为 JSON
- sas - 从外部文件读取
- pandas - TensorFlow 数据集:无法将 NumPy 数组转换为张量(不支持的对象类型 numpy.ndarray)
- flutter - 将行定位在堆栈 Flutter 的底部
- docusignapi - SOAP API RequestTemplate 信封收件人没有签名组信息
- c# - 如何设置 Bing 地图样式?
- javascript - 函数无法识别来自另一个函数的条件