首页 > 解决方案 > 稀疏矩阵:如果行的总和低于 X (Scipy),则删除行

问题描述

假设,我有以下稀疏矩阵:

from scipy.sparse import coo_matrix
m = coo_matrix(([1,1,1,3,2], ([1,2,2,3,4],[1,1,2,3,3])))
print(m.toarray())

>>> array([[0, 0, 0, 0],
>>>       [0, 1, 0, 0],
>>>       [0, 1, 1, 0],
>>>       [0, 0, 0, 3],
>>>       [0, 0, 0, 2]])

而且我只想保留总和大于 1 的行。我认为以下方法可行。

csr = m.tocsr()
csr[(csr.sum(1) > 1)]

但它没有。相反,我必须对 numpy 数组进行转换(使用squeeze):

csr = m.tocsr()
csr = csr[np.asarray(csr.sum(1) > 1).squeeze()]
csr.toarray()

所以,我得到了我想要的:

array([[0, 1, 1, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 2]], dtype=int64)

有没有更直接的方法来做到这一点?

我知道有类似的答案,例如在检查了类似this one的其他答案之后,但在他们的情况下(with M.getnnz(1)>0),该函数直接返回一个数组。

标签: pythonnumpyscipysparse-matrix

解决方案


看看细节:

In [803]: m = sparse.csr_matrix(([1,1,1,3,2], ([1,2,2,3,4],[1,1,2,3,3])))                              
In [804]: m                                                                                            
Out[804]: 
<5x4 sparse matrix of type '<class 'numpy.longlong'>'
    with 5 stored elements in Compressed Sparse Row format>
In [805]: m.A                                                                                          
Out[805]: 
array([[0, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 1, 1, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 2]], dtype=int64)
In [806]: m.sum(axis=1)                                                                                
Out[806]: 
matrix([[0],
        [1],
        [2],
        [3],
        [2]])

sumonndarray减小尺寸(除非keepdims设置)。但是sparse复制np.matrix, 并保留尺寸。所以结果是一个 (5,1) 矩阵。

np.matrix步骤的缩写array/ravel

In [807]: m.sum(axis=1).A1                                                                             
Out[807]: array([0, 1, 2, 3, 2])

和索引:

In [811]: m[m.sum(axis=1).A1>1,:]                                                                      
Out[811]: 
<3x4 sparse matrix of type '<class 'numpy.longlong'>'
    with 4 stored elements in Compressed Sparse Row format>
In [812]: _.A                                                                                          
Out[812]: 
array([[0, 1, 1, 0],
       [0, 0, 0, 3],
       [0, 0, 0, 2]], dtype=int64)

我在别处提到csr矩阵索引(通常)使用“提取矩阵”和矩阵乘法。考虑到数据的存储方式,这是稳健且合理的,但它不如密集数组索引那么快或强大。

有时我们通过作用于矩阵的基本属性来获得速度dataindicesindptr。但这需要对这种表示有更多的了解,所以我不会在这里详细介绍。


推荐阅读