From 5bf8e41d2f85b3b856121b81b86cc952d19478bc Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Tue, 24 May 2016 18:44:50 +0100 Subject: [PATCH] Create `@sub` macro for creating SubArrays via indexing. --- base/exports.jl | 1 + base/subarray.jl | 73 +++++++++++++++++++++++++++++++++++++++++++ doc/stdlib/arrays.rst | 6 ++++ test/subarray.jl | 26 +++++++++++++++ 4 files changed, 106 insertions(+) diff --git a/base/exports.jl b/base/exports.jl index d5dc7801aa060..cd76c3778e6a0 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -1412,6 +1412,7 @@ export @enum, @label, @goto, + @view, # SparseArrays module re-exports SparseArrays, diff --git a/base/subarray.jl b/base/subarray.jl index 06027d4ad5d7d..adf7807b321ec 100644 --- a/base/subarray.jl +++ b/base/subarray.jl @@ -318,3 +318,76 @@ function parentdims(s::SubArray) end dimindex end + +""" + replace_ref_end!(ex) + +Recursively replace occurences of the symbol :end in a "ref" expression (i.e. A[...]) `ex` +with the appropriate function calls (`endof`, `size` or `trailingsize`). Replacement uses +the closest enclosing ref, so + + A[B[end]] + +should transform to + + A[B[endof(B)]] + +""" +function replace_ref_end!(ex,withex=nothing) + if isa(ex,Symbol) && ex == :end + withex == nothing && error("Invalid use of end") + return withex + elseif isa(ex,Expr) + if ex.head == :ref + S = ex.args[1] = replace_ref_end!(ex.args[1],withex) + # new :ref, so redefine withex + nargs = length(ex.args)-1 + if nargs == 0 + return ex + elseif nargs == 1 + # replace with endof(S) + ex.args[2] = replace_ref_end!(ex.args[2],:(Base.endof($S))) + else + n = 1 + J = endof(ex.args) + for j = 2:J-1 + exj = ex.args[j] = replace_ref_end!(ex.args[j],:(Base.size($S,$n))) + if isa(exj,Expr) && exj.head == :... + # splatted object + exjs = exj.args[1] + n = :($n + length($exjs)) + elseif isa(n, Expr) + # previous expression splatted + n = :($n + 1) + else + # an integer + n += 1 + end + end + ex.args[J] = replace_ref_end!(ex.args[J],:(Base.trailingsize($S,$n))) + end + else + # recursive search + for i = eachindex(ex.args) + ex.args[i] = replace_ref_end!(ex.args[i],withex) + end + end + end + ex +end + +""" + @view A[inds...] + +Creates a `SubArray` from an indexing expression. This can only be applied directly to a +reference expression (e.g. `@view A[1,2:end]`), and should *not* be used as the target of +an assignment (e.g. `@view(A[1,2:end]) = ...`). +""" +macro view(ex) + if isa(ex, Expr) && ex.head == :ref + ex = replace_ref_end!(ex) + Expr(:&&, true, esc(Expr(:call,:(Base.view),ex.args...))) + else + throw(ArgumentError("Invalid use of @view macro: argument must be a reference expression A[...].")) + end +end diff --git a/doc/stdlib/arrays.rst b/doc/stdlib/arrays.rst index c258443fd465c..2cc45514e1a82 100644 --- a/doc/stdlib/arrays.rst +++ b/doc/stdlib/arrays.rst @@ -383,6 +383,12 @@ Indexing, Assignment, and Concatenation Like :func:`getindex`\ , but returns a view into the parent array ``A`` with the given indices instead of making a copy. Calling :func:`getindex` or :func:`setindex!` on the returned :obj:`SubArray` computes the indices to the parent array on the fly without checking bounds. +.. function:: @view A[inds...] + + .. Docstring generated from Julia source + + Creates a ``SubArray`` from an indexing expression. This can only be applied directly to a reference expression (e.g. ``@view A[1,2:end]``\ ), and should *not* be used as the target of an assignment (e.g. ``@view(A[1,2:end]) = ...``\ ). + .. function:: parent(A) .. Docstring generated from Julia source diff --git a/test/subarray.jl b/test/subarray.jl index 14ebec68c3553..771e0c963f1e0 100644 --- a/test/subarray.jl +++ b/test/subarray.jl @@ -465,3 +465,29 @@ end # the following segfaults with LLVM 3.8 on Windows, ref #15417 @test collect(view(view(reshape(1:13^3, 13, 13, 13), 3:7, 6:6, :), 1:2:5, :, 1:2:5)) == cat(3,[68,70,72],[406,408,410],[744,746,748]) + + + +# tests @view (and replace_ref_end!) +X = reshape(1:24,2,3,4) +Y = 4:-1:1 + +@test isa(@view(X[1:3]), SubArray) + + +@test X[1:end] == @view X[1:end] +@test X[1:end-3] == @view X[1:end-3] +@test X[1:end,2,2] == @view X[1:end,2,2] +@test X[1,1:end-2] == @view X[1,1:end-2] +@test X[1,2,1:end-2] == @view X[1,2,1:end-2] +@test X[1,2,Y[2:end]] == @view X[1,2,Y[2:end]] +@test X[1:end,2,Y[2:end]] == @view X[1:end,2,Y[2:end]] + +u = (1,2:3) +@test X[u...,2:end] == @view X[u...,2:end] +@test X[(1,)...,(2,)...,2:end] == @view X[(1,)...,(2,)...,2:end] + +# test macro hygiene +let size=(x,y)-> error("should not happen") + @test X[1:end,2,2] == @view X[1:end,2,2] +end