Skip to content

Conversation

@MasonProtter
Copy link
Collaborator

@MasonProtter MasonProtter commented Dec 21, 2025

This is unfinished and very rough around the edges, but I thought I should open it up so people can take a look at it, give feedback, and follow along. The idea here is to lift the creation of AD rules into generated function bodies so that it doesn't need a separate preparation step. I'm working on both the ForwardMode implementation and the ReverseMode implementation, but at the time of writing, this PR only has the forward mode commits because I'm still struggling with some things in reverse mode.

Prior approaches I've tried

  • In Allocations when invoking a CodeInstance from a generated function JuliaLang/julia#60348, what I attempted was a much simpler approach where I just made a generated function that called build_rrule/build_frule in the generator to create the opaque closures, and then interpolated that opaque closure into the generated code. This seemed sensible enough to me, but unfortunately the compiler does not like it when opaque closures are interpolated into function bodies and so calling them incurs some minor allocations. Jameson and Keno both said this likely just won't work.

  • Another thing I tried (but never submitted a PR for) was to use the Extend invoke to accept CodeInstance JuliaLang/julia#56660 mechanism to generate a (co-)dual-IR CodeInstance in the generated function, and then interpolate and invoke that CodeInstance. Medium-to-long term, I still think this is likely the right way to do things, but currently this mechanism also incurs allocations when invoking the CodeInstance: Allocations when invoke-ing a constant CodeInstance JuliaLang/julia#60441. We'll see, but I suspect that this indicates this invoke mechanism needs support at the LLVM level. (update: I've found a potential fix for this: Add optimizer support for invoke(f, ::CodeInstance, args...) JuliaLang/julia#60442).

    • A remaining problem is that I think requires a bit of a re-design of the codegen pipeline since we'd be more closely coupled to julia's own compiler pipeline. We'd basically need to do everything in the Compiler.optimize step, which seems generally kinda hard and different from what we currently do, which is generate the primal IRCode, copy it, modify it, and then stick it in an OpaqueClosure. I'm not entirely sure how to approach this yet.

This PR's Approach

This PR takes a slightly more old-fashioned approach, which is that I essentially do a more convoluted version of

# This is pseudo-code
@generated function generated_frule!!(args...)
    dual_ir::CodeInfo = make_dual_ir(args...; world)
    ci::CodeInfo = irc_to_codeinfo(dual_ir)
    strip_type_information!(ci)
    ci
end

That is, at compile time we look up the dual IR which is from the end of the compiler pipeline, then we take that dual IR, erase all the type information from it, and then pass it back through the compiler.

This is doubtless something that Jameson would say is not allowed, but for what it's worth, Diffractor.jl actually also does essentially the same thing (https://github.com/JuliaDiff/Diffractor.jl/blob/main/src/stage1/recurse.jl#L298-L312), a fact that I discovered after trying out this mechanism.

