Skip to content

Commit

Permalink
Merge pull request JuliaLang#66 from JuliaStats/sjk/mean
Browse files Browse the repository at this point in the history
n-dimensional weighted mean
  • Loading branch information
lindahua committed May 26, 2014
2 parents 98c4868 + 1ab08cb commit 875975b
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 38 deletions.
2 changes: 2 additions & 0 deletions src/StatsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ module StatsBase
harmmean, # harmonic mean
trimmean, # trimmed mean
wmean, # weighted mean
wsum, # weighted sum with vector as second argument
wsum!, # in-place weighted sum across dimensions

# scalar_stats
skewness, # (standardized) skewness
Expand Down
68 changes: 50 additions & 18 deletions src/means.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,69 @@ end

# Weighted means

mean{T<:Number}(v::AbstractArray{T}, w::WeightVec) = dot(v, values(w)) / sum(w)
# 1D weighted sum/mean
wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w)
Base.sum(v::BitArray, w::WeightVec) = wsum(v, values(w))
Base.sum(v::SparseMatrixCSC, w::WeightVec) = wsum(v, values(w))
Base.sum(v::AbstractArray, w::WeightVec) = wsum(v, values(w))
Base.mean(v::AbstractArray, w::WeightVec) = sum(v, w) / sum(w)

function wmean{T<:Number}(v::AbstractArray{T}, w::AbstractArray)
Base.depwarn("wmean is deprecated, use mean(v, weights(w)) instead.", :wmean)
mean(v, weights(w))
end

function mean!{T<:Number,W<:Real}(r::AbstractVector, v::AbstractMatrix{T}, w::WeightVec{W}, dim::Int)
m, n = size(v)
if dim == 1
(length(r) == n && length(w) == m) || throw(DimensionMismatch("Dimensions mismatch"))
At_mul_B!(r, v, values(w))
elseif dim == 2
(length(r) == m && length(w) == n) || throw(DimensionMismatch("Dimensions mismatch"))
A_mul_B!(r, v, values(w))
else
error("Invalid value of dim.")
# General Cartesian-based weighted sum across dimensions
import Base.Cartesian: @ngenerate, @nloops, @nref
@ngenerate N typeof(r) function wsum!{T,N,S,W<:Real}(r::AbstractArray{T,N}, v::AbstractArray{S,N},
w::AbstractVector{W}, dim::Int)
1 <= dim <= N || error("dim = $dim not in range [1,$N]")
for i = 1:N
(i == dim && size(r, i) == 1 && size(v, i) == length(w)) || size(r, i) == size(v, i) || error(DimensionMismatch(""))
end
return scale!(r, inv(sum(w)))
fill!(r, 0)
weight = zero(W)
@nloops N i v d->(if d == dim
weight = w[i_d]
j_d = 1
else
j_d = i_d
end) @inbounds (@nref N r j) += (@nref N v i)*weight
r
end

function mean{T<:Number,W<:Real}(v::AbstractMatrix{T}, w::WeightVec{W}, dim::Int)
R = typeof(float(one(T) * one(W)))
m, n = size(v)
dim == 1 ? reshape(mean!(Array(R, n), v, w, 1), 1, n) :
dim == 2 ? reshape(mean!(Array(R, m), v, w, 2), m, 1) :
error("Invalid value of dim.")
# Weighted sum via `A_mul_B!`/`At_mul_B!` for first and last
# dimensions of compatible arrays. `vec` and `reshape` are only
# guaranteed not to make a copy for Arrays, so only supports Arrays if
# these calls may be necessary.
function wsum!{W<:Real}(r::Union(Array, AbstractVector), v::Union(Array, AbstractMatrix), w::AbstractVector{W}, dim::Int)
if dim == 1
m = size(v, 1)
n = div(length(v), m)
(length(r) == n && length(w) == m) || throw(DimensionMismatch(""))
At_mul_B!(vec(r), isa(v, AbstractMatrix) ? v : reshape(v, m, n), w)
elseif dim == ndims(v)
n = size(v, ndims(v))
m = div(length(v), n)
(length(r) == m && length(w) == n) || throw(DimensionMismatch(""))
A_mul_B!(vec(r), isa(v, AbstractMatrix) ? v : reshape(v, m, n), w)
else
invoke(wsum!, (AbstractArray, AbstractArray, typeof(w), Int), r, v, w, dim)
end
r
end

