Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more type-stable reductions #6116

Merged
merged 1 commit into from
Mar 12, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 94 additions & 47 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

###### higher level reduction functions ######

# Note that getting type-stable results from reduction functions,
# or at least having type-stable loops, is nontrivial (#6069).

# reduce

function reduce(op::Callable, itr) # this is a left fold
Expand All @@ -19,17 +22,24 @@ function reduce(op::Callable, itr) # this is a left fold
return op() # empty collection
end
(v, s) = next(itr, s)
while !done(itr, s)
if done(itr, s)
return v
else # specialize for length > 1 to have a hopefully type-stable loop
(x, s) = next(itr, s)
v = op(v,x)
result = op(v, x)
while !done(itr, s)
(x, s) = next(itr, s)
result = op(result, x)
end
return result
end
return v
end

# pairwise reduction, requires n > 1 (to allow type-stable loop)
function r_pairwise(op::Callable, A::AbstractArray, i1,n)
if n < 128
@inbounds v = A[i1]
for i = i1+1:i1+n-1
@inbounds v = op(A[i1], A[i1+1])
for i = i1+2:i1+n-1
@inbounds v = op(v,A[i])
end
return v
Expand All @@ -41,12 +51,12 @@ end

function reduce(op::Callable, A::AbstractArray)
n = length(A)
n == 0 ? op() : r_pairwise(op,A, 1,n)
n == 0 ? op() : n == 1 ? A[1] : r_pairwise(op,A, 1,n)
end

function reduce(op::Callable, v0, A::AbstractArray)
n = length(A)
n == 0 ? v0 : op(v0, r_pairwise(op,A, 1,n))
n == 0 ? v0 : n == 1 ? op(v0, A[1]) : op(v0, r_pairwise(op,A, 1,n))
end

function reduce(op::Callable, v0, itr)
Expand Down Expand Up @@ -169,46 +179,53 @@ end
function sum(itr)
s = start(itr)
if done(itr, s)
return 0
if method_exists(eltype, (typeof(itr),))
T = eltype(itr)
return zero(T) + zero(T)
else
throw(ArgumentError("sum(itr) is undefined for empty collections; instead, do isempty(itr) ? z : sum(itr), where z is the correct type of zero for your sum"))
end
end
(v, s) = next(itr, s)
done(itr, s) && return v + zero(v) # adding zero for type stability
# specialize for length > 1 to have type-stable loop
(x, s) = next(itr, s)
result = v + x
while !done(itr, s)
(x, s) = next(itr, s)
v += x
result += x
end
return v
return result
end

sum(A::AbstractArray{Bool}) = countnz(A)

# a fast implementation of sum in sequential order (from left to right)
# a fast implementation of sum in sequential order (from left to right).
# to allow type-stable loops, requires length > 1
function sum_seq{T}(a::AbstractArray{T}, ifirst::Int, ilast::Int)

@inbounds if ifirst + 3 >= ilast # a has at most four elements
if ifirst > ilast
return zero(T)
else
i = ifirst
s = a[i]
while i < ilast
s += a[i+=1]
end
return s

@inbounds if ifirst + 6 >= ilast # length(a) < 8
i = ifirst
s = a[i] + a[i+1]
i = i+1
while i < ilast
s += a[i+=1]
end
return s

else # a has more than four elements
else # length(a) >= 8

# more effective utilization of the instruction
# pipeline through manually unrolling the sum
# into four-way accumulation. Benchmark shows
# that this results in about 2x speed-up.

s1 = a[ifirst]
s2 = a[ifirst + 1]
s3 = a[ifirst + 2]
s4 = a[ifirst + 3]
s1 = a[ifirst] + a[ifirst + 4]
s2 = a[ifirst + 1] + a[ifirst + 5]
s3 = a[ifirst + 2] + a[ifirst + 6]
s4 = a[ifirst + 3] + a[ifirst + 7]

i = ifirst + 4
i = ifirst + 8
il = ilast - 3
while i <= il
s1 += a[i]
Expand Down Expand Up @@ -243,6 +260,7 @@ end
# Note: sum_seq uses four accumulators, so each accumulator gets at most 256 numbers
const PAIRWISE_SUM_BLOCKSIZE = 1024

# note: requires length > 1, due to sum_seq
function sum_pairwise(a::AbstractArray, ifirst::Int, ilast::Int)
# bsiz: maximum block size

Expand All @@ -254,18 +272,29 @@ function sum_pairwise(a::AbstractArray, ifirst::Int, ilast::Int)
end
end

sum(a::AbstractArray) = sum_pairwise(a, 1, length(a))
sum{T<:Integer}(a::AbstractArray{T}) = sum_seq(a, 1, length(a))
function sum{T}(a::AbstractArray{T})
n = length(a)
n == 0 && return zero(T) + zero(T)
n == 1 && return a[1] + zero(a[1])
sum_pairwise(a, 1, length(a))
end

function sum{T<:Integer}(a::AbstractArray{T})
n = length(a)
n == 0 && return zero(T) + zero(T)
n == 1 && return a[1] + zero(a[1])
sum_seq(a, 1, length(a))
end

