From 96d5482643f324a2f450da153a2431b0efd61b3a Mon Sep 17 00:00:00 2001 From: Samuel Berthe Date: Tue, 28 Jan 2025 12:35:04 +0100 Subject: [PATCH] fix(product): fix empty slice behavior (#583) (#584) --- math.go | 8 ++++---- math_test.go | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/math.go b/math.go index 5d2825ff..e3f42892 100644 --- a/math.go +++ b/math.go @@ -89,11 +89,11 @@ func SumBy[T any, R constraints.Float | constraints.Integer | constraints.Comple // Play: https://go.dev/play/p/2_kjM_smtAH func Product[T constraints.Float | constraints.Integer | constraints.Complex](collection []T) T { if collection == nil { - return 0 + return 1 } if len(collection) == 0 { - return 0 + return 1 } var product T = 1 @@ -107,11 +107,11 @@ func Product[T constraints.Float | constraints.Integer | constraints.Complex](co // Play: https://go.dev/play/p/wadzrWr9Aer func ProductBy[T any, R constraints.Float | constraints.Integer | constraints.Complex](collection []T, iteratee func(item T) R) R { if collection == nil { - return 0 + return 1 } if len(collection) == 0 { - return 0 + return 1 } var product R = 1 diff --git a/math_test.go b/math_test.go index 3c8e9016..d23029e8 100644 --- a/math_test.go +++ b/math_test.go @@ -108,14 +108,16 @@ func TestProduct(t *testing.T) { result5 := Product([]uint32{2, 3, 4, 5}) result6 := Product([]uint32{}) result7 := Product([]complex128{4_4, 2_2}) + result8 := Product[uint32](nil) is.Equal(result1, float32(160.908)) is.Equal(result2, int32(120)) is.Equal(result3, int32(0)) is.Equal(result4, int32(-126)) is.Equal(result5, uint32(120)) - is.Equal(result6, uint32(0)) + is.Equal(result6, uint32(1)) is.Equal(result7, complex128(96_8)) + is.Equal(result8, uint32(1)) } func TestProductBy(t *testing.T) { @@ -128,14 +130,16 @@ func TestProductBy(t *testing.T) { result5 := ProductBy([]uint32{2, 3, 4, 5}, func(n uint32) uint32 { return n }) result6 := ProductBy([]uint32{}, func(n uint32) uint32 { return n }) result7 := ProductBy([]complex128{4_4, 2_2}, func(n complex128) complex128 { return n }) + result8 := ProductBy(nil, func(n uint32) uint32 { return n }) is.Equal(result1, float32(160.908)) is.Equal(result2, int32(120)) is.Equal(result3, int32(0)) is.Equal(result4, int32(-126)) is.Equal(result5, uint32(120)) - is.Equal(result6, uint32(0)) + is.Equal(result6, uint32(1)) is.Equal(result7, complex128(96_8)) + is.Equal(result8, uint32(1)) } func TestMean(t *testing.T) {