Skip to content

Commit

Permalink
Merge pull request #8432 from JuliaLang/teh/cartesian_iteration
Browse files Browse the repository at this point in the history
Efficient cartesian iteration (new version of #6437)
  • Loading branch information
timholy committed Nov 15, 2014
2 parents 4f7b787 + 2fa852d commit 37ffa13
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 11 deletions.
13 changes: 12 additions & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ function trailingsize(A, n)
return s
end

## Traits for array types ##

abstract LinearIndexing
immutable LinearFast <: LinearIndexing end
immutable LinearSlow <: LinearIndexing end

linearindexing(::AbstractArray) = LinearSlow()
linearindexing(::Array) = LinearFast()
linearindexing(::Range) = LinearFast()

## Bounds checking ##
checkbounds(sz::Int, i::Int) = 1 <= i <= sz || throw(BoundsError())
checkbounds(sz::Int, i::Real) = checkbounds(sz, to_index(i))
Expand Down Expand Up @@ -241,7 +251,8 @@ zero{T}(x::AbstractArray{T}) = fill!(similar(x), zero(T))

## iteration support for arrays as ranges ##

start(a::AbstractArray) = 1
start(A::AbstractArray) = _start(A,linearindexing(A))
_start(::AbstractArray,::LinearFast) = 1
next(a::AbstractArray,i) = (a[i],i+1)
done(a::AbstractArray,i) = (i > length(a))
isempty(a::AbstractArray) = (length(a) == 0)
Expand Down
2 changes: 1 addition & 1 deletion base/dates/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ function in{T<:TimeType}(x::T, r::StepRange{T})
end

Base.start{T<:TimeType}(r::StepRange{T}) = 0
Base.next{T<:TimeType}(r::StepRange{T}, i) = (r.start+r.step*i,i+1)
Base.next{T<:TimeType}(r::StepRange{T}, i::Int) = (r.start+r.step*i,i+1)
Base.done{T<:TimeType,S<:Period}(r::StepRange{T,S}, i::Integer) = length(r) <= i
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ export
cumsum,
cumsum!,
cumsum_kbn,
eachelement,
eachindex,
extrema,
fill!,
fill,
Expand Down
106 changes: 106 additions & 0 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,109 @@
### Multidimensional iterators
module IteratorsMD

import Base: start, _start, done, next, getindex, setindex!, linearindexing
import Base: @nref, @ncall, @nif, @nexprs, LinearFast, LinearSlow

export eachindex

# Traits for linear indexing
linearindexing(::BitArray) = LinearFast()

# Iterator/state
abstract CartesianIndex{N} # the state for all multidimensional iterators
abstract IndexIterator{N} # Iterator that visits the index associated with each element

stagedfunction Base.call{N}(::Type{CartesianIndex},index::NTuple{N,Int})
indextype,itertype=gen_cartesian(N)
return :($indextype(index))
end
stagedfunction Base.call{N}(::Type{IndexIterator},index::NTuple{N,Int})
indextype,itertype=gen_cartesian(N)
return :($itertype(index))
end

let implemented = IntSet()
global gen_cartesian
function gen_cartesian(N::Int, with_shared=Base.is_unix(OS_NAME))
# Create the types
indextype = symbol("CartesianIndex_$N")
itertype = symbol("IndexIterator_$N")
if !in(N,implemented)
fieldnames = [symbol("I_$i") for i = 1:N]
fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N]
extype = Expr(:type, false, Expr(:(<:), indextype, Expr(:curly, :CartesianIndex, N)), Expr(:block, fields...))
exindices = Expr[:(index[$i]) for i = 1:N]

onesN = ones(Int, N)
infsN = fill(typemax(Int), N)
anyzero = Expr(:(||), [:(iter.dims.$(fieldnames[i]) == 0) for i = 1:N]...)

# Some necessary ambiguity resolution
exrange = N != 1 ? nothing : quote
next(R::StepRange, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1)
next{T}(R::UnitRange{T}, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1)
end
exshared = !with_shared ? nothing : quote
getindex{T}(S::SharedArray{T,$N}, I::$indextype) = S.s[I]
setindex!{T}(S::SharedArray{T,$N}, v, I::$indextype) = S.s[I] = v
end
totalex = quote
# type definition
$extype
# extra constructor from tuple
$indextype(index::NTuple{$N,Int}) = $indextype($(exindices...))

immutable $itertype <: IndexIterator{$N}
dims::$indextype
end
$itertype(dims::NTuple{$N,Int})=$itertype($indextype(dims))

# getindex and setindex!
$exshared
getindex{T}(A::AbstractArray{T,$N}, index::$indextype) = @nref $N A d->getfield(index,d)
setindex!{T}(A::AbstractArray{T,$N}, v, index::$indextype) = (@nref $N A d->getfield(index,d)) = v

