python - 在pytorch中分解一批会导致不同的结果,为什么?
问题描述
我正在尝试在 pytorch 中进行批处理。在我下面的代码中,您可能会认为x
一批批大小为 2(每个样本是一个 10d 向量)。我x_sep
用来表示 中的第一个样本x
。
import torch
import torch.nn as nn
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.fc1 = nn.Linear(10,10)
def forward(self, x):
x = self.fc1(x)
return x
f = net()
x = torch.randn(2,10)
print(f(x[0])==f(x)[0])
理想情况下,f(x[0])==f(x)[0]
应该给出一个包含所有真实条目的张量。但我电脑上的输出是
tensor([False, False, True, True, False, False, False, False, True, False])
为什么会这样?是计算错误吗?或者它与如何在pytorch中实现批处理有关?
更新:我稍微简化了代码。问题还是一样的。
我的推理:
我相信f(x)[0]==f(x[0])
应该有它的所有条目True
,因为矩阵乘法是这样说的。让我们将其x
视为 2x10 矩阵,并将线性变换f()
视为由矩阵表示B
(暂时忽略偏差)。然后f(x)=xB
通过我们的符号。矩阵乘法告诉我们,xB
等于先把右边的两行分别乘以B
,然后再把两行放在一起。翻译回代码,是f(x[0])==f(x)[0]
和f(x[1])==f(x)[1]
。
即使我们考虑偏差,每一行都应该有相同的偏差,并且等式应该仍然成立。
另请注意,这里没有进行任何培训。因此,如何初始化权重应该无关紧要。
解决方案
TL;博士
在引擎盖下,它使用一个名为的函数,该函数addmm
具有一些优化,并且可能以稍微不同的方式相乘向量
我刚刚明白真正的问题是什么,我编辑了答案。
在尝试在我的机器上重现和调试它之后。我发现:
f(x)[0].detach().numpy()
>>>array([-0.5386441 , 0.4983463 , 0.07970242, 0.53507525, 0.71045876,
0.7791027 , 0.29027492, -0.07919329, -0.12045971, -0.9111403 ],
dtype=float32)
f(x[0]).detach().numpy()
>>>array([-0.5386441 , 0.49834624, 0.07970244, 0.53507525, 0.71045876,
0.7791027 , 0.29027495, -0.07919335, -0.12045971, -0.9111402 ],
dtype=float32)
f(x[0]).detach().numpy() == f(x)[0].detach().numpy()
>>>array([ True, False, False, True, True, True, False, False, True,
False])
如果你仔细观察,你会发现所有为 False 的索引,在第 5 个浮点数上有轻微的数字变化。
经过更多调试,我在它使用的线性函数中看到addmm
:
def linear(input, weight, bias=None):
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
当 addmm 时addmm
,实现beta*mat + alpha*(mat1 @ mat2)
并且据说速度更快(例如,请参见此处)。
推荐阅读
- ssl - 如何获取服务器的 SSL 证书?
- python - 使用 Python 从 html 的标题标签中提取字符串
- c# - 可选地将 T 类型的参数传递给 Action
- python - 将字符串转换为日期并删除数据框一列中的非日期
- javascript - 在 Chrome 中进行 CSS 转换缩放时,如何停止边界模糊?
- swift - Mac 触控栏应用程序 - 添加快捷方式以将触控栏置于最前面
- git - 如何使用 git2-rs Rust 板条箱完成“git pull”?
- uwp - 应用程序的日历不再是 Windows 日历应用程序中新约会的选项
- java - Java 嵌套正则表达式组不捕获内部组
- sql - 从字符串吐出的 SQL 查询