python - 为什么 scipy.linalg.LU 重复求解 Ax = b 这么慢?
问题描述
传统观点认为,如果您使用相同的 A 和不同的 b 多次求解 Ax = b,则应该对 LU 使用 LU 分解。如果我p, l, u = scipy.linalg.lu(A)
在一个循环中多次使用和解决
x = scipy.linalg.solve(l, p.T@b)
x = scipy.linalg.solve(u, x)
这最终比仅仅使用慢得多
x = scipy.linalg.solve(A,b)
在循环。是scipy.linalg.solve()
不是优化使用前向和后向替换来解决上下对角线系统?或者,是否有可能存在一些编译技巧,python 认识到它可以对scipy.linalg.solve
零件进行 LU 分解?
我知道 scipy 中有一些linalg.lu_factor
惯例linalg.lu_solve
,但我想远离那些,因为这应该是一个教学示例。
解决方案
我的大多数线性代数研究都是计算机前的(或至少是 MATLAB/python 前的)。但我可以阅读文档。
In [29]: from scipy import linalg as la
从以下示例数组开始la.lu
:
In [30]: A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
In [31]: p, l, u = la.lu(A)
In [32]: p
Out[32]:
array([[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.]])
In [33]: l
Out[33]:
array([[ 1. , 0. , 0. , 0. ],
[ 0.28571429, 1. , 0. , 0. ],
[ 0.71428571, 0.12 , 1. , 0. ],
[ 0.71428571, -0.44 , -0.46153846, 1. ]])
In [34]: u
Out[34]:
array([[ 7. , 5. , 6. , 6. ],
[ 0. , 3.57142857, 6.28571429, 5.28571429],
[ 0. , 0. , -1.04 , 3.08 ],
[ 0. , 0. , 0. , 7.46153846]])
In [42]: b=np.arange(4)
In [43]: la.solve(A,b)
Out[43]: array([-0.21649485, 2.54639175, -1.54639175, 0.01030928])
In [44]: timeit la.solve(A,b)
43.5 µs ± 88.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
我看到一个la.solve_triangular
. 经过一番反复试验,我得到了:
In [46]: la.solve_triangular(u,la.solve_triangular(l,p.T@b, lower=True))
Out[46]: array([-0.21649485, 2.54639175, -1.54639175, 0.01030928])
并计时:
In [47]: timeit la.solve_triangular(u,la.solve_triangular(l,p.T@b, lower=True))
83 µs ± 2.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
所以 double 使用solve_trianglar
比 one 慢solve
,但比使用solve
不知道它的数组是三角形的 a 快。
In [48]: la.solve(u,la.solve(l,p.T@b))
Out[48]: array([-0.21649485, 2.54639175, -1.54639175, 0.01030928])
In [49]: timeit la.solve(u,la.solve(l,p.T@b))
137 µs ± 342 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
我不知道这些计算将如何扩展。
测试@Warren 的lu_solve
想法(在已删除的答案中)
https://stackoverflow.com/a/64473976/901925
In [50]: lu_and_piv = la.lu_factor(A)
In [51]: lu_and_piv
Out[51]:
(array([[ 7. , 5. , 6. , 6. ],
[ 0.28571429, 3.57142857, 6.28571429, 5.28571429],
[ 0.71428571, 0.12 , -1.04 , 3.08 ],
[ 0.71428571, -0.44 , -0.46153846, 7.46153846]]),
array([2, 2, 3, 3], dtype=int32))
In [52]: la.lu_solve(lu_and_piv, b)
Out[52]: array([-0.21649485, 2.54639175, -1.54639175, 0.01030928])
In [53]: timeit la.lu_solve(lu_and_piv, b)
7.47 µs ± 14.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
推荐阅读
- excel - 如何使用带有日期的多个条件从另一个工作表中查找单元格
- java - Spring Boot 将属性加载为 java.util.Properties
- javascript - 使用刷新间隔在 GoogleMaps 中刷新 KML
- spring - how to get prometheus webflux r2dbc in spring work together? It gives me error when I try to run actual apis from service
- javascript - How to take image(s) when user happy and face detected using face-api
- java - 如何在连接表上构建 JPQL 查询?
- java - Java,打印列表中两个元素之和的代码不起作用
- r - 在基本 R 中对同一数据集执行相同操作时偶尔会出现数值问题
- arduino - 调用“”(Arduino)没有匹配的功能
- vue.js - babel.config.js babel eslint 关闭/禁用规则