首页 > 解决方案 > 在 Rcpp 中查找向量中所有最大值/最小值的索引

问题描述

假设我有一个向量

v = c(1,2,3)

我可以很容易地找到哪个元素是最大值

cppFunction('int which_maxCpp(NumericVector v) {
  int z = which_max(v);
  return z;
}')

which_maxCpp(v)

2

但是,如果我有一个向量,例如

v2 = c(1,2,3,1,2,3)

得到

which_maxCpp(v2)

2

而我应该发现索引 2 和索引 5(或索引 3 和索引 6,如果使用 1 索引)等于向量中的最大值

有没有办法让 which_max (或 which_min )找到向量的所有最小/最大元素的索引,或者是否需要另一个(我假设本机 C++)函数?

标签: rmaxrcpp

解决方案


我不知道原生函数,但循环编写起来相当简单。

这里有三个版本。

两个找到Rcpp::max()向量的,然后找到与这个最大值匹配的向量的索引。一个使用预分配Rcpp::IntegerVector()来存储结果,然后将其作为子集以删除额外的“未使用”零。另一个使用 astd::vector< int >和 a.push_back()来存储结果。

library(Rcpp)

cppFunction('IntegerVector which_maxCpp1(NumericVector v) {
  double m = Rcpp::max(v);
  Rcpp::IntegerVector res( v.size() );  // pre-allocate result vector

  int i;
  int counter = 0;
  for( i = 0; i < v.size(); ++i) {
    if( v[i] == m ) {
      res[ counter ] = i;
      counter++;
    }
  }
  counter--;
  Rcpp::Range rng(0, counter);  
  return res[rng];
}')

v = c(1,2,3,1,2,3)

which_maxCpp(v)
# [1] 2 5
cppFunction('IntegerVector which_maxCpp2(NumericVector v) {
  double m = Rcpp::max(v);
  std::vector< int > res;

  int i;
  for( i = 0; i < v.size(); ++i) {
    if( v[i] == m ) {
      res.push_back( i );
    }
  }
  Rcpp::IntegerVector iv( res.begin(), res.end() );
  return iv;
}')

which_maxCpp(v)
# [1] 2 5

第三个选项通过找到最大值并同时跟踪一个循环中的索引来避免对向量的双重传递。

cppFunction('IntegerVector which_maxCpp3(NumericVector v) {

  double current_max = v[0];
  int n = v.size();
  std::vector< int > res;
  res.push_back( 0 );
  int i;

  for( i = 1; i < n; ++i) {
    double x = v[i];
    if( x > current_max ) {
      res.clear();
      current_max = x;
      res.push_back( i );
    } else if ( x == current_max ) {
      res.push_back( i );
    }
  }
  Rcpp::IntegerVector iv( res.begin(), res.end() );
  return iv;
}')

基准测试

以下是一些基准,展示了这些函数如何与基本 R 方法叠加。

library(microbenchmark)

x <- sample(1:100, size = 1e6, replace = T)

microbenchmark(
  iv = { which_maxCpp1(x) },
  stl = { which_maxCpp2(x) },
  max = { which_maxCpp3(x) },
  r = { which( x == max(x)) } 
)

# Unit: milliseconds
# expr      min        lq      mean    median       uq        max neval
#   iv 6.638583 10.617945 14.028378 10.956616 11.63981 165.719783   100
#  stl 6.830686  9.506639  9.787291  9.744488 10.17247  11.275061   100
#  max 3.161913  5.690886  5.926433  5.913899  6.19489   7.427020   100
#    r 4.044166  5.558075  5.819701  5.719940  6.00547   7.080742   100

推荐阅读