From 7ee4fd648b24f0cf89c1fd2cafbb8692c628f2ac Mon Sep 17 00:00:00 2001 From: Petr Vana Date: Thu, 2 Sep 2021 21:44:01 +0200 Subject: [PATCH] Fix `sum()` and `prod()` for tuples (#41510) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR aims to fix #39182 and #39183 by using the universal implementation of `prod` and `sum` from https://github.com/JuliaLang/julia/blob/97f817a379b0c3c5f9bb803427fe88a018ebfe18/base/reduce.jl#L588 However, the file `abstractarray.jl` is included way sooner, and it is crucial to have already a simplified version of `prod` function. We can specify a simplified version or `prod` only for a system-wide `Int` type that is sufficient to compile `Base`. ``` julia prod(x::Tuple{}) = 1 # This is consistent with the regular prod because there is no need for size promotion # if all elements in the tuple are of system size. prod(x::Tuple{Int, Vararg{Int}}) = *(x...) ``` Although the implementations are different, they lead to the same binary code for tuples containing ~~`UInt` and~~ `Int`. ``` julia julia> a = (1,2,3) (1, 2, 3) # Simplified version for tuples containing Int only julia> prod_simplified(x::Tuple{Int, Vararg{Int}}) = *(x...) julia> @code_native prod_simplified(a) .text ; ┌ @ REPL[1]:1 within `prod_simplified' ; │┌ @ operators.jl:560 within `*' @ int.jl:88 movq 8(%rdi), %rax imulq (%rdi), %rax imulq 16(%rdi), %rax ; │└ retq nop ; └ ``` ``` julia # Regular prod without the simplification julia> @code_native prod(a) .text ; ┌ @ reduce.jl:588 within `prod` ; │┌ @ reduce.jl:588 within `#prod#247` ; ││┌ @ reduce.jl:289 within `mapreduce` ; │││┌ @ reduce.jl:289 within `#mapreduce#240` ; ││││┌ @ reduce.jl:162 within `mapfoldl` ; │││││┌ @ reduce.jl:162 within `#mapfoldl#236` ; ││││││┌ @ reduce.jl:44 within `mapfoldl_impl` ; │││││││┌ @ reduce.jl:48 within `foldl_impl` ; ││││││││┌ @ tuple.jl:276 within `_foldl_impl` ; │││││││││┌ @ operators.jl:613 within `afoldl` ; ││││││││││┌ @ reduce.jl:81 within `BottomRF` ; │││││││││││┌ @ reduce.jl:38 within `mul_prod` ; ││││││││││││┌ @ int.jl:88 within `*` movq 8(%rdi), %rax imulq (%rdi), %rax ; │││││││││└└└└ ; │││││││││┌ @ operators.jl:614 within `afoldl` ; ││││││││││┌ @ reduce.jl:81 within `BottomRF` ; │││││││││││┌ @ reduce.jl:38 within `mul_prod` ; ││││││││││││┌ @ int.jl:88 within `*` imulq 16(%rdi), %rax ; │└└└└└└└└└└└└ retq nop ; └ ``` --- base/tuple.jl | 15 +++++---------- test/tuple.jl | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/base/tuple.jl b/base/tuple.jl index 2e1db9a316bfd4..77fa6ba0ea1a37 100644 --- a/base/tuple.jl +++ b/base/tuple.jl @@ -488,17 +488,12 @@ reverse(t::Tuple) = revargs(t...) ## specialized reduction ## -# TODO: these definitions cannot yet be combined, since +(x...) -# where x might be any tuple matches too many methods. -# TODO: this is inconsistent with the regular sum in cases where the arguments -# require size promotion to system size. -sum(x::Tuple{Any, Vararg{Any}}) = +(x...) - -# NOTE: should remove, but often used on array sizes -# TODO: this is inconsistent with the regular prod in cases where the arguments -# require size promotion to system size. prod(x::Tuple{}) = 1 -prod(x::Tuple{Any, Vararg{Any}}) = *(x...) +# This is consistent with the regular prod because there is no need for size promotion +# if all elements in the tuple are of system size. +# It is defined here separately in order to support bootstrap, because it's needed earlier +# than the general prod definition is available. +prod(x::Tuple{Int, Vararg{Int}}) = *(x...) all(x::Tuple{}) = true all(x::Tuple{Bool}) = x[1] diff --git a/test/tuple.jl b/test/tuple.jl index bdfaae6bf10328..913f024240e7ae 100644 --- a/test/tuple.jl +++ b/test/tuple.jl @@ -361,6 +361,24 @@ end @test prod(()) === 1 @test prod((1,2,3)) === 6 + # issue 39182 + @test sum((0xe1, 0x1f)) === sum([0xe1, 0x1f]) + @test sum((Int8(3),)) === Int(3) + @test sum((UInt8(3),)) === UInt(3) + @test sum((3,)) === Int(3) + @test sum((3.0,)) === 3.0 + @test sum(("a",)) == sum(["a"]) + @test sum((0xe1, 0x1f), init=0x0) == sum([0xe1, 0x1f], init=0x0) + + # issue 39183 + @test prod((Int8(100), Int8(100))) === 10000 + @test prod((Int8(3),)) === Int(3) + @test prod((UInt8(3),)) === UInt(3) + @test prod((3,)) === Int(3) + @test prod((3.0,)) === 3.0 + @test prod(("a",)) == prod(["a"]) + @test prod((0xe1, 0x1f), init=0x1) == prod([0xe1, 0x1f], init=0x1) + @testset "all" begin @test all(()) === true @test all((false,)) === false