Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mock Enzyme plugin #636

Draft
wants to merge 1 commit into
base: 09-26-make_gpuinterpreter_extensible
Choose a base branch
from

Conversation

vchuravy
Copy link
Member

@vchuravy vchuravy commented Oct 2, 2024

Make sure that the new infrastructure can handle the complicate song and dance Enzyme needs to do.

Copy link
Member Author

vchuravy commented Oct 2, 2024

Warning

This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
Learn more

This stack of pull requests is managed by Graphite. Learn more about stacking.

Join @vchuravy and the rest of your teammates on Graphite Graphite

@vchuravy vchuravy marked this pull request as ready for review October 2, 2024 09:29
@vchuravy vchuravy marked this pull request as draft October 2, 2024 09:30
Copy link

codecov bot commented Oct 2, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 61.78%. Comparing base (1638cc2) to head (d906034).

Additional details and impacted files
@@                            Coverage Diff                            @@
##           09-26-make_gpuinterpreter_extensible     #636       +/-   ##
=========================================================================
- Coverage                                 72.29%   61.78%   -10.51%     
=========================================================================
  Files                                        24       24               
  Lines                                      3353     3274       -79     
=========================================================================
- Hits                                       2424     2023      -401     
- Misses                                      929     1251      +322     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +157 to +184
args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args...
idx = 0

# 0. Enzyme proper: Desugar args
primal_args = args
primal_argtypes = match.spec_types.parameters[2:end]

adjoint_rt = info.rt
adjoint_args = args # TODO
adjoint_argtypes = primal_argtypes

# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
expr = Expr(:foreigncall,
"extern gpuc.lookup",
Ptr{Cvoid},
Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype
0,
QuoteNode(:llvmcall),
deferred_info.meta,
case.invoke,
primal_args...
)
ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid}))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aviatesk does this look correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle_call! isn’t meant to be overloaded, so I think this approach is preferred:

function Core.Compiler.src_inlining_policy(interp::GPUCompiler#=or EnzymeInterpreter?=#,
    @nospecialize(src), info::AutodiffCallInfo, stmt_flag::UInt32)
    # Goal:
    # The IR we want to return here is:
    # unpack the args ..
    # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...) 
    # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
    ir = Core.Compiler.IRCode() # contains a placeholder
    ...
    return ir
end

By overloading src_inlining_policy (or inlining_policy in older versions), we can apply this custom inlining to const-propped call sites and semi-concrete interpreted call sites as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I just realized that overloading src_inlining_policy wouldn't be enough.. We need to overload retrieve_ir_for_inlining too, but it doesn't take info::CallInfo, so maybe we need to tweak the interface..

But I believe this approach (overloading inlining_policy) works at least for pre-1.11.

Haha, then I have misunderstood the comment in:

https://github.com/JuliaLang/julia/blob/b9d9b69165493f6fc03870d975be05c67f14a30b/base/compiler/ssair/inlining.jl#L1668-L1669

It seems like I’ve ended up betraying my past self.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you once told me to extend handle_call

test/ptx_tests.jl Outdated Show resolved Hide resolved
@vchuravy vchuravy force-pushed the 09-26-make_gpuinterpreter_extensible branch from 052a118 to 4ef6019 Compare October 3, 2024 10:27
@vchuravy vchuravy force-pushed the 10-02-mock_enzyme_plugin branch 3 times, most recently from 3cfc5e1 to 1b8d203 Compare October 3, 2024 13:36
funcT = LLVM.called_type(call)
funcT = LLVM.FunctionType(LLVM.return_type(funcT), LLVM.parameters(funcT)[3:end])
direct_call = LLVM.call!(builder, funcT, target, ops[3:end - 1]) # why is the -1 necessary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the last op is the called value

end

function mock_enzyme!(@nospecialize(job), intrinsic, mod::LLVM.Module)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feedback from @wsmoses

This pass would need access to the compiled dictionary so that Enzyme can do a lookup from "emitted function" to Julia function. This would speak for a different phase order in #633

