首页 > 解决方案 > tf.data.Dataset - 删除缓存?

问题描述

是否可以删除调用后构建的内存缓存tf.data.Dataset.cache()

这就是我想做的。数据集的扩充是非常昂贵的,所以当前的代码或多或少是:

data = tf.data.Dataset(...) \
       .map(<expensive_augmentation>) \
       .cache() \
       # .shuffle().batch() etc. 

然而,这意味着每次迭代data都会看到相同的数据样本的增强版本。我想做的是使用缓存几个时期,然后重新开始,或者等效地做类似Dataset.map(<augmentation>).fleeting_cache().repeat(8). 这有可能实现吗?

标签: pythontensorflow

解决方案


缓存生命周期与数据集相关联,因此您可以通过重新创建数据集来实现:

def create_dataset():
  dataset = tf.data.Dataset(...)
  dataset = dataset.map(<expensive_augmentation>)
  dataset = dataset.shuffle(...)
  dataset = dataset.batch(...)
  return dataset

for epoch in range(num_epochs):
  # Drop the cache every 8 epochs.
  if epoch % 8 == 0: dataset = create_dataset()
  for batch in dataset:
    train(batch)

推荐阅读