首页 > 解决方案 > 如何将线程用于拥抱面变压器

问题描述

我正在尝试在线程上运行一个拥抱脸模型,模式正是“cardiffnlp/twitter-roberta-base-sentiment”。但与此同时,我只想要它的一个实例,因为它在时间方面确实很昂贵。

换句话说,我有多个 CSV 文件(几千个),每个文件都有大约 20k-30k 行,我希望所有这些文件中的每一行都由 huggingface 模型执行,正如您可能已经想象的那样我不想为每个线程实例化模型的原因(每个线程仅用于读取一行并将其写入数据库)。我的方法的问题是,当我运行代码时,Huggingface 模型会给我一个错误。

RuntimeError: 已经借用

你们中的任何人都可以帮助我了解如何解决它吗?

拥抱脸模型:

class EmotionDetection(object):
    def __init__(self, model_name="cardiffnlp/twitter-roberta-base-sentiment"):
        self.model_name = model_name
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True,
                                                     task="sentiment-analysis", device=0)

    def get_emotion_by_label(self, label: str):
        if label == "LABEL_0":
            return "negative"
        elif label == "LABEL_1":
            return "neutral"
        elif label == "LABEL_2":
            return "positive"
        else:
            print("SOMETHING IS WRONG")
            return ""

    def get_emotion(self, phrase):
        results = self.classifier(phrase)
        res = dict()
        for result in results:
            for emotion in result:
                res.update({self.get_emotion_by_label(emotion['label']): emotion['score']})
        return res

我生成数据库的代码:

class GenerateDbThread(object):
    def __init__(self, text: str, created_at: datetime.datetime, get_emotion_function, cursor, table_name):
        self.table_name = table_name

        self.text = text
        self.created_at = created_at
        emotions = get_emotion_function(self.text)

        self.pos = emotions['positive']
        self.neg = emotions['negative']
        self.neu = emotions['neutral']

        self.cursor = cursor

    def execute(self):
        query = f"INSERT INTO {self.table_name}(date, positive, negative, neutral, tweet) " \
                f"VALUES (datetime('{str(self.created_at)}'),{self.pos},{self.neg},{self.neu}, '{self.text}')"
        self.cursor.execute(query)
        self.cursor.commit()


def get_all_data_files_path(data_dir: str):
    return [f for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))]


def run(file: str, table_name: str):
    df = pd.read_csv(os.path.join('data', file), delimiter=',')
    for index, row in df.iterrows():
        text = row['tweet']
        language = row['language']
        split_data = row['created_at'].split(" ")
        GTB_Time = f"{split_data[2]} {split_data[3]} {split_data[4]}"
        created_at = datetime.datetime.strptime(row['created_at'], f"%Y-%m-%d %H:%M:%S {GTB_Time}")
        if language == "en":
            GenerateDbThread(text, created_at, emotion_detector.get_emotion, cursor, table_name)


def init_db(db_name, table_name):
    conn = sqlite3.connect(db_name)
    cursor = conn.cursor()
    cursor.execute(f"""
    CREATE TABLE IF NOT EXISTS {table_name} (
        uid INTEGER PRIMARY KEY AUTOINCREMENT,
        date DATETIME NOT NULL,
        positive REAL NOT NULL,
        negative REAL NOT NULL,
        neutral REAL NOT NULL,
        text TEXT NOT NULL
    )""")
    cursor.execute(f"CREATE INDEX IF NOT EXISTS ix_tweets_index ON {table_name}(uid)")
    cursor.close()


ex = ThreadPoolExecutor(max_workers=10)

files = get_all_data_files_path('data')

init_db("DB_NAME.db", "TABLE_NAME")

emotion_detector = EmotionDetection()
conn = sqlite3.connect("DB_NAME.db")
cursor = conn.cursor()

pbar = tqdm(total=len(files))
futures = [ex.submit(run, file, "TABLE_NAME") for file in files]
for future in futures:
    res = future.result()
    pbar.update(1)
pbar.close()

标签: pythonmultithreadingthreadpoolhuggingface-transformershuggingface-tokenizers

解决方案


推荐阅读