python - 在 Pytorch 中实现 LeNet
问题描述
对不起,如果这个问题非常基本。感觉网上资源挺丰富的,但是大部分都是半成品或者略过我想了解的细节。
我正在尝试用 Pytorch 实现 LeNet 进行练习。
https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
为什么这个例子和网上的许多例子,他们在init中定义了卷积层和fc层,而在forward中定义了子采样和激活函数?
将 torch.nn.functional 用于某些功能,而将 torch.nn 用于其他功能的目的是什么?例如,您有与 torch.nn 的卷积(https://pytorch.org/docs/stable/nn.html#conv1d)和与 torch.nn.functional 的卷积(https://pytorch.org/docs/stable/ nn.functional.html#conv1d)。为什么选择其中一个?
假设我想尝试不同的图像尺寸,例如 28x28 (MNIST)。本教程建议我调整 MNIST 的大小。有没有办法改变 LeNet 的值?如果我不改变它们会怎样?
num_flat_features 的目的是什么?如果你想扁平化特征,你不能只做 x = x.view(-1, 16*5*5) 吗?
解决方案
为什么这个例子和网上的许多例子,他们在init中定义了卷积层和fc层,而在forward中定义了子采样和激活函数?
任何具有可训练参数的层都应在__init__
. __init__
二次采样、某些激活、 dropout等。没有任何可训练的参数,因此可以torch.nn.functional
在forward
.
将 torch.nn.functional 用于某些功能,而将 torch.nn 用于其他功能的目的是什么?
这些torch.nn.functional
函数是在大多数torch.nn
层的核心使用的实际函数,它们调用 C++ 编译代码。例如nn.Conv2d
subclasses nn.Module
,任何包含可训练参数的自定义层或模型也应如此。该类处理注册参数并封装训练和测试所需的其他一些必要功能。在这期间forward
它实际上是nn.functional.conv2d
用来应用卷积操作的。nn.ReLU
正如第一个问题中提到的,当执行像 ReLU 这样的无参数操作时,使用类和nn.functional.relu
函数之间实际上没有区别。
提供它们的原因是它们提供了一些做非常规事情的自由。例如,在我前几天写的这个答案中,提供一个没有的解决方案nn.functional.conv2d
会很困难。
假设我想尝试不同的图像尺寸,例如 28x28 (MNIST)。本教程建议我调整 MNIST 的大小。有没有办法改变 LeNet 的值?如果我不改变它们会怎样?
没有明显的方法可以更改现有的、经过训练的模型以支持不同的图像大小。线性层输入的大小必然是固定的,模型中该点的特征数量通常由网络输入的大小决定。如果输入的大小与模型设计的大小不同,那么当数据进入线性层时,它将具有错误的元素数量并导致程序崩溃。一些模型可以处理一系列输入大小,通常使用类似nn.AdaptiveAvgPool2d
线性层之前的层,以确保线性层的输入形状始终相同。即便如此,如果输入图像尺寸太小,那么网络中的下采样和/或池化操作将导致特征图在某些时候消失,从而导致程序崩溃。
num_flat_features 的目的是什么?如果你想扁平化特征,你不能只做 x = x.view(-1, 16*5*5) 吗?
当你定义线性层时,你需要告诉它权重矩阵有多大。线性层的权重只是一个不受约束的矩阵(和偏置向量)。因此,权重矩阵的形状由输入形状决定,但在向前运行之前您不知道输入形状,因此在初始化模型时需要将其作为附加参数(或硬编码)提供。
要解决实际问题。是的,在forward
你可以简单地使用
x = x.view(-1, 16*5*5)
更好的是,使用
x = torch.flatten(x, start_dim=1)
本教程是在.flatten
函数添加到库之前编写的。作者实际上只是编写了他们自己的扁平化功能,无论x
. 这可能是因为您有一些可移植的代码,可以在您的模型中使用,而无需硬编码大小。从编程的角度来看,概括这些事情很好,因为这意味着如果您决定更改模型的一部分,您不必担心更改这些幻数(尽管这种担忧似乎没有扩展到初始化)。
推荐阅读
- vue.js - vue 文件 + vue cli + npm 的开发名称是什么
- html - 位置跨度彼此相邻
- types - 错误 - “表达式必须具有类类型”
- c# - 当我通过 foreach 循环时,它会将列表中的所有内容更改为添加的新数据?
- jsp - 在 scriptlet 中实例化类的问题
- r - 仅使用 data.table 操作以数据表的形式获取每个组的第一个到最后一个元素
- excel - 如果我清除单元格 B,则 VBA 清除单元格 C
- javascript - 按下按钮时添加点
- nlp - 损失函数负对数似然给出损失尽管完美的准确性
- mysql - nodejs mysql池连接只是空闲