首页 > 解决方案 > 如何在 R 中编译神经网络的网格搜索

问题描述

我正在尝试为我想在 R 中创建的神经网络找到最佳参数。我正在使用 h2o 包并遵循https://www.kaggle.com/wti200/deep-neural-network-parameter-中的教程搜索-r/评论

我的代码似乎在 1 分钟内运行,据我了解,网格搜索应该运行多个模型,直到确定最佳参数并且需要一段时间才能运行。请让我知道哪里出了问题以及如何进行网格搜索以优化我的参数。

h2o.init(nthreads=-1,max_mem_size='6G')
testHex = as.h2o(test)
trainHex = as.h2o(training)

predictors <-colnames(training)[!(colnames(training) %in% c("responseVar"))]
response = "responseVar"

hyper_params <- list(
  activation=c("Rectifier","Tanh","Maxout","RectifierWithDropout","TanhWithDropout","MaxoutWithDropout"),
  hidden=list(c(20,20),c(50,50),c(75,75),c(100,100),c(30,30,30),c(25,25,25,25)),
  input_dropout_ratio=c(0,0.03,0.05),
  #rate=c(0.01,0.02,0.05),
  l1=seq(0,1e-4,1e-6),
  l2=seq(0,1e-4,1e-6)
)
h2o.rm("dl_grid_random")

search_criteria = list(strategy = "RandomDiscrete", max_runtime_secs = 360, max_models = 100, seed=1234567, stopping_rounds=5, stopping_tolerance=1e-2)
dl_random_grid <- h2o.grid(
  algorithm="deeplearning",
  grid_id = "dl_grid_random",
  training_frame=trainHex,
  x=predictors, 
  y=response,
  epochs=1,
  stopping_metric="RMSE",
  stopping_tolerance=1e-2,        ## stop when logloss does not improve by >=1% for 2 scoring events
  stopping_rounds=2,
  score_validation_samples=10000, ## downsample validation set for faster scoring
  score_duty_cycle=0.025,         ## don't score more than 2.5% of the wall time
  max_w2=10,                      ## can help improve stability for Rectifier
  hyper_params = hyper_params,
  search_criteria = search_criteria
)                            

grid <- h2o.getGrid("dl_grid_random",sort_by="mae",decreasing=FALSE)
grid

grid@summary_table[1,]
best_model <- h2o.getModel(grid@model_ids[[1]]) ## model with lowest logloss
best_model

标签: rneural-networkh2ogrid-searchhyperparameters

解决方案


您已经在网格设置中设置了max_runtime_secs = 360via search_criteria,因此它可能运行的最长时间是 6 分钟。

如果网格在此之前停止,则意味着您的提前停止设置正在触发网格提前停止。如果您希望它运行更长时间,那么您可以增加stopping_rounds网格和/或增加stopping_tolerance并且它应该对提前停止和运行时间不太敏感(在上面的代码中,您将它们分别设置为51e-2)。您可能还想设置stopping_metric = "RMSE",因为回归的默认值是平均残差 ( "deviance")。在用户指南中查看有关停止指标的更多信息。

我注意到你有stopping_rounds = 2单独的 DNN 模型(如果你的模型拟合不足,你可能想尝试增加它)。


推荐阅读