首页 > 解决方案 > 将 numpy.bmat 与 numba 一起使用

问题描述

我正在尝试在我的 numba 优化 python 程序中使用 np.bmat。为此,我必须手动定义一个 jitted 函数 bmat ,因为不支持来自 numpy 的本机函数:

@njit
def _bmat_2d(matrices):
    arr_rows = []
    for row in matrices:
        arr_rows.append(np.concatenate(row, axis=-1))
    return np.array(np.concatenate(arr_rows, axis=0))

(此代码或多或少是 numpy 的简化副本)

然而:

  1. numba 仅接受 np.concatenate [1] 输入中的元组
  2. numba 非常不擅长将任意列表转换为元组 [2]

你对此有什么想法吗?

参考:

标签: pythonnumpynumba

解决方案


以下内容是否适合您的目的?

import numpy as np
import numba as nb

@nb.njit
def _bmat_2d(m):
    out = np.hstack(m[0])
    for row in m[1:]:
        x = np.hstack(row)
        out = np.vstack((out, x))

    return out

A = np.random.randint(10, size=(3,2))
B = np.random.randint(10, size=(3,1))
C = np.random.randint(10, size=(3,3))
D = np.random.randint(10, size=(4,6))

a = np.bmat(((A, B, C), (D,)))
b = _bmat_2d(((A, B, C), (D,)))

print(np.allclose((a, b))  # True

请注意,您必须传入元组而不是列表列表,否则您将收到“反射列表”错误,因为当前版本中的 Numba 无法处理列表列表。


推荐阅读