Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve cat design #384

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,13 @@ const _Triangular_SparseKronArrays = UpperOrLowerTriangular{<:Any,<:_SparseKronA
const _Annotated_SparseKronArrays = Union{_Triangular_SparseKronArrays, _Symmetric_SparseKronArrays, _Hermitian_SparseKronArrays}
const _SparseKronGroup = Union{_SparseKronArrays, _Annotated_SparseKronArrays}

const _SpecialArrays = Union{Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal}
const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A}
const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A}
const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{<:Any,A} # AbstractTriangular{T,A}
const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}

@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
mA, nA = size(A); mB, nB = size(B)
mC, nC = mA*mB, nA*nB
Expand Down
73 changes: 45 additions & 28 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import Base: sort!, findall, copy!
import LinearAlgebra: promote_to_array_type, promote_to_arrays_
using LinearAlgebra: _SpecialArrays, _DenseConcatGroup

### The SparseVector

Expand Down Expand Up @@ -1175,24 +1174,10 @@ function _absspvec_vcat(X::AbstractSparseVector{Tv,Ti}...) where {Tv,Ti}
SparseVector(len, rnzind, rnzval)
end

hcat(Xin::Union{Vector, AbstractSparseVector}...) = hcat(map(sparse, Xin)...)
vcat(Xin::Union{Vector, AbstractSparseVector}...) = vcat(map(sparse, Xin)...)

### Concatenation of un/annotated sparse/special/dense vectors/matrices

const _SparseArrays = Union{AbstractSparseVector,
AbstractSparseMatrixCSC,
Adjoint{<:Any,<:AbstractSparseVector},
Transpose{<:Any,<:AbstractSparseVector}}
const _SparseConcatArrays = Union{_SpecialArrays, _SparseArrays}

const _Symmetric_SparseConcatArrays = Symmetric{<:Any,<:_SparseConcatArrays}
const _Hermitian_SparseConcatArrays = Hermitian{<:Any,<:_SparseConcatArrays}
const _Triangular_SparseConcatArrays = UpperOrLowerTriangular{<:Any,<:_SparseConcatArrays}
const _Annotated_SparseConcatArrays = Union{_Triangular_SparseConcatArrays, _Symmetric_SparseConcatArrays, _Hermitian_SparseConcatArrays}
# It's important that _SparseConcatGroup is a larger union than _DenseConcatGroup to make
# sparse cat-methods less specific and to kick in only if there is some sparse array present
const _SparseConcatGroup = Union{_DenseConcatGroup, _SparseConcatArrays, _Annotated_SparseConcatArrays}
# by type-pirating and subverting the Base.cat design by making these a subtype of the normal methods for it
# and re-defining all of it here. See https://github.com/JuliaLang/julia/issues/2326
# for what would have been a more principled way of doing this.

# Concatenations involving un/annotated sparse/special matrices/vectors should yield sparse arrays

Expand All @@ -1204,23 +1189,55 @@ _sparse(A) = _makesparse(A)
_makesparse(x::Number) = x
_makesparse(x::AbstractVector) = convert(SparseVector, issparse(x) ? x : sparse(x))::SparseVector
_makesparse(x::AbstractMatrix) = convert(SparseMatrixCSC, issparse(x) ? x : sparse(x))::SparseMatrixCSC
anysparse() = false
anysparse(X) = X isa AbstractArray && issparse(X)
anysparse(X, Xs...) = anysparse(X) || anysparse(Xs...)

function hcat(X::Union{Vector, AbstractSparseVector}...)
if anysparse(X...)
X = map(sparse, X)
end
return cat(X...; dims=Val(2))
end
function vcat(X::Union{Vector, AbstractSparseVector}...)
if anysparse(X...)
X = map(sparse, X)
end
return cat(X...; dims=Val(1))
end

# type-pirate the Base.cat design by making this a subtype of the existing method for it
# in future versions of Julia (v1.10+), in which https://github.com/JuliaLang/julia/issues/2326 is not fixed yet, the <:Number constraint could be relaxed
# but see also https://github.com/JuliaSparse/SparseArrays.jl/issues/71
const _SparseConcatGroup = Union{AbstractVecOrMat{<:Number},Number}

# `@constprop :aggressive` allows `dims` to be propagated as constant improving return type inference
Base.@constprop :aggressive function Base._cat(dims, Xin::_SparseConcatGroup...)
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
T = promote_eltype(Xin...)
Base.@constprop :aggressive function Base._cat(dims, X::_SparseConcatGroup...)
T = promote_eltype(X...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
end
return Base._cat_t(dims, T, X...)
end
function hcat(Xin::_SparseConcatGroup...)
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
function hcat(X::_SparseConcatGroup...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
end
return cat(X..., dims=Val(2))
end
function vcat(Xin::_SparseConcatGroup...)
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
function vcat(X::_SparseConcatGroup...)
if anysparse(X...)
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
end
return cat(X..., dims=Val(1))
end
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
vcat(_hvcat_rows(rows, X...)...)
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
if anysparse(X...)
vcat(_hvcat_rows(rows, X...)...)
else
typed_hvcat(promote_eltypeof(X...), rows, X...)
end
end
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
if row1 ≤ 0
throw(ArgumentError("length of block row must be positive, got $row1"))
Expand All @@ -1237,7 +1254,7 @@ end
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()

# make sure UniformScaling objects are converted to sparse matrices for concatenation
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix
promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)

"""
Expand Down