From e015b3e66def8f0d4e0cb3b116572cad2d5d46a0 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 3 Jan 2022 11:52:07 +0800 Subject: [PATCH] simplify min/max init --- base/reducedim.jl | 47 +++++++++++------------------------------------ test/reduce.jl | 20 +++++++------------- 2 files changed, 18 insertions(+), 49 deletions(-) diff --git a/base/reducedim.jl b/base/reducedim.jl index fcf93ed027394c..d9571a575969b0 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -91,7 +91,6 @@ end # reducedim_initarray is called by reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init, ::Type{R}) where {R} = fill!(similar(A,R,reduced_indices(A,region)), init) reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init::T) where {T} = reducedim_initarray(A, region, init, T) -# TODO: extend this to minimum and maximum reducedim_initarray(A::AbstractArrayOrBroadcasted, region, ::UndefInitializer, ::Type{R}) where {R} = similar(A,R,reduced_indices(A,region)) # TODO: better way to handle reducedim initialization # @@ -126,45 +125,21 @@ function _reducedim_init(f, op, fv, fop, A, region) end # initialization when computing minima and maxima requires a little care -for (f1, f2, initval, typeextreme) in ((:min, :max, :Inf, :typemax), (:max, :min, :(-Inf), :typemin)) - @eval function reducedim_init(f, op::typeof($f1), A::AbstractArray, region) - # First compute the reduce indices. This will throw an ArgumentError - # if any region is invalid - ri = reduced_indices(A, region) +function reducedim_init(f::F, ::Union{typeof(min),typeof(max)}, A::AbstractArray, region) where {F} + # First compute the reduce indices. This will throw an ArgumentError + # if any region is invalid + ri = reduced_indices(A, region) - # Next, throw if reduction is over a region with length zero - any(i -> isempty(axes(A, i)), region) && _empty_reduce_error() + # Next, throw if reduction is over a region with length zero + any(i -> isempty(axes(A, i)), region) && _empty_reduce_error() - # Make a view of the first slice of the region - A1 = view(A, ri...) + # Make a view of the first slice of the region + A1 = view(A, ri...) - if isempty(A1) - # If the slice is empty just return non-view version as the initial array - return copy(A1) - else - # otherwise use the min/max of the first slice as initial value - v0 = mapreduce(f, $f2, A1) - - T = _realtype(f, promote_union(eltype(A))) - Tr = v0 isa T ? T : typeof(v0) - - # but NaNs and missing need to be avoided as initial values - if (v0 == v0) === false - # v0 is NaN - v0 = $initval - elseif isunordered(v0) - # v0 is missing or a third-party unordered value - Tnm = nonmissingtype(Tr) - # TODO: Some types, like BigInt, don't support typemin/typemax. - # So a Matrix{Union{BigInt, Missing}} can still error here. - v0 = $typeextreme(Tnm) - end - # v0 may have changed type. - Tr = v0 isa T ? T : typeof(v0) + # calculate the output type + T = promote_typejoin_union(_return_type(f, Tuple{eltype(A)})) - return reducedim_initarray(A, region, v0, Tr) - end - end + map!(f, reducedim_initarray(A,region,undef,T), A1) end reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(max), A::AbstractArray{T}, region) where {T} = reducedim_initarray(A, region, zero(f(zero(T))), _realtype(f, T)) diff --git a/test/reduce.jl b/test/reduce.jl index 86f12fec745cbc..c1506670d37504 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -378,16 +378,13 @@ A = circshift(reshape(1:24,2,3,4), (0,1,1)) @test size(extrema(A,dims=(1,2,3))) == size(maximum(A,dims=(1,2,3))) @test extrema(x->div(x, 2), A, dims=(2,3)) == reshape([(0,11),(1,12)],2,1,1) -# TODO: drop `a′` once `minimum` and `maximum` is fixed -# (the following test_broken pass) -function test_extrema(a, a′ = a; dims_test = ((), 1, 2, (1,2), 3)) +function test_extrema(a; dims_test = ((), 1, 2, (1,2), 3)) for dims in dims_test vext = extrema(a; dims) - vmin, vmax = minimum(a′; dims), maximum(a′; dims) - @test all(x -> isequal(x[1], x[2:3]), zip(vext,vmin,vmax)) || foreach(i -> display(i),(eltype(a), vext,vmin,vmax)) + vmin, vmax = minimum(a; dims), maximum(a; dims) + @test all(x -> isequal(x[1], x[2:3]), zip(vext,vmin,vmax)) end end -@test_broken minimum([missing BigInt(1)], dims = 2)[1] === missing @testset "0.0,-0.0 test for extrema with dims" begin @test extrema([-0.0;0.0], dims = 1)[1] === (-0.0,0.0) @test tuple(extrema([-0.0;0.0], dims = 2)...) === ((-0.0, -0.0), (0.0, 0.0)) @@ -395,18 +392,15 @@ end @testset "NaN/missing test for extrema with dims #43599" begin for sz = (3, 10, 100) for T in (Int, BigInt, Float64, BigFloat) - Aₘ = Matrix{Union{Float64, Missing}}(rand(-sz:sz, sz, sz)) + Aₘ = Matrix{Union{T, Missing}}(rand(-sz:sz, sz, sz)) Aₘ[rand(1:sz*sz, sz)] .= missing - ATₘ = Matrix{Union{T, Missing}}(Aₘ) - test_extrema(ATₘ, Aₘ) + test_extrema(Aₘ) if T <: AbstractFloat Aₙ = map(i -> ismissing(i) ? T(NaN) : i, Aₘ) - ATₙ = map(i -> ismissing(i) ? T(NaN) : i, ATₘ) - test_extrema(ATₙ, Aₙ) + test_extrema(Aₙ) p = rand(1:sz*sz, sz) Aₘ[p] .= NaN - ATₘ[p] .= NaN - test_extrema(ATₘ, Aₘ) + test_extrema(Aₘ) end end end