Skip to content

Commit 5c49ff8

Browse files
committed
Create @sub macro for creating SubArrays via indexing.
1 parent 3bed78c commit 5c49ff8

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

base/exports.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1427,6 +1427,7 @@ export
14271427
@enum,
14281428
@label,
14291429
@goto,
1430+
@sub,
14301431

14311432
# SparseArrays module re-exports
14321433
SparseArrays,

base/subarray.jl

+69
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,72 @@ function parentdims(s::SubArray)
352352
end
353353
dimindex
354354
end
355+
356+
"""
357+
replace_ref_end!(ex)
358+
359+
Recursively replace occurences of the symbol :end in a "ref" expression (i.e. A[...]) `ex` with the appropriate function calls (`endof`, `size` or `trailingsize`). Replacement uses the closest enclosing ref, so
360+
361+
A[B[end]]
362+
363+
should transform to
364+
365+
A[B[endof(B)]]
366+
367+
"""
368+
function replace_ref_end!(ex,withex=nothing)
369+
if isa(ex,Symbol) && ex == :end
370+
withex == nothing && error("Invalid use of end")
371+
return withex
372+
elseif isa(ex,Expr)
373+
if ex.head == :ref
374+
S = ex.args[1] = replace_ref_end!(ex.args[1],withex)
375+
# new :ref, so redefine withex
376+
nargs = length(ex.args)-1
377+
if nargs == 0
378+
return ex
379+
elseif nargs == 1
380+
# replace with endof(S)
381+
ex.args[2] = replace_ref_end!(ex.args[2],:(Base.endof($S)))
382+
else
383+
n = 1
384+
J = endof(ex.args)
385+
for j = 2:J-1
386+
exj = ex.args[j] = replace_ref_end!(ex.args[j],:(Base.size($S,$n)))
387+
if isa(exj,Expr) && exj.head == :...
388+
# splatted object
389+
exjs = exj.args[1]
390+
n = :($n + length($exjs))
391+
elseif isa(n, Expr)
392+
# previous expression splatted
393+
n = :($n + 1)
394+
else
395+
# an integer
396+
n += 1
397+
end
398+
end
399+
ex.args[J] = replace_ref_end!(ex.args[J],:(Base.trailingsize($S,$n)))
400+
end
401+
else
402+
# recursive search
403+
for i = eachindex(ex.args)
404+
ex.args[i] = replace_ref_end!(ex.args[i],withex)
405+
end
406+
end
407+
end
408+
ex
409+
end
410+
411+
"""
412+
@sub A[...]
413+
414+
Creates `SubArray` from an indexing expression. This can only be applied directly to a reference expression (e.g. `@sub A[1,2:end]`), and should *not* be used as the target of an assignment (e.g. `@sub(A[1,2:end]) = ...`).
415+
"""
416+
macro sub(ex)
417+
if isa(ex, Expr) && ex.head == :ref
418+
ex = replace_ref_end!(ex)
419+
esc(Expr(:call,:(Base.slice),ex.args...))
420+
else
421+
throw(ArgumentError("Invalid use of @sub macro: argument must be a reference expression A[...]."))
422+
end
423+
end

test/subarray.jl

+26
Original file line numberDiff line numberDiff line change
@@ -475,3 +475,29 @@ end
475475
# the following segfaults with LLVM 3.8 on Windows, ref #15417
476476
@test collect(sub(sub(reshape(1:13^3, 13, 13, 13), 3:7, 6, :), 1:2:5, :, 1:2:5)) ==
477477
cat(3,[68,70,72],[406,408,410],[744,746,748])
478+
479+
480+
481+
# tests @sub (and replace_ref_end!)
482+
X = reshape(1:24,2,3,4)
483+
Y = 4:-1:1
484+
485+
@test isa(@sub(X[1:3]), SubArray)
486+
487+
488+
@test X[1:end] == @sub X[1:end]
489+
@test X[1:end-3] == @sub X[1:end-3]
490+
@test X[1:end,2,2] == @sub X[1:end,2,2]
491+
@test X[1,1:end-2] == @sub X[1,1:end-2]
492+
@test X[1,2,1:end-2] == @sub X[1,2,1:end-2]
493+
@test X[1,2,Y[2:end]] == @sub X[1,2,Y[2:end]]
494+
@test X[1:end,2,Y[2:end]] == @sub X[1:end,2,Y[2:end]]
495+
496+
u = (1,2:3)
497+
@test X[u...,2:end] == @sub X[u...,2:end]
498+
@test X[(1,)...,(2,)...,2:end] == @sub X[(1,)...,(2,)...,2:end]
499+
500+
# test macro hygiene
501+
let size=(x,y)-> error("should not happen")
502+
@test X[1:end,2,2] == @sub X[1:end,2,2]
503+
end

0 commit comments

Comments
 (0)