matlab - 具有许多导致内存错误的匿名函数的 Matlab 并行代码
问题描述
我有一个代码可以解决许多不同输入/参数的科学问题。我正在使用并行 for 循环来遍历一系列参数,并遇到内存使用问题。我已尽我所能将一个代表我的代码的 MWE 放在一起。
基本上,对于每个参数组合,我在几个不同的求解器选项上运行一个小循环。在我的真实代码中,这是改变求解器公差和使用的方程(我们有一些不同的变换可以帮助调节)。每次计算实际上都是针对小型 ODE 系统(3 个方程,但每个方程都非常复杂且通常很僵硬)的一种射击方法,并带有一个调用 ODE 求解器的优化例程。每次运行都需要几秒钟/分钟的时间,并行化开销可以忽略不计,并且加速与内核数量几乎完全一致。
要解释下面的代码,请从driver
. 首先定义一些参数(a
和f
在 MWE 中)并将它们保存在一个文件中。文件名在函数之间传递。然后创建 3 组(在本例中)求解器参数,用于选择要使用的 ode 求解器、容差和方程组。然后进入 for 循环,在其他参数上循环c
,在每次迭代中使用每组求解器参数调用优化函数。最后,我保存了一个包含每次迭代结果的临时文件(因此,如果服务器出现故障,我不会丢失所有内容)。这些文件大约 1kB,我只有大约 10,000 个,所以总大小约为 10MB。在主循环之后,我将所有内容重新组合成单个向量。
该equations
函数创建要求解的实际微分方程,这是使用 switch 语句选择要返回的方程来完成的。该objectiveFunction
函数用于str2func
指定 ODE 求解器,调用equations
以获取要求解的方程,然后求解它们并计算目标函数值。
问题是似乎存在某种内存泄漏。一段时间后,大约几天后,代码速度变慢,最后出现内存不足错误(在 48 个内核上运行,可用内存约为 380GB,ode15s
出现错误)。随着时间的推移,内存使用量的增加是相当缓慢的,但肯定存在,我不知道是什么原因造成的。
具有 10,000 个值的 MWEc
需要很长时间才能运行(实际上 1,000 个可能就足够了),并且每个 worker 的内存使用量确实会随着时间的推移而增加。我认为文件加载/保存和作业分配会导致很多开销,与我的实际代码不同,但这不会影响内存使用。
我的问题是,什么可能导致内存使用量缓慢增加?
我对导致问题的想法是:
- 使用
str2func
不是很好,我应该改用 aswitch
并接受必须明确地将求解器写入代码吗? - 所有被调用的匿名函数(在 ODE 求解器中)都保留工作空间数据,而不是在每次
parfor
迭代结束时释放它 - 抑制警告导致问题:我抑制了许多 ODE 步长警告(这不应该是一个因素,因为导致问题的错误已在 2017a 中修复,并且我使用的服务器运行 2017b)
- 内存中的东西
fminbnd
或ode15s
实际上正在泄漏内存
我无法想出一种方法来很好地有效地解决 1 和 2(从代码性能和代码编写的角度来看),而且我怀疑 3 或 4 实际上是问题所在。
下面是驱动函数:
function [xi,mfv] = driver()
% a and f are used in all cases. In actual code these are defined in a
% separate function
paramFile = 'params';
a = 4;
f = @(x) 2*x;
% this filename (params) gets passed around from function to function
save('params.mat','a','f')
% The struct setup has specifc options for the each iteration
setup(1).method = 'ode45'; % any ODE solver can be used here
setup(1).atol = 1e-3; % change the ODE solver tolerance
setup(1).eqs = 'second'; % changes what equations are solved
setup(2).method = 'ode15s';
setup(2).atol = 1e-3;
setup(2).eqs = 'second';
setup(3).method = 'ode15s';
setup(3).atol = 1e-4;
setup(3).eqs = 'first';
c = linspace(0,1);
parfor i = 1:numel(c) % loop over parameter c
xi = 0;
minFVal = inf;
for j = 1:numel(setup) % loop over each set configuration setup
% find optimal initial condition and record corresponding value of
% objective function
[xInitial,fval] = fminsearch(@(x0) objectiveFunction(x0,c(i),...
paramFile,setup(j)),1);
if fval<minFVal % keep the best solution
xi = xInitial;
minFVal = fval;
end
end
% save some variables
saveInParForLoop(['tempresult_' num2str(i)],xi,minFVal);
end
% Now combine temporary files into single vectors
xi = zeros(size(c)); mfv = xi;
for i = 1:numel(c)
S = load(['tempresult_' num2str(i) '.mat'],'xi','minFVal');
xi(i) = S.xi;
mfv(i) = S.minFVal;
end
% delete the temporary files now that the data has been consolidated
for i = 1:numel(c)
delete(['tempresult_' num2str(i) '.mat']);
end
end
function saveInParForLoop(filename,xi,minFVal)
% you can't save directly in a parfor loop, this is the workaround
save(filename,'xi','minFVal')
end
这是定义方程的函数
function [der,transform] = equations(paramFile,setup)
% Defines the differential equation and a transformation for the solution
% used to calculate the objective function
% Note in my actual code I generate these equations earlier
% and pass them around directly, rather than always redefining them
load(paramFile,'a','f')
switch setup.eqs
case 'first'
der = @(x) f(x)*2+a;
transform = @(x) exp(x);
case 'second'
der = @(x) f(x)/2-a;
transform = @(x) sqrt(abs(x));
end
这是评估目标函数的函数
function val = objectiveFunction(x0,c,paramFile,setup)
load(paramFile,'a')
% specify the ODE solver and AbsTol from s
solver = str2func(setup.method);
options = odeset('AbsTol',setup.atol);
% get the differential equation and transform equations
[der,transform] = equations(paramFile,setup);
dxdt = @(t,y) der(y);
% solve the IVP
[~,y] = solver(dxdt,0:.05:1,x0,options);
% calculate the objective function value
val = norm(transform(y)-c*a);
如果您运行此代码,它将创建 100 个临时文件,然后将其删除,它还会创建params
不会被删除的文件。您将需要并行计算工具箱。
解决方案
您可能会遇到这个已知问题:https ://uk.mathworks.com/support/bugreports/1976165 。这在刚刚发布的 R2019b 中被标记为已修复。(由此引起的泄漏很小但持续存在 - 因此可能确实需要几天时间才能显现出来)。
推荐阅读
- go - 在 proto3 文件中使用 IP 字段(IPV4 或 IPV6)以供 Golang 和 C# 使用的更好方法是什么
- sorting - 如何对作为字符串的日期字段进行排序
- oracle - 双重会员运营商是做什么的?
- c - Simulink Embeded Coder 生成的 C 代码中定义的“rtmGetU”的目的是什么
- jquery - 如何隐藏最近
使用 jquery 动态生成且没有任何类或 id 的标签
- javascript - 如何用 reactJS 过滤?
- c# - 如何通过表单传递集合数据?
- r - 如何避免使用 shinyWidgets 下拉列表和数据表折叠
- c# - 每 n 分钟调用一次 Azure Continuous Webjob 中的函数
- r - 如何总结数据框中的相同单元格?