python-3.x - 在 TensorFlow 中保存神经网络的学习进度
问题描述
在 Python 中,Tensorflow:我使用 Tensorflow 训练并应用了一个神经网络,现在我需要保存它的进度以便在以后的时间点进一步训练它。
我使用了很多配置
saver = tf.train.Saver()
saver.restore()
tf.train.import_meta_graph('model/model.ckpt.meta')
tf.train.export_meta_graph('model/model.ckpt.meta')
ETC...
但它总是会产生错误。
这是我的代码。它类似于 Mnist 示例代码,但使用自定义生成的输入并具有单个连续的输出神经元。
x = tf.placeholder('float', [None, 4000]) # 4000 is my structure, just an example
y = tf.placeholder('float') # And I need a single, continuous output
def train_neural_network(x):
testdata_images, testdata_labels = generate_training_data_batch(size = 10)
#Generates test data
data=[]
for i in range(how_many_batches):
data.append(generate_training_data_batch(size = 10))
#Generates training data
prediction = neural_network_model(x)
# neural_network_model() is defined as a 4000x15x15x10x1 neural network
cost = tf.reduce_mean( tf.square( tf.subtract(y, prediction) ) )
optimizer = tf.train.AdamOptimizer(0.01).minimize(cost)
hm_epochs = 10
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(hm_epochs):
epoch_loss = 0
for i in range(how_many_batches):
epoch_x, epoch_y = data[i]
_, c = sess.run([optimizer, cost], feed_dict = {x: epoch_x, y: epoch_y})
epoch_loss += c
accuracy1 = tf.subtract(y, prediction)
result = sess.run(accuracy1, feed_dict={x: epoch_x, y: epoch_y})
print(result)
# This is just here so I can see what is going on
saver.save(sess, 'model/model.ckpt')
tf.train.export_meta_graph('model/model.ckpt.meta')
tf.reset_default_graph()
稍后在同一个文件中,我想使用保存的神经网络进行一些预测:
train_neural_network(x)
X, Y = generate_training_data_batch(size = 1)
prediction = neural_network_model(x)
with tf.Session() as sess:
tf.train.import_meta_graph('model/model.ckpt.meta')
sess.run(tf.global_variables_initializer())
thought = sess.run(prediction, feed_dict={x: X})
print(Y, thought)
使用此版本,我收到错误消息
ValueError: Tensor("Variable:0", shape=(4000, 15), dtype=float32_ref) must be from the same graph as Tensor("Placeholder_35:0", shape=(?, 4000), dtype=float32).
我也收到错误消息,例如
ValueError: At least two variables have the same name: Variable/Adam
我正在寻找这个几周的解决方案,所以我会很欣慰地最终解决这个问题。
解决方案
推荐阅读
- flutter - 从重启的 App Flutter 运行集成测试
- javascript - 如何将水平网格滚动到页面加载的特定位置 - 需要响应式解决方案
- ios - 在静态方法中快速设置委托
- php - (已解决)以编程方式更改 Yii2 ActiveForm/Widget Input 的值
- javascript - 在属性中组合两个字符串对象以在 .map 循环中使用
- variables - Mindstorms EV3 编程问题 - 颜色传感器和变量
- apache - 在 Apache 服务器上扫描文本文件以获取数据
- python - 在 VSCode 中跨环境自动化 PyTest
- symfony - SwiftMailer 附件 pdf 由 Dompdf 动态生成
- javascript - 如何在 AJax 调用的回调中连接到 signalR 服务器并将我的客户端方法注册到我的 signalR 服务器