python - 如何将线程用于拥抱面变压器
问题描述
我正在尝试在线程上运行一个拥抱脸模型,模式正是“cardiffnlp/twitter-roberta-base-sentiment”。但与此同时,我只想要它的一个实例,因为它在时间方面确实很昂贵。
换句话说,我有多个 CSV 文件(几千个),每个文件都有大约 20k-30k 行,我希望所有这些文件中的每一行都由 huggingface 模型执行,正如您可能已经想象的那样我不想为每个线程实例化模型的原因(每个线程仅用于读取一行并将其写入数据库)。我的方法的问题是,当我运行代码时,Huggingface 模型会给我一个错误。
RuntimeError: 已经借用
你们中的任何人都可以帮助我了解如何解决它吗?
拥抱脸模型:
class EmotionDetection(object):
def __init__(self, model_name="cardiffnlp/twitter-roberta-base-sentiment"):
self.model_name = model_name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True,
task="sentiment-analysis", device=0)
def get_emotion_by_label(self, label: str):
if label == "LABEL_0":
return "negative"
elif label == "LABEL_1":
return "neutral"
elif label == "LABEL_2":
return "positive"
else:
print("SOMETHING IS WRONG")
return ""
def get_emotion(self, phrase):
results = self.classifier(phrase)
res = dict()
for result in results:
for emotion in result:
res.update({self.get_emotion_by_label(emotion['label']): emotion['score']})
return res
我生成数据库的代码:
class GenerateDbThread(object):
def __init__(self, text: str, created_at: datetime.datetime, get_emotion_function, cursor, table_name):
self.table_name = table_name
self.text = text
self.created_at = created_at
emotions = get_emotion_function(self.text)
self.pos = emotions['positive']
self.neg = emotions['negative']
self.neu = emotions['neutral']
self.cursor = cursor
def execute(self):
query = f"INSERT INTO {self.table_name}(date, positive, negative, neutral, tweet) " \
f"VALUES (datetime('{str(self.created_at)}'),{self.pos},{self.neg},{self.neu}, '{self.text}')"
self.cursor.execute(query)
self.cursor.commit()
def get_all_data_files_path(data_dir: str):
return [f for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))]
def run(file: str, table_name: str):
df = pd.read_csv(os.path.join('data', file), delimiter=',')
for index, row in df.iterrows():
text = row['tweet']
language = row['language']
split_data = row['created_at'].split(" ")
GTB_Time = f"{split_data[2]} {split_data[3]} {split_data[4]}"
created_at = datetime.datetime.strptime(row['created_at'], f"%Y-%m-%d %H:%M:%S {GTB_Time}")
if language == "en":
GenerateDbThread(text, created_at, emotion_detector.get_emotion, cursor, table_name)
def init_db(db_name, table_name):
conn = sqlite3.connect(db_name)
cursor = conn.cursor()
cursor.execute(f"""
CREATE TABLE IF NOT EXISTS {table_name} (
uid INTEGER PRIMARY KEY AUTOINCREMENT,
date DATETIME NOT NULL,
positive REAL NOT NULL,
negative REAL NOT NULL,
neutral REAL NOT NULL,
text TEXT NOT NULL
)""")
cursor.execute(f"CREATE INDEX IF NOT EXISTS ix_tweets_index ON {table_name}(uid)")
cursor.close()
ex = ThreadPoolExecutor(max_workers=10)
files = get_all_data_files_path('data')
init_db("DB_NAME.db", "TABLE_NAME")
emotion_detector = EmotionDetection()
conn = sqlite3.connect("DB_NAME.db")
cursor = conn.cursor()
pbar = tqdm(total=len(files))
futures = [ex.submit(run, file, "TABLE_NAME") for file in files]
for future in futures:
res = future.result()
pbar.update(1)
pbar.close()
解决方案
推荐阅读
- c++ - 为什么在这个链表的实现中,这个 C++ 构造函数在同一个内存位置被调用了两次?
- python - 如何在连接四游戏中确定获胜者
- c# - DevExpress - 在 TreeList 中选择多个复选框
- c++ - 我应该如何使用 c++ 处理“std::cerr”和“std::cin.fail()”?
- c# - 有没有办法通过 MVVM 和 XAML 来拉伸 DataGrid 中的行以垂直占用所有空间?
- android - Listview 自定义过滤器在过滤列表中单击错误的项目
- c# - C# 使用 for() 处理多个变量
- python - 如何使 QTableWidget 的最后一列的大小等于其他列的大小
- allennlp - 据我所知,没有办法在 AllenNLP 配置文件中参数化字符串——只有整数或浮点数
- r - geom_bar() + facet_wrap() 中的列宽不一致