python - 在 keras 中使用数据生成器时如何输出 ypred 和 ytrue
问题描述
我正在使用数据生成器训练 keras 模型,该数据生成器从目录中批量读取数据。这适用于model.fit()
. 但是在使用时model.predict()
,我希望同时返回ypred
和ytrue
值。
我可以启用/修改 model.predict() 来执行此操作(可能使用自定义回调)吗?
class DataGenerator(tf.keras.utils.Sequence):
def__init__(self, ids, batch_size=256):
self.batch_size=batch_size
self.ids = ids
def __len__(self):
return(self.ids)
def __getitem__(self, index):
X, y = np.load(f'data/{index}.npy', allow_pickle=True)
return X, y
def on_epoch_end(self):
'''Shuffle ids in each epoch'''
self.ids = np.random.choice(self.ids, len(self.ids), replace=False)
model = buildModel() #builds a multilayer perceptron
train_ids = np.arange(10000) #training data are in data/0.npy, data/1.npy, ... data/9999.npy
val_ids = np.arange(10000, 12000)
train_generator = DataGenerator(train_ids)
val_generator = DataGenerator(val_ids)
# Train model
history = model.fit(x=train_generator, epochs=100)
# Validate model (but I don't have ytrue)
ypred = model.predict(x=val_generator).reshape(-1)
# What I would like to achieve
(ypred, ytrue) = model.predict(x=val_generator, callbacks=[some_custom_callback])
# Or
ypred = model.predict(x=val_generator)
ytrue = some_fancy_method(val_generator)
解决方案
这可以通过向您的DataGenerator
类添加一个方法来完成,该方法将拟合模型作为输入,将其应用于生成的数据批次,然后返回ytrue
和ypred
。
class DataGenerator(tf.keras.utils.Sequence):
def__init__(self, ids, batch_size=256):
self.batch_size=batch_size
self.ids = ids
def __len__(self):
return(self.ids)
def __getitem__(self, index):
X, y = self.load_data(index)
return X, y
def load_data(self, index):
X, y = np.load(f'data/{index}.npy', allow_pickle=True)
return X, y
def predict(self, model):
ytrue, ypred = [], []
for index in self.ids:
X, y = self.load_data(index)
pred = model.predict(X).reshape(-1)
ytrue.extend(y)
ypred.extend(pred)
return ytrue, ypred
def on_epoch_end(self):
'''Shuffle ids in each epoch'''
self.ids = np.random.choice(self.ids, len(self.ids), replace=False)
train_generator = DataGenerator(train_ids)
val_generator = DataGenerator(val_ids)
# Train model
history = model.fit(x=train_generator, epochs=100)
# Validate model
ypred, ytrue = val_generator.predict(model)