首页 > 解决方案 > 优化 R 中的矩阵计算 - 可以避免 aperm?

问题描述

我有一个矩阵计算,我想加快速度。

一些玩具数据和示例代码:

n = 2 ; d = 3
mu <- matrix(runif(n*d), nrow=n, ncol=d)
sig <- matrix(runif(n*d), nrow=n, ncol=d)
x_i <- c(0, 0, 1, 1)
not_missing <- !is.na(x_i)

calc1 <-function(n, d, mu, sig, x_i, not_missing){
    z <- array( rep(0, length(x_i)*n*d),
                dim = c(length(x_i), n, d))
    
    subtract_term <- 0.5*log(2*pi*sig)
    
    for(i in 1:length(x_i)){
        if( not_missing[i] ){
            z[i, , ] <- ((-(x_i[i] - mu)^2 / (2*sig)) - subtract_term  )
        }
    }
    z <- aperm(z, c( 2, 1, 3))
    return(z)
}

microbenchmark(
    z1 <- calc1(n, d, mu, sig, x_i, not_missing)
)

在使用真实数据进行剖析时,z[i, , ] <-线和aperm()线都是慢点。我一直在尝试aperm通过更早地转置 2D 矩阵以避免 3D 转置来优化它以避免完全调用,但是我无法正确地将 3D 数组放在一起。非常感谢任何帮助。

编辑:我有来自@G 的部分解决方案。Grothendieck 取消了 aperm,但由于某种原因,它并没有带来太多的速度提升。他回答的新解决方案是:

calc2 <-function(n, d, mu, sig, x_i, not_missing){
    nx <- length(x_i)
    z <- array( 0, dim = c(n, nx, d))
    
    subtract_term <- 0.5*log(2*pi*sig)
    
    for(i in 1:nx){
        if( not_missing[i] ) {
          z[, i, ] <- ((-(x_i[i] - mu)^2 / (2*sig)) - subtract_term  )
        }
    }
    return(z)
} 

速度比较:

> microbenchmark(
+         z1 <- calc1(n, d, mu, sig, x_i, not_missing),
+         z2 <- calc2(n, d, mu, sig, x_i, not_missing), times = 1000
+ )
Unit: microseconds
                                         expr    min      lq     mean  median     uq      max neval cld
 z1 <- calc1(n, d, mu, sig, x_i, not_missing) 13.586 14.2975 24.41132 14.5020 14.781 9125.591  1000   a
 z2 <- calc2(n, d, mu, sig, x_i, not_missing)  9.094  9.5615 19.98271  9.8875 10.202 9655.254  1000   a

标签: rmultidimensional-array

解决方案


这消除了 aperm。

calc2 <-function(n, d, mu, sig, x_i, not_missing){
    nx <- length(x_i)
    z <- array( 0, dim = c(n, nx, d))
    
    subtract_term <- 0.5*log(2*pi*sig)
    
    for(i in 1:nx){
        if( not_missing[i] ) {
          z[, i, ] <- ((-(x_i[i] - mu)^2 / (2*sig)) - subtract_term  )
        }
    }
    return(z)
} 

z1 <- calc1(n, d, mu, sig, x_i, not_missing)
z2 <- calc2(n, d, mu, sig, x_i, not_missing)

identical(z1, z2)
## [1] TRUE

推荐阅读