首页 > 解决方案 > 如何使用 TensorFlow 2.0 打乱两个 numpy 数据集?

问题描述

我希望在TensorFlow 2.0中编写一个函数,而不是在每次训练迭代之前对数据及其目标标签进行洗牌。

假设我有两个 numpy 数据集,X 和 y,代表分类的数据和标签。我怎样才能同时洗牌?

使用sklearn它非常简单:

from sklearn.utils import shuffle
X, y = shuffle(X, y)

我怎样才能在TensorFlow 2.0中做同样的事情?我在文档中找到的唯一工具是tf.random.shuffle,但一次只需要一个对象,我需要喂两个。

标签: pythonnumpytensorflowtensorflow2.0

解决方案


而不是洗牌 x 和 y ,更容易洗牌他们的索引,所以首先生成一个索引列表

indices = tf.range(start=0, limit=tf.shape(x_data)[0], dtype=tf.int32)

然后洗牌这些指数

idx = tf.random.shuffle(indices)

并使用这些索引来打乱数据

x_data = tf.gather(x_data, idx)
y_data = tf.gather(y_data, idx)

你会洗牌的数据


推荐阅读