python - TensorFlow 用不同的数据重新训练神经网络
问题描述
例如,我有一个inputs
神经网络列表
list_of_inputs = [inputs1, inputs2, inputs3, ... ,inputsN]
*以及相应的标签列表*
list_of_labels = [label1, label2, label3, ..., labelN]
我想将每一对输入/训练input,label
到神经网络中,记录损失,然后input,label
在同一网络上训练下一对并记录所有对的损失等 input,label
。
注意:我不想每次input,label
添加新的权重时都重新初始化权重,我想使用前一对训练过的权重。网络如下所示(您可以在其中看到我也在打印损失)。我该怎么办?
with tf.name_scope("nn"):
model = tf.keras.Sequential([
tfp.layers.DenseFlipout(64, activation=tf.nn.relu),
tfp.layers.DenseFlipout(64, activation=tf.nn.softmax),
tfp.layers.DenseFlipout(np.squeeze(labels).shape[0])
])
logits = model(inputs)
loss = tf.reduce_mean(tf.square(labels - logits))
train_op_bnn = tf.train.AdamOptimizer().minimize(loss)
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
for i in range(100):
sess.run(train_op_bnn)
print(sess.run(loss))
编辑:
问题是当我尝试在如下函数中格式化网络时:
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
inputs,labels = MEMORY[0]
logits, model_losses = build_graph(inputs)
loss = tf.reduce_mean(tf.square(labels - logits))
train_op_bnn = tf.train.AdamOptimizer().minimize(loss)
sess.run(train_op_bnn)
print(sess.run(loss))
我收到一个错误:
FailedPreconditionError Traceback (most recent call last)
<ipython-input-95-5ca77fa0606a> in <module>()
36 train_op_bnn = tf.train.AdamOptimizer().minimize(loss)
37
---> 38 sess.run(train_op_bnn)
39 print(sess.run(loss))
40
解决方案
logits, model_losses = build_graph(inputs)
loss = tf.reduce_mean(tf.square(labels - logits))
train_op_bnn = tf.train.AdamOptimizer().minimize(loss)
应该在上面
with tf.Session() as sess:
和你的init_op
定义之上
推荐阅读
- python - QTableWidget 验证数范围
- python - 如何将列表项添加到字典中已有值的键中?
- r - postGIS:ST_MakeEnvelope() 为什么只有一个类似的查询有效?
- r - double type object (characters and numbers) to dataframe
- python - Text Translation if column value is equal to language
- python - Take my data from my computer and Verify that data is not stolen
- python - gunicorn and PDF
- java - 无法将“java.lang.String”类型的值转换为所需的“java.time.LocalDate”类型
- sql - oracle 10g float formatting
- php - Laravel 多个观察者和作业