Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot compute nested gradients with reverse mode AD #28

Closed
bolgarbe opened this issue Jul 30, 2021 · 3 comments
Closed

Cannot compute nested gradients with reverse mode AD #28

bolgarbe opened this issue Jul 30, 2021 · 3 comments

Comments

@bolgarbe
Copy link

The following fails in Diffractor (works in Zygote but only in mixed mode):

julia> using Diffractor: var"'"

julia> A = randn(3,3);

julia> f(x) = sum(A*x);

julia> g(x) = sum(x .* f'(x));

julia> g'(ones(3))
ERROR: MethodError: no method matching _adjoint_vec_pullback(::ChainRulesCore.ZeroTangent)
Closest candidates are:
  _adjoint_vec_pullback(::ChainRulesCore.Tangent) at /home/bj0rn/.julia/packages/ChainRules/xaVoS/src/rulesets/LinearAlgebra/structured.jl:119
  _adjoint_vec_pullback(::AbstractMatrix) at /home/bj0rn/.julia/packages/ChainRules/xaVoS/src/rulesets/LinearAlgebra/structured.jl:120
  _adjoint_vec_pullback(::ChainRulesCore.AbstractThunk) at /home/bj0rn/.julia/packages/ChainRules/xaVoS/src/rulesets/LinearAlgebra/structured.jl:121
Stacktrace:
 [1] ∂⃖¹₁times_pullback
   @ ./none:1
 [2] (::Diffractor.∂⃖rruleB{1, 1})(::ChainRulesCore.ZeroTangent, ::Vararg{Any})
   @ Diffractor ~/.julia/packages/Diffractor/EwwgG/src/stage1/generated.jl:64
 [3] ∂⃖²₂f
   @ ./none:1
 [4] (::Diffractor.∂⃖weaveInnerOdd{1, 1})(Δ::Tuple{ChainRulesCore.ZeroTangent, Vector{Float64}})
   @ Diffractor ~/.julia/packages/Diffractor/EwwgG/src/stage1/generated.jl:64
 [5] ∂⃖¹₁PrimeDerivativeBack
   @ ./none:1
 [6] ∂⃖¹₁g
   @ ./none:1
 [7] (::Diffractor.PrimeDerivativeBack{1, typeof(g)})(x::Vector{Float64})
   @ Diffractor ~/.julia/packages/Diffractor/EwwgG/src/interface.jl:160
 [8] top-level scope
   @ REPL[5]:1
@mzgubic
Copy link
Member

mzgubic commented Aug 2, 2021

I think this is a special case of #25

@Keno
Copy link
Collaborator

Keno commented Aug 2, 2021

The error as reported as fixed, but it then gets lost in higher-order AD of broadcast which doesn't quite work yet. However, the use of broadcast in the sum rrule is somewhat gratuitous anyway, so this should work after JuliaDiff/ChainRules.jl#493.

@Keno
Copy link
Collaborator

Keno commented Aug 4, 2021

Fixed by JuliaDiff/ChainRules.jl#494

@Keno Keno closed this as completed Aug 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants