nlp - 如何使用转换器模型改进代码以加快词嵌入?
问题描述
我需要为一堆具有不同语言模型的文档计算词嵌入。没问题,脚本做得很好,除了我正在使用笔记本电脑,没有 GPU 并且每个文本需要大约 1.5 秒来处理,这太长了(我有成千上万的文本要处理)。
以下是我使用 pytorch 和变压器库的方法:
import torch
from transformers import CamembertModel, CamembertTokenizer
docs = [text1, text2, ..., text20000]
tok = CamembertTokenizer.from_pretrained('camembert-base')
model = CamembertModel.from_pretrained('camembert-base', output_hidden_states=True)
# let try with a batch size of 64 documents
docids = [tok.encode(
doc, max_length=512, return_tensors='pt', pad_to_max_length=True) for doc in docs[:64]]
ids=torch.cat(tuple(docids))
device = 'cuda' if torch.cuda.is_available() else 'cpu' # cpu in my case...
model = model.to(device)
ids = ids.to(device)
model.eval()
with torch.no_grad():
out = model(input_ids=ids)
# 103s later...
有人对提高速度有任何想法或建议吗?
解决方案
我不认为有一种简单的方法可以在不使用 GPU 的情况下显着提高速度。
我能想到的一些方法包括Sentence-Transformers使用的智能批处理,您基本上将相似长度的输入排序在一起,以避免填充到完整的 512 个令牌限制。我不确定这会给您带来多少加速,但这是您可以在短时间内显着改善它的唯一方法。
否则,如果您可以访问Google colab,您也可以使用他们的 GPU 环境,如果处理可以在合理的时间内完成。
推荐阅读
- python - Python Gsheets insert_row 正在跳过数据
- gitlab-ci - 允许作业运行“重启”命令而不会导致失败
- blockchain - 将 esplora 配置为指向 bitcoind 服务器的 ip 地址
- swift - Swift - 为什么“do”语句之后的“while”语句会导致编译器错误?
- angular - 模块 '"*/node_modules/ngx-echarts/ngx-echarts"' 没有导出成员 'Ngx Echarts Service'
- typescript - 打字稿界面集
- opengl - DirectDraw Surface - 数据布局
- apache-spark - 有没有办法只打印 Spark SQL 输出?它在执行时打印有关环境的所有其他信息
- opengl - 标准化点云(用于在 OpenGL 中显示)的最有效方法是什么?
- jquery - 如何确定物品是否被丢弃