首页 > 解决方案 > 运行时保存 TensorFlow 状态(权重)

问题描述

我有一个在模型上运行的 python 程序。我没有实现保护程序,所以我想知道是否有一种方法可以在配件运行时直接从内存中恢复权重(也许从临时文件中?)

标签: pythontensorflow

解决方案


我认为应该可以使用命令,tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES).

在不保存模型的情况下提取/恢复模型权重的代码如下所示。

import tensorflow as tf
import numpy as np

X_ = tf.placeholder(tf.float64, [None, 5], name="Input")
Y_ = tf.placeholder(tf.float64, [None, 1], name="Output")

X = np.random.randint(1,10,[10,5])
Y = np.random.randint(0,2,[10,1])

with tf.variable_scope("LogReg"):
    pred = tf.layers.dense(X_, 1, activation=tf.nn.sigmoid, name = 'fc1')
    loss = tf.losses.mean_squared_error(labels=Y_, predictions=pred)
    training_ops = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

with tf.Session() as sess:

    all_vars= tf.global_variables()
    def get_var(name):
        for i in range(len(all_vars)):
            if all_vars[i].name.startswith(name):
                return all_vars[i]
        return None

    sess.run(tf.global_variables_initializer())    
    for i in range(200):
        sess.run([training_ops], feed_dict={X_: X,Y_: Y})
        Weight_Vars = sess.run([tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)], feed_dict={X_: X,Y_: Y})
        print(Weight_Vars)

推荐阅读