首页 > 解决方案 > WEKA 中十倍交叉验证的结果不一致

问题描述

昨天在weka中用2种方式实现了10折交叉验证,但是结果不一致。

方式一:直接调用方法eval.crossValidateModel()

 J48 j48 = new J48();
 j48.buildClassifier(ins);  // ins is the Instances object
 Evaluation eval = new Evaluation(ins);
 eval.crossValidateModel(j48, ins, 10, new Random(1)); // 10-fold cross validation
 ... // get results by eval.getXX(0) or eval.getXXX(1)

方式2:使用方法testCV()trainCV()在每个折叠中,

 ins.randomize(new Random(1)); // ins is the Instances object
 ins.stratify(10); // randomize the dataset then split into 10 folds

 for(int i=0; i<10; i++){       
    Instances trainData = ins.trainCV(10, i);
    Instances testData = ins.testCV(10, i);
    J48 j48 = new J48();
    j48.buildClassifier(trainData);

    Evaluation eval = new Evaluation(trainData);
    eval.evaluateModel(j48, testData);
    ... // get results by eval.getXX(0) or eval.getXXX(1)
 }

根据 weka api docs,上述 2 种方式应该有相同的结果,即方式 2 的平均结果(例如,精度,召回率)应该等于方式 1 的结果。但事实是它们不是同样,任何人都可以找出我的代码中的错误,或者提供其他不错的评估方法吗?谢谢你们!

标签: wekacross-validation

解决方案


如果您查看weka.classifiers.Evaluation.crossValidateModel方法的代码(取决于您的版本,委托对象),您会发现它使用了该weka.core.Instances.trainCV(int,int,Random)方法。此外,您需要Evaluation使用完整数据集的类先验来初始化对象。

这是更新的代码:

Evaluation eval = new Evaluation(ins);  // init evaluation
rand = new Random(1);
int numFolds = 10;  // 10-fold CV
ins.randomize(rand); // randomize the data
ins.stratify(numFolds); // stratify the randomized data for 10-fold CV
J48 template = new J48();  // classifier template for evaluation
//template.setOptions(...);  // if further options need to be set

for (int i = 0; i < numFolds; i++) {       
  Instances trainData = ins.trainCV(numFolds, i, rand);
  Instances testData = ins.testCV(numFolds, i);
  Classifier cls = AbstractClassifier.makeCopy(template);  // copy of classifier template
  cls.buildClassifier(trainData);
  eval.evaluateModel(cls, testData);  // accumulate statistics
}

... // get results by eval.getXX(0) or eval.getXXX(1)

推荐阅读