Base.sum!{W<:Real}(r::AbstractArray, v::AbstractArray, w::WeightVec{W}, dim::Int) =
wsum!(r, v, values(w), dim)

wsum{T<:Number,W<:Real}(v::AbstractArray{T}, w::AbstractVector{W}, dim::Int) =
wsum!(Array(typeof(zero(T)*zero(W) + zero(T)*zero(W)), Base.reduced_dims(size(v), dim)), v, w, dim)

Base.sum{T<:Number,W<:Real}(v::AbstractArray{T}, w::WeightVec{W}, dim::Int) = wsum(v, values(w), dim)

Base.mean!(r::AbstractArray, v::AbstractArray, w::WeightVec, dim::Int) =
scale!(Base.sum!(r, v, w, dim), inv(sum(w)))

Base.mean{T<:Number,W<:Real}(v::AbstractArray{T}, w::WeightVec{W}, dim::Int) =
mean!(Array(typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W)), Base.reduced_dims(size(v), dim)), v, w, dim)

62 changes: 42 additions & 20 deletions test/means.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,53 @@ using Base.Test
@test_approx_eq trimmean([-100, 2, 3, 7, 200], 0.4) 4.0
@test_approx_eq trimmean([-100, 2, 3, 7, 200], 0.8) 3.0

@test_approx_eq mean([1.0, 2.0, 3.0], weights([1/3, 1/3, 1/3])) 2.0
@test_approx_eq mean([1.0, 2.0, 3.0], weights([1.0, 0.0, 0.0])) 1.0
@test_approx_eq mean([1.0, 2.0, 3.0], weights([0.0, 1.0, 0.0])) 2.0
@test_approx_eq mean([1.0, 2.0, 3.0], weights([0.0, 0.0, 1.0])) 3.0
@test_approx_eq mean([1.0, 2.0, 3.0], weights([0.5, 0.0, 0.5])) 2.0
@test_approx_eq mean([1.0, 2.0, 3.0], weights([0.5, 0.5, 0.0])) 1.5
@test_approx_eq mean([1.0, 2.0, 3.0], weights([0.0, 0.5, 0.5])) 2.5

@test_approx_eq mean(1:3, weights([1/3, 1/3, 1/3])) 2.0
@test_approx_eq mean(1:3, weights([1.0, 0.0, 0.0])) 1.0
@test_approx_eq mean(1:3, weights([0.0, 1.0, 0.0])) 2.0
@test_approx_eq mean(1:3, weights([0.0, 0.0, 1.0])) 3.0
@test_approx_eq mean(1:3, weights([0.5, 0.0, 0.5])) 2.0
@test_approx_eq mean(1:3, weights([0.5, 0.5, 0.0])) 1.5
@test_approx_eq mean(1:3, weights([0.0, 0.5, 0.5])) 2.5
@test_approx_eq sum([1.0, 2.0, 3.0], weights([1/3, 1/3, 1/3])) 2.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([1.0, 0.0, 0.0])) 1.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.0, 1.0, 0.0])) 2.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.0, 0.0, 1.0])) 3.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.5, 0.0, 0.5])) 2.0
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.5, 0.5, 0.0])) 1.5
@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.0, 0.5, 0.5])) 2.5

