diff --git a/base/exports.jl b/base/exports.jl index 0dfd70f27dcf94..a9d3e83bcda440 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -1427,6 +1427,7 @@ export @enum, @label, @goto, + @sub, # SparseArrays module re-exports SparseArrays, diff --git a/base/subarray.jl b/base/subarray.jl index 2505f5eabeffb4..c2caa921e15b1f 100644 --- a/base/subarray.jl +++ b/base/subarray.jl @@ -352,3 +352,72 @@ function parentdims(s::SubArray) end dimindex end + +""" + replace_ref_end!(ex) + +Recursively replace occurences of the symbol :end in a "ref" expression (i.e. A[...]) `ex` with the appropriate function calls (`endof`, `size` or `trailingsize`). Replacement uses the closest enclosing ref, so + + A[B[end]] + +should transform to + + A[B[endof(B)]] + +""" +function replace_ref_end!(ex,withex=nothing) + if isa(ex,Symbol) && ex == :end + withex == nothing && error("Invalid use of end") + return withex + elseif isa(ex,Expr) + if ex.head == :ref + S = ex.args[1] = replace_ref_end!(ex.args[1],withex) + # new :ref, so redefine withex + nargs = length(ex.args)-1 + if nargs == 0 + return ex + elseif nargs == 1 + # replace with endof(S) + ex.args[2] = replace_ref_end!(ex.args[2],:(Base.endof($S))) + else + n = 1 + J = endof(ex.args) + for j = 2:J-1 + exj = ex.args[j] = replace_ref_end!(ex.args[j],:(Base.size($S,$n))) + if isa(exj,Expr) && exj.head == :... + # splatted object + exjs = exj.args[1] + n = :($n + length($exjs)) + elseif isa(n, Expr) + # previous expression splatted + n = :($n + 1) + else + # an integer + n += 1 + end + end + ex.args[J] = replace_ref_end!(ex.args[J],:(Base.trailingsize($S,$n))) + end + else + # recursive search + for i = eachindex(ex.args) + ex.args[i] = replace_ref_end!(ex.args[i],withex) + end + end + end + ex +end + +""" + @sub A[...] + +Creates `SubArray` from an indexing expression. This can only be applied directly to a reference expression (e.g. `@sub A[1,2:end]`), and should *not* be used as the target of an assignment (e.g. `@sub(A[1,2:end]) = ...`). +""" +macro sub(ex) + if isa(ex, Expr) && ex.head == :ref + ex = replace_ref_end!(ex) + esc(Expr(:call,:(Base.slice),ex.args...)) + else + throw(ArgumentError("Invalid use of @sub macro: argument must be a reference expression A[...].")) + end +end diff --git a/test/subarray.jl b/test/subarray.jl index 26204738c0300a..bcf34527fdec21 100644 --- a/test/subarray.jl +++ b/test/subarray.jl @@ -475,3 +475,29 @@ end # the following segfaults with LLVM 3.8 on Windows, ref #15417 @test collect(sub(sub(reshape(1:13^3, 13, 13, 13), 3:7, 6, :), 1:2:5, :, 1:2:5)) == cat(3,[68,70,72],[406,408,410],[744,746,748]) + + + +# tests @sub (and replace_ref_end!) +X = reshape(1:24,2,3,4) +Y = 4:-1:1 + +@test isa(@sub(X[1:3]), SubArray) + + +@test X[1:end] == @sub X[1:end] +@test X[1:end-3] == @sub X[1:end-3] +@test X[1:end,2,2] == @sub X[1:end,2,2] +@test X[1,1:end-2] == @sub X[1,1:end-2] +@test X[1,2,1:end-2] == @sub X[1,2,1:end-2] +@test X[1,2,Y[2:end]] == @sub X[1,2,Y[2:end]] +@test X[1:end,2,Y[2:end]] == @sub X[1:end,2,Y[2:end]] + +u = (1,2:3) +@test X[u...,2:end] == @sub X[u...,2:end] +@test X[(1,)...,(2,)...,2:end] == @sub X[(1,)...,(2,)...,2:end] + +# test macro hygiene +let size=(x,y)-> error("should not happen") + @test X[1:end,2,2] == @sub X[1:end,2,2] +end