python - 我的 PyTorch 转发功能可以做额外的操作吗?
问题描述
通常,一个forward
函数将一堆层串在一起并返回最后一层的输出。在返回之前,我可以在最后一层之后做一些额外的处理吗?例如,一些标量乘法和通过.view
?
我知道 autograd 以某种方式计算出渐变。所以我不知道我的额外处理是否会以某种方式搞砸。谢谢。
解决方案
pytorch通过张量的计算图而不是函数来跟踪梯度。只要您的张量具有属性并且它们不是,您就可以(几乎)做任何您喜欢的事情并且仍然能够进行反向传播。
只要您使用 pytorch 的操作(例如,此处和此处列出的操作),您应该没问题。requires_grad=True
grad
None
有关更多信息,请参阅此。
例如(取自torchvision 的 VGG 实现):
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
# ...
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1) # <-- what you were asking about
x = self.classifier(x)
return x
在torchvision 的 ResNet 实现中可以看到一个更复杂的例子:
class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
# ...
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None: # <-- conditional execution!
identity = self.downsample(x)
out += identity # <-- inplace operations
out = self.relu(out)
return out
推荐阅读
- python - 获取角度范围 0 - 2*pi - python
- python - rpy2安装导致UnsatisfiableError与blas冲突
- context-free-grammar - 给定语言的上下文无关语法和 pda
- android - 使用 Chromecast SDK 确定是否在接收器上为 android 启用了隐藏式字幕
- gitlab - 如何在项目管道中获取顶级组名称?
- events - Electron(v5-v7) webview 不接受击键
- python - 为所有切片器选项导出到 PDF Power BI 仪表板
- django - 如何在表单中动态添加字段?- 实现一个 django 应用程序来跟踪订单状态
- c - C中的char数组到浮点转换错误
- java - 试图用tomcat通过球衣中的希伯来语字符