pytorch - 使用 pytorch 验证卷积定理
问题描述
基本上这个定理表述如下:
F(f*g) = F(f)xF(g)
我知道这个定理,但我只是无法使用 pytorch 重现结果。
以下是可重现的代码:
import torch
import torch.nn.functional as F
# calculate f*g
f = torch.ones((1,1,5,5))
g = torch.tensor(list(range(9))).view(1,1,3,3).float()
conv = F.conv2d(f, g, bias=None, padding=2)
# calculate F(f*g)
F_fg = torch.rfft(conv, signal_ndim=2, onesided=False)
# calculate F x G
f = f.squeeze()
g = g.squeeze()
# need to pad into at least [w1+w2-1, h1+h2-1], which is 7 in our case.
size = f.size(0) + g.size(0) - 1
f_new = torch.zeros((7,7))
g_new = torch.zeros((7,7))
f_new[1:6,1:6] = f
g_new[2:5,2:5] = g
F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
FxG = torch.mul(F_f, F_g)
print(FxG - F_fg)
这是print(FxG - F_fg)的结果
tensor([[[[[ 0.0000e+00, 0.0000e+00],
[ 4.1426e+02, 1.7270e+02],
[-3.6546e+01, 4.7600e+01],
[-1.0216e+01, -4.1198e+01],
[-1.0216e+01, -2.0223e+00],
[-3.6546e+01, -6.2804e+01],
[ 4.1426e+02, -1.1427e+02]],
...
[[ 4.1063e+02, -2.2347e+02],
[-7.6294e-06, 2.2817e+01],
[-1.9024e+01, -9.0105e+00],
[ 7.1708e+00, -4.1027e+00],
[-2.6739e+00, -1.1121e+01],
[ 8.8471e+00, 7.1710e+00],
[ 4.2528e+01, 9.7559e+01]]]]])
你可以看到差异并不总是0。
有人可以告诉我为什么以及如何正确执行此操作吗?
谢谢
解决方案
所以我仔细看看你到目前为止做了什么。我在您的代码中确定了三个错误来源。我将尝试在这里充分解决它们中的每一个问题。
1. 复杂算术
PyTorch 目前不支持复数乘法 (AFAIK)。FFT 操作只是返回一个具有实数和虚数维度的张量。我们需要显式编码复数乘法,而不是使用torch.mul
or运算符。*
(a + ib) * (c + id) = (a*c - b*d) + i(a*d + b*c)
2.卷积的定义
CNN文献中经常使用的“卷积”的定义,实际上与讨论卷积定理时使用的定义不同。我不会详细介绍,但是理论上的定义是在滑动和乘法之前翻转内核。相反,pytorch、tensorflow、caffe等中的卷积操作……并没有做这种翻转。
考虑到这一点,我们可以在应用 FFT 之前简单地翻转g
(水平和垂直)。
3. 锚点位置
假设使用卷积定理时的锚点是填充的左上角g
。同样,我不会对此进行详细介绍,但这就是数学运算的方式。
第二点和第三点通过一个例子可能更容易理解。假设您使用了以下内容g
[1 2 3]
[4 5 6]
[7 8 9]
而不是g_new
成为
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 2 3 0 0]
[0 0 4 5 6 0 0]
[0 0 7 8 9 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
实际上应该是
[5 4 0 0 0 0 6]
[2 1 0 0 0 0 3]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[8 7 0 0 0 0 9]
我们垂直和水平翻转内核,然后应用循环移位,使内核的中心位于左上角。
我最终重写了您的大部分代码并对其进行了概括。最复杂的操作是g_new
正确定义。我决定使用网格网格和模运算来同时翻转和移动索引。如果这里的某些内容对您没有意义,请发表评论,我会尽力澄清。
import torch
import torch.nn.functional as F
def conv2d_pyt(f, g):
assert len(f.size()) == 2
assert len(g.size()) == 2
f_new = f.unsqueeze(0).unsqueeze(0)
g_new = g.unsqueeze(0).unsqueeze(0)
pad_y = (g.size(0) - 1) // 2
pad_x = (g.size(1) - 1) // 2
fcg = F.conv2d(f_new, g_new, bias=None, padding=(pad_y, pad_x))
return fcg[0, 0, :, :]
def conv2d_fft(f, g):
assert len(f.size()) == 2
assert len(g.size()) == 2
# in general not necessary that inputs are odd shaped but makes life easier
assert f.size(0) % 2 == 1
assert f.size(1) % 2 == 1
assert g.size(0) % 2 == 1
assert g.size(1) % 2 == 1
size_y = f.size(0) + g.size(0) - 1
size_x = f.size(1) + g.size(1) - 1
f_new = torch.zeros((size_y, size_x))
g_new = torch.zeros((size_y, size_x))
# copy f to center
f_pad_y = (f_new.size(0) - f.size(0)) // 2
f_pad_x = (f_new.size(1) - f.size(1)) // 2
f_new[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x] = f
# anchor of g is 0,0 (flip g and wrap circular)
g_center_y = g.size(0) // 2
g_center_x = g.size(1) // 2
g_y, g_x = torch.meshgrid(torch.arange(g.size(0)), torch.arange(g.size(1)))
g_new_y = (g_y.flip(0) - g_center_y) % g_new.size(0)
g_new_x = (g_x.flip(1) - g_center_x) % g_new.size(1)
g_new[g_new_y, g_new_x] = g[g_y, g_x]
# take fft of both f and g
F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
# complex multiply
FxG_real = F_f[:, :, 0] * F_g[:, :, 0] - F_f[:, :, 1] * F_g[:, :, 1]
FxG_imag = F_f[:, :, 0] * F_g[:, :, 1] + F_f[:, :, 1] * F_g[:, :, 0]
FxG = torch.stack([FxG_real, FxG_imag], dim=2)
# inverse fft
fcg = torch.irfft(FxG, signal_ndim=2, onesided=False)
# crop center before returning
return fcg[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x]
# calculate f*g
f = torch.randn(11, 7)
g = torch.randn(5, 3)
fcg_pyt = conv2d_pyt(f, g)
fcg_fft = conv2d_fft(f, g)
avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()
print('Average difference:', avg_diff)
这给了我
Average difference: 4.6866085767760524e-07
这非常接近于零。我们没有得到完全为零的原因仅仅是由于浮点错误。
推荐阅读
- asp.net-core - ASP.NET CORE 中的角色管理器
- json - MongoDB Compass->Bad PartialFilterExpression: SyntaxError: Unexpected token e in JSON at position 1
- r - 根据列表名称(即 A1)对嵌套列表进行排序
- javascript - JavaScript - EventListener 在外部 js 文件中不起作用
- ios - 如何将具有多个单元格的 UITableView 嵌入到 UITableViewCell 中并让 automaticDimension 正常工作?
- javascript - 从javascript中的对象获取多个不同的值
- javascript - 当我尝试使用 Firebase 注销用户时,为什么 Redux 会抛出 null 错误?
- spring-security - Spring Security - OAuth 2.0 客户端 - 客户端凭证授予
- python - 如何在熊猫中拆分 str.split() 列的输出?
- java - 在 OpenJDK11/OpenJFX 上运行较旧的 Swing 代码 - 如何修复 IllegalAccessError?