python - Keras fit_generator:生成器内的随机增强+洗牌
问题描述
我创建了一个生成器以将其输入到fit_generator
keras 的功能中。生成器创建一些随机值。这就是我的做法:
class DataGenerator(object):
def __init__(self, X_Y_file_path, batch_size, N):
self.X_Y_file_path = X_Y_file_path
self.batch_size = size
self.N = N
def initialize_zeros(self):
X = np.zeros((self.batch_size, 1), dtype='int32')
Y = np.zeros((self.batch_size, 1), dtype='int32')
Y_neg = np.zeros((self.batch_size, self.N))
return X, Y, Y_neg
def generate(self):
while True:
i = 0
X, Y, Y_neg = initialize_zeros()
for row in load_data_per_line(self.X_Y_file_path): # load_data_per_line is generator function which goes each line at a time from one file.
x, y = row
y_neg = random.sample(id_list, self.N) # a list of id to pick randomly
X[i] = x
Y[i] = y
Y_neg[i] = y_neg
if i == self.batch_size:
yield ([X, Y_neg], Y) # Y_neg goes as input in the model.(not important here. just mentioning)
X, Y, Y_neg = initialize_zeros()
i = 0
所以这是我的发电机。使用相同的样本数据,它似乎可以正常工作。
我想知道如何在这个生成器中实现一个 shuffle 函数来在每个 epoch 之后进行shuffle?
搜索了一下,我发现了可以覆盖方法的序列on_epoch_end
,但不清楚如何使用Sequence
继承实现上述生成器。有什么帮助吗?(顺便说一句,上面的函数“安全”可以用use_multiprocessing
在fit_generator
吗?)
编辑
这X_Y_file_path
是一个文件(已知长度)。这load_data_per_line
是一个生成器函数,每行产生一个。
解决方案
您在使用序列的正确轨道上。当与多处理一起使用时,它将保证每个数据点都被看到一次。构建序列的一种简单方法是预加载原始数据,每次请求批处理时进行动态处理。
class MySeq(Sequence):
def __init__(self, X_Y_file_path, batch_size, N):
self.X_Y_file_path = X_Y_file_path
self.batch_size = size
self.N = N
self.data = load_data_per_line(self.X_Y_file_path)
def __len__(self):
return int(np.ceil(len(self.data) / self.batch_size))
def __getitem__(self, idx):
# Just slice the data based on batch index (idx)
batch_data = self.data[idx*self.batch_size:(idx+1)*self.batch_size]
X = np.zeros((len(batch_data), 1), dtype='int32')
Y = np.zeros((len(batch_data), 1), dtype='int32')
Y_neg = np.zeros((len(batch_data), self.N))
for row, i in enumerate(data):
x, y = row
y_neg = random.sample(id_list, self.N) # a list of id to pick randomly
X[i] = x
Y[i] = y
Y_neg[i] = y_neg
return [X, Y_neg], Y # This is a single batch
现在您可以对您的self.data
使用进行任何处理on_epoch_end()
推荐阅读
- ios - iOS 测试:iOS Sandbox 设置中没有出现“Reset Eligibility”
- java - 三元如果包含函数结果
- python - 如何从 AWS S3 嵌套目录中读取泡菜文件?
- vba - 两张表过滤数据结果仅选择性列复制粘贴到另一张表数据粘贴时不应替换
- mysql - 如何查询具有单引号的 json 数据类型列
- mongodb - 有没有更快的方法来克隆有限的 mongo 数据库
- ios - 在 tableView didSelectRowAt indexPath iOS 之前准备 segue 调用
- html - a href 不适用于带有 :after content:''" 的图像;
- c# - 当 SetParent() 不起作用时如何在面板中运行外部应用程序 C#
- visual-studio-code - VS Code Elixir F2“重命名符号”不起作用