python - Numpy 沿轴卷积 2 个二维数组
问题描述
我有 2 个二维数组。我试图沿轴 1 进行卷积。np.convolve
没有提供axis
论点。这里的答案是使用 1 维数组与 1 维数组进行卷积np.apply_along_axis
。但它不能直接应用于我的用例。这里的问题没有答案。
MWE如下。
import numpy as np
a = np.random.randint(0, 5, (2, 5))
"""
a=
array([[4, 2, 0, 4, 3],
[2, 2, 2, 3, 1]])
"""
b = np.random.randint(0, 5, (2, 2))
"""
b=
array([[4, 3],
[4, 0]])
"""
# What I want
c = np.convolve(a, b, axis=1) # axis is not supported as an argument
"""
c=
array([[16, 20, 6, 16, 24, 9],
[ 8, 8, 8, 12, 4, 0]])
"""
我知道我可以使用 来做到这一点np.fft.fft
,但完成一件简单的事情似乎是不必要的步骤。有没有一种简单的方法可以做到这一点?谢谢。
解决方案
在正交维度上循环是否会被认为太丑陋?除非主要维度非常短,否则这不会增加太多开销。提前创建输出数组可确保不需要复制任何内存。
def convolvesecond(a, b):
N1, L1 = a.shape
N2, L2 = b.shape
if N1 != N2:
raise ValueError("Not compatible")
c = np.zeros((N1, L1 + L2 - 1), dtype=a.dtype)
for n in range(N1):
c[n,:] = np.convolve(a[n,:], b[n,:], 'full')
return c
对于一般情况(沿一对多维数组的第 k 个轴进行卷积),我会求助于我经常使用的一对辅助函数将多维问题转换为基本的 2d 情况:
def semiflatten(x, d=0):
'''SEMIFLATTEN - Permute and reshape an array to convenient matrix form
y, s = SEMIFLATTEN(x, d) permutes and reshapes the arbitrary array X so
that input dimension D (default: 0) becomes the second dimension of the
output, and all other dimensions (if any) are combined into the first
dimension of the output. The output is always 2-D, even if the input is
only 1-D.
If D<0, dimensions are counted from the end.
Return value S can be used to invert the operation using SEMIUNFLATTEN.
This is useful to facilitate looping over arrays with unknown shape.'''
x = np.array(x)
shp = x.shape
ndims = x.ndim
if d<0:
d = ndims + d
perm = list(range(ndims))
perm.pop(d)
perm.append(d)
y = np.transpose(x, perm)
# Y has the original D-th axis last, preceded by the other axes, in order
rest = np.array(shp, int)[perm[:-1]]
y = np.reshape(y, [np.prod(rest), y.shape[-1]])
return y, (d, rest)
def semiunflatten(y, s):
'''SEMIUNFLATTEN - Reverse the operation of SEMIFLATTEN
x = SEMIUNFLATTEN(y, s), where Y, S are as returned from SEMIFLATTEN,
reverses the reshaping and permutation.'''
d, rest = s
x = np.reshape(y, np.append(rest, y.shape[-1]))
perm = list(range(x.ndim))
perm.pop()
perm.insert(d, x.ndim-1)
x = np.transpose(x, perm)
return x
(请注意,reshape
不要transpose
创建副本,因此这些功能非常快。)
有了这些,通用形式可以写成:
def convolvealong(a, b, axis=-1):
a, S1 = semiflatten(a, axis)
b, S2 = semiflatten(b, axis)
c = convolvesecond(a, b)
return semiunflatten(c, S1)
推荐阅读
- scala - 写入仅具有键和 1 列模式的数据框时,将空值附加到 HBASE 表中的所有其他列以匹配键
- jquery - 如果在 Wordpress 定制器中添加/删除或更新小部件,请重新加载客户端浏览器?
- typescript - 如何添加 mobx 打字稿支持?
- python - 为大型矩阵找到最小二乘解的更快方法
- python - Python - 如何通过其元素的内容过滤列表?
- powershell - PowerShell,替换名称相同但小写的标头
- tensorflow - tensorflow tf.profile 计算的 FLOPs 是多少?
- android - 我很少发生无法从 onDestroy 访问 ViewModel
- sql - 我可以在 jOOQ SQL 执行器中使用命名参数吗?
- java - Jpanel错误重绘