python - 如何通过每个补丁/时期的输出来可视化训练过程?
问题描述
我在 Keras 中的神经网络学习了我的原始数据的表示。为了准确了解它是如何学习的,我认为绘制每个训练批次(或时期)的数据并将这些图转换为视频会很有趣。
我被困在如何在训练阶段获得模型的输出。
我想过做这样的事情(伪代码):
epochs = 200
plt_outputs = []
for i in range(epochs):
model.fit(x_train,y_train, epochs = 1)
plt_outputs.append(output_layer(x_test))
其中 output_layer 是我感兴趣的神经网络中的层。之后我将使用 plot_data 生成每个图并将其转换为视频。(那部分我还不关心..)
但这并不是一个好的解决方案,而且我不知道如何获得每批的输出。对此有什么想法吗?
解决方案
您可以自定义测试步骤中发生的事情,就像这个官方教程一样:
import tensorflow as tf
import numpy as np
class CustomModel(tf.keras.Model):
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_pred = self(x, training=False)
test_outputs.append(y_pred) # ADD THIS HERE
# Updates the metrics tracking the loss
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Update the metrics.
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct an instance of CustomModel
inputs = tf.keras.Input(shape=(8,))
x = tf.keras.layers.Dense(8, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(1)(x)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"], run_eagerly=True)
test_outputs = list() # ADD THIS HERE
# Evaluate with our custom test_step
x = np.random.random((1000, 8))
y = np.random.random((1000, 1))
model.evaluate(x, y)
我添加了一个列表,现在在测试步骤中,它将将此列表附加到输出中。您需要添加run_eagerly=True
才能model.compile()
使其正常工作。这将输出此类输出的列表:
<tf.Tensor: shape=(32, 1), dtype=float32, numpy=
array([[ 0.10866462],
[ 0.2749035 ],
[ 0.08196291],
[ 0.25862294],
[ 0.30985728],
[ 0.20230596],
...
[ 0.17108777],
[ 0.29692617],
[-0.03684975],
[ 0.03525433],
[ 0.26774448],
[ 0.21728781],
[ 0.0840873 ]], dtype=float32)>
推荐阅读
- angular - 对 ngrx 效果的操作类型不匹配
- r - 如何在 R Markdown 的投影仪演示文稿的标题页中添加更多信息?
- java-8 - JDK 或 JRE 的 Java 核心 API 部分?
- selenium - 无效的 --log-level 值。无法初始化日志记录。Exiting... 启动 Selenium Grid 节点时出错
- android - 使用 BillingClient 进行应用内计费。响应代码 = -1。服务连接断开
- python - 如何使用 Python 获取文本中的值 [n]?
- android - 如何在我的应用程序中设置我的 DatePicker 对话框
- linux - 如何修复oracle TNS - 远程连接数据库时连接超时错误?
- c# - 为标签报告的对象列表重新排序,然后向下
- javascript - 编辑:如何使用 Bootstrap className 来使用开关?