首页 > 解决方案 > 使用数据生成器时,Keras 自定义指标 self.validation_data 为 none

问题描述

我一直在尝试训练模型并在每个时期结束时计算精度和召回率。

自定义指标

class Metrics(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.precision = []
        self.recall = []

    def on_epoch_end(self, epoch, logs={}):
        print(type(self.validation_data))
        print(self.validation_data)
        predict = np.round(np.asarray(self.model.predict(self.validation_data[0])))
        targ = self.validation_data[1]

        precision_score = sklm.precision_score(targ, predict)
        recall = sklm.recall_score(targ, predict)
        self.precision.append(precision_score)
        self.recall.append(recall)

    def avg_precision_score(self):
        return np.mean(self.precision_score)

    def avg_recall_score(self):
        return np.mean(self.recall)

并且在训练时我正在使用数据生成器。

   training_set = train_datagen.flow_from_directory('train/',
                                                     target_size=(dim_x,dim_y),
                                                     batch_size=8, # 16 32
                                                     class_mode='categorical')

    test_set = test_datagen.flow_from_directory('test/',
                                                 target_size=(dim_x,dim_y),
                                                 batch_size=8, # 16 32
                                                 class_mode='categorical')
    metrics = Metrics()
    history = classifier.fit_generator(
                training_set,
                steps_per_epoch=2,#50,
                epochs=1, # 25
                validation_data=test_set,
                validation_steps=10,
                callbacks=[metrics]
                )

但这将 self.validation 提供为 None 类型。我究竟做错了什么 ?

标签: pythontensorflowmachine-learningkerastf.keras

解决方案


找到了解决这个问题的方法。在问题评论中提及 https://github.com/keras-team/keras/issues/10472

class Metrics(Callback):

    def __init__(self, val_data, batch_size = 20):
        super().__init__()
        self.validation_data = val_data
        self.batch_size = batch_size

初始化验证数据解决了使用数据生成器时的问题。


推荐阅读