首页 > 解决方案 > 将 arma::cube 子视图转换为 NumericVector 以使用糖

问题描述

我将一个 3D 数组从 R 传递到 C++ 并遇到了类型转换问题。我们如何arma::cube subviews从 RcppArmadillo 转换NumericVectors为使用 Rcpp 中的糖函数对它们进行操作which_min

假设您有一个Q带有一些数字条目的 3D 立方体。我的目标是获取每行i和每个第三维的列条目最小值的索引k。在 R 语法中,这是which.min(Q[i,,k]).

例如对于i = 1k = 1

cube Q = randu<cube>(3,3,3);
which_min(Q.slice(1).row(1)); // this fails

我认为转换为 NumericVector 可以解决问题,但是这种转换失败了

which_min(as<NumericVector>(Q.slice(1).row(1))); // conversion failed

我怎样才能让它工作?谢谢您的帮助。

标签: c++rcpparmadillo

解决方案


你有几个选择:

  1. 您可以为此使用 Armadillo 函数,即成员函数(请参阅此处.index_min()的Armadillo 文档)。
  2. 您可以使用Rcpp::wrap()将任意对象转换为 SEXP”arma::cube subviews转换为 aRcpp::NumericVector并使用糖函数Rcpp::which_min()

最初我只是将第一个选项作为答案,因为它似乎是实现目标的更直接的方法,但我添加了第二个选项(在答案的更新中),因为我现在认为任意转换可能是什么的一部分你很好奇。

我将以下 C++ 代码放在一个文件中so-answer.cpp

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

// [[Rcpp::export]]
Rcpp::List index_min_test() {
    arma::cube Q = arma::randu<arma::cube>(3, 3, 3);
    int whichmin = Q.slice(1).row(1).index_min();
    Rcpp::List result = Rcpp::List::create(Rcpp::Named("Q") = Q,
                                           Rcpp::Named("whichmin") = whichmin);
    return result;
}

// [[Rcpp::export]]
Rcpp::List which_min_test() {
    arma::cube Q = arma::randu<arma::cube>(3, 3, 3);
    Rcpp::NumericVector x = Rcpp::wrap(Q.slice(1).row(1));
    int whichmin = Rcpp::which_min(x);
    Rcpp::List result = Rcpp::List::create(Rcpp::Named("Q") = Q,
                                           Rcpp::Named("whichmin") = whichmin);
    return result;
}

我们有一个使用犰狳的函数.index_min()和一个用于Rcpp::wrap()启用Rcpp::which_min().

然后我用Rcpp::sourceCpp()它来编译它,使函数可用于 R,并演示用几个不同的种子调用它们:

Rcpp::sourceCpp("so-answer.cpp")
set.seed(1)
arma <- index_min_test()
set.seed(1)
wrap <- which_min_test()
arma$Q[2, , 2]
#> [1] 0.2059746 0.3841037 0.7176185
wrap$Q[2, , 2]
#> [1] 0.2059746 0.3841037 0.7176185
arma$whichmin
#> [1] 0
wrap$whichmin
#> [1] 0
set.seed(2)
arma <- index_min_test()
set.seed(2)
wrap <- which_min_test()
arma$Q[2, , 2]
#> [1] 0.5526741 0.1808201 0.9763985
wrap$Q[2, , 2]
#> [1] 0.5526741 0.1808201 0.9763985
arma$whichmin
#> [1] 1
wrap$whichmin
#> [1] 1
library(microbenchmark)
microbenchmark(arma = index_min_test(), wrap = which_min_test())
#> Unit: microseconds
#>  expr    min      lq     mean  median      uq    max neval cld
#>  arma 12.981 13.7105 15.09386 14.1970 14.9920 62.907   100   a
#>  wrap 13.636 14.3490 15.66753 14.7405 15.5415 64.189   100   a

reprex 包(v0.2.1)于 2018 年 12 月 21 日创建


推荐阅读