首页 > 解决方案 > 如何在钩子函数中检索图层的名称?

问题描述

我有一个神经网络

class ConvNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.trunk = nn.ModuleList()
    self.trunk.add_module('conv1', nn.Conv2d(3, 10, 3))

    self.classifier = nn.Linear(30, 2)
  def forward(self, x):
    out = self.classifier(self.trunk.conv1(x))
    return out

model = ConvNet()

我注册了前向钩子

def hook(module, input, output):
    print(module, input[0].shape, output.shape)

x =  model.trunk.conv1.register_forward_hook(hook)

如何在钩子函数中检索“conv1”层的名称,module._get_namereturns Conv2dmodule.__class__returns <class 'torch.nn.modules.conv.Conv2d'>,如何获得“conv1”?

标签: pythonpytorch

解决方案


推荐阅读