Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend @SArray (nested cat, 1.7 syntax) #1009

Merged
merged 11 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.4.2"
version = "1.4.3"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
163 changes: 10 additions & 153 deletions src/MArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,160 +110,17 @@ Base.dataids(ma::MArray) = (UInt(pointer(ma)),)
Base.unsafe_convert(Ptr{T}, pointer_from_objref(a))
end

macro MArray(ex)
if !isa(ex, Expr)
error("Bad input for @MArray")
end

if ex.head == :vect # vector
return esc(Expr(:call, MArray{Tuple{length(ex.args)}}, Expr(:tuple, ex.args...)))
elseif ex.head == :ref # typed, vector
return esc(Expr(:call, Expr(:curly, :MArray, Tuple{length(ex.args)-1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
elseif ex.head == :hcat # 1 x n
s1 = 1
s2 = length(ex.args)
return esc(Expr(:call, MArray{Tuple{s1, s2}}, Expr(:tuple, ex.args...)))
elseif ex.head == :typed_hcat # typed, 1 x n
s1 = 1
s2 = length(ex.args) - 1
return esc(Expr(:call, Expr(:curly, :MArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
elseif ex.head == :vcat
if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m
# Validate
s1 = length(ex.args)
s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1)
s2 = minimum(s2s)
if maximum(s2s) != s2
throw(ArgumentError("Rows must be of matching lengths"))
end

exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2]
return esc(Expr(:call, MArray{Tuple{s1, s2}}, Expr(:tuple, exprs...)))
else # n x 1
return esc(Expr(:call, MArray{Tuple{length(ex.args), 1}}, Expr(:tuple, ex.args...)))
end
elseif ex.head == :typed_vcat
if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m
# Validate
s1 = length(ex.args) - 1
s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1)
s2 = minimum(s2s)
if maximum(s2s) != s2
error("Rows must be of matching lengths")
end

exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2]
return esc(Expr(:call, Expr(:curly, :MArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, exprs...)))
else # typed, n x 1
return esc(Expr(:call, Expr(:curly, :MArray, Tuple{length(ex.args)-1, 1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
end
elseif isa(ex, Expr) && ex.head == :comprehension
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
end
ex = ex.args[1]
n_rng = length(ex.args) - 1
rng_args = [ex.args[i+1].args[1] for i = 1:n_rng]
rngs = [Core.eval(__module__, ex.args[i+1].args[2]) for i = 1:n_rng]
rng_lengths = map(length, rngs)

f = gensym()
f_expr = :($f = ($(Expr(:tuple, rng_args...)) -> $(ex.args[1])))

# TODO figure out a generic way of doing this...
if n_rng == 1
exprs = [:($f($j1)) for j1 in rngs[1]]
elseif n_rng == 2
exprs = [:($f($j1, $j2)) for j1 in rngs[1], j2 in rngs[2]]
elseif n_rng == 3
exprs = [:($f($j1, $j2, $j3)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3]]
elseif n_rng == 4
exprs = [:($f($j1, $j2, $j3, $j4)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4]]
elseif n_rng == 5
exprs = [:($f($j1, $j2, $j3, $j4, $j5)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5]]
elseif n_rng == 6
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6]]
elseif n_rng == 7
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7]]
elseif n_rng == 8
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7, $j8)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7], j8 in rngs[8]]
else
error("@MArray only supports up to 8-dimensional comprehensions")
end

return quote
$(esc(f_expr))
$(esc(Expr(:call, Expr(:curly, :MArray, Tuple{rng_lengths...}), Expr(:tuple, exprs...))))
end
elseif isa(ex, Expr) && ex.head == :typed_comprehension
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
end
T = ex.args[1]
ex = ex.args[2]
n_rng = length(ex.args) - 1
rng_args = [ex.args[i+1].args[1] for i = 1:n_rng]
rngs = [Core.eval(__module__, ex.args[i+1].args[2]) for i = 1:n_rng]
rng_lengths = map(length, rngs)

f = gensym()
f_expr = :($f = ($(Expr(:tuple, rng_args...)) -> $(ex.args[1])))

# TODO figure out a generic way of doing this...
if n_rng == 1
exprs = [:($f($j1)) for j1 in rngs[1]]
elseif n_rng == 2
exprs = [:($f($j1, $j2)) for j1 in rngs[1], j2 in rngs[2]]
elseif n_rng == 3
exprs = [:($f($j1, $j2, $j3)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3]]
elseif n_rng == 4
exprs = [:($f($j1, $j2, $j3, $j4)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4]]
elseif n_rng == 5
exprs = [:($f($j1, $j2, $j3, $j4, $j5)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5]]
elseif n_rng == 6
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6]]
elseif n_rng == 7
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7]]
elseif n_rng == 8
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7, $j8)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7], j8 in rngs[8]]
else
error("@MArray only supports up to 8-dimensional comprehensions")
end

