首页 > 解决方案 > 请问如何优化这个 Python 代码?

问题描述

我有一个来自双变量法线的数据集,但是每个数据点都是从家庭的不同成员中提取的,即均值和协方差矩阵不同。我正在尝试优化我的代码,我想知道是否有更好的方法来做到这一点。

目前,我正在使用列表理解,并且尝试了并行性,但没有取得很大成功。对于 100,000 个数据点,列表理解在大约 8 秒内完成,并行性大约需要几秒钟。

我的数据大小可以从几十万到几百万不等,但数据将始终来自二元高斯。同样正如我所说,每个数据点都是从具有不同均值和协方差的双变量高斯中提取的,但是在下面的代码片段中,我采用了最简单的情况,并且我为所有数据传入了相同的协方差和均值。

任何想法都非常感谢!

import numpy as np
import pandas as pd
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing import cpu_count
from scipy.stats import multivariate_normal
from timeit import timeit


def mvn_loglik(param):
    out = [multivariate_normal.logpdf(p[0], p[1], p[2]) for i, p in enumerate(param)]
    return out

def mvn_loglik_helper(p):
    out = multivariate_normal.logpdf(p[0], p[1], p[2])
    return out

def mvn_loglik_par(param):
    n = max(1, cpu_count() - 1)
    pool = ThreadPool(n)
    results = pool.map(mvn_loglik_helper, param)
    pool.close()
    pool.join()
    return results

n = 100000
data = np.zeros([n,2])
mu = np.zeros([n,2])
cov = np.tile(np.eye(2,2), (n,1,1))

param = list(zip(*[data, mu, cov]))
assert np.all(mvn_loglik(param) == mvn_loglik_par(param))

def f():
    param = list(zip(*[data, mu, cov]))
    return mvn_loglik(param)

def g():
    param = list(zip(*[data, mu, cov]))
    return mvn_loglik_par(param)


print(timeit(f, number=1))
8.6705456

print(timeit(g, number=1))
10.187561600000002

标签: python

解决方案


推荐阅读