-
Notifications
You must be signed in to change notification settings - Fork 29
[Very WIP] Compile-time rule generation take-two #900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Okay, some progress:
julia> f(x) = sin(x) + 10x;
julia> g(x) = f(x);
julia> derivative(g, AutoMooncakeForward(), 1.0)
10.54030230586814
julia> f(x) = sin(x) + 1;
julia> derivative(g, AutoMooncakeForward(), 1.0)
0.5403023058681398however, I still need to do some more work to get the method to be invalidated whenever the user defines a |
|
One core benefit of a compiled rule is reduced time-to-first-gradient through support for precompilation. This appears achievable using the approach proposed in this PR. Other benefits, such as higher-order nesting, are appealing; however, forward-over-reverse is, in principle, the preferred approach for Hessians. Support beyond second order is outside the scope of Mooncake. One important caution is that generated functions are more vulnerable to specialization explosion than opaque closures, which can significantly increase compilation time—a problem encountered in Zygote and a key differentiator of Diffractor. The table below clarifies the primary trade-off. This choice directly governs compilation cost, precompilation behavior, and compatibility with compiler-based transformations. The table was created with the assistance of ChatGPT and may not be entirely accurate.
|
Signed-off-by: Mason Protter <[email protected]>
This won't matter for forward-mode but would have serious performance implications for reverse-mode. In a nutshell, Mooncake fixes Zygote’s main performance issue by differentiating after type inference rather than before. Core difference in the compiler pipeline
Why this matters for performance
|
|
The type information is all still present during the IR construction and the final performance is the same. We use the type infromation to derive the IR, optimize it, and then at the very end erase it, but the compiler just re-infers it. |
|
Thanks for the clarifiction, very helpful.
This is the tricky step. Are you sure the compiler can re-infer types for the adjoint? It would be very helpful to eloborate on that. |
|
Yes, it's very easy for the compiler to do because we already hand it the post-optimized form. That means that when the compiler tries to type infer and optimize the IR, it's already in exactly the form that it should ideally end up in at the end. The type inference step just ends up being it filling in a bunch of empty slots that trivially follow from one another, and the control flow and whatnot are already in an optimal state for the compiler, so most of the transformations just automatically pass through. |
I don't understand this, but it may be hard to explain more. It's worth noting that since generated functions need to be pure, they may have different behaviour compared to opaque closure wrt world-age and global variables. Hopefully, we will understand these differences better when there is a prototype |
Yeah, I can show some examples once the reverse mode stuff is working
Anything that was supported before should also be supported in this version, including closures and references to global variables. The trick is that the user's primal code already creates the closure objects, and mooncake then can just create coduals and dual instances of those pre-existing types, so there's not really any generated function restrictions. I believe the user can also advance the worldage in the function being differentiated without issue. |
dfec3c4 to
7b2f7ba
Compare
This is unfinished and very rough around the edges, but I thought I should open it up so people can take a look at it, give feedback, and follow along. The idea here is to lift the creation of AD rules into generated function bodies so that it doesn't need a separate preparation step. I'm working on both the
ForwardModeimplementation and theReverseModeimplementation, but at the time of writing, this PR only has the forward mode commits because I'm still struggling with some things in reverse mode.Prior approaches I've tried
In Allocations when invoking a
CodeInstancefrom a generated function JuliaLang/julia#60348, what I attempted was a much simpler approach where I just made a generated function that calledbuild_rrule/build_frulein the generator to create the opaque closures, and then interpolated that opaque closure into the generated code. This seemed sensible enough to me, but unfortunately the compiler does not like it when opaque closures are interpolated into function bodies and so calling them incurs some minor allocations. Jameson and Keno both said this likely just won't work.Another thing I tried (but never submitted a PR for) was to use the Extend
invoketo accept CodeInstance JuliaLang/julia#56660 mechanism to generate a (co-)dual-IRCodeInstancein the generated function, and then interpolate andinvokethatCodeInstance. Medium-to-long term, I still think this is likely the right way to do things, but currently this mechanism also incurs allocations when invoking theCodeInstance: Allocations wheninvoke-ing a constantCodeInstanceJuliaLang/julia#60441. We'll see, but I suspect that this indicates thisinvokemechanism needs support at the LLVM level. (update: I've found a potential fix for this: Add optimizer support forinvoke(f, ::CodeInstance, args...)JuliaLang/julia#60442).Compiler.optimizestep, which seems generally kinda hard and different from what we currently do, which is generate the primalIRCode, copy it, modify it, and then stick it in anOpaqueClosure. I'm not entirely sure how to approach this yet.This PR's Approach
This PR takes a slightly more old-fashioned approach, which is that I essentially do a more convoluted version of
That is, at compile time we look up the dual IR which is from the end of the compiler pipeline, then we take that dual IR, erase all the type information from it, and then pass it back through the compiler.
This is doubtless something that Jameson would say is not allowed, but for what it's worth, Diffractor.jl actually also does essentially the same thing (https://github.com/JuliaDiff/Diffractor.jl/blob/main/src/stage1/recurse.jl#L298-L312), a fact that I discovered after trying out this mechanism.
Some assorted notes:
LazyFRuleandDynamicFRule, they're unnecessary here. This new mechanism handles non-inlined function calls (LazyFRule), and dynamic dispatches (DynamicFRule) just fine.which means that it's both easy and fast to do higher order AD:
is_primitivebut for some reason this
mapcall blows up if the boxed variable is the final value:This has to do with
uninit_duals not being properly initialized, but I'm not totally sure where I went wrong that it's not occuring correctly.