首页 > 技术文章 > 水池抽样算法

r1-12king 2022-01-16 18:07 原文

引入

有这样一类问题, 就是大数据流中的随机抽样问题,即:

当内存无法加载全部数据时,如何从包含未知大小的数据流中随机选取k个数据,并且要保证每个数据被抽取到的概率相等。

这道题有两个限制:

  1. 高效,即节省内存的使用
  2. 尽量随机地返回值

假如我们去掉限制1,可以很简单地做出来:我们将所有数据加载进内存,计算链表长度,然后通过random函数来求取几个随机数。

但是这样的效率并不高,把所有数据加载到内存,如果数据非常大可能会导致无法计算。

 

注意题目中有一个说明,就是链表。链表这种数据结构是通过数据节点首尾相连形成的链式存储结构。

既然是链表,那么可以一个一个节点处理,不需要将所有数据加载到内存,一个节点一个节点去处理,这还不够形象,将题目换个形式来表述:

我们有大量的文本文件存在硬盘中,想随机抽取几行,保证尽可能少得使用内存并且能够完全随机.

之前想到的加载到内存就不太适合了,但是还可以想到别的办法,比如每次读取一行记录加载到内存,记数+1,清空内存中行数据,直到最后统计一共多少行,然后根据总行数来计算K个随机数。

如何再取回行对应的数据呢?我们可以再遍历一遍,一边遍历一边记录这一行的行数是不是在k个随机数中,如果是,则将该行内容保留。

这样的话遍历两次应该可以做到,但是数据量大的时候遍历两次的时间消耗是非常高的。

所以还有更好的方案吗,那就是水塘抽样算法。

 

 

水塘抽样算法

对于复杂问题一定要学会归纳总结,即从小例子入手,然后分析,得出结论,然后在证明。不然遇到一个抽象问题,不举例感觉这个问题,直接解还是比较难的。

接下来从 k = 1 开始说明

k = 1

首先考虑最简单的情况,当 k = 1 时,如何选取:

  • 假设数据流含有N个数据,要保证每条数据被抽取到的概率相等,那么每个数被抽取的概率应该为 1/N
    • 遇到第一个数为 n1 的时候, 保留它, p(n1) = 1 
    • 遇到第二个数为 n2 的时候,以 1/2 的概率保留它,则有 p(n1) = 1 * 1/2 = 1/2 , p(n2) = 1/2 
    • 遇到第三个数为 n3 的时候,以 1/3 的概率保留它,则有 p(n1) = 1/2 *(1-1/3) = 1/3 = p(n2),  p(n3) = 1/3。
    • ······
    • 遇到第 i 个数为 ni 的时候, 以 1/i 的概率保留它,则有 p(n1) = p(n2) = ···  = p(ni-1)  = 1/(n-1) * (1 - 1/n) = 1/n, p(ni) = 1/n

这样就可以看出,对于k=1的情况,我们可以制定这样简单的抽样策略:

在取第n个数据的时候,我们生成一个0到1的随机数p,如果p小于等于1/n,保留第n个数。大于1/n,继续保留前面的数。直到数据流结束,返回此数。

数据流中第 i 个数被保留的概率为 1/i 。只要采取这种策略,只需要遍历一遍数据流就可以得到采样值,并且保证所有数被选取的概率均为 1/n 。

下面用数学归纳法证明此结论。

1) 当n=1时,第一个元素以1/1的概率返回,符合条件。

2) 假设当n=k时成立,即每个元素都以1/k的概率返回,先证明n=k+1时,是否成立。

对于最后一个元素显然以1/k+1的概率返回,符合条件,对于前k个数据,被返回的概率为1/k * (1- 1/k+1)=1/k+1,满足题意。

综上所述,结论成立。

代码

 1 class Solution:
 2     def __init__(self, head: Optional[ListNode]):
 3         self.head = head
 4 
 5     def getRandom(self) -> int:
 6         node, i, ans = self.head, 1, 0
 7         while node:
 8             if randrange(i) == 0:  # 1/i 的概率选中(替换为答案)
 9                 ans = node.val
10             i += 1
11             node = node.next
12         return ans

ps: k=1 时代码中生成的不是概率而是区间数的原因见 k>1 的策略。

k >1

当 k>1时,即为水塘抽样。

对于k>1的情况,我们可以采用类似的思考策略:

有了k =1 的理解,我们可以直接替换结论,只需把上面的 1/n 变为 k/n 即可。策略如下:

在取第n个数据的时候,我们生成一个0到1的随机数p,如果p小于等于 k/n,替换池中任意一个为第n个数。大于k/n,继续保留前面的数。直到数据流结束,返回此k个数。

但是为了保证计算机计算分数额准确性,一般是生成一个0到n的随机数,跟k相比,道理是一样的。

可以以同样的方法证明。

1)初始情况k<=n,出现在水库中的k个元素的出现概率都是一致的,都是1

2)第一步。第一步就是指,处理第k+1个元素的情况。分两种情况:元素全部都没有被替换;其中某个元素被第k+1个元素替换掉。

我们先看情况2:第k+1个元素被选中的概率是k/(k+1)(根据公式k/i),所以这个新元素在水塘中出现的概率就一定是k/(k+1)(不管它替换掉哪个元素,反正肯定它是以这个概率出现在水塘中)。

  下面来看水塘中剩余的元素出现的概率,也就是1-P(这个元素被替换掉的概率)。水塘中任意一个元素被替换掉的概率是:(k/k+1)*(1/k)=1/(k+1),意即首先要第k+1个元素被选中,然后自己在集合的k个元素中被选中。

  那它出现的概率就是1-1/(k+1)=k/(k+1)。可以看出来,旧元素和新元素出现的概率是相等的。

情况1:当元素全部都没有替换掉的时候,每个元素的出现概率肯定是一样的,这很显然。但具体是多少呢?就是1-P(第k+1个元素被选中)=1-k/(k+1)=1/(k+1)。

3)归纳法:重复上面的过程,只要证明第i步到第i+1步,所有元素出现的概率是相等的即可。

代码

 1 vector<int> ReservoirSampling(vector<int>& results, vector<int>& nums, int k)
 2 {
 3     // results.size(): k
 4     // nums.size(): N
 5     int N = nums.size();
 6 
 7     for (int i=0; i<k; ++i) {
 8         results[i] = nums[i];
 9     }
10 
11     for (int i=k; i<N; ++i) {
12         int random = rand()%i;
13         if (random<k) {
14             results[random] = nums[i];
15         }
16     }
17 
18     return results;
19 }

 



推荐阅读