python - 比使用 np.where() 在 Python 中的二维数组中查找指定值的所有遇到的索引更快的方法
问题描述
我正在开发一个 OpenCV 程序来计算视图中红色对象的中心。
在我正在处理的图像矩阵中,我已经对图像进行了过滤,使任何略带红色的东西都显示为 255 作为矩阵元素。
我正在使用 np.where() 找到所有 255 元素的位置
我正在使用 np.mean() 来计算 np.where() 获取的二维数组中索引 [0] 和 [1] 的平均值,以计算“中心”坐标
这是我的代码片段,它计算视图中所有红色对象的“质心”。
red_only_array = np.array(red_only)
locations = np.where(red_only==255)
x_avg = np.mean(locations[1])
y_avg = np.mean(locations[0])
我在 cv2.ConnectedComponents() 标记的 10 个不同对象矩阵上重复这个过程。这次我这样做是为了获得单个红色物体的质心。
以下是我正在使用的代码
_,labels = cv2.connectedComponents(red_only_array, connectivity = 8)
b = np.matrix(labels)
Obj1= b==1
Obj1 = np.uint8(Obj1)
Obj1[Obj1>0] =255
c1_max = np.where(Obj1 ==255)
centroid1 = np.array([np.mean(c1_max[1]),np.mean(c1_max[0])])
Obj2= b==2
Obj2 = np.uint8(Obj2)
Obj2[Obj2>0] =255
c2_max = np.where(Obj2 ==255)
centroid2 = np.array([np.mean(c2_max[1]),np.mean(c2_max[0])])
上述代码重复直到 b ==10
现在,我在 8GB RAM 的 Raspberry Pi 4 上延迟了大约 160 毫秒。我的同事认为 np.where() 是我代码中的瓶颈。有没有办法进一步优化呢?我的目标循环时间是 50 毫秒。
谢谢
解决方案
剖析1:
当有 100 个连通分量且图像大小为 2000x2000 时,找到质心是最慢的一步。整个程序在笔记本电脑上运行需要 28 秒。
from skimage import measure
from skimage import filters
import numpy as np
import cProfile
def make_blobs(size=256, n_blobs=12):
np.random.seed(1)
im = np.zeros((size, size))
points = size * np.random.random((2, n_blobs ** 2))
im[(points[0]).astype(np.int), (points[1]).astype(np.int)] = 1
im = filters.gaussian(im, sigma=size / (4. * n_blobs))
blobs = im > 0.7 * im.mean()
return blobs
def faster_centroid(img):
s = 1 / np.mean(img)
shape = img.shape
x_coords = np.arange(shape[0])
y_coords = np.arange(shape[1])
x_mean = np.mean(img * x_coords[:, np.newaxis]) * s
y_mean = np.mean(img * y_coords[np.newaxis, :]) * s
return x_mean, y_mean
def label_blobs(blobs):
all_labels = measure.label(blobs)
blobs_labels = measure.label(blobs, background=0)
return all_labels, blobs_labels
def find_all_centroids(all_labels):
max_ix = np.max(all_labels)
centroid_list = []
for i in range(max_ix + 1):
centroid = faster_centroid(all_labels == i)
centroid_list.append(centroid)
return centroid_list
def main():
blobs = make_blobs(2000, n_blobs=100)
# Label connected regions of an integer array.
all_labels, blobs_labels = label_blobs(blobs)
print(all_labels)
all_centroids = find_all_centroids(all_labels)
print(all_centroids)
cProfile.run("main()", "results.cprofile")
剖析 2:
[['<function get_centroids1 at 0x7f0027ba6280>', 1.3774937389971456],
['<function get_centroids2 at 0x7f0027ba6310>', 2.308947408993845],
['<function get_centroids3 at 0x7f0027ba63a0>', 0.695534451995627]]
4.262 main red3.py:61
├─ 2.245 get_centroids2 red3.py:36
│ ├─ 1.258 [self]
│ └─ 0.954 mean <__array_function__ internals>:2
│ [5 frames hidden] <__array_function__ internals>, numpy...
│ 0.954 ufunc.reduce <built-in>:0
├─ 1.334 get_centroids1 red3.py:25
│ ├─ 1.031 where <__array_function__ internals>:2
│ │ [3 frames hidden] <__array_function__ internals>, <buil...
│ │ 1.031 implement_array_function <built-in>:0
│ ├─ 0.188 [self]
│ └─ 0.080 mean <__array_function__ internals>:2
│ [5 frames hidden] <__array_function__ internals>, numpy...
└─ 0.683 get_centroids3 red3.py:51
├─ 0.333 <dictcomp> red3.py:57
├─ 0.233 nonzero <__array_function__ internals>:2
│ [5 frames hidden] <__array_function__ internals>, numpy...
└─ 0.048 [self]
from skimage import measure
from skimage import filters
import numpy as np
#import cProfile
from pyinstrument import Profiler
import timeit
def make_blobs(size=256, n_blobs=12):
np.random.seed(1)
im = np.zeros((size, size))
points = size * np.random.random((2, n_blobs ** 2))
im[(points[0]).astype(np.int), (points[1]).astype(np.int)] = 1
im = filters.gaussian(im, sigma=size / (4. * n_blobs))
blobs = im > 0.7 * im.mean()
return blobs
def label_blobs(blobs):
all_labels = measure.label(blobs)
blobs_labels = measure.label(blobs, background=0)
return all_labels, blobs_labels
def get_centroids1(all_labels):
n_blobs = np.max(all_labels) + 1
centroid_list = []
for i in range(n_blobs):
locations = np.where(all_labels == i)
x_avg = np.mean(locations[1])
y_avg = np.mean(locations[0])
centroid_list.append([x_avg, y_avg])
return centroid_list
def get_centroids2(all_labels):
n_blobs = np.max(all_labels) + 1
centroid_list = []
for i in range(n_blobs):
img = (all_labels == i)
s = 1 / np.mean(img)
shape = img.shape
x_coords = np.arange(shape[0])
y_coords = np.arange(shape[1])
x_mean = np.mean(img * x_coords[:, np.newaxis]) * s
y_mean = np.mean(img * y_coords[np.newaxis, :]) * s
centroid_list.append([x_mean, y_mean])
return centroid_list
def get_centroids3(x):
# https://stackoverflow.com/questions/32748950/
n_blobs = np.max(x) + 1
nz = np.nonzero(x)
coords = np.column_stack(nz)
nzvals = x[nz[0], nz[1]]
res = {k: coords[nzvals == k] for k in range(1, n_blobs + 1)}
return res
def main():
f_list = [get_centroids1, get_centroids2, get_centroids3]
blobs = make_blobs(2000, n_blobs=5)
# Label connected regions of an integer array.
all_labels, blobs_labels = label_blobs(blobs)
profiler = Profiler()
profiler.start()
timings = []
for f in f_list:
s = timeit.default_timer()
for i in range(10):
r = f(all_labels)
e = timeit.default_timer()
print(r)
timings.append([str(f), e - s])
print(timings)
profiler.stop()
print(profiler.output_text(unicode=True, color=True))
main()
推荐阅读
- android - macOS Mojave 上 Android 模拟器的授权问题?
- c++ - 文件读写 - 是否可以在不调用文件“导航”函数的情况下顺序和连续地 I/O 缓冲区的内容?
- jenkins-pipeline - SAP Cloud SDK Jenkins 管道 s4sdk-pipeline.groovy - 跳过生产部署步骤
- c++ - 获取数组长度的指针数学
- dictionary - SWI-Prolog YALL 与 dicts 冲突
- java - 没有任何堆栈跟踪的“本地主机上的服务器 Tomcat v9.0 服务器无法启动”
- python - xarray选择具有多维坐标的最近纬度/经度
- firebase - Flutter 未连接到 Cloud Firestore
- python - 为什么我用python制作的数学游戏打不开?
- redux-form - redux-form 可以接受来自 react-select 组件的多个值吗?