首页 > 解决方案 > Cython 字符串支持

问题描述

我正在尝试优化一些代码。我已经设法使用 Numpy 和 Numba 优化了我的大部分项目,但是还有一些我无法使用这些工具优化的剩余字符串处理代码。因此,我想尝试使用 Cython 优化这部分。

此处的代码采用运行长度编码的字符串(一个字母,可选地后跟一个数字,指示该字母重复多少次)并将其扩展。然后,它使用字典查找将扩展的字符串转换为 0 和 1 的数组,以将字母与 0 和 1 的序列匹配。

是否可以使用 Cython 优化此代码?

import numpy as np
import re

vector_list = ["A22gA5BA35QA17gACA3QA7gA9IAAgEIA3wA3gCAAME@EACRHAQAAQBACIRAADQAIA3wAQEE}rm@QfpT}/Mp-.n?",
                "A64IA13CA5RA13wAABA5EAECA5EA4CEgEAABGCAAgAyAABolBCA3WA4GADkBOA?QQgCIECmth.n?"]


_base64chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz@}]^+-*/?,._"
_bin2base64 = {"{:06b}".format(i): base64char for i, base64char in enumerate(_base64chars)}
_base642bin = {v: k for k, v in _bin2base64.items()}

_n_vector_ranks_only = np.arange(1023,-1,-1)


def _decompress_get(data):
    for match in re.finditer(r"(?P<char>.)((?P<count>\d+))?", data):
        if not match.group("count"): yield match.group("char")
        else: yield match.group("char") * int(match.group("count"))


def _n_apply_weights(vector):
    return np.multiply(vector, _n_vector_ranks_only)

def n_decompress(compressed_vector):
    decompressed_b64 = "".join(_decompress_get(compressed_vector))
    vectorized = "".join(_base642bin[c] for c in decompressed_b64)[:-2]
    as_binary = np.fromiter(vectorized, int)
    return as_binary


def test(x, y):
    if len(x) != 1024:
        x = n_decompress(x)
    vector_a = _n_apply_weights(x)
    if len(y) != 1024:
        y = n_decompress(y)
    vector_b = _n_apply_weights(y)
    maxPQ = np.sum(np.maximum(vector_a, vector_b))
    return np.sum(np.minimum(vector_a, vector_b))/maxPQ

v1 = vector_list[0]
v2= vector_list[1]
print(test(v1, v2))

标签: pythonnumpycython

解决方案


单独使用 Numpy 可以很好地加快问题的第二部分(通过字典查找)。我已经通过索引到 Numpy 数组来替换字典查找。

我在开始时生成 Numpy 数组。一个技巧是意识到可以使用 . 将字母转换为表示它们的基础数字ord。对于 ASCII 字符串,它始终介于 0 和 127 之间:

_base642bin_array = np.zeros((128,),dtype=np.uint8)
for i in range(len(_base64chars)):
    _base642bin_array[ord(_base64chars[i])] = i

n_decompress使用内置的 numpy 函数在函数中转换为 1 和 0。

def n_decompress2(compressed_vector):
    # encode is for Python 3: str -> bytes
    decompressed_b64 = "".join(_decompress_get(compressed_vector)).encode()
    # byte string into the underlying numeric data
    decompressed_b64 = np.fromstring(decompressed_b64,dtype=np.uint8)
    # conversion done by numpy indexing rather than dictionary lookup
    vectorized = _base642bin_array[decompressed_b64]
    # convert to a 2D array of 1s and 0s
    as_binary = np.unpackbits(vectorized[:,np.newaxis],axis=1)
    # remove the two digits you don't care about (always 0) from binary array
    as_binary = as_binary[:,2:]
    # reshape to 1D (and chop off two at the end)
    return as_binary.ravel()[:-2]

仅使用 Numpy(没有 Cython/Numba,我怀疑它们不会有太大帮助),这使我的速度比您的版本快 2.4 倍(请注意,我根本没有改变_decompress_get,所以两个时间都包括您的)。_decompress_get我认为主要优点是与字典查找相比,用数字索引到数组中更快。


_decompress_get可能可以使用 Cython 进行改进,但这是一个非常困难的问题......


推荐阅读