首页 > 解决方案 > Tidymodels:使用 R 中的函数 collect_predictions() 和 ggplot() 绘制预测值与真实值

问题描述

概述

我使用带有数据框 FID 的 tidymodels 包制作了四个模型(见下文):

  1. 一般线性模型
  2. 袋装树
  3. 随机森林
  4. 增强树

数据框包含三个预测变量:

  1. 年份(数字)
  2. 月(因子)
  3. 天数(数字)

因变量是频率(数字)

我正在关注本教程:-

https://smltar.com/mlregression.html#firstregressionevaluation

问题

我想绘制关于我的模型表现如何以及这些值是否可以在不同类型的模型之间进行比较的定量估计。对于所有重新采样的数据集,我想直观地显示预测频率(因变量)在发布时与真实频率的对比(请参阅下面我想要绘制的示例)。

我遇到此错误消息

Error in FUN(X[[i]], ...) : object 'Frequency' not found

如果有人可以帮助我解决这个错误,我将不胜感激。

提前谢谢了。

R-代码

##Open the tidymodels package
library(tidymodels)
library(glmnet)
library(parsnip)
library(rpart.plot)
library(rpart)
library(tidyverse) # manipulating data
library(skimr) # data visualization
library(baguette) # bagged trees
library(future) # parallel processing & decrease computation time
library(xgboost) # boosted trees
library(ranger)
library(yardstick)
library(purrr)
library(forcats)
library(ggplot)

#split this single dataset into two: a training set and a testing set
data_split <- initial_split(FID)
# Create data frames for the two sets:
train_data <- training(data_split)
test_data  <- testing(data_split)

# resample the data with 10-fold cross-validation (10-fold by default)
cv <- vfold_cv(train_data, v=10)

###########################################################
##Produce the recipe

rec <- recipe(Frequency ~ ., data = FID) %>% 
          step_nzv(all_predictors(), freq_cut = 0, unique_cut = 0) %>% # remove variables with zero variances
          step_novel(all_nominal()) %>% # prepares test data to handle previously unseen factor levels 
          step_medianimpute(all_numeric(), -all_outcomes(), -has_role("id vars"))  %>% # replaces missing numeric observations with the median
          step_dummy(all_nominal(), -has_role("id vars")) # dummy codes categorical variables

##########################################################
##Produce Models
##########################################################
##General Linear Models
##########################################################

##Produce the glm model
mod_glm<-linear_reg(mode="regression",
                       penalty = 0.1, 
                       mixture = 1) %>% 
                            set_engine("glmnet")

##Create workflow
wflow_glm <- workflow() %>% 
                add_recipe(rec) %>%
                      add_model(mod_glm)

##Fit the model

###########################################################################
##Estimate how well that model performs, let’s fit many times, 
##once to each of these resampled folds, and then evaluate on the heldout 
##part of each resampled fold.
##########################################################################
plan(multisession)

fit_glm <- fit_resamples(
                        wflow_glm,
                        cv,
                        metrics = metric_set(rmse, rsq),
                        control = control_resamples(save_pred = TRUE)
                        )
##Collect model predictions for each fold for the number of blue whale sightings

Predictions<-fit_glm %>% 
                  collect_predictions()

Predictions

##Plot the predicted and true values 
fit_glm %>%
      collect_predictions() %>%
      ggplot(aes(Frequency, .pred, color = id)) +
      geom_abline(lty = 2, color = "gray80", size = 1.5) +
      geom_point(alpha = 0.3) +
      labs(
      x = "Truth",
      y = "Predicted year",
      color = NULL,
      title = "Predicted and True Years for Frequency",
      subtitle = "Each Cross-validation Fold is Shown in a Different Color"
      )

所需的情节

在此处输入图像描述

数据帧 - FID

structure(list(Year = c(2015, 2015, 2015, 2015, 2015, 2015, 2015, 
2015, 2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016, 2016, 2016, 
2016, 2016, 2016, 2016, 2016, 2016, 2017, 2017, 2017, 2017, 2017, 
2017, 2017, 2017, 2017, 2017, 2017, 2017), Month = structure(c(1L, 
2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 1L, 2L, 3L, 4L, 
5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 
8L, 9L, 10L, 11L, 12L), .Label = c("January", "February", "March", 
"April", "May", "June", "July", "August", "September", "October", 
"November", "December"), class = "factor"), Frequency = c(36, 
28, 39, 46, 5, 0, 0, 22, 10, 15, 8, 33, 33, 29, 31, 23, 8, 9, 
7, 40, 41, 41, 30, 30, 44, 37, 41, 42, 20, 0, 7, 27, 35, 27, 
43, 38), Days = c(31, 28, 31, 30, 6, 0, 0, 29, 15, 
29, 29, 31, 31, 29, 30, 30, 7, 0, 7, 30, 30, 31, 30, 27, 31, 
28, 30, 30, 21, 0, 7, 26, 29, 27, 29, 29)), row.names = c(NA, 
-36L), class = "data.frame")

标签: rmachine-learningggplot2regressiontidymodels

解决方案


推荐阅读