首页 > 解决方案 > 使文字常量的类型依赖于其他变量

问题描述

我在 Julia 中有以下代码,其中文字常量2.对数组元素进行乘法运算。我现在将文字常量设为单精度 ( 2.f0),但我想让类型依赖于其他变量(这些是 allFloat64或 all Float32)。我如何以优雅的方式做到这一点?

function diff!(
        at, a,
        visc, dxidxi, dyidyi, dzidzi,
        itot, jtot, ktot)
​
    @tturbo for k in 2:ktot-1
        for j in 2:jtot-1
            for i in 2:itot-1
                at[i, j, k] += visc * (
                    (a[i-1, j  , k  ] - 2.f0 * a[i, j, k] + a[i+1, j  , k  ]) * dxidxi +
                    (a[i  , j-1, k  ] - 2.f0 * a[i, j, k] + a[i  , j+1, k  ]) * dyidyi +
                    (a[i  , j  , k-1] - 2.f0 * a[i, j, k] + a[i  , j  , k+1]) * dzidzi )
            end
        end
    end
end

标签: julialiterals

解决方案


一般来说,如果你有一个 scalarx或一个 array A,你可以分别用T = typeof(x)or 获取类型T = eltype(A),然后使用它来将文字转换为等效类型,例如

julia> A = [1.0]
1-element Vector{Float64}:
 1.0

julia> T = eltype(A)
Float64

julia> T(2)
2.0

所以你原则上可以在函数中使用它,如果一切都是类型稳定的,这实际上应该是无开销的:

julia> @code_native 2 * 1.0f0
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ promotion.jl:322 within `*'
; │┌ @ promotion.jl:292 within `promote'
; ││┌ @ promotion.jl:269 within `_promote'
; │││┌ @ number.jl:7 within `convert'
; ││││┌ @ float.jl:94 within `Float32'
    vcvtsi2ss   %rdi, %xmm1, %xmm1
; │└└└└
; │ @ promotion.jl:322 within `*' @ float.jl:331
    vmulss  %xmm0, %xmm1, %xmm0
; │ @ promotion.jl:322 within `*'
    retq
    nopw    (%rax,%rax)
; └

julia> @code_native 2.0f0 * 1.0f0
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ float.jl:331 within `*'
    vmulss  %xmm1, %xmm0, %xmm0
    retq
    nopw    %cs:(%rax,%rax)
; └

julia> @code_native Float32(2) * 1.0f0
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ float.jl:331 within `*'
    vmulss  %xmm1, %xmm0, %xmm0
    retq
    nopw    %cs:(%rax,%rax)
; └

然而,碰巧的是,Julia 中有一种更优雅的模式来编写函数签名,这样它将参数化地专门用于您传递给此函数的数组的元素类型,然后您应该能够在没有开销的情况下使用它确保您的文字是适当的类型,如下所示:

function diff!(at::AbstractArray{T}, a::AbstractArray{T},
        visc, dxidxi, dyidyi, dzidzi,
        itot, jtot, ktot) where T <: Number

    @tturbo for k in 2:ktot-1
        for j in 2:jtot-1
            for i in 2:itot-1
                at[i, j, k] += visc * (
                    (a[i-1, j  , k  ] - T(2) * a[i, j, k] + a[i+1, j  , k  ]) * dxidxi +
                    (a[i  , j-1, k  ] - T(2) * a[i, j, k] + a[i  , j+1, k  ]) * dyidyi +
                    (a[i  , j  , k-1] - T(2) * a[i, j, k] + a[i  , j  , k+1]) * dzidzi )
            end
        end
    end
end

这种方法在有关 Julia中的参数方法的文档中进行了一定程度的讨论


推荐阅读