首页 > 解决方案 > 即使在 R 中进行记忆,递归也很慢

问题描述

我正在尝试解决Project Euler 的问题 #14。所以主要目标是找到 Collat​​z 序列的长度。

首先,我用常规循环解决了问题:

compute <- function(n) {
    result <- 0
    max_chain <- 0
    hashmap <- 1
    for (i in 1:n) {
        chain <- 1
        number <- i
        while (number > 1) {
            if (!is.na(hashmap[number])) {
                chain <- chain + hashmap[number]
                break
            }
            if (number %% 2 == 0) {
                chain <- chain + 1
                number <- number / 2
            } else {
                chain <- chain + 2
                number <- (3 * number + 1) / 2
            }
        }
        hashmap[i] <- chain
        if (chain > max_chain) {
            max_chain <- chain
            result <- i
        }
    }
    return(result)
}

仅 2 秒n = 1000000。我决定将while循环替换为递归

len_collatz_chain <- function(n, hashmap) {
    get_len <- function(n) {
        if (is.na(hashmap[n])) {
            hashmap[n] <<- ifelse(n %% 2 == 0, 1 + get_len(n / 2), 2 + get_len((3 * n + 1) / 2))
        }
        return(hashmap[n])
    }
    get_len(n)
    return(hashmap)
}

compute <- function(n) {
    result <- 0
    max_chain <- 0
    hashmap <- 1
    for (i in 1:n) {
        hashmap <- len_collatz_chain(i, hashmap)
        print(length(hashmap))
        if (hashmap[i] > max_chain) {
            max_chain <- hashmap[i]
            result <- i
        }
    }
    return(result)
}

该解决方案有效,但运行速度很慢。差不多1分钟n = 10000。我想原因之一是 Rhashmap每次调用函数时都会创建对象len_collatz_chain

我知道 Rcpp 包,是的,第一个解决方案工作正常,但我不明白我错在哪里。有小费吗?

例如,我的 Python 递归解决方案可在 1 秒内使用n = 1000000

def len_collatz_chain(n: int, hashmap: dict) -> int:
    if n not in hashmap:
        hashmap[n] = 1 + len_collatz_chain(n // 2, hashmap) if n % 2 == 0 else 2 + len_collatz_chain((3 * n + 1) // 2, hashmap)
    return hashmap[n]

def compute(n: int) -> int:
    result, max_chain, hashmap = 0, 0, {1: 1}
    for i in range(2, n):
        chain = len_collatz_chain(i, hashmap)
        if chain > max_chain:
            result, max_chain = i, chain
    return result

标签: rrecursionmemoizationcollatz

解决方案


R 和 Python 代码之间的主要区别在于,在 R 中,您使用向量作为哈希图,而在 Python 中,您使用字典,并且哈希图作为函数参数多次传输。

在 Python 中,如果您有一个 Dictionary 作为函数参数,那么只有对实际数据的引用才会传输到被调用的函数。这很快。被调用函数与调用者处理相同的数据。

在 R 中,向量在用作函数参数时会被复制。这可能会很慢,但在被调用函数无法更改调用者数据的意义上更安全。

这是 Python 在您的代码中速度如此之快的主要原因。

但是,您可以稍微更改 R 代码,以便不再将 hashmap 作为函数参数传输:

len_collatz_chain <- local({
  
  hashmap <- 1L
  
  get_len <- function(n) {
    if (is.na(hashmap[n])) {
      hashmap[n] <<- ifelse(n %% 2 == 0, 1 + get_len(n / 2), 2 + get_len((3 * n + 1) / 2))
    }
    hashmap[n]
  }
  
  get_len
})


compute <- function(n) {
  result <- rep(NA_integer_, n)
  
  for (i in seq_len(n)) {
    result[i] <- len_collatz_chain(i)
  }
  result
}

compute(n=10000)

这使得 R 代码更快。(尽管 Python 可能会更快)。

请注意,我还删除了returnR 代码中的语句,因为它们不是必需的,并在调用堆栈中添加了一层。


推荐阅读