Skip to content

Commit

Permalink
fix lu deprecation warnings on Julia nightly
Browse files Browse the repository at this point in the history
  • Loading branch information
simeonschaub committed Jun 2, 2021
1 parent 61b8f12 commit 1500ad4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.8.0"
version = "0.8.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
11 changes: 7 additions & 4 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
# for derivations for wide and tall matrices, see
# https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/

const _RowMaximum = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum : Val{true}
const _NoPivot = VERSION >= v"1.7.0-DEV.1188" ? NoPivot : Val{false}

function frule(
(_, ΔA), ::typeof(lu!), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs...
(_, ΔA), ::typeof(lu!), A::StridedMatrix, pivot::Union{_RowMaximum,_NoPivot}; kwargs...
)
F = lu!(A, pivot; kwargs...)
∂factors = pivot === Val(true) ? ΔA[F.p, :] : ΔA
∂factors = pivot === _RowMaximum() ? ΔA[F.p, :] : ΔA
m, n = size(∂factors)
q = min(m, n)
if m == n # square A
Expand Down Expand Up @@ -72,7 +75,7 @@ function frule(
end

function rrule(
::typeof(lu), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs...
::typeof(lu), A::StridedMatrix, pivot::Union{_RowMaximum,_NoPivot}; kwargs...
)
F = lu(A, pivot; kwargs...)
function lu_pullback(ΔF::Tangent)
Expand Down Expand Up @@ -124,7 +127,7 @@ function rrule(
ldiv!(L1', ∂A1)
rdiv!(∂A, U')
end
if pivot === Val(true)
if pivot === _RowMaximum()
∂A = ∂A[invperm(F.p), :]
end
return NoTangent(), ∂A, NoTangent()
Expand Down
21 changes: 12 additions & 9 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ function FiniteDifferences.to_vec(x::Val)
return Bool[], Val_from_vec
end

const ROW_MAXIMUM = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum() : Val(true)
const NO_PIVOT = VERSION >= v"1.7.0-DEV.1188" ? NoPivot() : Val(false)

@testset "Factorizations" begin
@testset "lu decomposition" begin
n = 10
@testset "lu! frule" begin
@testset "lu!(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for
T in (Float64, ComplexF64),
pivot in (Val(true), Val(false)),
pivot in (ROW_MAXIMUM, NO_PIVOT),
m in (7, 10, 13)

test_frule(lu!, randn(T, m, n), pivot NoTangent())
Expand All @@ -35,26 +38,26 @@ end
Asingular = zeros(n, n)
ΔAsingular = rand_tangent(Asingular)
@test_throws SingularException frule(
(ZeroTangent(), copy(ΔAsingular)), lu!, copy(Asingular), Val(true)
(ZeroTangent(), copy(ΔAsingular)), lu!, copy(Asingular), ROW_MAXIMUM
)
frule((ZeroTangent(), ΔAsingular), lu!, Asingular, Val(true); check=false)
frule((ZeroTangent(), ΔAsingular), lu!, Asingular, ROW_MAXIMUM; check=false)
@test true # above line would have errored if this was not working right
end
end
@testset "lu rrule" begin
@testset "lu(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for
T in (Float64, ComplexF64),
pivot in (Val(true), Val(false)),
pivot in (ROW_MAXIMUM, NO_PIVOT),
m in (7, 10, 13)

test_rrule(lu, randn(T, m, n), pivot NoTangent())
end
@testset "check=false passed to primal function" begin
Asingular = zeros(n, n)
F = lu(Asingular, Val(true); check=false)
F = lu(Asingular, ROW_MAXIMUM; check=false)
ΔF = Tangent{typeof(F)}(; U=rand_tangent(F.U), L=rand_tangent(F.L))
@test_throws SingularException rrule(lu, Asingular, Val(true))
_, back = rrule(lu, Asingular, Val(true); check=false)
@test_throws SingularException rrule(lu, Asingular, ROW_MAXIMUM)
_, back = rrule(lu, Asingular, ROW_MAXIMUM; check=false)
back(ΔF)
@test true # above line would have errored if this was not working right
end
Expand All @@ -72,8 +75,8 @@ end
end
@testset "matrix inverse using LU" begin
@testset "inv!(lu(::LU{$T,<:StridedMatrix}))" for T in (Float64,ComplexF64)
test_frule(LinearAlgebra.inv!, lu(randn(T, n, n), Val(true)))
test_rrule(inv, lu(randn(T, n, n), Val(true)))
test_frule(LinearAlgebra.inv!, lu(randn(T, n, n), ROW_MAXIMUM))
test_rrule(inv, lu(randn(T, n, n), ROW_MAXIMUM))
end
end
end
Expand Down

0 comments on commit 1500ad4

Please sign in to comment.