首页 > 技术文章 > python代码技术优化

bregman 2016-09-27 16:33 原文

numba 编译优化

from numba import jit
@jit
def eval_mcc(y_true, y_prob, threshold=False):
    idx  = np.argsort(y_prob)
    nump = np.sum(y_true)
    numn = idx.size - nump
    tp, tn, fp, fn = nump, 0.0, numn, 0.0
    best_mcc, best_proba, prev_proba = 0.0, -1, -1
    for proba, y_i in zip(y_prob[idx], y_true[idx]):
        if proba != prev_proba:
            prev_proba = proba
            new_mcc = mcc(tp, tn, fp, fn)
            if new_mcc >= best_mcc:
                best_mcc, best_proba = new_mcc,  proba
        if y_i == 1:
            tp -= 1.0
            fn += 1.0
        else:
            fp -= 1.0
            tn += 1.0
    return (best_proba, best_mcc) if threshold else best_mcc

numexpr 简单表达式的优化

测试 1

In [1]: import numpy as np
In [2]: import numexpr as ne
In [3]: a = np.arange(1e6)
In [4]: b = np.arange(1e6)
In [5]: timeit a**2 + b**2 + 2*a*b
100 loops, best of 3: 10.8 ms per loop

In [6]: timeit ne.evaluate("a**2 + b**2 + 2*a*b")
100 loops, best of 3: 2.4 ms per loop

测试 2

In [8]: A = np.random.rand(5000,5000)
In [9]: B = np.random.rand(5000,5000)
In [10]: timeit A*B
10 loops, best of 3: 137 ms per loop

In [11]: timeit ne.evaluate("A*B")
10 loops, best of 3: 101 ms per loop

推荐阅读