首页 > 解决方案 > Tensorflow 2.0: Accessing a batch's tensors from a callback

问题描述

I'm using Tensorflow 2.0 and trying to write a tf.keras.callbacks.Callback that reads both the inputs and outputs of my model for the batch.

I expected to be able to override on_batch_end and access model.inputs and model.outputs but they are not EagerTensor with a value that I could access. Is there anyway to access the actual tensors values that were involved in a batch?

This has many practical uses such as outputting these tensors to Tensorboard for debugging, or serializing them for other purposes. I am aware that I could just run the whole model again using model.predict but that would force me to run every input twice through the network (and I might also have non-deterministic data generator). Any idea on how to achieve this?

标签: pythontensorflowkerastensorflow2.0tf.keras

解决方案


不,无法在回调中访问输入和输出的实际值。这不仅仅是回调设计目标的一部分。回调只能访问模型、要拟合的参数、纪元数和一些指标值。如您所见,model.input 和 model.output 仅指向符号 KerasTensors,而不是实际值。

为了做你想做的事,你可以接受输入,将它(可能与 RaggedTensor)与你关心的输出堆叠起来,然后让它成为你模型的额外输出。然后将您的功能实现为仅读取 y_pred 的自定义指标。在你的指标中,unstack y_pred 以获得输入和输出,然后可视化/序列化/等。

另一种方法可能是实现一个自定义层,该层使用 py_function 在 python 中调用函数。这在认真训练期间会非常慢,但在诊断/调试期间可能就足够了。


推荐阅读