python - 优化生成变量的拒绝方法
问题描述
我对生成连续随机变量的拒绝方法的优化有疑问。我有一个密度:f(x) = 3/2 (1-x^2)
。这是我的代码:
import random
import matplotlib.pyplot as plt
import numpy as np
import time
import scipy.stats as ss
a=0 # xmin
b=1 # xmax
m=3/2 # ymax
variables = [] #list for variables
def f(x):
return 3/2 * (1 - x**2) #probability density function
reject = 0 # number of rejections
start = time.time()
while len(variables) < 100000: #I want to generate 100 000 variables
u1 = random.uniform(a,b)
u2 = random.uniform(0,m)
if u2 <= f(u1):
variables.append(u1)
else:
reject +=1
end = time.time()
print("Time: ", end-start)
print("Rejection: ", reject)
x = np.linspace(a,b,1000)
plt.hist(variables,50, density=1)
plt.plot(x, f(x))
plt.show()
ss.probplot(variables, plot=plt)
plt.show()
我的第一个问题:我的概率图制作正确吗?第二,标题中的内容。如何优化该方法?我想得到一些建议来优化代码。现在该代码大约需要 0.5 秒,并且大约有 50 000 次拒绝。是否可以减少拒绝的时间和次数?如果需要,我可以使用不同的生成变量的方法进行优化。
解决方案
我的第一个问题:我的概率图制作正确吗?
不,它是根据默认正态分布制作的。您必须将函数打包f(x)
到派生自 stats.rv_continuous 的类中,将其放入 _pdf 方法,并将其传递给probplot
第二,标题中的内容。如何优化该方法?是否可以减少拒绝的时间和次数?
当然,您拥有 NumPy 矢量功能的强大功能。永远不要写显式循环——vectoriz、vectorize 和 vectorize!
看看下面修改过的代码,不是一个循环,一切都是通过 NumPy 向量完成的。我的计算机上 100000 个样本(Xeon、Win10 x64、Anaconda Python 3.7)的时间从 0.19 下降到 0.003。
import numpy as np
import scipy.stats as ss
import matplotlib.pyplot as plt
import time
a = 0. # xmin
b = 1. # xmax
m = 3.0/2.0 # ymax
def f(x):
return 1.5 * (1.0 - x*x) # probability density function
start = time.time()
N = 100000
u1 = np.random.uniform(a, b, N)
u2 = np.random.uniform(0.0, m, N)
negs = np.empty(N)
negs.fill(-1)
variables = np.where(u2 <= f(u1), u1, negs) # accepted samples are positive or 0, rejected are -1
end = time.time()
accept = np.extract(variables>=0.0, variables)
reject = N - len(accept)
print("Time: ", end-start)
print("Rejection: ", reject)
x = np.linspace(a, b, 1000)
plt.hist(accept, 50, density=True)
plt.plot(x, f(x))
plt.show()
ss.probplot(accept, plot=plt) # against normal distribution
plt.show()
关于减少拒绝的数量,您可以使用 0 拒绝进行反向采样,它是三次方程,因此可以轻松使用
更新
这是用于 probplot 的代码:
class my_pdf(ss.rv_continuous):
def _pdf(self, x):
return 1.5 * (1.0 - x*x)
ss.probplot(accept, dist=my_pdf(a=a, b=b, name='my_pdf'), plot=plt)
你应该得到类似的东西
推荐阅读
- vue.js - Vue.js 2 + WP REST API“TypeError:无法读取 null 的属性‘过滤器’”
- c# - 更改 NumericUpDown 的边框颜色
- javascript - 如何使用 expo 或 expo-image-picker-multiple 从图库中选择多个图像
- javascript - 为什么该函数在第二次运行后才返回一个值,即使console.log返回了值?
- ruby-on-rails - 如何在 ruby 中模拟 Finder/Mac OS 复制?
- performance - 隐含的 DO 循环编译性能
- html - 在命令行中渲染 rmarkdown 文件出现错误:未提供 html_dependency 的路径
- c# - 我该如何解决这样的错误,例如不能将类型“int”隐式转换为“byte”。存在显式转换(您是否缺少演员表?)
- julia - 如何在 Julia 中找到数组最小值的值和索引?
- ios - 覆盖 UIMenu 上的用户界面样式(即浅色/深色)