首页 > 解决方案 > 来自变压器库的 pytorch 模型中输入词的基于梯度的显着性

问题描述

以下代码用于确定输入字对最可能输出单元的影响。

def _register_embedding_list_hook(model, embeddings_list):
    def forward_hook(module, inputs, output):
        embeddings_list.append(output.squeeze(0).clone().cpu().detach().numpy())
    embedding_layer = model.bert.embeddings.word_embeddings
    handle = embedding_layer.register_forward_hook(forward_hook)
    return handle

def _register_embedding_gradient_hooks(model, embeddings_gradients):
    def hook_layers(module, grad_in, grad_out):
        embeddings_gradients.append(grad_out[0])
    embedding_layer = model.bert.embeddings.word_embeddings
    hook = embedding_layer.register_backward_hook(hook_layers)
    return hook

def saliency_map(model, input_ids, segment_ids, input_mask):
    torch.enable_grad()
    model.eval()
    embeddings_list = []
    handle = _register_embedding_list_hook(model, embeddings_list)
    embeddings_gradients = []
    hook = _register_embedding_gradient_hooks(model, embeddings_gradients)

    model.zero_grad()
    A = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
    pred_label_ids = np.argmax(A.logits[0].detach().numpy())
    A.logits[0][pred_label_ids].backward()
    handle.remove()
    hook.remove()

    saliency_grad = embeddings_gradients[0].detach().cpu().numpy()        
    saliency_grad = np.sum(saliency_grad[0] * embeddings_list[0], axis=1)
    norm = np.linalg.norm(saliency_grad, ord=1)
    saliency_grad = [e / norm for e in saliency_grad] 
    
    return saliency_grad

以下列方式使用(用于情感分析模型):

from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")

tokens = tokenizer('A really bad movie')

input_ids = torch.tensor([tokens['input_ids']], dtype=torch.long)
token_type_ids = torch.tensor([tokens['token_type_ids']], dtype=torch.long)
attention_ids = torch.tensor([tokens['attention_mask']], dtype=torch.long)

saliency_scores = saliency_map(model, input_ids, 
                                token_type_ids, 
                                attention_ids)

但它会为无意义的标记产生以下分数,因为例如“坏”对预测的类别有负面影响(这是负面的)。这段代码有什么问题?

在此处输入图像描述

以下是更多示例:

在此处输入图像描述 在此处输入图像描述

标签: pytorchgradienthuggingface-transformers

解决方案


推荐阅读