首页 > 解决方案 > 如何合并两个(或更多)TensorFlow 数据集?

问题描述

我已经获取了具有 3 个分区的 CelebA 数据集,如下所示

>>> celeba_bldr = tfds.builder('celeb_a')
>>> datasets = celeba_bldr.as_dataset()
>>> datasets.keys()
dict_keys(['test', 'train', 'validation'])

ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']

现在,我想将它们全部合并到一个数据集中。例如,我需要将训练和验证组合在一起,或者可能将它们合并在一起,然后根据我自己的不同主题不相交标准将它们拆分。有没有办法做到这一点?

我在文档https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset中找不到任何选项来执行此操作

标签: pythontensorflowtensorflow-datasetstensorflow2.0

解决方案


查看您链接的文档,数据集似乎有concatenate方法,所以我认为您可以获得一个联合数据集:

ds_train = datasets['train']
ds_test = datasets['test']
ds_valid = datasets['validation']

ds = ds_train.concatenate(ds_test).concatenate(ds_valid)

请参阅:https ://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#concatenate


推荐阅读