首页 > 解决方案 > 处理更大的张量流数据集

问题描述

我对 Tensorflow 比较陌生,并且根据我在 ts 网站上找到的教程进行了一些模型训练。我已经能够将满足我初步要求的功能组合在一起。

我正在本地读取一个 csv 文件,该文件提供了一些指向与写在同一 csv 行上的标签相关联的图像的链接。我的代码大致是这样的:

def map_func(*row):
  img = process_img(img_filename)
  output = read(row)
  return img, output

dataset = tf.data.experimental.CsvDataset(CSV_FILE, default_format, header=True)
dataset = dataset.map(map_func)
dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
dataset = dataset.batch(NB_IMG)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

X, y = next(iter(dataset))

X_train, X_test = tf.split(X, split, axis=0)
y_train, y_test = tf.split(y, split, axis=0)

model = create_model()
model.compile(optimizer=OPTIMIZER, loss='mse')
model.fit(x=X_train, y=y_train, epochs=EPOCHS, validation_data=(X_test, y_test))

NB_IMG 是我拥有的图像总数。EPOCHS 在这里任意固定为给定值(通常为 20 或 40),并且拆分是应用于 NB_IMG 的比率。

我所有的图像都在本地计算机上,使用该代码,我的 GPU 目前可以大致管理多达 50000 张图像。训练因更多图像而失败(GPU 已耗尽)。我可以理解这是因为我一次读取所有数据,但我有点受阻,无法在这里采取下一步以便能够管理更大的数据集。

下面这部分是我猜需要改进的部分:

X, y = next(iter(dataset))

这里有人可以帮助我前进并指导我找到一些可以在更大数据集上训练模型的示例或片段吗?我对下一步行动有点迷失了,不确定在 ts 文档中的重点。我并没有真正在网上找到适合我需要的明确示例。我应该如何循环不同的批次?迭代器是如何编码的?

谢谢!

标签: pythontensorflowiteratordataset

解决方案


process_img那么,你能详细介绍一下这两个函数read吗?

在我的实验中,我注意到shuffle当您有大量数据并且缓冲区大小很大时,该函数可能会很慢。尝试评论该行并检查它是否运行得更快。如果是这样,您可以使用pandas加载您的 CSV 文件,然后将其随机播放并使用tf.data.Dataset.from_tensor_slices

Tensorflow 现在有一个很好的工具来分析模型和数据集管道(https://www.tensorflow.org/tensorboard/tensorboard_profiling_keras)。


推荐阅读