首页 > 解决方案 > 关于 R 中的 K 折交叉验证

问题描述

我创建了这个代码的函数来执行逻辑回归的 5 折交叉验证。

  require(ISLR)
    folds <- cut(seq(1,nrow(Smarket)),breaks=5,labels=FALSE)



    log_cv=sapply(1:5,function(x)
    {
      set.seed(123)           

      testIndexes <- which(folds==x,arr.ind=TRUE)
      testData <- Smarket[testIndexes, ]
      trainData <- Smarket[-testIndexes, ]
      glm_log=glm(Direction ~ Lag1 + Lag2 + Lag3 + 
Lag4 + Lag5 + Volume ,family = "binomial",  data = trainData)
      glm.prob <- predict(glm_log, testData, "response")
      glm.pred <- ifelse(glm.prob >= 0.5, 1, 0)
      return(glm.pred)

    }
    )

上述函数的输出给出了每一折的预测值。

> head(log_cv)
  [,1] [,2] [,3] [,4] [,5]
1    1    1    0    1    1
2    0    1    1    1    1
3    0    1    1    1    0
4    1    1    0    1    1
5    1    1    1    1    1
6    1    1    1    0    1
> 

有什么方法可以结合上述结果使用 5 折交叉验证得到混淆矩阵?

标签: rcross-validationconfusion-matrix

解决方案


混淆矩阵由真阳性、假阳性、真阴性、假阴性的数量组成。从交叉验证中,您需要每个折叠的平均值。您有一个预测矩阵,log_cv需要将其与您的testData.

一种方法,虽然我确信这里的其他人会推荐 tidyverse,但是将您的测试数据转换为矩阵:

truth <- matrix(testData$response, ncol = 5, nrow = nrow(testData))

然后使用逻辑运算符来查找真阳性等:

真阳性:

mean(apply(truth & testData, 2, sum))

真正的否定:

mean(apply(!truth & !testData, 2, sum))

误报:

mean(apply(truth & !testData, 2, sum))

假阴性:

mean(apply(!truth & testData, 2, sum))

推荐阅读