python - PyFFTW 在多维数组上的性能
问题描述
我有一个 nD 数组,比如维度:(144, 522720),我需要计算它的 FFT。
PyFFTW
似乎比numpy
and慢scipy
,这不是预期的。
我在做一些明显错误的事情吗?
下面是我的代码
import numpy
import scipy
import pyfftw
import time
n1 = 144
n2 = 522720
loops = 2
pyfftw.config.NUM_THREADS = 4
pyfftw.config.PLANNER_EFFORT = 'FFTW_ESTIMATE'
# pyfftw.config.PLANNER_EFFORT = 'FFTW_MEASURE'
Q_1 = pyfftw.empty_aligned([n1, n2], dtype='float64')
Q_2 = pyfftw.empty_aligned([n1, n2], dtype='complex_')
Q_ref = pyfftw.empty_aligned([n1, n2], dtype='complex_')
# repeat a few times to see if pyfft planner helps
for i in range(0,loops):
Q_1 = numpy.random.rand(n1,n2)
s1 = time.time()
Q_ref = numpy.fft.fft(Q_1, axis=0)
print('NUMPY - elapsed time: ', time.time() - s1, 's.')
s1 = time.time()
Q_2 = scipy.fft.fft(Q_1, axis=0)
print('SCIPY - elapsed time: ', time.time() - s1, 's.')
print('Equal = ', numpy.allclose(Q_2, Q_ref))
s1 = time.time()
Q_2 = pyfftw.interfaces.numpy_fft.fft(Q_1, axis=0)
print('PYFFTW NUMPY - elapsed time = ', time.time() - s1, 's.')
print('Equal = ', numpy.allclose(Q_2, Q_ref))
s1 = time.time()
Q_2 = pyfftw.interfaces.scipy_fftpack.fft(Q_1, axis=0)
print('PYFFTW SCIPY - elapsed time = ', time.time() - s1, 's.')
print('Equal = ', numpy.allclose(Q_2, Q_ref))
s1 = time.time()
fft_object = pyfftw.builders.fft(Q_1, axis=0)
Q_2 = fft_object()
print('FFTW PURE Elapsed time = ', time.time() - s1, 's')
print('Equal = ', numpy.allclose(Q_2, Q_ref))
解决方案
首先,如果您在主循环之前打开缓存,接口将按预期工作:
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(30)
有趣的是,尽管应该存储智慧pyfftw
,但当缓存关闭时,对象的构建仍然相当缓慢。不管怎样,这正是缓存的目的。在您的情况下,您需要使缓存保持活动时间很长,因为您的循环很长。
fft_object
其次,在最终测试中包含构建时间是不公平的比较。如果您将其移到计时器之外,则调用fft_object
是更好的措施。
第三,有趣的是,即使打开了缓存,调用 tonumpy_fft
也比调用scipy_fft
. 由于代码路径没有明显差异,我建议这是缓存问题。这是一种timeit
寻求缓解的问题。这是我提出的更有意义的时序代码:
import numpy
import scipy
import pyfftw
import timeit
n1 = 144
n2 = 522720
pyfftw.config.NUM_THREADS = 4
pyfftw.config.PLANNER_EFFORT = 'FFTW_MEASURE'
Q_1 = pyfftw.empty_aligned([n1, n2], dtype='float64')
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(30)
times = timeit.repeat(lambda: numpy.fft.fft(Q_1, axis=0), repeat=5, number=1)
print('NUMPY fastest time = ', min(times))
times = timeit.repeat(lambda: scipy.fft.fft(Q_1, axis=0), repeat=5, number=1)
print('SCIPY fastest time = ', min(times))
times = timeit.repeat(
lambda: pyfftw.interfaces.numpy_fft.fft(Q_1, axis=0), repeat=5, number=1)
print('PYFFTW NUMPY fastest time = ', min(times))
times = timeit.repeat(
lambda: pyfftw.interfaces.scipy_fftpack.fft(Q_1, axis=0), repeat=5, number=1)
print('PYFFTW SCIPY fastest time = ', min(times))
fft_object = pyfftw.builders.fft(Q_1, axis=0)
times = timeit.repeat(lambda: fft_object(Q_1), repeat=5, number=1)
print('FFTW PURE fastest time = ', min(times))
在我的机器上,这给出了如下输出:
NUMPY fastest time = 0.6622681759763509
SCIPY fastest time = 0.6572431400418282
PYFFTW NUMPY fastest time = 0.4003451430471614
PYFFTW SCIPY fastest time = 0.40362057799939066
FFTW PURE fastest time = 0.324020683998242
如果您不强制将输入数组复制为复杂数据类型,则可以做得更好,方法是将其更改Q_1
为complex128
:
NUMPY fastest time = 0.6483533839927986
SCIPY fastest time = 0.847397351055406
PYFFTW NUMPY fastest time = 0.3237176960101351
PYFFTW SCIPY fastest time = 0.3199474769644439
FFTW PURE fastest time = 0.2546963169006631
这种有趣scipy
的减速是可以重复的。
也就是说,如果您的输入是真实的,那么您应该进行真正的变换(使用 50% 以上的加速pyfftw
)并处理生成的复杂输出。
这个例子的有趣之处在于(我认为)缓存在结果中的重要性(我认为这就是为什么切换到真正的转换在加快速度方面如此有效)。当您使用将数组大小更改为 524288(下一个 2 的幂,您认为这可能会加快速度,但不会显着减慢速度)时,您也会看到一些戏剧性的东西。在这种情况下,一切都会慢下来,scipy
尤其是。对我来说,它对scipy
缓存更敏感,这可以解释将输入更改为减速的complex128
原因(尽管 522720 对于 FFTing 来说是一个相当不错的数字,所以也许我们应该期待减速)。
最后,如果速度次于准确性,则始终可以使用 32 位浮点数作为数据类型。如果将其与进行真正的变换相结合,则与numpy
上面给出的初始最佳值相比,您将获得 10 倍以上的加速:
PYFFTW NUMPY fastest time = 0.09026529802940786
PYFFTW SCIPY fastest time = 0.1701313250232488
FFTW PURE fastest time = 0.06202622700948268
(numpy 和 scipy 变化不大,因为我认为它们在内部使用 64 位浮点数)。
编辑:我忘记了 Scipyfftpack
真正的 FFT 有一个奇怪的输出结构,它的pyfftw
复制速度有些慢。这在新的 FFT 模块中变得更加明智。
新的 FFT 接口在 pyFFTW 中实现,应该是首选。不幸的是,正在重建的文档存在问题,因此文档已经过时了很长时间并且没有显示新界面 - 希望现在已经修复。
推荐阅读
- javascript - 避免 TO-DO 列表应用程序中的空文本输入
- node.js - 安装 webdriverio 节点时找不到依赖项
- salesforce - Lightning DataTable 中的 keyField 问题
- asp.net - 带有 aspx 扩展名的路由 url 到 mvc 路由
- c++ - 幕后的 C++ 模块
- django - Django - 管理页面登录也登录到用户身份验证
- ios - iOS - 如何检查手电筒/手电筒/闪光灯的开/关状态
- java - 如何在 servlet 中从 Web 服务获取数据
- java - 使用java在json中使用多个键获取键值
- java - 使用 Maven 安装和部署上传 airtifact 到 nexus