python-3.x - 计算包含矩阵的符号函数的梯度和 Hessian
问题描述
我有以下功能:
我想用它做几件事:
将其转换为 SymPy 函数(完成)
x = MatrixSymbol('x', 2, 1) fx = x.T * MatMul(Matrix([[1, 2], [4, 7]]), x) + x.T * Matrix([3, 5]) + 6*Identity(1)
-
在这一点上,我已经尝试了以下渐变功能:
v = list(ordered(fx.free_symbols)) gradient = lambda f, v: Matrix([f]).jacobian(v) fxd = gradient(fx, v)
但是,这输出
[0, 0]
为不正确的结果。结果应该是:对于粗麻布,我使用 SymPy 的内置函数尝试了以下操作:
v = list(ordered(fx.free_symbols)) fxdd = hessian(fx, v)
但是,这个函数给了我以下错误:
ShapeError: Matrix size mismatch: (2, 2) + (2, 1)
输出应该是:
那么,我的问题是,如何执行第二步中的操作?
解决方案
您可以使用 获得渐变diff
。我不确定如何在不经过的情况下获得 Hessian as_explicit
:
In [49]: x = MatrixSymbol('x', 2, 1)
...: fx = x.T * MatMul(Matrix([[1, 2], [4, 7]]), x) + x.T * Matrix([3, 5]) + 6*Identity(1)
In [50]: fx.diff(x) # gradient
Out[50]:
⎡1 2⎤ ⎡1 4⎤ ⎡3⎤
⎢ ⎥⋅x + ⎢ ⎥⋅x + ⎢ ⎥
⎣4 7⎦ ⎣2 7⎦ ⎣5⎦
In [51]: hessian(fx.as_explicit(), x.as_explicit()) # Hessian
Out[51]:
⎡2 6 ⎤
⎢ ⎥
⎣6 14⎦
推荐阅读
- javascript - 如何通过 html 选择使用 javascript 移动 html 内容
- java - 从多个路径从 Firebase 数据库获取数据
- python - 气象站项目中的KeyError
- maven - TFS 2018 神器 maven 可以支持上游 repo 吗?
- javascript - 预期 2-3 个参数,但在执行 POST 操作时有 4 个角度?
- php - Windows 10 作曲家冻结
- javascript - 开玩笑的错误,需要 Babel "^7.0.0-0",但加载了 "6.26.3"
- javascript - 通过JQuery循环获取json对象的特定索引
- html - 更改边框折叠时,html,css边框消失
- python - CSV数据预处理和重新格式化python