首页 > 解决方案 > SVM 未使用“probability = TRUE”进行训练,概率不可用于预测

问题描述

我在尝试使用 mlr3 输出 SVM 的预测概率时遇到了问题。

library(mlr3)
task = mlr_tasks$get("iris")
svm_learner = mlr_learners$get("classif.svm")
train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)

svm_learner$train(task, row_ids = task$row_ids[train_set])
svm_learner$predict_type<-"prob"
prediction<-svm_learner$predict(task,row_ids = task$row_ids[test_set])
prediction
Warning message:
In predict.svm(self$model, newdata = newdata, probability = (self$predict_type ==  :
  SVM has not been trained using `probability = TRUE`, probabilities not available for predictions.



Session info
> sessionInfo(package = NULL)
R version 3.6.2 (2019-12-12)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 17763)

Matrix products: default

Random number generation:
 RNG:     Mersenne-Twister 
 Normal:  Inversion 
 Sample:  Rounding 

locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252   
[3] LC_MONETARY=English_United States.1252 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] precrec_0.10.1     forcats_0.4.0      stringr_1.4.0      purrr_0.3.3        readr_1.3.1       
 [6] tidyr_1.0.0        tibble_2.1.3       tidyverse_1.2.1    dplyr_0.8.3        mlr3learners_0.1.5
[11] GGally_1.4.0       ggplot2_3.2.1      mlr3_0.1.6         mlr3viz_0.1.0      e1071_1.7-3       
[16] biomaRt_2.38.0    

loaded via a namespace (and not attached):
 [1] Biobase_2.42.0       httr_1.4.1           bit64_0.9-7          jsonlite_1.6        
 [5] modelr_0.1.4         assertthat_0.2.1     lgr_0.3.3            stats4_3.6.2        
 [9] blob_1.2.0           cellranger_1.1.0     mlr3misc_0.1.6       progress_1.2.2      
[13] pillar_1.4.3         RSQLite_2.1.2        backports_1.1.5      lattice_0.20-38     
[17] glue_1.3.1           uuid_0.1-2           digest_0.6.23        RColorBrewer_1.1-2  
[21] checkmate_1.9.4      rvest_0.3.3          colorspace_1.4-1     plyr_1.8.5          
[25] XML_3.98-1.20        pkgconfig_2.0.3      mlr3measures_0.1.1   broom_0.5.2         
[29] haven_2.1.0          scales_1.0.0         generics_0.0.2       IRanges_2.16.0      
[33] withr_2.1.2          BiocGenerics_0.28.0  lazyeval_0.2.2       cli_2.0.0           
[37] magrittr_1.5         crayon_1.3.4         readxl_1.3.1         paradox_0.1.0       
[41] memoise_1.1.0        fansi_0.4.0          nlme_3.1-142         xml2_1.2.0          
[45] class_7.3-15         tools_3.6.2          data.table_1.12.8    prettyunits_1.0.2   
[49] hms_0.5.2            lifecycle_0.1.0      S4Vectors_0.20.1     munsell_0.5.0       
[53] AnnotationDbi_1.44.0 compiler_3.6.2       rlang_0.4.1          grid_3.6.2          
[57] RCurl_1.95-4.12      rstudioapi_0.10      bitops_1.0-6         labeling_0.3        
[61] gtable_0.3.0         DBI_1.0.0            reshape_0.8.8        reshape2_1.4.3      
[65] R6_2.4.1             lubridate_1.7.4      bit_1.1-14           zeallot_0.1.0       
[69] stringi_1.4.3        parallel_3.6.2       Rcpp_1.0.2           vctrs_0.2.1         
[73] tidyselect_0.2.5

我知道 SVM 不会输出概率,但 SVM 可以将预测数据拟合到分离超平面函数并从超平面获得带符号的距离度量。我想检索带符号的距离,然后用它们来计算 AUC。但是predict_type<-"response",我只能得到预测的类,而不是有符号的距离。predict_type<-"probability",我得到了上面的错误。

标签: mlr3

解决方案


你的代码是倒退的。修改如下:

library(mlr3)
task = mlr_tasks$get("iris")
svm_learner = mlr_learners$get("classif.svm")
train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)


svm_learner$predict_type<-"prob"
svm_learner$train(task, row_ids = task$row_ids[train_set])
prediction<-svm_learner$predict(task,row_ids = task$row_ids[test_set])
prediction

注意改变predict_type然后训练。


推荐阅读