Some assorted notes:

  1. I was able to remove LazyFRule and DynamicFRule, they're unnecessary here. This new mechanism handles non-inlined function calls (LazyFRule), and dynamic dispatches (DynamicFRule) just fine.
  2. I make the generated function object capable of storing the captured variables like before, but I'm not actually sure it's needed for forwards mode without the Lazy/Dynamic FRule types (it's needed for reverse mode though)
  3. It's fast without preparation!
julia> f(x) = 2*x^2 - 3x;

# No preparation step required!

julia> @btime derivative(f, AutoMooncakeForward(), x)  setup=x=rand()
  2.023 ns (0 allocations: 0 bytes)
0.8623908025118019
  1. There's no rule-generation in the generated code:
julia> @code_typed Mooncake.generated_frule!!(Mooncake.zero_dual(f), Dual(2.0, 1.0))
CodeInfo(
1%1  =   builtin Core.getfield(args, 2)::Dual{Float64, Float64}%2  =   builtin Base.getfield(%1, :primal)::Float64%3  =   builtin Base.getfield(%1, :primal)::Float64%4  = intrinsic (Core.Intrinsics.mul_float)(%2, %3)::Float64%5  =   builtin Base.getfield(%1, :primal)::Float64%6  =   builtin Base.getfield(%1, :tangent)::Float64%7  = intrinsic (Core.Intrinsics.mul_float)(%5, %6)::Float64%8  =   builtin Base.getfield(%1, :primal)::Float64%9  =   builtin Base.getfield(%1, :tangent)::Float64%10 = intrinsic (Core.Intrinsics.mul_float)(%8, %9)::Float64%11 = intrinsic (Core.Intrinsics.add_float)(%7, %10)::Float64%12 = intrinsic (Core.Intrinsics.mul_float)(2.0, %4)::Float64%13 = intrinsic (Core.Intrinsics.mul_float)(2.0, %11)::Float64%14 = intrinsic (Core.Intrinsics.mul_float)(%4, 0.0)::Float64%15 = intrinsic (Core.Intrinsics.add_float)(%13, %14)::Float64%16 =   builtin Base.getfield(%1, :primal)::Float64%17 = intrinsic (Core.Intrinsics.mul_float)(3.0, %16)::Float64%18 =   builtin Base.getfield(%1, :tangent)::Float64%19 = intrinsic (Core.Intrinsics.mul_float)(3.0, %18)::Float64%20 =   builtin Base.getfield(%1, :primal)::Float64%21 = intrinsic (Core.Intrinsics.mul_float)(%20, 0.0)::Float64%22 = intrinsic (Core.Intrinsics.add_float)(%19, %21)::Float64%23 = intrinsic (Core.Intrinsics.sub_float)(%12, %17)::Float64%24 = intrinsic (Core.Intrinsics.sub_float)(%15, %22)::Float64%25 = %new(Dual{Float64, Float64}, %23, %24)::Dual{Float64, Float64}
└──       return %25
) => Dual{Float64, Float64}

which means that it's both easy and fast to do higher order AD:

julia> derivative(AutoMooncakeForward(), 1.0) do x
           derivative(AutoMooncakeForward(), x) do x
               f(x)
           end
       end
4.0

julia> @btime derivative(AutoMooncakeForward(), x) do x
           derivative(AutoMooncakeForward(), x) do x
               sin(x)^2 + 2x
           end
       end setup=(x=rand())
  10.370 ns (0 allocations: 0 bytes)
1.8145635969158271
  1. Even though higher order AD does work right away out of the box, we probably will still want to make some adjustments for higher order AD anyways, because the inner-AD pass will typically end up erasing function barriers which should have been is_primitive
  2. Even though we don't use tags, this mechanism is immune to "perturbation confusion". For example, here are some old examples of perturbation confusion from other packages:
julia> D(f, x) = derivative(f, AutoMooncakeForward(), x);

julia> D(x -> x * D(y -> x + y, 1.0), 1.0)
1.0 # correct
julia> let x = 2.0
           D(1.0) do y
               x = x*y
           end
           D(1.0) do y
               x = y*x
           end
       end
2.0 # correct
  1. Some stuff is not working and I need to figure out why. For example, closures with boxed variables mostly work:
julia> D(1.0) do x
           foreach(1:3) do i
               x *= i
           end
           x
       end
6.0

but for some reason this map call blows up if the boxed variable is the final value:

ulia> D(1.0) do x
           map(1:3) do i
               x *= i
               x
           end
           x
       end
Unreachable reached at 0x7f8cf9960780

[245757] signal 4 (2): Illegal instruction
in expression starting at REPL[4]:1
collect_to! at ./array.jl:852 [inlined]
GeneratedFRule at /home/masonp/Nextcloud/Julia/MooncakeWork/Mooncake/src/interpreter/forward_mode.jl:0
generated_frule!! at ./none (unknown line) [inlined]
collect_to_with_first! at ./array.jl:826 [inlined]
generated_frule!! at ./none (unknown line) [inlined]
generated_frule!! at /home/masonp/Nextcloud/Julia/MooncakeWork/Mooncake/src/interpreter/forward_mode.jl:0
unknown function (ip: 0x7f8cf9960a0b) at (unknown file)
_collect at ./array.jl:820 [inlined]
GeneratedFRule at /home/masonp/Nextcloud/Julia/MooncakeWork/Mooncake/src/interpreter/forward_mode.jl:0
unknown function (ip: 0x7f8cf993b0bf) at (unknown file)
generated_frule!! at ./none (unknown line) [inlined]
generated_frule!! at /home/masonp/Nextcloud/Julia/MooncakeWork/Mooncake/src/interpreter/forward_mode.jl:0
unknown function (ip: 0x7f8cf993a91f) at (unknown file)
collect_similar at ./array.jl:732 [inlined]
map at ./abstractarray.jl:3372 [inlined]
#9 at ./REPL[4]:2 [inlined]
GeneratedFRule at /home/masonp/Nextcloud/Julia/MooncakeWork/Mooncake/src/interpreter/forward_mode.jl:0
unknown function (ip: 0x7f8cf993a6f3) at (unknown file)
generated_frule!! at ./none (unknown line) [inlined]
value_and_derivative!! at /home/masonp/Nextcloud/Julia/MooncakeWork/Mooncake/src/interface.jl:610 [inlined]
[...]

This has to do with uninit_duals not being properly initialized, but I'm not totally sure where I went wrong that it's not occuring correctly.

@MasonProtter
Copy link
Collaborator Author

Okay, some progress:

  • I figured out what was causing those segfaults at the end of the above post, and that's now fixed. (I needed to remove some PiNodes that had appeared in the IR). This also unblocked some problems I was having in reverse mode, but it's still not quite working.
  • Backedge tracking is now pretty good, it'll now recompile the rule if any method in the primal body gets invalidated, even if it's inlined:
julia> f(x) = sin(x) + 10x;

julia> g(x) = f(x);

julia> derivative(g, AutoMooncakeForward(), 1.0)
10.54030230586814

julia> f(x) = sin(x) + 1;

julia> derivative(g, AutoMooncakeForward(), 1.0)
0.5403023058681398

however, I still need to do some more work to get the method to be invalidated whenever the user defines a frule!! method or a _is_primitive method which would affect the code body.

@yebai
Copy link
Member

yebai commented Jan 9, 2026

One core benefit of a compiled rule is reduced time-to-first-gradient through support for precompilation. This appears achievable using the approach proposed in this PR. Other benefits, such as higher-order nesting, are appealing; however, forward-over-reverse is, in principle, the preferred approach for Hessians. Support beyond second order is outside the scope of Mooncake.

One important caution is that generated functions are more vulnerable to specialization explosion than opaque closures, which can significantly increase compilation time—a problem encountered in Zygote and a key differentiator of Diffractor. The table below clarifies the primary trade-off. This choice directly governs compilation cost, precompilation behavior, and compatibility with compiler-based transformations.

The table was created with the assistance of ChatGPT and may not be entirely accurate.

Feature Generated function Regular closure Opaque closure (misty closure)
Specialize on arguments Yes Yes Limited
Specialize on captured state NIL Yes No
Inlining allowed Yes Yes No
Number of codegen instances Large Large Constant
Precompilation friendliness ⚠️
AD / compiler-transform friendly ⚠️

@yebai
Copy link
Member

yebai commented Jan 15, 2026

That is, at compile time we look up the dual IR which is from the end of the compiler pipeline, then we take that dual IR, erase all the type information from it, and then pass it back through the compiler.

This won't matter for forward-mode but would have serious performance implications for reverse-mode. In a nutshell, Mooncake fixes Zygote’s main performance issue by differentiating after type inference rather than before.

Core difference in the compiler pipeline

  • Zygote

    • Rewrites lowered, untyped SSA IR
    • Generates untyped adjoint (dual) IR
    • Relies on Julia’s type inference to recover types afterward
  • Mooncake

    • Runs after Julia’s type inference
    • Rewrites typed SSA IR
    • Emits typed adjoint IR with resolved dispatch

Why this matters for performance

  • Inference

    • Zygote: must infer a much larger, higher-order reverse-mode program → often fails
    • Mooncake: inference is already done → little recovery needed
  • Pullbacks

    • Zygote: higher-order, closure-heavy, frequently abstract
    • Mooncake: concrete, often stack-allocated, optimisable
  • Optimization

    • Zygote: Any blocks LLVM optimizations
    • Mooncake: concrete types enable aggressive optimisation

@MasonProtter
Copy link
Collaborator Author

The type information is all still present during the IR construction and the final performance is the same. We use the type infromation to derive the IR, optimize it, and then at the very end erase it, but the compiler just re-infers it.

@yebai
Copy link
Member

yebai commented Jan 15, 2026

Thanks for the clarifiction, very helpful.

but the compiler just re-infers it.

This is the tricky step. Are you sure the compiler can re-infer types for the adjoint? It would be very helpful to eloborate on that.

@MasonProtter
Copy link
Collaborator Author

Yes, it's very easy for the compiler to do because we already hand it the post-optimized form. That means that when the compiler tries to type infer and optimize the IR, it's already in exactly the form that it should ideally end up in at the end.

The type inference step just ends up being it filling in a bunch of empty slots that trivially follow from one another, and the control flow and whatnot are already in an optimal state for the compiler, so most of the transformations just automatically pass through.

@yebai
Copy link
Member

yebai commented Jan 15, 2026

The type inference step just ends up being it filling in a bunch of empty slots that trivially follow from one another, and the control flow and whatnot are already in an optimal state for the compiler, so most of the transformations just automatically pass through.

I don't understand this, but it may be hard to explain more.

It's worth noting that since generated functions need to be pure, they may have different behaviour compared to opaque closure wrt world-age and global variables. Hopefully, we will understand these differences better when there is a prototype

@MasonProtter
Copy link
Collaborator Author

MasonProtter commented Jan 16, 2026

I don't understand this, but it may be hard to explain more.

Yeah, I can show some examples once the reverse mode stuff is working

It's worth noting that since generated functions need to be pure, they may have different behaviour compared to opaque closure wrt world-age and global variables. Hopefully, we will understand these differences better when there is a prototype

Anything that was supported before should also be supported in this version, including closures and references to global variables. The trick is that the user's primal code already creates the closure objects, and mooncake then can just create coduals and dual instances of those pre-existing types, so there's not really any generated function restrictions.

I believe the user can also advance the worldage in the function being differentiated without issue.

@yebai yebai force-pushed the main branch 3 times, most recently from dfec3c4 to 7b2f7ba Compare January 24, 2026 21:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants