首页 > 解决方案 > 如何从 Python 中的点列表中创建具有中心的椭圆体的二进制 3 维矩阵?

问题描述

目标

我有一个 M 长度的 3 维坐标点数组作为浮点数。我想创建一个预定义形状的 3 维 numpy 数组,其中填充以这些点为中心的给定浮点半径的椭圆体。因为这是用于图像处理的,所以我将数组中的每个值称为“像素”。如果这些椭球重叠,我想通过欧几里德距离将像素分配到更近的中心。最终输出将是一个 numpy 数组,其背景为 0,椭球内的像素编号为 1、2、... M,对应于初始坐标列表,类似于 scipy 的 ndimage.label(...) 的输出.

下面,我有一个简单的方法,它考虑输出数组中的每个位置并将其与每个定义的中心进行比较,为任何椭球内的任何像素创建一个值为 1 的二进制数组。然后它使用 scikit-image 来分水岭这个二进制数组。虽然此代码有效,但对我来说它的使用速度非常慢,因为它考虑了每个像素和中心对,并且因为它单独执行分水岭。如何加快此代码的速度?

朴素的例子

def define_centromeres(template_image, centers_of_mass, xradius = 4.5, yradius = 4.5, zradius = 3.5):
    """ Creates a binary N-dimensional numpy array of ellipsoids.

    :param template_image: An N-dimensional numpy array of the same shape as the output array.
    :param centers_of_mass: A list of lists of N floats defining the centers of spots.
    :param zradius: A float defining the radius in pixels in the z direction of the ellipsoids.
    :param xradius: A float defining the radius in pixels in the x direction of the ellipsoids.
    :param yradius: A float defining the radius in pixels in the y direction of the ellipsoids.
    :return: A binary N-dimensional numpy array.
    """
    out = np.full_like(template_image, 0, dtype=int)
    for idx, val in np.ndenumerate(template_image):
        z, x, y = idx
        for point in centers_of_mass:
            pz, px, py = point[0], point[1], point[2]
            if (((z - pz)/zradius)**2 + ((x - px)/xradius)**2 + ((y - py)/yradius)**2) <= 1:
                out[z, x, y] = 1
                break
    return out

Scikit-image 的分水岭函数;通过更改此方法中的代码不太可能找到加速:

def watershed_image(binary_input_image):
    bg_distance = ndi.distance_transform_edt(binary_input_image,
                                             return_distances=True,
                                             return_indices=False)
    local_maxima = peak_local_max(bg_distance, min_distance=1, labels=binary_input_image)
    bg_mask = np.zeros(bg_distance.shape, dtype=bool)
    bg_mask[tuple(local_maxima.T)] = True
    marks, _ = ndi.label(bg_mask)
    output_watershed = watershed(-bg_distance, marks, mask=binary_input_image)
    return output_watershed

小规模示例数据:

zdim, xdim, ydim = 15, 100, 100
example_shape = np.zeros((zdim,xdim,ydim))
example_points = np.random.random_sample(size=(10,3))*np.array([zdim,xdim,ydim])
center_spots_image = define_centromeres(example_shape, example_points)
watershed_spots = watershed_image(center_spots_image)

输出:

center_spots_image,最大投影到 2d

watershed_spots,最大投影到 2d

请注意,这些图像只是最终 3d 输出数组的 2d 表示。

补充说明

输出数组的典型大小为 31x512x512,或总共 8.1e6 个值,输入坐标的典型大小为 40 个 3 维坐标点。我想针对这种规模优化此过程的速度。

我在这个项目中使用了 numpy、scipy 和 scikit-image,我必须坚持使用这些以及其他维护良好和文档化的包。

对上述代码的可访问性错误或我的解释不够清晰深表歉意。我是一名研究科学家,几乎没有接受过正规的计算机科学培训。

标签: pythonarraysnumpyoptimizationscikit-image

解决方案


加速 numpy 代码的黄金法则是尽可能矢量化。这将循环从 Python 移动到更快的 C 例程中。这里的主要减速来自循环数组中的每个元素np.ndenumerate。您可以使用 numpy 完全在其中执行此操作np.indices,加上一些广播以使数组对齐:

