c# - ML .NET MulticlassEvaluationMetrics 总是在 testSet 上计算 0
问题描述
我一直在关注 ML .NET 教程中的这个示例:https ://github.com/dotnet/samples/tree/master/machine-learning/tutorials/GitHubIssueClassification
并构建了我自己的这个示例版本,它从 .xlsx(不同的数据集)读取数据并将其拆分为训练集和测试集。它运行良好并且可以做出正确的预测,但我一生都无法弄清楚为什么当我将 _testSet 输入其中时,评估指标(每个参数)总是显示为 0。当我喂 _trainSet 它评估为 1 这是预期的。
即使我设置 TestFraction == 0.5 它仍然评估为 0。
using System;
using System.Data;
using System.Data.OleDb;
using System.Collections.Generic;
using System.Linq;
using System.IO;
using Microsoft.ML;
namespace Test.Repository
{
public class SearchEntry
{
[LoadColumn(0)]
public string Topic { get; set; }
[LoadColumn(1)]
public string Subject { get; set; }
}
public class SearchPrediction
{
[ColumnName("PredictedLabel")]
public string Topic;
}
public class Googler
{
private static string _appPath => Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]);
public string SourceExcel { get; set; } = @"..\..\..\..\Test.Repository\model\in_data.xlsx";
public string ModelSavePath { get; set; } = @"..\..\..\..\Test.Repository\model\model";
public double TestFraction { get; set; } = 0.2d;
private static IDataView _trainingDataView;
private static MLContext _mlContext;
private static ITransformer _trainedModel;
private static IEstimator<ITransformer> pipeline;
private static PredictionEngine<SearchEntry, SearchPrediction> _predEngine;
private static List<SearchEntry> _trainSet;
private static List<SearchEntry> _testSet;
public void LoadModelData()
{
_mlContext = new MLContext(seed: 0);
var dt = Heplers.Excel.Query(SourceExcel, "SELECT * FROM [data$]");
var searchEntries = dt.AsEnumerable()
.Select(r => new SearchEntry { Topic = (string)r["Topic"], Subject = (string)r["Subject"] });
var dataview = _mlContext.Data.LoadFromEnumerable(searchEntries);
var split = _mlContext.Data
.TrainTestSplit(dataview, testFraction: TestFraction,
samplingKeyColumnName: "Topic");
_trainSet = _mlContext.Data
.CreateEnumerable<SearchEntry>(split.TrainSet, reuseRowObject: false).ToList();
_testSet = _mlContext.Data
.CreateEnumerable<SearchEntry>(split.TestSet, reuseRowObject: false).ToList();
_trainingDataView = _mlContext.Data.LoadFromEnumerable<SearchEntry>(_trainSet);
}
public void ProcessData()
{
Console.WriteLine($"=============== Processing Data ===============");
pipeline = _mlContext.Transforms.Conversion.MapValueToKey(inputColumnName: "Topic", outputColumnName: "Label")
.Append(_mlContext.Transforms.Text.FeaturizeText(inputColumnName: "Subject", outputColumnName: "SubjectFeaturized"))
.Append(_mlContext.Transforms.Concatenate("Features", "SubjectFeaturized"))
.AppendCacheCheckpoint(_mlContext);
Console.WriteLine($"=============== Finished Processing Data ===============");
}
public void BuildAndTrainModel()
{
var trainingPipeline = pipeline
.Append(_mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated("Label", "Features"))
.Append(_mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
Console.WriteLine($"=============== Training the model ===============");
_trainedModel = trainingPipeline.Fit(_trainingDataView);
Console.WriteLine($"=============== Finished Training the model Ending time: {DateTime.Now.ToString()} ===============");
}
public void Evaluate()
{
Console.WriteLine($"=============== Evaluating to get model's accuracy metrics - Starting time: {DateTime.Now.ToString()} ===============");
var testDataView = _mlContext.Data.LoadFromEnumerable<SearchEntry>(_testSet);
var testMetrics = _mlContext.MulticlassClassification.Evaluate(_trainedModel.Transform(testDataView));
Console.WriteLine($"=============== Evaluating to get model's accuracy metrics - Ending time: {DateTime.Now.ToString()} ===============");
Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for Multi-class Classification model - Test Data ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* MicroAccuracy: {testMetrics.MicroAccuracy:0.###}");
Console.WriteLine($"* MacroAccuracy: {testMetrics.MacroAccuracy:0.###}");
Console.WriteLine($"* LogLoss: {testMetrics.LogLoss:#.###}");
Console.WriteLine($"* LogLossReduction: {testMetrics.LogLossReduction:#.###}");
Console.WriteLine($"*************************************************************************************************************");
}
}
}
输出如下:
*************************************************************************************************************
* Metrics for Multi-class Classification model - Test Data
*------------------------------------------------------------------------------------------------------------
* MicroAccuracy: 0
* MacroAccuracy: 0
* LogLoss:
* LogLossReduction: NaN
*************************************************************************************************************
解决方案
切换
var split = _mlContext.Data
.TrainTestSplit(dataview, testFraction: TestFraction, samplingKeyColumnName: "Topic");
至
var split = _mlContext.Data
.TrainTestSplit(dataview, testFraction: TestFraction);
使用 samplingKeyColumnName: "Topic" 我的测试集将只有 2 个唯一主题,没有它有 6 个。因此指标很差。
但我仍然不喜欢这个结果。我总共有 10 个独特的主题,感觉就像测试集必须为每个主题至少有一些条目。Microsoft.ML TrainTestSplit 似乎不能保证这一点。
写了一个自定义拆分器:
private (List<SearchEntry> TrainSet, List<SearchEntry> TestSet) TrainTestSplit(List<SearchEntry> searchEntries, double testFraction)
{
var rand = new Random();
var testSet = searchEntries.AsEnumerable()
.Select(r => new { Random = rand.Next(), Entry = r })
.OrderBy(r => r.Random)
.Select(r => r.Entry)
.GroupBy(r => r.Topic)
.Select(r => r.Take((int)Math.Ceiling(searchEntries.Where(e => e.Topic == r.Key).Count() * testFraction)))
.SelectMany(r => r)
.ToList();
var trainSet = searchEntries.Except(testSet).ToList();
return (trainSet, testSet);
}
推荐阅读
- python - VerificationSMSCode() 得到了一个意外的关键字参数“电话”
- python - Pytest:捕获的标准错误设置和捕获的日志设置重复
- c# - 如何检查 `SomeType<>` 是否实现了 `SomeInterface<>`?
- python - 将 numpy 时间戳数组格式化为连接字符串
- java - Java android在右下角的照片中添加时间戳
- javascript - React Query v3 useInfiniteQuery 返回 isLoading,isFetching 始终为 true,isFetchingNextPage 始终为 false
- c++ - 如何创建矢量
? - python-3.x - 如何让 Django 接受我的日期格式
- sql - 带有 IF 条件的 Postgres 触发器
- react-native - React Navigation 自定义标题和自定义返回按钮