首页 > 解决方案 > Keras fit_generator:生成器内的随机增强+洗牌

问题描述

我创建了一个生成器以将其输入到fit_generatorkeras 的功能中。生成器创建一些随机值。这就是我的做法:

class DataGenerator(object):
    def __init__(self, X_Y_file_path, batch_size, N):
        self.X_Y_file_path = X_Y_file_path
        self.batch_size = size
        self.N = N

    def initialize_zeros(self):
        X = np.zeros((self.batch_size, 1), dtype='int32')
        Y = np.zeros((self.batch_size, 1), dtype='int32')
        Y_neg = np.zeros((self.batch_size, self.N))
        return X, Y, Y_neg

     def generate(self):
        while True:
            i = 0 
            X, Y, Y_neg = initialize_zeros()
            for row in load_data_per_line(self.X_Y_file_path): # load_data_per_line is generator function which goes each line at a time from one file.
                x, y = row
                y_neg = random.sample(id_list, self.N) # a list of id to pick randomly
                X[i] = x
                Y[i] = y
                Y_neg[i] = y_neg
                if i == self.batch_size:
                    yield ([X, Y_neg], Y) # Y_neg goes as input in the model.(not important here. just mentioning)
                    X, Y, Y_neg = initialize_zeros()
                    i = 0

所以这是我的发电机。使用相同的样本数据,它似乎可以正常工作。

我想知道如何在这个生成器中实现一个 shuffle 函数来在每个 epoch 之后进行shuffle?

搜索了一下,我发现了可以覆盖方法的序列on_epoch_end,但不清楚如何使用Sequence继承实现上述生成器。有什么帮助吗?(顺便说一句,上面的函数“安全”可以用use_multiprocessingfit_generator吗?)

编辑

X_Y_file_path是一个文件(已知长度)。这load_data_per_line是一个生成器函数,每行产生一个。

标签: pythontensorflowkeras

解决方案


您在使用序列的正确轨道上。当与多处理一起使用时,它将保证每个数据点都被看到一次。构建序列的一种简单方法是预加载原始数据,每次请求批处理时进行动态处理。

class MySeq(Sequence):
    def __init__(self, X_Y_file_path, batch_size, N):
        self.X_Y_file_path = X_Y_file_path
        self.batch_size = size
        self.N = N
        self.data = load_data_per_line(self.X_Y_file_path)

     def __len__(self):
        return int(np.ceil(len(self.data) / self.batch_size))

     def __getitem__(self, idx):
        # Just slice the data based on batch index (idx)
        batch_data = self.data[idx*self.batch_size:(idx+1)*self.batch_size]
        X = np.zeros((len(batch_data), 1), dtype='int32')
        Y = np.zeros((len(batch_data), 1), dtype='int32')
        Y_neg = np.zeros((len(batch_data), self.N))
        for row, i in enumerate(data):
            x, y = row
            y_neg = random.sample(id_list, self.N) # a list of id to pick randomly
            X[i] = x
            Y[i] = y
            Y_neg[i] = y_neg
        return [X, Y_neg], Y # This is a single batch

现在您可以对您的self.data使用进行任何处理on_epoch_end()


推荐阅读