python - NumPy,如何有效地进行涉及 ndarray 其他元素的元素操作(无循环)
问题描述
嗨)我正在尝试使用循环到 NumPy 操作来重构我的代码,以使代码更快。任何线索如何做到这一点?此代码根据 2D ndarray 中相邻元素的值为每个元素分配一个值,我找不到此类特定人员的任何答案。
这是在此处描述的照片上查找鞍点的 6 个邻居方法的实现https://documentcloud.adobe.com/link/track?uri=urn:aaid:scds:US:978c30d2-4888-491c-85c1-3949ea6166e9
它采用当前元素与其相邻元素的差异。然后它计算这些差异的符号变化,如果它> = 4,那么它就是鞍点。
有没有可能没有循环?
抱歉,如果问题不清楚或格式不正确 - 这是我在 StackOverflow 上发布的第一个问题
def findSaddlePoints6neibours(gray):
gray = gray.astype(int)
h = gray.shape[0]
w = gray.shape[1]
number = 0
result = np.zeros((h, w))
for y in range(1, h - 1):
for x in range(1, w - 1):
center = gray[y][x]
neiboursDiff = []
neiboursDiff.append(gray[y-1][x] - center)
neiboursDiff.append(gray[y-1][x+1] - center)
neiboursDiff.append(gray[y][x+1] - center)
neiboursDiff.append(gray[y+1][x] - center)
neiboursDiff.append(gray[y+1][x-1] - center)
neiboursDiff.append(gray[y][x-1] - center)
changes = 0
for i in range(5):
if (neiboursDiff[i] < 0 and neiboursDiff[i+1] > 0) or (neiboursDiff[i] > 0 and neiboursDiff[i+1] < 0):
changes += 1
if (neiboursDiff[0] < 0 and neiboursDiff[5] > 0) or (neiboursDiff[0] > 0 and neiboursDiff[5] < 0):
changes += 1
if changes > 3:
number += 1
result[y][x] = 1
return [result, number]
解决方案
这是一种矢量化解决方案:
import numpy as np
def findSaddlePoints6neibours_vec(gray):
gray = np.asarray(gray, dtype=int)
center = gray[1:-1, 1:-1]
diffs = [
gray[:-2, 1:-1],
gray[:-2, 2:],
gray[1:-1, 2:],
gray[2:, 1:-1],
gray[2:, :-2],
gray[1:-1, :-2],
]
diffs.append(diffs[0])
diffs = np.stack(diffs)
diffs -= center
sign_changes = np.count_nonzero(diffs[:-1] * diffs[1:] < 0, axis=0)
is_saddle = sign_changes > 3
number = np.count_nonzero(is_saddle)
result = np.pad(is_saddle, ((1, 1), (1, 1)), mode='constant').astype(int)
return result, number
快速测试:
import numpy as np
# Make example input
np.random.seed(100)
gray = np.random.randint(-10, 10, size=(80, 100))
# The original function
result1, number1 = findSaddlePoints6neibours(gray)
# The vectorized function
result2, number2 = findSaddlePoints6neibours_vec(gray)
# Check results match
print(number1 == number2)
# True
print(np.all(result1 == result2))
# True
# Compare run times
%timeit findSaddlePoints6neibours(gray)
# 31.1 ms ± 682 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit findSaddlePoints6neibours_vec(gray)
# 247 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
编辑:
上述函数的缺点是占用内存较多。如果您可以使用 Numba,则可以编译该函数并在使用并行化时使其更快:
import numba as nb
@nb.njit(parallel=True)
def findSaddlePoints6neibours_nb(gray):
gray = gray.astype(np.int32)
h = gray.shape[0]
w = gray.shape[1]
number = 0
result = np.zeros((h, w), dtype=np.int32)
neiboursDiff = np.empty(7, dtype=np.int32)
for y in nb.prange(1, h - 1):
for x in np.prange(1, w - 1):
neiboursDiff[0] = gray[y-1][x]
neiboursDiff[1] = gray[y-1][x+1]
neiboursDiff[2] = gray[y][x+1]
neiboursDiff[3] = gray[y+1][x]
neiboursDiff[4] = gray[y+1][x-1]
neiboursDiff[5] = gray[y][x-1]
neiboursDiff[6] = neiboursDiff[0]
neiboursDiff -= gray[y, x]
changes = np.sum(neiboursDiff[:-1] * neiboursDiff[1:] < 0)
is_saddle = int(changes > 3)
number += is_saddle
result[y, x] = is_saddle
return result, number
继续上面的小基准:
%timeit findSaddlePoints6neibours_nb(gray)
# 114 µs ± 496 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
推荐阅读
- bash - 在 shell 脚本中循环用户输入的目录
- reactjs - API 数据未使用 useEffect() 在 img 标签中呈现
- python - 有人可以解释一下这个带有 tkinter 的程序是如何工作的吗?
- linux - Bash:为什么 kill -STOP 一个子进程使我的子进程和当前进程在 ps 中显示相同的命令行
- python - 查找系统的空闲时间
- .htaccess - 有没有办法让 htaccess 在匹配的 URL 上返回空白页?
- python - 包含上的 Google Cloud Datastore Python3 AttributeError
- javascript - 将 sql 查询返回到 txt 文件 - Node JS
- jquery - 使用自定义模板作为评级的 Kendo ui 下拉列表 jquery 无法正确显示所选值 - Telerik kendo ui jquery
- javascript - 如何删除表单的ajax部分