neural-network - 如何在 DiffEqFlux.lj neuroODE 中创建任意参数化层?Julia Julialang Flux.jl
问题描述
我能够使用 Flux.jl 和 DiffEqFlux.jl 在 julia(1.3 和 1.2)中创建和优化神经微积分,但它在一个至关重要的一般情况下失败了。
什么有效:
- 如果神经网络参数是由提供的 Flux.jl 层(如Dense())构建的,我可以训练它。
- 我可以在网络链中包含任意函数作为层,例如x -> x.*x
失败 的原因:但是,如果任意函数有我想训练的参数,那么 Flux。Train 不会调整这些参数导致它失败。
我已经尝试使这些添加的参数被跟踪并包含在提供给训练系统的参数列表中,但它会忽略它们并且它们保持不变。
文档非常神秘地说,可以在一个层上使用 Flux.@functor 以确保它的参数被跟踪。然而,函子直到版本 0.10.0 才包含在 Flux 中,并且与 DiffEqFlux 中的 NeuralODE 兼容的唯一 Flux 版本是 0.9.0
所以这是我想使用的 2 层神经网络的玩具示例
p = param([1.0])
dudt = chain( x -> p[1]*x.*x, Dense(2,2) )
ps = Flux.params(dudt)
然后我在这个上使用通量火车。当我这样做时,参数 p 没有变化,但 Dense 层中的参数是变化的。
我已经尝试过明确包括这样
ps = Flux.Params([p,dudt])
但这具有相同的结果和相同的问题
我认为我需要做的是构建一个带有相关函数的结构,该函数实现
x->p[1]*x*x
然后在此调用@functor。然后可以在链中使用该结构。
但正如我注意到的那样,带有@functor 的 Flux 版本与任何版本的 DiffEqFlux 都不兼容。
所以我需要一种方法让通量关注我的自定义参数,而不仅仅是 Dense() 中的参数
如何???
解决方案
我想我明白你的问题是什么,但请澄清我是否在这里回答了错误的问题。问题在于,p
它只是从全局参考中获取,因此在伴随过程中没有区别。在 2020 年处理此问题的更好方法是使用FastChain
. 该FastChan
接口允许您定义层函数及其参数依赖关系,因此这是使您的神经网络将任意函数与参数结合起来的好方法。看起来是这样的:
using DifferentialEquations
using Flux, Zygote
using DiffEqFlux
x = Float32[2.; 0.]
p = Float32[2.0]
tspan = (0.0f0,1.0f0)
mylayer(x,p) = p[1]*x
DiffEqFlux.paramlength(::typeof(mylayer)) = 1
DiffEqFlux.initial_params(::typeof(mylayer)) = rand(Float32,1)
dudt = FastChain(FastDense(2,50,tanh),FastDense(50,2),mylayer)
p = DiffEqFlux.initial_params(dudt)
function f(u,p,t)
dudt(u,p)
end
ex_neural_ode(x,p) = solve(ODEProblem(f,x,tspan,p),Tsit5())
solve(ODEProblem(f,x,tspan,p),Tsit5())
du0,dp = Zygote.gradient((x,p)->sum(ex_neural_ode(x,p)),x,p)
其中 的最后一个值是inp
的一个参数。或者你可以直接使用 Flux:p
mylayer
using DifferentialEquations
using Flux, Zygote
using DiffEqFlux
x = Float32[2.; 0.]
p2 = Float32[2.0]
tspan = (0.0f0,1.0f0)
dudt = Chain(Dense(2,50,tanh),Dense(50,2))
p,re = Flux.destructure(dudt)
function f(u,p,t)
re(p[1:end-1])(u) |> x-> p[end]*x
end
ex_neural_ode() = solve(ODEProblem(f,x,tspan,[p;p2]),Tsit5())
grads = Zygote.gradient(()->sum(ex_neural_ode()),Flux.params(x,p,p2))
grads[x]
grads[p]
grads[p2]
推荐阅读
- java - 从一个类中启动另一个活动的活动访问方法
- java - 如何在 Quarkus 中获取静态值的配置值
- progressive-web-apps - 是否可以在一个网站上拥有多个 PWA?
- c# - C# MS ClearScript 添加动态宿主对象
- reactjs - 使用 react-ace diff 组件无法显示 diff 变化
- javascript - useEffect 钩子在 setTimeout 和 state 中行为不端
- elasticsearch - elasticsearch:没有已知的主节点,安排重试
- google-cloud-platform - 我可以在自己的虚拟机实例上运行云构建吗
- asp.net-mvc - 如何强制主页运行 IE11 和 iFrame 运行 IE7
- sql - Oracle - 有没有办法重写我的查询以减少我查询表的次数