diff --git a/src/weights.jl b/src/weights.jl index cf535d408..d17ec9bde 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -682,6 +682,32 @@ function mean(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:) return mean(A, dims=dims) end +""" + mean(f, A::AbstractArray, w::AbstractWeights) + +Compute the weighted mean of array `A`, after transforming it'S +contents with the function `f`, with weight vector `w` (of type +`AbstractWeights`). + +# Examples +```julia +n = 20 +x = rand(n) +w = rand(n) +mean(√, x, weights(w)) +``` +""" +function mean(f, A::AbstractArray, w::AbstractWeights) + return sum(Broadcast.instantiate(Broadcast.broadcasted(A, w) do a_i, wg + return f(a_i) * wg + end)) / sum(w) +end + +function mean(f, A::AbstractArray, w::UnitWeights) + length(A) != length(w) && throw(DimensionMismatch("Inconsistent array dimension.")) + return mean(f, A) +end + ##### Weighted quantile ##### """ diff --git a/test/weights.jl b/test/weights.jl index 52142efd8..e80ca96a0 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -270,6 +270,27 @@ end @test mean(a, f(wt), dims=3) ≈ sum(a.*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) @test_throws ErrorException mean(a, f(wt), dims=4) end + + @test mean(√, [1:3;], f([1.0, 1.0, 0.5])) ≈ 1.3120956 + @test mean(√, 1:3, f([1.0, 1.0, 0.5])) ≈ 1.3120956 + @test mean(√, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ 1.60824421 + 0.88948688im + + @test mean(log, [1:3;], f([1.0, 1.0, 0.5])) ≈ 0.49698133 + @test mean(log, 1:3, f([1.0, 1.0, 0.5])) ≈ 0.49698133 + @test mean(log, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ 1.155407982 + 1.03678427im + + @test mean(x -> x^2, [1:3;], f([1.0, 1.0, 0.5])) ≈ 3.8 + @test mean(x -> x^2, 1:3, f([1.0, 1.0, 0.5])) ≈ 3.8 + @test mean(x -> x^2, [1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ -5.0 + 16.0im + + c = 1.0:9.0 + w = UnitWeights{Float64}(9) + @test mean(√, c, w) ≈ sum(sqrt.(c)) / length(c) + @test_throws DimensionMismatch mean(√, c, UnitWeights{Float64}(6)) + @test mean(log, c, w) ≈ sum(log.(c)) / length(c) + @test_throws DimensionMismatch mean(log, c, UnitWeights{Float64}(6)) + @test mean(x -> x^2, c, w) ≈ sum(c.^2) / length(c) + @test_throws DimensionMismatch mean(x -> x^2, c, UnitWeights{Float64}(6)) end @testset "Quantile fweights" begin