首页 > 解决方案 > 正在传递给 scipy.optimize.minimze 的函数的未知额外位置参数

问题描述

我应该制作一个脚本来创建一个用于检测两个刺激 A 和 B 的模型,其中包含刺激 A 和 B 的值的参数(实际上,一个表示 A 和 B 发生预测的值),两者的学习率stimuli,要模拟的试验次数。

但我坚持使用 scipy.optimize.minimize。我认为不管这个模型/模拟的描述如何,这里最重要的是我得到的错误消息似乎是在说我试图传递给最小化的函数有太多的位置参数。我明白了

TypeError: rw_simulation() takes 5 positional arguments but 6 were given

我传递给 scipy.optimize.minimize 的 args 参数的元组显然有 5 个元素。学习率应该是浮点数,v_stimulus 值也应该是浮点数...... trial_total 应该是一个 int。关于我如何尝试使用最小化,这里有些东西不正确,但是这条声称有 6 个位置参数的错误消息并不能真正帮助我理解我可以解决的问题。

def rw_simulation(v_stimulusA, v_stimulusB, learning_rateA, learning_rateB, trial_total): 
   # this code for the most part isn't important

    reinforcement_val = 100
    stimA_predict_vals = []
    stimB_predict_vals = []
    stimA_predict_vals.append(v_stimulusA)
    stimB_predict_vals.append(v_stimulusB)

    for trial in range(trial_count):
        # unimportant computation
        v_stimulusA += learning_rateA*(reinforcement_val - delta_diff)
        v_stimulusB += learning_rateB*(reinforcement_val - delta_diff)
        stimA_predict_vals.append(v_stimulusA) # stores the change in value for stimA at the end of each trial.
        stimB_predict_vals.append(v_stimulusB) # also stores the change in value for stimB at the end of each trial
    return stimA_predict_vals, stimB_predict_vals

trial_count = 200
start_predictA = 0
start_predictB = 0
learning_rateA = 0.3
learning_rateB = 0.3
hw1_data = np.load("hw1_data.npy") # np array containing initial guesses, though it is possible that this var is causing problems. However, the error message at hand clearly isn't related to this.
stimulus_prediction_vals_tuple = ()
stimulus_prediction_vals_tuple = rw_simulation(start_predictA, start_predictB, learning_rateA, learning_rateB, trial_count)

minimize(rw_simulation, hw1_data, args=(start_predictA, start_predictB, learning_rateA, learning_rateB, trial_count,), method='nelder-mead')

我现在只想至少解决这个关于额外位置参数的错误消息。我实际上有些担心,我用于 scipy.optimize.minimize 的 x0 参数的“hw1_data”参数可能不是我应该使用的变量。但是,如果那是导致此错误消息的原因...显然根本不明显。另外,我将把它留给另一个问题。

真的,我如何传递 6 个位置参数?第6个是从哪里来的?

编辑:在某些时候,当我取出 args 参数或将元组设置为 ()/空时,我收到一条错误消息,基本上说除了我的 rw_simulation 函数的第一个参数之外的每个参数都没有被设置。所以似乎有些东西被传递到第一个参数中,这导致了额外的参数。我实际上不知道第一个参数是什么或可能来自什么。也许函数名?但是……那我该怎么办?

标签: pythonnumpyscipy

解决方案


推荐阅读