首页 > 解决方案 > 将一个 tf.data.Datasets 与另一个的所有其他元素合并

问题描述

我想合并两个tf.data.Dataset,以便只有第一个的每个其他样本与另一个合并,而不会丢失任何样本。

例如,让我们有两个数字列表:

ds1 = tf.data.Dataset.range(10)
ds10 = tf.data.Dataset.range(10, 60, 10)

我想将它们组合起来,以便将来自第二个的样本添加到第一个,但只能每隔一次:

0, 11, 2, 23, 4, 35, 6, 47, 8, 59

有一种zip方法可以合并两个数据集,但它是通过从每个数据集中抽取一个样本来实现的——不合并样本意味着从 中删除一个样本ds10,这不是我想要的。

我可以从那里继续,并ds10用 zip 压缩期间掉落的“虚拟”样本进行压缩ds1,但它看起来效率不高。

有没有一种有效的方法来做到这一点,而不会丢弃样本(真实的或“虚拟的”)?

标签: pythontensorflowtensorflow-datasets

解决方案


尝试这个:

def combine(pair,to_add):
    combined = [pair[0], pair[1] + to_add]
    return tf.data.Dataset.from_tensor_slices(combined)

ds1 = tf.data.Dataset.range(10)
ds2 = tf.data.Dataset.range(10,60,10)

combined = tf.data.Dataset.zip((ds1.batch(2),ds2)).flat_map(combine)

解释:

首先,批处理ds1.batch(2)。这会产生[(0,1), (2,3), ...].
将此压缩到另一个数据集以获取[((0,1),10), ((2,3),20), ...].
撤消批处理,flat_map并在过程中将 each(a,b)ceach [((a,b),c), ...]like结合起来[(a,b+c), ...]
然后将结果展平以移除大括号并得到[0, 11, 2, 23, 4, 35, 6, 47, 8, 59].
在处理几个tf.data.Datasets 时,像这样的批处理和取消批处理是一种常见的模式。


推荐阅读