python - 如何保存表示在 Tensorflow 中构建的神经网络的对象
问题描述
我是 Tensorflow 的新手,正在 github 上玩一些代码。此代码为神经网络创建一个类,其中包括构建网络、制定损失函数、训练网络、执行预测等方法。
骨架代码看起来像这样:
class NeuralNetwork:
def __init__(...):
def initializeNN():
def trainNN():
def predictNN():
等等。神经网络是用 Tensorflow 构建的,因此,类定义及其方法使用 tensorflow 语法。
现在在我的脚本的主要部分,我通过
model = NeuralNetwork(...)
并使用model.predict等模型的方法来生成绘图。
由于训练神经网络需要很长时间,我想保存对象“模型”以供将来使用并有可能调用其方法。我试过泡菜和莳萝,但都失败了。对于泡菜,我得到了错误:
类型错误:无法腌制 _thread.RLock 对象
而对于莳萝,我得到了:
TypeError:无法腌制 SwigPyObject 对象
有什么建议我可以保存对象并仍然能够调用它的方法吗?这是必不可少的,因为我可能想在未来对一组不同的点进行预测。
谢谢!
解决方案
你应该做的是:
# Build the graph
model = NeuralNetwork(...)
# Create a train saver/loader object
saver = tf.train.Saver()
# Create a session
with tf.Session() as sess:
# Train the model in the same way you are doing it currently
model.train_model()
# Once you are done training, just save the model definition and it's learned weights
saver.save(sess, save_path)
而且,你完成了。然后,当您想再次使用该模型时,您可以做的是:
# Build the graph
model = NeuralNetwork()
# Create a train saver/loader object
loader = tf.train.Saver()
# Create a session
with tf.Session() as sess:
# Load the model variables
loader.restore(sess, save_path)
# Train the model again for example
model.train_model()
推荐阅读
- gcc - 构建 GCC 编译器时出现问题
- c - 将数据块转换为与结构对齐
- macos - Mac PyCharm 2018.2:从搜索栏中切换正则表达式的键盘快捷键?
- javascript - 按钮将输入添加到表单,但只有第一个。不在多个表单/div上
- matlab - 在matlab中移动目录中的所有文件和文件夹
- conv-neural-network - Leela Chess Zero:输出层的概率向量有多大?
- ubuntu - 命令 hadoop dfs -mkdir -p /user/flume/tweets/ 有什么问题
- c# - 如何从 requestTelemetry 获取会话 ID?
- coq - 归纳类型的构造函数何时穷举?
- gdb - 使用 gdb 在内存中搜索字符串/数字?