julia - 将 ForwardDiff.jl 用于包含许多变量和参数的函数 Julia
问题描述
ForwardDiff.jl 的 github 存储库有一些示例。我试图扩展这个例子,除了一个变量向量,一个参数。我无法让它工作。
这是示例(它很短,所以我将展示它而不是链接)
using ForwardDiff
x = rand(5)
f(x::Vector) = sum(sin, x) .+ prod(tan, x) * sum(sqrt, x);
g = x -> ForwardDiff.gradient(f, x);
g(x) # this outputs the gradient.
我想修改它,因为我使用具有多个参数和变量的函数。作为一个简单的修改,我尝试添加一个参数。
f(x::Vector, y) = (sum(sin, x) .+ prod(tan, x) * sum(sqrt, x)) * y;
我尝试了以下方法无济于事:
fp = x -> ForwardDiff.gradient(f, x);
fp = x -> ForwardDiff.gradient(f, x, y);
y = 1
println("test grad: ", fp(x, y))
我收到以下错误消息:
ERROR: LoadError: MethodError: no method matching (::var"#73#74")(::Array{Float64,1}, ::Int64)
2017 年没有回答类似的问题。评论将我带到这里,似乎该功能只能接受一个输入?
目标函数必须是一元的(即只接受一个参数)。ForwardDiff.jacobian 是此规则的一个例外。
这有改变吗?只能区分一元函数似乎非常有限。
一种可能的解决方法是将变量和参数列表连接起来,然后将返回的渐变分割为不包括与参数相关的渐变,但这似乎很愚蠢。
解决方案
我个人认为对 ForwardDiff 使用这种仅一元的语法是有意义的。在您的情况下,您可以打包/解包x
并y
放入单个向量(名义上x2
如下):
julia> using ForwardDiff
julia> x = rand(5)
5-element Array{Float64,1}:
0.4304735670747184
0.3939269364431113
0.7912705403776603
0.8942024934250143
0.5724373306715196
julia> f(x::Vector, y) = (sum(sin, x) .+ prod(tan, x) * sum(sqrt, x)) * y;
julia> y = 1
1
julia> f(x2::Vector) = f(x2[1:end-1], x2[end]) % unpacking in f call
f (generic function with 2 methods)
julia> fp = x -> ForwardDiff.gradient(f, x);
julia> println("test grad: ", fp([x; y])) % packing in fp call
test grad: [2.6105844240785796, 2.741442601659502, 1.9913192377198885, 1.9382805843854594, 2.26202717745402, 3.434350946190029]
但我的偏好是明确地以不同的方式命名偏导数:
julia> ∂f∂x(x,y) = ForwardDiff.gradient(x -> f(x,y), x)
∂f∂x (generic function with 1 method)
julia> ∂f∂y(x,y) = ForwardDiff.derivative(y -> f(x,y), y)
∂f∂y (generic function with 1 method)
julia> ∂f∂x(x, y)
5-element Array{Float64,1}:
2.6105844240785796
2.741442601659502
1.9913192377198885
1.9382805843854594
2.26202717745402
julia> ∂f∂y(x, y)
3.434350946190029
推荐阅读
- typescript - 将“唯一符号”传递给函数
- java - 如何使用 Jackson ObjectMapper 反序列化 Spring 的 ResponseEntity(可能使用 @JsonCreator 和 Jackson mixins)
- javascript - 从我的 .json 文件/第一次 js 节点用户获取未定义的值
- python-3.x - 在启用 R-Tree 模块的情况下安装 SQLITE?
- google-smart-home - 如何在谷歌智能家居中调用“真空厨房”?
- html - CSS定位 - 如何将一个输入放在另一个旁边
- python-3.x - 将字符串转换为数字时出现问题
- tensorflow - Facenet 和单图像前向传播
- ios - nfc 代码在 ios 11 中有效,但在 iOS13 中无效
- java - 如何正确配置successHandler?