diff --git a/Project.toml b/Project.toml index 93741304e..586e1c376 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.8.0" +version = "0.8.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 109b16f80..aa05f9ccc 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -15,11 +15,14 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! # for derivations for wide and tall matrices, see # https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/ +const _RowMaximum = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum : Val{true} +const _NoPivot = VERSION >= v"1.7.0-DEV.1188" ? NoPivot : Val{false} + function frule( - (_, ΔA), ::typeof(lu!), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs... + (_, ΔA), ::typeof(lu!), A::StridedMatrix, pivot::Union{_RowMaximum,_NoPivot}; kwargs... ) F = lu!(A, pivot; kwargs...) - ∂factors = pivot === Val(true) ? ΔA[F.p, :] : ΔA + ∂factors = pivot === _RowMaximum() ? ΔA[F.p, :] : ΔA m, n = size(∂factors) q = min(m, n) if m == n # square A @@ -72,7 +75,7 @@ function frule( end function rrule( - ::typeof(lu), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs... + ::typeof(lu), A::StridedMatrix, pivot::Union{_RowMaximum,_NoPivot}; kwargs... ) F = lu(A, pivot; kwargs...) function lu_pullback(ΔF::Tangent) @@ -124,7 +127,7 @@ function rrule( ldiv!(L1', ∂A1) rdiv!(∂A, U') end - if pivot === Val(true) + if pivot === _RowMaximum() ∂A = ∂A[invperm(F.p), :] end return NoTangent(), ∂A, NoTangent() diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 397f1622f..4fa290adb 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -20,13 +20,16 @@ function FiniteDifferences.to_vec(x::Val) return Bool[], Val_from_vec end +const ROW_MAXIMUM = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum() : Val(true) +const NO_PIVOT = VERSION >= v"1.7.0-DEV.1188" ? NoPivot() : Val(false) + @testset "Factorizations" begin @testset "lu decomposition" begin n = 10 @testset "lu! frule" begin @testset "lu!(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for T in (Float64, ComplexF64), - pivot in (Val(true), Val(false)), + pivot in (ROW_MAXIMUM, NO_PIVOT), m in (7, 10, 13) test_frule(lu!, randn(T, m, n), pivot ⊢ NoTangent()) @@ -35,26 +38,26 @@ end Asingular = zeros(n, n) ΔAsingular = rand_tangent(Asingular) @test_throws SingularException frule( - (ZeroTangent(), copy(ΔAsingular)), lu!, copy(Asingular), Val(true) + (ZeroTangent(), copy(ΔAsingular)), lu!, copy(Asingular), ROW_MAXIMUM ) - frule((ZeroTangent(), ΔAsingular), lu!, Asingular, Val(true); check=false) + frule((ZeroTangent(), ΔAsingular), lu!, Asingular, ROW_MAXIMUM; check=false) @test true # above line would have errored if this was not working right end end @testset "lu rrule" begin @testset "lu(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for T in (Float64, ComplexF64), - pivot in (Val(true), Val(false)), + pivot in (ROW_MAXIMUM, NO_PIVOT), m in (7, 10, 13) test_rrule(lu, randn(T, m, n), pivot ⊢ NoTangent()) end @testset "check=false passed to primal function" begin Asingular = zeros(n, n) - F = lu(Asingular, Val(true); check=false) + F = lu(Asingular, ROW_MAXIMUM; check=false) ΔF = Tangent{typeof(F)}(; U=rand_tangent(F.U), L=rand_tangent(F.L)) - @test_throws SingularException rrule(lu, Asingular, Val(true)) - _, back = rrule(lu, Asingular, Val(true); check=false) + @test_throws SingularException rrule(lu, Asingular, ROW_MAXIMUM) + _, back = rrule(lu, Asingular, ROW_MAXIMUM; check=false) back(ΔF) @test true # above line would have errored if this was not working right end @@ -72,8 +75,8 @@ end end @testset "matrix inverse using LU" begin @testset "inv!(lu(::LU{$T,<:StridedMatrix}))" for T in (Float64,ComplexF64) - test_frule(LinearAlgebra.inv!, lu(randn(T, n, n), Val(true))) - test_rrule(inv, lu(randn(T, n, n), Val(true))) + test_frule(LinearAlgebra.inv!, lu(randn(T, n, n), ROW_MAXIMUM)) + test_rrule(inv, lu(randn(T, n, n), ROW_MAXIMUM)) end end end