首页 > 解决方案 > numpy数组中的like语句如何写一个案例

问题描述

def custom_asymmetric_train(y_true, y_pred):
    residual = (y_true - y_pred).astype("float")
    grad = np.where(residual>0, -2*10.0*residual, -2*residual)
    hess = np.where(residual>0, 2*10.0, 2.0)
    return grad, hess

我想写这个声明:

    case when residual>=0 and residual<=0.5 then -2*1.2*residual
    when residual>=0.5 and residual<=0.7 then -2*1.*residual
    when residual>0.7 then -2*2*residual end ) 

但是np.where不能写 &(and) 逻辑。np.where在 python中的逻辑时如何编写这种情况。

谢谢

标签: pythonnumpynumpy-ndarray

解决方案


该语句可以使用np.select编写为:

import numpy as np

residual = np.random.rand(10) -0.3 # -0.3 to get some negative values
condlist = [(residual>=0.0)&(residual<=0.5), (residual>=0.5)&(residual<=0.7), residual>0.7]
choicelist = [-2*1.2*residual, -2*1.0*residual,-2*2.0*residual]

residual = np.select(condlist, choicelist, default=residual)

请注意,当满足 中的多个条件时condlist,使用遇到的第一个条件。当所有条件评估为False时,它将使用该default值。此外,为了您的信息,您需要&在布尔 numpy 数组上使用按位运算符,因为andpython 关键字对它们不起作用。

让我们对这些答案进行基准测试:

residual = np.random.rand(10000) -0.3

def charl_3where(residual):
    residual = np.where((residual>=0.0)&(residual<=0.5), -2*1.2*residual, residual)
    residual = np.where((residual>=0.5)&(residual<=0.7), -2*1.0*residual, residual)
    residual = np.where(residual>0.7, -2*2.0*residual, residual)
    return residual

def yaco_select(residual):
    condlist = [(residual>=0.0)&(residual<=0.5), (residual>=0.5)&(residual<=0.7), residual>0.7]
    choicelist = [-2*1.2*residual, -2*1.0*residual,-2*2.0*residual]
    residual = np.select(condlist, choicelist, default=residual)
    return residual


%timeit charl_3where(residual)
>>> 112 µs ± 1.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit yaco_select(residual)
>>> 141 µs ± 2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

让我们尝试优化这些numba

from numba import jit

@jit(nopython=True)
def yaco_numba(residual):
    out = np.empty_like(residual)
    for i in range(residual.shape[0]):
        if residual[i]<0.0 :
            out[i] = residual[i]
        elif residual[i]<=0.5 :
            out[i] = -2*1.2*residual[i]
        elif residual[i]<=0.7:
            out[i] = -2*1.0*residual[i]
        else: # residual>0.7
            out[i] = -2*2.0*residual[i]        
    return out

%timeit yaco_numba(residual)
>>> 6.65 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

最终检查

res1 = charl_3where(residual)
res2 = yaco_select(residual)
res3 = yaco_numba(residual)
np.allclose(res1,res3)
>>> True
np.allclose(res2,res3)
>>> True

这个15x比以前最好的快。希望这可以帮助。


推荐阅读