首页 > 解决方案 > Keras 生成器一直在洗牌,尽管它被要求不要

问题描述

我默认使用 Keras 数据生成器将 shuffle 初始化为 false:

class data_generator(keras.utils.Sequence):
    def __init__(self, frames, labels, batch_size, data_dir, shuffle=False):
        'Initialization'
        self.batch_size = batch_size
        self.labels = labels
        self.frames = frames
        self.data_dir = data_dir
        self.shuffle = shuffle
        self.size = len(self.frames)
        self.on_epoch_end()

  ...

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.frames))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

   ...

这就是我为训练和验证创建实例的方式:

train_generator = data_generator(x_train[:num_train_examples], y_train[:num_train_examples], batch_size, data_dir)
val_generator = data_generator(x_train[num_train_examples:], y_train[num_train_examples:], batch_size, data_dir)

然后训练模型:

model.fit_generator(train_generator,
                        validation_data=val_generator,
                        callbacks=[history],
                        epochs=num_epochs)

但是生成器不断产生随机索引:

starting training
Epoch 1/1

batch start: 0, batch end: 2

batch start: 24, batch end: 26

batch start: 2, batch end: 4

batch start: 114, batch end: 116

batch start: 4, batch end: 6

batch start: 60, batch end: 62

batch start: 6, batch end: 8

batch start: 68, batch end: 70

batch start: 8, batch end: 10

batch start: 94, batch end: 96

我该怎么做才能让它洗牌?

生成器类的getitem函数:

    def __getitem__(self, index):
        'Generate one batch of data'
        x_batch, y_batch = self.__data_generation(index)

        return x_batch, y_batch

    def __data_generation(self, index):
        'Generates data containing batch_size samples'
        limit = min(self.size, (index + 1)*self.batch_size)
        x_batch = []
        print('\nbatch start: ' + str(index*self.batch_size) + ', batch end: ' + str(limit))
        for frame in self.frames[index*self.batch_size:limit]:
            video_array = np.load(self.data_dir + '/' + frame + '.npy')
            x_batch.append(np.array(video_array))

        return np.array(x_batch), self.labels[index*self.batch_size:limit]

编辑:现在我可以看到模式,看起来非随机批次与随机批次交替

标签: pythonmachine-learningkeras

解决方案


我假设问题可能出在你的__len__(self)函数中(如果你已经定义了它)。我将该__len__(self)功能添加到您的代码中并尝试过,它现在没有随机播放。代码在这里:

class data_generator(keras.utils.Sequence):
    def __init__(self, frames, labels, batch_size, data_dir, shuffle=False):
        'Initialization'
        self.batch_size = batch_size
        self.labels = labels
        self.frames = frames
        self.data_dir = data_dir
        self.shuffle = shuffle
        self.size = len(self.frames)
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(self.size/self.batch_size))

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.frames))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __getitem__(self, index):
        'Generate one batch of data'
        x_batch, y_batch = self.__data_generation(index)
        return x_batch, y_batch

    # def __data_generation(self, index):
    #     'Generates data containing batch_size samples'
    #     current_indices = self.indexes[index*self.batch_size:(index + 1)*self.batch_size]
    #     x_batch = []
    #     y_batch = []
    #     for idx in current_indices:
    #         # video_array = np.load(self.data_dir + '/' + self.frames[idx] + '.npy')
    #         # x_batch.append(np.array(video_array))
    #         y_batch.append(self.labels[idx])

    #     return np.array(x_batch), y_batch

    def __data_generation(self, index):
        'Generates data containing batch_size samples'
        limit = min(self.size, (index + 1)*self.batch_size)
        x_batch = []
        print('\nbatch start: ' + str(index*self.batch_size) + ', batch end: ' + str(limit))
        for frame in self.frames[index*self.batch_size:limit]:
            video_array = np.load(self.data_dir + '/' + frame + '.npy')
            x_batch.append(np.array(video_array))
        return np.array(x_batch), self.labels[index*self.batch_size:limit]

上面的代码按您的预期工作,它不会随机播放。但是,您定义__data_generation函数的方式,如果您希望它随机播放,它就不起作用。因此,我编写了自己的__data_generation函数,你可以看到它被注释掉了。如果你使用它,你可以获得你想要的功能。如果shuffleTrue,它将随机播放。如果shuffleFalse,则不会随机播放。希望能帮助到你。


推荐阅读