python - Keras 生成器一直在洗牌,尽管它被要求不要
问题描述
我默认使用 Keras 数据生成器将 shuffle 初始化为 false:
class data_generator(keras.utils.Sequence):
def __init__(self, frames, labels, batch_size, data_dir, shuffle=False):
'Initialization'
self.batch_size = batch_size
self.labels = labels
self.frames = frames
self.data_dir = data_dir
self.shuffle = shuffle
self.size = len(self.frames)
self.on_epoch_end()
...
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.frames))
if self.shuffle == True:
np.random.shuffle(self.indexes)
...
这就是我为训练和验证创建实例的方式:
train_generator = data_generator(x_train[:num_train_examples], y_train[:num_train_examples], batch_size, data_dir)
val_generator = data_generator(x_train[num_train_examples:], y_train[num_train_examples:], batch_size, data_dir)
然后训练模型:
model.fit_generator(train_generator,
validation_data=val_generator,
callbacks=[history],
epochs=num_epochs)
但是生成器不断产生随机索引:
starting training
Epoch 1/1
batch start: 0, batch end: 2
batch start: 24, batch end: 26
batch start: 2, batch end: 4
batch start: 114, batch end: 116
batch start: 4, batch end: 6
batch start: 60, batch end: 62
batch start: 6, batch end: 8
batch start: 68, batch end: 70
batch start: 8, batch end: 10
batch start: 94, batch end: 96
我该怎么做才能让它不洗牌?
生成器类的getitem函数:
def __getitem__(self, index):
'Generate one batch of data'
x_batch, y_batch = self.__data_generation(index)
return x_batch, y_batch
def __data_generation(self, index):
'Generates data containing batch_size samples'
limit = min(self.size, (index + 1)*self.batch_size)
x_batch = []
print('\nbatch start: ' + str(index*self.batch_size) + ', batch end: ' + str(limit))
for frame in self.frames[index*self.batch_size:limit]:
video_array = np.load(self.data_dir + '/' + frame + '.npy')
x_batch.append(np.array(video_array))
return np.array(x_batch), self.labels[index*self.batch_size:limit]
编辑:现在我可以看到模式,看起来非随机批次与随机批次交替
解决方案
我假设问题可能出在你的__len__(self)
函数中(如果你已经定义了它)。我将该__len__(self)
功能添加到您的代码中并尝试过,它现在没有随机播放。代码在这里:
class data_generator(keras.utils.Sequence):
def __init__(self, frames, labels, batch_size, data_dir, shuffle=False):
'Initialization'
self.batch_size = batch_size
self.labels = labels
self.frames = frames
self.data_dir = data_dir
self.shuffle = shuffle
self.size = len(self.frames)
self.on_epoch_end()
def __len__(self):
return int(np.ceil(self.size/self.batch_size))
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.frames))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __getitem__(self, index):
'Generate one batch of data'
x_batch, y_batch = self.__data_generation(index)
return x_batch, y_batch
# def __data_generation(self, index):
# 'Generates data containing batch_size samples'
# current_indices = self.indexes[index*self.batch_size:(index + 1)*self.batch_size]
# x_batch = []
# y_batch = []
# for idx in current_indices:
# # video_array = np.load(self.data_dir + '/' + self.frames[idx] + '.npy')
# # x_batch.append(np.array(video_array))
# y_batch.append(self.labels[idx])
# return np.array(x_batch), y_batch
def __data_generation(self, index):
'Generates data containing batch_size samples'
limit = min(self.size, (index + 1)*self.batch_size)
x_batch = []
print('\nbatch start: ' + str(index*self.batch_size) + ', batch end: ' + str(limit))
for frame in self.frames[index*self.batch_size:limit]:
video_array = np.load(self.data_dir + '/' + frame + '.npy')
x_batch.append(np.array(video_array))
return np.array(x_batch), self.labels[index*self.batch_size:limit]
上面的代码按您的预期工作,它不会随机播放。但是,您定义__data_generation函数的方式,如果您希望它随机播放,它就不起作用。因此,我编写了自己的__data_generation
函数,你可以看到它被注释掉了。如果你使用它,你可以获得你想要的功能。如果shuffle为True,它将随机播放。如果shuffle为False,则不会随机播放。希望能帮助到你。
推荐阅读
- python - 在 django 中更新记录时,检查新输入的数据是否已经与数据库中的其他记录匹配?
- javascript - $split() (jsonata) 函数中的正则表达式给出“正则表达式匹配零长度字符串”为什么?
- javascript - 将 mongoDB 结果推送到数组,但数组仍然为空
- java - 在 Selenium 中设置代理版本(使用 int)时,java.lang.Long 不能转换为类 java.lang.Integer?
- python - 从数据透视表中提取列
- firebase-realtime-database - Firebase - firebase.database.ServerValue.TIMESTAMP 是什么数据类型
- networkx - networkx: 直径给 13 想找到那些节点或距离是什么
- javascript - 如何在 React 中将道具传递给 {children}?
- git - Github Action 自托管到 Heroku 登录错误
- c - 如何修复 HackerRank 的“错误:控制到达非无效函数 [-Werror=return-type] 的末尾”?