From 26a6bfcae091b02bd9833f0ef5e1ba41efaa9d46 Mon Sep 17 00:00:00 2001 From: Davide Lasagna Date: Sun, 29 May 2016 19:45:03 +0100 Subject: [PATCH] RFC: Change iteratorsize trait of `product(itr1, itr2)` (#16437) change iteratorsize trait of `product(itr1, itr2)` fixes #16436 - Adds many tests to product function and tests more thoroughly the iterator traits - Adds a Prod1 type - Adds ndims(::Base.Prod*) - Change state of Prod1 iterator from tuple to integer --- base/iterator.jl | 58 ++++++++++++--- test/functional.jl | 180 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 221 insertions(+), 17 deletions(-) diff --git a/base/iterator.jl b/base/iterator.jl index 53fd02f564390..c3479f8fab74b 100644 --- a/base/iterator.jl +++ b/base/iterator.jl @@ -302,10 +302,47 @@ done(it::Repeated, state) = false iteratorsize{O}(::Type{Repeated{O}}) = IsInfinite() iteratoreltype{O}(::Type{Repeated{O}}) = HasEltype() -# product + +# Product -- cartesian product of iterators abstract AbstractProdIterator +length(p::AbstractProdIterator) = prod(size(p)) +size(p::AbstractProdIterator) = _prod_size(p.a, p.b, iteratorsize(p.a), iteratorsize(p.b)) +ndims(p::AbstractProdIterator) = length(size(p)) + +# generic methods to handle size of Prod* types +_prod_size(a, ::HasShape) = size(a) +_prod_size(a, ::HasLength) = (length(a), ) +_prod_size(a, A) = + throw(ArgumentError("Cannot compute size for object of type $(typeof(a))")) +_prod_size(a, b, ::HasLength, ::HasLength) = (length(a), length(b)) +_prod_size(a, b, ::HasLength, ::HasShape) = (length(a), size(b)...) +_prod_size(a, b, ::HasShape, ::HasLength) = (size(a)..., length(b)) +_prod_size(a, b, ::HasShape, ::HasShape) = (size(a)..., size(b)...) +_prod_size(a, b, A, B) = + throw(ArgumentError("Cannot construct size for objects of types $(typeof(a)) and $(typeof(b))")) + +# one iterator +immutable Prod1{I} <: AbstractProdIterator + a::I +end +product(a) = Prod1(a) + +eltype{I}(::Type{Prod1{I}}) = Tuple{eltype(I)} +size(p::Prod1) = _prod_size(p.a, iteratorsize(p.a)) + +@inline start(p::Prod1) = start(p.a) +@inline function next(p::Prod1, st) + n, st = next(p.a, st) + (n, ), st +end +@inline done(p::Prod1, st) = done(p.a, st) + +iteratoreltype{I}(::Type{Prod1{I}}) = iteratoreltype(I) +iteratorsize{I}(::Type{Prod1{I}}) = iteratorsize(I) + +# two iterators immutable Prod2{I1, I2} <: AbstractProdIterator a::I1 b::I2 @@ -327,11 +364,11 @@ changes the fastest. Example: (1,5) (2,5) """ -product(a) = Zip1(a) product(a, b) = Prod2(a, b) + eltype{I1,I2}(::Type{Prod2{I1,I2}}) = Tuple{eltype(I1), eltype(I2)} + iteratoreltype{I1,I2}(::Type{Prod2{I1,I2}}) = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2)) -length(p::AbstractProdIterator) = length(p.a)*length(p.b) iteratorsize{I1,I2}(::Type{Prod2{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2)) function start(p::AbstractProdIterator) @@ -359,13 +396,15 @@ end @inline next(p::Prod2, st) = prod_next(p, st) @inline done(p::AbstractProdIterator, st) = st[4] +# n iterators immutable Prod{I1, I2<:AbstractProdIterator} <: AbstractProdIterator a::I1 b::I2 end - product(a, b, c...) = Prod(a, product(b, c...)) + eltype{I1,I2}(::Type{Prod{I1,I2}}) = tuple_type_cons(eltype(I1), eltype(I2)) + iteratoreltype{I1,I2}(::Type{Prod{I1,I2}}) = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2)) iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2)) @@ -374,12 +413,13 @@ iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),it ((x[1][1],x[1][2]...), x[2]) end -prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasLength() -prod_iteratorsize(a, ::IsInfinite) = IsInfinite() # products can have an infinite last iterator (which moves slowest) +prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasShape() +# products can have an infinite iterator +prod_iteratorsize(::IsInfinite, ::IsInfinite) = IsInfinite() +prod_iteratorsize(a, ::IsInfinite) = IsInfinite() +prod_iteratorsize(::IsInfinite, b) = IsInfinite() prod_iteratorsize(a, b) = SizeUnknown() -_size(p::Prod2) = (length(p.a), length(p.b)) -_size(p::Prod) = (length(p.a), _size(p.b)...) """ IteratorND(iter, dims) @@ -400,7 +440,7 @@ immutable IteratorND{I,N} end new{I,N}(iter, shape) end - (::Type{IteratorND}){I<:AbstractProdIterator}(p::I) = IteratorND(p, _size(p)) + (::Type{IteratorND}){I<:AbstractProdIterator}(p::I) = IteratorND(p, size(p)) end start(i::IteratorND) = start(i.iter) diff --git a/test/functional.jl b/test/functional.jl index 55c762aee1f12..754442d8bfbeb 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -188,14 +188,178 @@ end # product # ------- -@test isempty(Base.product(1:2,1:0)) -@test isempty(Base.product(1:2,1:0,1:10)) -@test isempty(Base.product(1:2,1:10,1:0)) -@test isempty(Base.product(1:0,1:2,1:10)) -@test collect(Base.product(1:2,3:4)) == [(1,3),(2,3),(1,4),(2,4)] -@test isempty(collect(Base.product(1:0,1:2))) -@test length(Base.product(1:2,1:10,4:6)) == 60 -@test Base.iteratorsize(Base.product(1:2, countfrom(1))) == Base.IsInfinite() +# empty? +for itr in [Base.product(1:0), + Base.product(1:2, 1:0), + Base.product(1:0, 1:2), + Base.product(1:0, 1:1, 1:2), + Base.product(1:1, 1:0, 1:2), + Base.product(1:1, 1:2 ,1:0)] + @test isempty(itr) + @test isempty(collect(itr)) +end + +# collect a product - first iterators runs faster +@test collect(Base.product(1:2)) == [(i,) for i=1:2] +@test collect(Base.product(1:2, 3:4)) == [(i, j) for i=1:2, j=3:4] +@test collect(Base.product(1:2, 3:4, 5:6)) == [(i, j, k) for i=1:2, j=3:4, k=5:6] + +# iteration order +let + expected = [(1,3,5), (2,3,5), (1,4,5), (2,4,5), (1,3,6), (2,3,6), (1,4,6), (2,4,6)] + actual = Base.product(1:2, 3:4, 5:6) + for (exp, act) in zip(expected, actual) + @test exp == act + end +end + +# collect multidimensional array +let + a, b = 1:3, [4 6; + 5 7] + p = Base.product(a, b) + @test size(p) == (3, 2, 2) + @test length(p) == 12 + @test ndims(p) == 3 + @test eltype(p) == NTuple{2, Int} + cp = collect(p) + for i = 1:3 + @test cp[i, :, :] == [(i, 4) (i, 6); + (i, 5) (i, 7)] + end +end + +# with 1D inputs +let + a, b, c = 1:2, 1.0:10.0, Int32(1):Int32(0) + + # length + @test length(Base.product(a)) == 2 + @test length(Base.product(a, b)) == 20 + @test length(Base.product(a, b, c)) == 0 + + # size + @test size(Base.product(a)) == (2, ) + @test size(Base.product(a, b)) == (2, 10) + @test size(Base.product(a, b, c)) == (2, 10, 0) + + # eltype + @test eltype(Base.product(a)) == Tuple{Int} + @test eltype(Base.product(a, b)) == Tuple{Int, Float64} + @test eltype(Base.product(a, b, c)) == Tuple{Int, Float64, Int32} + + # ndims + @test ndims(Base.product(a)) == 1 + @test ndims(Base.product(a, b)) == 2 + @test ndims(Base.product(a, b, c)) == 3 +end + +# with multidimensional inputs +let + a, b, c = randn(4, 4), randn(3, 3, 3), randn(2, 2, 2, 2) + args = Any[(a,), + (a, a), + (a, b), + (a, a, a), + (a, b, c)] + sizes = Any[(4, 4), + (4, 4, 4, 4), + (4, 4, 3, 3, 3), + (4, 4, 4, 4, 4, 4), + (4, 4, 3, 3, 3, 2, 2, 2, 2)] + for (method, fun) in zip([size, ndims, length], [x->x, length, prod]) + for i in 1:length(args) + @test method(Base.product(args[i]...)) == method(collect(Base.product(args[i]...))) == fun(sizes[i]) + end + end +end + +# more tests on product with iterators of various type +let + iters = (1:2, + rand(2, 2, 2), + take(1:4, 2), + Base.product(1:2, 1:3), + Base.product(rand(2, 2), rand(1, 1, 1)) + ) + for method in [size, length, ndims, eltype] + for i = 1:length(iters) + args = iters[i] + @test method(Base.product(args...)) == method(collect(Base.product(args...))) + for j = 1:length(iters) + args = iters[i], iters[j] + @test method(Base.product(args...)) == method(collect(Base.product(args...))) + for k = 1:length(iters) + args = iters[i], iters[j], iters[k] + @test method(Base.product(args...)) == method(collect(Base.product(args...))) + end + end + end + end +end + +# product of finite length and infinite length iterators +let + a = 1:2 + b = countfrom(1) + ab = Base.product(a, b) + ba = Base.product(b, a) + abexp = [(1, 1), (2, 1), (1, 2), (2, 2), (1, 3), (2, 3)] + baexp = [(1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)] + for (expected, actual) in zip([abexp, baexp], [ab, ba]) + for (i, el) in enumerate(actual) + @test el == expected[i] + i == length(expected) && break + end + @test_throws ArgumentError length(actual) + @test_throws ArgumentError size(actual) + @test_throws ArgumentError ndims(actual) + end + + # size infinite or unknown raises an error + for itr in Any[countfrom(1), Filter(i->0, 1:10)] + @test_throws ArgumentError length(Base.product(itr)) + @test_throws ArgumentError size(Base.product(itr)) + @test_throws ArgumentError ndims(Base.product(itr)) + end +end + +# iteratorsize trait business +let f1 = Filter(i->i>0, 1:10) + @test Base.iteratorsize(Base.product(f1)) == Base.SizeUnknown() + @test Base.iteratorsize(Base.product(1:2, f1)) == Base.SizeUnknown() + @test Base.iteratorsize(Base.product(f1, 1:2)) == Base.SizeUnknown() + @test Base.iteratorsize(Base.product(f1, f1)) == Base.SizeUnknown() + @test Base.iteratorsize(Base.product(f1, countfrom(1))) == Base.IsInfinite() + @test Base.iteratorsize(Base.product(countfrom(1), f1)) == Base.IsInfinite() +end +@test Base.iteratorsize(Base.product(1:2, countfrom(1))) == Base.IsInfinite() +@test Base.iteratorsize(Base.product(countfrom(2), countfrom(1))) == Base.IsInfinite() +@test Base.iteratorsize(Base.product(countfrom(1), 1:2)) == Base.IsInfinite() +@test Base.iteratorsize(Base.product(1:2)) == Base.HasShape() +@test Base.iteratorsize(Base.product(1:2, 1:2)) == Base.HasShape() +@test Base.iteratorsize(Base.product(take(1:2, 1), take(1:2, 1))) == Base.HasShape() +@test Base.iteratorsize(Base.product(take(1:2, 2))) == Base.HasLength() +@test Base.iteratorsize(Base.product([1 2; 3 4])) == Base.HasShape() + +# iteratoreltype trait business +let f1 = Filter(i->i>0, 1:10) + @test Base.iteratoreltype(Base.product(f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any + @test Base.iteratoreltype(Base.product(1:2, f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any + @test Base.iteratoreltype(Base.product(f1, 1:2)) == Base.HasEltype() # FIXME? eltype(f1) is Any + @test Base.iteratoreltype(Base.product(f1, f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any + @test Base.iteratoreltype(Base.product(f1, countfrom(1))) == Base.HasEltype() # FIXME? eltype(f1) is Any + @test Base.iteratoreltype(Base.product(countfrom(1), f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any +end +@test Base.iteratoreltype(Base.product(1:2, countfrom(1))) == Base.HasEltype() +@test Base.iteratoreltype(Base.product(countfrom(1), 1:2)) == Base.HasEltype() +@test Base.iteratoreltype(Base.product(1:2)) == Base.HasEltype() +@test Base.iteratoreltype(Base.product(1:2, 1:2)) == Base.HasEltype() +@test Base.iteratoreltype(Base.product(take(1:2, 1), take(1:2, 1))) == Base.HasEltype() +@test Base.iteratoreltype(Base.product(take(1:2, 2))) == Base.HasEltype() +@test Base.iteratoreltype(Base.product([1 2; 3 4])) == Base.HasEltype() + + # flatten # -------