首页 > 解决方案 > nn.Linear层在pytorch中附加维度的应用

问题描述

pytorch中的全连接层(nn.Linear)如何应用于“附加维度”?该文档说,它可以应用于将张量连接(N,*,in_features)(N,*,out_features),其中N批次中的示例数量,因此它是无关紧要的,并且*是那些“附加”维度。这是否意味着使用附加维度中的所有可能切片训练单个层,或者是为每个切片训练单独的层或其他不同的层?

标签: pytorchtensor

解决方案


in_features * out_features学习的参数linear.weightout_features学习的参数linear.bias。你可以把nn.Linear工作想象成

  1. 将张量重塑为一些(N', in_features),其中N'是 的乘积N和用 描述的所有维度*input_2d = input.reshape(-1, in_features)
  2. 应用标准矩阵-矩阵乘法output_2d = linear.weight @ input_2d
  3. 添加偏差output_2d += linear.bias.reshape(1, in_features)(注意我们在所有N'维度上广播它)
  4. input除了最后一个之外,将输出重塑为具有与 相同的尺寸:output = output_2d.reshape(*input.shape[:-1], out_features)
  5. return output

因此,前导维度N与维度相同*。文档N明确让您知道输入必须至少是2d,但可以是任意多维。


推荐阅读