首页 > 解决方案 > 朱莉娅:避免许多函数参数的有效而整洁的方法

问题描述

我一直在 Julia 中编写随机 PDE 模拟,随着我的问题变得越来越复杂,独立参数的数量也在增加。那么从什么开始,

myfun(N,M,dt,dx,a,b)

最终变成

myfun(N,M,dt,dx,a,b,c,d,e,f,g,h)

它会导致 (1) 混乱的代码,(2) 由于放错函数参数而增加出错的机会,(3) 无法泛化以用于其他函数。

(3) 很重要,因为我已经对我的代码进行了简单的并行化,以评估 PDE 的许多不同运行。所以我想把我的函数转换成一种形式:

myfun(args)

其中 args 包含所有相关参数。我在 Julia 中发现的问题是,创建struct包含所有相关参数作为属性的属性会大大减慢速度。我认为这是由于结构属性的不断访问。作为一个简单的(ODE)工作示例,

function example_fun(N,dt,a,b)
  V = zeros(N+1)
  U = 0
  z = randn(N+1)
  for i=2:N+1
    V[i] = V[i-1]*(1-dt)+U*dt
    U = U*(1-dt/a)+b*sqrt(2*dt/a)*z[i]
  end
  return V
end

如果我尝试将其重写为,

function example_fun2(args)
  V = zeros(args.N+1)
  U = 0
  z = randn(args.N+1)
  for i=2:args.N+1
    V[i] = V[i-1]*(1-args.dt)+U*args.dt
    U = U*(1-args.dt/args.a)+args.b*sqrt(2*args.dt/args.a)*z[i]
  end
  return V
end

然后,虽然函数调用看起来很优雅,但重新访问类中的每个属性很麻烦,而且这种持续访问属性会减慢模拟速度。什么是更好的解决方案?有没有办法简单地“解包”结构的属性,这样就不必不断地访问它们?如果是这样,这将如何概括?

编辑:我正在定义我使用的结构如下:

struct Args
  N::Int64
  dt::Float64
  a::Float64
  b::Float64
end

编辑2:我已经意识到,如果您没有在结构定义中指定数组的维度,那么具有 Array{} 属性的结构会导致性能差异。例如,如果 c 是参数的一维数组,

struct Args_1
  N::Int64
  c::Array{Float64}
end

在 f(args) 中的性能会比 f(N,c) 差得多。但是,如果我们在结构定义中指定 c 是一维数组,

struct Args_1
  N::Int64
  c::Array{Float64,1}
end

然后性能损失就消失了。这个问题和我的函数定义中显示的类型不稳定性似乎解释了我在使用结构作为函数参数时遇到的性能差异。

标签: structjulianumerical-methods

解决方案


也许你没有声明args的类型声明的参数类型?

考虑这个小例子:

struct argstype
    N
    dt
end
myfun(args) = args.N * args.dt

myfunis not type-stable 不能推断返回类型的类型:

@code_warntype myfun(argstype(10,0.1))
Variables:
  #self# <optimized out>
  args::argstype

Body:
  begin 
      return ((Core.getfield)(args::argstype, :N)::Any * (Core.getfield)(args::argstype, :dt)::Any)::Any
  end::Any

但是,如果您声明类型,则代码将变为类型稳定的:

struct argstype2
    N::Int
    dt::Float64
end

@code_warntype myfun(argstype2(10,0.1))
Variables:
  #self# <optimized out>
  args::argstype2

Body:
  begin 
      return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype2, :N)::Int64)::Float64, (Core.getfield)(args::argstype2, :dt)::Float64)::Float64
  end::Float64

您会看到 Float64 的推断返回类型。使用参数类型(https://docs.julialang.org/en/v0.6.3/manual/types/#Parametric-Types-1),您的代码仍然同时保持通用和类型稳定:

struct argstype3{T1,T2}
    N::T1
    dt::T2
end

@code_warntype myfun(argstype3(10,0.1))
Variables:
  #self# <optimized out>
  args::argstype3{Int64,Float64}

Body:
  begin 
      return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :N)::Int64)::Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :dt)::Float64)::Float64
  end::Float64

推荐阅读