首页 > 解决方案 > 拆分和重组 TensorFlow 数据集

问题描述

我目前有一个Dataset带有多个批次的张量流(批次数是可变的,但可以被 4 整除)。我想取出每 4 批用作测试,其余的用作训练,但我还没有遇到一个优雅的解决方案。所需结果的简化视觉示例:

Dataset = [b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12]
train = [b1,b2,b3,b5,b6,b7,b9,b10,b11]
test = [b4,b8,b12]

大多数关于Datasets 的 train-validation-test 拆分的解决方案都使用 和 的组合Dataset.take()Dataset.skip()因为他们不介意将数据拆分到中间的某个位置。然而,如果我要使用这个解决方案,它需要我计算数据集的大小,用多个take()s 和skip()s 在它上面运行一个丑陋的循环,然后收集结果并将它们连接到一个新的Dataset. 没有更好的方法来选择张量流数据集中的批次间隔吗?

标签: pythontensorflowkerastensorflow-datasets

解决方案


该解决方案可以通过组合enumerate()filter()和来实现map(),类似于此处提供的答案。

玩具示例:

list(
    Dataset.from_tensor_slices(np.arange(12))
    .batch(2)
    .as_numpy_iterator()
)

输出:

[array([0, 1]),
 array([2, 3]),
 array([4, 5]),
 array([6, 7]),
 array([8, 9]),
 array([10, 11])]

玩具示例的解决方案:

list(
    Dataset.from_tensor_slices(np.arange(12))
    .batch(2)
    #solution starts here
    .enumerate() 
    .filter(lambda i, data: (i+1)%4 !=0)
    .map(lambda i,data: data)
    #solution ends here
    .as_numpy_iterator()
)

出去:

[array([0, 1]), 
 array([2, 3]), 
 array([4, 5]),
 array([8, 9]),
 array([10, 11])]

推荐阅读