python - scipy 行搜索的问题
问题描述
我正在尝试使用 scipy 函数line_search,但我收到一个错误,这对我来说没有意义。可能我做错了什么。我的目标是对将矩阵映射到实数的函数 f 进行线搜索。
我在这里发现了一个类似的问题,但不幸的是,在那里应用解决方案并不能解决我的问题,尽管那里的作者想用向量而不是矩阵做类似的事情。
我准备了一个简单的玩具示例。
import numpy
from scipy.optimize import line_search
def fct(x):
return x[0,0]**2+x[1,1]**2
def grad(x):
return np.array([[2*x[0,0],0],[0,2*x[1,1]]])
start = np.array([[1,2],[3,4]])
direction = np.array([[1,1],[0,1]])
result = line_search(fct,grad,start,direction)
错误消息如下:
File "C:\Users\myname\AppData\Local\Continuum\anaconda3\envs\MyEnv\lib\site-packages\scipy\optimize\linesearch.py", line 429, in scalar_search_wolfe2
if (phi_a1 > phi0 + c1 * alpha1 * derphi0) or \
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
似乎错误发生在 line_search_wolfe2 的子例程中,当我调用 line_search 时会调用该子例程。显然,发生错误的行中的值不是标量值。我查看了源代码,但找不到任何理由说明该值不是标量的,如预期的那样。
我的输入有什么问题?我在 line_search 函数的定义中有什么错误吗?
预先感谢您的帮助!
解决方案
推荐阅读
- forms - 错误“函数或接口标记为受限,或函数使用不支持的自动化类型”分配表单对象
- r - dplyr 错误 rlang 0.3.0。filter_impl(.data, quo) 中的错误
- hadoop - Sqoop - 服务器时区值“马来半岛标准时间”无法识别或代表多个时区
- python - Regex I want to match until certain characters but still be able to match strings if it doesn't have these characters
- node.js - 如何从 NodeJS 中的 API 捕获响应正文
- c# - how to pass dropdown selected name value and its id value using BeginForm to controller method
- java - Why is Locale final in Java?
- android - 无法删除文件(.delete() 返回 false)
- python - 使用 selenium 和 python 从 DOM 中查找第二个元素
- python - Numpy.array(500x50),我如何跨行连接,使其变为(500x1)