首页 > 解决方案 > 通过广播简化这个三重循环

问题描述

我有一个用纯 numpy 编写的函数,我在其中计算了一些统计数据大量的时间,而且它太长了。此函数包含一个三重循环,但我找不到如何将其转换为广播。

由于我的实际数据很难理解,但随机数据在这里没有意义,我在下面的代码中为您提供了这些数据的示例,并在最后绘制了预期输出的图表。

import numpy as np
import seaborn as sns

# Setting parameters and data
z_init = np.array([[5.61293390e-01, 9.97100450e-01, 4.23530180e-01, 6.08808896e-01],
       [1.22563280e-01, 1.72015130e-01, 8.71145720e-01, 5.40745844e-01],
       [8.51194500e-02, 1.18289130e-01, 8.90346540e-01, 7.22859351e-01],
       [9.83241310e-01, 9.57282690e-01, 7.22347100e-02, 5.43527399e-02],
       [5.49211550e-01, 3.97858250e-01, 6.86380990e-01, 7.91494336e-01],
       [9.94878920e-01, 6.39160920e-01, 2.01045170e-01, 9.86840712e-01],
       [5.04337540e-01, 5.69995040e-01, 3.99087430e-01, 4.32140476e-01],
       [9.28230540e-01, 9.32143440e-01, 1.02748280e-01, 9.92666867e-01],
       [1.77513660e-01, 1.83466350e-01, 7.99027540e-01, 6.30800256e-01],
       [5.14663640e-01, 6.34361690e-01, 5.33889110e-01, 7.90899958e-01],
       [2.58006640e-01, 2.88319290e-01, 7.09604700e-01, 9.02145588e-01],
       [2.04811730e-01, 1.58717810e-01, 8.41421970e-01, 8.84068574e-01],
       [3.11875950e-01, 2.46353420e-01, 7.58289460e-01, 9.65660849e-01],
       [9.36622730e-01, 8.01263020e-01, 9.83931900e-02, 5.44281251e-01],
       [3.45077880e-01, 2.88884330e-01, 6.99352220e-01, 9.71301027e-01],
       [8.18323020e-01, 8.42968360e-01, 2.68607890e-01, 1.52418342e-01],
       [9.07517590e-01, 7.41841580e-01, 3.94466860e-01, 3.33215046e-01],
       [3.84328210e-01, 2.85716010e-01, 7.73390420e-01, 4.79702183e-01],
       [4.74390660e-01, 7.16142340e-01, 4.52819790e-01, 7.98200958e-01],
       [1.91063150e-01, 3.11021770e-01, 8.17990970e-01, 3.65566296e-02],
       [3.84454330e-01, 5.31087070e-01, 6.46125720e-01, 4.26028784e-01],
       [7.76256310e-01, 9.82152950e-01, 2.24877600e-01, 9.03596071e-01],
       [5.21782070e-01, 6.09994810e-01, 5.49719820e-01, 3.77052128e-01],
       [2.44654800e-01, 2.22705930e-01, 6.92217350e-01, 8.04524950e-01],
       [6.26568380e-01, 6.85160450e-01, 5.05651420e-01, 1.06857236e-01],
       [4.57300600e-02, 3.19478800e-02, 9.56385030e-01, 5.62236853e-01],
       [9.85249100e-02, 9.00401200e-02, 9.15879990e-01, 9.37933793e-01],
       [1.13619610e-01, 1.08790380e-01, 8.86476490e-01, 9.07453097e-01],
       [8.61160760e-01, 9.26073490e-01, 1.51885700e-02, 5.61689264e-01],
       [4.72355650e-01, 8.90940390e-01, 4.79858960e-01, 1.31270153e-01],
       [3.29944340e-01, 3.87579980e-01, 5.10804930e-01, 5.19551698e-01],
       [5.29880000e-03, 7.01797000e-03, 9.93625440e-01, 1.85747216e-01],
       [6.23029300e-01, 4.84573370e-01, 2.68151590e-01, 5.60921564e-03],
       [8.03967200e-01, 7.84008540e-01, 2.37276020e-01, 9.47798098e-01],
       [4.58525700e-01, 6.45049070e-01, 5.76664620e-01, 5.75709041e-01],
       [8.65658770e-01, 8.99023760e-01, 2.40967370e-01, 5.56589158e-01],
       [9.67638170e-01, 8.83972000e-01, 3.44920000e-04, 9.17016810e-01],
       [7.24319710e-01, 8.63911350e-01, 2.64988220e-01, 2.13078474e-01],
       [9.91532810e-01, 9.85368470e-01, 6.58391400e-02, 4.58927178e-04],
       [1.17465420e-01, 1.14261700e-01, 9.05365980e-01, 1.80863318e-01],
       [8.63220080e-01, 7.91506140e-01, 5.18878970e-01, 5.49666344e-01],
       [3.12470000e-02, 3.19607200e-02, 9.72981330e-01, 8.60758375e-01],
       [4.41205810e-01, 4.87292710e-01, 5.39217100e-01, 5.65980037e-01],
       [1.39334500e-02, 1.40076300e-02, 9.84327060e-01, 9.38894626e-01],
       [2.43659030e-01, 2.09662260e-01, 7.81243000e-01, 1.14951150e-01],
       [9.45841720e-01, 9.45075580e-01, 1.37627400e-02, 5.39213927e-01],
       [7.92013050e-01, 6.55037130e-01, 2.08627580e-01, 1.50823215e-01],
       [6.04095200e-02, 4.57398400e-02, 9.53590740e-01, 6.32755639e-01],
       [5.67334500e-01, 2.75674320e-01, 6.47657510e-01, 4.68491101e-01],
       [1.58600060e-01, 1.22128390e-01, 8.47935330e-01, 4.94281577e-01],
       [8.26576000e-03, 4.07989000e-03, 9.95207870e-01, 8.02447365e-02],
       [7.25499790e-01, 7.05574910e-01, 3.94566850e-01, 8.90077195e-01],
       [8.30398180e-01, 7.65006390e-01, 1.04508490e-01, 9.44908637e-01],
       [6.84425700e-01, 8.13177160e-01, 2.55194010e-01, 1.71608600e-01],
       [3.69045830e-01, 4.10810940e-01, 6.39276590e-01, 9.22700243e-01],
       [1.42119190e-01, 1.45086850e-01, 8.87464990e-01, 2.35533293e-01],
       [8.51399930e-01, 8.63513030e-01, 4.28106000e-02, 7.49796027e-01],
       [7.26388760e-01, 9.23435870e-01, 3.21152410e-01, 5.59389176e-01],
       [2.68165680e-01, 2.21699530e-01, 7.13336850e-01, 8.28847266e-01],
       [4.67212100e-02, 6.18397600e-02, 9.08459550e-01, 1.73109978e-01],
       [8.12353540e-01, 6.14787930e-01, 2.36200930e-01, 6.70979632e-01],
       [3.56200600e-01, 2.86300900e-01, 6.87996620e-01, 7.68872468e-01],
       [4.27617260e-01, 4.08906890e-01, 4.65987670e-01, 1.67199623e-01],
       [6.63373240e-01, 9.66214910e-01, 1.39582640e-01, 9.85382902e-01],
       [5.51993350e-01, 4.93202560e-01, 5.63663960e-01, 1.69990831e-01],
       [8.04742160e-01, 7.23388830e-01, 1.97937550e-01, 5.06756753e-01],
       [1.07240370e-01, 1.15115720e-01, 9.07925810e-01, 3.46134208e-01],
       [3.61709450e-01, 2.16649010e-01, 7.91721970e-01, 5.22621049e-01],
       [9.83195600e-01, 9.35189250e-01, 1.09384140e-01, 4.87989100e-01],
       [1.07405620e-01, 1.05033440e-01, 8.76795260e-01, 2.44237928e-01],
       [6.75897130e-01, 6.50329960e-01, 3.04297580e-01, 3.60810270e-01],
       [7.02020600e-02, 4.96392100e-02, 9.33498520e-01, 7.17513612e-01],
       [4.84155500e-01, 6.88098980e-01, 3.46669530e-01, 2.16784063e-01],
       [6.04164790e-01, 7.48494480e-01, 9.49017500e-02, 2.69127829e-03],
       [5.92501140e-01, 7.18188940e-01, 4.79787090e-01, 4.72203718e-01],
       [6.47244640e-01, 9.12962170e-01, 3.94908800e-02, 1.89967176e-02],
       [7.52063710e-01, 8.36582980e-01, 2.56381510e-01, 1.82552057e-01],
       [7.33809600e-01, 5.88942430e-01, 3.17564930e-01, 4.83186793e-02],
       [6.37782580e-01, 7.91589180e-01, 3.08634220e-01, 1.83951279e-01],
       [7.32009020e-01, 9.14051250e-01, 1.80915920e-01, 2.45163585e-01],
       [1.53493780e-01, 1.90967590e-01, 8.19005590e-01, 7.55056039e-01],
       [5.36161820e-01, 5.13641150e-01, 5.01637010e-01, 3.47079632e-01],
       [6.06637230e-01, 6.67565790e-01, 3.33999130e-01, 2.51786198e-01],
       [7.25650010e-01, 8.41152620e-01, 2.36374270e-01, 2.61322095e-01],
       [6.52008490e-01, 8.66015010e-01, 1.90032370e-01, 5.14531432e-01],
       [2.59336300e-02, 3.60464100e-02, 9.42735970e-01, 8.76251330e-01],
       [3.91414850e-01, 3.16164320e-01, 6.36344310e-01, 2.11938819e-01],
       [6.43722130e-01, 5.38235890e-01, 1.13523690e-01, 3.54529909e-01],
       [7.90799970e-01, 7.44277280e-01, 3.24458070e-01, 1.60302427e-01],
       [7.10510700e-02, 8.50407900e-02, 9.08863250e-01, 9.18056054e-02],
       [8.27656880e-01, 7.68024600e-01, 6.35402600e-02, 1.39203186e-01],
       [4.22585470e-01, 4.66851210e-01, 5.36839920e-01, 9.51087042e-02],
       [8.74929100e-02, 9.12235300e-02, 8.91159090e-01, 3.88725280e-02],
       [9.36443830e-01, 8.16299420e-01, 2.90021130e-01, 2.89175878e-01],
       [9.26354550e-01, 9.51074570e-01, 8.49412000e-03, 1.76092602e-01],
       [5.72510800e-02, 3.56584500e-02, 9.67113730e-01, 7.74680782e-01],
       [1.10646890e-01, 1.10709490e-01, 8.85004310e-01, 8.08970193e-01],
       [9.30983140e-01, 9.85525760e-01, 6.47764700e-02, 3.51535913e-01],
       [9.16176650e-01, 8.13787300e-01, 2.80715970e-01, 7.69428516e-01],
       [9.34773620e-01, 8.73895270e-01, 1.23538120e-01, 5.72796569e-01]])
