首页 > 解决方案 > 使用 TensorFlow 从检查点重新开始训练后保留训练/验证拆分

问题描述

我编写了一个 TensorFlow 训练循环,它在每个 epoch 结束时进行验证。在训练开始时,我将我的数据集拆分为训练和验证子集(大约 85%-15% 的拆分)。我的数据集实际上由存储在磁盘上的小块音频样本组成,我在拆分之前随机打乱整个数据集,因此我在训练和验证子集上得到了完全均匀的分布。问题是,如果我从给定的检查点重新开始训练,随机洗牌会再次发生,我怀疑这会导致数据污染——验证阶段可能会处理网络已经训练过的数据集的位。我想我看到这会影响重新评估后训练的损失和准确性,但这很难说。

我在网上找不到有关此特定问题的任何信息,但我建议的解决方案是将验证拆分中的文件名缓存到文件中,如果重新启动则从那里加载它们。有更好的解决方案吗?

为清楚起见,我使用 tf.data.Dataset API,使用简单的数据集管道构建训练和验证数据集,该管道首先从磁盘上的文件中读取样本。

标签: pythontensorflowmachine-learning

解决方案


如果您设置洗牌的种子,则顺序将是一致的:

import tensorflow as tf

for _ in range(5):
    ds = tf.data.Dataset.range(1, 10).shuffle(4, seed=42).batch(3)
    for i in ds:
        print(i)
    print()
tf.Tensor([4 1 2], shape=(3,), dtype=int64)
tf.Tensor([7 3 6], shape=(3,), dtype=int64)
tf.Tensor([5 9 8], shape=(3,), dtype=int64)

tf.Tensor([4 1 2], shape=(3,), dtype=int64)
tf.Tensor([7 3 6], shape=(3,), dtype=int64)
tf.Tensor([5 9 8], shape=(3,), dtype=int64)

tf.Tensor([4 1 2], shape=(3,), dtype=int64)
tf.Tensor([7 3 6], shape=(3,), dtype=int64)
tf.Tensor([5 9 8], shape=(3,), dtype=int64)

tf.Tensor([4 1 2], shape=(3,), dtype=int64)
tf.Tensor([7 3 6], shape=(3,), dtype=int64)
tf.Tensor([5 9 8], shape=(3,), dtype=int64)

tf.Tensor([4 1 2], shape=(3,), dtype=int64)
tf.Tensor([7 3 6], shape=(3,), dtype=int64)
tf.Tensor([5 9 8], shape=(3,), dtype=int64)

因此,您所需要的只是每次以相同顺序排列的文件列表,您可以使用tf.data.Dataset.list_files, 并设置shuffle=False

ds = tf.data.Dataset.list_files(r'C:\Users\User\Downloads\*', shuffle=False)

推荐阅读