Skip to content

Commit

Permalink
Make construct_type more robust
Browse files Browse the repository at this point in the history
As most `StaticArray` has a special constructor `SA(x::Tuple)`. This help us to locate constructor missing.
  • Loading branch information
N5N3 committed Mar 5, 2022
1 parent 26aea89 commit cf16d37
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ _size1(::Type{<:StaticMatrix{M}}) where {M} = M
function construct_type(SA, x::AbstractArray)
ET = has_eltype(SA) ? eltype(SA) : eltype(x)
has_size(SA) || throw_convert(SA)
construct_type(SA, ET, Tuple{size(SA)...})
construct_type(Tuple{SA, AbstractArray}, ET, Tuple{size(SA)...})
end
throw_convert(SA) = throw(DimensionMismatch("No precise constructor for $SA found. Input is not static sized."))

Expand All @@ -52,7 +52,7 @@ function construct_type(::Type{SA}, ::Type{A}, ::Type{SZ}, ::Type{T}) where {SA
ET = has_eltype(SA) ? eltype(SA) : T
if has_size(SA) # SA has size defined, just check its validity.
if length(SA) == len
return construct_type(SA, ET, Tuple{size(SA)...})
return construct_type(Tuple{SA, A}, ET, Tuple{size(SA)...})
elseif A !== Args && length(SA) == 1
# special case for SVector{1}((1,2)), We return nothing.
# Then Parent call will try `construct_type(SA, (x,))`
Expand All @@ -62,17 +62,17 @@ function construct_type(::Type{SA}, ::Type{A}, ::Type{SZ}, ::Type{T}) where {SA
else
if SA <: StaticVector
# For vector, use input length directly
return construct_type(SA, ET, Tuple{len})
return construct_type(Tuple{SA, A}, ET, Tuple{len})
elseif SA <: StaticMatrix && has_size1(SA)
# Similar for n*? matrix, as we have SMatrix{N}(...)
# TODO: is there a better way to extend this branch to, e.g. "?*n" "m*?*n"?
N = _size1(SA)
M = len ÷ N
M * N == len || throw(DimensionMismatch("Incorrect matrix sizes. $len does not divide $N elements"))
return construct_type(SA, ET, Tuple{N, M})
return construct_type(Tuple{SA, A}, ET, Tuple{N, M})
elseif A <: StaticArray
# Here we just try with src's shape.
return construct_type(SA, ET, SZ)
return construct_type(Tuple{SA, A}, ET, SZ)
end
end
throw_convert(SA, SZ)
Expand All @@ -84,11 +84,12 @@ check_parameters(::Type{MArray{S,T,N,L}}) where {S,T,N,L} = (check_array_paramet
check_parameters(::Type{SHermitianCompact{N,T,L}}) where {N,T,L} = (_check_hermitian_parameters(Val(N), Val(L));true)
check_parameters(::Type{Union{}}) = false

function construct_type(SA, ET, SZ::Type{<:Tuple})
function construct_type(::Type{Tuple{SA, A}}, ::Type{ET}, ::Type{SZ}) where {SA <: StaticArray, A <: Union{Tuple, Args, StaticArray, AbstractArray}, ET, SZ <: Tuple}
# Here we use Base.typeintersect to get the most concrete dest type (if valid)
# It's similar to `similar_type`, but not fallback to SArray by default.
T = Base.typeintersect(SA, StaticArray{SZ,ET,tuple_length(SZ)})
check_parameters(T) || throw_convert(SA, SZ)
A === Tuple && SA === T && error("Constructor is missing for $(SA)(::Tuple), please file a bug.")
return T
end
throw_convert(SA, ::Type{Tuple{N}}) where {N} = throw(DimensionMismatch("No precise constructor for $SA found. Length of input was $N."))
Expand Down
12 changes: 12 additions & 0 deletions test/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,16 @@ end
end
mInt = SA[Int16(1) Int16(2) Int16(3); Int16(4) Int16(5) Int16(6)] # SMatrix{3,2,Int16}
@test float(typeof(mInt)) === SMatrix{2, 3, float(Int16), 6}
end

struct BugSArray{S<:Tuple,T,N,L} <: StaticArray{S,T,N}
data::NTuple{L,T}
BugSArray{S,T,N,L}(x::NTuple{L,Any}) where {S<:Tuple,T,N,L} = new{S,T,N,L}(map(T,x))
end
(::Type{BS})(x::Tuple) where {BS<:BugSArray} = StaticArrays.construct_type(BS, x)(x)
BugSVector{N,T,L} = BugSArray{Tuple{N},T,1,L}
@testset "missing constructor" begin
@test_throws DimensionMismatch BugSArray(1,2,3)
@test_throws ErrorException BugSVector(1,2,3) # we catch a missing constructor here.
@test BugSVector{<:Any,<:Any,3}(1,2,3) isa BugSVector{3,Int,3}
end

0 comments on commit cf16d37

Please sign in to comment.