Skip to content

Commit

Permalink
Create @sub macro for creating SubArrays via indexing.
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbyrne committed May 24, 2016
1 parent 3bed78c commit 5c49ff8
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,7 @@ export
@enum,
@label,
@goto,
@sub,

# SparseArrays module re-exports
SparseArrays,
Expand Down
69 changes: 69 additions & 0 deletions base/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,72 @@ function parentdims(s::SubArray)
end
dimindex
end

"""
replace_ref_end!(ex)
Recursively replace occurences of the symbol :end in a "ref" expression (i.e. A[...]) `ex` with the appropriate function calls (`endof`, `size` or `trailingsize`). Replacement uses the closest enclosing ref, so
A[B[end]]
should transform to
A[B[endof(B)]]
"""
function replace_ref_end!(ex,withex=nothing)
if isa(ex,Symbol) && ex == :end
withex == nothing && error("Invalid use of end")
return withex
elseif isa(ex,Expr)
if ex.head == :ref
S = ex.args[1] = replace_ref_end!(ex.args[1],withex)
# new :ref, so redefine withex
nargs = length(ex.args)-1
if nargs == 0
return ex
elseif nargs == 1
# replace with endof(S)
ex.args[2] = replace_ref_end!(ex.args[2],:(Base.endof($S)))
else
n = 1
J = endof(ex.args)
for j = 2:J-1
exj = ex.args[j] = replace_ref_end!(ex.args[j],:(Base.size($S,$n)))
if isa(exj,Expr) && exj.head == :...
# splatted object
exjs = exj.args[1]
n = :($n + length($exjs))
elseif isa(n, Expr)
# previous expression splatted
n = :($n + 1)
else
# an integer
n += 1
end
end
ex.args[J] = replace_ref_end!(ex.args[J],:(Base.trailingsize($S,$n)))
end
else
# recursive search
for i = eachindex(ex.args)
ex.args[i] = replace_ref_end!(ex.args[i],withex)
end
end
end
ex
end

"""
@sub A[...]
Creates `SubArray` from an indexing expression. This can only be applied directly to a reference expression (e.g. `@sub A[1,2:end]`), and should *not* be used as the target of an assignment (e.g. `@sub(A[1,2:end]) = ...`).
"""
macro sub(ex)
if isa(ex, Expr) && ex.head == :ref
ex = replace_ref_end!(ex)
esc(Expr(:call,:(Base.slice),ex.args...))
else
throw(ArgumentError("Invalid use of @sub macro: argument must be a reference expression A[...]."))
end
end
26 changes: 26 additions & 0 deletions test/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,29 @@ end
# the following segfaults with LLVM 3.8 on Windows, ref #15417
@test collect(sub(sub(reshape(1:13^3, 13, 13, 13), 3:7, 6, :), 1:2:5, :, 1:2:5)) ==
cat(3,[68,70,72],[406,408,410],[744,746,748])



# tests @sub (and replace_ref_end!)
X = reshape(1:24,2,3,4)
Y = 4:-1:1

@test isa(@sub(X[1:3]), SubArray)


@test X[1:end] == @sub X[1:end]
@test X[1:end-3] == @sub X[1:end-3]
@test X[1:end,2,2] == @sub X[1:end,2,2]
@test X[1,1:end-2] == @sub X[1,1:end-2]
@test X[1,2,1:end-2] == @sub X[1,2,1:end-2]
@test X[1,2,Y[2:end]] == @sub X[1,2,Y[2:end]]
@test X[1:end,2,Y[2:end]] == @sub X[1:end,2,Y[2:end]]

u = (1,2:3)
@test X[u...,2:end] == @sub X[u...,2:end]
@test X[(1,)...,(2,)...,2:end] == @sub X[(1,)...,(2,)...,2:end]

# test macro hygiene
let size=(x,y)-> error("should not happen")
@test X[1:end,2,2] == @sub X[1:end,2,2]
end

0 comments on commit 5c49ff8

Please sign in to comment.