首页 > 解决方案 > predict.xgb.Booster "multi:softprob":理解维度矩阵列中输入类之间的关系

问题描述

我以这种方式使用predict.xgb.Booster目标"multi:softprob"和 3 个类:

library(xgboost)
data(iris)

iris$Species <- as.factor(iris$Species)

# extract 80% random samples as training set
ix <- sample(nrow(iris), 0.8 * nrow(iris))
# all
all <- xgb.DMatrix(data.matrix(iris[, 1:ncol(iris)-1]),
                   label = as.numeric(iris$Species)-1)
# training set
train <- xgb.DMatrix(data.matrix(iris[ix, 1:ncol(iris)-1]),
                     label = as.numeric(iris$Species[ix])-1)
# test set (20% of the dataset)
test <- xgb.DMatrix(data.matrix(iris[-ix, 1:ncol(iris)-1]),
                    label = as.numeric(iris$Species[-ix])-1)

params <- list(
  objective = "multi:softprob",
  learning_rate = 0.05,
  subsample = 0.9,
  colsample_bynode = 1,
  reg_lambda = 2,
  max_depth = 35,
  num_class = length(unique(iris$Species))
)

# https://www.rdocumentation.org/packages/xgboost/versions/1.4.1.1/topics/xgb.train
mod <- xgb.train(
  params,
  data = train,
  watchlist = list(valid = test),
  early_stopping_rounds = 50,
  print_every_n = 100,
  nrounds = 10000 # early stopping
)

pred <- predict(mod, newdata = all, reshape = TRUE)

使用上面的代码,pred看起来像:

            V1         V2         V3
1   0.95375967 0.02518489 0.02105547
2   0.95375967 0.02518489 0.02105547
3   0.95375967 0.02518489 0.02105547
...

我需要创建一个向量来存储每行中具有最高值的类。

例如,在上面的示例中,它将V1适用于所有三行。

我的疑问是,我怎么知道 V1 到 V3 指的是我的 3 个课程中的哪一个?

标签: rxgboostpredict

解决方案


推荐阅读