python - 使用 numpy.where 创建掩码
问题描述
我有这个函数来创建一个掩码(布尔数组),我想快得多。
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
res_1 = get_validity_1(numpy.linspace(0, 1, 100000000), numpy.array([[0.01, 0.1], [0.5, 0.8]]))
这个问题的问题是如何使用 numpy.where 条件来制作它。我试过这个:
def get_validity_2(ts, times):
return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))
但蟒蛇提高:
ValueError: invalid number of arguments
这里有一些输入断言:
- ts[n-1] < ts[n]
- 次[n][0] < 次[n][1]
- 次[n-1][1] < 次[n][0]
这是一个脚本作为输入:
import time, numpy
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
def get_validity_2(ts, times):
return numpy.where(numpy.logical_or([t1<ts.all()<t2 for t1, t2 in times]))
if __name__ == "__main__":
n = 100000000
ts = numpy.linspace(0, 1, n)
times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
t0 = time.time()
res_1 = get_validity_1(ts, times)
t_1 = time.time() - t0
t0 = time.time()
res_2 = get_validity_2(ts, times)
t_2 = time.time() - t0
print("t_1: " + str(t_1))
print("t_2: " + str(t_2))
assert res_1 == res_2
assert t_1 > t_2
有谁知道如何完成函数“get_validity_2”并传递断言?或者只是一个包的功能来解决这个问题?
解决方案
np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])
如果你想要一个像你试图实现的 1-liner。但是,这仍然是低效的,因为您正在比较 O(N) 中的大型数组。
由于 ts 已排序,因此这里有一种使用二进制搜索在 O(log(N)) 中查找开始/结束索引的更快方法:
def get_validity_3(ts, times):
validity = numpy.zeros(len(ts))
for start, end in times:
index_start = np.searchsorted(ts, start)
index_end = np.searchsorted(ts, end)
validity[index_start:index_end] = 1
return validity
整体代码:
import time, numpy
def get_validity_1(ts, times):
validity = numpy.zeros(len(ts))
indexes = []
for start, end in times:
index_start = numpy.argmax(ts >= start)
index_end = numpy.argmax(ts >= end)
indexes.append([index_start, index_end])
for start, end in indexes:
validity[start:end] = 1
return validity
def get_validity_2(ts, times):
return np.logical_or(*[np.logical_and(t1<ts, ts<t2) for t1, t2 in times])
def get_validity_3(ts, times):
validity = numpy.zeros(len(ts))
for start, end in times:
index_start = np.searchsorted(ts, start)
index_end = np.searchsorted(ts, end)
validity[index_start:index_end] = 1
return validity
if __name__ == "__main__":
n = 100000000
ts = numpy.linspace(0, 1, n)
times = numpy.array([[0.01, 0.1], [0.5, 0.8]])
t0 = time.time()
res_1 = get_validity_1(ts, times)
t_1 = time.time() - t0
t0 = time.time()
res_2 = get_validity_2(ts, times)
t_2 = time.time() - t0
t0 = time.time()
res_3 = get_validity_3(ts, times)
t_3 = time.time() - t0
print("t_1: " + str(t_1))
print("t_2: " + str(t_2))
print("t_3: " + str(t_3))
assert (res_1 == res_2).all()
assert (res_1 == res_3).all()
输出:
t_1: 0.4412200450897217
t_2: 0.3446168899536133
t_3: 0.14597129821777344
推荐阅读
- oracle - 如何在修改另一个字段时更新一个字段
- visual-studio-code - Visual Studio Code 面包屑导航折叠
- firebase - 检查 auth.id 是否在带有 firestore.rules 的资源数据映射中?
- powershell - 如何从外部获取 Powershell 执行策略
- r - 使用 dplyr 将数据框和列表转换为长格式
- python - 我可以让工具提示在 plotly-dash 中保持点击吗?
- python - 如何修复“无法连接到 'localhost:3306' 上的 MySQL 服务器”错误
- python - 命名空间包的正确设置是什么,无论是安装的还是作为源包的?
- java - 是否可以读取 txt 文件的内容并使用 System.out.print(); 显示它?
- typescript - 在 Typescript 中使用具有泛型参数的类型作为构造函数