首页 > 解决方案 > 如何指定pytorch nn.Linear的输入维度?

问题描述

例如,我定义了如下所示的模型:

class Net(nn.module):
   def __init__():
      self.conv11 = nn.Conv2d(in_channel,out1_channel,3)
      self.conv12 = nn.Conv2d(...)
      self.conv13 = nn.Conv2d(...)
      self.conv14 = nn.Conv2d(...)
      ...
      #Here is the point
      flat = nn.Flatten()
      #I don't want to compute the size of data after flatten, but I need a linear layer.

      fc_out = nn.Linear(???,out_dim)

问题是线性层,我不想计算线性层输入的大小,但定义模型需要指定它。我怎么解决这个问题?

标签: pythondeep-learningpytorch

解决方案


我这样做的方法是输入一些任意值并让模型抛出错误。您将能够在错误描述中看到输入特征的数量。

还有其他方法可以做到这一点。您可以手动计算大小,并在每层旁边写下nn.Conv2d描述层输出的注释。在使用 之前nn.Flatten(),您将获得输出,只需将除 bacthsize 之外的所有维度相乘即可。结果值是nn.Linear()图层的输入特征数。

如果您不想这样做,可以尝试torchlayers。一个方便的包,可让您定义像 Keras 这样的 pytorch 模型。


推荐阅读