z_init[:,3] = np.random.uniform(size=100)

d_init = z_init.shape[1]
n_init = z_init.shape[0]
N = 999
bp = np.array([0.5,0.2,0.55,0.7])

需要简化的实际循环如下:

random_numbers = np.random.uniform(size=(N,n_init,d_init))
vals = np.zeros(shape=(N,d_init))
for index in range(2 ** d_init):
    binary_repr = np.array(list(np.binary_repr(index, width=d_init)), dtype="float64")
    min = bp * binary_repr
    max = bp ** (1 - binary_repr)
    lambda_l = np.prod(max - min)

    for d in np.arange(d_init):

        mask = np.arange(d_init) != d
        min_without_d = np.array(min)[mask]
        max_without_d = np.array(max)[mask]
        lambda_k = np.prod(max_without_d - min_without_d)

        z_rep = np.repeat(z_init[None,:,:],N,axis=0)
        z_rep[:,:,d] = random_numbers[:,:,d]

        f_l = np.mean(np.all(np.logical_and(z_rep >= min, z_rep < max), axis=2),axis=1)
        f_k = np.mean(np.all(np.logical_and(z_rep[:,:, mask] >= min_without_d, z_rep[:,:, mask] < max_without_d), axis=2),axis=1)

        vals[:,d] += 1 / 2 * f_k ** 2 / lambda_k + f_l ** 2 / lambda_l - 2 * f_k * f_l / lambda_k

绘制结果:

sns.distplot(vals[:,0],hist=False, rug=True)
sns.distplot(vals[:,1],hist=False, rug=True)
sns.distplot(vals[:,2],hist=False, rug=True)
sns.distplot(vals[:,3],hist=False, rug=True)

问题是循环使用 2^d 案例的讨厌的二进制表示,这给我带来了简化的麻烦。但也许第二个循环仍然可以矢量化?

谢谢您的帮助 :)

编辑:当人们发表相关评论时,我编辑了代码以包含它们。

标签: pythonnumpyarray-broadcasting

解决方案


推荐阅读