python - numpy数组中的like语句如何写一个案例
问题描述
def custom_asymmetric_train(y_true, y_pred):
residual = (y_true - y_pred).astype("float")
grad = np.where(residual>0, -2*10.0*residual, -2*residual)
hess = np.where(residual>0, 2*10.0, 2.0)
return grad, hess
我想写这个声明:
case when residual>=0 and residual<=0.5 then -2*1.2*residual
when residual>=0.5 and residual<=0.7 then -2*1.*residual
when residual>0.7 then -2*2*residual end )
但是np.where
不能写 &(and) 逻辑。np.where
在 python中的逻辑时如何编写这种情况。
谢谢
解决方案
该语句可以使用np.select编写为:
import numpy as np
residual = np.random.rand(10) -0.3 # -0.3 to get some negative values
condlist = [(residual>=0.0)&(residual<=0.5), (residual>=0.5)&(residual<=0.7), residual>0.7]
choicelist = [-2*1.2*residual, -2*1.0*residual,-2*2.0*residual]
residual = np.select(condlist, choicelist, default=residual)
请注意,当满足 中的多个条件时condlist
,使用遇到的第一个条件。当所有条件评估为False
时,它将使用该default
值。此外,为了您的信息,您需要&
在布尔 numpy 数组上使用按位运算符,因为and
python 关键字对它们不起作用。
让我们对这些答案进行基准测试:
residual = np.random.rand(10000) -0.3
def charl_3where(residual):
residual = np.where((residual>=0.0)&(residual<=0.5), -2*1.2*residual, residual)
residual = np.where((residual>=0.5)&(residual<=0.7), -2*1.0*residual, residual)
residual = np.where(residual>0.7, -2*2.0*residual, residual)
return residual
def yaco_select(residual):
condlist = [(residual>=0.0)&(residual<=0.5), (residual>=0.5)&(residual<=0.7), residual>0.7]
choicelist = [-2*1.2*residual, -2*1.0*residual,-2*2.0*residual]
residual = np.select(condlist, choicelist, default=residual)
return residual
%timeit charl_3where(residual)
>>> 112 µs ± 1.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit yaco_select(residual)
>>> 141 µs ± 2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
让我们尝试优化这些numba
from numba import jit
@jit(nopython=True)
def yaco_numba(residual):
out = np.empty_like(residual)
for i in range(residual.shape[0]):
if residual[i]<0.0 :
out[i] = residual[i]
elif residual[i]<=0.5 :
out[i] = -2*1.2*residual[i]
elif residual[i]<=0.7:
out[i] = -2*1.0*residual[i]
else: # residual>0.7
out[i] = -2*2.0*residual[i]
return out
%timeit yaco_numba(residual)
>>> 6.65 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
最终检查
res1 = charl_3where(residual)
res2 = yaco_select(residual)
res3 = yaco_numba(residual)
np.allclose(res1,res3)
>>> True
np.allclose(res2,res3)
>>> True
这个15x
比以前最好的快。希望这可以帮助。
推荐阅读
- asp.net-mvc - 类从模型 asp.net mvc 中消失了
- docker - 在 Kubernetes 集群上部署 zuul 代理
- android - Android 9 asynctask 中的 java.lang.NoClassDefFoundError
- c++ - 在 C++ 中比较无穷大和无穷大
- sql - SQL 将列中的部分字符串与另一列匹配
- php - PHP - 读取可执行文件的数字签名并验证作者,如何使用 PHP 验证可执行文件的数字签名
- java - problems with ACLMessages in JADE
- javascript - Firechat - 无法发送消息
- c# - 动态阅读类
- c# - 如何从目录创建图像集合