tensorflow - 使用 tensorflow 的估计器 API 在 RNN 的每个时期中的权重矩阵和成本
问题描述
我使用 Estimator API 来训练 RNN 模型,我想绘制成本/纪元图并获得最佳模型权重矩阵。Estimator API 有可能吗?这是代码:
classifier.train(input_fn=lambda: input_fn_train(train_x, label_train, batch_size),steps=train_steps)
eval_result = classifier.evaluate(input_fn=lambda: input_fn_eval(test_x, label_test, batch_size))
解决方案
有可能的。您需要做的是配置您的 Estimator 以生成有助于您决定要保留哪些权重的相关信息。这可以通过检查点来完成。这是您模型的“保存”。将一些配置传递给 Estimatorconfig=
会很有用。
下面是一个带有自定义 Estimator 的示例:
def model_fn(features, labels, mode, params):
#Some code is here that gives you the output of your model from where
#you get your predictions.
if mode == tf.estimator.ModeKeys.TRAIN or tf.estimator.ModeKeys.EVAL:
#Some more code is here
loss = #your loss function here
tf.summary.scalar('loss', loss)
if mode == tf.estimator.ModeKeys.TRAIN:
#More code here that train your model
if mode == tf.estimator.ModeKeys.EVAL:
#Again more code that you use to get some evaluation metrics
if mode == tf.estimator.ModeKeys.PREDICT:
#Code...
return tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
configuration = tf.estimator.RunConfig(save_summary_steps=10,
keep_checkpoint_max=30,
save_checkpoints_steps=10,
log_step_count_steps=10)
custom_estimator = tf.estimator.Estimator(model_fn=model_fn,
model_dir='model_dir',
config=configuration)
custom_estimator.train(input_fn=input_fn_train, steps=10000)
save_summary_steps
:实际上,您可以这样想,就像您的估算器会在多少步之后更新您的摘要一样。这很有用,因此您可以每 10 步绘制一次损失图。
save_checkpoints_steps
:经过多少步后,您的估算器将在当前状态下保存。
您可以在model_dir
.
如果您使用的是罐装 Estimator,我认为摘要是预定义的,但损失函数已经存在,因此您只需配置打印摘要的频率以及保存模型状态的频率。
推荐阅读
- python - 如何修复:AttributeError:模块“tensorflow”在 JupyterNotebook 中没有属性“优化器”(使用 colab.research)
- javascript - 无法在第一页加载时获取参数
- python-3.x - 命名所有连接的 txt 文件
- amazon-web-services - AWS CodeBuild DOWNLOAD_SOURCE 失败
- c# - URI解析C#有问题吗?
- python - 当 qcombobox 索引更改覆盖 QUiloader 时,PySide2 在小部件上重新绘制
- pandas - 根据给定索引合并来自两个数据帧的元素
- python - 我需要帮助在我的代码中形成 if else 条件
- arrays - Array Sum(求所有对角元素和边界元素的总和)
- python - SoundRecognition 不起作用,找不到 pyaudio 模块