@test_approx_eq sum(1:3, weights([1/3, 1/3, 1/3])) 2.0
@test_approx_eq sum(1:3, weights([1.0, 0.0, 0.0])) 1.0
@test_approx_eq sum(1:3, weights([0.0, 1.0, 0.0])) 2.0
@test_approx_eq sum(1:3, weights([0.0, 0.0, 1.0])) 3.0
@test_approx_eq sum(1:3, weights([0.5, 0.0, 0.5])) 2.0
@test_approx_eq sum(1:3, weights([0.5, 0.5, 0.0])) 1.5
@test_approx_eq sum(1:3, weights([0.0, 0.5, 0.5])) 2.5
@test_approx_eq sum(1:3, weights([1.0, 1.0, 0.5])) 4.5
@test_approx_eq mean(1:3, weights([1.0, 1.0, 0.5])) 1.8

a = [1. 2. 3.; 4. 5. 6.]

@test size(mean(a, weights(ones(2)), 1)) == (1, 3)
@test_approx_eq sum(a, weights([1.0, 1.0]), 1) [5.0, 7.0, 9.0]
@test_approx_eq mean(a, weights([1.0, 1.0]), 1) [2.5, 3.5, 4.5]
@test_approx_eq mean(a, weights([1.0, 0.0]), 1) [1.0, 2.0, 3.0]
@test_approx_eq mean(a, weights([0.0, 1.0]), 1) [4.0, 5.0, 6.0]
@test_approx_eq sum(a, weights([1.0, 0.0]), 1) [1.0, 2.0, 3.0]
@test_approx_eq sum(a, weights([0.0, 1.0]), 1) [4.0, 5.0, 6.0]

@test size(mean(a, weights(ones(3)), 2)) == (2, 1)
@test_approx_eq mean(a, weights([1.0, 1.0, 1.0]), 2) [2.0, 5.0]
@test_approx_eq mean(a, weights([1.0, 0.0, 0.0]), 2) [1.0, 4.0]
@test_approx_eq mean(a, weights([0.0, 0.0, 1.0]), 2) [3.0, 6.0]
@test_approx_eq wsum!(zeros(1, 2), a, [1.0, 1.0, 1.0], 2) [6.0 15.0]
@test_approx_eq wsum(a, [1.0, 1.0, 1.0], 2) [6.0 15.0]
@test_approx_eq sum!(zeros(1, 2), a, weights([1.0, 1.0, 1.0]), 2) [6.0 15.0]
@test_approx_eq sum(a, weights([1.0, 1.0, 1.0]), 2) [6.0 15.0]
@test_approx_eq mean(a, weights([1.0, 1.0, 1.0]), 2) [2.0 5.0]
@test_approx_eq sum(a, weights([1.0, 0.0, 0.0]), 2) [1.0 4.0]
@test_approx_eq sum(a, weights([0.0, 0.0, 1.0]), 2) [3.0 6.0]

@test_throws ErrorException mean(a, weights(ones(3)), 3)
@test_throws DimensionMismatch mean(a, weights(ones(2)), 2)
@test_throws DimensionMismatch mean!(ones(1, 1), a, weights(ones(3)), 2)

a = reshape(1.0:27.0, 3, 3, 3)

for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0])
@test_approx_eq sum(a, weights(wt), 1) sum(a.*reshape(wt, length(wt), 1, 1), 1)
@test_approx_eq sum(a, weights(wt), 2) sum(a.*reshape(wt, 1, length(wt), 1), 2)
@test_approx_eq sum(a, weights(wt), 3) sum(a.*reshape(wt, 1, 1, length(wt)), 3)
@test_approx_eq mean(a, weights(wt), 1) sum(a.*reshape(wt, length(wt), 1, 1), 1)/sum(wt)
@test_approx_eq mean(a, weights(wt), 2) sum(a.*reshape(wt, 1, length(wt), 1), 2)/sum(wt)
@test_approx_eq mean(a, weights(wt), 3) sum(a.*reshape(wt, 1, 1, length(wt)), 3)/sum(wt)
@test_throws ErrorException mean(a, weights(wt), 4)
end

0 comments on commit 875975b

Please sign in to comment.