-
Notifications
You must be signed in to change notification settings - Fork 29
Enable second-order differentiation via forward-over-reverse #878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
fd8f884
Fix forward-over-reverse, add tests, DI doesn't work yet
sunxd3 1bb341d
fix version specific issue
sunxd3 4869aa9
Use NativeInterpreter in frule_type to avoid maybe_primitive recursion
sunxd3 d6480be
add frule for _build_rule!
sunxd3 9159ef9
revert the nativeinterpreter change
sunxd3 a426235
add has_equal_data_internal for MistyClosureTangent
sunxd3 231a696
add rule for jl_genericmemory_owner to fix 1.11
sunxd3 1d28b33
add frule for jl_alloc_array_1d to fix 1.10 error
sunxd3 4bb3c61
aboid hardcoded UInt length
sunxd3 64410c6
formatting
sunxd3 3f8d7eb
more x86 compacy
sunxd3 029a777
deal with Union{} bottom type, make rule type Any for LazyFRule
sunxd3 03eb3c0
revert the Any type change for LazyFRule
sunxd3 14b6fcf
test forward over forward to see
sunxd3 4339183
Merge branch 'main' into sunxd/f-o-r
sunxd3 5db1ce2
remove added DI tests
sunxd3 6578aa6
Update forward_over_reverse.jl
yebai 5db16f0
Add skip_world_age_check kwarg to build_frule for MistyClosure support
sunxd3 09c496a
Add rrule!! for _build_rule! that throws on reverse-over-reverse
sunxd3 b7413d5
Remove unnecessary rules: literal_pow, push!, jl_genericmemory_owner
sunxd3 c2f14b0
Merge branch 'main' into sunxd/f-o-r
sunxd3 2a9a680
Preserve primitive inlining policy in optimise_ir! for forward-over-r…
sunxd3 e18edfd
Restrict getfield frule!! to AbstractArray to avoid ambiguity with St…
sunxd3 e324759
Merge branch 'main' into sunxd/f-o-r
sunxd3 3168a13
Fix Julia 1.12 Future assertion by using sv.interp for wrapper interp…
sunxd3 ed10d38
Revert interpreter forwarding changes (too hairy for now)
sunxd3 d79f375
Add back jl_genericmemory_owner frule/rrule for Julia 1.11+
sunxd3 c9109fe
Move `jl_genericmemory_owner` frule to test file.
sunxd3 7d258bf
Merge branch 'main' into sunxd/f-o-r
sunxd3 f6ba807
version bump
sunxd3 d90977a
Merge branch 'main' into sunxd/f-o-r
yebai 3334682
Add frules to make DI interface work.
sunxd3 ae7dea4
Actually enable f-o-r tests
sunxd3 5769127
Merge branch 'main' into sunxd/f-o-r
sunxd3 7a78468
Move DI interface rules to avoiding_non_differentiable_code.jl
sunxd3 403f610
Move interface.jl include before rules
sunxd3 344cb39
Move jl_genericmemory_owner frule/rrule to main source for DI tests
sunxd3 f269afe
Move higher-order differentiation rules to high_order_derivative_patc…
sunxd3 edc3aca
Split DI tests into first-order and second-order CI jobs
sunxd3 7ed82fd
Restrict second-order DI tests to hessian only
sunxd3 b10b372
Move FoR tests to rules/, add world age docs for _dual_mc
sunxd3 c1c2be9
Merge branch 'main' into sunxd/f-o-r
sunxd3 d2713d4
Bump version to 0.4.193 (#913)
Copilot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| # Forward-mode primitive for _build_rule! on LazyDerivedRule. | ||
| # This avoids differentiating through get_interpreter which has a ccall to jl_get_world_counter. | ||
| # The tangent propagation happens through the fwds_oc MistyClosure call, not the rule building. | ||
| # Reverse-over-reverse is not supported; an rrule!! that throws is provided below. | ||
| @is_primitive MinimalCtx Tuple{typeof(_build_rule!),LazyDerivedRule,Tuple} | ||
|
|
||
| function frule!!( | ||
| ::Dual{typeof(_build_rule!)}, | ||
| lazy_rule_dual::Dual{<:LazyDerivedRule{sig}}, | ||
| args_dual::Dual{<:Tuple}, | ||
| ) where {sig} | ||
| lazy_rule = primal(lazy_rule_dual) | ||
| lazy_tangent = tangent(lazy_rule_dual) | ||
| primal_args = primal(args_dual) | ||
| tangent_args = tangent(args_dual) | ||
|
|
||
| # Build rrule if not built (primal operation, no differentiation needed) | ||
| if !isdefined(lazy_rule, :rule) | ||
| interp = get_interpreter(ReverseMode) | ||
| lazy_rule.rule = build_rrule(interp, lazy_rule.mi; debug_mode=lazy_rule.debug_mode) | ||
| end | ||
| derived_rule = lazy_rule.rule | ||
|
|
||
| # Initialize the tangent of the derived rule if needed | ||
| rule_tangent_field = lazy_tangent.fields.rule | ||
| if !isdefined(rule_tangent_field, :tangent) | ||
| # Need to update the MutableTangent's fields with a new PossiblyUninitTangent | ||
| new_rule_tangent = PossiblyUninitTangent(zero_tangent(derived_rule)) | ||
| lazy_tangent.fields = merge(lazy_tangent.fields, (; rule=new_rule_tangent)) | ||
| rule_tangent_field = new_rule_tangent | ||
| end | ||
| derived_tangent = rule_tangent_field.tangent | ||
|
|
||
| # Forward-differentiate through the DerivedRule call. | ||
| # DerivedRule(args...) internally calls fwds_oc(args...) and returns (CoDual, Pullback) | ||
| fwds_oc = derived_rule.fwds_oc | ||
| fwds_oc_tangent = derived_tangent.fields.fwds_oc | ||
|
|
||
| # Handle varargs unflattening | ||
| isva = _isva(derived_rule) | ||
| nargs = derived_rule.nargs | ||
| N = length(primal_args) | ||
| uf_primal_args = __unflatten_codual_varargs(isva, primal_args, nargs) | ||
| uf_tangent_args = __unflatten_tangent_varargs(isva, tangent_args, nargs) | ||
|
|
||
| # Create dual args for frule!! call | ||
| dual_args = map(Dual, uf_primal_args, uf_tangent_args) | ||
|
|
||
| # Call frule!! on fwds_oc to get forward-differentiated result | ||
| dual_fwds_oc = Dual(fwds_oc, fwds_oc_tangent) | ||
| codual_result_dual = frule!!(dual_fwds_oc, dual_args...) | ||
|
|
||
| # Create Pullback and its tangent | ||
| pb_oc_ref = derived_rule.pb_oc_ref | ||
| pb_primal = Pullback(sig, pb_oc_ref, isva, N) | ||
| pb_tangent = Tangent((; pb_oc=derived_tangent.fields.pb_oc_ref)) | ||
|
|
||
| # Return Dual of (CoDual, Pullback) | ||
| primal_result = (primal(codual_result_dual), pb_primal) | ||
| tangent_result = (tangent(codual_result_dual), pb_tangent) | ||
| return Dual(primal_result, tangent_result) | ||
| end | ||
|
|
||
| # Helper to unflatten tangent args similar to __unflatten_codual_varargs | ||
| function __unflatten_tangent_varargs(isva::Bool, tangent_args, ::Val{nargs}) where {nargs} | ||
| isva || return tangent_args | ||
| group_tangent = tangent_args[nargs:end] | ||
| return (tangent_args[1:(nargs - 1)]..., group_tangent) | ||
| end | ||
|
|
||
| # Reverse-over-reverse is not supported. Throw an informative error. | ||
| function rrule!!( | ||
| ::CoDual{typeof(_build_rule!)}, ::CoDual{<:LazyDerivedRule}, ::CoDual{<:Tuple} | ||
| ) | ||
| throw( | ||
| ArgumentError( | ||
| "Reverse-over-reverse differentiation is not supported. " * | ||
| "Encountered attempt to differentiate _build_rule! in reverse mode.", | ||
| ), | ||
| ) | ||
| end | ||
|
|
||
| # TODO: This is a workaround for forward-over-reverse. Primitives in reverse mode can get | ||
| # inlined when building the forward rule, exposing internal ccalls that lack an frule!!. | ||
| # For example, `dataids` is a reverse-mode primitive, but inlining it exposes | ||
| # `jl_genericmemory_owner`. The proper fix is to prevent primitive inlining during | ||
| # forward-over-reverse by forwarding `inlining_policy` through `BugPatchInterpreter` to | ||
| # `MooncakeInterpreter` during `optimise_ir!`, but this causes allocation regressions. | ||
| # See https://github.com/chalk-lab/Mooncake.jl/pull/878 for details. | ||
| @static if VERSION >= v"1.11-" | ||
| function frule!!( | ||
| ::Dual{typeof(_foreigncall_)}, | ||
| ::Dual{Val{:jl_genericmemory_owner}}, | ||
| ::Dual{Val{Any}}, | ||
| ::Dual{Tuple{Val{Any}}}, | ||
| ::Dual{Val{0}}, | ||
| ::Dual{Val{:ccall}}, | ||
| a::Dual{<:Memory}, | ||
| ) | ||
| return zero_dual(ccall(:jl_genericmemory_owner, Any, (Any,), primal(a))) | ||
| end | ||
| function rrule!!( | ||
| ::CoDual{typeof(_foreigncall_)}, | ||
| ::CoDual{Val{:jl_genericmemory_owner}}, | ||
| ::CoDual{Val{Any}}, | ||
| ::CoDual{Tuple{Val{Any}}}, | ||
| ::CoDual{Val{0}}, | ||
| ::CoDual{Val{:ccall}}, | ||
| a::CoDual{<:Memory}, | ||
| ) | ||
| y = zero_fcodual(ccall(:jl_genericmemory_owner, Any, (Any,), primal(a))) | ||
| return y, NoPullback(ntuple(_ -> NoRData(), 7)) | ||
| end | ||
| end | ||
|
|
||
| # Avoid differentiating through AD infrastructure during second-order differentiation. | ||
| @zero_derivative MinimalCtx Tuple{ | ||
| typeof(Core.kwcall),NamedTuple,typeof(prepare_gradient_cache),Vararg | ||
| } | ||
| @zero_derivative MinimalCtx Tuple{ | ||
| typeof(Core.kwcall),NamedTuple,typeof(prepare_derivative_cache),Vararg | ||
| } | ||
| @zero_derivative MinimalCtx Tuple{ | ||
| typeof(Core.kwcall),NamedTuple,typeof(prepare_pullback_cache),Vararg | ||
| } | ||
| @zero_derivative MinimalCtx Tuple{typeof(zero_tangent),Any} | ||
|
|
||
| @static if VERSION < v"1.11-" | ||
| @generated function frule!!( | ||
| ::Dual{typeof(_foreigncall_)}, | ||
| ::Dual{Val{:jl_alloc_array_1d}}, | ||
| ::Dual{Val{Vector{P}}}, | ||
| ::Dual{Tuple{Val{Any},Val{Int}}}, | ||
| ::Dual{Val{0}}, | ||
| ::Dual{Val{:ccall}}, | ||
| ::Dual{Type{Vector{P}}}, | ||
| n::Dual{Int}, | ||
| args::Vararg{Dual}, | ||
| ) where {P} | ||
| T = tangent_type(P) | ||
| return quote | ||
| _n = primal(n) | ||
| y = ccall(:jl_alloc_array_1d, Vector{$P}, (Any, Int), Vector{$P}, _n) | ||
| dy = ccall(:jl_alloc_array_1d, Vector{$T}, (Any, Int), Vector{$T}, _n) | ||
| return Dual(y, dy) | ||
| end | ||
| end | ||
| @generated function frule!!( | ||
| ::Dual{typeof(_foreigncall_)}, | ||
| ::Dual{Val{:jl_alloc_array_2d}}, | ||
| ::Dual{Val{Matrix{P}}}, | ||
| ::Dual{Tuple{Val{Any},Val{Int},Val{Int}}}, | ||
| ::Dual{Val{0}}, | ||
| ::Dual{Val{:ccall}}, | ||
| ::Dual{Type{Matrix{P}}}, | ||
| m::Dual{Int}, | ||
| n::Dual{Int}, | ||
| args::Vararg{Dual}, | ||
| ) where {P} | ||
| T = tangent_type(P) | ||
| return quote | ||
| _m, _n = primal(m), primal(n) | ||
| y = ccall(:jl_alloc_array_2d, Matrix{$P}, (Any, Int, Int), Matrix{$P}, _m, _n) | ||
| dy = ccall(:jl_alloc_array_2d, Matrix{$T}, (Any, Int, Int), Matrix{$T}, _m, _n) | ||
| return Dual(y, dy) | ||
| end | ||
| end | ||
| @generated function frule!!( | ||
| ::Dual{typeof(_foreigncall_)}, | ||
| ::Dual{Val{:jl_alloc_array_3d}}, | ||
| ::Dual{Val{Array{P,3}}}, | ||
| ::Dual{Tuple{Val{Any},Val{Int},Val{Int},Val{Int}}}, | ||
| ::Dual{Val{0}}, | ||
| ::Dual{Val{:ccall}}, | ||
| ::Dual{Type{Array{P,3}}}, | ||
| l::Dual{Int}, | ||
| m::Dual{Int}, | ||
| n::Dual{Int}, | ||
| args::Vararg{Dual}, | ||
| ) where {P} | ||
| T = tangent_type(P) | ||
| return quote | ||
| _l, _m, _n = primal(l), primal(m), primal(n) | ||
| y = ccall( | ||
| :jl_alloc_array_3d, | ||
| Array{$P,3}, | ||
| (Any, Int, Int, Int), | ||
| Array{$P,3}, | ||
| _l, | ||
| _m, | ||
| _n, | ||
| ) | ||
| dy = ccall( | ||
| :jl_alloc_array_3d, | ||
| Array{$T,3}, | ||
| (Any, Int, Int, Int), | ||
| Array{$T,3}, | ||
| _l, | ||
| _m, | ||
| _n, | ||
| ) | ||
| return Dual(y, dy) | ||
| end | ||
| end | ||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.