def define_centromeres_vec(template_image, centers_of_mass, xradius=4.5, yradius=4.5, zradius=3.5):
    """Creates a binary N-dimensional numpy array of ellipsoids.

    :param template_image: An N-dimensional numpy array of the same shape as the output array.
    :param centers_of_mass: A list of lists of N floats defining the centers of spots.
    :param zradius: A float defining the radius in pixels in the z direction of the ellipsoids.
    :param xradius: A float defining the radius in pixels in the x direction of the ellipsoids.
    :param yradius: A float defining the radius in pixels in the y direction of the ellipsoids.
    :return: A binary N-dimensional numpy array.
    """
    out = np.zeros_like(template_image, dtype=int)
    # indices[:, z, x, y] = [z, x, y]
    indices = np.indices(template_image.shape, dtype=float)
    radii = np.asarray([zradius, xradius, yradius], dtype=float)
    for point in centers_of_mass:
        mask = np.sum(((indices - point[:,None,None,None]) / radii[:,None,None,None]) ** 2, axis=0) <= 1
        out[mask] = 1
    return out

这使我在示例数据上的速度提高了大约 60 倍:2.83s ± 116ms 到 44.4ms ± 555µs,包括分水岭。

radii如果我们将广播移到循环外,我们可以获得更快的速度:

    # ...
    radii_bcast = np.broadcast_to(radii[:, None, None, None], shape=indices.shape)
    for point in centers_of_mass:
        mask = np.sum(((indices - point[:,None,None,None]) / radii_bcast) ** 2, axis=0) <= 1
        out[mask] = 1
    return out

这得到了另一个 2 左右的因子define_centromeres,但现在很多总时间都花在了分水岭例程中(21ms ± 60µs)。

最后,我们可以通过对同一循环中的像素进行分类来消除分水岭步骤。out[mask]我们可以只存储 COM 的索引(加 1),而不是设置为 1。在此之前,我们检查与之前的椭球是否有任何重叠mask & out(这将是True当前椭球与之前的椭球重叠的任何地方)。对于任何重叠的部分,我们可以使用scipy.spatial.KDTree来获取每个点的最近 COM:

from scipy.spatial import KDTree

def define_centromeres_labeled(template_image, centers_of_mass, xradius=4.5, yradius=4.5, zradius=3.5):
    """Creates a labeled N-dimensional numpy array of ellipsoids.

    :param template_image: An N-dimensional numpy array of the same shape as the output array.
    :param centers_of_mass: A list of lists of N floats defining the centers of spots.
    :param zradius: A float defining the radius in pixels in the z direction of the ellipsoids.
    :param xradius: A float defining the radius in pixels in the x direction of the ellipsoids.
    :param yradius: A float defining the radius in pixels in the y direction of the ellipsoids.
    :return: An N-dimensional numpy array, with label `n` for the ellipsoid at index `n-1`.
    """
    out = np.zeros_like(template_image, dtype=int)
    # indices[:, z, x, y] = [z, x, y]
    indices = np.indices(template_image.shape, dtype=float)
    radii = np.asarray([zradius, xradius, yradius], dtype=float)
    radii_bcast = np.broadcast_to(radii[:, None, None, None], shape=indices.shape)
    tree = KDTree(centers_of_mass)
    for i, point in enumerate(centers_of_mass, start=1):
        mask = np.sum(((indices - point[:,None,None,None]) / radii_bcast) ** 2, axis=0) <= 1
        # check for overlap
        if np.any(mask & out):
            # get the overlapping points before modifying out
            overlap_mask = mask & out.astype(bool)
            overlap_idx = np.array(np.where(overlap_mask)).T
            out[mask] = i
            # get the closest center for all overlapping points
            out[overlap_mask] = tree.query(overlap_idx)[1] + 1
        else:
            out[mask] = i
    return out

这大约是前一种方法的两倍(20.2ms ± 384µs),另外还有一个优点是椭球根据它们的索引进行标记,并且相邻的椭球不会合并在一起(这是分水岭的问题)。

在我的机器上,它在 4.38s ± 43.9ms(31x512x512,40 个椭球体)内运行全尺寸示例数据。


推荐阅读