python - 如何在pytorch中为nn.Transformer编写一个前向钩子函数?
问题描述
我了解到前向挂钩函数的形式为hook_fn(m,x,y)
. m 指模型,x 指输入,y 指输出。我想为nn.Transformer
.
但是,变压器层需要输入 src 和 tgt。例如,>>> out = transformer_model(src, tgt)
。那么我怎样才能区分这些输入呢?
解决方案
您的钩子将使用tuple s forx
和调用您的回调函数y
。如文档页面中所述torch.nn.Module.register_forward_hook
(它确实很好地解释了类型x
和y
虽然)。
输入仅包含给模块的位置参数。关键字参数不会传递给钩子,而只会传递给转发。[...]。
model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
定义你的回调:
def hook(module, x, y):
print(f'is tuple={isinstance(x, tuple)} - length={len(x)}')
src, tgt = x
print(f'src: {src.shape}')
print(f'tgt: {tgt.shape}')
钩到你的nn.Module
:
>>> model.register_forward_hook(hook)
做一个推断:
>>> out = model(src, tgt)
is tuple=True - length=2
src: torch.Size([10, 32, 512])
tgt: torch.Size([20, 32, 512])
推荐阅读
- c - STM32F103芯片大约每500ms保持一次复位
- python - 如何使用 python/bash 脚本在 Centos7 中列出所有早于“X”小时的文件?
- unix - Unix命令通过按文件类型和名称分组和排序列出所有文件
- mongodb - 如何对 MongoDB 聚合进行分组
- angular - 使用 Jasmine 和 await 测试基于 HTTP 的异步函数:预期有一个对条件“...”的匹配请求,但没有找到
- algorithm - 为什么 Dijkstras 算法的时间复杂度为 O(V^2)
- c++ - 如何在函数中增加 x 的值?
- drools - kContainer.newKieSession 在 drools 7.37 版本中返回 NULL
- vue.js - 如何在 NuxtJS 中设置全局 $axios 标头
- java - 如何从透明的 java swing table 中删除边距/边框