struct - 朱莉娅:避免许多函数参数的有效而整洁的方法
问题描述
我一直在 Julia 中编写随机 PDE 模拟,随着我的问题变得越来越复杂,独立参数的数量也在增加。那么从什么开始,
myfun(N,M,dt,dx,a,b)
最终变成
myfun(N,M,dt,dx,a,b,c,d,e,f,g,h)
它会导致 (1) 混乱的代码,(2) 由于放错函数参数而增加出错的机会,(3) 无法泛化以用于其他函数。
(3) 很重要,因为我已经对我的代码进行了简单的并行化,以评估 PDE 的许多不同运行。所以我想把我的函数转换成一种形式:
myfun(args)
其中 args 包含所有相关参数。我在 Julia 中发现的问题是,创建struct
包含所有相关参数作为属性的属性会大大减慢速度。我认为这是由于结构属性的不断访问。作为一个简单的(ODE)工作示例,
function example_fun(N,dt,a,b)
V = zeros(N+1)
U = 0
z = randn(N+1)
for i=2:N+1
V[i] = V[i-1]*(1-dt)+U*dt
U = U*(1-dt/a)+b*sqrt(2*dt/a)*z[i]
end
return V
end
如果我尝试将其重写为,
function example_fun2(args)
V = zeros(args.N+1)
U = 0
z = randn(args.N+1)
for i=2:args.N+1
V[i] = V[i-1]*(1-args.dt)+U*args.dt
U = U*(1-args.dt/args.a)+args.b*sqrt(2*args.dt/args.a)*z[i]
end
return V
end
然后,虽然函数调用看起来很优雅,但重新访问类中的每个属性很麻烦,而且这种持续访问属性会减慢模拟速度。什么是更好的解决方案?有没有办法简单地“解包”结构的属性,这样就不必不断地访问它们?如果是这样,这将如何概括?
编辑:我正在定义我使用的结构如下:
struct Args
N::Int64
dt::Float64
a::Float64
b::Float64
end
编辑2:我已经意识到,如果您没有在结构定义中指定数组的维度,那么具有 Array{} 属性的结构会导致性能差异。例如,如果 c 是参数的一维数组,
struct Args_1
N::Int64
c::Array{Float64}
end
在 f(args) 中的性能会比 f(N,c) 差得多。但是,如果我们在结构定义中指定 c 是一维数组,
struct Args_1
N::Int64
c::Array{Float64,1}
end
然后性能损失就消失了。这个问题和我的函数定义中显示的类型不稳定性似乎解释了我在使用结构作为函数参数时遇到的性能差异。
解决方案
也许你没有声明args的类型声明的参数类型?
考虑这个小例子:
struct argstype
N
dt
end
myfun(args) = args.N * args.dt
myfun
is not type-stable 不能推断返回类型的类型:
@code_warntype myfun(argstype(10,0.1))
Variables:
#self# <optimized out>
args::argstype
Body:
begin
return ((Core.getfield)(args::argstype, :N)::Any * (Core.getfield)(args::argstype, :dt)::Any)::Any
end::Any
但是,如果您声明类型,则代码将变为类型稳定的:
struct argstype2
N::Int
dt::Float64
end
@code_warntype myfun(argstype2(10,0.1))
Variables:
#self# <optimized out>
args::argstype2
Body:
begin
return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype2, :N)::Int64)::Float64, (Core.getfield)(args::argstype2, :dt)::Float64)::Float64
end::Float64
您会看到 Float64 的推断返回类型。使用参数类型(https://docs.julialang.org/en/v0.6.3/manual/types/#Parametric-Types-1),您的代码仍然同时保持通用和类型稳定:
struct argstype3{T1,T2}
N::T1
dt::T2
end
@code_warntype myfun(argstype3(10,0.1))
Variables:
#self# <optimized out>
args::argstype3{Int64,Float64}
Body:
begin
return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :N)::Int64)::Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :dt)::Float64)::Float64
end::Float64
推荐阅读
- node.js - 抛出新的 ERR_HTTP_INVALID_STATUS_CODE(originalStatusCode);
- c - 带有 makefile 和多个 .c 源文件的“未定义引用”
- javascript - Node.js 以管理员身份启动程序
- swiftui - 如何在 SwiftUI 上推送到 DatePicker 选择更改的另一个视图
- java - 构建标准 API 查询以避免 MultipleBagFetchException
- python - 继续尝试安装 ecapture 但每次都给我错误
- c - C 在函数原型中使用宏结构
- python - 循环直到按键被按下并重复
- python - 如何根据出现来排列 Dataframe 值
- c# - 以百分比/概率在 Unity 中生成随机游戏对象