首页 > 解决方案 > 如何从常规矩阵继承

问题描述

using ShiftedArrays

struct CircularMatrix{T} <: AbstractArray{T,2}
    data::Array{T,2}
    view::CircShiftedArray
    currentIndex::Int
    function CircularMatrix{T}(dims...) where T
        data = zeros(T, dims...)
        CircularMatrix(data, ShiftedArrays.circshift(data, (0, -1)), 1)
    end
end

Base.size(M::CircularMatrix) = size(M.data)
Base.eltype(::Type{CircularMatrix{T}}) where {T} = T

function shift_forward!(M::CircularMatrix)
    M.shift_forward!(1)
end

function shift_forward!(M::CircularMatrix, n)
    # replace the view with a view shifted forwards.
    M.currentIndex += n
    M.view = ShiftedArrays.circshift(M.data, (n, M.currentIndex))
end

@inline Base.@propagate_inbounds function Base.getindex(M::CircularMatrix, i) = M.view[i]
@inline Base.@propagate_inbounds function Base.setindex!(M::CircularMatrix, data, i) = M.view[i] = data

如何使 CircularMatrix 像常规矩阵一样工作。这样我就可以像访问它一样

m = CircularMatrix{Int}(4,4)
m[1, 1] = 5
x = view(m, 1, :)

标签: julia

解决方案


您的矩阵类型被定义为AbstractArray{T, 2}. 您需要在 Julia 的非正式数组接口中为您的类型实现一些方法,以使适用于您的自定义类型的函数和特性AbstractArray{T, 2}也适用于您的自定义类型,也就是说,使您CircularMatrix成为一个可迭代、可索引、功能齐全的矩阵。

实现的方法是

  1. size(M::CircularMatrix)
  2. getindex(M::CircularMatrix, i::Int)
  3. getindex(M::CircularMatrix, I::Vararg{Int, N})
  4. setindex!(M::CircularMatrix, v, i::Int)
  5. setindex!(M::CircularMatrix, v, I::Vararg{Int, N})

您已经实现了 1、2 和 4,但尚未设置索引样式。如果您选择线性索引样式,您可能不需要 3 和 5 。您只需要设置IndexStyleIndexLinear()并且可能进行一些修改,那么一切都应该适用于您的矩阵。

1.size(M::CircularMatrix)

第一个是sizesize(A::CircularMatrix)返回 aTuple的尺寸A。我相信您的矩阵可能类似于以下内容

Base.size(M::CircularMatrix) = size(M.data)

2.getindex(M::CircularMatrix, i::Int)

如果您选择线性索引样式,则需要此方法。getindex(M, i::Int)应该给你线性索引的值i。您已经在代码中实现了它。如果您选择线性索引,您需要IndexStyle为您的类型设置,然后您只需跳过 3 和 5。Julia 会自动转换多个索引访问,例如a[3, 5],转换为线性索引访问。

Base.IndexStyle(::Type{<:CircularMatrix}) = IndexLinear()

Base.@propogate_inbounds function Base.getindex(M::CircularMatrix, i::Int)
    @boundscheck checkbounds(M, i)
    @inbounds M.view[i]
end

@inbounds在第二行使用这里可能会更好。如果调用者不使用@inbounds,我们首先检查边界,这有望使后续的边界检查变得不必要。不过,您可能希望在开发过程中忽略这一点。

3.getindex(M::CircularMatrix, I::Vararg{Int, N})

第三个是笛卡尔索引样式。如果你选择这种风格,你需要实现这个方法。Vararg{Int, N}在签名中代表“确切的N Int论点”。这里N应该等于 的维数CircularMatrix。由于这是一个矩阵,N 应该是 2。如果你选择这种风格,你需要定义如下

Base.@propogate_inbounds function Base.getindex(A::CircularMatrix, I::Vararg{Int, 2})
    @boundscheck checkbounds(A, I...)
    @inbounds A.view[# convert I[1]` and `I[2]` to a linear index in `view`]
end

或者因为你的维度不是参数并且矩阵是二维的,所以

 Base.@propogate_inbounds function Base.getindex(A::CircularMatrix, i::Int, j::Int)
    @boundscheck checkbounds(A, i, j)
    @inbounds A.view[# convert i` and `j` to a linear index in `view`]
end

4.setindex!(M::CircularMatrix, v, i::Int)

第四个与第二个类似。i如果您选择线性索引样式,此方法应将值设置为线性索引。

5.setindex!(M::CircularMatrix, v, I::Vararg{Int, N})

如果您选择笛卡尔索引样式,第五个应该与第三个类似。


在实现 1、2 和 4 和设置IndexStyle之后,您应该有一个可以正常工作的自定义矩阵类型。

m[1, 1] = 5
x = view(m, 1, :)

for e in 
  ...
end

for i in eachindex(m)
  ...
end

display(m)
println(m)
length(m)
ndims(m)
map(f, A)
....

这些都应该有效。

几点注意事项

  • 这里有一个抽象数组接口的文档,其中包含一些示例。您还可以查看实现的可选方法

  • GitHub 上有一个JuliaArray组织,它提供了许多有用的自定义数组实现,包括StaticArraysOffsetArrays等,还有一个提供自定义矩阵类型的JuliaMatrices组织。你可能想看看他们的实现。

  • @inline如果您使用Base.@propogate_inbounds.

@propagate_inbounds

告诉编译器内联函数,同时保留调用者的入站上下文。

  • 您不需要eltype为您的矩阵定义,因为已经有一个定义AbstractArray{T, N}返回T

推荐阅读