首页 > 解决方案 > MXNET 针对 MNIST 数据集评估预测

问题描述

我已经使用 Lenet CNN 训练了 MNIST 模型,现在我正在尝试针对经过训练的网络评估一些输入图像。

训练和评估都很好,因为它在 100 个 epoch 后达到了 0.963241 的准确度。

[03:20:24] /home/greg/dev/matchbox/src/Lenet.hpp:246: EPOCH [99] Val Accuracy: 0.963241
[03:20:24] /home/greg/dev/matchbox/src/Lenet.hpp:247: EPOCH [99] Val LogLoss: 0.147613

现在我正在评估具有一位数的输入图像,但我的预测是错误的,即使它们具有很高的准确性。

The model predicts the input image to be a [8 ] with Accuracy = 0.99845

我怀疑问题出在我加载它们NDArray image_data = LoadInputImage(image_file)的图像并且 NDARRAY 形状不正确时。

NDArray Predictor::LoadInputImage(const std::string &image_file) {
    if (!FileExists(image_file)) {
        LG << "Image file " << image_file << " does not exist";
        throw std::runtime_error("Image file does not exist");
    }
    LG << "Loading the image " << image_file << std::endl;

    cv::Mat mat = cv::imread(image_file, cv::IMREAD_GRAYSCALE);
    mat.convertTo(mat, CV_32F);

    /*resize pictures to (28, 28) according to the pretrained model*/
    int channels = input_shape_[1];
    int height = input_shape_[2];
    int width = input_shape_[3];

    cv::resize(mat, mat, cv::Size(width, height));
    std::vector<float> array((float *) mat.data, (float *) mat.data + mat.rows * mat.cols);

    std::cout << mat;

    NDArray image_data = NDArray(input_shape_, global_ctx_, false);
    image_data.SyncCopyFromCPU(array.data(), input_shape_.Size());
    NDArray::WaitAll();
    return image_data;
}

这是图像转储到控制台时的外观std::cout << mat。在下面的输出中,为了便于阅读,我已替换255___

___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 249, 226, ___, 142, 100, 113, 198, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 187, 102, 50, 139, ___, 175, 133, 111, 71, 92, 238, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 190, 125, 197, 251, ___, ___, ___, ___, ___, 47, 125, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 201, 235, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 97, 130, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 64, 158, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 106, 68, 252, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 236, 150, 98, 194, 195, 245, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 117, 10, 0, 109, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 238, 244, 200, 48, 105, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 244, 160, 173, 236, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 46, 136, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 167, 115, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 244, 197, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 170, 60, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 211, 235, ___, ___, ___, ___, ___, ___, ___, 250, 31, 118, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, 150, 31, 251, ___, ___, ___, ___, ___, ___, 223, 149, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, 119, 52, 227, ___, ___, 235, 177, 44, 186, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 121, 190, 214, 76, 56, 70, 162, 252, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, 224, 177, 220, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___;
___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___, ___

数据的形状看起来正确28x28

预测方法

void Predictor::Score(const std::string &image_file) {
    // Load the input image
    NDArray image_data = LoadInputImage(image_file);
    LG << "Running the forward pass on model to predict the image";

    /*
     * The executor->arg_arrays represent the arguments to the model.
     *
     * Copying the image_data that contains the NDArray of input image
     * to the arg map of the executor. The input is stored with the key "data" in the map.
     */
    double ms = ms_now();

    image_data.CopyTo(&args_map_["data"]);
    NDArray::WaitAll();

    // Run the forward pass.
    executor_->Forward(false);
    NDArray::WaitAll();
    auto array = executor_->outputs[0].Copy(global_ctx_);

    /*
    * Find out the maximum accuracy and the index associated with that accuracy.
    * This is done by using the argmax operator on NDArray.
    */
    auto predicted = array.ArgmaxChannel();

    /*
     * Wait until all the previous write operations on the 'predicted'
     * NDArray to be complete before we read it.
     * This method guarantees that all previous write operations that pushed into the backend engine
     * for execution are actually finished.
     */
    predicted.WaitToRead();
    NDArray::WaitAll();

    auto best_idx = predicted.At(0);
    auto best_accuracy = array.At(0, best_idx);
    LG << "best_idx, best_accuracy = " << best_idx << " : " << best_accuracy;

    if (output_labels.empty()) {
        LG << "The model predicts the highest accuracy of " << best_accuracy << " at index "
           << best_idx;
    } else {
        LG << "The model predicts the input image to be a [" << output_labels[best_idx]
           << " ] with Accuracy = " << best_accuracy << std::endl;
    }

    mx_uint len = output_labels.size();
    std::vector<mx_float> pred_data(len);
    std::vector<mx_float> label_data(len);

    predicted.SyncCopyToCPU(&pred_data, len);

    // Display all candidates
    for (mx_uint i = 0; i < len; ++i) {
        auto val = pred_data[i];  // predicted
        auto label = label_data[i]; // expected

        auto best_idx = predicted.At(i);
        auto best_accuracy = array.At(0, best_idx);
        LG << "best_idx, best_accuracy = " << best_idx << " : " << best_accuracy;
        auto accuracy = array.At(0, i);
        LG << "Found, Expected, Accuracy  :: " << i << " : " << val << " = " << label << " : " << accuracy << " == "
           << best_accuracy;
    }

    ms = ms_now() - ms;

    auto args_name = net_.ListArguments();
    LG << "INFO:" << "label_name = " << args_name[args_name.size() - 1];
    LG << "INFO:" << "rgb_mean: " << "(" << rgb_mean_[0] << ", " << rgb_mean_[1]
       << ", " << rgb_mean_[2] << ")";
    LG << "INFO:" << "rgb_std: " << "(" << rgb_std_[0] << ", " << rgb_std_[1]
       << ", " << rgb_std_[2] << ")";
    LG << "INFO:" << "Image shape: " << "(" << input_shape_[1] << ", "
       << input_shape_[2] << ", " << input_shape_[3] << ")";
    LG << "INFO:" << "Batch size = " << input_shape_[0] << " for inference";
    LG << "INFO:" << "Throughput: " << (1000.0 * input_shape_[0] / ms)
       << " images per second";
}

输入图像是数字 '3' 28x28 的彩色图像,我怀疑问题出在此处,LoadInputImage(const std::string &image_file)但尚无法查明。

任何想法都会有所帮助。

标签: c++machine-learningmultidimensional-arraymxnet

解决方案


推荐阅读