word-embedding - 文档嵌入的最后一层 longformer
问题描述
使用 longformer API 返回有限数量层的正确方法是什么?
与基本BERT中的这种情况不同,从返回类型中我不清楚如何仅获取最后 N 层。
所以,我运行这个:
from transformers import LongformerTokenizer, LongformerModel
text = "word " * 4096 # long document!
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
encoded_input = tokenizer(text, return_tensors="pt", max_length=4096, truncation=True)
output = model(**encoded_input)
我从返回中得到这样的尺寸:
>>> output[0].shape
torch.Size([1, 4096, 768])
>>> output[1].shape
torch.Size([1, 768])
你可以看到 [0] 的形状与我的令牌数量非常相似。我相信切片只会给我更少的令牌,而不仅仅是最后 N 层。
从下面的答案更新
即使要求output_hidden_states
,尺寸仍然看起来不正确,我不清楚如何将它们减少到矢量大小的一维嵌入。这就是我的意思:
encoded_input = tokenizer(text, return_tensors="pt", max_length=4096, truncation=True)
output = model(**encoded_input, output_hidden_states=True)
好的,现在让我们看看 output[2],元组的第三项:
>>> len(output[2])
13
假设我们想查看 13 层中的最后 3 层:
>>> [pair[0].shape for pair in output[2][-3:]]
[torch.Size([4096, 768]), torch.Size([4096, 768]), torch.Size([4096, 768])]
所以我们看到 13 层中的每一层都是成形的 (4096 x 768),它们看起来像:
>>> [pair[0] for pair in output[2][-3:]]
[tensor([[-0.1494, 0.0190, 0.0389, ..., -0.0470, 0.0259, 0.0609],
我们仍然有 4096 的大小,因为它对应于我的令牌计数:
>>> np.mean(np.stack([pair[0].detach().numpy() for pair in output[2][-3:]]), axis=0).shape
(4096, 768)
将这些平均在一起似乎不会给出有效的嵌入(用于余弦相似度等比较)。
解决方案
output
是一个由两个元素组成的元组:
- sequence_output(即最后一个编码器块)
- pooled_output
为了获取所有隐藏层,您需要将参数设置output_hidden_states
为true:
output = model(**encoded_input, output_hidden_states=True)
输出现在有 3 个元素,第三个元素包含嵌入层和每个编码层的输出。
推荐阅读
- angular - Angular http post响应处理
- javascript - 在使用它之前检测 window.close() 是否可以工作(JavaScript)
- android - 在 Firebase Recycler 的 OnBindViewHolder 中设置 View Holder 的高度
- neo4j - 使用密码查询时,螺栓连接会选择性地变慢,而基于 Web 的 GUI 总是很快
- r - 如何从函数参数定义 R data.table/data.frame 列?
- angular - Angular 2 重置 Observable 计时器
- html - 除非我在浏览器的开发工具中编辑文件,否则不会为规则解析 CSS 样式表
- xml - XSLT - text() 连接单词
- java - AuthenticationProvider 未进行身份验证
- javascript - 为什么'addEventListener'只适用于'for循环'