Skip to content

Commit c93589c

Browse files
authored
Reland new hvcat design (#384, #407)
This reverts commit 2c7f4d6 (#406), now that JuliaLang/julia#48977 is finished.
1 parent a9637dd commit c93589c

File tree

2 files changed

+53
-29
lines changed

2 files changed

+53
-29
lines changed

src/linalg.jl

+7
Original file line numberDiff line numberDiff line change
@@ -1301,6 +1301,13 @@ const _Triangular_SparseKronArrays = UpperOrLowerTriangular{<:Any,<:_SparseKronA
13011301
const _Annotated_SparseKronArrays = Union{_Triangular_SparseKronArrays, _Symmetric_SparseKronArrays, _Hermitian_SparseKronArrays}
13021302
const _SparseKronGroup = Union{_SparseKronArrays, _Annotated_SparseKronArrays}
13031303

1304+
const _SpecialArrays = Union{Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal}
1305+
const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A}
1306+
const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A}
1307+
const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{<:Any,A} # AbstractTriangular{T,A}
1308+
const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
1309+
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}
1310+
13041311
@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
13051312
mA, nA = size(A); mB, nB = size(B)
13061313
mC, nC = mA*mB, nA*nB

src/sparsevector.jl

+46-29
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
import Base: sort!, findall, copy!
66
import LinearAlgebra: promote_to_array_type, promote_to_arrays_
7-
8-
using LinearAlgebra: adj_or_trans, _SpecialArrays, _DenseConcatGroup
7+
using LinearAlgebra: adj_or_trans
98

109
### The SparseVector
1110

@@ -1176,24 +1175,10 @@ function _absspvec_vcat(X::AbstractSparseVector{Tv,Ti}...) where {Tv,Ti}
11761175
SparseVector(len, rnzind, rnzval)
11771176
end
11781177

1179-
hcat(Xin::Union{Vector, AbstractSparseVector}...) = hcat(map(sparse, Xin)...)
1180-
vcat(Xin::Union{Vector, AbstractSparseVector}...) = vcat(map(sparse, Xin)...)
1181-
11821178
### Concatenation of un/annotated sparse/special/dense vectors/matrices
1183-
1184-
const _SparseArrays = Union{AbstractSparseVector,
1185-
AbstractSparseMatrixCSC,
1186-
Adjoint{<:Any,<:AbstractSparseVector},
1187-
Transpose{<:Any,<:AbstractSparseVector}}
1188-
const _SparseConcatArrays = Union{_SpecialArrays, _SparseArrays}
1189-
1190-
const _Symmetric_SparseConcatArrays = Symmetric{<:Any,<:_SparseConcatArrays}
1191-
const _Hermitian_SparseConcatArrays = Hermitian{<:Any,<:_SparseConcatArrays}
1192-
const _Triangular_SparseConcatArrays = UpperOrLowerTriangular{<:Any,<:_SparseConcatArrays}
1193-
const _Annotated_SparseConcatArrays = Union{_Triangular_SparseConcatArrays, _Symmetric_SparseConcatArrays, _Hermitian_SparseConcatArrays}
1194-
# It's important that _SparseConcatGroup is a larger union than _DenseConcatGroup to make
1195-
# sparse cat-methods less specific and to kick in only if there is some sparse array present
1196-
const _SparseConcatGroup = Union{_DenseConcatGroup, _SparseConcatArrays, _Annotated_SparseConcatArrays}
1179+
# by type-pirating and subverting the Base.cat design by making these a subtype of the normal methods for it
1180+
# and re-defining all of it here. See https://github.com/JuliaLang/julia/issues/2326
1181+
# for what would have been a more principled way of doing this.
11971182

11981183
# Concatenations involving un/annotated sparse/special matrices/vectors should yield sparse arrays
11991184

@@ -1205,23 +1190,55 @@ _sparse(A) = _makesparse(A)
12051190
_makesparse(x::Number) = x
12061191
_makesparse(x::AbstractVector) = convert(SparseVector, issparse(x) ? x : sparse(x))::SparseVector
12071192
_makesparse(x::AbstractMatrix) = convert(SparseMatrixCSC, issparse(x) ? x : sparse(x))::SparseMatrixCSC
1193+
anysparse() = false
1194+
anysparse(X) = X isa AbstractArray && issparse(X)
1195+
anysparse(X, Xs...) = anysparse(X) || anysparse(Xs...)
1196+
1197+
function hcat(X::Union{Vector, AbstractSparseVector}...)
1198+
if anysparse(X...)
1199+
X = map(sparse, X)
1200+
end
1201+
return cat(X...; dims=Val(2))
1202+
end
1203+
function vcat(X::Union{Vector, AbstractSparseVector}...)
1204+
if anysparse(X...)
1205+
X = map(sparse, X)
1206+
end
1207+
return cat(X...; dims=Val(1))
1208+
end
1209+
1210+
# type-pirate the Base.cat design by making this a subtype of the existing method for it
1211+
# in future versions of Julia (v1.10+), in which https://github.com/JuliaLang/julia/issues/2326 is not fixed yet, the <:Number constraint could be relaxed
1212+
# but see also https://github.com/JuliaSparse/SparseArrays.jl/issues/71
1213+
const _SparseConcatGroup = Union{AbstractVecOrMat{<:Number},Number}
12081214

12091215
# `@constprop :aggressive` allows `dims` to be propagated as constant improving return type inference
1210-
Base.@constprop :aggressive function Base._cat(dims, Xin::_SparseConcatGroup...)
1211-
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
1212-
T = promote_eltype(Xin...)
1216+
Base.@constprop :aggressive function Base._cat(dims, X::_SparseConcatGroup...)
1217+
T = promote_eltype(X...)
1218+
if anysparse(X...)
1219+
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
1220+
end
12131221
return Base._cat_t(dims, T, X...)
12141222
end
1215-
function hcat(Xin::_SparseConcatGroup...)
1216-
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
1223+
function hcat(X::_SparseConcatGroup...)
1224+
if anysparse(X...)
1225+
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
1226+
end
12171227
return cat(X..., dims=Val(2))
12181228
end
1219-
function vcat(Xin::_SparseConcatGroup...)
1220-
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
1229+
function vcat(X::_SparseConcatGroup...)
1230+
if anysparse(X...)
1231+
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
1232+
end
12211233
return cat(X..., dims=Val(1))
12221234
end
1223-
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
1224-
vcat(_hvcat_rows(rows, X...)...)
1235+
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
1236+
if anysparse(X...)
1237+
vcat(_hvcat_rows(rows, X...)...)
1238+
else
1239+
Base.typed_hvcat(promote_eltypeof(X...), rows, X...)
1240+
end
1241+
end
12251242
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
12261243
if row1 0
12271244
throw(ArgumentError("length of block row must be positive, got $row1"))
@@ -1238,7 +1255,7 @@ end
12381255
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()
12391256

12401257
# make sure UniformScaling objects are converted to sparse matrices for concatenation
1241-
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC
1258+
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix
12421259
promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)
12431260

12441261
"""

0 commit comments

Comments
 (0)