tensorflow - tf.keras 中带有 tfrecords 或 numpy 的数据管道
问题描述
我想用比我的内存大的数据在 Tensorflow 2.0 的 tf.keras 中训练一个模型,但教程只显示了带有预定义数据集的示例。
我遵循了本教程:
使用 tf.data 加载图像,我无法对 numpy 数组或 tfrecords 上的数据进行此操作。
这是一个将数组转换为 tensorflow 数据集的示例。我想要的是使这项工作适用于多个 numpy 数组文件或多个 tfrecords 文件。
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
# Since the dataset already takes care of batching,
# we don't pass a `batch_size` argument.
model.fit(train_dataset, epochs=3)
解决方案
如果您有tfrecords
文件:
path = ['file1.tfrecords', 'file2.tfrecords', ..., 'fileN.tfrecords']
dataset = tf.data.Dataset.list_files(path, shuffle=True).repeat()
dataset = dataset.interleave(lambda filename: tf.data.TFRecordDataset(filename), cycle_length=len(path))
dataset = dataset.map(parse_function).batch()
parse_function 处理解码和任何类型的扩充。
如果使用 numpy 数组,您可以从文件名列表或数组列表中构造数据集。标签只是一个列表。或者它们可以在解析单个示例时从文件中获取。
path = #list of numpy arrays
或者
path = os.listdir(path_to files)
dataset = tf.data.Dataset.from_tensor_slices((path, labels))
dataset = dataset.map(parse_function).batch()
parse_function 处理解码:
def parse_function(filename, label): #Both filename and label will be passed if you provided both to from_tensor_slices
f = tf.read_file(filename)
image = tf.image.decode_image(f))
image = tf.reshape(image, [H, W, C])
label = label #or it could be extracted from, for example, filename, or from file itself
#do any augmentations here
return image, label
要解码 .npy 文件,最好的方法是不使用reshape
or read_file
,decode_raw
但首先使用 numpys 加载np.load
:
paths = [np.load(i) for i in ["x1.npy", "x2.npy"]]
image = tf.reshape(filename, [2])
或尝试使用decode_raw
f = tf.io.read_file(filename)
image = tf.io.decode_raw(f, tf.float32)
然后只需将批处理数据集传递给model.fit(dataset)
. TensorFlow 2.0 允许对数据集进行简单的迭代。无需使用迭代器。即使在更高版本的 1.x API 中,您也可以将数据集传递给.fit
方法
for example in dataset:
func(example)
推荐阅读
- servlets - 尽管传递了参数的值,但 Java Servlet 中的 request.getParameter 间歇性地变为 null
- gremlin - 查找节点结果集的所有边
- windows-10 - 3D Studio Max 9 - 未打开(在欢迎屏幕处冻结)
- javascript - Electron.js 以及如何让 Node Serial 工作
- json - 传递 LIST (Flutter) 时 json 原始正文中的 HTTP POST 出错
- javascript - 如何为java脚本中数组json对象的新属性赋值?
- mongodb - 聚合多个字段并计算平均mondodb
- c++ - 我写的合并排序的时间复杂度是多少
- c# - c#如何从它的派生类中取消订阅订阅事件的私有方法?
- python - 独角兽(uvicorn)。以编程方式或通过命令行运行?