From 7dc649948b26fc7e5c7525560d41064f221244a5 Mon Sep 17 00:00:00 2001 From: Thomas Christensen Date: Fri, 2 Jul 2021 10:15:46 -0400 Subject: [PATCH] Fix method ambiguity for `qr` (#931) and lingering ambiguities for `lu` (#932) * fix `qr` method ambiguities (#931) and lingering `lu` ambiguities (#920) * fix inferrence issues due to using `@invoke` for `lu` keyword arguments * bump version --- Project.toml | 2 +- src/lu.jl | 28 ++++++++++++---------------- src/qr.jl | 33 +++++++++++++++++++-------------- test/lu.jl | 11 +++++++++++ test/qr.jl | 8 ++++++++ 5 files changed, 51 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 341be74e..87015e78 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.2.4" +version = "1.2.5" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/lu.jl b/src/lu.jl index 1544bedb..1e6c0f22 100644 --- a/src/lu.jl +++ b/src/lu.jl @@ -30,25 +30,21 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::LU) end # LU decomposition -function lu(A::StaticMatrix, pivot::Union{Val{false},Val{true}}=Val(true); check = true) - L, U, p = _lu(A, pivot, check) - LU(L, U, p) -end - -# For the square version, return explicit lower and upper triangular matrices. -# We would do this for the rectangular case too, but Base doesn't support that. -function lu(A::StaticMatrix{N,N}, pivot::Union{Val{false},Val{true}}=Val(true); - check = true) where {N} - L, U, p = _lu(A, pivot, check) - LU(LowerTriangular(L), UpperTriangular(U), p) -end +for pv in (:true, :false) + # ... define each `pivot::Val{true/false}` method individually to avoid ambiguties + @eval function lu(A::StaticMatrix, pivot::Val{$pv}; check = true) + L, U, p = _lu(A, pivot, check) + LU(L, U, p) + end -@static if VERSION >= v"1.7-DEV" - # disambiguation - function lu(A::StaticMatrix{N,N}, pivot::Val{true}) where {N} - Base.@invoke lu(A::StaticMatrix{N,N} where N, pivot::Union{Val{false},Val{true}}) + # For the square version, return explicit lower and upper triangular matrices. + # We would do this for the rectangular case too, but Base doesn't support that. + @eval function lu(A::StaticMatrix{N,N}, pivot::Val{$pv}; check = true) where {N} + L, U, p = _lu(A, pivot, check) + LU(LowerTriangular(L), UpperTriangular(U), p) end end +lu(A::StaticMatrix; check = true) = lu(A, Val(true); check=check) # location of the first zero on the diagonal, 0 when not found function _first_zero_on_diagonal(A::StaticMatrix{M,N,T}) where {M,N,T} diff --git a/src/qr.jl b/src/qr.jl index d4aeb078..6b174ed9 100644 --- a/src/qr.jl +++ b/src/qr.jl @@ -11,10 +11,26 @@ Base.iterate(S::QR, ::Val{:R}) = (S.R, Val(:p)) Base.iterate(S::QR, ::Val{:p}) = (S.p, Val(:done)) Base.iterate(S::QR, ::Val{:done}) = nothing +for pv in (:true, :false) + @eval begin + @inline function qr(A::StaticMatrix, pivot::Val{$pv}) + QRp = _qr(Size(A), A, pivot) + if length(QRp) === 2 + # create an identity permutation since that is cheap, + # and much safer since, in the case of isbits types, we can't + # safely leave the field undefined. + p = identity_perm(QRp[2]) + return QR(QRp[1], QRp[2], p) + else # length(QRp) === 3 + return QR(QRp[1], QRp[2], QRp[3]) + end + end + end +end """ - qr(A::StaticMatrix, pivot=Val(false)) + qr(A::StaticMatrix, pivot::Union{Val{true}, Val{false}} = Val(false)) -Compute the QR factorization of `A`. The factors can be obtain by iteration: +Compute the QR factorization of `A`. The factors can be obtained by iteration: ```julia julia> A = @SMatrix rand(3,4); @@ -34,18 +50,7 @@ julia> F.Q * F.R ≈ A true ``` """ -@inline function qr(A::StaticMatrix, pivot::Union{Val{false}, Val{true}} = Val(false)) - QRp = _qr(Size(A), A, pivot) - if length(QRp) === 2 - # create an identity permutation since that is cheap, - # and much safer since, in the case of isbits types, we can't - # safely leave the field undefined. - p = identity_perm(QRp[2]) - return QR(QRp[1], QRp[2], p) - else # length(QRp) === 3 - return QR(QRp[1], QRp[2], QRp[3]) - end -end +qr(A::StaticMatrix) = qr(A, Val(false)) function identity_perm(R::StaticMatrix{N,M,T}) where {N,M,T} return similar_type(R, Int, Size((M,)))(ntuple(x -> x, Val{M}())) diff --git a/test/lu.jl b/test/lu.jl index 7852a733..51a793d0 100644 --- a/test/lu.jl +++ b/test/lu.jl @@ -65,3 +65,14 @@ end @test_throws SingularException lu(A) @test !issuccess(lu(A; check = false)) end + +@testset "LU method ambiguity" begin + # Issue #920; just test that methods do not throw an ambiguity error when called + for A in ((@SMatrix [1.0 2.0; 3.0 4.0]), (@SMatrix [1.0 2.0 3.0; 4.0 5.0 6.0])) + @test isa(lu(A), StaticArrays.LU) + @test isa(lu(A, Val(true)), StaticArrays.LU) + @test isa(lu(A, Val(false)), StaticArrays.LU) + @test isa(lu(A; check=false), StaticArrays.LU) + @test isa(lu(A; check=true), StaticArrays.LU) + end +end \ No newline at end of file diff --git a/test/qr.jl b/test/qr.jl index a89fba36..f83be439 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -60,3 +60,11 @@ Random.seed!(42) test_qr(arr) end end + +@testset "QR method ambiguity" begin + # Issue #931; just test that methods do not throw an ambiguity error when called + A = @SMatrix [1.0 2.0 3.0; 4.0 5.0 6.0] + @test isa(qr(A), StaticArrays.QR) + @test isa(qr(A, Val(true)), StaticArrays.QR) + @test isa(qr(A, Val(false)), StaticArrays.QR) +end \ No newline at end of file