首页 > 解决方案 > 如何创建自定义模型(在插入符号中使用循环/子模型技巧)

问题描述

我与这个问题争论了很长时间。我觉得自己像个白痴,因为答案可能很明显,但我找不到一个解释如何做到这一点的线程。

关于自定义模型创建的文档部分对我来说是这样的。我觉得我在教育期间的某个地方错过了一些非常具体的课程,现在每个人都记得,但我,因为我发现的只是“是的,只需创建一个自定义模型,然后完成”。

这里的实际问题:

我想获得对 in 的每一次迭代的gbm预测caret。例如,gbm我可以只使用n.treesin predict(..., n.trees = 1:100),它就完成了。

显然,caret我需要使用称为子模型技巧的东西,这意味着 - 如果我理解正确 - 我必须创建自己的自定义模型。

但我可以看到getModelInfo('gbm'),有某种循环功能!

$gbm$loop
function (grid) 
{
    loop <- plyr::ddply(grid, c("shrinkage", "interaction.depth", 
        "n.minobsinnode"), function(x) c(n.trees = max(x$n.trees)))
    submodels <- vector(mode = "list", length = nrow(loop))
    for (i in seq(along = loop$n.trees)) {
        index <- which(grid$interaction.depth == loop$interaction.depth[i] & 
            grid$shrinkage == loop$shrinkage[i] & grid$n.minobsinnode == 
            loop$n.minobsinnode[i])
        trees <- grid[index, "n.trees"]
        submodels[[i]] <- data.frame(n.trees = trees[trees != 
            loop$n.trees[i]])
    }
    list(loop = loop, submodels = submodels)

我该如何使用它?为什么默认情况下它不起作用?我是否真的需要创建一个自定义模型 - 或者可能不需要?

免责声明 1:我不想使用任何交叉验证。对于单个 gbm 运行的每次迭代,我只想提取预测。

免责声明 2:我不想使用predict.gbm()on $finalModel,因为我还想测试一些其他算法,这些算法也利用了那个子模型技巧。我不想使用所有不同的算法特定predict()功能,因为那我为什么还要打扰插入符号。

我什至不知道我应该把什么作为可复制的例子。代码没有问题。我只是不知道这东西应该如何工作。

标签: rmachine-learningr-caretgbm

解决方案


这是一个关于如何为每棵树的测试数据提取所需预测的示例:

library(caret)
library(mlbench) #for the data set
data(Sonar) #some data set I always use on stack overflow

res <- train(Class~.,
             data = Sonar,
             method = "gbm",
             trControl = trainControl(method = "cv", #some evaluations scheme
                                      number = 5,
                                      savePredictions = "all"), #tell caret you would like to save all,
             tuneGrid = expand.grid(shrinkage = 0.01,
                                    interaction.depth = 2, 
                                    n.minobsinnode = 10,
                                    n.trees = 1:100)) #some random values and all the trees

res$pred #results are stored in here

基本上,您在帖子中显示的代码告诉插入符号不要调整所有 n.tree 模型,而只是调整max(n.trees)每个超参数组合的模型,然后使用它来获得预测n.trees < max(n.trees)

一些情节

library(ggplot2)

ggplot(res$results)+
  geom_line(aes(x = n.trees, y = Accuracy))

在此处输入图像描述

您也可以选择不savePredictions = "all"这样做,因为这会导致内存不足的火车对象。而是使用res$results它来计算所有所需的指标。


推荐阅读