首页 > 解决方案 > TensorFlow2.0 Model.fit auc 不等于model.predict得到的auc,然后在同一个开发集中用sklearn roc_auc_score计算

问题描述

tf2.0,我的模型显示在下面的代码中,它的 auc 在训练中约为 0.62

    def load_dataset(csv_path, shuffle=True):
        return tf.data.experimental.make_csv_dataset(
            csv_path,
            batch_size=256,  
            shuffle=shuffle,  
            label_name='label',  
            na_value='?',
            num_epochs=1,
            ignore_errors=True)

    train_data = load_dataset('../data_demo/train/copy1_5train_index.csv')
    train_data = train_data.map(embedding_train)
    dev_data = load_dataset('../data_demo/dev/dev_index.csv')
    dev_data = dev_data.map(embedding_dev)

    model = tf.keras.Sequential([
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(1, activation='sigmoid'),
    ])

    class_weight = {0: 0.19, 1: 0.81}

    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.AUC()])

    checkpoint_save_path = 'model/dc.ckpt'
    if os.path.exists(checkpoint_save_path + '.index'):
        print('--------------load the model-------------')
        model.load_weights(checkpoint_save_path)

    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                     save_weights_only=True,
                                                     save_best_only=False)

    model.fit(train_data, epochs=5, class_weight=class_weight, validation_data=dev_data, callbacks=cp_callback)

    model.summary()

但是,加载模型然后用 sklearn 计算 auc 是 0.5,代码如下所示

from sklearn.metrics import roc_auc_score

def build_predict_result():

    def load_dataset(csv_path, shuffle=True):
        return tf.data.experimental.make_csv_dataset(
            csv_path,
            batch_size=1024, 
            shuffle=shuffle,  
            label_name='label', 
            na_value='?',
            num_epochs=1,
            ignore_errors=True)

    def load_vec():
        news_vec = np.load('../data_demo/dev/dev_news_vec.npy')
        user_vec = np.load('../data_demo/dev/dev_user_vec.npy')
        return tf.constant(news_vec, dtype=tf.float32), tf.constant(user_vec, dtype=tf.float32)

    nvec, uvec = load_vec()

    def embedding(x, y):
        return tf.concat([tf.nn.embedding_lookup(uvec, x['uindex']), tf.nn.embedding_lookup(nvec, x['nindex'])],
                         axis=1), y

    dev_data = load_dataset('../data_demo/dev/dev_index.csv')
    dev_data = dev_data.map(embedding)

    model = tf.keras.Sequential([
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(1, activation='sigmoid'),
    ])

    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        metrics=['accuracy', 'AUC'])

    checkpoint_save_path = 'model/dc.ckpt'
    if os.path.exists(checkpoint_save_path + '.index'):
        print('--------------load the model-------------')
        model.load_weights(checkpoint_save_path)

    result = model.predict(dev_data)
    print(result.shape)

    dev_df = pd.read_csv('../data_demo/dev/dev_index.csv')
    dev_df['predict'] = result
    dev_df.to_csv('dev_index_predict.csv', index=False)

build_predict_result()

def evaluate():
    """
    计算ndcg
    :return:
    """
    df = pd.read_csv('dev_index_predict.csv')
    auc = roc_auc_score(df['label'], df['predict'])
    print('auc: ' + str(auc))

evaluate()

我找不到计算 auc 指标的 tf2.0 源代码。通过用[1, 0, 0, 1], [0.4, 0.2, 0.3, 0.5]等相同数据测试tf2.0和sklearn AUC函数,它们的结果是相等的。

标签: scikit-learntensorflow2.0auc

解决方案


感谢大家关注我的问题。我的问题解决了。函数 tf.data.experimental.make_csv_dataset 中参数“shuffle”的默认值为 True。所以预测结果与原始csv数据不对应。将 'shuffle' 更改为 False 解决我的问题。

    def load_dataset(csv_path, shuffle=False):
        return tf.data.experimental.make_csv_dataset(
            csv_path,
            batch_size=1024, 
            shuffle=shuffle,  
            label_name='label', 
            na_value='?',
            num_epochs=1,
            ignore_errors=True)

推荐阅读