首页 > 解决方案 > 如何为 SciPy fmin_l_bfgs_b 制作伪函数和函数素数?

问题描述

我想用它scipy.optimize.fmin_l_bfgs_b来找到成本函数的最小值。

为此,我想首先创建一个实例one_batch(代码one_batch如下),以指定训练示例的批次以及那些未包含在损失函数中但计算损失所必需的参数。

因为该模块loss_calc旨在同时返回损失和损失素数,所以我面临着将损失函数和损失函数素数分离的问题scipy.optimize.fmin_l_bfgs_b

从 的代码中可以看出one_batch,给定一批训练样例,[loss, dloss/dParameters]每个样例的 都会并行计算。我不想对get_loss和进行两次完全相同的计算get_loss_prime

那么如何设计方法get_lossget_loss_prime,以便我只需要进行一次并行计算呢?

这是代码one_batch

from calculator import loss_calc

class one_batch:

    def __init__(self, 
                 auxiliary_model_parameters, 
                 batch_example):

        # auxiliary_model_parameters are parameters need to specify 
        # the loss calculator but are not included in the loss function.

        self.auxiliary_model_parameters = auxiliary_model_parameters 
        self.batch_example = batch_example

    def parallel(self, func, args):
        pool = multiprocessing.Pool(multiprocessing.cpu_count())
        result = pool.map(func, args)
        return result 

    def one_example(self, example):
        temp_instance = loss_calc(self.auxiliary_model_parameters, 
                                  self.model_vector)
        loss, dloss = temp_instance(example).calculate()
        return [loss, dloss]

    def main(self, model_vector):
        self.model_vector = model_vector

        # model_vector and auxiliary_model_parameters are necessary 
        # for creating an instance of loss function calculator 

        result_list = parallel(self.one_example, 
                               self.batch_examples)

        # result_list is a list of sublists, each sublist is 
        # [loss, dloss/dParameter] for each training example 

   def get_loss(self):
       ?

   def get_loss_prime(self):
       ?

标签: pythonalgorithmscipy

解决方案


您可以使用直接返回两个函数值作为输入的目标函数fmin_l_bfgs_b

from scipy.optimize import fmin_l_bfgs_b
import numpy as np

def obj_fun(x):
    fx = 2*x**2 + 2*x + 1
    grad = np.array([4*x + 2])
    return fx, grad

fmin_l_bfgs_b(obj_fun, x0=[12])

(array([-0.5]), array([0.5]), {'grad': array([[-3.55271368e-15]]),
'task': b'CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL',
'funcalls ':4,'nit':2,'warnflag':0})


推荐阅读