diff --git a/src/SizedArray.jl b/src/SizedArray.jl index 34bf9e67..06ab2f6b 100644 --- a/src/SizedArray.jl +++ b/src/SizedArray.jl @@ -12,18 +12,42 @@ array may be reshaped. immutable SizedArray{S,T,N,M} <: StaticArray{T,N} data::Array{T,M} - function SizedArray(a) + function SizedArray(a::Array) if length(a) != prod(S) error("Dimensions $(size(a)) don't match static size $S") end new(a) end + + function SizedArray() + new(Array{T,M}(S)) + end end @inline (::Type{SizedArray{S,T,N}}){S,T,N,M}(a::Array{T,M}) = SizedArray{S,T,N,M}(a) @inline (::Type{SizedArray{S,T}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,_ndims(S),M}(a) @inline (::Type{SizedArray{S}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,_ndims(S),M}(a) +@inline (::Type{SizedArray{S,T,N}}){S,T,N}() = SizedArray{S,T,N,N}() +@inline (::Type{SizedArray{S,T}}){S,T}() = SizedArray{S,T,_ndims(S),_ndims(S)}() + +@generated function (::Type{SizedArray{S,T,N,M}}){S,T,N,M,L}(x::NTuple{L}) + if L != prod(S) + error("Dimension mismatch") + end + exprs = [:(a[$i] = x[$i]) for i = 1:L] + return quote + $(Expr(:meta, :inline)) + a = SizedArray{S,T,N,M}() + @inbounds $(Expr(:block, exprs...)) + return a + end +end + +@inline (::Type{SizedArray{S,T,N}}){S,T,N}(x::Tuple) = SizedArray{S,T,N,N}(x) +@inline (::Type{SizedArray{S,T}}){S,T}(x::Tuple) = SizedArray{S,T,_dims(S),_dims(S)}(x) +@inline (::Type{SizedArray{S}}){S,T,L}(x::NTuple{L,T}) = SizedArray{S,T,_dims(S),_dims(S)}(x) + # Overide some problematic default behaviour @inline convert{SA<:SizedArray}(::Type{SA}, sa::SizedArray) = SA(sa.data) @@ -38,10 +62,14 @@ end @propagate_inbounds setindex!(a::SizedArray, v, i::Int) = setindex!(a.data, v, i) typealias SizedVector{S,T,M} SizedArray{S,T,1,M} +@pure size{S}(::Type{SizedVector{S}}) = S @inline (::Type{SizedVector{S}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,1,M}(a) +@inline (::Type{SizedVector{S}}){S,T,L}(x::NTuple{L,T}) = SizedArray{S,T,1,1}(x) typealias SizedMatrix{S,T,M} SizedArray{S,T,2,M} +@pure size{S}(::Type{SizedMatrix{S}}) = S @inline (::Type{SizedMatrix{S}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,2,M}(a) +@inline (::Type{SizedMatrix{S}}){S,T,L}(x::NTuple{L,T}) = SizedArray{S,T,2,2}(x) """ diff --git a/src/abstractarray.jl b/src/abstractarray.jl index 6def83c1..3cd91aff 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -41,6 +41,9 @@ end @pure function similar_type{SA<:StaticArray,T}(::Union{SA,Type{SA}}, ::Type{T}, size::Int) similar_type(similar_type(SA, T), size) end +@pure function similar_type{SA<:StaticArray,T,S}(::Union{SA,Type{SA}}, ::Type{T}, size::Size{S}) + similar_type(similar_type(SA, T), size) +end @generated function similar_type{SA<:StaticArray,T}(::Union{SA,Type{SA}}, ::Type{T}) # This function has a strange error (on tests) regarding double-inference, if it is marked @pure if T == eltype(SA) @@ -115,6 +118,25 @@ end @pure similar_type{SA<:StaticArray,N}(::Union{SA,Type{SA}}, sizes::Tuple{Vararg{Int,N}}) = SArray{sizes, eltype(SA), N, prod(sizes)} +@generated function similar_type{SA <: StaticArray,S}(::Union{SA,Type{SA}}, ::Size{S}) + if length(S) == 1 + return quote + $(Expr(:meta, :inline)) + SVector{$(S[1]), $(eltype(SA))} + end + elseif length(S) == 2 + return quote + $(Expr(:meta, :inline)) + SMatrix{$(S[1]), $(S[2]), $(eltype(SA))} + end + else + return quote + $(Expr(:meta, :inline)) + SArray{S, $(eltype(SA)), $(length(S)), $(prod(S))} + end + end +end + # Some specializations for the mutable case @pure similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray}}(::Union{MA,Type{MA}}, size::Int) = MVector{size, eltype(MA)} @pure similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray}}(::Union{MA,Type{MA}}, sizes::Tuple{Int}) = MVector{sizes[1], eltype(MA)} @@ -123,23 +145,73 @@ end @pure similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray},N}(::Union{MA,Type{MA}}, sizes::Tuple{Vararg{Int,N}}) = MArray{sizes, eltype(MA), N, prod(sizes)} +@generated function similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray},S}(::Union{MA,Type{MA}}, ::Size{S}) + if length(S) == 1 + return quote + $(Expr(:meta, :inline)) + MVector{$(S[1]), $(eltype(MA))} + end + elseif length(S) == 2 + return quote + $(Expr(:meta, :inline)) + MMatrix{$(S[1]), $(S[2]), $(eltype(MA))} + end + else + return quote + $(Expr(:meta, :inline)) + MArray{S, $(eltype(MA)), $(length(S)), $(prod(S))} + end + end +end + # And also similar() returning mutable StaticArrays @inline similar{SV <: StaticVector}(::SV) = MVector{length(SV),eltype(SV)}() @inline similar{SV <: StaticVector, T}(::SV, ::Type{T}) = MVector{length(SV),T}() -@inline similar{SA <: StaticArray}(::SA, sizes::Tuple{Int}) = MVector{sizes[1], eltype(SA)}() -@inline similar{SA <: StaticArray}(::SA, size::Int) = MVector{size, eltype(SA)}() -@inline similar{T}(::StaticArray, ::Type{T}, sizes::Tuple{Int}) = MVector{sizes[1],T}() -@inline similar{T}(::StaticArray, ::Type{T}, size::Int) = MVector{size,T}() @inline similar{SM <: StaticMatrix}(m::SM) = MMatrix{size(SM,1),size(SM,2),eltype(SM),length(SM)}() @inline similar{SM <: StaticMatrix, T}(::SM, ::Type{T}) = MMatrix{size(SM,1),size(SM,2),T,length(SM)}() -@inline similar{SA <: StaticArray}(::SA, sizes::Tuple{Int,Int}) = MMatrix{sizes[1], sizes[2], eltype(SA), sizes[1]*sizes[2]}() -@inline similar(a::StaticArray, T::Type, sizes::Tuple{Int,Int}) = MMatrix{sizes[1], sizes[2], T, sizes[1]*sizes[2]}() @inline similar{SA <: StaticArray}(m::SA) = MArray{size(SA),eltype(SA),ndims(SA),length(SA)}() @inline similar{SA <: StaticArray,T}(m::SA, ::Type{T}) = MArray{size(SA),T,ndims(SA),length(SA)}() -@inline similar{SA <: StaticArray,N}(m::SA, sizes::NTuple{N, Int}) = MArray{sizes,eltype(SA),N,prod(sizes)}() -@inline similar{SA <: StaticArray,N,T}(m::SA, ::Type{T}, sizes::NTuple{N, Int}) = MArray{sizes,T,N,prod(sizes)}() + +@generated function similar{SA <: StaticArray,S}(::SA, ::Size{S}) + if length(S) == 1 + return quote + $(Expr(:meta, :inline)) + MVector{$(S[1]), $(eltype(SA))}() + end + elseif length(S) == 2 + return quote + $(Expr(:meta, :inline)) + MMatrix{$(S[1]), $(S[2]), $(eltype(SA))}() + end + else + return quote + $(Expr(:meta, :inline)) + MArray{S, $(eltype(SA))}() + end + end +end + +@generated function similar{SA <: StaticArray, T, S}(::SA, ::Type{T}, ::Size{S}) + if length(S) == 1 + return quote + $(Expr(:meta, :inline)) + MVector{$(S[1]), T}() + end + elseif length(S) == 2 + return quote + $(Expr(:meta, :inline)) + MMatrix{$(S[1]), $(S[2]), T}() + end + else + return quote + $(Expr(:meta, :inline)) + MArray{S, T}() + end + end +end + # This is used in Base.LinAlg quite a lot, and it impacts type stability # since some functions like expm() branch on a check for Hermitian or Symmetric diff --git a/src/matrix_multiply.jl b/src/matrix_multiply.jl index 52046474..6b9ec030 100644 --- a/src/matrix_multiply.jl +++ b/src/matrix_multiply.jl @@ -280,7 +280,7 @@ end if can_blas && size(A,1)*size(A,2)*size(B,2) >= 14*14*14 return quote $(Expr(:meta, :inline)) - C = similar(A, $T, $s) + C = similar(A, $T, $(Size(s))) A_mul_B_blas!(C, A, B) return C end