python - 为什么我保存的 Tensorflow 模型在恢复时会预测废话?
问题描述
我在 Tensorflow 中训练了一个卷积神经网络,它分析图像并计算其中的对象,并将其保存以备后用。现在我正在尝试恢复模型并预测切割成图块的图像的值。不过,我得到的是无意义的值,并且每个图块的数字几乎相同。每个加载的模型都会给出围绕特定值的数字,每个图像相同,但每个模型不同。我在想,也许我使用了恢复模型中的错误张量?这是我的代码的摘录:
x = tf.placeholder(tf.float32, [None, 98, 98, 3], name='x')
y = tf.placeholder(tf.float32, [None, ], name='y')
# create two convolutional layers: layer1 and layer2
s3 = create_conv_layer_for_sum(layer2, f2, f3, [5, 5], 2, outf_sum, name='s_layer3')
y_pred = s3
error = tf.pow((y - y_pred), 2)
# other error measures also present
optimiser = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(error)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
# train the model here
saver = tf.train.Saver()
save_path = saver.save(sess, "models/model"+str(num)+"/model.ckpt")
def create_conv_layer_for_sum(input_data, num_input_channels, num_filters, filter_shape, stride, out_fction, name):
# ...
sum = tf.reduce_sum(transformed, axis=[1, 2, 3], name=name+'_output')
return sum
这部分是训练和保存。然后我恢复模型:
sess = tf.Session()
saver = tf.train.import_meta_graph('models/' + model + '/model.ckpt.meta')
saver.restore(sess, 'models/' + model + '/model.ckpt')
inputData = CNNutils.load_photo(photo, 98) # cuts photo into squares and stacks those as a numpy array
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('x:0')
s3 = graph.get_tensor_by_name('s_layer3_output:0')
y_pred = tf.reduce_sum(s3)
pred, sum3 = sess.run([y_pred, s3], feed_dict={x: inputData})
print(pred)
print(sum3)
s3
应该是最后一层的输出,然后y_pred
将来自单个图块的整个图像的预测相加。
我将不胜感激任何帮助。
解决方案
您可以使用 保存模型model.save
并使用 恢复模型吗load_model
?
Keras 支持更简单的界面,可以将模型权重和模型架构一起保存到单个 H5 文件中。
使用方式保存模型model.save
包括我们需要了解的有关模型的所有信息,包括:
- 模型权重。
- 模型架构。
- 模型编译细节(损失和指标)。
- 模型优化器状态。
然后可以通过调用load_model()
函数并传递文件名来加载保存的模型。该函数返回具有相同架构和权重的模型。
您可以在此处找到有关模型保存和加载的示例。
推荐阅读
- python - 矩阵减法 | ValueError:操作数无法与形状一起广播 (1,30) (30,455)
- apache-spark - 通过列名中包含 peroid(.) 的 Spark-Scala-Phoenix 读取 HBase 表
- python - django.db.utils.IntegrityError:NOT NULL 约束失败:accounts_user.user_id
- ssh - 远程服务器使用智能卡身份验证时如何设置 SQL Developer 数据库连接
- python - 如何创建包含当天汇总的表格,转置会生成重复的列以分解预订状态
- php - 有没有办法从两个不同的 html 文件中使用 PHP?
- flutter - Flutter 填充值的常量值无效
- php - Symantec Messaging 网关添加域肥皂请求
- python - 用熊猫数据框中的剩余行迭代每一行
- javascript - 查找两个或多个数组中的重复项数