首页 > 解决方案 > Tensorflow 2.0 将训练后的参数保存到新文件中

问题描述

我需要使用 TF 的内置函数之一(如 tf.train.Checkpoint 或任何其他函数)保存 TensorFlow 2.0 模型的训练变量,并希望在新文件中调用它们。我没有使用 tf.Keras.Sequantial 也不想使用类似 model.save_weights()

我试过 tf.train.Checkpoint 来保存变量,但不知道如何恢复它们。我曾经在 TF 1.0 中使用 tf.train.Saver() 来使用会话保存变量并使用 tf.train.import_meta_graph 和 tf.train.latest_checkpoint 恢复它们。但是,到目前为止,我还没有在 TF 2.0 文档中找到等效的功能。

尝试使用 tensorflow 2.0 格式的检查点保护程序来保存训练参数 W、b_v、b_h

saver = tf.train.Checkpoint()

saver.listed = [W, b_v, b_h]

saver.mapped = {'W':saver.listed[0],'b_v':saver.listed[1],'b_h':saver.listed[2]}

save_path = saver.save('trained_pa​​rameters')

在一个新文件中:

恢复器 = tf.train.Checkpoint()

restorer.restore('trained_pa​​rameters')

通过之前映射的名称调用参数不起作用,不知道该怎么做

标签: pythontensorflow

解决方案


推荐阅读