首页 > 解决方案 > 如何通过改组将张量流数据集拆分为 N 个数据集

问题描述

我有一个 tensorflow 数据集ds,我想将其拆分为 N 个数据集,它们的并集是原始数据集,并且它们之间不共享样本。我试过了:

ds_list = [ds.shard(N,index=i) for i in range(N)]

但不幸的是,这不是随机的:每个新数据集总是会从原始数据集中获得相同的样本。例如,ds_list[0] 的样本编号为 0,N,2N,3N...,而 ds_list[1] 的样本编号为 1,N+1,2N+1,3N+1... 有什么办法可以将原始数据集随机细分为相同大小的数据集?

不幸的是,简单地改组之前不会解决问题:

import tensorflow as tf
import math

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15, 16, 17, 18, 19, 20])

N=2
ds = ds.shuffle(20)
ds_list = [ds.shard(N,index=i) for i in range(N)]


for ds in ds_list:
    shard_set = sorted(set(list(ds.as_numpy_iterator())))
    print(shard_set)

输出:

    [3, 5, 6, 8, 11, 12, 14, 15, 19, 20]
    [1, 2, 4, 5, 6, 7, 8, 14, 15, 20]

如同:

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15, 16, 17, 18, 19, 20])
N=2
ds_list = []
ds = ds.shuffle(20)
size = ds.__len__()
sub = math.floor(size/N)
for n in range(N):
    ds_sub = ds.take(sub)
    remainder = ds.skip(sub)
    ds_list.append(ds_sub)
    ds = remainder  

for ds in ds_list:
    shard_set = sorted(set(list(ds.as_numpy_iterator())))
    print(shard_set)

标签: tensorflowtensorflow2.0

解决方案


也许(对于 N 个分片):

ds_list = []
ds = ds.shuffle()
size = ds.__len__()
sub = floor(size/N)
for n in range(N):
    ds_sub = ds.take(sub)
    remainder = ds.skip(sub)
    ds_list.append(ds_sub)
    ds = remainder  

推荐阅读