首页 > 解决方案 > 将 CSR 矩阵乘以向量

问题描述

我正在关注this stackoverflow post on csr matrix multiplication to a vector 并在python中实现它并获取列表超出范围错误。

这是我的代码:

def MatrixMultiplication(data,row_ptr,col_ptr,vec):
  ResultMatrix =[]
  vec_len = len(vec)
  for i in range(0,vec_len):
    ResultMatrix.insert(i,0)
  for i in range(0,vec_len):
    start, end = row_ptr[i], row_ptr[i + 1]
    for k in range(start, end):
      ResultMatrix[i] = ResultMatrix[i]+data[k]*vec[col_ptr[k]]
  return ResultMatrix

data = [2, 4, 7, 1, 3, 2]
row_ptr =  [2,3 ,5, 5 ,6]
col_ptr = [1 ,3, 4, 0, 3, 3]
vec = [2,3, 5, 4, 2]

MatrixMultiplication(data,row_ptr,col_ptr,vec)

请帮我解决我哪里出错了。

输出应该是:[22 14 14 0 8]

错误 :

IndexError: list index out of range
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<command-338158343473691> in <module>()
----> 1 MatrixMultiplication(data,row_ptr,col_ptr,vec)

<command-3658506804172571> in MatrixMultiplication(data, row_ptr, col_ptr, vec)
      5     ResultMatrix.insert(i,0)
      6   for i in range(0,vec_len):
----> 7     start, end = row_ptr[i], row_ptr[i + 1]
      8     for k in range(start, end):
      9       ResultMatrix[i] = ResultMatrix[i]+data[k]*vec[col_ptr[k]]

IndexError: list index out of range

供参考:

row_ptr 的最后一个元素将是数据列表的大小

标签: pythonmatrix-multiplication

解决方案


错误消息非常不言自明:您尝试row_ptr[i + 1]在一个上升到 的 for 循环中访问vec_len,这是您的列表的长度。当您到达 for 循环的最后一次迭代时i = vec_len - 1,然后i + 1 = vec_len, 超出了列表的范围(请记住,Python 列表是 0 初始化的)。

为了防止这个错误,你的范围应该只vec_len - 1在你的第二个 for 循环中上升。


推荐阅读