python - 我的 PyTorch 模型中不同层的名称是什么?
问题描述
我在 PyTorch 中有以下模型:
UNet3D(
(encoders): ModuleList(
(0): Encoder(
(basic_module): DoubleConv(
(SingleConv1): SingleConv(
(groupnorm): GroupNorm(1, 5, eps=1e-05, affine=True)
(conv): Conv3d(5, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
(ReLU): ReLU(inplace=True)
)
(SingleConv2): SingleConv(
(groupnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
(conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
(ReLU): ReLU(inplace=True)
)
)
)
(1): Encoder(
(pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(basic_module): DoubleConv(
(SingleConv1): SingleConv(
(groupnorm): GroupNorm(8, 64, eps=1e-05, affine=True)
(conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
(ReLU): ReLU(inplace=True)
)
(SingleConv2): SingleConv(
(groupnorm): GroupNorm(8, 64, eps=1e-05, affine=True)
(conv): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
(ReLU): ReLU(inplace=True)
)
)
)
(2): Encoder(
(pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(basic_module): DoubleConv(
(SingleConv1): SingleConv(
(groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True)
(conv): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
(ReLU): ReLU(inplace=True)
)
(SingleConv2): SingleConv(
(groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True)
(conv): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
(ReLU): ReLU(inplace=True)
)
)
)
有人可以告诉我这里不同层的名称是什么吗?例如,“编码器 (0)”?我想从模型中提取中间层输出,所以我需要每一层的名称。
解决方案
名称由括号内的内容给出。请注意,ModuleList 是一个列表类型,因此其中的模块由索引寻址。
pytorch 论坛通常对此非常有用。这篇文章描述了如何访问和更改层,但它同样适用于注册前向挂钩。例如,在你的情况下
model.encoders[0].basic_module
将在第一个编码器中为您提供 basic_module。
推荐阅读
- python - 使用 Python 将数据 Excel 导出到谷歌表格
- kubernetes - 是否可以为 Kubernetes Jobs 提供一个工作池以避免创建 Pod 时间?
- laravel - laravel如何在存储firebase中上传图像
- reactjs - 如何在我的子组件中配置路由?反应路由器dom
- python - Python - 似乎无法找到如何通过其路径获取文件
- scala - 设置 UDF 返回的 DecimalType 的精度
- lua - GtkSourceView等Gtk类型如何注册到lgi lua
- scala - Spark GraphX:如何从边列表中加载折线图?
- blazor - 在同一方法调用中使其可见时如何聚焦元素?
- firebase - Flutter - 如何在 Firebase Firestore 上以正确格式存储长值?