python - 我可以在会话之外或在创建会话之间有挂钩/回调吗?
问题描述
可以执行一个计划,根据train_and_evaluate()
我传递的规范训练和评估模型。我可以向EvalSpec和TrainSpec注册一些钩子,但有限制。
问题是我只能有一个SessionRunHook
,它将作为回调工作,但总是只在会话中。
我的问题是我有一个更复杂的时间表。在评估期间,我还想量化模型并进一步评估该量化模型。如前所述,这里的问题是,如果我使用-like 对象,我总是在会话中。SessionRunHook
所以问题是是否有一种方法可以train_and_evaluate()
在两者之间使用和注册一些回调:
train_and_evaluate(..):
# .. deep down ..
while <condition>:
with tf.Session() as train_sess:
# Do training ..
if the_callback_i_want:
the_callback_i_want()
with tf.Session() as eval_sess:
# Do evaluation ..
这可能吗?
解决方案
我想你可以实现begin
你自己的SessionHook
子类的方法。
import tensorflow as tf
def the_callback_i_want():
# You need to work in a new graph so let's create a new one
g = tf.Graph()
with g.as_default():
x = tf.get_variable("x", ())
x = tf.assign_add(x, 1)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print("I'm here !", sess.run(x))
class MyHook(tf.train.SessionRunHook):
def begin(self):
"""Called once before using the session.
When called, the default graph is the one that will be launched in the
session. The hook can modify the graph by adding new operations to it.
After the `begin()` call the graph will be finalized and the other callbacks
can not modify the graph anymore. Second call of `begin()` on the same
graph, should not change the graph.
"""
the_callback_i_want()
import iris_data
# Fetch the data
(train_x, train_y), (test_x, test_y) = iris_data.load_data()
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
# Build 2 hidden layer DNN with 10, 10 units respectively.
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns, hidden_units=[10, 10], n_classes=3)
# Fetch the data
(train_x, train_y), (test_x, test_y) = iris_data.load_data()
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
train_spec = tf.estimator.TrainSpec(input_fn=lambda:iris_data.train_input_fn(train_x, train_y,
10), max_steps=100)
eval_spec = tf.estimator.EvalSpec(input_fn=lambda:iris_data.eval_input_fn(test_x, test_y,
10), hooks=[MyHook()])
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
它打印:
INFO:tensorflow:Saving checkpoints for 100 into /var/folders/***/model.ckpt.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-10-18-20:19:28
I'm here ! 1.9386581
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/***/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-10-18-20:19:28
推荐阅读
- javascript - 未捕获的错误:[$injector:modulerr] 错误:[$injector:nomod] 模块“chart.js”不可用
- python - Python YAML 到 JSON 到 YAML
- amazon-s3 - 无法使用 SSE-KMS 加密(使用 .NET SDK)对 S3 存储桶进行分段上传
- c - 带有两个约束的链表插入
- python-pptx - python pptx PowerPoint幻灯片构建——理解模式/构建方法的麻烦
- html - 响应式表头未显示
- python - 如何使用蓝图在 Flask 中提供静态文件
- dialogflow-es - 服务器上的 Dialogflow 实体列表
- security - Nginx 处理 500 内部服务器错误安全问题
- python - 需要明确默认管道中的哪个组件会修改 Doc 上的 lemma_ 并需要有关提高 spacy 吞吐量的建议