python - numpy - 有效地过滤选择具有随机约束的随机样本
问题描述
我想用 numpy 获得 N 个随机样本,过滤后满足一个标准。我对我目前的实施不满意;对于较大的 N 值(例如 100,000),它太慢了。如何更有效地过滤这些样本以满足相关标准均匀随机样本小于 f/g 的条件?必须有一种更快的方法来实现此代码。
import numpy as np
from scipy.special import gamma
import matplotlib.pyplot as plt
def f(x): return 1. / gamma(3) * x * np.exp(-1 * x)
lambd = .2
c = 1 / lambd / gamma(3) * (2./(1-lambd)) ** 2 * np.exp(-1 * (1 - lambd) * (2. / (lambd - 1)))
def g(x): return c * lambd * np.exp(-1 * lambd * x)
x = np.linspace(0, 50, 1000)
samples = []
N = 100
while len(samples) < N:
randou = np.random.uniform(0, 1)
randoh = c * np.random.exponential(0.2)
if randou <= f(randoh) / g(randoh): samples.append(randoh)
plt.hist(samples, 100, normed=True, label='Simulated PDF')
plt.plot(x, f(x), label='True PDF', lw=2)
plt.xlim(0, 10)
plt.show()
我还尝试一次性生成样本,然后在 while 循环中过滤这些样本,但我不确定这种方法实际上有多快:
samples = np.random.uniform(0, 1, 100000)
hsamps = c * np.random.exponential(0.2, 100000)
N = 100
idx = np.array([True, False])
while len(idx[idx==True]) > 0:
idx = samples > ( f(hsamps) / g(hsamps))
samples[idx] = np.random.uniform(0, 1, len(idx[idx==True]))
hsamps[idx] = c * np.random.exponential(0.2, len(idx[idx==True]))
解决方案
为了利用 NumPy 的速度,您需要使用大型数组,而不是在循环中处理的单个标量。例如,您可以生成N
如下示例:
randous = np.random.uniform(0, 1, size=N)
randohs = c * np.random.exponential(0.2, size=N)
然后选择那些通过你的过滤器的人,如下所示:
mask = randous <= f(randohs) / g(randohs)
return randohs[mask]
唯一的问题是无法保证randohs[mask]
具有所需数量的值(或任何值)。所以我们可能会重复这个,直到我们生成足够的样本:
while len(samples) < N:
randohs = generate_samples()
samples.extend(randohs)
samples = samples[:N]
尽管使用了一个while循环,但这仍然比一次生成一个样本要快得多。
import numpy as np
from scipy.special import gamma
import matplotlib.pyplot as plt
def f(x):
return 1. / gamma(3) * x * np.exp(-1 * x)
def g(x):
return c * lambd * np.exp(-1 * lambd * x)
def generate_samples(N=10**5):
randous = np.random.uniform(0, 1, size=N)
randohs = c * np.random.exponential(0.2, size=N)
mask = randous <= f(randohs) / g(randohs)
return randohs[mask]
lambd = .2
c = (1 / lambd / gamma(3) * (2./(1-lambd)) ** 2
* np.exp(-1 * (1 - lambd) * (2. / (lambd - 1))))
x = np.linspace(0, 50, 1000)
samples = []
N = 10**5
while len(samples) < N:
randohs = generate_samples()
samples.extend(randohs)
samples = samples[:N]
plt.hist(samples, 100, density=True, label='Simulated PDF')
plt.plot(x, f(x), label='True PDF', lw=2)
plt.xlim(0, 10)
plt.show()
推荐阅读
- java - SpringFox:枚举数组显示为字符串数组
- html - 在 HTML 中调整图像相对于彼此的大小
- azure-node-sdk - ms-rest-nodeauth (azure-sdk-for-js) :错误:凭据参数需要实现 signRequest 方法
- python - 使用迭代的 Kivy 下拉列表
- python - Matplotlib - 图形图例不适用于具有多个轴的多个子图
- c# - 如何在不丢失输入数据的情况下在两个表单之间来回移动
- keycloak - 在 keycloak 中设置两层角色模型
- flutter - 如何在 google_maps_flutter 中隐藏 3d 建筑物?
- 3d - 将对象坐标与图像坐标匹配
- android - Discord bot 在 ios 与 android 和桌面上的功能不同