diff --git a/math.go b/math.go index 9dce28cf..f5793d49 100644 --- a/math.go +++ b/math.go @@ -82,3 +82,23 @@ func SumBy[T any, R constraints.Float | constraints.Integer | constraints.Comple } return sum } + +// Mean calculates the mean of a collection of numbers. +func Mean[T constraints.Float | constraints.Integer](collection []T) T { + var length T = T(len(collection)) + if length == 0 { + return 0 + } + var sum T = Sum(collection) + return sum / length +} + +// MeanBy calculates the mean of a collection of numbers using the given return value from the iteration function. +func MeanBy[T any, R constraints.Float | constraints.Integer](collection []T, iteratee func(item T) R) R { + var length R = R(len(collection)) + if length == 0 { + return 0 + } + var sum R = SumBy(collection, iteratee) + return sum / length +} diff --git a/math_example_test.go b/math_example_test.go index 00c6d972..e71a9b79 100644 --- a/math_example_test.go +++ b/math_example_test.go @@ -66,3 +66,21 @@ func ExampleSumBy() { fmt.Printf("%v", result) // Output: 6 } + +func ExampleMean() { + list := []int{1, 2, 3, 4, 5} + + result := Mean(list) + + fmt.Printf("%v", result) +} + +func ExampleMeanBy() { + list := []string{"foo", "bar"} + + result := MeanBy(list, func(item string) int { + return len(item) + }) + + fmt.Printf("%v", result) +} diff --git a/math_test.go b/math_test.go index d0bab70f..8e5efeda 100644 --- a/math_test.go +++ b/math_test.go @@ -97,3 +97,33 @@ func TestSumBy(t *testing.T) { is.Equal(result4, uint32(0)) is.Equal(result5, complex128(6_6)) } + +func TestMean(t *testing.T) { + t.Parallel() + is := assert.New(t) + + result1 := Mean([]float32{2.3, 3.3, 4, 5.3}) + result2 := Mean([]int32{2, 3, 4, 5}) + result3 := Mean([]uint32{2, 3, 4, 5}) + result4 := Mean([]uint32{}) + + is.Equal(result1, float32(3.7250001)) + is.Equal(result2, int32(3)) + is.Equal(result3, uint32(3)) + is.Equal(result4, uint32(0)) +} + +func TestMeanBy(t *testing.T) { + t.Parallel() + is := assert.New(t) + + result1 := MeanBy([]float32{2.3, 3.3, 4, 5.3}, func(n float32) float32 { return n }) + result2 := MeanBy([]int32{2, 3, 4, 5}, func(n int32) int32 { return n }) + result3 := MeanBy([]uint32{2, 3, 4, 5}, func(n uint32) uint32 { return n }) + result4 := MeanBy([]uint32{}, func(n uint32) uint32 { return n }) + + is.Equal(result1, float32(3.7250001)) + is.Equal(result2, int32(3)) + is.Equal(result3, uint32(3)) + is.Equal(result4, uint32(0)) +}