From d114e11c93c416d35da059d4b6b5f6799628dd79 Mon Sep 17 00:00:00 2001 From: "Steven G. Johnson" Date: Mon, 10 Mar 2014 17:49:06 -0400 Subject: [PATCH] more type-stable reductions (fix #6069) --- base/reduce.jl | 141 +++++++++++++++++++++++++++++++---------------- test/arrayops.jl | 29 ++++++++++ 2 files changed, 123 insertions(+), 47 deletions(-) diff --git a/base/reduce.jl b/base/reduce.jl index af19f92a92504..48577c45f04a8 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -2,6 +2,9 @@ ###### higher level reduction functions ###### +# Note that getting type-stable results from reduction functions, +# or at least having type-stable loops, is nontrivial (#6069). + # reduce function reduce(op::Callable, itr) # this is a left fold @@ -19,17 +22,24 @@ function reduce(op::Callable, itr) # this is a left fold return op() # empty collection end (v, s) = next(itr, s) - while !done(itr, s) + if done(itr, s) + return v + else # specialize for length > 1 to have a hopefully type-stable loop (x, s) = next(itr, s) - v = op(v,x) + result = op(v, x) + while !done(itr, s) + (x, s) = next(itr, s) + result = op(result, x) + end + return result end - return v end +# pairwise reduction, requires n > 1 (to allow type-stable loop) function r_pairwise(op::Callable, A::AbstractArray, i1,n) if n < 128 - @inbounds v = A[i1] - for i = i1+1:i1+n-1 + @inbounds v = op(A[i1], A[i1+1]) + for i = i1+2:i1+n-1 @inbounds v = op(v,A[i]) end return v @@ -41,12 +51,12 @@ end function reduce(op::Callable, A::AbstractArray) n = length(A) - n == 0 ? op() : r_pairwise(op,A, 1,n) + n == 0 ? op() : n == 1 ? A[1] : r_pairwise(op,A, 1,n) end function reduce(op::Callable, v0, A::AbstractArray) n = length(A) - n == 0 ? v0 : op(v0, r_pairwise(op,A, 1,n)) + n == 0 ? v0 : n == 1 ? op(v0, A[1]) : op(v0, r_pairwise(op,A, 1,n)) end function reduce(op::Callable, v0, itr) @@ -169,46 +179,53 @@ end function sum(itr) s = start(itr) if done(itr, s) - return 0 + if method_exists(eltype, (typeof(itr),)) + T = eltype(itr) + return zero(T) + zero(T) + else + throw(ArgumentError("sum(itr) is undefined for empty collections; instead, do isempty(itr) ? z : sum(itr), where z is the correct type of zero for your sum")) + end end (v, s) = next(itr, s) + done(itr, s) && return v + zero(v) # adding zero for type stability + # specialize for length > 1 to have type-stable loop + (x, s) = next(itr, s) + result = v + x while !done(itr, s) (x, s) = next(itr, s) - v += x + result += x end - return v + return result end sum(A::AbstractArray{Bool}) = countnz(A) -# a fast implementation of sum in sequential order (from left to right) +# a fast implementation of sum in sequential order (from left to right). +# to allow type-stable loops, requires length > 1 function sum_seq{T}(a::AbstractArray{T}, ifirst::Int, ilast::Int) - - @inbounds if ifirst + 3 >= ilast # a has at most four elements - if ifirst > ilast - return zero(T) - else - i = ifirst - s = a[i] - while i < ilast - s += a[i+=1] - end - return s + + @inbounds if ifirst + 6 >= ilast # length(a) < 8 + i = ifirst + s = a[i] + a[i+1] + i = i+1 + while i < ilast + s += a[i+=1] end + return s - else # a has more than four elements + else # length(a) >= 8 # more effective utilization of the instruction # pipeline through manually unrolling the sum # into four-way accumulation. Benchmark shows # that this results in about 2x speed-up. - s1 = a[ifirst] - s2 = a[ifirst + 1] - s3 = a[ifirst + 2] - s4 = a[ifirst + 3] + s1 = a[ifirst] + a[ifirst + 4] + s2 = a[ifirst + 1] + a[ifirst + 5] + s3 = a[ifirst + 2] + a[ifirst + 6] + s4 = a[ifirst + 3] + a[ifirst + 7] - i = ifirst + 4 + i = ifirst + 8 il = ilast - 3 while i <= il s1 += a[i] @@ -243,6 +260,7 @@ end # Note: sum_seq uses four accumulators, so each accumulator gets at most 256 numbers const PAIRWISE_SUM_BLOCKSIZE = 1024 +# note: requires length > 1, due to sum_seq function sum_pairwise(a::AbstractArray, ifirst::Int, ilast::Int) # bsiz: maximum block size @@ -254,18 +272,29 @@ function sum_pairwise(a::AbstractArray, ifirst::Int, ilast::Int) end end -sum(a::AbstractArray) = sum_pairwise(a, 1, length(a)) -sum{T<:Integer}(a::AbstractArray{T}) = sum_seq(a, 1, length(a)) +function sum{T}(a::AbstractArray{T}) + n = length(a) + n == 0 && return zero(T) + zero(T) + n == 1 && return a[1] + zero(a[1]) + sum_pairwise(a, 1, length(a)) +end + +function sum{T<:Integer}(a::AbstractArray{T}) + n = length(a) + n == 0 && return zero(T) + zero(T) + n == 1 && return a[1] + zero(a[1]) + sum_seq(a, 1, length(a)) +end # Kahan (compensated) summation: O(1) error growth, at the expense # of a considerable increase in computational expense. function sum_kbn{T<:FloatingPoint}(A::AbstractArray{T}) n = length(A) if (n == 0) - return zero(T) + return zero(T)+zero(T) end - s = A[1] - c = zero(T) + s = A[1]+zero(T) + c = zero(T)+zero(T) for i in 2:n Ai = A[i] t = s + Ai @@ -286,14 +315,23 @@ end function prod(itr) s = start(itr) if done(itr, s) - return *() + if method_exists(eltype, (typeof(itr),)) + T = eltype(itr) + return one(T) * one(T) + else + throw(ArgumentError("prod(itr) is undefined for empty collections; instead, do isempty(itr) ? o : prod(itr), where o is the correct type of identity for your product")) + end end (v, s) = next(itr, s) + done(itr, s) && return v * one(v) # multiplying by one for type stability + # specialize for length > 1 to have type-stable loop + (x, s) = next(itr, s) + result = v * x while !done(itr, s) (x, s) = next(itr, s) - v = v*x + result *= x end - return v + return result end prod(A::AbstractArray{Bool}) = @@ -301,14 +339,16 @@ prod(A::AbstractArray{Bool}) = function prod_rgn{T}(A::AbstractArray{T}, first::Int, last::Int) if first > last - return one(T) + return one(T) * one(T) end i = first - v = A[i] + @inbounds v = A[i] + i == last && return v * one(v) + @inbounds result = v * A[i+=1] while i < last - @inbounds v *= A[i+=1] + @inbounds result *= A[i+=1] end - return v + return result end prod{T}(A::AbstractArray{T}) = prod_rgn(A, 1, length(A)) @@ -467,11 +507,17 @@ function mapreduce(f::Callable, op::Callable, itr) end (x, s) = next(itr, s) v = f(x) - while !done(itr, s) + if done(itr, s) + return v + else # specialize for length > 1 to have a hopefully type-stable loop (x, s) = next(itr, s) - v = op(v,f(x)) + result = op(v, f(x)) + while !done(itr, s) + (x, s) = next(itr, s) + result = op(result, f(x)) + end + return result end - return v end function mapreduce(f::Callable, op::Callable, v0, itr) @@ -482,10 +528,11 @@ function mapreduce(f::Callable, op::Callable, v0, itr) return v end +# pairwise reduction, requires n > 1 (to allow type-stable loop) function mr_pairwise(f::Callable, op::Callable, A::AbstractArray, i1,n) if n < 128 - @inbounds v = f(A[i1]) - for i = i1+1:i1+n-1 + @inbounds v = op(f(A[i1]), f(A[i1+1])) + for i = i1+2:i1+n-1 @inbounds v = op(v,f(A[i])) end return v @@ -496,11 +543,11 @@ function mr_pairwise(f::Callable, op::Callable, A::AbstractArray, i1,n) end function mapreduce(f::Callable, op::Callable, A::AbstractArray) n = length(A) - n == 0 ? op() : mr_pairwise(f,op,A, 1,n) + n == 0 ? op() : n == 1 ? f(A[1]) : mr_pairwise(f,op,A, 1,n) end function mapreduce(f::Callable, op::Callable, v0, A::AbstractArray) n = length(A) - n == 0 ? v0 : op(v0, mr_pairwise(f,op,A, 1,n)) + n == 0 ? v0 : n == 1 ? op(v0, f(A[1])) : op(v0, mr_pairwise(f,op,A, 1,n)) end # specific mapreduce functions diff --git a/test/arrayops.jl b/test/arrayops.jl index 08cf04dfd82bc..7c925cf663428 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -367,6 +367,35 @@ end @test sum(z) == sum(z,(1,2,3,4))[1] == 136 +# check variants of summation for type-stability and other issues (#6069) +sum2(itr) = invoke(sum, (Any,), itr) +plus(x,y) = x + y +plus() = 0 +sum3(A) = reduce(plus, A) +sum4(itr) = invoke(reduce, (Function, Any), plus, itr) +sum5(A) = reduce(plus, 0, A) +sum6(itr) = invoke(reduce, (Function, Int, Any), plus, 0, itr) +sum7(A) = mapreduce(x->x, plus, A) +sum8(itr) = invoke(mapreduce, (Function, Function, Any), x->x, plus, itr) +sum9(A) = mapreduce(x->x, plus, 0, A) +sum10(itr) = invoke(mapreduce, (Function, Function, Int, Any), x->x,plus,0,itr) +for f in (sum2, sum3, sum4, sum5, sum6, sum7, sum8, sum9, sum10) + @test sum(z) == f(z) + @test sum(Int[]) == f(Int[]) == 0 + @test sum(Int[7]) == f(Int[7]) == 7 + if f == sum3 || f == sum4 || f == sum7 || f == sum8 + @test typeof(f(Int8[])) == typeof(f(Int8[1 7])) + else + @test typeof(f(Int8[])) == typeof(f(Int8[1])) == typeof(f(Int8[1 7])) + end +end +@test typeof(sum(Int8[])) == typeof(sum(Int8[1])) == typeof(sum(Int8[1 7])) + +prod2(itr) = invoke(prod, (Any,), itr) +@test prod(Int[]) == prod2(Int[]) == 1 +@test prod(Int[7]) == prod2(Int[7]) == 7 +@test typeof(prod(Int8[])) == typeof(prod(Int8[1])) == typeof(prod(Int8[1 7])) == typeof(prod2(Int8[])) == typeof(prod2(Int8[1])) == typeof(prod2(Int8[1 7])) + v = cell(2,2,1,1) v[1,1,1,1] = 28.0 v[1,2,1,1] = 36.0