Skip to content

Commit

Permalink
implement N-d generators
Browse files Browse the repository at this point in the history
[ci skip]
  • Loading branch information
JeffBezanson committed Jan 27, 2016
1 parent 9345d67 commit 6bdaa9b
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 6 deletions.
13 changes: 8 additions & 5 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ function map!{F}(f::F, dest::AbstractArray, A::AbstractArray)
return dest
end

function map_to!{T,F}(f::F, offs, st, dest::AbstractArray{T}, A::AbstractArray)
function map_to!{T,F}(f::F, offs, st, dest::AbstractArray{T}, A)
# map to dest array, checking the type of each result. if a result does not
# match, widen the result type and re-dispatch.
i = offs
Expand All @@ -1318,9 +1318,12 @@ function map_to!{T,F}(f::F, offs, st, dest::AbstractArray{T}, A::AbstractArray)
return dest
end

_default_eltype(f::Type) = f
_default_eltype(f) = Union{}

function map(f, A::AbstractArray)
if isempty(A)
return isa(f,Type) ? similar(A,f) : similar(A,Union{})
return similar(A, _default_eltype(f))
end
st = start(A)
A1, st = next(A, st)
Expand All @@ -1338,7 +1341,7 @@ function map!{F}(f::F, dest::AbstractArray, A::AbstractArray, B::AbstractArray)
return dest
end

function map_to!{T,F}(f::F, offs, dest::AbstractArray{T}, A::AbstractArray, B::AbstractArray)
function map_to!{T,F}(f::F, offs, st, dest::AbstractArray{T}, A, B)
for i = offs:length(A)
@inbounds Ai, Bi = A[i], B[i]
el = f(Ai, Bi)
Expand All @@ -1358,12 +1361,12 @@ end
function map(f, A::AbstractArray, B::AbstractArray)
shp = promote_shape(size(A),size(B))
if prod(shp) == 0
return similar(A, Union{}, shp)
return similar(A, _default_eltype(f), shp)
end
first = f(A[1], B[1])
dest = similar(A, typeof(first), shp)
dest[1] = first
return map_to!(f, 2, dest, A, B)
return map_to!(f, 2, nothing, dest, A, B)
end

## N argument
Expand Down
1 change: 1 addition & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ unsafe_convert{T}(::Type{T}, x::T) = x
(::Type{Array{T}}){T}(m::Int, n::Int, o::Int) = Array{T,3}(m, n, o)

# TODO: possibly turn these into deprecations
Array{T,N}(::Type{T}, d::NTuple{N,Int}) = Array{T}(d)
Array{T}(::Type{T}, d::Int...) = Array{T}(d)
Array{T}(::Type{T}, m::Int) = Array{T,1}(m)
Array{T}(::Type{T}, m::Int,n::Int) = Array{T,2}(m,n)
Expand Down
52 changes: 52 additions & 0 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,55 @@ eltype{I1,I2}(::Type{Prod{I1,I2}}) = tuple_type_cons(eltype(I1), eltype(I2))
x = prod_next(p, st)
((x[1][1],x[1][2]...), x[2])
end

immutable GeneratorND{F,I<:AbstractProdIterator}
f::F
iter::I

(::Type{GeneratorND}){F}(f::F, iters...) = (P = product(iters...); new{F,typeof(P)}(f, P))
end

start(g::GeneratorND) = start(g.iter)
done(g::GeneratorND, s) = done(g.iter, s)
function next(g::GeneratorND, s)
v, s2 = next(g.iter, s)
g.f(v...), s2
end

_size(p::Prod2) = (length(p.a), length(p.b))
_size(p::Prod) = (length(p.a), _size(p.b)...)

size(g::GeneratorND) = _size(g.iter)

function collect(g::GeneratorND)
sz = size(g)
if prod(sz) == 0
return Array(Union{}, sz)
end
st = start(g.iter)
A1, st = next(g.iter, st)
first = g.f(A1...)
dest = Array(typeof(first), sz)
dest[1] = first
return map_to!(xs->g.f(xs...), 2, st, dest, g.iter)
end

# special case for 2d
function collect{F,P<:Prod2}(g::GeneratorND{F,P})
f = g.f
a = g.iter.a
b = g.iter.b
sz = size(g)
if prod(sz) == 0
return Array(Union{}, sz)
end
fst = f(first(a), first(b)) # TODO: don't recompute this in the loop
dest = Array(typeof(fst), sz)
for j in b
for i in a
val = f(i, j) # TODO: handle type changes
@inbounds dest[i, j] = val
end
end
return dest
end
4 changes: 3 additions & 1 deletion src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1859,7 +1859,9 @@
vars names))))
(expand-forms
(expand-binding-forms
`(call (top Generator) (-> (tuple ,@names) (block ,@stmts ,expr)) ,@ranges))))))
`(call (top ,(if (length> ranges 1) 'GeneratorND 'Generator))
(-> (tuple ,@names) (block ,@stmts ,expr))
,@ranges))))))

'comprehension
(lambda (e) (expand-forms
Expand Down

0 comments on commit 6bdaa9b

Please sign in to comment.