return quote
$(esc(f_expr))
$(esc(Expr(:call, Expr(:curly, :MArray, Tuple{rng_lengths...}, T), Expr(:tuple, exprs...))))
end
elseif isa(ex, Expr) && ex.head == :call
if ex.args[1] == :zeros || ex.args[1] == :ones || ex.args[1] == :rand || ex.args[1] == :randn || ex.args[1] == :randexp
if length(ex.args) == 1
error("@MArray got bad expression: $(ex.args[1])()")
else
return quote
if isa($(esc(ex.args[2])), DataType)
$(ex.args[1])($(esc(Expr(:curly, MArray, Expr(:curly, Tuple, ex.args[3:end]...), ex.args[2]))))
else
$(ex.args[1])($(esc(Expr(:curly, MArray, Expr(:curly, Tuple, ex.args[2:end]...)))))
end
end
end
elseif ex.args[1] == :fill
if length(ex.args) == 1
error("@MArray got bad expression: $(ex.args[1])()")
elseif length(ex.args) == 2
error("@MArray got bad expression: $(ex.args[1])($(ex.args[2]))")
else
return quote
$(esc(ex.args[1]))($(esc(ex.args[2])), MArray{$(esc(Expr(:curly, Tuple, ex.args[3:end]...)))})
end
end
else
error("@MArray only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
end
else
error("Bad input for @MArray")
end
"""
@MArray [a b; c d]
@MArray [[a, b];[c, d]]
@MArray [i+j for i in 1:2, j in 1:2]
@MArray ones(2, 2, 2)

A convenience macro to construct `MArray` with arbitrary dimension.
See [`@SArray`](@ref) for detailed features.
"""
macro MArray(ex)
esc(static_array_gen(MArray, ex, __module__))
end

function promote_rule(::Type{<:MArray{S,T,N,L}}, ::Type{<:MArray{S,U,N,L}}) where {S,T,U,N,L}
Expand Down
123 changes: 11 additions & 112 deletions src/MMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,116 +62,15 @@ end
## MMatrix methods ##
#####################

macro MMatrix(ex)
if !isa(ex, Expr)
error("Bad input for @MMatrix")
end
if ex.head == :vect && length(ex.args) == 1 # 1 x 1
return esc(Expr(:call, MMatrix{1, 1}, Expr(:tuple, ex.args[1])))
elseif ex.head == :ref && length(ex.args) == 2 # typed, 1 x 1
return esc(Expr(:call, Expr(:curly, :MMatrix, 1, 1, ex.args[1]), Expr(:tuple, ex.args[2])))
elseif ex.head == :hcat # 1 x n
s1 = 1
s2 = length(ex.args)
return esc(Expr(:call, MMatrix{s1, s2}, Expr(:tuple, ex.args...)))
elseif ex.head == :typed_hcat # typed, 1 x n
s1 = 1
s2 = length(ex.args) - 1
return esc(Expr(:call, Expr(:curly, :MMatrix, s1, s2, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
elseif ex.head == :vcat
if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m
# Validate
s1 = length(ex.args)
s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1)
s2 = minimum(s2s)
if maximum(s2s) != s2
throw(ArgumentError("Rows must be of matching lengths"))
end

exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2]
return esc(Expr(:call, MMatrix{s1, s2}, Expr(:tuple, exprs...)))
else # n x 1
return esc(Expr(:call, MMatrix{length(ex.args), 1}, Expr(:tuple, ex.args...)))
end
elseif ex.head == :typed_vcat
if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m
# Validate
s1 = length(ex.args) - 1
s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1)
s2 = minimum(s2s)
if maximum(s2s) != s2
throw(ArgumentError("Rows must be of matching lengths"))
end

exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2]
return esc(Expr(:call, Expr(:curly, :MMatrix,s1, s2, ex.args[1]), Expr(:tuple, exprs...)))
else # typed, n x 1
return esc(Expr(:call, Expr(:curly, :MMatrix, length(ex.args)-1, 1, ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
end
elseif isa(ex, Expr) && ex.head == :comprehension
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
end
ex = ex.args[1]
if length(ex.args) != 3
error("Use a 2-dimensional comprehension for @MMatrx")
end

rng1 = Core.eval(__module__, ex.args[2].args[2])
rng2 = Core.eval(__module__, ex.args[3].args[2])
f = gensym()
f_expr = :($f = (($(ex.args[2].args[1]), $(ex.args[3].args[1])) -> $(ex.args[1])))
exprs = [:($f($j1, $j2)) for j1 in rng1, j2 in rng2]

return quote
$(esc(f_expr))
$(esc(Expr(:call, Expr(:curly, :MMatrix, length(rng1), length(rng2)), Expr(:tuple, exprs...))))
end
elseif isa(ex, Expr) && ex.head == :typed_comprehension
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]")
end
T = ex.args[1]
ex = ex.args[2]
if length(ex.args) != 3
error("Use a 2-dimensional comprehension for @MMatrx")
end

rng1 = Core.eval(__module__, ex.args[2].args[2])
rng2 = Core.eval(__module__, ex.args[3].args[2])
f = gensym()
f_expr = :($f = (($(ex.args[2].args[1]), $(ex.args[3].args[1])) -> $(ex.args[1])))
exprs = [:($f($j1, $j2)) for j1 in rng1, j2 in rng2]
"""
@MMatrix [a b c d]
@MMatrix [[a, b];[c, d]]
@MMatrix [i+j for i in 1:2, j in 1:2]
@MMatrix ones(2, 2, 2)

return quote
$(esc(f_expr))
$(esc(Expr(:call, Expr(:curly, :MMatrix, length(rng1), length(rng2), T), Expr(:tuple, exprs...))))
end
elseif isa(ex, Expr) && ex.head == :call
if ex.args[1] == :zeros || ex.args[1] == :ones || ex.args[1] == :rand || ex.args[1] == :randn || ex.args[1] == :randexp
if length(ex.args) == 3
return quote
$(ex.args[1])(MMatrix{$(esc(ex.args[2])),$(esc(ex.args[3]))})
end
elseif length(ex.args) == 4
return quote
$(ex.args[1])(MMatrix{$(esc(ex.args[3])), $(esc(ex.args[4])), $(esc(ex.args[2]))})
end
else
error("@MMatrix expected a 2-dimensional array expression")
end
elseif ex.args[1] == :fill
if length(ex.args) == 4
return quote
$(esc(ex.args[1]))($(esc(ex.args[2])), MMatrix{$(esc(ex.args[3])), $(esc(ex.args[4]))})
end
else
error("@MMatrix expected a 2-dimensional array expression")
end
else
error("@MMatrix only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
end
else
error("Bad input for @MMatrix")
end
end
A convenience macro to construct `MMatrix`.
See [`@SArray`](@ref) for detailed features.
"""
macro MMatrix(ex)
esc(static_matrix_gen(MMatrix, ex, __module__))
end
76 changes: 8 additions & 68 deletions src/MVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,76 +28,16 @@ const MVector{S, T} = MArray{Tuple{S}, T, 1, S}
#####################
## MVector methods ##
#####################
"""
@MVector [a, b, c, d]
@MVector [i for i in 1:2]
@MVector ones(2)

A convenience macro to construct `MVector`.
See [`@SArray`](@ref) for detailed features.
"""
macro MVector(ex)
if isa(ex, Expr) && ex.head == :vect
return esc(Expr(:call, MVector{length(ex.args)}, Expr(:tuple, ex.args...)))
elseif isa(ex, Expr) && ex.head == :ref
return esc(Expr(:call, Expr(:curly, :MVector, length(ex.args[2:end]), ex.args[1]), Expr(:tuple, ex.args[2:end]...)))
elseif isa(ex, Expr) && ex.head == :comprehension
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator
error("Expected generator in comprehension, e.g. [f(i) for i = 1:3]")
end
ex = ex.args[1]
if length(ex.args) != 2
error("Use a one-dimensional comprehension for @MVector")
end

rng = Core.eval(__module__, ex.args[2].args[2])
f = gensym()
f_expr = :($f = ($(ex.args[2].args[1]) -> $(ex.args[1])))
exprs = [:($f($j)) for j in rng]

return quote
$(esc(f_expr))
$(esc(Expr(:call, Expr(:curly, :MVector, length(rng)), Expr(:tuple, exprs...))))
end
elseif isa(ex, Expr) && ex.head == :typed_comprehension
if length(ex.args) != 2 || !isa(ex.args[2], Expr) !! ex.args[2].head != :generator
error("Expected generator in typed comprehension, e.g. Float64[f(i) for i = 1:3]")
end
T = ex.args[1]
ex = ex.args[2]
if length(ex.args) != 2
error("Use a one-dimensional comprehension for @MVector")
end

rng = Core.eval(__module__, ex.args[2].args[2])
f = gensym()
f_expr = :($f = ($(ex.args[2].args[1]) -> $(ex.args[1])))
exprs = [:($f($j)) for j in rng]

return quote
$(esc(f_expr))
$(esc(Expr(:call, Expr(:curly, :MVector, length(rng), T), Expr(:tuple, exprs...))))
end
elseif isa(ex, Expr) && ex.head == :call
if ex.args[1] == :zeros || ex.args[1] == :ones || ex.args[1] == :rand || ex.args[1] == :randn || ex.args[1] == :randexp
if length(ex.args) == 2
return quote
$(esc(ex.args[1]))(MVector{$(esc(ex.args[2]))})
end
elseif length(ex.args) == 3
return quote
$(esc(ex.args[1]))(MVector{$(esc(ex.args[3])), $(esc(ex.args[2]))})
end
else
error("@MVector expected a 1-dimensional array expression")
end
elseif ex.args[1] == :fill
if length(ex.args) == 3
return quote
$(esc(ex.args[1]))($(esc(ex.args[2])), MVector{$(esc(ex.args[3]))})
end
else
error("@MVector expected a 1-dimensional array expression")
end
else
error("@MVector only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
end
else
error("Use @MVector [a,b,c] or @MVector([a,b,c])")
end
esc(static_vector_gen(MVector, ex, __module__))
end

# Named field access for the first four elements, using the conventional field
Expand Down
Loading