首页 > 解决方案 > scipy 行搜索的问题

问题描述

我正在尝试使用 scipy 函数line_search,但我收到一个错误,这对我来说没有意义。可能我做错了什么。我的目标是对将矩阵映射到实数的函数 f 进行线搜索。

我在这里发现了一个类似的问题,但不幸的是,在那里应用解决方案并不能解决我的问题,尽管那里的作者想用向量而不是矩阵做类似的事情。

我准备了一个简单的玩具示例。

import numpy
from scipy.optimize import line_search

def fct(x):
    return x[0,0]**2+x[1,1]**2

def grad(x):
    return np.array([[2*x[0,0],0],[0,2*x[1,1]]])   

start = np.array([[1,2],[3,4]])
direction = np.array([[1,1],[0,1]])

result = line_search(fct,grad,start,direction)

错误消息如下:

File "C:\Users\myname\AppData\Local\Continuum\anaconda3\envs\MyEnv\lib\site-packages\scipy\optimize\linesearch.py", line 429, in scalar_search_wolfe2
    if (phi_a1 > phi0 + c1 * alpha1 * derphi0) or \
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

似乎错误发生在 line_search_wolfe2 的子例程中,当我调用 line_search 时会调用该子例程。显然,发生错误的行中的值不是标量值。我查看了源代码,但找不到任何理由说明该值不是标量的,如预期的那样。

我的输入有什么问题?我在 line_search 函数的定义中有什么错误吗?

预先感谢您的帮助!

标签: pythonpython-3.xscipymathematical-optimization

解决方案


推荐阅读