Skip to content

Commit

Permalink
Fix #26488: don't map over values not provided (#26521)
Browse files Browse the repository at this point in the history
This is a symptom of the good old how-to-allocate-a-result-array-of-an-arbitrary-transform-of-its-elements problem. Eventually it'd be nice to solve this with `collect` of a lazy implementation, but for now this papers over the egregious problem.
  • Loading branch information
mbauman authored and JeffBezanson committed Mar 20, 2018
1 parent a7104ac commit aa7b3f9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
12 changes: 8 additions & 4 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,14 @@ let
[AbstractArray{t} for t in uniontypes(BitIntFloat)]...,
[AbstractArray{Complex{t}} for t in uniontypes(BitIntFloat)]...}

global reducedim_init(f, op::Union{typeof(+),typeof(add_sum)}, A::T, region) =
reducedim_initarray(A, region, mapreduce_first(f, op, zero(eltype(A))))
global reducedim_init(f, op::Union{typeof(*),typeof(mul_prod)}, A::T, region) =
reducedim_initarray(A, region, mapreduce_first(f, op, one(eltype(A))))
global function reducedim_init(f, op::Union{typeof(+),typeof(add_sum)}, A::T, region)
z = zero(f(zero(eltype(A))))
reducedim_initarray(A, region, op(z, z))
end
global function reducedim_init(f, op::Union{typeof(*),typeof(mul_prod)}, A::T, region)
u = one(f(one(eltype(A))))
reducedim_initarray(A, region, op(u, u))
end
end

## generic (map)reduction
Expand Down
12 changes: 12 additions & 0 deletions test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,18 @@ for region in Any[-1, 0, (-1, 2), [0, 1], (1,-2,3), [0 1;
@test_throws ArgumentError minimum(abs, Areduc, dims=region)
end

# issue #26488
@testset "don't map over initial values not provided" begin
@test sum(x->x+1, [1], dims=1)[1] === sum(x->x+1, [1]) === 2
@test prod(x->x+1, [1], dims=1)[1] === prod(x->x+1, [1]) === 2
@test mapreduce(x->x+1, +, [1], dims=1)[1] === mapreduce(x->x+1, +, [1]) === 2
@test mapreduce(x->x+1, *, [1], dims=1)[1] === mapreduce(x->x+1, *, [1]) === 2
@test mapreduce(!, &, [false], dims=1)[1] === mapreduce(!, &, [false]) === true
@test mapreduce(!, |, [true], dims=1)[1] === mapreduce(!, |, [true]) === false
@test mapreduce(x->1/x, max, [1], dims=1)[1] === mapreduce(x->1/x, max, [1]) === 1.0
@test mapreduce(x->-1/x, min, [1], dims=1)[1] === mapreduce(x->-1/x, min, [1]) === -1.0
end

# check type of result
@testset "type of sum(::Array{$T}" for T in [UInt8, Int8, Int32, Int64, BigInt]
result = sum(T[1 2 3; 4 5 6; 7 8 9], dims=2)
Expand Down

0 comments on commit aa7b3f9

Please sign in to comment.