python - 在不计算整个句子的情况下估计给定句子的标记概率/logits
问题描述
我有这样一句话: "I like sitting in my new chair and _____ about life"
。
而且我有一组特定的令牌,例如["watch", "run", "think", "apple", "light"]
我想计算每个标记在该不完整句子中作为下一个单词出现的概率。希望我应该得到例如的概率"think"
更高"apple"
。
我正在使用 pytorch-transformers(特别是 GPT2LMHeadModel),一个可能的解决方案是使用每个标记评估整个句子的分数,但是当要评估的标记数量约为 100 或 1000 时,计算时间开始太长了。
必须可以只处理一次句子并以某种方式使用隐藏状态来计算标记集的概率,但我不知道该怎么做。
有任何想法吗?提前致谢
编辑:
实际代码如下所示(估计每次完整句子的概率)。对于每个句子,运行该方法大约需要 0.1 秒,score()
如果我想评估数千个单词,这会变成几个小时。
from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
def score(sentence):
tokenize_input = tokenizer.tokenize(sentence)
tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
loss = model(tensor_input, labels=tensor_input)
return -loss[0].item()
candidates = ["watch", "run", "think", "apple", "light"]
sent_template = "I like sitting in my new chair and {} about life"
print({candidate: score(sent_template.format(candidate)) for candidate in candidates})
解决方案
您的示例产生了以下输出,并且在我的环境中完成了 282 名候选人大约需要 48.5 秒(我只进行了 3 次运行):
{'watch': -5.406847953796387
, 'run': -5.533411502838135
, 'think': -4.525279521942139
, 'apple': -6.158637046813965
, 'light': -5.835141658782959}
正如评论中提到的,我认为您可以使用过去的参数和快速标记器进行一些计算,如下面的评论示例所示:
import torch
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from torch.nn import CrossEntropyLoss
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
###We calculate the hidden_states and the past of the common left part of the sentence
past = "I like sitting in my new chair and"
past_tokenize_input = tokenizer.tokenize(past)
past_tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(past_tokenize_input)])
past_last_hidden_state, past = model.transformer(past_tensor_input)
def score(sentence, past, past_last_hidden_state, past_tensor_input):
tokenize_input = tokenizer.tokenize(sentence, )
tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
###the following code is slightly modified from https://github.com/huggingface/transformers/blob/09a2f40684f77e62d0fd8485fe9d2d610390453f/src/transformers/modeling_gpt2.py#L604
###now we calculate the right part of the sentence with the already calculated past
transformer_outputs = model.transformer(
tensor_input,
past=past,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
)
###and concatenate the output of with the hidden_state of the left part of the sentence
hidden_states = torch.cat((past_last_hidden_state, transformer_outputs[0]), dim=1)
###the following part is exactly the same as https://github.com/huggingface/transformers/blob/09a2f40684f77e62d0fd8485fe9d2d610390453f/src/transformers/modeling_gpt2.py#L604
lm_logits = model.lm_head(hidden_states)
labels_input = torch.cat((past_tensor_input, tensor_input), dim=1)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels_input[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return -loss.item()
candidates = ["watch", "run", "think", "apple", "light"]
sent_template = " {} about life"
print({candidate: score(sent_template.format(candidate), past, past_last_hidden_state, past_tensor_input) for candidate in candidates})
输出:
{'watch': -5.406846046447754
, 'run': -5.533413887023926
, 'think': -4.525280952453613
, 'apple': -6.158637046813965
, 'light': -5.835141181945801}
The runtime here was 40.5 seconds with 282 candidates (3 cycles again). You also see that I lost some precision.
Many thanks to patrickvonplaten who gave me a good explanation about the past implementation.
推荐阅读
- vue.js - 将路由器链接添加到 Vuetify 树视图
- html - Elastic Beanstalk 上对 index.html 的更改
- ios - 使用firebaseUI ios swift进行电话身份验证
- c - strcpy 在特定内存地址返回值加垃圾
- python - 执行 NER(命名实体识别)的过程 - NLP
- uwp - 在外围设备端断开连接后(重新)连接 - UWP
- python - 如何在 Pandas 中过滤带有条件空白字段的记录
- video - FFmpeg 丢失视频文件的转换
- tensorflow.js - 如何检查 tfjs 模型是否正确加载到浏览器
- r - 一个运行完全相同的数据,另一个不在 R 中类似 lm() 的函数中运行