首页 > 解决方案 > 在 TensorFlow 2.0 中修改 TensorBoard

问题描述

我正在关注Sentdex 的 DQN 教程。我一直试图在 TF 2.0 中重写自定义 TensorBoard。重点是将**stats添加到文件中,例如:{'reward_avg': -99.0, 'reward_min': -200, 'reward_max': 2, 'epsilon': 1} 原始代码:

class ModifiedTensorBoard(TensorBoard):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.step = 1
        self.writer = tf.summary.FileWriter(self.log_dir)

    # Custom method for saving own metrics
    # Creates writer, writes custom metrics and closes writer
    def update_stats(self, **stats):
        self._write_logs(stats, self.step)

我的尝试:

def update_stats(self, **stats):
    for name, value in stats.items():
        with self.writer.as_default():
            tf.summary.scalar(name, value, self.step)

这样我得到:TypeError: unsupported operand type(s) for +: 'ModifiedTensorBoard' and 'list'

标签: tensorboardtensorflow2.0

解决方案


我遵循了相同的教程,这是我为使其工作所做的工作:

这是 ModifiedTensorBoard 类:

class ModifiedTensorBoard(TensorBoard):

    # Overriding init to set initial step and writer (we want one log file for all .fit() calls)
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.step = 1
        self.writer = tf.summary.create_file_writer(self.log_dir)
        self._log_write_dir = os.path.join(self.log_dir, MODEL_NAME)

    # Overriding this method to stop creating default log writer
    def set_model(self, model):
        pass

    # Overrided, saves logs with our step number
    # (otherwise every .fit() will start writing from 0th step)
    def on_epoch_end(self, epoch, logs=None):
        self.update_stats(**logs)

    # Overrided
    # We train for one batch only, no need to save anything at epoch end
    def on_batch_end(self, batch, logs=None):
        pass

    # Overrided, so won't close writer
    def on_train_end(self, _):
        pass

    def on_train_batch_end(self, batch, logs=None):
        pass

    # Custom method for saving own metrics
    # Creates writer, writes custom metrics and closes writer
    def update_stats(self, **stats):
        self._write_logs(stats, self.step)

    def _write_logs(self, logs, index):
        with self.writer.as_default():
            for name, value in logs.items():
                tf.summary.scalar(name, value, step=index)
                self.step += 1
                self.writer.flush()

推荐阅读