首页 > 解决方案 > 在张量流中重新初始化迭代器后对数据集进行洗牌

问题描述

我正在使用 tensorflow dataset api 将数据输入模型。据我所知,我应该对数据集进行洗牌以从模型中获得最佳性能。但是,由于我正在训练一个时期,然后测试一个时期,依此类推......我不太确定我是否以不同的方式执行洗牌。为了更好地说明,下面是我的代码:

train_dataset = tf.data.TFRecordDataset(filename_train).map(_parse_function).filter(filter_examples)\
            .shuffle(60000, seed=mseed, reshuffle_each_iteration=False) \
            .batch(train_batch_size)
train_iterator = train_dataset.make_initializable_iterator(shared_name="Training_iterator")

因此,每当我使用整个数据集时,我都会将迭代器重新初始化为:

sess.run(train_iterator.initializer)

那安全吗?我在问,因为在训练时我得到了以下损失函数的形状

在此处输入图像描述

因此,不同时期之间的洗牌是确定性的吗?

请注意,我使用种子shuffle只是为了使结果在代码的不同运行之间可重现。

标签: pythontensorflowshuffletensorflow-datasets

解决方案


种子影响整个默认图表。通过设置种子,您可以确定洗牌,这意味着每次洗牌都会保持相同的顺序。所以是的,您将在第二个时期获得相同的订单。您还可以为种子设置占位符并在每个时期更改它,更多信息在这里 https://github.com/tensorflow/tensorflow/issues/13446 没有种子改组变成伪随机


推荐阅读