# next iteration
$exrange
@inline function next{T}(A::AbstractArray{T,$N}, state::$indextype)
@inbounds v = A[state]
newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
v, newstate
end
@inline function next(iter::$itertype, state::$indextype)
newstate = @nif $N d->(getfield(state,d) < getfield(iter.dims,d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
state, newstate
end

# start
start(iter::$itertype) = $anyzero ? $indextype($(infsN...)) : $indextype($(onesN...))
end
eval(totalex)
push!(implemented,N)
end
return indextype, itertype
end
end

# Iteration
eachindex(A::AbstractArray) = IndexIterator(size(A))

# start iteration
_start{T,N}(A::AbstractArray{T,N},::LinearSlow) = CartesianIndex(ntuple(N,n->ifelse(isempty(A),typemax(Int),1))::NTuple{N,Int})

# Ambiguity resolution
done(R::StepRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
done(R::UnitRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
done(R::FloatRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)

done{T,N}(A::AbstractArray{T,N}, I::CartesianIndex{N}) = getfield(I, N) > size(A, N)
done{N}(iter::IndexIterator{N}, I::CartesianIndex{N}) = getfield(I, N) > getfield(iter.dims, N)

end # IteratorsMD

using .IteratorsMD


### From array.jl

@ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...)
Expand Down
4 changes: 2 additions & 2 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ copy(r::Range) = r
## iteration

start(r::FloatRange) = 0
next{T}(r::FloatRange{T}, i) = (convert(T, (r.start + i*r.step)/r.divisor), i+1)
done(r::FloatRange, i) = (length(r) <= i)
next{T}(r::FloatRange{T}, i::Int) = (convert(T, (r.start + i*r.step)/r.divisor), i+1)
done(r::FloatRange, i::Int) = (length(r) <= i)

# NOTE: For ordinal ranges, we assume start+step might be from a
# lifted domain (e.g. Int8+Int8 => Int); use that for iterating.
Expand Down
12 changes: 5 additions & 7 deletions base/sharedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,16 @@ end

convert(::Type{Array}, S::SharedArray) = S.s

# # pass through getindex and setindex! - they always work on the complete array unlike DArrays
# pass through getindex and setindex! - they always work on the complete array unlike DArrays
getindex(S::SharedArray) = getindex(S.s)
getindex(S::SharedArray, I::Real) = getindex(S.s, I)
getindex(S::SharedArray, I::AbstractArray) = getindex(S.s, I)
@nsplat N 1:5 getindex(S::SharedArray, I::NTuple{N,Any}...) = getindex(S.s, I...)

setindex!(S::SharedArray, x) = (setindex!(S.s, x); S)
setindex!(S::SharedArray, x, I::Real) = (setindex!(S.s, x, I); S)
setindex!(S::SharedArray, x, I::AbstractArray) = (setindex!(S.s, x, I); S)
@nsplat N 1:5 setindex!(S::SharedArray, x, I::NTuple{N,Any}...) = (setindex!(S.s, x, I...); S)
setindex!(S::SharedArray, x) = setindex!(S.s, x)
setindex!(S::SharedArray, x, I::Real) = setindex!(S.s, x, I)
setindex!(S::SharedArray, x, I::AbstractArray) = setindex!(S.s, x, I)
@nsplat N 1:5 setindex!(S::SharedArray, x, I::NTuple{N,Any}...) = setindex!(S.s, x, I...)

function fill!(S::SharedArray, v)
f = S->fill!(S.loc_subarr_1d, v)
Expand Down Expand Up @@ -377,5 +377,3 @@ end
end

@unix_only shm_open(shm_seg_name, oflags, permissions) = ccall(:shm_open, Int, (Ptr{UInt8}, Int, Int), shm_seg_name, oflags, permissions)


51 changes: 51 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,3 +925,54 @@ end
b718cbc = 5
@test b718cbc[1.0] == 5
@test_throws InexactError b718cbc[1.1]

# Multidimensional iterators
function mdsum(A)
s = 0.0
for a in A
s += a
end
s
end

function mdsum2(A)
s = 0.0
@inbounds for I in eachindex(A)
s += A[I]
end
s
end

a = [1:5]
@test isa(Base.linearindexing(a), Base.LinearFast)
b = sub(a, :)
@test isa(Base.linearindexing(b), Base.IteratorsMD.LinearSlow)
shp = [5]
for i = 1:10
A = reshape(a, tuple(shp...))
@test mdsum(A) == 15
@test mdsum2(A) == 15
B = sub(A, ntuple(i, i->Colon())...)
@test mdsum(B) == 15
@test mdsum2(B) == 15
unshift!(shp, 1)
end

a = [1:10]
shp = [2,5]
for i = 2:10
A = reshape(a, tuple(shp...))
@test mdsum(A) == 55
@test mdsum2(A) == 55
B = sub(A, ntuple(i, i->Colon())...)
@test mdsum(B) == 55
@test mdsum2(B) == 55
insert!(shp, 2, 1)
end

a = ones(0,5)
b = sub(a, :, :)
@test mdsum(b) == 0
a = ones(5,0)
b = sub(a, :, :)
@test mdsum(b) == 0

0 comments on commit 37ffa13

Please sign in to comment.