Skip to content

Commit

Permalink
Improve performance of weighted sum (#778)
Browse files Browse the repository at this point in the history
The current code is calling the `AbstractArray` matrix multiplication fallback,
which is slower than BLAS.
  • Loading branch information
nalimilan authored Mar 31, 2022
1 parent 5c011db commit e8ab265
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
18 changes: 12 additions & 6 deletions src/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,18 @@ Compute the weighted sum of an array `v` with weights `w`, optionally over the d
"""
wsum(v::AbstractArray, w::AbstractVector, dims::Colon=:) = transpose(w) * vec(v)

# Optimized methods (to ensure we use BLAS when possible)
for W in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights)
@eval begin
wsum(v::AbstractArray, w::$W, dims::Colon) = transpose(w.values) * vec(v)
end
end

function wsum(A::AbstractArray, w::UnitWeights, dims::Colon)
length(A) != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return sum(A)
end

## wsum along dimension
#
# Brief explanation of the algorithm:
Expand Down Expand Up @@ -605,12 +617,6 @@ optionally over the dimension `dims`.
Base.sum(A::AbstractArray, w::AbstractWeights{<:Real}; dims::Union{Colon,Int}=:) =
wsum(A, w, dims)

function Base.sum(A::AbstractArray, w::UnitWeights; dims::Union{Colon,Int}=:)
a = (dims === :) ? length(A) : size(A, dims)
a != length(w) && throw(DimensionMismatch("Inconsistent array dimension."))
return sum(A, dims=dims)
end

##### Weighted means #####

function wmean(v::AbstractArray{<:Number}, w::AbstractVector)
Expand Down
2 changes: 1 addition & 1 deletion test/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ end
@testset "Sum, mean, quantiles and variance for unit weights" begin
wt = uweights(Float64, 3)

@test sum([1.0, 2.0, 3.0], wt) 6.0
@test sum([1.0, 2.0, 3.0], wt) wsum([1.0, 2.0, 3.0], wt) 6.0
@test mean([1.0, 2.0, 3.0], wt) 2.0

@test sum(a, wt, dims=1) sum(a, dims=1)
Expand Down

0 comments on commit e8ab265

Please sign in to comment.