首页 > 解决方案 > Is it possible to override the progress bar of TensorFlow's keras?

问题描述

In the last days, I have been observing a weird behavior in the printed loss in the progress bar. It turned out that the weird behaviour was due to the fact that the default progress bar of keras displays a moving average of the losses (rather than the actual losses at every epoch).

So, is it possible to override the progress bar of TensorFlow's keras? I don't think so.

There's the class tf.keras.utils.Progbar that contains the parameter stateful_metrics, which is probably what I need, but fit doesn't seem to provide an option to override the progress bar or to change the behaviour from moving average to actual loss of the epoch/step. What alternative do you suggest? Feel free to write an answer below with some reproducible code.

标签: tensorflowkerastensorflow2.0tf.keras

解决方案


听起来你想要的应该通过tf.keras.callbacks.ProgbarLogger. 从理论上讲,它应该按照以下示例中概述的方式工作,但是,当前存在tf.keras.callbacks.ProgbarLogger.

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255, x_test / 255

model = tf.keras.Sequential([
    Flatten(),
    Dense(128, activation='relu'),
    Dense(10)
])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
progbar_callback = tf.keras.callbacks.ProgbarLogger(stateful_metrics="accuracy")
model.fit(x_train, y_train, callbacks=[progbar_callback])

推荐阅读