Skip to content

Commit

Permalink
Add unwrapping mechanism for triangular mul and solves (#50058)
Browse files Browse the repository at this point in the history
This adds an unwrapping mechanism to triangular matrices, basically
following the BLAS example in terms of characters encoding wrappers. It
mirrors the `AdjOrTransOrHermOrSym` mechanism closely. Packages that
want to overload by storage type can overload `generic_trimatmul!` (and
potentially `generic_matrimul!`). Note the similarity to
`generic_matvecmul!` and `generic_matmatmul!`. There is, unfortunately,
some added code due to the fact that lazy conjugate wrappers have a
different "wrapper depth" compared to the classic, e.g.,
`*Triangular{<:Any,<:Adjoint}`. I believe that with this PR we cover all
wrappers of typically dense matrices with the unwrapping mechanism. ~~An
analogous approach could be applied to `ldiv!`, if that's of interest
and of benefit to the ecosystem.~~
  • Loading branch information
dkarrasch authored Jul 16, 2023
2 parents d215d91 + e67ddaa commit 18d18dc
Show file tree
Hide file tree
Showing 6 changed files with 550 additions and 504 deletions.
8 changes: 7 additions & 1 deletion stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ end
Adjoint(A) = Adjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
Transpose(A) = Transpose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)

# TODO: remove, is already replaced by wrapperop
"""
adj_or_trans(::AbstractArray) -> adjoint|transpose|identity
adj_or_trans(::Type{<:AbstractArray}) -> adjoint|transpose|identity
Return [`adjoint`](@ref) from an `Adjoint` type or object and
[`transpose`](@ref) from a `Transpose` type or object. Otherwise,
return [`identity`](@ref). Note that `Adjoint` and `Transpose` have
Expand All @@ -94,9 +94,15 @@ inplace_adj_or_trans(::Type{<:AbstractArray}) = copyto!
inplace_adj_or_trans(::Type{<:Adjoint}) = adjoint!
inplace_adj_or_trans(::Type{<:Transpose}) = transpose!

# unwraps Adjoint, Transpose, Symmetric, Hermitian
_unwrap(A::Adjoint) = parent(A)
_unwrap(A::Transpose) = parent(A)

# unwraps Adjoint and Transpose only
_unwrap_at(A) = A
_unwrap_at(A::Adjoint) = parent(A)
_unwrap_at(A::Transpose) = parent(A)

Base.dataids(A::Union{Adjoint, Transpose}) = Base.dataids(A.parent)
Base.unaliascopy(A::Union{Adjoint,Transpose}) = typeof(A)(Base.unaliascopy(A.parent))

Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ function ldiv!(c::AbstractVecOrMat, A::Bidiagonal, b::AbstractVecOrMat)
end
ldiv!(A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = @inline ldiv!(b, A, b)
ldiv!(c::AbstractVecOrMat, A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) =
(t = adj_or_trans(A); _rdiv!(t(c), t(b), t(A)); return c)
(t = wrapperop(A); _rdiv!(t(c), t(b), t(A)); return c)

### Generic promotion methods and fallbacks
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(_initarray(\, eltype(A), eltype(B), B), A, B)
Expand Down Expand Up @@ -846,7 +846,7 @@ end
rdiv!(A::AbstractMatrix, B::Bidiagonal) = @inline _rdiv!(A, A, B)
rdiv!(A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) = @inline _rdiv!(A, A, B)
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) =
(t = adj_or_trans(B); ldiv!(t(C), t(B), t(A)); return C)
(t = wrapperop(B); ldiv!(t(C), t(B), t(A)); return C)

/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B)

Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ for T = (:Number, :UniformScaling, :Diagonal)
end

function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
HH = _mulmattri!(_initarray(*, eltype(H), eltype(U), H), H, U)
HH = mul!(_initarray(*, eltype(H), eltype(U), H), H, U)
UpperHessenberg(HH)
end
function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
HH = _multrimat!(_initarray(*, eltype(U), eltype(H), H), U, H)
HH = mul!(_initarray(*, eltype(U), eltype(H), H), U, H)
UpperHessenberg(HH)
end

Expand Down
16 changes: 6 additions & 10 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ AdjOrTransStridedMat{T} = Union{Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:Any, <:StridedMatrix{T}}}
StridedMaybeAdjOrTransVecOrMat{T} = Union{StridedVecOrMat{T}, AdjOrTrans{<:Any, <:StridedVecOrMat{T}}}

_parent(A) = A
_parent(A::Adjoint) = parent(A)
_parent(A::Transpose) = parent(A)

matprod(x, y) = x*y + x*y

# dot products
Expand Down Expand Up @@ -115,14 +111,14 @@ end
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
end
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
end

# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
Expand All @@ -131,13 +127,13 @@ function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:Bla
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
convert(AbstractArray{TS}, A),
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
end
function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
end
# the following case doesn't seem to benefit from the translation A*B = (B' * A')'
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex})
Expand Down
Loading

0 comments on commit 18d18dc

Please sign in to comment.