# Kahan (compensated) summation: O(1) error growth, at the expense
# of a considerable increase in computational expense.
function sum_kbn{T<:FloatingPoint}(A::AbstractArray{T})
n = length(A)
if (n == 0)
return zero(T)
return zero(T)+zero(T)
end
s = A[1]
c = zero(T)
s = A[1]+zero(T)
c = zero(T)+zero(T)
for i in 2:n
Ai = A[i]
t = s + Ai
Expand All @@ -286,29 +315,40 @@ end
function prod(itr)
s = start(itr)
if done(itr, s)
return *()
if method_exists(eltype, (typeof(itr),))
T = eltype(itr)
return one(T) * one(T)
else
throw(ArgumentError("prod(itr) is undefined for empty collections; instead, do isempty(itr) ? o : prod(itr), where o is the correct type of identity for your product"))
end
end
(v, s) = next(itr, s)
done(itr, s) && return v * one(v) # multiplying by one for type stability
# specialize for length > 1 to have type-stable loop
(x, s) = next(itr, s)
result = v * x
while !done(itr, s)
(x, s) = next(itr, s)
v = v*x
result *= x
end
return v
return result
end

prod(A::AbstractArray{Bool}) =
error("use all() instead of prod() for boolean arrays")

function prod_rgn{T}(A::AbstractArray{T}, first::Int, last::Int)
if first > last
return one(T)
return one(T) * one(T)
end
i = first
v = A[i]
@inbounds v = A[i]
i == last && return v * one(v)
@inbounds result = v * A[i+=1]
while i < last
@inbounds v *= A[i+=1]
@inbounds result *= A[i+=1]
end
return v
return result
end
prod{T}(A::AbstractArray{T}) = prod_rgn(A, 1, length(A))

Expand Down Expand Up @@ -467,11 +507,17 @@ function mapreduce(f::Callable, op::Callable, itr)
end
(x, s) = next(itr, s)
v = f(x)
while !done(itr, s)
if done(itr, s)
return v
else # specialize for length > 1 to have a hopefully type-stable loop
(x, s) = next(itr, s)
v = op(v,f(x))
result = op(v, f(x))
while !done(itr, s)
(x, s) = next(itr, s)
result = op(result, f(x))
end
return result
end
return v
end

function mapreduce(f::Callable, op::Callable, v0, itr)
Expand All @@ -482,10 +528,11 @@ function mapreduce(f::Callable, op::Callable, v0, itr)
return v
end

# pairwise reduction, requires n > 1 (to allow type-stable loop)
function mr_pairwise(f::Callable, op::Callable, A::AbstractArray, i1,n)
if n < 128
@inbounds v = f(A[i1])
for i = i1+1:i1+n-1
@inbounds v = op(f(A[i1]), f(A[i1+1]))
for i = i1+2:i1+n-1
@inbounds v = op(v,f(A[i]))
end
return v
Expand All @@ -496,11 +543,11 @@ function mr_pairwise(f::Callable, op::Callable, A::AbstractArray, i1,n)
end
function mapreduce(f::Callable, op::Callable, A::AbstractArray)
n = length(A)
n == 0 ? op() : mr_pairwise(f,op,A, 1,n)
n == 0 ? op() : n == 1 ? f(A[1]) : mr_pairwise(f,op,A, 1,n)
end
function mapreduce(f::Callable, op::Callable, v0, A::AbstractArray)
n = length(A)
n == 0 ? v0 : op(v0, mr_pairwise(f,op,A, 1,n))
n == 0 ? v0 : n == 1 ? op(v0, f(A[1])) : op(v0, mr_pairwise(f,op,A, 1,n))
end

# specific mapreduce functions
Expand Down
29 changes: 29 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,35 @@ end

@test sum(z) == sum(z,(1,2,3,4))[1] == 136

# check variants of summation for type-stability and other issues (#6069)
sum2(itr) = invoke(sum, (Any,), itr)
plus(x,y) = x + y
plus() = 0
sum3(A) = reduce(plus, A)
sum4(itr) = invoke(reduce, (Function, Any), plus, itr)
sum5(A) = reduce(plus, 0, A)
sum6(itr) = invoke(reduce, (Function, Int, Any), plus, 0, itr)
sum7(A) = mapreduce(x->x, plus, A)
sum8(itr) = invoke(mapreduce, (Function, Function, Any), x->x, plus, itr)
sum9(A) = mapreduce(x->x, plus, 0, A)
sum10(itr) = invoke(mapreduce, (Function, Function, Int, Any), x->x,plus,0,itr)
for f in (sum2, sum3, sum4, sum5, sum6, sum7, sum8, sum9, sum10)
@test sum(z) == f(z)
@test sum(Int[]) == f(Int[]) == 0
@test sum(Int[7]) == f(Int[7]) == 7
if f == sum3 || f == sum4 || f == sum7 || f == sum8
@test typeof(f(Int8[])) == typeof(f(Int8[1 7]))
else
@test typeof(f(Int8[])) == typeof(f(Int8[1])) == typeof(f(Int8[1 7]))
end
end
@test typeof(sum(Int8[])) == typeof(sum(Int8[1])) == typeof(sum(Int8[1 7]))

prod2(itr) = invoke(prod, (Any,), itr)
@test prod(Int[]) == prod2(Int[]) == 1
@test prod(Int[7]) == prod2(Int[7]) == 7
@test typeof(prod(Int8[])) == typeof(prod(Int8[1])) == typeof(prod(Int8[1 7])) == typeof(prod2(Int8[])) == typeof(prod2(Int8[1])) == typeof(prod2(Int8[1 7]))

v = cell(2,2,1,1)
v[1,1,1,1] = 28.0
v[1,2,1,1] = 36.0
Expand Down