tensorflow - 如何通过在 Tensorflow 中预测另一个权重的模型进行反向传播
问题描述
我目前正在尝试训练一个模型(超网络),该模型可以预测另一个模型(主网络)的权重,从而减少主网络的交叉熵损失。但是,当我使用 tf.assign 将新权重分配给网络时,它不允许反向传播到超网络中,从而使系统不可微分。我已经测试了我的权重是否正确更新,它们似乎是因为从更新的权重中减去初始权重是一个非零和。
这是我想要实现的最小示例。
import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import softmax
def random_addition(variables):
addition_update_ops = []
for variable in variables:
update = tf.assign(variable, variable+tf.random_normal(shape=variable.get_shape()))
addition_update_ops.append(update)
return addition_update_ops
def network_predicted_addition(variables, network_preds):
addition_update_ops = []
for idx, variable in enumerate(variables):
if idx == 0:
print(variable)
update = tf.assign(variable, variable + network_preds[idx])
addition_update_ops.append(update)
return addition_update_ops
def dense_weight_update_net(inputs, reuse):
with tf.variable_scope("weight_net", reuse=reuse):
output = tf.layers.conv2d(inputs=inputs, kernel_size=(3, 3), filters=16, strides=(1, 1),
activation=tf.nn.leaky_relu, name="conv_layer_0", padding="SAME")
output = tf.reduce_mean(output, axis=[0, 1, 2])
output = tf.reshape(output, shape=(1, output.get_shape()[0]))
output = tf.layers.dense(output, units=(16*3*3*3))
output = tf.reshape(output, shape=(3, 3, 3, 16))
return output
def conv_net(inputs, reuse):
with tf.variable_scope("conv_net", reuse=reuse):
output = tf.layers.conv2d(inputs=inputs, kernel_size=(3, 3), filters=16, strides=(1, 1),
activation=tf.nn.leaky_relu, name="conv_layer_0", padding="SAME")
output = tf.reduce_mean(output, axis=[1, 2])
output = tf.layers.dense(output, units=2)
output = softmax(output)
return output
input_x_0 = tf.zeros(shape=(32, 32, 32, 3))
target_y_0 = tf.zeros(shape=(32), dtype=tf.int32)
input_x_1 = tf.ones(shape=(32, 32, 32, 3))
target_y_1 = tf.ones(shape=(32), dtype=tf.int32)
input_x = tf.concat([input_x_0, input_x_1], axis=0)
target_y = tf.concat([target_y_0, target_y_1], axis=0)
output_0 = conv_net(inputs=input_x, reuse=False)
target_y = tf.one_hot(target_y, 2)
crossentropy_loss_0 = tf.losses.softmax_cross_entropy(onehot_labels=target_y, logits=output_0)
conv_net_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="conv_net")
weight_net_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="weight_net")
print(conv_net_parameters)
weight_updates = dense_weight_update_net(inputs=input_x, reuse=False)
#updates_0 = random_addition(conv_net_parameters)
updates_1 = network_predicted_addition(conv_net_parameters, network_preds=[weight_updates])
with tf.control_dependencies(updates_1):
output_1 = conv_net(inputs=input_x, reuse=True)
crossentropy_loss_1 = tf.losses.softmax_cross_entropy(onehot_labels=target_y, logits=output_1)
check_sum = tf.reduce_sum(tf.abs(output_0 - output_1))
c_opt = tf.train.AdamOptimizer(beta1=0.9, learning_rate=0.001)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Needed for correct batch norm usage
with tf.control_dependencies(update_ops): # Needed for correct batch norm usage
train_variables = weight_net_parameters #+ conv_net_parameters
c_error_opt_op = c_opt.minimize(crossentropy_loss_1,
var_list=train_variables,
colocate_gradients_with_ops=True)
init=tf.global_variables_initializer()
with tf.Session() as sess:
init = sess.run(init)
loss_list_0 = []
loss_list_1 = []
for i in range(1000):
_, checksum, crossentropy_0, crossentropy_1 = sess.run([c_error_opt_op, check_sum, crossentropy_loss_0,
crossentropy_loss_1])
loss_list_0.append(crossentropy_0)
loss_list_1.append(crossentropy_1)
print(checksum, np.mean(loss_list_0), np.mean(loss_list_1))
有谁知道我如何让 tensorflow 来计算梯度?谢谢你。
解决方案
在这种情况下,您的权重不是变量,它们是基于超网络计算的张量。在训练期间,您真正拥有的只是一个网络。如果我理解正确,那么您建议放弃超网络并能够仅使用主网络来执行预测。
如果是这种情况,那么您可以手动保存权重值并将它们重新加载为常量,或者您可以像在训练期间一样使用tf.cond
并tf.assign
分配它们,但使用tf.cond
选择使用变量或计算张量取决于是否你正在做训练或推理。
在训练期间,您将需要使用来自超网络的计算张量来启用反向传播。
评论中的示例w
是您将使用的权重,您可以在训练期间分配一个变量来跟踪它,然后使用tf.cond
该变量(在推理期间)或来自超网络的计算值(在训练期间)。在此示例中,您需要传入一个布尔占位符is_training_placeholder
以指示您是否正在运行推理训练。
tf.assign(w_variable, w_from_hypernetwork)
w = tf.cond(is_training_placeholder, true_fn=lambda: w_from_hypernetwork, false_fn=lambda: w_variable)
推荐阅读
- r - 多边形集群的分层多边形(例如州/县)
- python - 无法使用 pyodbc 连接到 Azure SQL Server
- reactjs - 将参数传递给 Modal 窗口 Antd React
- java-8 - 在特性标志下维护不同版本的java方法
- realm - 在 Keycloak 中继承/共享用户联合
- python - 在 pandas 的新文章中存储 Tf-idf 矩阵并更新现有矩阵
- json - 来自 SQL Server 上 FOR JSON 的平面 JSON 结果,带有连接查询
- asp.net-core - Asp.Net Core 2.0 跨服务器负载均衡
- c# - C# WinForm 中的 AppDomain.CurrentDomain.GetAssemblies()
- python - 获取您的电子邮件的永久地址