tensorflow - Tensorflow saver.save() 无法正常工作
问题描述
调用 saver.save() 会在 google colab 上保存四个文件,但在我的本地,我只是得到一个检查点文件。(使用 macOS High Sierra)
def train(model,model_dir,batch_szie,tr_x,tr_y,val_x,val_y,lr_rate):
saver = tf.train.Saver()
epoch = 0
max_epoch = 300
print ("starting training")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
while epoch < max_epoch:
for br_x,br_y in shuffle_batch(tr_x,tr_y,batch_size):
sess.run(model.opt,feed_dict={model.inputs:br_x,model.outputs:br_y,model.is_training:True,model.lr_rate:lr_rate})
epoch = epoch + 1
saver.save(sess,model_dir,write_meta_graph=True)
print ('completed training, model saved')
这就是我所说的训练功能。custom_dnn() 函数创建一个自定义类的对象,然后调用 train 函数。模型训练正确,我已经可视化了训练图,但是在保存时,我的本地只保存了一个检查点文件。
model_dir = "/Users/proj/model/"
batch_size = 50
lr_rate = 0.001
tf.reset_default_graph()
model_sub1 = custom_dnn()
train(model_sub1,model_dir,batch_size,tr_x,tr_y,val_x,val_y,lr_rate)
我究竟做错了什么?
解决方案
推荐阅读
- laravel - Laravel 以巨大的延迟发送密码重置电子邮件
- java - 搜索3个元素是否出现在同一行
- haskell - How to find out if it's summer or winter time?
- ios - 尝试为新闻提要创建 timeAgo 函数,但它只显示“0 秒前”
- python - Python - 不确定如何将元组附加到字典值
- azure - 通过 HTTPS 将 Azure AppService 连接到 VM 上的应用程序
- c - 设置区域设置不起作用
- rest - 如何使用rest api写入firebase数据库?
- java - 存储、检索、编辑和删除数据
- c# - 如何执行异步等待函数导入/SP EF6