首页 > 解决方案 > 在 ML.NET 中执行 ITransformer.Transform 后从 IDataView 中提取 MultiClass 结果

问题描述

我正在尝试一般使用 ML.NET,而不必创建一个类作为模型的输入和输出。为此,在创建模型后:

        public static (ITransformer model, double accuracy) TrainMultiClassModel(MulticlassExperimentSettings experimentSettings, MLContext mlContext, IDataView myview, string LabelName)
    {
        ITransformer trainedModel;
        MulticlassClassificationExperiment experiment = mlContext.Auto().CreateMulticlassClassificationExperiment(experimentSettings);

        ExperimentResult<MulticlassClassificationMetrics> experimentResult = experiment.Execute(myview, LabelName);
        RunDetail<MulticlassClassificationMetrics> best = experimentResult.BestRun;

        trainedModel = best.Model;

        return (trainedModel, best.ValidationMetrics.MacroAccuracy);
    }

其中 myView 包含正确设置 DataKinds 的 CSV 文件。

数据示例: 在此处输入图像描述

然后我通过运行这样的东西来执行该模型:

            MemoryStream modelStream = new MemoryStream(ModelData);
            ITransformer trainedModel = mlContext.Model.Load(modelStream, out var modelInputSchema);
            var predictions = trainedModel.Transform(myview);

同样,myView 包含来自 CSV 文件的数据,只是预测列为空。

现在我们有了 IDataView 类型的“预测”。

对于回归结果,这很容易。查找名为“Score”的模式并将其加载为浮点数:

float[] scoreColumn = predictions.GetColumn<float>("Score").ToArray();

但它如何用于 MultiClass 实验?有一个名为“PredictedLabel”的“String”类型的模式,但它包含 0 到 1 之间的数字,如下所示:

var labelColumn = predictions.Schema.FirstOrDefault(s => s.Name == "PredictedLabel" && s.IsHidden == false);
string[] scoreColumn = predictions.GetColumn<string>(labelColumn).ToArray();

我如何获得(在这种情况下)物种的实际名称?还是我必须以某种方式将数字映射到名称?我为此使用哪个映射表?

先感谢您。

编辑:埃里克的代码给出了这个列表:

1.4
1.9
0.2
0.4
 0.3
0.1
0.5
0.6
1.5
1.3
1.6
1.0
1.1
1.8
1.2
1.7
2.5
2.1
2.2
2.0
2.4
2.3

那些是 22,这很奇怪:没有一个正确的物种确实有 22 个字符(如果那是名字中的字符),我确实只输入了 4 行数据来解决。“PredictedLabel”同时输出 4 个值,但仍然是数字:在此处输入图像描述

但现在我想知道:我如何阅读这个领域?也许它包含答案: 在此处输入图像描述

标签: c#.netmachine-learningml.net

解决方案


您要使用的是一种名为GetKeyValues. 这将为您提供VBuffer<ReadOnlyMemory<char>>,其中缓冲区中的每个字符串都是多类分类模型中“键”或“类”中相应索引的“值”。

var predictions = trainedModel.Transform(myview);

var labelColumn = predictions.Schema[labelName]; // this is "Species" in your example above

VBuffer<ReadOnlyMemory<char>> keys = default;
labelColumn.GetKeyValues(ref keys);

foreach (var key in keys.DenseValues())
{
    Console.WriteLine(key);
}

推荐阅读