首页 > 解决方案 > 如何测试用户提供的参数是否与函数默认方法的形式参数匹配?

问题描述

我在 R 中编程,需要检查用户在列表中提供的所有参数是否都是函数params的有效参数fparams如果 的任何命名元素与 的参数名称不匹配,则该函数应返回错误f。目前实现的方式是:

if (is.null(names(params)) || any(!names(params) %in% names(formals(f)))) {
        stop("names of params must match arguments of f")
    }

我在使用函数测试这段代码时遇到了问题caret::train。该函数train只有一个形式参数x,因此names(formals(caret::train))返回c('x','...')。但是,默认 S3 方法caret::train具有额外的形式参数。如果它们只是其中一种方法的参数而不是函数本身的参数,我如何以编程方式测试用户输入的命名列表是否与函数参数匹配?这应该是适用于任何功能的通用解决方案,而不仅仅是train.

可重现的例子

my_fun <- function(f, params) {

    if (is.null(names(params)) || any(!names(params) %in% names(formals(f)))) {
        stop("names of params must match arguments of f")
    }

    do.call(f, params) 

}

library(caret)

# my_fun returns error: names of params must match arguments of f
my_fun(f = caret::train, params = list(x = data.frame(x = 1:10), y = rep(1,10), method = 'rf'))

# caret::train returns error: argument "y" is missing, with no default
my_fun(f = caret::train, params = list(x = data.frame(x = 1:10)))

标签: r

解决方案


我认为这并不容易。您可以检查 S3,就像我在下面所做的那样,这主要是有效的。但是,这并没有考虑 R 拥有的其他 OOP 系统。而且,即使这个实现也不是那么稳定,例如,如果您使用符号,这将不起作用::(即train有效但caret::train不有效)。最后一点是因为getS3method不适用于::符号,不知道为什么。

my_fun <- function(f, params, env = parent.frame()) {
  # check for S3 generic
  if (isS3stdGeneric(f)) {
    s <- deparse(substitute(f))
    dispatch_arg <- formalArgs(f)[1]
    classes_to_check <- c(class(params[[dispatch_arg]]), 'default')
    for (i in seq_along(classes_to_check)) {
      f <- getS3method(s, classes_to_check[i], optional = TRUE, parent.frame(n = 2))
      if (is.function(f)) break
    }
  }
  if (is.null(names(params)) || !all(names(params) %in% formalArgs(f))) {
    stop("names of params must match arguments of f", call. = FALSE)
  }
  do.call(f, params)
}

例子:

library(caret)

my_fun(f = train, params = list(x = data.frame(x = 1:10), y = rep(1,10), method = 'rf'))
# works

my_fun(f = train, params = list(x = data.frame(x = 1:10)))
# argument "y" is missing, with no default

推荐阅读