首页 > 解决方案 > 如何在 ndim > 2 数组中使用`apply_along_axis`?

问题描述

我正在尝试对玩具数字数据集图像应用高斯滤波。它将图像存储在 (1797, 8, 8) 数组中。单独地,我可以使它工作,但是当我尝试将它应用于整个图像集时apply_along_axis,出现了问题。

这是核心示例:

from sklearn.datasets import load_digits
from scipy.ndimage.filters import gaussian_filter
images = load_digits().images

# Filter individually
individual = gaussian_filter(images[0], sigma=1, order=0)

# Use `apply_along_axis`
transformed = np.apply_along_axis(
    func1d=lambda x: gaussian_filter(x, sigma=1, order=0),
    axis=2,
    arr=images
)

# They produce different arrays
(transformed[0] != individual).all()
Out: True

我试图改变轴,但没有帮助。我还检查了,首先,简单地返回图像/平方值。在这些情况下,结果似乎相同。然而,应用点积再次产生不同的结果。

# Squared values
transformed = np.apply_along_axis(
    func1d=lambda x: x ** 2,
    axis=2,
    arr=images
)

# They produce the same arrays
(transformed[0] == images[0] ** 2).all()
Out: True

# Dot product
transformed = np.apply_along_axis(
    func1d=lambda x: np.dot(x, x),
    axis=2,
    arr=images
)

individual = np.dot(images[0], images[0])

# They produce different arrays
(transformed[0] != individual).all()
Out: True

我确定我误解了这些功能的工作方式。我究竟做错了什么?

更新:正如@hpaulj 在评论中指出的那样,func1d参数 inapply_along_axis只接受一维数组。看...

标签: pythonarraysnumpy

解决方案


推荐阅读