Skip to content

Commit 0296364

Browse files
committed
Mock Enzyme plugin
1 parent 052a118 commit 0296364

File tree

2 files changed

+161
-2
lines changed

2 files changed

+161
-2
lines changed

test/plugin_testsetup.jl

+140-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct NeverInlineMeta <: InlineStateMeta end
3636
import GPUCompiler: abstract_call_known, GPUInterpreter
3737
import Core.Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
3838
StmtInfo, AbsIntState, EFFECTS_TOTAL,
39-
MethodResultPure
39+
MethodResultPure, CallInfo, IRCode
4040

4141
function abstract_call_known(meta::InlineStateMeta, interp::GPUInterpreter, @nospecialize(f),
4242
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
@@ -69,5 +69,143 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
6969
return nothing
7070
end
7171

72+
struct MockEnzymeMeta end
7273

73-
end
74+
# Having to define this function is annoying
75+
# introduce `abstract type InferenceMeta`
76+
function inlining_handler(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo)
77+
return nothing
78+
end
79+
80+
function autodiff end
81+
82+
import GPUCompiler: DeferredCallInfo
83+
struct AutodiffCallInfo <: CallInfo
84+
rt
85+
info::DeferredCallInfo
86+
end
87+
88+
function abstract_call_known(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(f),
89+
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
90+
(; fargs, argtypes) = arginfo
91+
92+
if f === autodiff
93+
if length(argtypes) <= 1
94+
@static if VERSION < v"1.11.0-"
95+
return CallMeta(Union{}, Effects(), NoCallInfo())
96+
else
97+
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
98+
end
99+
end
100+
101+
other_fargs = fargs === nothing ? nothing : fargs[2:end]
102+
other_arginfo = ArgInfo(other_fargs, argtypes[2:end])
103+
call = Core.Compiler.abstract_call(interp, other_arginfo, si, sv, max_methods)
104+
callinfo = DeferredCallInfo(MockEnzymeMeta(), call.rt, call.info)
105+
106+
# Real Enzyme must compute `rt` and `exct` according to enzyme semantics
107+
# and likely perform a unwrapping of fargs...
108+
rt = call.rt
109+
110+
# TODO: Edges? Effects?
111+
@static if VERSION < v"1.11.0-"
112+
# Can't use call.effects since otherwise this call might be just replaced with rt
113+
return CallMeta(rt, Effects(), AutodiffCallInfo(rt, callinfo))
114+
else
115+
return CallMeta(rt, call.exct, Effects(), AutodiffCallInfo(rt, callinfo))
116+
end
117+
end
118+
119+
return nothing
120+
end
121+
122+
import Core.Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature
123+
124+
# We really need a Compiler stdlib
125+
Base.getindex(ir::IRCode, i) = Core.Compiler.getindex(ir, i)
126+
Base.setindex!(inst::Instruction, val, i) = Core.Compiler.setindex!(inst, val, i)
127+
128+
const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
129+
function Core.Compiler.handle_call!(todo::Vector{Pair{Int,Any}}, ir::IRCode, stmt_idx::Int,
130+
stmt::Expr, info::AutodiffCallInfo, flag::FlagType,
131+
sig::Signature, state::InliningState)
132+
133+
# Goal:
134+
# The IR we want to inline here is:
135+
# unpack the args ..
136+
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
137+
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
138+
139+
# 0. Obtain primal mi from DeferredCallInfo
140+
# TODO: remove this code duplication
141+
deferred_info = info.info
142+
minfo = deferred_info.info
143+
results = minfo.results
144+
if length(results.matches) != 1
145+
return nothing
146+
end
147+
match = only(results.matches)
148+
149+
# lookup the target mi with correct edge tracking
150+
# TODO: Effects?
151+
case = Core.Compiler.compileable_specialization(
152+
match, Core.Compiler.Effects(), Core.Compiler.InliningEdgeTracker(state), info)
153+
@assert case isa Core.Compiler.InvokeCase
154+
@assert stmt.head === :call
155+
156+
# Now create the IR we want to inline
157+
ir = Core.Compiler.IRCode() # contains a placeholder
158+
args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args...
159+
idx = 0
160+
161+
# 0. Enzyme proper: Desugar args
162+
primal_args = args
163+
primal_argtypes = match.spec_types.parameters[2:end]
164+
165+
adjoint_rt = info.rt
166+
adjoint_args = args # TODO
167+
adjoint_argtypes = primal_argtypes
168+
169+
# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
170+
expr = Expr(:foreigncall,
171+
"extern gpuc.lookup",
172+
Ptr{Cvoid},
173+
Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype
174+
0,
175+
QuoteNode(:llvmcall),
176+
deferred_info.meta,
177+
case.invoke,
178+
primal_args...
179+
)
180+
ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid}))
181+
182+
# 2. Call to magic `__autodiff`
183+
expr = Expr(:foreigncall,
184+
"extern __autodiff",
185+
adjoint_rt,
186+
Core.svec(Any, Ptr{Cvoid}, adjoint_argtypes...),
187+
0,
188+
QuoteNode(:llvmcall),
189+
ptr,
190+
adjoint_args...
191+
)
192+
ret = insert_node!(ir, idx, NewInstruction(expr, adjoint_rt))
193+
194+
# Finally replace placeholder return
195+
ir[Core.SSAValue(1)][:inst] = Core.ReturnNode(ret)
196+
ir[Core.SSAValue(1)][:type] = Ptr{Cvoid}
197+
198+
ir = Core.Compiler.compact!(ir)
199+
200+
# which mi to use here?
201+
# push inlining todos
202+
# TODO: Effects
203+
# aviatesk mentioned using inlining_policy instead...
204+
itodo = Core.Compiler.InliningTodo(case.invoke, ir, Core.Compiler.Effects())
205+
@assert itodo.linear_inline_eligible
206+
push!(todo, (stmt_idx=>itodo))
207+
208+
return nothing
209+
end
210+
211+
end #module

test/ptx_tests.jl

+21
Original file line numberDiff line numberDiff line change
@@ -504,4 +504,25 @@ end
504504
ir = sprint(io->PTX.code_llvm(io, kernel_inline, Tuple{Ptr{Int64}, Int64}, meta=Plugin.NeverInlineMeta()))
505505
@test occursin("call fastcc i64 @julia_inline", ir)
506506
end
507+
508+
@testset "Mock Enzyme" begin
509+
function f(x)
510+
x^2
511+
end
512+
513+
function kernel(a, x)
514+
y = Plugin.autodiff(f, x)
515+
unsafe_store!(a, y)
516+
nothing
517+
end
518+
519+
520+
@show PTX.code_typed(kernel, Tuple{Ptr{Float64}, Float64}, meta=Plugin.MockEnzymeMeta())
521+
522+
# FIXME: the fact that meta is necessary here almost invalidates that extension mechanism
523+
# we somehow need to be able to add this kind of "autodiff" abs int handling.
524+
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, meta=Plugin.MockEnzymeMeta()))
525+
@test occursin("call double @__autodiff", ir)
526+
end
527+
507528
end #testitem

0 commit comments

Comments
 (0)