From d6a36af7d55437987573967d5f31db95e1d0b75e Mon Sep 17 00:00:00 2001 From: Debartha Paul Date: Tue, 1 Aug 2023 08:42:56 +0530 Subject: [PATCH 1/8] Weighted mean with function Adds the method for a weighted mean of elements transformed by a function. - Added `mean(f, itr, weights)` - Added tests for the method --- src/weights.jl | 25 +++++++++++++++++++++++++ test/weights.jl | 11 +++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/weights.jl b/src/weights.jl index cf535d408..80c507967 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -682,6 +682,31 @@ function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) return mean(A, dims=dims) end +""" + mean(f, A::AbstractArray, w::AbstractWeights[, dims::Int]) + +Compute the weighted mean of array `A`, after transforming it'S +contents with the function `f`, with weight vector `w` (of type +`AbstractWeights`). If `dim` is provided, compute the +weighted mean along dimension `dims`. + +# Examples +```julia +n = 20 +x = rand(n) +w = rand(n) +mean(√, x, weights(w)) +``` +""" +mean(f, A::AbstractArray, w::AbstractWeights; dims::Union{Colon,Int}=:) = + _mean(f.(A), w, dims) + +function mean(f, A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) + a = (dims === :) ? length(A) : size(A, dims) + a != length(w) && throw(DimensionMismatch("Inconsistent array dimension.")) + return mean(f.(A), dims=dims) +end + ##### Weighted quantile ##### """ diff --git a/test/weights.jl b/test/weights.jl index 52142efd8..0466805d8 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -270,6 +270,17 @@ end @test mean(a, f(wt), dims=3) ≈ sum(a.*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) @test_throws ErrorException mean(a, f(wt), dims=4) end + + @test mean(√, [1:3;], f([1.0, 1.0, 0.5])) ≈ 1.3120956 + @test mean(√, 1:3, f([1.0, 1.0, 0.5])) ≈ 1.3120956 + @test mean(√, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ 1.60824421 + 0.88948688im + + for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) + @test mean(√, a, f(wt), dims=1) ≈ sum(sqrt.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) + @test mean(√, a, f(wt), dims=2) ≈ sum(sqrt.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) + @test mean(√, a, f(wt), dims=3) ≈ sum(sqrt.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) + @test_throws ErrorException mean(√, a, f(wt), dims=4) + end end @testset "Quantile fweights" begin From 2acfbfa665b639348094cd080dd5b4c1757293e5 Mon Sep 17 00:00:00 2001 From: Debartha Paul Date: Tue, 1 Aug 2023 20:02:47 +0530 Subject: [PATCH 2/8] Added more tests for coverage Added tests for UnitWeights --- test/weights.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/weights.jl b/test/weights.jl index 0466805d8..a5b6c0669 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -281,6 +281,16 @@ end @test mean(√, a, f(wt), dims=3) ≈ sum(sqrt.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) @test_throws ErrorException mean(√, a, f(wt), dims=4) end + + b = reshape(1.0:9.0, 3, 3) + w = UnitWeights{Float64}(3) + @test mean(√, b, w; dims=1) ≈ reshape(w, :, 3) * sqrt.(b) / sum(w) + @test mean(√, b, w; dims=2) ≈ sqrt.(b) * w / sum(w) + + c = 1.0:9.0 + w = UnitWeights{Float64}(9) + @test mean(√, c, w) ≈ sum(sqrt.(c)) / length(c) + @test_throws DimensionMismatch mean(√, c, UnitWeights{Float64}(6)) end @testset "Quantile fweights" begin From b8be0a536a3261270bd80f5dac9673e2f66c2961 Mon Sep 17 00:00:00 2001 From: Debartha Paul Date: Thu, 3 Aug 2023 22:45:34 +0530 Subject: [PATCH 3/8] Minor modifications - Add keyword arguments for the weights - Modified functions to use `Iterators.map` - Add more tests --- src/weights.jl | 10 +++++----- test/weights.jl | 32 +++++++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index 80c507967..274213dd0 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -683,11 +683,11 @@ function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) end """ - mean(f, A::AbstractArray, w::AbstractWeights[, dims::Int]) + mean(f, A::AbstractArray, w::AbstractWeights[; dims]) Compute the weighted mean of array `A`, after transforming it'S contents with the function `f`, with weight vector `w` (of type -`AbstractWeights`). If `dim` is provided, compute the +`AbstractWeights`). If `dims` is provided, compute the weighted mean along dimension `dims`. # Examples @@ -698,13 +698,13 @@ w = rand(n) mean(√, x, weights(w)) ``` """ -mean(f, A::AbstractArray, w::AbstractWeights; dims::Union{Colon,Int}=:) = - _mean(f.(A), w, dims) +mean(f, A::AbstractArray, w::AbstractWeights; kwargs...) = + mean(collect(Iterators.map(f, A)), w; kwargs...) function mean(f, A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) a = (dims === :) ? length(A) : size(A, dims) a != length(w) && throw(DimensionMismatch("Inconsistent array dimension.")) - return mean(f.(A), dims=dims) + return mean(collect(Iterators.map(f, A)), dims=dims) end ##### Weighted quantile ##### diff --git a/test/weights.jl b/test/weights.jl index a5b6c0669..8c855fe23 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -275,22 +275,48 @@ end @test mean(√, 1:3, f([1.0, 1.0, 0.5])) ≈ 1.3120956 @test mean(√, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ 1.60824421 + 0.88948688im + @test mean(log, [1:3;], f([1.0, 1.0, 0.5])) ≈ 0.49698133 + @test mean(log, 1:3, f([1.0, 1.0, 0.5])) ≈ 0.49698133 + @test mean(log, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ 1.155407982 + 1.03678427im + + @test mean(x -> x^2, [1:3;], f([1.0, 1.0, 0.5])) ≈ 3.8 + @test mean(x -> x^2, 1:3, f([1.0, 1.0, 0.5])) ≈ 3.8 + @test mean(x -> x^2, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ -5.0 + 16.0im + for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) - @test mean(√, a, f(wt), dims=1) ≈ sum(sqrt.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) - @test mean(√, a, f(wt), dims=2) ≈ sum(sqrt.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) - @test mean(√, a, f(wt), dims=3) ≈ sum(sqrt.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) + @test mean(√, a, f(wt); dims=1) ≈ sum(sqrt.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) + @test mean(√, a, f(wt); dims=2) ≈ sum(sqrt.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) + @test mean(√, a, f(wt); dims=3) ≈ sum(sqrt.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) @test_throws ErrorException mean(√, a, f(wt), dims=4) + + @test mean(log, a, f(wt); dims=1) ≈ sum(log.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) + @test mean(log, a, f(wt); dims=2) ≈ sum(log.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) + @test mean(log, a, f(wt); dims=3) ≈ sum(log.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) + @test_throws ErrorException mean(log, a, f(wt), dims=4) + + @test mean(x -> x^2, a, f(wt); dims=1) ≈ sum((a.^2).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) + @test mean(x -> x^2, a, f(wt); dims=2) ≈ sum((a.^2).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) + @test mean(x -> x^2, a, f(wt); dims=3) ≈ sum((a.^2).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) + @test_throws ErrorException mean(log, a, f(wt), dims=4) end b = reshape(1.0:9.0, 3, 3) w = UnitWeights{Float64}(3) @test mean(√, b, w; dims=1) ≈ reshape(w, :, 3) * sqrt.(b) / sum(w) @test mean(√, b, w; dims=2) ≈ sqrt.(b) * w / sum(w) + @test mean(log, b, w; dims=1) ≈ reshape(w, :, 3) * log.(b) / sum(w) + @test mean(log, b, w; dims=2) ≈ log.(b) * w / sum(w) + @test mean(x -> x^2, b, w; dims=1) ≈ reshape(w, :, 3) * (b.^2) / sum(w) + @test mean(x -> x^2, b, w; dims=2) ≈ (b.^2) * w / sum(w) c = 1.0:9.0 w = UnitWeights{Float64}(9) @test mean(√, c, w) ≈ sum(sqrt.(c)) / length(c) @test_throws DimensionMismatch mean(√, c, UnitWeights{Float64}(6)) + @test mean(log, c, w) ≈ sum(log.(c)) / length(c) + @test_throws DimensionMismatch mean(log, c, UnitWeights{Float64}(6)) + @test mean(x -> x^2, c, w) ≈ sum(c.^2) / length(c) + @test_throws DimensionMismatch mean(x -> x^2, c, UnitWeights{Float64}(6)) end @testset "Quantile fweights" begin From 7653f2c5f66015b0e9ffedc27e7f58538384332a Mon Sep 17 00:00:00 2001 From: Debartha Paul Date: Wed, 23 Aug 2023 07:20:23 +0530 Subject: [PATCH 4/8] Corrections --- src/weights.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index 274213dd0..20175ea9f 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -698,13 +698,13 @@ w = rand(n) mean(√, x, weights(w)) ``` """ -mean(f, A::AbstractArray, w::AbstractWeights; kwargs...) = - mean(collect(Iterators.map(f, A)), w; kwargs...) +mean(f, A, w::AbstractWeights; kwargs...) = + mean(broadcast(f, A), w; kwargs...) function mean(f, A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) a = (dims === :) ? length(A) : size(A, dims) a != length(w) && throw(DimensionMismatch("Inconsistent array dimension.")) - return mean(collect(Iterators.map(f, A)), dims=dims) + return mean(broadcast(f, A), dims=dims) end ##### Weighted quantile ##### From 448e6eee1f6562c6d0789b32d68be4b0a433a140 Mon Sep 17 00:00:00 2001 From: Debartha Paul Date: Sun, 3 Sep 2023 19:41:38 +0530 Subject: [PATCH 5/8] Try to use Broadcast All checks not passed --- src/weights.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index 20175ea9f..f6bbce73e 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -698,8 +698,12 @@ w = rand(n) mean(√, x, weights(w)) ``` """ -mean(f, A, w::AbstractWeights; kwargs...) = - mean(broadcast(f, A), w; kwargs...) +function mean(f, A, w::AbstractWeights; kwargs...) + functionweightedsum = sum(Broadcast.instantiate(Broadcast.broadcasted(f, A, w) do f, x_i, w + return f(x_i) * w + end); kwargs...) + return functionweightedsum / sum(w) +end function mean(f, A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) a = (dims === :) ? length(A) : size(A, dims) From 6a6997ab73a3a37e6c3edd053003d5932f50cccb Mon Sep 17 00:00:00 2001 From: Debartha Paul Date: Fri, 6 Oct 2023 09:06:15 +0530 Subject: [PATCH 6/8] Fixings and finalising - Removed implementation for multi-dimensional array - Updated documentations - Updated tests --- src/weights.jl | 24 ++++++++++++------------ test/weights.jl | 26 -------------------------- 2 files changed, 12 insertions(+), 38 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index f6bbce73e..db664584c 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -683,12 +683,11 @@ function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) end """ - mean(f, A::AbstractArray, w::AbstractWeights[; dims]) + mean(f, A::AbstractArray, w::AbstractWeights) Compute the weighted mean of array `A`, after transforming it'S contents with the function `f`, with weight vector `w` (of type -`AbstractWeights`). If `dims` is provided, compute the -weighted mean along dimension `dims`. +`AbstractWeights`). # Examples ```julia @@ -698,17 +697,18 @@ w = rand(n) mean(√, x, weights(w)) ``` """ -function mean(f, A, w::AbstractWeights; kwargs...) - functionweightedsum = sum(Broadcast.instantiate(Broadcast.broadcasted(f, A, w) do f, x_i, w - return f(x_i) * w - end); kwargs...) - return functionweightedsum / sum(w) +mean(f, A::AbstractArray, w::AbstractWeights) = +_funcweightedmean(f, A, w) + +function _funcweightedmean(f, A::AbstractArray, w::AbstractWeights) + return sum(Broadcast.broadcasted(f, A, w) do f, a_i, wg + return f(a_i) * wg + end) / sum(w) end -function mean(f, A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) - a = (dims === :) ? length(A) : size(A, dims) - a != length(w) && throw(DimensionMismatch("Inconsistent array dimension.")) - return mean(broadcast(f, A), dims=dims) +function mean(f, A::AbstractArray, w::UnitWeights) + length(A) != length(w) && throw(DimensionMismatch("Inconsistent array dimension.")) + return mean(f, A) end ##### Weighted quantile ##### diff --git a/test/weights.jl b/test/weights.jl index 8c855fe23..e80ca96a0 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -283,32 +283,6 @@ end @test mean(x -> x^2, 1:3, f([1.0, 1.0, 0.5])) ≈ 3.8 @test mean(x -> x^2, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ -5.0 + 16.0im - for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) - @test mean(√, a, f(wt); dims=1) ≈ sum(sqrt.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) - @test mean(√, a, f(wt); dims=2) ≈ sum(sqrt.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) - @test mean(√, a, f(wt); dims=3) ≈ sum(sqrt.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) - @test_throws ErrorException mean(√, a, f(wt), dims=4) - - @test mean(log, a, f(wt); dims=1) ≈ sum(log.(a).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) - @test mean(log, a, f(wt); dims=2) ≈ sum(log.(a).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) - @test mean(log, a, f(wt); dims=3) ≈ sum(log.(a).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) - @test_throws ErrorException mean(log, a, f(wt), dims=4) - - @test mean(x -> x^2, a, f(wt); dims=1) ≈ sum((a.^2).*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) - @test mean(x -> x^2, a, f(wt); dims=2) ≈ sum((a.^2).*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) - @test mean(x -> x^2, a, f(wt); dims=3) ≈ sum((a.^2).*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) - @test_throws ErrorException mean(log, a, f(wt), dims=4) - end - - b = reshape(1.0:9.0, 3, 3) - w = UnitWeights{Float64}(3) - @test mean(√, b, w; dims=1) ≈ reshape(w, :, 3) * sqrt.(b) / sum(w) - @test mean(√, b, w; dims=2) ≈ sqrt.(b) * w / sum(w) - @test mean(log, b, w; dims=1) ≈ reshape(w, :, 3) * log.(b) / sum(w) - @test mean(log, b, w; dims=2) ≈ log.(b) * w / sum(w) - @test mean(x -> x^2, b, w; dims=1) ≈ reshape(w, :, 3) * (b.^2) / sum(w) - @test mean(x -> x^2, b, w; dims=2) ≈ (b.^2) * w / sum(w) - c = 1.0:9.0 w = UnitWeights{Float64}(9) @test mean(√, c, w) ≈ sum(sqrt.(c)) / length(c) From f9984f8173854c77ab61d097f5caf580dd9324c7 Mon Sep 17 00:00:00 2001 From: Debartha Paul Date: Sun, 3 Dec 2023 21:57:30 +0530 Subject: [PATCH 7/8] Changes as requested by @devmotion Used `Broadcast.instantiate` as requested to overcome falling back to Cartesian indexing --- src/weights.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index db664584c..cd51c129b 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -701,9 +701,9 @@ mean(f, A::AbstractArray, w::AbstractWeights) = _funcweightedmean(f, A, w) function _funcweightedmean(f, A::AbstractArray, w::AbstractWeights) - return sum(Broadcast.broadcasted(f, A, w) do f, a_i, wg + return sum(Broadcast.instantiate(Broadcast.broadcasted(A, w) do a_i, wg return f(a_i) * wg - end) / sum(w) + end)) / sum(w) end function mean(f, A::AbstractArray, w::UnitWeights) From c1768a69150ecfd83dc84ccf45d62173c2636862 Mon Sep 17 00:00:00 2001 From: Debartha Paul Date: Sun, 17 Dec 2023 14:52:27 +0530 Subject: [PATCH 8/8] Remove internal method `_funcweightedmean` Instead deploy it as a method for `mean` --- src/weights.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index cd51c129b..d17ec9bde 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -697,10 +697,7 @@ w = rand(n) mean(√, x, weights(w)) ``` """ -mean(f, A::AbstractArray, w::AbstractWeights) = -_funcweightedmean(f, A, w) - -function _funcweightedmean(f, A::AbstractArray, w::AbstractWeights) +function mean(f, A::AbstractArray, w::AbstractWeights) return sum(Broadcast.instantiate(Broadcast.broadcasted(A, w) do a_i, wg return f(a_i) * wg end)) / sum(w)