首页 > 解决方案 > 混淆矩阵创建问题

问题描述

fit <- rpart(unacc~., data = carTrain, method = 'class')

我在 carTrain 上创建了决策树。

和预测

predict_unseen <- predict(fit,carTest, type = 'class')

这里 carTest是要预测的看不见的数据

现在我正在创建一个混淆矩阵

confusionMatrix(carTest$unacc,predict_unseen)

我收到错误

confusionMatrix(carTest$unacc,predict_unseen)

concurrentMatrix.default(carTest$unacc, predict_unseen) 中的错误:数据的级别不能多于参考

标签: r

解决方案


library(rpart)
library(imptree)
data(carEvaluation)

table(carEvaluation$acceptance)
> table(carEvaluation$acceptance)

  acc  good unacc vgood 
  384    69  1210    65

请注意,这unacc只是acceptance属性中的类别之一。

所以你可以做这样的事情:

{set.seed(3456)
  train <- caret::createDataPartition(carEvaluation$acceptance, p = .8, # partition 80%~20%
                                      list = FALSE)
  carTrain <- carEvaluation[train,]
  carTest  <- carEvaluation[-train,]
  fit <- rpart::rpart(acceptance~., data = carTrain, method = 'class')
}
df <- data.frame(obs = carTest$acceptance,
                 predict(fit, newdata = carTest, type = "class"))
cfm <- caret::confusionMatrix(df$predict.fit..newdata...carTest..type....class.., df$obs)
cfm
> cfm
Confusion Matrix and Statistics

          Reference
Prediction acc good unacc vgood
     acc    70    0    10     2
     good    5   12     1     0
     unacc   1    0   231     0
     vgood   0    1     0    11

Overall Statistics

               Accuracy : 0.9419          
                 95% CI : (0.9116, 0.9641)
    No Information Rate : 0.7035          
    P-Value [Acc > NIR] : < 2.2e-16       

                  Kappa : 0.8762          
 Mcnemar's Test P-Value : NA              

Statistics by Class:

                     Class: acc Class: good Class: unacc Class: vgood
Sensitivity              0.9211     0.92308       0.9545      0.84615
Specificity              0.9552     0.98187       0.9902      0.99698
Pos Pred Value           0.8537     0.66667       0.9957      0.91667
Neg Pred Value           0.9771     0.99693       0.9018      0.99398
Prevalence               0.2209     0.03779       0.7035      0.03779
Detection Rate           0.2035     0.03488       0.6715      0.03198
Detection Prevalence     0.2384     0.05233       0.6744      0.03488
Balanced Accuracy        0.9381     0.95248       0.9724      0.92157

您不一定需要完全按照此处示例的方式编写代码。我建议查看caret包的文档和rpart代码增强。或者您可以提供一个完全可重现的示例。


推荐阅读