tensorflow - tensorflow 数据集 shuffle 然后批处理或批处理然后 shuffle
问题描述
我最近开始学习 tensorflow。
我不确定是否有区别
x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.shuffle(buffer_size=4)
ds.batch(4)
和
x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.batch(4)
ds.shuffle(buffer_size=4)
另外,我不确定为什么我不能使用
dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
因为它给出了错误
dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
AttributeError: 'TensorSliceDataset' object has no attribute 'shuffle_batch'
谢谢!
解决方案
TL; DR:是的,有区别。几乎总是,你会想在之前打电话。类上没有方法,您必须分别调用这两个方法来对数据集进行混洗和批处理。Dataset.shuffle()
Dataset.batch()
shuffle_batch()
tf.data.Dataset
a 的变换以tf.data.Dataset
它们被调用的相同顺序应用。Dataset.batch()
将其输入的连续元素组合成输出中的单个批处理元素。通过考虑以下两个数据集,我们可以看到操作顺序的效果:
tf.enable_eager_execution() # To simplify the example code.
# Batch before shuffle.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.batch(3)
dataset = dataset.shuffle(9)
for elem in dataset:
print(elem)
# Prints:
# tf.Tensor([1 1 1], shape=(3,), dtype=int32)
# tf.Tensor([2 2 2], shape=(3,), dtype=int32)
# tf.Tensor([0 0 0], shape=(3,), dtype=int32)
# Shuffle before batch.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.shuffle(9)
dataset = dataset.batch(3)
for elem in dataset:
print(elem)
# Prints:
# tf.Tensor([2 0 2], shape=(3,), dtype=int32)
# tf.Tensor([2 1 0], shape=(3,), dtype=int32)
# tf.Tensor([0 1 1], shape=(3,), dtype=int32)
在第一个版本中(洗牌前的批次),每批次的元素是输入中的 3 个连续元素;而在第二个版本中(批处理前洗牌),它们是从输入中随机采样的。通常,当通过(某些变体)小批量随机梯度下降进行训练时,每个批次的元素应该从总输入中尽可能均匀地采样。否则,网络可能会过度拟合输入数据中的任何结构,并且生成的网络将无法达到如此高的精度。
推荐阅读
- python - 将 N 位整数拆分为 M 块
- three.js - 为什么计算机上的 THREE.WebGLRenderer.setPixelRatio() 比移动设备慢?
- xlib - XTest 是否支持游戏手柄模拟摇杆的伪造功能?
- javascript - VSCode 扩展 Prettier 在方法参数后强制换行
- java - Java - 有没有办法从 base64 输入确定文件大小?
- html - 单击按钮时表单未提交 - onclick="this.form.submit()"
- r - 如何循环浏览书签状态输入并以正确的顺序恢复?
- wonderware - Wonderware - Intouch:演示许可证
- image - 如何计算k-means聚类中两点之间的颜色距离?
- sitefinity - 更新到开发服务器时,sitefinity 自定义小部件未显示任何内容