python-3.x - 加速 Python 程序(自适应中值滤波器)
问题描述
嗨,有人可以改进这段代码吗?该代码是关于自适应中值滤波器的。在处理大图像时,代码非常慢。
import numpy as np
def padding(img,pad):
padded_img = np.zeros((img.shape[0]+2*pad,img.shape[1]+2*pad))
padded_img[pad:-pad,pad:-pad] = img
return padded_img
def AdaptiveMedianFilter(img,s=3,sMax=7):
if len(img.shape) == 3:
raise Exception ("Single channel image only")
H,W = img.shape
a = sMax//2
padded_img = padding(img,a)
f_img = np.zeros(padded_img.shape)
for i in range(a,H+a+1):
for j in range(a,W+a+1):
value = Lvl_A(padded_img,i,j,s,sMax)
f_img[i,j] = value
return f_img[a:-a,a:-a]
def Lvl_A(mat,x,y,s,sMax):
window = mat[x-(s//2):x+(s//2)+1,y-(s//2):y+(s//2)+1]
Zmin = np.min(window)
Zmed = np.median(window)
Zmax = np.max(window)
A1 = Zmed - Zmin
A2 = Zmed - Zmax
if A1 > 0 and A2 < 0:
return Lvl_B(window)
else:
s += 2
if s <= sMax:
return Lvl_A(mat,x,y,s,sMax)
else:
return Zmed
def Lvl_B(window):
h,w = window.shape
Zmin = np.min(window)
Zmed = np.median(window)
Zmax = np.max(window)
Zxy = window[h//2,w//2]
B1 = Zxy - Zmin
B2 = Zxy - Zmax
if B1 > 0 and B2 < 0 :
return Zxy
else:
return Zmed
有没有办法改进这段代码?例如使用矢量化滑动窗口?我不知道如何使用什么 numpy 函数。Ps:对于边界检查它使用填充,所以它不必检查越界。
解决方案
numbanjit
非常适合这种计算。与parallel=True
+混合prange
可以更快。此外,您可以将最小值、最大值和中值传递给Lvl_B
而不是像@CrisLuengo 指出的那样重新计算它们。
这是修改后的代码:
import numpy as np
from numba import njit,prange
@njit
def padding(img,pad):
padded_img = np.zeros((img.shape[0]+2*pad,img.shape[1]+2*pad))
padded_img[pad:-pad,pad:-pad] = img
return padded_img
@njit(parallel=True)
def AdaptiveMedianFilter(img,s=3,sMax=7):
if len(img.shape) == 3:
raise Exception ("Single channel image only")
H,W = img.shape
a = sMax//2
padded_img = padding(img,a)
f_img = np.zeros(padded_img.shape)
for i in prange(a,H+a+1):
for j in range(a,W+a+1):
value = Lvl_A(padded_img,i,j,s,sMax)
f_img[i,j] = value
return f_img[a:-a,a:-a]
@njit
def Lvl_A(mat,x,y,s,sMax):
window = mat[x-(s//2):x+(s//2)+1,y-(s//2):y+(s//2)+1]
Zmin = np.min(window)
Zmed = np.median(window)
Zmax = np.max(window)
A1 = Zmed - Zmin
A2 = Zmed - Zmax
if A1 > 0 and A2 < 0:
return Lvl_B(window, Zmin, Zmed, Zmax)
else:
s += 2
if s <= sMax:
return Lvl_A(mat,x,y,s,sMax)
else:
return Zmed
@njit
def Lvl_B(window, Zmin, Zmed, Zmax):
h,w = window.shape
Zxy = window[h//2,w//2]
B1 = Zxy - Zmin
B2 = Zxy - Zmax
if B1 > 0 and B2 < 0 :
return Zxy
else:
return Zmed
这段代码在我的机器上使用 256x256 随机图像快 500 倍。
请注意,由于(包含)编译时间,第一次调用不会快得多。
另请注意,由于滑动窗口共享许多值,因此无需重新计算每个值的最小/最大/中值,计算速度会更快(参见论文恒定时间中值滤波 (Perreault et al, 2007))。
推荐阅读
- r - 如何将 table 函数与长度为 2 的输入集合一起使用?
- pyqt5 - PyQt5 设置 qt_ntfs_permission_lookup
- javascript - API字符串中的箭头?
- php - 在 Apache2 上设置 php-fpm 的 /status 页面
- r - R Shiny中具有反应数据的动态图数
- iis - 如何在 iis 10 中的 Url 重定向中添加规则
- android - 是否可以使用 HSM 保护 Android KeyChain?
- javascript - 有没有办法让 JavaScript 在乘以数字时将字符串读取为字符串?
- amazon-web-services - 使用 for_each 遍历具有多个值的对象映射
- javascript - 如何在 JavaScript 中打印 BigInt 而不会丢失精度?