r - 沿着数组的多个维度优化 which.max
问题描述
我有一些带有 4 维数组的代码,我需要在多个维度上应用 which.max。它很慢,我想找到加快速度的方法。
例子:
library(microbenchmark)
array4d <- array( runif(5*500*50*5 ,-1,0),
dim = c(5, 500, 50, 5) )
microbenchmark(
max_idx <- apply(array4d, c(1,2,3), which.max )
)
任何提示表示赞赏,谢谢!
编辑:我已经设法通过直接在 for 循环中编码来使它稍微快一点(虽然丑陋) - 但我希望那里有人有更好的想法!
method1 <- function(z) {
apply(z, c(1,2,3), which.max)
}
method2 <- function(z){
result <- array( , dim = dim(z)[1:3] )
for(i in 1:dim(z)[1]){
for(j in 1:dim(z)[2]){
for(k in 1:dim(z)[3]){
result[i, j, k] <- which.max(z[i,j,k,])
}
}
}
return(result)
}
microbenchmark(
result1 <- method1(array4d),
result2 <- method2(array4d))
> microbenchmark(
+ result1 <- method1(array4d),
+ result2 <- method2(array4d)
+ )
Unit: milliseconds
expr min lq mean median uq max neval cld
result1 <- method1(array4d) 111.9061 140.1400 165.2441 155.6773 170.3967 384.6425 100 b
result2 <- method2(array4d) 113.4572 123.2429 136.8583 130.8505 141.9620 215.0968 100 a
解决方案
当你想在 R 中增加一点速度并且你已经用尽了使用原生矢量化的简单收益时,你应该考虑Rcpp
.
这是一个使用 C++ 的算法的实现,它包含在对 的调用中Rcpp:cppFunction
:
Rcpp::cppFunction("
NumericVector apply_which_max(NumericVector input){
std::vector<int> dims = input.attr(\"dim\");
int last_dim = dims.back();
int diff = input.size()/last_dim;
std::vector<int> result(diff);
dims.pop_back();
for(int i = 0; i < diff; ++i)
{
double max_val = input[i];
int max_ind = i;
for(int j = i; j < input.size(); j += diff)
{
if(input[j] > max_val) {
max_val = input[j];
max_ind = j;
}
}
result[i] = max_ind / diff + 1;
}
NumericVector arr = wrap(result);
arr.attr(\"dim\") = dims;
return arr ;
}"
)
我们最好证明它具有相同的输出:
max_idx1 <- apply(array4d, c(1 ,2, 3), which.max)
max_idx2 <- apply_which_max(array4d)
all(max_idx1 == max_idx2)
#> [1] TRUE
但它更快吗?
使用基准测试(在我的慢速机器上),我们得到:
microbenchmark(apply(array4d, c(1 ,2, 3), which.max))
#> Unit: milliseconds
#> expr min lq mean median uq
#> apply(array4d, c(1, 2, 3), which.max) 243.2283 276.5381 342.5796 330.0815 403.8358
#> max neval
#> 543.8325 100
microbenchmark(apply_which_max(array4d))
#> Unit: milliseconds
#> expr min lq mean median uq max neval
#> apply_which_max(array4d) 3.1761 3.2725 3.713377 3.342 3.4088 12.962 100
也就是说,不仅更快,而且快了大约 100 倍。
推荐阅读
- amazon-cloudformation - 是否可以创建 Cloudfromation 模板以部署到 AWS EKS?
- file - 如何在 Flutter 中下载 mp3 文件并将其保存在用户设备存储中?
- java - 我如何创建 PDF 表单,而不是用 Java 中的用户生成的数据填充它,使用 iText?
- apache - 如何在 Apache 中允许多种身份验证类型?
- java - 如何转换 MultiValueMap
到地图 > 与流? - azure-table-storage - 使用逻辑应用从 Azure 表存储中获取值
- c - 二进制的无效操作数 - (有 'int' 和 'int *')
- java - scala - 打印 26*8 矩阵以获取时间戳的所有格式值
- firebase - Zapier Firebase 标记
- visual-studio - SSIS 错误 - 由于 80040153、注册表值无效,从 SQL 任务中解析查询失败