首页 > 解决方案 > Pytorch - Tensorboard - Precision-Recall Curve 只显示一个点

问题描述

我正在训练一个有 6 个类的分类器,并且我的 pr 曲线具有以下功能:

def predict_class_probabilities(model, features):
    model.eval()  # Switch to evaluation mode
    with no_grad():  # Don't want to calculate gradients here, only when training
        predictions = model(features)
        # Get class prediction probabilities
        prediction_class_probabilities = F.softmax(predictions, dim=1)

    return prediction_class_probabilities

def add_pr_curves_tensorboard(summary_writer, model, features, labels, global_step=0):

    prediction_class_probabilities = predict_class_probabilities(model, features)

    # Iterate over each class and add pr curve to summary_writer
    for class_index in range(num_classes):
        # Need binary prediction for class class_index for the add_pr_curve method
        binary_class_label = labels == class_index
        # Prediction probability for class class_index
        predictions_class_probability = prediction_class_probabilities[:, class_index]
        tag = 'class {}'.format(class_index)

    summary_writer.add_pr_curve(tag, binary_class_label, predictions_class_probability,
                                global_step=global_step)

在训练我的模型后,我得到了 93% 的准确率(0 类和 1 类的准确率为 92% 和 96%),但是我的 0 类和 1 类的 pr 曲线看起来像这样(其他曲线看起来相似): pr 曲线类 0pr 曲线类 1。有人可以告诉我我在这里做错了什么吗?最良好的祝愿,托拜厄斯

标签: pythonpytorchclassificationtensorboardprecision-recall

解决方案


推荐阅读