首页 > 解决方案 > 如何在 r 中的 rpart() 中关闭 k 折交叉验证

问题描述

我有比特币时间序列,我使用 11 个技术指标作为特征,我想将回归树拟合到数据中。据我所知,r中有两个函数可以创建回归树,即rpart()和tree(),但这两个函数似乎都不合适。rpart() 使用 k 折交叉验证来验证最优成本复杂度参数 cp,而在 tree() 中,无法指定 cp 的值。

我知道 cv.tree() 通过交叉验证寻找 cp 的最佳值,但同样, cv.tee() 使用 k 折交叉验证。由于我有时间序列,因此有时间依赖性,我不想使用 k 折交叉验证,因为 k 折交叉验证会将数据随机分成 k 折,将模型拟合到 k-1 折并计算左边第k折的MSE,然后我的时间序列的序列显然被破坏了。

我找到了 rpart() 函数的一个参数,即 xval,它应该让我指定交叉验证的数量,但是当我查看当 xval=0 时的 rpart() 函数调用的输出时,它没有似乎交叉验证已关闭。下面你可以看到我的函数调用和输出:

tree.model= rpart(Close_5~ M+ DSMA+ DWMA+ DEMA+ CCI+ RSI+ DKD+ R+ FI+ DVI+ 
OBV, data= train.subset, method= "anova", control= 
rpart.control(cp=0.01,xval= 0, minbucket = 5))

> summary(tree.model)
Call:
rpart(formula = Close_5 ~ M + DSMA + DWMA + DEMA + CCI + RSI + 
DKD + R + FI + DVI + OBV, data = train.subset, method = "anova", 
control = rpart.control(cp = 0.01, xval = 0, minbucket = 5))
n= 590 

           CP nsplit rel error
1  0.35433076      0 1.0000000
2  0.10981049      1 0.6456692
3  0.06070669      2 0.5358587
4  0.04154720      3 0.4751521
5  0.02415633      5 0.3920576
6  0.02265346      6 0.3679013
7  0.02139752      8 0.3225944
8  0.02096500      9 0.3011969
9  0.02086543     10 0.2802319
10 0.01675277     11 0.2593665
11 0.01551861     13 0.2258609
12 0.01388126     14 0.2103423
13 0.01161287     15 0.1964610
14 0.01127722     16 0.1848482
15 0.01000000     18 0.1622937

似乎 rpart() 交叉验证了 15 个不同的 cp 值。如果这些值是用 k 折交叉验证来测试的,那么我的时间序列的顺序又会被破坏,我基本上不能使用这些结果。有谁知道我如何有效地关闭 rpart() 中的交叉验证,或者如何改变 tree() 中 cp 的值?

更新:我听从了我们一位同事的建议并设置了 xval=1,但这似乎并没有解决问题。当 xval=1 here时,您可以看到完整的函数输出。顺便说一句,parameters[j] 是参数向量的第 j 个元素。当我调用这个函数时,parameters[j]= 0.0009765625

提前谢谢了

标签: rtreecross-validationrpart

解决方案


为了证明rpart()通过迭代cp与重采样的递减值来创建树节点,我们将使用包中的Ozone数据mlbench来比较 OP 的结果,rpart()caret::train()在对 OP 的评论中讨论。我们将设置臭氧数据,如支持向量机的 CRAN 文档中所示,它支持非线性回归并且与rpart().

library(rpart)
library(caret)
data(Ozone, package = "mlbench")
# split into test and training
index <- 1:nrow(Ozone)
set.seed(01381708)
testIndex <- sample(index, trunc(length(index) / 3))
testset <- na.omit(Ozone[testIndex,-3])
trainset <- na.omit(Ozone[-testIndex,-3])


# rpart version
set.seed(95014) #reset seed to ensure sample is same as caret version
rpart.model <- rpart(V4 ~ .,data = trainset,xval=0)
# summary(rpart.model)
# calculate RMSE
rpart.pred <- predict(rpart.model, testset[,-3])
crossprod(rpart.pred - testset[,3]) / length(testIndex)

...以及 RMSE 计算的输出:

> crossprod(rpart.pred - testset[,3]) / length(testIndex)
         [,1]
[1,] 18.25507

接下来,我们将caret::train()按照对 OP 的评论中的建议进行相同的分析。

# caret version
set.seed(95014)
rpart.model <- caret::train(x = trainset[,-3],
                            y = trainset[,3],method = "rpart", trControl = trainControl(method = "none"), 
                            metric = "RMSE", tuneGrid = data.frame(cp=0.01), 
                            preProcess = c("center", "scale"), xval = 0, minbucket = 5)
# summary(rpart.model)
# demonstrate caret version did not do resampling
rpart.model
# calculate RMSE, which matches RMSE from rpart() 
rpart.pred <- predict(rpart.model, testset[,-3])
crossprod(rpart.pred - testset[,3]) / length(testIndex)

当我们从中打印模型输出时,caret::train()它清楚地指出没有重新采样。

> rpart.model
CART 

135 samples
 11 predictor

Pre-processing: centered (9), scaled (9), ignore (2) 
Resampling: None

caret::train()版本的 RMSE 与rpart().

> # calculate RMSE, which matches RMSE from rpart() 
> rpart.pred <- predict(rpart.model, testset[,-3])
> crossprod(rpart.pred - testset[,3]) / length(testIndex)
         [,1]
[1,] 18.25507
> 

结论

首先,如上面配置的那样,重采样也不是,caret::train()也不rpart()是重采样。但是,如果打印模型输出,则会看到多个值cp用于通过这两种技术生成 47 个节点的最终树。

插入符号的输出summary(rpart.model)

          CP nsplit rel error
1 0.58951537      0 1.0000000
2 0.08544094      1 0.4104846
3 0.05237152      2 0.3250437
4 0.04686890      3 0.2726722
5 0.03603843      4 0.2258033
6 0.02651451      5 0.1897648
7 0.02194866      6 0.1632503
8 0.01000000      7 0.1413017

rpart 的输出summary(rpart.model)

          CP nsplit rel error
1 0.58951537      0 1.0000000
2 0.08544094      1 0.4104846
3 0.05237152      2 0.3250437
4 0.04686890      3 0.2726722
5 0.03603843      4 0.2258033
6 0.02651451      5 0.1897648
7 0.02194866      6 0.1632503
8 0.01000000      7 0.1413017

month其次,两个模型都通过将和day变量作为自变量来考虑时间值。在Ozone数据集中,V1是月变量,V2是日变量。所有数据都是在 1976 年收集的,因此数据集中不包含年份变量,并且在svm小插图的原始分析中,在分析之前删除了星期几。

第三,为了使用算法来解释其他基于时间的影响,rpart()或者svm()当日期属性不用作模型中的特征时,必须将滞后效应作为模型中的特征包括在内,因为这些算法不直接考虑时间分量。如何使用一系列滞后值使用回归树集合来执行此操作的一个示例是用于时间序列预测的集合回归树


推荐阅读