首页 > 解决方案 > 在不丢失基数信息的情况下对 TensorFlow 数据集进行窗口化?

问题描述

tf.data.Dataset.window返回一个新数据集,其元素是数据集,而这些嵌套数据集的元素是所需大小的窗口。如果你有一个数据集(比如说,Dataset.range(10)并且想要一个类似的窗口数据集[0 1 2] [1 2 3] ... [7 8 9]),那么使用windowplus可以做到这一点flat_map

>>> d = tf.data.Dataset.range(10).window(3, shift=1, drop_remainder=True).flat_map(lambda x: x.batch(3))
>>> print(list(d))
[<tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 1, 2])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 3, 4])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([3, 4, 5])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([4, 5, 6])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([5, 6, 7])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([6, 7, 8])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([7, 8, 9])>]

但是,flat_map导致数据集丢失基数信息的原因:

>>> d.cardinality.numpy()
<tf.Tensor: shape=(), dtype=int64, numpy=-2>

(-2 是UNKNOWN_CARDINALITY;请参阅Tensorflow 2.0:flat_map() 以展平数据集的数据集返回基数 -2

我想创建此类窗口的数据集,同时保留基数信息。处理未知基数的数据集的一个小烦恼是,Keras 训练进度条需要先在一个 epoch 上运行,然后才能生成 ETA。我尝试.take(n_windows)了自己计算的地方n_windows,但仍然返回了一个带有UNKNOWN_CARDINALITY.

有没有办法在不丢失基数信息的情况下对数据集进行窗口化?

标签: pythontensorflowkerastensorflow-datasets

解决方案


推荐阅读