diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 16528e6764531..dcaad4fdbdd91 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -442,16 +442,6 @@ function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _a end out end -function __muldiag!(out, D::Diagonal, B, _add::MulAddMul) - require_one_based_indexing(out, B) - alpha, beta = _add.alpha, _add.beta - if iszero(alpha) - _rmul_or_fill!(out, beta) - else - __muldiag_nonzeroalpha!(out, D, B, _add) - end - return out -end @inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} beta = _add.beta @@ -495,16 +485,6 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _a end out end -function __muldiag!(out, A, D::Diagonal, _add::MulAddMul) - require_one_based_indexing(out, A) - alpha, beta = _add.alpha, _add.beta - if iszero(alpha) - _rmul_or_fill!(out, beta) - else - __muldiag_nonzeroalpha!(out, A, D, _add) - end - return out -end @inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul) d1 = D1.diag @@ -515,15 +495,7 @@ end end out end -function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul) - alpha, beta = _add.alpha, _add.beta - if iszero(alpha) - _rmul_or_fill!(out.diag, beta) - else - __muldiag_nonzeroalpha!(out, D1, D2, _add) - end - return out -end + @inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul) d1 = D1.diag d2 = D2.diag @@ -532,11 +504,26 @@ end end out end + +# muldiag mainly handles the zero-alpha case, so that we need only +# specialize the non-trivial case +function __muldiag!(out, A, B, _add::MulAddMul) + require_one_based_indexing(out, A, B) + alpha, beta = _add.alpha, _add.beta + if iszero(alpha) + _rmul_or_fill!(out, beta) + else + __muldiag_nonzeroalpha!(out, A, B, _add) + end + return out +end + function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1}) where {ais1} require_one_based_indexing(out) alpha, beta = _add.alpha, _add.beta _rmul_or_fill!(out, beta) if !iszero(alpha) + # we ony update the diagonal _add_bis1 = MulAddMul{ais1,false,typeof(alpha),Bool}(alpha,true) __muldiag_nonzeroalpha!(out, D1, D2, _add_bis1) end