python - 输入到 nn.Linear(in_features=16*4*4, out_features=100)
问题描述
我正在使用以下模型在 MNIST 数据集上执行 CNN:
class ConvNet(nn.Module):
def __init__(self, mode):
super(ConvNet, self).__init__()
# Define various layers here, such as in the tutorial example
# self.conv1 = nn.Conv2D(...)
#First Convolution Kayer
#input size (28,28), output size = (24,24)
self.conv1 = nn.Conv2d(1,6,5)
self.reLU1 = nn.ReLU(inplace=True)
self.MaxPool1 = nn.MaxPool2d(kernel_size=2)
#Second Convolution Layer
#input size (12,12), output_size = (8,8)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
self.reLU2 = nn.ReLU(inplace=True)
self.MaxPool2 = nn.MaxPool2d(kernel_size=2)
#Affine operations
self.fc1 = nn.Linear(in_features = 16*4*4, out_features = 100)
self.sig = torch.nn.Sigmoid()
self.fc2 = nn.Linear(in_features=100, out_features=10)
我的前向传球定义如下。
def forward_pass(self, X):
#Conv Layer #1
X = self.conv1(X)
X = self.reLU1(X)
X = self.MaxPool1(X)
#Conv Layer #2
X = self.conv2(X)
X = self.reLU2(X)
X = self.MaxPool2(X)
print(Tensor.size(X))
#X = X.view()
X = self.fc1(X)
X = self.sig(X)
X = self.fc2(X)
return X
尝试将张量传递到完全连接的layer #1 (fc1)
. 这是由于in_features
我上一个卷积层的不匹配。
当我在全连接层之前打印出张量 X 的大小时,tensor.Size([10,16,4,4]).
谁能向我解释计算第一个全连接层的输入的正确方法是什么?
解决方案
您的分类器的输入是 shape (10, 16, 4, 4)
,丢弃与批量大小相对应的第一个维度,最终得到16*4*4
元素。所以这是正确的,但形状不是:在将张量馈送到 之前,您需要展平空间维度fc1
。您可以使用nn.Flatten
:
class ConvNet(nn.Module):
def __init__(self, mode):
super(ConvNet, self).__init__()
## layer definitions
self.flatten = nn.Flatten()
def forward(self, X):
## inference on CNN
X = self.flatten(X)
## inference on fully-connected layers
这是一个推理示例:
>>> model = ConvNet(mode=None)
>>> model(torch.rand(10, 1, 24, 24))
torch.Size([10, 10])
旁注请命名您的函数forward
而不是forward_pass
,这是标准做法。
推荐阅读
- javascript - 2个相同形式的模态,但1个按钮没有关闭
- python - 单击时如何禁用按钮?
- hadoop - 更改表列名 parquet 格式 Hadoop
- mysql - 更新非常大的数据库中的字符串
- c - 如何从相对二进制文件中删除记录?
- c++ - 有没有办法证明下面的第二个片段在函数声明之前插入了一个不可见的声明`struct S;`?
- z3 - 指定 Z3 sum 类型的构造函数的无量词方式
- php - 如何组成一个包含多个部分的视图,这些部分具有需要跨越的变量?
- angular - 尝试设置最后一个活动的 MatTab 时出现 ExpressionChangedAfterItHasBeenCheckedError?
- python - numpy中的迭代器协议