-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mock Enzyme plugin #636
base: 09-26-make_gpuinterpreter_extensible
Are you sure you want to change the base?
Mock Enzyme plugin #636
Conversation
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
7c4dc5c
to
e00cfb7
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## 09-26-make_gpuinterpreter_extensible #636 +/- ##
=========================================================================
- Coverage 72.29% 61.78% -10.51%
=========================================================================
Files 24 24
Lines 3353 3274 -79
=========================================================================
- Hits 2424 2023 -401
- Misses 929 1251 +322 ☔ View full report in Codecov by Sentry. |
e00cfb7
to
0296364
Compare
args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args... | ||
idx = 0 | ||
|
||
# 0. Enzyme proper: Desugar args | ||
primal_args = args | ||
primal_argtypes = match.spec_types.parameters[2:end] | ||
|
||
adjoint_rt = info.rt | ||
adjoint_args = args # TODO | ||
adjoint_argtypes = primal_argtypes | ||
|
||
# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call | ||
expr = Expr(:foreigncall, | ||
"extern gpuc.lookup", | ||
Ptr{Cvoid}, | ||
Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype | ||
0, | ||
QuoteNode(:llvmcall), | ||
deferred_info.meta, | ||
case.invoke, | ||
primal_args... | ||
) | ||
ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aviatesk does this look correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
handle_call!
isn’t meant to be overloaded, so I think this approach is preferred:
function Core.Compiler.src_inlining_policy(interp::GPUCompiler#=or EnzymeInterpreter?=#,
@nospecialize(src), info::AutodiffCallInfo, stmt_flag::UInt32)
# Goal:
# The IR we want to return here is:
# unpack the args ..
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
ir = Core.Compiler.IRCode() # contains a placeholder
...
return ir
end
By overloading src_inlining_policy
(or inlining_policy
in older versions), we can apply this custom inlining to const-propped call sites and semi-concrete interpreted call sites as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, then I have misunderstood the comment in:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I just realized that overloading src_inlining_policy
wouldn't be enough.. We need to overload retrieve_ir_for_inlining
too, but it doesn't take info::CallInfo
, so maybe we need to tweak the interface..
But I believe this approach (overloading inlining_policy
) works at least for pre-1.11.
Haha, then I have misunderstood the comment in:
It seems like I’ve ended up betraying my past self.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think you once told me to extend handle_call
052a118
to
4ef6019
Compare
3cfc5e1
to
1b8d203
Compare
funcT = LLVM.called_type(call) | ||
funcT = LLVM.FunctionType(LLVM.return_type(funcT), LLVM.parameters(funcT)[3:end]) | ||
direct_call = LLVM.call!(builder, funcT, target, ops[3:end - 1]) # why is the -1 necessary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the last op is the called value
end | ||
|
||
function mock_enzyme!(@nospecialize(job), intrinsic, mod::LLVM.Module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think presently Enzyme does more than that as well. To rough approximation, it does the following as its entire compilation step
-
"Before anything else happens"
Set each llvmf to know about its worldage, methodinstance, and return type: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6703
Make things inline ready [e.g. remove some tbaa which is broken] https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6705
Rewrite some nvvm and related intrinsics https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6709
mark various type unstable calls as inactive and change inttoptr'd ccalls into calls by name [storing the actual int value to later restore]: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6714
Replace unhandled blas calls with fallback: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6755
Annotate types and activities: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6894
Mark custom rules and related as noinline: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7035
Lower calling convention of functino being differentiated: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7465 -
Optimization pipeline: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7514
Currently we use a modified optimization pipeline that also adds new passes which we found to be critical for performance (namely the new jl_inst_simpliy pass among others for interprocedural dead arg elim) -
AD
First we run a julia analysis pass if the fn differentiated was a closure and requested we error if it is written https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7659
Upgrading some memcpy's to load/store: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7799
Actually generating the derivatives: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7801
Inverse of the preserve nvvm pass above: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7897
Restoring the actual inttoptr => function name from above https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8074
Other immediate post Enzyme passes: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8085 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8087 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8105 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8129
Post Enzyme optimization (https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8131):
- This includes running new passes to fix garbage collection/etc but presumably can just be scheduled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be more useful to think not about "what does Enzyme do", but "what information does Enzyme need and when".
The actual plugin PR is
#633 we can add more callbacks, as long as the callbacks only trigger if we detect that a marker function is present in the module.
The issue is that Enzyme requires an orthogonal composability axes than the job level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. what are the interfaces you need, also what his historic and what are design constraints.
As an example, I would like the LLVM IR to be serializable without references to runtime pointers.
Otherwise caching will be impossible and CUDA.jl is already caching it's LLVM IR (post optimizations currently so we could just throw away the Enzyme metadata)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all of the before Enzyme stuff is necessary atm. The caveat being that if there was a better way to make sure functions had better names than us doing the converstion from an inttoptr to the name and restoring it, that would be great. So we definitely need a "right after IR is generated pass plugin"
We presently need ideally other hooks into the optimization pipeline to add the other passes we run. For example https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler/optimize.jl#L2522
re 3) We need ways to know the full context of the julia AD request. I think this is basically just the gpucompiler config object [and the restore inttoptr name map]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So a caveat of the implementation here is that we no longer create a EnzymeConfig object when inside a CUDA compilation. We still do when we do a host compilation, but wouldn't anymore for nested compilation.
We would need in the lowering from autododiff
to llvmcall of autodiff capture the necessary information and encode that on the calling side (e.g. similar to the Clang plugin for Enzyme).
"right after IR is generated pass plugin"
Right, that's what I added in #633 (comment)
So keep in mind that we are talking about two compilation modes, first of all the inside CUDA.jl and the second one for nested.
For CPU nested you would still have a lot more control over code flow and could do all the blas things and inttoptr, but those are errors already on the GPU.
(; fargs, argtypes) = arginfo | ||
|
||
@assert f === autodiff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@assert f === autodiff |
end | ||
|
||
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, f::typeof(autodiff), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abstract_call_known
with this signature would never be called from Core.Compiler
, so this overload would do nothing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a double extension problem. GPUCompiler provides has a GPUInterpreter.
Both GPUCompiler/CUDA and Enzyme want to modify the rules being applied.
But when we are applying Enzyme to CUDA code we must "inherit" the rules from CUDA, up to now Enzyme had a EnzymeInterpreter, but I would like to get rid of that.
But Enzyme rules shouldn't apply to CUDA code by default. However I also need to teach in an extensible matter GPUCompiler about autodiff
such that:
function kernel(args...)
autodiff(f, ....)
end
@cuda kernel(args...)
works.
4ef6019
to
1638cc2
Compare
1b8d203
to
d906034
Compare
Make sure that the new infrastructure can handle the complicate song and dance Enzyme needs to do.