diff --git a/Project.toml b/Project.toml index c90fc1443..e73dec2f2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.8.1" +version = "0.8.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 30769eca3..3145123a4 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -194,10 +194,19 @@ function frule((_, ΔA), ::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where return Y, ∂Y end +function rrule( + ::typeof(pinv), + x::Union{AbstractVector{T}, LinearAlgebra.AdjOrTransAbsVec{T}}, +) where {T<:Union{Real,Complex}} + y, full_pb = rrule(pinv, x, 0) + pinv_pullback(Δy) = return full_pb(Δy)[1:2] + return y, pinv_pullback +end + function rrule( ::typeof(pinv), x::AbstractVector{T}, - tol::Real = 0, + tol::Real, ) where {T<:Union{Real,Complex}} y = pinv(x, tol) function pinv_pullback(Δy) @@ -210,7 +219,7 @@ end function rrule( ::typeof(pinv), x::LinearAlgebra.AdjOrTransAbsVec{T}, - tol::Real = 0, + tol::Real, ) where {T<:Union{Real,Complex}} y = pinv(x, tol) function pinv_pullback(Δy) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 109b16f80..9bb0f64de 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -428,8 +428,12 @@ end ##### ##### `cholesky` ##### - -function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U) +function rrule(::typeof(cholesky), A::Real) + C, full_pb = rrule(cholesky, A, :U) + cholesky_pullback(ΔC::Tangent) = return full_pb(ΔC)[1:2] + return C, cholesky_pullback +end +function rrule(::typeof(cholesky), A::Real, uplo::Symbol) C = cholesky(A, uplo) function cholesky_pullback(ΔC::Tangent) return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent() diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 2e70338b4..511e01aeb 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -50,7 +50,8 @@ @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint) test_frule(pinv, F(randn(T, 3)) ⊢ F(randn(T, 3))) - test_rrule(pinv, F(randn(T, 3))) + check_inferred = VERSION ≥ v"1.5" + test_rrule(pinv, F(randn(T, 3)); check_inferred=check_inferred) # Check types. # TODO: Do we need this still? diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 397f1622f..534c45be4 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -357,7 +357,8 @@ end # also we might be missing some overloads for different tangent-types in the rules @testset "cholesky" begin @testset "Real" begin - test_rrule(cholesky, 0.8) + check_inferred = VERSION ≥ v"1.5" + test_rrule(cholesky, 0.8; check_inferred=check_inferred) end @testset "Diagonal{<:Real}" begin D = Diagonal(rand(5) .+ 0.1)