首页 > 解决方案 > 将不在字典中的单词替换为

问题描述

我有一本字典(10k+ 字)和一篇文章(10M+ 字)。我想将字典中没有出现的所有单词替换为<unk>.

我试过str.maketrans了,但它的键应该是一个字符。

然后我尝试了这个https://stackoverflow.com/a/40348578/5634636但正则表达式非常慢。

有更好的解决方案吗?

标签: pythonstring

解决方案


我们将问题分为两部分:

  • 给定单词列表passage,找​​到索引 ipassage[i]不在另一个单词列表中dictionary
  • 然后简单地放在<unk>这些索引上。

1 中需要做的主要工作。为此,我们首先将字符串列表转换为 2D numpy 数组,以便我们可以有效地执行操作。此外,我们对二进制搜索中所需的字典进行排序。此外,我们用 0 填充字典以具有与 相同的列数passage_enc

# assume passage, dictionary are initially lists of words
passage = np.array(passage)  # np array of dtype='<U4'
passage_enc = passage.view(np.uint8).reshape(-1, passage.itemsize)[:, ::4]  # 2D np array of size len(passage) x max(len(x) for x in passage), with ords of chars

dictionary = np.array(dictionary)
dictionary = np.sort(dictionary)    
dictionary_enc = dictionary.view(np.uint8).reshape(-1, dictionary.itemsize)[:, ::4]
pad = np.zeros((len(dictionary), passage_enc.shape[1] - dictionary_enc.shape[1]))    
dictionary_enc = np.hstack([dictionary_enc, pad]).astype(np.uint8)

然后我们只是遍历段落,并检查字符串(现在是一个数组)是否在字典中。它需要 O(n * m), n, m 分别是段落和字典的大小。但是,我们可以通过事先对字典进行排序并在其中进行二进制搜索来改进这一点。因此,它变为 O(n * logm)。

此外,我们 JIT 编译代码以使其更快。下面,我使用numba

import numba as nb
import numpy as np

@nb.njit(cache=True)  # cache as being used multiple times
def smaller(a, b):
    n = len(a)
    i = 0
    while(i<n and a[i] == b[i]):
        i+=1
    if(i==n):
        return False
    return a[i] < b[i]

@nb.njit(cache=True)
def bin_index(array, item):
    first, last = 0, len(array) - 1

    while first <= last:
        mid = (first + last) // 2
        if np.all(array[mid] == item):
            return mid

        if smaller(item, array[mid]):
            last = mid - 1
        else:
            first = mid + 1

    return -1

@nb.njit(cache=True)
def replace(dictionary, passage):
    unknown_indices = []
    n = len(passage)
    for i in range(n):
        ind = bin_index(dictionary, passage[i])
        if(ind == -1):
            unknown_indices.append(i)
    return unknown_indices

检查样本数据

import nltk
emma = nltk.corpus.gutenberg.words('austen-emma.txt')
passage = np.array(emma)
passage = np.repeat(passage, 50)  # bloat coprus to have around 10mil words
passage_enc = passage.view(np.uint8).reshape(-1, passage.itemsize)[:, ::4]

persuasion = nltk.corpus.gutenberg.words('austen-persuasion.txt')
dictionary = np.array(persuasion)
dictionary = np.sort(dictionary)  # sort for binary search

dictionary_enc = dictionary.view(np.uint8).reshape(-1, dictionary.itemsize)[:, ::4]
pad = np.zeros((len(dictionary), passage_enc.shape[1] - dictionary_enc.shape[1]))

dictionary_enc = np.hstack([dictionary_enc, pad]).astype(np.uint8)  # pad with zeros so as to make dictionary_enc and passage_enc of same shape[1]

出于计时目的,段落和字典的大小最终都符合 OP 要求的顺序。这个电话:

unknown_indices = replace(dictionary_enc, passage_enc)

在我的 8 核 16G 系统上耗时 17.028s(包括预处理时间,显然不包括加载语料库的时间)。

然后,很简单:

passage[unknown_indices] = "<unk>"

PS:我想,我们可以通过parallel=True在 njit 装饰器中使用来获得更快的速度replace。我遇到了一些奇怪的错误,如果我能够解决它,我会进行编辑。


推荐阅读