首页 > 解决方案 > 是否有一个函数可以让我控制 var args 的数量?

问题描述

我有以下代码:

circ(x) = x./sqrt(sum(x .* x))

x -> cat(circ(x), circ(x); dims = 1)

但我希望能够创建一个函数,在其中输入一个数字并将该数量的 circ(x) 连接起来。

例如:

function Ncircs(n)
  #some way to make cat() have as its parameter circ n number of times
end

我可以打电话Ncircs(2)得到 x -> cat(circ(x), circ(x); dims = 1)Ncircs(3)得到 x -> cat(circ(x), circ(x), circ(x); dims = 1)Ncircs(4)得到 x -> cat(circ(x), circ(x), circ(x), circ(x); dims = 1)

等等

有没有办法做到这一点?我必须使用宏吗?

标签: julia

解决方案


你可以写:

Ncircs(n) = x -> cat(Iterators.repeated(circ(x), n)...; dims = 1)

如果你知道你会一直这样做,dims=1那么catvcatreduce

Ncircs(n) = x -> reduce(vcat, Iterators.repeated(circ(x), n))

对于大型n.

附带说明:使用另一个选项 ( vcat) 将产生类型稳定的结果,而第一个选项不是类型稳定的。

编辑

为什么不允许减少空集合?

一般来说,原因是您无法判断减少的结果应该是什么。如果你想允许一个空集合,你应该添加init关键字参数。这是一个例子:

julia> reduce(vcat, [])
ERROR: ArgumentError: reducing over an empty collection is not allowed

julia> reduce(vcat, [], init = [1])
1-element Array{Int64,1}:
 1

julia> reduce(vcat, [[2,3], [4,5]], init = [1])
5-element Array{Int64,1}:
 1
 2
 3
 4
 5

结果是类型稳定的是什么意思

这意味着 Julia 能够在编译时(在执行代码之前)判断函数返回值的类型。类型稳定的代码通常运行得更快(尽管这是一个广泛的话题 - 我建议您阅读 Julia 手册以详细了解它)。@code_warntype您可以使用and来检查函数类型是否稳定Test.@inferred

在这里,让我在您的具体情况下给您一个解释(我截断了一些输出以缩短答案)。

julia> x = [1,2,3]
3-element Array{Int64,1}:
 1
 2
 3

julia> y = [4,5,6]
3-element Array{Int64,1}:
 4
 5
 6

julia> @code_warntype vcat(x,y)
Body::Array{Int64,1}
...

julia> @code_warntype cat(x,y, dims=1)
Body::Any
...

julia> using Test

julia> @inferred vcat(x,y)
6-element Array{Int64,1}:
 1
 2
 3
 4
 5
 6

julia> @inferred cat(x,y, dims=1)
ERROR: return type Array{Int64,1} does not match inferred return type Any

Any以上意味着编译器不知道答案的类型。原因是在这种情况下,这种类型取决于dims参数。如果是1,它将是一个向量,如果是2,它将是一个矩阵。

我怎么知道它对于大型n

您可以运行@which宏:

julia> @which reduce(vcat, [[1,2,3], [4,5,6]])
reduce(::typeof(vcat), A::AbstractArray{#s72,1} where #s72<:(Union{AbstractArray{T,2}, AbstractArray{T,1}} where T)) in Base at abstractarray.jl:1321

而且你看到有一个专门的reduce方法vcat

现在如果你运行:

@edit reduce(vcat, [[1,2,3], [4,5,6]])

将打开一个编辑器,您会看到它调用了一个内部函数_typed_vcat,该函数针对vcat大量数组进行了优化。引入了这种优化是因为使用这样的 splattingvcat([[1,2,3], [4,5,6]]...)在结果中是等效的,但是您必须进行 splatting (the ),这本身就有一些使用该版本...可以避免的成本。reduce

为了确保我所说的是真的,您可以执行以下基准测试:

julia> using BenchmarkTools

julia> y = [[i] for i in 1:10000];

julia> @benchmark vcat($y...)
BenchmarkTools.Trial:
  memory estimate:  156.45 KiB
  allocs estimate:  3
  --------------
  minimum time:     67.200 μs (0.00% GC)
  median time:      77.800 μs (0.00% GC)
  mean time:        102.804 μs (8.50% GC)
  maximum time:     35.179 ms (99.47% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark reduce(vcat, $y)
BenchmarkTools.Trial:
  memory estimate:  78.20 KiB
  allocs estimate:  2
  --------------
  minimum time:     67.700 μs (0.00% GC)
  median time:      69.700 μs (0.00% GC)
  mean time:        82.442 μs (6.39% GC)
  maximum time:     32.719 ms (99.58% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark cat($y..., dims=1)
ERROR: StackOverflowError:

而且您会看到该reduce版本比 splatting 版本稍快vcat,而cat对于非常大的版本则根本失败n(对于较小的版本,n它会起作用,但只是速度较慢)。


推荐阅读