Skip to content

Commit

Permalink
improve cat design (#384)
Browse files Browse the repository at this point in the history
This is type-piracy, but we cannot change that (until
JuliaLang/julia#2326), so at least do not make these method
intersections unnecessary slow and complicated for everyone who does not
care about SparseArrays and does not load it, and unreliable for
everyone who does load it.
  • Loading branch information
vtjnash authored Apr 20, 2023
1 parent 8145759 commit 5fc5771
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
7 changes: 7 additions & 0 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,13 @@ const _Triangular_SparseKronArrays = UpperOrLowerTriangular{<:Any,<:_SparseKronA
const _Annotated_SparseKronArrays = Union{_Triangular_SparseKronArrays, _Symmetric_SparseKronArrays, _Hermitian_SparseKronArrays}
const _SparseKronGroup = Union{_SparseKronArrays, _Annotated_SparseKronArrays}

const _SpecialArrays = Union{Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal}
const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A}
const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A}
const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{<:Any,A} # AbstractTriangular{T,A}
const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}

@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
mA, nA = size(A); mB, nB = size(B)
mC, nC = mA*mB, nA*nB
Expand Down
73 changes: 45 additions & 28 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import Base: sort!, findall, copy!
import LinearAlgebra: promote_to_array_type, promote_to_arrays_
using LinearAlgebra: _SpecialArrays, _DenseConcatGroup

### The SparseVector

Expand Down Expand Up @@ -1175,24 +1174,10 @@ function _absspvec_vcat(X::AbstractSparseVector{Tv,Ti}...) where {Tv,Ti}
SparseVector(len, rnzind, rnzval)
end

hcat(Xin::Union{Vector, AbstractSparseVector}...) = hcat(map(sparse, Xin)...)
vcat(Xin::Union{Vector, AbstractSparseVector}...) = vcat(map(sparse, Xin)...)

### Concatenation of un/annotated sparse/special/dense vectors/matrices

const _SparseArrays = Union{AbstractSparseVector,
AbstractSparseMatrixCSC,
Adjoint{<:Any,<:AbstractSparseVector},
Transpose{<:Any,<:AbstractSparseVector}}
const _SparseConcatArrays = Union{_SpecialArrays, _SparseArrays}

const _Symmetric_SparseConcatArrays = Symmetric{<:Any,<:_SparseConcatArrays}
const _Hermitian_SparseConcatArrays = Hermitian{<:Any,<:_SparseConcatArrays}
const _Triangular_SparseConcatArrays = UpperOrLowerTriangular{<:Any,<:_SparseConcatArrays}
const _Annotated_SparseConcatArrays = Union{_Triangular_SparseConcatArrays, _Symmetric_SparseConcatArrays, _Hermitian_SparseConcatArrays}
# It's important that _SparseConcatGroup is a larger union than _DenseConcatGroup to make
# sparse cat-methods less specific and to kick in only if there is some sparse array present
const _SparseConcatGroup = Union{_DenseConcatGroup, _SparseConcatArrays, _Annotated_SparseConcatArrays}
# by type-pirating and subverting the Base.cat design by making these a subtype of the normal methods for it
# and re-defining all of it here. See https://github.com/JuliaLang/julia/issues/2326
# for what would have been a more principled way of doing this.

# Concatenations involving un/annotated sparse/special matrices/vectors should yield sparse arrays

Expand All @@ -1204,23 +1189,55 @@ _sparse(A) = _makesparse(A)
_makesparse(x::Number) = x
_makesparse(x::AbstractVector) = convert(SparseVector, issparse(x) ? x : sparse(x))::SparseVector
_makesparse(x::AbstractMatrix) = convert(SparseMatrixCSC, issparse(x) ? x : sparse(x))::SparseMatrixCSC
anysparse() = false
anysparse(X) = X isa AbstractArray && issparse(X)
anysparse(X, Xs...) = anysparse(X) || anysparse(Xs...)

function hcat(X::Union{Vector, AbstractSparseVector}...)
if anysparse(X...)
X = map(sparse, X)
end
return cat(X...; dims=Val(2))
end
function vcat(X::Union{Vector, AbstractSparseVector}...)
if anysparse(X...)
X = map(sparse, X)
end
return cat(X...; dims=Val(1))
end

# type-pirate the Base.cat design by making this a subtype of the existing method for it
# in future versions of Julia (v1.10+), in which https://github.com/JuliaLang/julia/issues/2326 is not fixed yet, the <:Number constraint could be relaxed
# but see also https://github.com/JuliaSparse/SparseArrays.jl/issues/71
const _SparseConcatGroup = Union{AbstractVecOrMat{<:Number},Number}

# `@constprop :aggressive` allows `dims` to be propagated as constant improving return type inference
Base.@constprop :aggressive function Base._cat(dims, Xin::_SparseConcatGroup...)
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
T = promote_eltype(Xin...)
Base.@constprop :aggressive function Base._cat(dims, X::_SparseConcatGroup...)
T = promote_eltype(X...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
end
return Base._cat_t(dims, T, X...)
end
function hcat(Xin::_SparseConcatGroup...)
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
function hcat(X::_SparseConcatGroup...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
end
return cat(X..., dims=Val(2))
end
function vcat(Xin::_SparseConcatGroup...)
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
function vcat(X::_SparseConcatGroup...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
end
return cat(X..., dims=Val(1))
end
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
vcat(_hvcat_rows(rows, X...)...)
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
if anysparse(X...)
vcat(_hvcat_rows(rows, X...)...)
else
typed_hvcat(promote_eltypeof(X...), rows, X...)
end
end
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
if row1 0
throw(ArgumentError("length of block row must be positive, got $row1"))
Expand All @@ -1237,7 +1254,7 @@ end
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()

# make sure UniformScaling objects are converted to sparse matrices for concatenation
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix
promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)

"""
Expand Down

0 comments on commit 5fc5771

Please sign in to comment.