From 93cdc0eb6fdc9db12e801fa3131b9e8f930e8fe0 Mon Sep 17 00:00:00 2001 From: Justin Willmert Date: Wed, 8 Jul 2020 14:53:55 -0500 Subject: [PATCH] Make use of fast sparse outer products from Julia (closes #4) The feature was added to Julia in time for v1.2 in JuliaLang/julia#24980, so get rid of the custom `outer()` method here and rewrite `quadprod()` in terms of just standard matrix methods. Julia v1.2 is the minimum-supported version at this point, so no need to worry about backporting the functionality. In the future, this function may yet still go away since the implementation is nearly trivial at this point, but that can be a follow-up PR. --- src/numerics.jl | 89 ++---------------------------------------------- test/numerics.jl | 6 ++-- 2 files changed, 7 insertions(+), 88 deletions(-) diff --git a/src/numerics.jl b/src/numerics.jl index d651a4e..08fa1b5 100644 --- a/src/numerics.jl +++ b/src/numerics.jl @@ -1,86 +1,5 @@ using SparseArrays -""" -Computes the outer product between a given column of a sparse matrix and a vector. -""" -function outer end - -""" - outer(A::SparseMatrixCSC, n::Integer, w::AbstractVector) - -Performs the equivalent of ``\\vec a_n \\vec w^\\dagger`` where ``\\vec a_n`` is the -column `A[:,n]`. -""" -function outer(A::SparseMatrixCSC{Tv,Ti}, n::Integer, w::AbstractVector{Tv}) where {Tv,Ti} - colptrn = nzrange(A, n) - rowvalA = rowvals(A) - nzvalsA = nonzeros(A) - - nnza = length(colptrn) - nnzw = length(w) - numnz = nnza * nnzw - - colptr = Vector{Ti}(undef, nnzw+1) - rowval = Vector{Ti}(undef, numnz) - nzvals = Vector{Tv}(undef, numnz) - - idx = 0 - @inbounds for jj = 1:nnzw - colptr[jj] = idx + 1 - - wv = conj(w[jj]) - iszero(wv) && continue - - for ii = colptrn - idx += 1 - rowval[idx] = rowvalA[ii] # copy row index from A - nzvals[idx] = wv * nzvalsA[ii] # outer product values - end - end - @inbounds colptr[nnzw+1] = idx + 1 - return SparseMatrixCSC(size(A,1), nnzw, colptr, rowval, nzvals) -end - -""" - outer(w::AbstractVector, A::SparseMatrixCSC, n::Integer) - -Performs the equivalent of ``\\vec w \\vec{a}_n^\\dagger`` where ``\\vec a_n`` is the -column `A[:,n]`. -""" -function outer(w::AbstractVector{Tv}, A::SparseMatrixCSC{Tv,Ti}, n::Integer) where {Tv,Ti} - colptrn = nzrange(A, n) - rowvalA = rowvals(A) - nzvalsA = nonzeros(A) - - nnza = length(colptrn) - nnzw = length(w) - numnz = nnza * nnzw - - colptr = zeros(Ti, size(A,1)+1) - rowval = Vector{Ti}(undef, numnz) - nzvals = Vector{Tv}(undef, numnz) - - idx = 0 - @inbounds colptr[1] = 1 # col 1 always at index 1 - @inbounds for jj = colptrn - av = conj(nzvalsA[jj]) - rv = rowvalA[jj] - - for ii = 1:nnzw - wv = w[ii] - iszero(wv) && continue - - idx += 1 - colptr[rv+1] += 1 # count num of entries in column - rowval[idx] = ii - nzvals[idx] = w[ii] * av # outer product values - end - end - cumsum!(colptr, colptr) # offsets are sum of all previous - - return SparseMatrixCSC(nnzw, size(A,1), colptr, rowval, nzvals) -end - """ quadprod(A, b, n, dir=:col) @@ -88,13 +7,11 @@ Computes the quadratic product ``ABA^T`` efficiently for the case where ``B`` is except for the `n`th column or row vector `b`, for `dir = :col` or `dir = :row`, respectively. """ -function quadprod(A, b, n, dir::Symbol=:col) +@inline function quadprod(A, b, n, dir::Symbol=:col) if dir == :col - w = A * b - return outer(w, A, n) + return (A * sparse(b)) * view(A, :, n)' elseif dir == :row - w = A * b - return outer(A, n, w) + return view(A, :, n) * (A * sparse(b))' else error("Unrecognized direction `dir = $(repr(dir))`.") end diff --git a/test/numerics.jl b/test/numerics.jl index 25a0876..02c31ac 100644 --- a/test/numerics.jl +++ b/test/numerics.jl @@ -16,6 +16,8 @@ Br = sparse(fill(i,n), collect(1:n), b, n, n) bt = convert(Vector{T}, b) Bct = convert(SparseMatrixCSC{T}, Bc) Brt = convert(SparseMatrixCSC{T}, Br) - @test At * Bct * At' == @inferred quadprod(At, bt, i, :col) - @test At * Brt * At' ≈ @inferred quadprod(At, bt, i, :row) + @test At * Bct * At' == quadprod(At, bt, i, :col) + @test @inferred(quadprod(At, bt, i, :col)) isa SparseMatrixCSC + @test At * Brt * At' ≈ quadprod(At, bt, i, :row) + @test @inferred(quadprod(At, bt, i, :row)) isa SparseMatrixCSC end