diff --git a/REQUIRE b/REQUIRE index a5b224e..fc37aee 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,2 +1,3 @@ julia 1.0 StaticArrays +Reexport diff --git a/src/LabelledArrays.jl b/src/LabelledArrays.jl index d1c80c8..b8393c9 100644 --- a/src/LabelledArrays.jl +++ b/src/LabelledArrays.jl @@ -1,10 +1,12 @@ module LabelledArrays -using StaticArrays, LinearAlgebra +using LinearAlgebra +using Reexport + +@reexport using StaticArrays -include("slvector.jl") include("lvector.jl") -export SLVector, LVector, @SLVector, @LVector +export LVector, @SLVector, @LVector end # module diff --git a/src/lvector.jl b/src/lvector.jl index b23879d..c0479fd 100644 --- a/src/lvector.jl +++ b/src/lvector.jl @@ -90,3 +90,49 @@ macro LVector(vals,syms) end end end + +##################################### +# SLVector +##################################### +const SLVector = LVector{T,SVector{N,T},Syms} where {T,N,Syms} + +""" + @SLVector ElementType Names + +Creates an anonymous function that builds a labelled static vector with eltype +`ElementType` with names determined from the `Names`. The labbeled static vector +is just an `LVector` whose entries are `SVector{ElementType}`. + +For example: + +```julia +ABC = @SLVector Float64 (:a,:b,:c) +x = ABC(1.0,2.5,3.0) +x.a == 1.0 +x.b == 2.5 +x.c == x[3] +``` + +""" +macro SLVector(E,syms) + return quote + function (vals...,) + v = SVector{$(length(syms.args)), $E}(vals...) + T = typeof(v) + return LVector{$(esc(E)),T,$syms}(v) + end + end +end + +Base.copy(x::SLVector{T,N,Syms}) where {T,N,Syms} = LVector{Syms}(x.__x) +function Base.similar(::SLVector{T,N,Syms}, ::Type{S}) where {T,N,Syms,S} + tmp = Vector{S}(undef, N) + LVector{Syms}(SVector{N}(tmp)) +end +function Base.AbstractVector{S}(x::SLVector{T,N,Syms}) where {S,T,N,Syms} + LVector{Syms}(S.(x.__x)) +end +function Base.broadcast(f, xs::SLVector{T,N,Syms}...) where {T,N,Syms} + result = broadcast(f, (x.__x for x in xs)...) + LVector{Syms}(result) +end \ No newline at end of file diff --git a/src/slvector.jl b/src/slvector.jl deleted file mode 100644 index 51c8e7b..0000000 --- a/src/slvector.jl +++ /dev/null @@ -1,37 +0,0 @@ -abstract type SLVector{N,T} <: FieldVector{N,T} end - -# SLVector Macro - -""" - @SLVector TypeName ElementType Names - -Creates a static vector type with name TypeName and eltype ElementType -with names determined from the `Names`. - -For example: - -```julia -@SLVector ABC Float64 [a,b,c] -x = ABC(1.0,2.5,3.0) -x.a == 1.0 -x.b == 2.5 -x.c == x[3] -``` - -""" -macro SLVector(tname,T,_names) - names = Symbol.(_names.args) - quote - struct $(tname) <: SLVector{$(length(names)),$T} - $((:($n::$T) for n in names)...) - $(tname)($((:($n) for n in names)...)) = new($((:($n) for n in names)...)) - $(tname)(x::Tuple{Any}) = new(first(x)) - end - end -end - -# Fix broadcast https://github.com/JuliaArrays/StaticArrays.jl/issues/314 -function StaticArrays.similar_type(::Type{V}, ::Type{T}, ::Size{N}) where - {V<:SLVector,T,N} - V -end diff --git a/test/diffeq.jl b/test/diffeq.jl index 8f00c2f..74351f3 100644 --- a/test/diffeq.jl +++ b/test/diffeq.jl @@ -1,7 +1,7 @@ using LabelledArrays, OrdinaryDiffEq, Test -@SLVector LorenzVector Float64 [x,y,z] -@SLVector LorenzParameterVector Float64 [σ,ρ,β] +LorenzVector = @SLVector Float64 (:x,:y,:z) +LorenzParameterVector = @SLVector Float64 (:σ,:ρ,:β) function f(u,p,t) x = p.σ*(u.y-u.x) diff --git a/test/slvectors.jl b/test/slvectors.jl index d244cef..298da2b 100644 --- a/test/slvectors.jl +++ b/test/slvectors.jl @@ -1,7 +1,7 @@ using LabelledArrays using Test -@SLVector ABC Int [a,b,c] +ABC = @SLVector Int (:a,:b,:c) b = ABC(1,2,3) @test b.a == 1 @@ -12,5 +12,26 @@ b = ABC(1,2,3) @test b[3] == b.c @test_throws UndefVarError fill!(a,1) -@test typeof(b) <: SLVector{3,Int} -typeof(b.+b) <: ABC +@test typeof(b.__x) <: SVector{3,Int} +bb = b.+b +@test bb isa LVector +@test eltype(bb) == Int +b1 = b.+1.0 +@test b1 isa LVector +@test eltype(b1) == Float64 + +# Type stability tests +ABC_int = @SLVector Int (:a,:b,:c) +ABC_float = @SLVector Float64 (:a, :b, :c) +x = ABC_int(1,2,3) +y = ABC_float(4.,5.,6.) + +@test typeof(copy(x)) == typeof(x) +@test typeof(similar(x)) == typeof(x) +@test typeof(similar(x, Float64)) == typeof(y) +@test typeof(convert(AbstractVector{Float64}, x)) == typeof(y) +@test_broken typeof(x .+ x) == typeof(x) # degrades to LVector of Vector{Int} +@test typeof(broadcast(+, x, x)) == typeof(x) # why does this work then? +@test_broken typeof(Float64.(x)) == typeof(y) # degrades to LVector of Vector{Float} +@test typeof(broadcast(Float64, x)) == typeof(y) # why does this work then? +@test_broken broadcast(+, x, y) # ERROR: conflicting broadcast rules defined