首页 > 解决方案 > 为什么我应该调用 BERT 模块实例而不是 forward 方法?

问题描述

我正在尝试使用变压器库中的 BERT 提取文本的矢量表示,并且偶然发现了“BERTModel”类文档的以下部分:

在此处输入图像描述

任何人都可以更详细地解释这一点吗?前向传递对我来说很直观(毕竟我试图获得最终的隐藏状态),我找不到任何关于“预处理和后处理”在这种情况下意味着什么的额外信息。

预先感谢!

标签: bert-language-modelhuggingface-transformers

解决方案


我认为这只是关于使用 PyTorchModule的一般建议。transformers模块是nn.Modules,它们需要一个方法forward。但是,不应model.forward()手动调用,而应调用model(). 原因是 PyTorch 在调用模块时会在后台做一些事情。您可以在源代码中找到它。

def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        result = hook(self, input)
        if result is not None:
            if not isinstance(result, tuple):
                result = (result,)
            input = result
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            result = hook_result
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result

你会看到它forward在必要时被调用。


推荐阅读