首页 > 解决方案 > 如何通过每个补丁/时期的输出来可视化训练过程?

问题描述

我在 Keras 中的神经网络学习了我的原始数据的表示。为了准确了解它是如何学习的,我认为绘制每个训练批次(或时期)的数据并将这些图转换为视频会很有趣。

我被困在如何在训练阶段获得模型的输出。

我想过做这样的事情(伪代码):

epochs = 200
plt_outputs = []
for i in range(epochs):
    model.fit(x_train,y_train, epochs = 1)
    plt_outputs.append(output_layer(x_test))

其中 output_layer 是我感兴趣的神经网络中的层。之后我将使用 plot_data 生成每个图并将其转换为视频。(那部分我还不关心..)

但这并不是一个好的解决方案,而且我不知道如何获得每批的输出。对此有什么想法吗?

标签: pythontensorflowkerasplotneural-network

解决方案


您可以自定义测试步骤中发生的事情,就像这个官方教程一样:

import tensorflow as tf
import numpy as np

class CustomModel(tf.keras.Model):
    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_pred = self(x, training=False)

        test_outputs.append(y_pred) # ADD THIS HERE

        # Updates the metrics tracking the loss
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}


# Construct an instance of CustomModel
inputs = tf.keras.Input(shape=(8,))
x = tf.keras.layers.Dense(8, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(1)(x)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"], run_eagerly=True)

test_outputs = list() # ADD THIS HERE

# Evaluate with our custom test_step
x = np.random.random((1000, 8))
y = np.random.random((1000, 1))
model.evaluate(x, y)

我添加了一个列表,现在在测试步骤中,它将将此列表附加到输出中。您需要添加run_eagerly=True才能model.compile()使其正常工作。这将输出此类输出的列表:

<tf.Tensor: shape=(32, 1), dtype=float32, numpy=
array([[ 0.10866462],
       [ 0.2749035 ],
       [ 0.08196291],
       [ 0.25862294],
       [ 0.30985728],
       [ 0.20230596],
            ...
       [ 0.17108777],
       [ 0.29692617],
       [-0.03684975],
       [ 0.03525433],
       [ 0.26774448],
       [ 0.21728781],
       [ 0.0840873 ]], dtype=float32)>

推荐阅读