首页 > 解决方案 > 使用 numpy.where 创建掩码

问题描述

我有这个函数来创建一个掩码(布尔数组),我想快得多。

def get_validity_1(ts, times):
    validity = numpy.zeros(len(ts))
    indexes = []
    for start, end in times:
        index_start = numpy.argmax(ts >= start)
        index_end = numpy.argmax(ts >= end)
        indexes.append([index_start, index_end])
    for start, end in indexes:
        validity[start:end] = 1
    return validity
res_1 = get_validity_1(numpy.linspace(0, 1, 100000000), numpy.array([[0.01, 0.1], [0.5, 0.8]]))

这个问题的问题是如何使用 numpy.where 条件来制作它。我试过这个:

def get_validity_2(ts, times):
    return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))

但蟒蛇提高:

ValueError: invalid number of arguments

这里有一些输入断言:

这是一个脚本作为输入:

import time, numpy

def get_validity_1(ts, times):
    validity = numpy.zeros(len(ts))
    indexes = []
    for start, end in times:
        index_start = numpy.argmax(ts >= start)
        index_end = numpy.argmax(ts >= end)
        indexes.append([index_start, index_end])
    for start, end in indexes:
        validity[start:end] = 1
    return validity
    
def get_validity_2(ts, times):
    return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))

if __name__ == "__main__":
    n = 100000000
    ts = numpy.linspace(0, 1, n)
    
    times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
    
    t0 = time.time()
    res_1 = get_validity_1(ts, times)
    t_1 = time.time() - t0
    
    t0 = time.time()
    res_2 = get_validity_2(ts, times)
    t_2 = time.time() - t0
    
    print("t_1: " + str(t_1))
    print("t_2: " + str(t_2))
    
    assert res_1 == res_2
    assert t_1 > t_2

有谁知道如何完成函数“get_validity_2”并传递断言?或者只是一个包的功能来解决这个问题?

标签: pythonnumpyoptimizationmask

解决方案


np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])

如果你想要一个像你试图实现的 1-liner。但是,这仍然是低效的,因为您正在比较 O(N) 中的大型数组。

由于 ts 已排序,因此这里有一种使用二进制搜索在 O(log(N)) 中查找开始/结束索引的更快方法:

def get_validity_3(ts, times):
    validity = numpy.zeros(len(ts))
    for start, end in times:
        index_start = np.searchsorted(ts, start)
        index_end = np.searchsorted(ts, end)
        validity[index_start:index_end] = 1
    return validity

整体代码:

import time, numpy

def get_validity_1(ts, times):
    validity = numpy.zeros(len(ts))
    indexes = []
    for start, end in times:
        index_start = numpy.argmax(ts >= start)
        index_end = numpy.argmax(ts >= end)
        indexes.append([index_start, index_end])
    for start, end in indexes:
        validity[start:end] = 1
    return validity
    
def get_validity_2(ts, times):
    return np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])
    
def get_validity_3(ts, times):
    validity = numpy.zeros(len(ts))
    for start, end in times:
        index_start = np.searchsorted(ts, start)
        index_end = np.searchsorted(ts, end)
        validity[index_start:index_end] = 1
    return validity
    

if __name__ == "__main__":
    n = 100000000
    ts = numpy.linspace(0, 1, n)
    
    times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
    
    t0 = time.time()
    res_1 = get_validity_1(ts, times)
    t_1 = time.time() - t0
    
    t0 = time.time()
    res_2 = get_validity_2(ts, times)
    t_2 = time.time() - t0
    
    t0 = time.time()
    res_3 = get_validity_3(ts, times)
    t_3 = time.time() - t0
    
    print("t_1: " + str(t_1))
    print("t_2: " + str(t_2))
    print("t_3: " + str(t_3))
    
    assert (res_1 == res_2).all()
    assert (res_1 == res_3).all()

输出:

t_1: 0.4412200450897217
t_2: 0.3446168899536133
t_3: 0.14597129821777344


推荐阅读