Skip to content

Commit

Permalink
Made Size use more consistent.
Browse files Browse the repository at this point in the history
Removed type unstable behaviour of similar, fixed #22
  • Loading branch information
Andy Ferris committed Nov 2, 2016
1 parent 1f035b4 commit 01ea146
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 10 deletions.
30 changes: 29 additions & 1 deletion src/SizedArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,42 @@ array may be reshaped.
immutable SizedArray{S,T,N,M} <: StaticArray{T,N}
data::Array{T,M}

function SizedArray(a)
function SizedArray(a::Array)
if length(a) != prod(S)
error("Dimensions $(size(a)) don't match static size $S")
end
new(a)
end

function SizedArray()
new(Array{T,M}(S))
end
end

@inline (::Type{SizedArray{S,T,N}}){S,T,N,M}(a::Array{T,M}) = SizedArray{S,T,N,M}(a)
@inline (::Type{SizedArray{S,T}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,_ndims(S),M}(a)
@inline (::Type{SizedArray{S}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,_ndims(S),M}(a)

@inline (::Type{SizedArray{S,T,N}}){S,T,N}() = SizedArray{S,T,N,N}()
@inline (::Type{SizedArray{S,T}}){S,T}() = SizedArray{S,T,_ndims(S),_ndims(S)}()

@generated function (::Type{SizedArray{S,T,N,M}}){S,T,N,M,L}(x::NTuple{L})
if L != prod(S)
error("Dimension mismatch")
end
exprs = [:(a[$i] = x[$i]) for i = 1:L]
return quote
$(Expr(:meta, :inline))
a = SizedArray{S,T,N,M}()
@inbounds $(Expr(:block, exprs...))
return a
end
end

@inline (::Type{SizedArray{S,T,N}}){S,T,N}(x::Tuple) = SizedArray{S,T,N,N}(x)
@inline (::Type{SizedArray{S,T}}){S,T}(x::Tuple) = SizedArray{S,T,_dims(S),_dims(S)}(x)
@inline (::Type{SizedArray{S}}){S,T,L}(x::NTuple{L,T}) = SizedArray{S,T,_dims(S),_dims(S)}(x)

# Overide some problematic default behaviour
@inline convert{SA<:SizedArray}(::Type{SA}, sa::SizedArray) = SA(sa.data)

Expand All @@ -38,10 +62,14 @@ end
@propagate_inbounds setindex!(a::SizedArray, v, i::Int) = setindex!(a.data, v, i)

typealias SizedVector{S,T,M} SizedArray{S,T,1,M}
@pure size{S}(::Type{SizedVector{S}}) = S
@inline (::Type{SizedVector{S}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,1,M}(a)
@inline (::Type{SizedVector{S}}){S,T,L}(x::NTuple{L,T}) = SizedArray{S,T,1,1}(x)

typealias SizedMatrix{S,T,M} SizedArray{S,T,2,M}
@pure size{S}(::Type{SizedMatrix{S}}) = S
@inline (::Type{SizedMatrix{S}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,2,M}(a)
@inline (::Type{SizedMatrix{S}}){S,T,L}(x::NTuple{L,T}) = SizedArray{S,T,2,2}(x)


"""
Expand Down
88 changes: 80 additions & 8 deletions src/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ end
@pure function similar_type{SA<:StaticArray,T}(::Union{SA,Type{SA}}, ::Type{T}, size::Int)
similar_type(similar_type(SA, T), size)
end
@pure function similar_type{SA<:StaticArray,T,S}(::Union{SA,Type{SA}}, ::Type{T}, size::Size{S})
similar_type(similar_type(SA, T), size)
end
@generated function similar_type{SA<:StaticArray,T}(::Union{SA,Type{SA}}, ::Type{T})
# This function has a strange error (on tests) regarding double-inference, if it is marked @pure
if T == eltype(SA)
Expand Down Expand Up @@ -115,6 +118,25 @@ end

@pure similar_type{SA<:StaticArray,N}(::Union{SA,Type{SA}}, sizes::Tuple{Vararg{Int,N}}) = SArray{sizes, eltype(SA), N, prod(sizes)}

@generated function similar_type{SA <: StaticArray,S}(::Union{SA,Type{SA}}, ::Size{S})
if length(S) == 1
return quote
$(Expr(:meta, :inline))
SVector{$(S[1]), $(eltype(SA))}
end
elseif length(S) == 2
return quote
$(Expr(:meta, :inline))
SMatrix{$(S[1]), $(S[2]), $(eltype(SA))}
end
else
return quote
$(Expr(:meta, :inline))
SArray{S, $(eltype(SA)), $(length(S)), $(prod(S))}
end
end
end

# Some specializations for the mutable case
@pure similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray}}(::Union{MA,Type{MA}}, size::Int) = MVector{size, eltype(MA)}
@pure similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray}}(::Union{MA,Type{MA}}, sizes::Tuple{Int}) = MVector{sizes[1], eltype(MA)}
Expand All @@ -123,23 +145,73 @@ end

@pure similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray},N}(::Union{MA,Type{MA}}, sizes::Tuple{Vararg{Int,N}}) = MArray{sizes, eltype(MA), N, prod(sizes)}

@generated function similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray},S}(::Union{MA,Type{MA}}, ::Size{S})
if length(S) == 1
return quote
$(Expr(:meta, :inline))
MVector{$(S[1]), $(eltype(MA))}
end
elseif length(S) == 2
return quote
$(Expr(:meta, :inline))
MMatrix{$(S[1]), $(S[2]), $(eltype(MA))}
end
else
return quote
$(Expr(:meta, :inline))
MArray{S, $(eltype(MA)), $(length(S)), $(prod(S))}
end
end
end

# And also similar() returning mutable StaticArrays
@inline similar{SV <: StaticVector}(::SV) = MVector{length(SV),eltype(SV)}()
@inline similar{SV <: StaticVector, T}(::SV, ::Type{T}) = MVector{length(SV),T}()
@inline similar{SA <: StaticArray}(::SA, sizes::Tuple{Int}) = MVector{sizes[1], eltype(SA)}()
@inline similar{SA <: StaticArray}(::SA, size::Int) = MVector{size, eltype(SA)}()
@inline similar{T}(::StaticArray, ::Type{T}, sizes::Tuple{Int}) = MVector{sizes[1],T}()
@inline similar{T}(::StaticArray, ::Type{T}, size::Int) = MVector{size,T}()

@inline similar{SM <: StaticMatrix}(m::SM) = MMatrix{size(SM,1),size(SM,2),eltype(SM),length(SM)}()
@inline similar{SM <: StaticMatrix, T}(::SM, ::Type{T}) = MMatrix{size(SM,1),size(SM,2),T,length(SM)}()
@inline similar{SA <: StaticArray}(::SA, sizes::Tuple{Int,Int}) = MMatrix{sizes[1], sizes[2], eltype(SA), sizes[1]*sizes[2]}()
@inline similar(a::StaticArray, T::Type, sizes::Tuple{Int,Int}) = MMatrix{sizes[1], sizes[2], T, sizes[1]*sizes[2]}()

@inline similar{SA <: StaticArray}(m::SA) = MArray{size(SA),eltype(SA),ndims(SA),length(SA)}()
@inline similar{SA <: StaticArray,T}(m::SA, ::Type{T}) = MArray{size(SA),T,ndims(SA),length(SA)}()
@inline similar{SA <: StaticArray,N}(m::SA, sizes::NTuple{N, Int}) = MArray{sizes,eltype(SA),N,prod(sizes)}()
@inline similar{SA <: StaticArray,N,T}(m::SA, ::Type{T}, sizes::NTuple{N, Int}) = MArray{sizes,T,N,prod(sizes)}()

@generated function similar{SA <: StaticArray,S}(::SA, ::Size{S})
if length(S) == 1
return quote
$(Expr(:meta, :inline))
MVector{$(S[1]), $(eltype(SA))}()
end
elseif length(S) == 2
return quote
$(Expr(:meta, :inline))
MMatrix{$(S[1]), $(S[2]), $(eltype(SA))}()
end
else
return quote
$(Expr(:meta, :inline))
MArray{S, $(eltype(SA))}()
end
end
end

@generated function similar{SA <: StaticArray, T, S}(::SA, ::Type{T}, ::Size{S})
if length(S) == 1
return quote
$(Expr(:meta, :inline))
MVector{$(S[1]), T}()
end
elseif length(S) == 2
return quote
$(Expr(:meta, :inline))
MMatrix{$(S[1]), $(S[2]), T}()
end
else
return quote
$(Expr(:meta, :inline))
MArray{S, T}()
end
end
end


# This is used in Base.LinAlg quite a lot, and it impacts type stability
# since some functions like expm() branch on a check for Hermitian or Symmetric
Expand Down
2 changes: 1 addition & 1 deletion src/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ end
if can_blas && size(A,1)*size(A,2)*size(B,2) >= 14*14*14
return quote
$(Expr(:meta, :inline))
C = similar(A, $T, $s)
C = similar(A, $T, $(Size(s)))
A_mul_B_blas!(C, A, B)
return C
end
Expand Down

0 comments on commit 01ea146

Please sign in to comment.