首页 > 解决方案 > 函数 torch.rfft() 和 torch.irfft() 的问题

问题描述

我需要运行为旧版 PyTorch 编写的代码。我有 1.9 版。从 1.8 版开始,PyTorch 引入了函数torch.fft.rfft()and torch.fft.irfft(),它们的工作方式与旧的torch.rfft()torch.irfft(). 我不知道如何替换这些函数,以便这段代码的工作方式与旧版本完全相同:

1.8版:

    fU = torch.rfft( u, 1, onesided=False)
    U = torch.irfft(fU, 1, onesided=False)
    torch.fft(x, x.ndim)

请帮我

标签: pythonpython-3.xpytorch

解决方案


正如您所提到的,torch.rfft()并且在输入和输出格式方面torch.irfft()存在差异。(参见https://github.com/pytorch/pytorch/wiki/The-torch.fft-module-in-PyTorch-1.7torch.fft.rfft()torch.fft.irfft()

根据PyTorch 的 GitHub 页面中的issue,我对二维操作进行了如下更改:

spectrum = torch.fft.rfft2(signal)
spectrum = torch.complex(spectrum[..., 0], spectrum[..., 1])
signal_recovered = torch.fft.irfft2(spectrum)

推荐阅读