python - 概率向量上的网格
问题描述
我试图得到一个 n 维概率向量的“网格”——每个条目都在 0 和 1 之间的向量,并且所有条目加起来为 1。我希望每个可能的向量都可以在其中坐标取任何由0 到 1 之间均匀间隔的值组成的v 。
为了说明这一点,下面是一个非常低效的实现,对于 n = 3 和 v = 3:
from itertools import product
grid_redundant = product([0, .5, 1], repeat=3)
grid = [point for point in grid_redundant if sum(point)==1]
现在grid
包含[(0, 0, 1), (0, 0.5, 0.5), (0, 1, 0), (0.5, 0, 0.5), (0.5, 0.5, 0), (1, 0, 0)]
.
这种“实现”对于更高维度和更细粒度的网格来说是可怕的。有没有一种好方法可以做到这一点,也许使用numpy
?
我也许可以在动机上加一点:如果只是从随机分布中抽样给我足够的极值点,我会非常高兴,但事实并非如此。看到这个问题。我所追求的“网格”不是随机的,而是系统地扫过单纯形(概率向量的空间)。
解决方案
这是一个递归解决方案。它不使用 NumPy,也不是超级高效,尽管它应该比发布的代码片段更快:
import math
from itertools import permutations
def probability_grid(values, n):
values = set(values)
# Check if we can extend the probability distribution with zeros
with_zero = 0. in values
values.discard(0.)
if not values:
raise StopIteration
values = list(values)
for p in _probability_grid_rec(values, n, [], 0.):
if with_zero:
# Add necessary zeros
p += (0.,) * (n - len(p))
if len(p) == n:
yield from set(permutations(p)) # faster: more_itertools.distinct_permutations(p)
def _probability_grid_rec(values, n, current, current_sum, eps=1e-10):
if not values or n <= 0:
if abs(current_sum - 1.) <= eps:
yield tuple(current)
else:
value, *values = values
inv = 1. / value
# Skip this value
yield from _probability_grid_rec(
values, n, current, current_sum, eps)
# Add copies of this value
precision = round(-math.log10(eps))
adds = int(round((1. - current_sum) / value, precision))
for i in range(adds):
current.append(value)
current_sum += value
n -= 1
yield from _probability_grid_rec(
values, n, current, current_sum, eps)
# Remove copies of this value
if adds > 0:
del current[-adds:]
print(list(probability_grid([0, 0.5, 1.], 3)))
输出:
[(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), (0.5, 0.5, 0.0), (0.0, 0.5, 0.5), (0.5, 0.0, 0.5)]
与发布方法的快速比较:
from itertools import product
def probability_grid_basic(values, n):
grid_redundant = product(values, repeat=n)
return [point for point in grid_redundant if sum(point)==1]
values = [0, 0.25, 1./3., .5, 1]
n = 6
%timeit list(probability_grid(values, n))
1.61 ms ± 20.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit probability_grid_basic(values, n)
6.27 ms ± 186 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
推荐阅读
- node.js - NodeJS 版本在终端不会改变
- python-3.x - 如何从python列表中存在的文件中删除时间戳值
- visual-c++ - 调试可视 C++ 属性(悬停在变量上不显示值)
- vue.js - 在 mouseenter vuetify 上显示对话框
- python-3.x - git clone 致命错误 - 存储库“https”不存在
- certbot - 当我做这个(ubuntu 16.04)时出现 CERTBOT 错误:sudo certbot --apache -d mysite.com
- react-native - React、Redux、Thunk、Persist:typeError store.getState 不是函数。(在 'store.getState()' 中,store.getState' 是未定义的)?
- python-3.x - 如何使用多处理在 Python 中迁移数据库?
- c# - Style.Render 不适用于服务器端(Razor)
- java - 通过 CDI 在单元测试中检索替代 entityManager