Or we add two callbacks "one early" and one "as part of the optimization pipeline"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think presently Enzyme does more than that as well. To rough approximation, it does the following as its entire compilation step

  1. "Before anything else happens"
    Set each llvmf to know about its worldage, methodinstance, and return type: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6703
    Make things inline ready [e.g. remove some tbaa which is broken] https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6705
    Rewrite some nvvm and related intrinsics https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6709
    mark various type unstable calls as inactive and change inttoptr'd ccalls into calls by name [storing the actual int value to later restore]: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6714
    Replace unhandled blas calls with fallback: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6755
    Annotate types and activities: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L6894
    Mark custom rules and related as noinline: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7035
    Lower calling convention of functino being differentiated: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7465

  2. Optimization pipeline: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7514
    Currently we use a modified optimization pipeline that also adds new passes which we found to be critical for performance (namely the new jl_inst_simpliy pass among others for interprocedural dead arg elim)

  3. AD
    First we run a julia analysis pass if the fn differentiated was a closure and requested we error if it is written https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7659
    Upgrading some memcpy's to load/store: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7799
    Actually generating the derivatives: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7801
    Inverse of the preserve nvvm pass above: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L7897
    Restoring the actual inttoptr => function name from above https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8074
    Other immediate post Enzyme passes: https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8085 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8087 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8105 and https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8129

Post Enzyme optimization (https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler.jl#L8131):

  • This includes running new passes to fix garbage collection/etc but presumably can just be scheduled

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more useful to think not about "what does Enzyme do", but "what information does Enzyme need and when".

The actual plugin PR is
#633 we can add more callbacks, as long as the callbacks only trigger if we detect that a marker function is present in the module.

The issue is that Enzyme requires an orthogonal composability axes than the job level.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. what are the interfaces you need, also what his historic and what are design constraints.

As an example, I would like the LLVM IR to be serializable without references to runtime pointers.
Otherwise caching will be impossible and CUDA.jl is already caching it's LLVM IR (post optimizations currently so we could just throw away the Enzyme metadata)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of the before Enzyme stuff is necessary atm. The caveat being that if there was a better way to make sure functions had better names than us doing the converstion from an inttoptr to the name and restoring it, that would be great. So we definitely need a "right after IR is generated pass plugin"

We presently need ideally other hooks into the optimization pipeline to add the other passes we run. For example https://github.com/EnzymeAD/Enzyme.jl/blob/3c0871d54b85e72a8b0dc2e2e5d48d2eb6e0a95e/src/compiler/optimize.jl#L2522

re 3) We need ways to know the full context of the julia AD request. I think this is basically just the gpucompiler config object [and the restore inttoptr name map]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a caveat of the implementation here is that we no longer create a EnzymeConfig object when inside a CUDA compilation. We still do when we do a host compilation, but wouldn't anymore for nested compilation.

We would need in the lowering from autododiff to llvmcall of autodiff capture the necessary information and encode that on the calling side (e.g. similar to the Clang plugin for Enzyme).

"right after IR is generated pass plugin"

Right, that's what I added in #633 (comment)

So keep in mind that we are talking about two compilation modes, first of all the inside CUDA.jl and the second one for nested.

For CPU nested you would still have a lot more control over code flow and could do all the blas things and inttoptr, but those are errors already on the GPU.

(; fargs, argtypes) = arginfo

@assert f === autodiff
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@assert f === autodiff

end

function abstract_call_known(meta::Nothing, interp::GPUInterpreter, f::typeof(autodiff),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abstract_call_known with this signature would never be called from Core.Compiler, so this overload would do nothing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a double extension problem. GPUCompiler provides has a GPUInterpreter.
Both GPUCompiler/CUDA and Enzyme want to modify the rules being applied.

But when we are applying Enzyme to CUDA code we must "inherit" the rules from CUDA, up to now Enzyme had a EnzymeInterpreter, but I would like to get rid of that.

But Enzyme rules shouldn't apply to CUDA code by default. However I also need to teach in an extensible matter GPUCompiler about autodiff such that:

function kernel(args...)
    autodiff(f, ....)
end
@cuda kernel(args...)

works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants