From eb68dd58942cb0cdcccb1824a473206c09367c02 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 10 Aug 2024 17:36:47 +0530 Subject: [PATCH 01/11] Avoid materializing arrays in bidiag matmul --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 4 ++- stdlib/LinearAlgebra/src/bidiag.jl | 37 ++++++++++++------- stdlib/LinearAlgebra/test/bidiag.jl | 33 +++++++++++++++++ stdlib/LinearAlgebra/test/tridiag.jl | 44 +++++++++++++++++++++++ 4 files changed, 105 insertions(+), 13 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 27d4255fb656b..17216845b350c 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -673,7 +673,9 @@ matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS) _matprod_dest_diag(A, TS) = similar(A, TS) function _matprod_dest_diag(A::SymTridiagonal, TS) n = size(A, 1) - Tridiagonal(similar(A, TS, n-1), similar(A, TS, n), similar(A, TS, n-1)) + ev = similar(A, TS, max(0, n-1)) + dv = similar(A, TS, n) + Tridiagonal(ev, dv, similar(ev)) end # Special handling for adj/trans vec diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index d86bad7e41435..fb660f2f2c208 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -576,8 +576,13 @@ function _bibimul!(C, A, B, _add) require_one_based_indexing(C) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) - iszero(n) && return C - n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) + if n <= 3 + # naive multiplication + for I in CartesianIndices(C) + _modify!(_add, sum(A[I[1], k] * B[k, I[2]] for k in axes(A,2)), C, I) + end + return C + end # We use `_rmul_or_fill!` instead of `_modify!` here since using # `_modify!` in the following loop will not update the # off-diagonal elements for non-zero beta. @@ -744,7 +749,14 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA nB = size(B,2) (iszero(nA) || iszero(nB)) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) - nA <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) + if nA <= 3 + # naive multiplication + for I in CartesianIndices(C) + col = Base.tail(Tuple(I)) + _modify!(_add, sum(A[I[1], k] * B[k, col...] for k in axes(A,2)), C, I) + end + return C + end l = _diag(A, -1) d = _diag(A, 0) u = _diag(A, 1) @@ -767,10 +779,10 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) m = size(B,2) - (iszero(m) || iszero(n)) && return C - iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) - if n <= 3 || m <= 1 - return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) + (iszero(_add.alpha) || iszero(m)) && return _rmul_or_fill!(C, _add.beta) + if m == 1 + B11 = B[1,1] + return mul!(C, A, B11, _add.alpha, _add.beta) end Bl = _diag(B, -1) Bd = _diag(B, 0) @@ -804,9 +816,6 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd m, n = size(A) (iszero(m) || iszero(n)) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) - if size(A, 1) <= 3 || size(B, 2) <= 1 - return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) - end @inbounds if B.uplo == 'U' for i in 1:m for j in n:-1:2 @@ -833,8 +842,12 @@ function _dibimul!(C, A, B, _add) require_one_based_indexing(C) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) - iszero(n) && return C - n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) + if n <= 3 + for I in CartesianIndices(C) + _modify!(_add, A.diag[I[1]] * B[I[1], I[2]], C, I) + end + return C + end _rmul_or_fill!(C, _add.beta) # see the same use above iszero(_add.alpha) && return C Ad = A.diag diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index ef50658a642fb..ebd6ffd88a76c 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -1048,4 +1048,37 @@ end @test mul!(similar(D), B, D) == mul!(similar(D), D, B) == B * D end +@testset "mul for small matrices" begin + @testset for n in 0:4 + D = Diagonal(rand(n)) + v = rand(n) + @testset for uplo in (:L, :U) + B = Bidiagonal(rand(n), rand(max(n-1,0)), uplo) + M = Matrix(B) + + @test B * v ≈ M * v + @test mul!(similar(v), B, v) ≈ M * v + + @test B * B ≈ M * M + @test mul!(similar(B, size(B)), B, B) ≈ M * M + + for m in 1:6 + AL = rand(m,n) + AR = rand(n,m) + @test AL * B ≈ AL * M + @test B * AR ≈ M * AR + @test mul!(similar(AL), AL, B) ≈ AL * M + @test mul!(similar(AR), B, AR) ≈ M * AR + end + + @test B * D ≈ M * D + @test D * B ≈ D * M + @test mul!(similar(B), B, D) ≈ M * D + @test mul!(similar(B), B, D) ≈ M * D + @test mul!(similar(B, size(B)), D, B) ≈ D * M + @test mul!(similar(B, size(B)), B, D) ≈ M * D + end + end +end + end # module TestBidiagonal diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 3330fa682fe5e..5ff8097cb9f44 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -970,4 +970,48 @@ end @test sprint(show, S) == "SymTridiagonal($(repr(diag(S))), $(repr(diag(S,1))))" end +@testset "mul for small matrices" begin + @testset for n in 0:4 + for T in ( + Tridiagonal(rand(max(n-1,0)), rand(n), rand(max(n-1,0))), + SymTridiagonal(rand(n), rand(max(n-1,0))), + ) + M = Matrix(T) + @test T * T ≈ M * M + @test mul!(similar(T, size(T)), T, T) ≈ M * M + + for m in 0:6 + AR = rand(n,m) + AL = rand(m,n) + @test AL * T ≈ AL * M + @test T * AR ≈ M * AR + @test mul!(similar(AL), AL, T) ≈ AL * M + @test mul!(similar(AR), T, AR) ≈ M * AR + end + + v = rand(n) + @test T * v ≈ M * v + @test mul!(similar(v), T, v) ≈ M * v + + D = Diagonal(rand(n)) + @test T * D ≈ M * D + @test D * T ≈ D * M + @test mul!(Tridiagonal(similar(T)), D, T) ≈ D * M + @test mul!(Tridiagonal(similar(T)), T, D) ≈ M * D + @test mul!(similar(T, size(T)), D, T) ≈ D * M + @test mul!(similar(T, size(T)), T, D) ≈ M * D + + B = Bidiagonal(rand(n), rand(max(0, n-1)), :U) + @test T * B ≈ M * B + @test B * T ≈ B * M + if n <= 2 + @test mul!(Tridiagonal(similar(T)), B, T) ≈ B * M + @test mul!(Tridiagonal(similar(T)), T, B) ≈ M * B + @test mul!(similar(T, size(T)), B, T) ≈ B * M + @test mul!(similar(T, size(T)), T, B) ≈ M * B + end + end + end +end + end # module TestTridiagonal From 632f41ba0b185c596a6decdb21c9f559c89f57e4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 10 Aug 2024 21:20:45 +0530 Subject: [PATCH 02/11] Make iteration cache-friendly in Bidiagonal matmul --- stdlib/LinearAlgebra/src/bidiag.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index fb660f2f2c208..6ab27b5435504 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -817,17 +817,17 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd (iszero(m) || iszero(n)) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) @inbounds if B.uplo == 'U' + for j in n:-1:2, i in 1:m + _modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j)) + end for i in 1:m - for j in n:-1:2 - _modify!(_add, A[i,j] * B.dv[j] + A[i,j-1] * B.ev[j-1], C, (i, j)) - end _modify!(_add, A[i,1] * B.dv[1], C, (i, 1)) end else # uplo == 'L' + for j in 1:n-1, i in 1:m + _modify!(_add, A[i,j] * B.dv[j] + A[i,j+1] * B.ev[j], C, (i, j)) + end for i in 1:m - for j in 1:n-1 - _modify!(_add, A[i,j] * B.dv[j] + A[i,j+1] * B.ev[j], C, (i, j)) - end _modify!(_add, A[i,n] * B.dv[n], C, (i, n)) end end From 3f832724eca3b037ec0205f14400ff2ca5b2e72f Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 10 Aug 2024 23:48:07 +0530 Subject: [PATCH 03/11] Specialize Bidiagonal * AbstractVecOrMat --- stdlib/LinearAlgebra/src/bidiag.jl | 38 +++++++++++++++++++++++++++++ stdlib/LinearAlgebra/test/bidiag.jl | 2 +- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 6ab27b5435504..3982b03d38a5b 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -757,6 +757,44 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA end return C end + _mul_bitrisym!(C, A, B, _add) +end +function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul) + nA = size(A,1) + nB = size(B,2) + d = B.dv + if A.uplo == 'U' + u = B.ev + @inbounds begin + for j = 1:nB + b₀, b₊ = B[1, j], B[2, j] + _modify!(_add, d[1]*b₀ + u[1]*b₊, C, (1, j)) + for i = 2:nA - 1 + b₀, b₊ = b₊, B[i + 1, j] + _modify!(_add, d[i]*b₀ + u[i]*b₊, C, (i, j)) + end + _modify!(_add, d[nA]*b₊, C, (nA, j)) + end + end + else + l = B.ev + @inbounds begin + for j = 1:nB + b₀, b₊ = B[1, j], B[2, j] + _modify!(_add, d[1]*b₀, C, (1, j)) + for i = 2:nA - 1 + b₋, b₀, b₊ = b₀, b₊, B[i + 1, j] + _modify!(_add, l[i - 1]*b₋ + d[i]*b₀, C, (i, j)) + end + _modify!(_add, l[nA - 1]*b₀ + d[nA]*b₊, C, (nA, j)) + end + end + end + C +end +function _mul_bitrisym!(C::AbstractVecOrMat, A::TriSym, B::AbstractVecOrMat, _add::MulAddMul) + nA = size(A,1) + nB = size(B,2) l = _diag(A, -1) d = _diag(A, 0) u = _diag(A, 1) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index ebd6ffd88a76c..559312cbbb2a1 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -1062,7 +1062,7 @@ end @test B * B ≈ M * M @test mul!(similar(B, size(B)), B, B) ≈ M * M - for m in 1:6 + for m in 0:6 AL = rand(m,n) AR = rand(n,m) @test AL * B ≈ AL * M From 61369523d8f52701804f374499555702c594e112 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 11 Aug 2024 13:52:18 +0530 Subject: [PATCH 04/11] Fix field access --- stdlib/LinearAlgebra/src/bidiag.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 3982b03d38a5b..025e731c6c860 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -762,9 +762,9 @@ end function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, _add::MulAddMul) nA = size(A,1) nB = size(B,2) - d = B.dv + d = A.dv if A.uplo == 'U' - u = B.ev + u = A.ev @inbounds begin for j = 1:nB b₀, b₊ = B[1, j], B[2, j] @@ -777,7 +777,7 @@ function _mul_bitrisym!(C::AbstractVecOrMat, A::Bidiagonal, B::AbstractVecOrMat, end end else - l = B.ev + l = A.ev @inbounds begin for j = 1:nB b₀, b₊ = B[1, j], B[2, j] From 149b84cef37b5129e0eda1225b926daf934d71f5 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 11 Aug 2024 16:51:24 +0530 Subject: [PATCH 05/11] Split bibimul into cases --- stdlib/LinearAlgebra/src/bidiag.jl | 227 +++++++++++++++++++++++++-- stdlib/LinearAlgebra/test/bidiag.jl | 16 ++ stdlib/LinearAlgebra/test/tridiag.jl | 23 ++- 3 files changed, 243 insertions(+), 23 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 025e731c6c860..618fec4ed371a 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -588,12 +588,6 @@ function _bibimul!(C, A, B, _add) # off-diagonal elements for non-zero beta. _rmul_or_fill!(C, _add.beta) iszero(_add.alpha) && return C - Al = _diag(A, -1) - Ad = _diag(A, 0) - Au = _diag(A, 1) - Bl = _diag(B, -1) - Bd = _diag(B, 0) - Bu = _diag(B, 1) @inbounds begin # first row of C C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2, 1]) @@ -604,6 +598,31 @@ function _bibimul!(C, A, B, _add) C[2,2] += _add(A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2]) C[2,3] += _add(A[2,2]*B[2,3] + A[2,3]*B[3,3]) C[2,4] += _add(A[2,3]*B[3,4]) + end + # middle rows + __bibimul!(C, A, B, _add) + @inbounds begin + # row before last of C + C[n-1,n-3] += _add(A[n-1,n-2]*B[n-2,n-3]) + C[n-1,n-2] += _add(A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]) + C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]) + C[n-1,n ] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]) + # last row of C + C[n,n-2] += _add(A[n,n-1]*B[n-1,n-2]) + C[n,n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) + C[n,n ] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) + end # inbounds + C +end +function __bibimul!(C, A, B, _add) + n = size(A,1) + Al = _diag(A, -1) + Ad = _diag(A, 0) + Au = _diag(A, 1) + Bl = _diag(B, -1) + Bd = _diag(B, 0) + Bu = _diag(B, 1) + @inbounds begin for j in 3:n-2 Ajj₋1 = Al[j-1] Ajj = Ad[j] @@ -623,16 +642,192 @@ function _bibimul!(C, A, B, _add) C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) end - # row before last of C - C[n-1,n-3] += _add(A[n-1,n-2]*B[n-2,n-3]) - C[n-1,n-2] += _add(A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]) - C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]) - C[n-1,n ] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]) - # last row of C - C[n,n-2] += _add(A[n,n-1]*B[n-1,n-2]) - C[n,n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) - C[n,n ] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) - end # inbounds + end + C +end +function __bibimul!(C, A, B::Bidiagonal, _add) + n = size(A,1) + Al = _diag(A, -1) + Ad = _diag(A, 0) + Au = _diag(A, 1) + if B.uplo == 'U' + Bd = _diag(B, 0) + Bu = _diag(B, 1) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) + C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) + C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + end + end + else # B.uplo == 'L' + Bl = _diag(B, -1) + Bd = _diag(B, 0) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) + C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j+1] += _add(Ajj₊1*Bj₊1j₊1) + end + end + end + C +end +function __bibimul!(C, A::Bidiagonal, B, _add) + n = size(A,1) + Bl = _diag(B, -1) + Bd = _diag(B, 0) + Bu = _diag(B, 1) + if A.uplo == 'U' + Ad = _diag(A, 0) + Au = _diag(A, 1) + @inbounds begin + for j in 3:n-2 + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j, j-1] += _add(Ajj*Bjj₋1) + C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) + C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + end + end + else # A.uplo == 'L' + Al = _diag(A, -1) + Ad = _diag(A, 0) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) + C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j, j+1] += _add(Ajj *Bjj₊1) + end + end + end + C +end +function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) + n = size(A,1) + if A.uplo == 'U' && B.uplo == 'U' + Ad = _diag(A, 0) + Au = _diag(A, 1) + Bd = _diag(B, 0) + Bu = _diag(B, 1) + @inbounds begin + for j in 3:n-2 + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j, j ] += _add(Ajj*Bjj) + C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) + C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + end + end + elseif A.uplo == 'U' && B.uplo == 'L' + Ad = _diag(A, 0) + Au = _diag(A, 1) + Bl = _diag(B, -1) + Bd = _diag(B, 0) + @inbounds begin + for j in 3:n-2 + Ajj = Ad[j] + Ajj₊1 = Au[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + C[j, j-1] += _add(Ajj*Bjj₋1) + C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j+1] += _add(Ajj₊1*Bj₊1j₊1) + end + end + elseif A.uplo == 'L' && B.uplo == 'U' + Al = _diag(A, -1) + Ad = _diag(A, 0) + Bd = _diag(B, 0) + Bu = _diag(B, 1) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Bj₋1j₋1 = Bd[j-1] + Bj₋1j = Bu[j-1] + Bjj = Bd[j] + Bjj₊1 = Bu[j] + Bj₊1j₊1 = Bd[j+1] + Bj₊1j₊2 = Bu[j+1] + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) + C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j, j+1] += _add(Ajj *Bjj₊1) + end + end + else # A.uplo == 'L' && B.uplo == 'L' + Al = _diag(A, -1) + Ad = _diag(A, 0) + Bl = _diag(B, -1) + Bd = _diag(B, 0) + @inbounds begin + for j in 3:n-2 + Ajj₋1 = Al[j-1] + Ajj = Ad[j] + Bj₋1j₋2 = Bl[j-2] + Bj₋1j₋1 = Bd[j-1] + Bjj₋1 = Bl[j-1] + Bjj = Bd[j] + Bj₊1j = Bl[j] + Bj₊1j₊1 = Bd[j+1] + C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) + C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) + C[j, j ] += _add(Ajj*Bjj) + end + end + end C end diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 559312cbbb2a1..0c91378b881d0 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -1058,9 +1058,11 @@ end @test B * v ≈ M * v @test mul!(similar(v), B, v) ≈ M * v + @test mul!(ones(size(v)), B, v, 2, 3) ≈ M * v * 2 .+ 3 @test B * B ≈ M * M @test mul!(similar(B, size(B)), B, B) ≈ M * M + @test mul!(ones(size(B)), B, B, 2, 4) ≈ M * M * 2 .+ 4 for m in 0:6 AL = rand(m,n) @@ -1069,6 +1071,8 @@ end @test B * AR ≈ M * AR @test mul!(similar(AL), AL, B) ≈ AL * M @test mul!(similar(AR), B, AR) ≈ M * AR + @test mul!(ones(size(AL)), AL, B, 2, 4) ≈ AL * M * 2 .+ 4 + @test mul!(ones(size(AR)), B, AR, 2, 4) ≈ M * AR * 2 .+ 4 end @test B * D ≈ M * D @@ -1077,7 +1081,19 @@ end @test mul!(similar(B), B, D) ≈ M * D @test mul!(similar(B, size(B)), D, B) ≈ D * M @test mul!(similar(B, size(B)), B, D) ≈ M * D + @test mul!(ones(size(B)), D, B, 2, 4) ≈ D * M * 2 .+ 4 + @test mul!(ones(size(B)), B, D, 2, 4) ≈ M * D * 2 .+ 4 end + BL = Bidiagonal(rand(n), rand(max(0, n-1)), :L) + ML = Matrix(BL) + BU = Bidiagonal(rand(n), rand(max(0, n-1)), :U) + MU = Matrix(BU) + T = Tridiagonal(zeros(max(0, n-1)), zeros(n), zeros(max(0, n-1))) + @test mul!(T, BL, BU) ≈ ML * MU + @test mul!(T, BU, BL) ≈ MU * ML + T = Tridiagonal(ones(max(0, n-1)), ones(n), ones(max(0, n-1))) + @test mul!(copy(T), BL, BU, 2, 3) ≈ ML * MU * 2 + T * 3 + @test mul!(copy(T), BU, BL, 2, 3) ≈ MU * ML * 2 + T * 3 end end diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 5ff8097cb9f44..22547fe72b0dd 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -979,6 +979,7 @@ end M = Matrix(T) @test T * T ≈ M * M @test mul!(similar(T, size(T)), T, T) ≈ M * M + @test mul!(ones(size(T)), T, T, 2, 4) ≈ M * M * 2 .+ 4 for m in 0:6 AR = rand(n,m) @@ -987,6 +988,8 @@ end @test T * AR ≈ M * AR @test mul!(similar(AL), AL, T) ≈ AL * M @test mul!(similar(AR), T, AR) ≈ M * AR + @test mul!(ones(size(AL)), AL, T, 2, 4) ≈ AL * M * 2 .+ 4 + @test mul!(ones(size(AR)), T, AR, 2, 4) ≈ M * AR * 2 .+ 4 end v = rand(n) @@ -1000,15 +1003,21 @@ end @test mul!(Tridiagonal(similar(T)), T, D) ≈ M * D @test mul!(similar(T, size(T)), D, T) ≈ D * M @test mul!(similar(T, size(T)), T, D) ≈ M * D - - B = Bidiagonal(rand(n), rand(max(0, n-1)), :U) - @test T * B ≈ M * B - @test B * T ≈ B * M - if n <= 2 - @test mul!(Tridiagonal(similar(T)), B, T) ≈ B * M - @test mul!(Tridiagonal(similar(T)), T, B) ≈ M * B + @test mul!(ones(size(T)), D, T, 2, 4) ≈ D * M * 2 .+ 4 + @test mul!(ones(size(T)), T, D, 2, 4) ≈ M * D * 2 .+ 4 + + for uplo in (:U, :L) + B = Bidiagonal(rand(n), rand(max(0, n-1)), uplo) + @test T * B ≈ M * B + @test B * T ≈ B * M + if n <= 2 + @test mul!(Tridiagonal(similar(T)), B, T) ≈ B * M + @test mul!(Tridiagonal(similar(T)), T, B) ≈ M * B + end @test mul!(similar(T, size(T)), B, T) ≈ B * M @test mul!(similar(T, size(T)), T, B) ≈ M * B + @test mul!(ones(size(T)), B, T, 2, 4) ≈ B * M * 2 .+ 4 + @test mul!(ones(size(T)), T, B, 2, 4) ≈ M * B * 2 .+ 4 end end end From 580500edbbe5802c3912713ca7b88a8ab41a93d6 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 11 Aug 2024 22:38:16 +0530 Subject: [PATCH 06/11] Fix undef variables in __bibimul --- stdlib/LinearAlgebra/src/bidiag.jl | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 618fec4ed371a..b4c353c98529d 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -665,7 +665,7 @@ function __bibimul!(C, A, B::Bidiagonal, _add) Bj₊1j₊1 = Bd[j+1] Bj₊1j₊2 = Bu[j+1] C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) end @@ -705,9 +705,6 @@ function __bibimul!(C, A::Bidiagonal, B, _add) for j in 3:n-2 Ajj = Ad[j] Ajj₊1 = Au[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] - Bj₋1j = Bu[j-1] Bjj₋1 = Bl[j-1] Bjj = Bd[j] Bjj₊1 = Bu[j] @@ -733,9 +730,6 @@ function __bibimul!(C, A::Bidiagonal, B, _add) Bjj₋1 = Bl[j-1] Bjj = Bd[j] Bjj₊1 = Bu[j] - Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) @@ -756,8 +750,6 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) for j in 3:n-2 Ajj = Ad[j] Ajj₊1 = Au[j] - Bj₋1j₋1 = Bd[j-1] - Bj₋1j = Bu[j-1] Bjj = Bd[j] Bjj₊1 = Bu[j] Bj₊1j₊1 = Bd[j+1] @@ -776,8 +768,6 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) for j in 3:n-2 Ajj = Ad[j] Ajj₊1 = Au[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] Bjj₋1 = Bl[j-1] Bjj = Bd[j] Bj₊1j = Bl[j] @@ -800,8 +790,6 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) Bj₋1j = Bu[j-1] Bjj = Bd[j] Bjj₊1 = Bu[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) C[j, j+1] += _add(Ajj *Bjj₊1) @@ -820,8 +808,6 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) Bj₋1j₋1 = Bd[j-1] Bjj₋1 = Bl[j-1] Bjj = Bd[j] - Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) C[j, j ] += _add(Ajj*Bjj) From b8d8a6283d541f72cc32d36e6e70e1de792f7e76 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 11 Aug 2024 23:56:08 +0530 Subject: [PATCH 07/11] _diag for SymTridiagonal --- stdlib/LinearAlgebra/src/bidiag.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index b4c353c98529d..888895c862585 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -557,7 +557,8 @@ end # function to get the internally stored vectors for Bidiagonal and [Sym]Tridiagonal # to avoid allocations in _mul! below (#24324, #24578) _diag(A::Tridiagonal, k) = k == -1 ? A.dl : k == 0 ? A.d : A.du -_diag(A::SymTridiagonal, k) = k == 0 ? A.dv : A.ev +_diag(A::SymTridiagonal{<:Number}, k) = k == 0 ? A.dv : A.ev +_diag(A::SymTridiagonal, k) = k == 0 ? view(A, diagind(A, IndexStyle(A))) : view(A, diagind(A, 1, IndexStyle(A))) function _diag(A::Bidiagonal, k) if k == 0 return A.dv From 898c390a5f7c70f1b8407b18d712e0940094bfc8 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 12 Aug 2024 14:44:35 +0530 Subject: [PATCH 08/11] Column-major iteration in bibimul --- stdlib/LinearAlgebra/src/bidiag.jl | 212 ++++++++++++++--------------- 1 file changed, 106 insertions(+), 106 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 888895c862585..b1e9be81df6e7 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -590,28 +590,27 @@ function _bibimul!(C, A, B, _add) _rmul_or_fill!(C, _add.beta) iszero(_add.alpha) && return C @inbounds begin - # first row of C - C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2, 1]) - C[1,2] += _add(A[1,1]*B[1,2] + A[1,2]*B[2,2]) - C[1,3] += _add(A[1,2]*B[2,3]) - # second row of C + # first column of C + C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2,1]) C[2,1] += _add(A[2,1]*B[1,1] + A[2,2]*B[2,1]) + C[3,1] += _add(A[3,2]*B[2,1]) + # second column of C + C[1,2] += _add(A[1,1]*B[1,2] + A[1,2]*B[2,2]) C[2,2] += _add(A[2,1]*B[1,2] + A[2,2]*B[2,2] + A[2,3]*B[3,2]) - C[2,3] += _add(A[2,2]*B[2,3] + A[2,3]*B[3,3]) - C[2,4] += _add(A[2,3]*B[3,4]) - end - # middle rows + C[3,2] += _add(A[3,2]*B[2,2] + A[3,3]*B[3,2]) + C[4,2] += _add(A[4,3]*B[3,2]) + end # inbounds + # middle columns __bibimul!(C, A, B, _add) @inbounds begin - # row before last of C - C[n-1,n-3] += _add(A[n-1,n-2]*B[n-2,n-3]) - C[n-1,n-2] += _add(A[n-1,n-1]*B[n-1,n-2] + A[n-1,n-2]*B[n-2,n-2]) + C[n-3,n-1] += _add(A[n-3,n-2]*B[n-2,n-1]) + C[n-2,n-1] += _add(A[n-2,n-2]*B[n-2,n-1] + A[n-2,n-1]*B[n-1,n-1]) C[n-1,n-1] += _add(A[n-1,n-2]*B[n-2,n-1] + A[n-1,n-1]*B[n-1,n-1] + A[n-1,n]*B[n,n-1]) - C[n-1,n ] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1, n]*B[n ,n ]) - # last row of C - C[n,n-2] += _add(A[n,n-1]*B[n-1,n-2]) - C[n,n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) - C[n,n ] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) + C[n, n-1] += _add(A[n,n-1]*B[n-1,n-1] + A[n,n]*B[n,n-1]) + # last column of C + C[n-2, n] += _add(A[n-2,n-1]*B[n-1,n]) + C[n-1, n] += _add(A[n-1,n-1]*B[n-1,n ] + A[n-1,n]*B[n,n ]) + C[n, n] += _add(A[n,n-1]*B[n-1,n ] + A[n,n]*B[n,n ]) end # inbounds C end @@ -625,23 +624,24 @@ function __bibimul!(C, A, B, _add) Bu = _diag(B, 1) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] - Ajj = Ad[j] + Aj₋2j₋1 = Au[j-2] + Aj₋1j = Au[j-1] Ajj₊1 = Au[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] + Aj₋1j₋1 = Ad[j-1] + Ajj = Ad[j] + Aj₊1j₊1 = Ad[j+1] + Ajj₋1 = Al[j-1] + Aj₊1j = Al[j] + Aj₊2j₊1 = Al[j+1] Bj₋1j = Bu[j-1] - Bjj₋1 = Bl[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] - C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j) - C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) - C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + + C[j-2, j] += _add(Aj₋2j₋1*Bj₋1j) + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j + Aj₋1j*Bjj) + C[j, j] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j+1, j] += _add(Aj₊1j*Bjj + Aj₊1j₊1*Bj₊1j) + C[j+2, j] += _add(Aj₊2j₊1*Bj₊1j) end end C @@ -651,44 +651,43 @@ function __bibimul!(C, A, B::Bidiagonal, _add) Al = _diag(A, -1) Ad = _diag(A, 0) Au = _diag(A, 1) + Bd = _diag(B, 0) if B.uplo == 'U' - Bd = _diag(B, 0) Bu = _diag(B, 1) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] + Aj₋2j₋1 = Au[j-2] + Aj₋1j = Au[j-1] + Aj₋1j₋1 = Ad[j-1] Ajj = Ad[j] - Ajj₊1 = Au[j] - Bj₋1j₋1 = Bd[j-1] + Ajj₋1 = Al[j-1] + Aj₊1j = Al[j] Bj₋1j = Bu[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) - C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) - C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + + C[j-2, j] += _add(Aj₋2j₋1*Bj₋1j) + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j + Aj₋1j*Bjj) + C[j, j] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j+1, j] += _add(Aj₊1j*Bjj) end end else # B.uplo == 'L' Bl = _diag(B, -1) - Bd = _diag(B, 0) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] - Ajj = Ad[j] + Aj₋1j = Au[j-1] Ajj₊1 = Au[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] - Bjj₋1 = Bl[j-1] + Ajj = Ad[j] + Aj₊1j₊1 = Ad[j+1] + Aj₊1j = Al[j] + Aj₊2j₊1 = Al[j+1] Bjj = Bd[j] Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) - C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) - C[j, j+1] += _add(Ajj₊1*Bj₊1j₊1) + + C[j-1, j] += _add(Aj₋1j*Bjj) + C[j, j] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j+1, j] += _add(Aj₊1j*Bjj + Aj₊1j₊1*Bj₊1j) + C[j+2, j] += _add(Aj₊2j₊1*Bj₊1j) end end end @@ -699,42 +698,45 @@ function __bibimul!(C, A::Bidiagonal, B, _add) Bl = _diag(B, -1) Bd = _diag(B, 0) Bu = _diag(B, 1) + Ad = _diag(A, 0) if A.uplo == 'U' - Ad = _diag(A, 0) Au = _diag(A, 1) @inbounds begin for j in 3:n-2 - Ajj = Ad[j] + Aj₋2j₋1 = Au[j-2] + Aj₋1j = Au[j-1] Ajj₊1 = Au[j] - Bjj₋1 = Bl[j-1] + Aj₋1j₋1 = Ad[j-1] + Ajj = Ad[j] + Aj₊1j₊1 = Ad[j+1] + Bj₋1j = Bu[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] - C[j, j-1] += _add(Ajj*Bjj₋1) - C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) - C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) - C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + + C[j-2, j] += _add(Aj₋2j₋1*Bj₋1j) + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j + Aj₋1j*Bjj) + C[j, j] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j+1, j] += _add(Aj₊1j₊1*Bj₊1j) end end else # A.uplo == 'L' Al = _diag(A, -1) - Ad = _diag(A, 0) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] + Aj₋1j₋1 = Ad[j-1] Ajj = Ad[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] + Aj₊1j₊1 = Ad[j+1] + Ajj₋1 = Al[j-1] + Aj₊1j = Al[j] + Aj₊2j₊1 = Al[j+1] Bj₋1j = Bu[j-1] - Bjj₋1 = Bl[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] - C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) - C[j, j+1] += _add(Ajj *Bjj₊1) + Bj₊1j = Bl[j] + + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j) + C[j, j] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j+1, j] += _add(Aj₊1j*Bjj + Aj₊1j₊1*Bj₊1j) + C[j+2, j] += _add(Aj₊2j₊1*Bj₊1j) end end end @@ -742,76 +744,74 @@ function __bibimul!(C, A::Bidiagonal, B, _add) end function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) n = size(A,1) + Ad = _diag(A, 0) + Bd = _diag(B, 0) if A.uplo == 'U' && B.uplo == 'U' - Ad = _diag(A, 0) Au = _diag(A, 1) - Bd = _diag(B, 0) Bu = _diag(B, 1) @inbounds begin for j in 3:n-2 + Aj₋2j₋1 = Au[j-2] + Aj₋1j = Au[j-1] + Aj₋1j₋1 = Ad[j-1] Ajj = Ad[j] - Ajj₊1 = Au[j] + Bj₋1j = Bu[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] - Bj₊1j₊1 = Bd[j+1] - Bj₊1j₊2 = Bu[j+1] - C[j, j ] += _add(Ajj*Bjj) - C[j, j+1] += _add(Ajj *Bjj₊1 + Ajj₊1*Bj₊1j₊1) - C[j, j+2] += _add(Ajj₊1*Bj₊1j₊2) + + C[j-2, j] += _add(Aj₋2j₋1*Bj₋1j) + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j + Aj₋1j*Bjj) + C[j, j] += _add(Ajj*Bjj) end end elseif A.uplo == 'U' && B.uplo == 'L' - Ad = _diag(A, 0) Au = _diag(A, 1) Bl = _diag(B, -1) - Bd = _diag(B, 0) @inbounds begin for j in 3:n-2 - Ajj = Ad[j] + Aj₋1j = Au[j-1] Ajj₊1 = Au[j] - Bjj₋1 = Bl[j-1] + Ajj = Ad[j] + Aj₊1j₊1 = Ad[j+1] Bjj = Bd[j] Bj₊1j = Bl[j] - Bj₊1j₊1 = Bd[j+1] - C[j, j-1] += _add(Ajj*Bjj₋1) - C[j, j ] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) - C[j, j+1] += _add(Ajj₊1*Bj₊1j₊1) + + C[j-1, j] += _add(Aj₋1j*Bjj) + C[j, j] += _add(Ajj*Bjj + Ajj₊1*Bj₊1j) + C[j+1, j] += _add(Aj₊1j₊1*Bj₊1j) end end elseif A.uplo == 'L' && B.uplo == 'U' Al = _diag(A, -1) - Ad = _diag(A, 0) - Bd = _diag(B, 0) Bu = _diag(B, 1) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] + Aj₋1j₋1 = Ad[j-1] Ajj = Ad[j] - Bj₋1j₋1 = Bd[j-1] + Ajj₋1 = Al[j-1] + Aj₊1j = Al[j] Bj₋1j = Bu[j-1] Bjj = Bd[j] - Bjj₊1 = Bu[j] - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1) - C[j, j ] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) - C[j, j+1] += _add(Ajj *Bjj₊1) + + C[j-1, j] += _add(Aj₋1j₋1*Bj₋1j) + C[j, j] += _add(Ajj₋1*Bj₋1j + Ajj*Bjj) + C[j+1, j] += _add(Aj₊1j*Bjj) end end else # A.uplo == 'L' && B.uplo == 'L' Al = _diag(A, -1) - Ad = _diag(A, 0) Bl = _diag(B, -1) - Bd = _diag(B, 0) @inbounds begin for j in 3:n-2 - Ajj₋1 = Al[j-1] Ajj = Ad[j] - Bj₋1j₋2 = Bl[j-2] - Bj₋1j₋1 = Bd[j-1] - Bjj₋1 = Bl[j-1] + Aj₊1j₊1 = Ad[j+1] + Aj₊1j = Al[j] + Aj₊2j₊1 = Al[j+1] Bjj = Bd[j] - C[j,j-2] += _add( Ajj₋1*Bj₋1j₋2) - C[j, j-1] += _add(Ajj₋1*Bj₋1j₋1 + Ajj*Bjj₋1) - C[j, j ] += _add(Ajj*Bjj) + Bj₊1j = Bl[j] + + C[j, j] += _add(Ajj*Bjj) + C[j+1, j] += _add(Aj₊1j*Bjj + Aj₊1j₊1*Bj₊1j) + C[j+2, j] += _add(Aj₊2j₊1*Bj₊1j) end end end From b54d60e601697a3d972389e7b690af8678a9629e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 12 Aug 2024 15:00:53 +0530 Subject: [PATCH 09/11] Tests for a matrix-valued eltype --- stdlib/LinearAlgebra/test/bidiag.jl | 20 +++++++++++++++++++- stdlib/LinearAlgebra/test/tridiag.jl | 20 +++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 0c91378b881d0..edad29d4ec180 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -1049,7 +1049,7 @@ end end @testset "mul for small matrices" begin - @testset for n in 0:4 + @testset for n in 0:6 D = Diagonal(rand(n)) v = rand(n) @testset for uplo in (:L, :U) @@ -1095,6 +1095,24 @@ end @test mul!(copy(T), BL, BU, 2, 3) ≈ ML * MU * 2 + T * 3 @test mul!(copy(T), BU, BL, 2, 3) ≈ MU * ML * 2 + T * 3 end + + n = 4 + arr = SizedArrays.SizedArray{(2,2)}(reshape([1:4;],2,2)) + for B in ( + Bidiagonal(fill(arr,n), fill(arr,n-1), :L), + Bidiagonal(fill(arr,n), fill(arr,n-1), :U), + ) + @test B * B ≈ Matrix(B) * Matrix(B) + BL = Bidiagonal(fill(arr,n), fill(arr,n-1), :L) + BU = Bidiagonal(fill(arr,n), fill(arr,n-1), :U) + @test BL * B ≈ Matrix(BL) * Matrix(B) + @test BU * B ≈ Matrix(BU) * Matrix(B) + @test B * BL ≈ Matrix(B) * Matrix(BL) + @test B * BU ≈ Matrix(B) * Matrix(BU) + D = Diagonal(fill(arr,n)) + @test D * B ≈ Matrix(D) * Matrix(B) + @test B * D ≈ Matrix(B) * Matrix(D) + end end end # module TestBidiagonal diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 22547fe72b0dd..15ac7f9f2147f 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -971,7 +971,7 @@ end end @testset "mul for small matrices" begin - @testset for n in 0:4 + @testset for n in 0:6 for T in ( Tridiagonal(rand(max(n-1,0)), rand(n), rand(max(n-1,0))), SymTridiagonal(rand(n), rand(max(n-1,0))), @@ -1021,6 +1021,24 @@ end end end end + + n = 4 + arr = SizedArrays.SizedArray{(2,2)}(reshape([1:4;],2,2)) + for T in ( + SymTridiagonal(fill(arr,n), fill(arr,n-1)), + Tridiagonal(fill(arr,n-1), fill(arr,n), fill(arr,n-1)), + ) + @test T * T ≈ Matrix(T) * Matrix(T) + BL = Bidiagonal(fill(arr,n), fill(arr,n-1), :L) + BU = Bidiagonal(fill(arr,n), fill(arr,n-1), :U) + @test BL * T ≈ Matrix(BL) * Matrix(T) + @test BU * T ≈ Matrix(BU) * Matrix(T) + @test T * BL ≈ Matrix(T) * Matrix(BL) + @test T * BU ≈ Matrix(T) * Matrix(BU) + D = Diagonal(fill(arr,n)) + @test D * T ≈ Matrix(D) * Matrix(T) + @test T * D ≈ Matrix(T) * Matrix(D) + end end end # module TestTridiagonal From 07b5987ea378a2f27f2d8cfa9cdc74cb601bc6c8 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 15 Sep 2024 11:26:59 +0530 Subject: [PATCH 10/11] Fast path for zero alpha --- stdlib/LinearAlgebra/src/bidiag.jl | 26 ++++++++++++++++---------- stdlib/LinearAlgebra/test/addmul.jl | 19 +++++++++++++++++++ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index b1e9be81df6e7..12d638f52add6 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -577,18 +577,19 @@ function _bibimul!(C, A, B, _add) require_one_based_indexing(C) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) + iszero(n) && return C + # We use `_rmul_or_fill!` instead of `_modify!` here since using + # `_modify!` in the following loop will not update the + # off-diagonal elements for non-zero beta. + _rmul_or_fill!(C, _add.beta) + iszero(_add.alpha) && return C if n <= 3 # naive multiplication for I in CartesianIndices(C) - _modify!(_add, sum(A[I[1], k] * B[k, I[2]] for k in axes(A,2)), C, I) + C[I] += _add(sum(A[I[1], k] * B[k, I[2]] for k in axes(A,2))) end return C end - # We use `_rmul_or_fill!` instead of `_modify!` here since using - # `_modify!` in the following loop will not update the - # off-diagonal elements for non-zero beta. - _rmul_or_fill!(C, _add.beta) - iszero(_add.alpha) && return C @inbounds begin # first column of C C[1,1] += _add(A[1,1]*B[1,1] + A[1, 2]*B[2,1]) @@ -1062,14 +1063,18 @@ function _dibimul!(C, A, B, _add) require_one_based_indexing(C) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) + iszero(n) && return C + # ensure that we fill off-band elements in the destination + _rmul_or_fill!(C, _add.beta) + iszero(_add.alpha) && return C if n <= 3 + # For simplicity, use a naive multiplication for small matrices + # that loops over all elements. for I in CartesianIndices(C) - _modify!(_add, A.diag[I[1]] * B[I[1], I[2]], C, I) + C[I] += _add(A.diag[I[1]] * B[I[1], I[2]]) end return C end - _rmul_or_fill!(C, _add.beta) # see the same use above - iszero(_add.alpha) && return C Ad = A.diag Bl = _diag(B, -1) Bd = _diag(B, 0) @@ -1103,7 +1108,8 @@ function _dibimul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add) check_A_mul_B!_sizes(size(C), size(A), size(B)) n = size(A,1) iszero(n) && return C - _rmul_or_fill!(C, _add.beta) # see the same use above + # ensure that we fill off-band elements in the destination + _rmul_or_fill!(C, _add.beta) iszero(_add.alpha) && return C Ad = A.diag Bdv, Bev = B.dv, B.ev diff --git a/stdlib/LinearAlgebra/test/addmul.jl b/stdlib/LinearAlgebra/test/addmul.jl index 3fff8289242f7..6cd2590855942 100644 --- a/stdlib/LinearAlgebra/test/addmul.jl +++ b/stdlib/LinearAlgebra/test/addmul.jl @@ -220,4 +220,23 @@ end end end +@testset "issue #55727" begin + C = zeros(1,1) + @testset "$(nameof(typeof(A)))" for A in Any[Diagonal([NaN]), + Bidiagonal([NaN], Float64[], :U), + SymTridiagonal([NaN], Float64[]), + Tridiagonal(Float64[], [NaN], Float64[]), + ] + @testset "$(nameof(typeof(B)))" for B in Any[SymTridiagonal([1.0], Float64[]), + Tridiagonal(Float64[], [1.0], Float64[]), + Bidiagonal([1.0], Float64[], :U), + Bidiagonal([1.0], Float64[], :L), + ] + C .= 0 + @test mul!(C, A, B, 0.0, false)[] === 0.0 + @test mul!(C, B, A, 0.0, false)[] === 0.0 + end + end +end + end # module From df9fe68ca2b5b6d0da119b07c10c15d121c48ce7 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 15 Sep 2024 14:08:11 +0530 Subject: [PATCH 11/11] Add more test combinations --- stdlib/LinearAlgebra/test/addmul.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/test/addmul.jl b/stdlib/LinearAlgebra/test/addmul.jl index 6cd2590855942..739b5545fdba7 100644 --- a/stdlib/LinearAlgebra/test/addmul.jl +++ b/stdlib/LinearAlgebra/test/addmul.jl @@ -224,13 +224,16 @@ end C = zeros(1,1) @testset "$(nameof(typeof(A)))" for A in Any[Diagonal([NaN]), Bidiagonal([NaN], Float64[], :U), + Bidiagonal([NaN], Float64[], :L), SymTridiagonal([NaN], Float64[]), Tridiagonal(Float64[], [NaN], Float64[]), ] - @testset "$(nameof(typeof(B)))" for B in Any[SymTridiagonal([1.0], Float64[]), - Tridiagonal(Float64[], [1.0], Float64[]), + @testset "$(nameof(typeof(B)))" for B in Any[ + Diagonal([1.0]), Bidiagonal([1.0], Float64[], :U), Bidiagonal([1.0], Float64[], :L), + SymTridiagonal([1.0], Float64[]), + Tridiagonal(Float64[], [1.0], Float64[]), ] C .= 0 @test mul!(C, A, B, 0.0, false)[] === 0.0