首页 > 解决方案 > Tensorflow 中的预取生成器(序列)

问题描述

我尝试使用 tf.keras.utils.Sequence 实现生成器 - 遵循此 Github 页面的方法: https ://mahmoudyusof.github.io/facial-keypoint-detection/data-generator/

所以我的发电机有以下形式:

class Generator(tf.keras.utils.Sequence):

  def __init__(self, *args, **kwargs):
    self.on_epoch_end()

  def on_epoch_end(self):
    #shuffle indices for batches

  def __len__(self):

  def __getitem__(self, idx):    
  #returning the idxth batch of the shuffled dataset    
  return X, y

不幸的是,我的模型的训练过程用这个生成器变得很长,所以我想预取它。

我试过了

Train_Generator = tf.data.Dataset.from_generator(Generator(Training_Files, batch_size=64, shuffle = True), output_types=(np.array, np.array))

将生成器转换为预取工作的类型。我收到错误消息:

`generator` must be callable.

我知道生成器必须支持 Iter() 协议才能工作。但是我该如何实施呢?或者你们知道提高这类生成器性能的其他方法吗?

提前谢谢!!

标签: pythontensorflowneural-networkdatasetgenerator

解决方案


我建议这样做:

Train_Generator = tf.data.Dataset.from_generator(Generator, args=[Training_Files, 64, True], output_types=(np.array, np.array))

推荐阅读