python-3.x - Tensorflow2 数据流水线如何优化?
问题描述
我使用一个大型图像数据集,我在第一步将其转换为 tfrecords,然后在下一步加载到 tf.data.dataset。
但是数据集太大了,尽管有 12 GB GPU,但我无法获得比 10 更大的批量大小。现在问题出现了,如何优化图像的加载,以便达到更大的 batch_size。
有没有办法使用 .fit_generator() 来优化这个过程?
这是我当前加载训练数据的过程(验证数据以相同的方式转换,因此这里也没有显示):
train_dataset = dataset.load_tfrecord_dataset(dataset_path, class_names_path, image_size)
train_dataset = train_dataset.shuffle(buffer_size=shuffle_buffer)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.map(lambda x, y: (
dataset.transform_images(x, image_size),
dataset.transform_targets(y, anchors, anchor_masks, image_size)))
train_dataset = train_dataset.prefetch(batch_size)
我的培训阶段开始:
history = model.fit(train_dataset,
epochs=epochs,
callbacks=callbacks,
validation_data=val_dataset)
解决方案
不幸的是,无论我们从软件角度优化多少,都有一些依赖于硬件架构的限制。
在您的情况下,可以增加批量大小的唯一方法是降低图像的尺寸;否则您将无法增加批量大小。
tf.data.Dataset()
是一个优秀的数据处理库,使用正确/必要的预处理步骤prefetch
确实可以让你处理得更快。
然而,由于硬件限制,您无法增加批量大小。要么减小图像大小以便能够增加批量大小,要么您需要选择更大的 GPU >=16 GB VRAM。
推荐阅读
- hyperledger-fabric - IBM Blockchain Platform VS Code GOPATH 错误
- python - 如何从父类调用属性?
- android-studio - 如何将现有的android项目转换为gradle依赖
- sql - 其他列相等的 SQL 组列
- ios - 使用swift杀死应用程序时如何存储ios推送通知
- xamarin.forms - 将焦点设置在 Xamarin.Forms MVVM (Prism) 中的 Entry 字段上
- mysql - MySQL排序表合并字段上具有相同值的行,但保持相互顺序
- powershell - 将自定义变量属性传递到 foreach 语句
- php - 如何使用动态 id 制作多个 ajax 触发器
- php - 来自 Firebase 的 PHP 读取值