首页 > 解决方案 > 如何在pytorch中为nn.Transformer编写一个前向钩子函数?

问题描述

我了解到前向挂钩函数的形式为hook_fn(m,x,y). m 指模型,x 指输入,y 指输出。我想为nn.Transformer.
但是,变压器层需要输入 src 和 tgt。例如,>>> out = transformer_model(src, tgt)。那么我怎样才能区分这些输入呢?

标签: pythonmachine-learningdeep-learningpytorchhook

解决方案


您的钩子将使用tuple s forx和调用您的回调函数y。如文档页面中所述torch.nn.Module.register_forward_hook它确实很好地解释了类型xy虽然)。

输入仅包含给模块的位置参数。关键字参数不会传递给钩子,而只会传递给转发。[...]。

model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)

定义你的回调:

def hook(module, x, y):
    print(f'is tuple={isinstance(x, tuple)} - length={len(x)}')      
    src, tgt = x
  
    print(f'src: {src.shape}')
    print(f'tgt: {tgt.shape}')

钩到你的nn.Module

>>> model.register_forward_hook(hook)

做一个推断:

>>> out = model(src, tgt)
is tuple=True - length=2
src: torch.Size([10, 32, 512])
tgt: torch.Size([20, 32, 512])

推荐阅读