Skip to content

Commit d906034

Browse files
committed
Mock Enzyme plugin
1 parent 1638cc2 commit d906034

File tree

3 files changed

+199
-4
lines changed

3 files changed

+199
-4
lines changed

test/native_tests.jl

-2
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,6 @@ end
171171
# smoke test
172172
job, _ = Native.create_job(eval(kernel), (Int64,))
173173

174-
# TODO: Add a `kernel=true` test
175-
176174
ci, rt = only(GPUCompiler.code_typed(job))
177175
@test rt === Ptr{Cvoid}
178176

test/plugin_testsetup.jl

+175-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct NeverInlineMeta <: InlineStateMeta end
3737
import GPUCompiler: abstract_call_known, GPUInterpreter
3838
import Core.Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
3939
StmtInfo, AbsIntState, EFFECTS_TOTAL,
40-
MethodResultPure
40+
MethodResultPure, CallInfo, IRCode
4141

4242
function abstract_call_known(meta::InlineStateMeta, interp::GPUInterpreter, @nospecialize(f),
4343
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
@@ -70,5 +70,178 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
7070
return nothing
7171
end
7272

73+
struct MockEnzymeMeta end
7374

74-
end
75+
# Having to define this function is annoying
76+
# introduce `abstract type InferenceMeta`
77+
function inlining_handler(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo)
78+
return nothing
79+
end
80+
81+
function autodiff end
82+
83+
import GPUCompiler: DeferredCallInfo
84+
struct AutodiffCallInfo <: CallInfo
85+
rt
86+
info::DeferredCallInfo
87+
end
88+
89+
function abstract_call_known(meta::Nothing, interp::GPUInterpreter, f::typeof(autodiff),
90+
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
91+
(; fargs, argtypes) = arginfo
92+
93+
@assert f === autodiff
94+
if length(argtypes) <= 1
95+
@static if VERSION < v"1.11.0-"
96+
return CallMeta(Union{}, Effects(), NoCallInfo())
97+
else
98+
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
99+
end
100+
end
101+
102+
other_fargs = fargs === nothing ? nothing : fargs[2:end]
103+
other_arginfo = ArgInfo(other_fargs, argtypes[2:end])
104+
# TODO: Ought we not change absint to use MockEnzymeMeta(), otherwise we fill the cache for nothing.
105+
call = Core.Compiler.abstract_call(interp, other_arginfo, si, sv, max_methods)
106+
callinfo = DeferredCallInfo(MockEnzymeMeta(), call.rt, call.info)
107+
108+
# Real Enzyme must compute `rt` and `exct` according to enzyme semantics
109+
# and likely perform a unwrapping of fargs...
110+
rt = call.rt
111+
112+
# TODO: Edges? Effects?
113+
@static if VERSION < v"1.11.0-"
114+
# Can't use call.effects since otherwise this call might be just replaced with rt
115+
return CallMeta(rt, Effects(), AutodiffCallInfo(rt, callinfo))
116+
else
117+
return CallMeta(rt, call.exct, Effects(), AutodiffCallInfo(rt, callinfo))
118+
end
119+
end
120+
121+
function abstract_call_known(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(f),
122+
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
123+
return nothing
124+
end
125+
126+
import Core.Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature
127+
128+
# We really need a Compiler stdlib
129+
Base.getindex(ir::IRCode, i) = Core.Compiler.getindex(ir, i)
130+
Base.setindex!(inst::Instruction, val, i) = Core.Compiler.setindex!(inst, val, i)
131+
132+
const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
133+
function Core.Compiler.handle_call!(todo::Vector{Pair{Int,Any}}, ir::IRCode, stmt_idx::Int,
134+
stmt::Expr, info::AutodiffCallInfo, flag::FlagType,
135+
sig::Signature, state::InliningState)
136+
137+
# Goal:
138+
# The IR we want to inline here is:
139+
# unpack the args ..
140+
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
141+
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
142+
143+
# 0. Obtain primal mi from DeferredCallInfo
144+
# TODO: remove this code duplication
145+
deferred_info = info.info
146+
minfo = deferred_info.info
147+
results = minfo.results
148+
if length(results.matches) != 1
149+
return nothing
150+
end
151+
match = only(results.matches)
152+
153+
# lookup the target mi with correct edge tracking
154+
# TODO: Effects?
155+
case = Core.Compiler.compileable_specialization(
156+
match, Core.Compiler.Effects(), Core.Compiler.InliningEdgeTracker(state), info)
157+
@assert case isa Core.Compiler.InvokeCase
158+
@assert stmt.head === :call
159+
160+
# Now create the IR we want to inline
161+
ir = Core.Compiler.IRCode() # contains a placeholder
162+
args = [Core.Compiler.Argument(i) for i in 2:length(stmt.args)] # f, args...
163+
idx = 0
164+
165+
# 0. Enzyme proper: Desugar args
166+
primal_args = args
167+
primal_argtypes = match.spec_types.parameters[2:end]
168+
169+
adjoint_rt = info.rt
170+
adjoint_args = args # TODO
171+
adjoint_argtypes = primal_argtypes
172+
173+
# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
174+
expr = Expr(:foreigncall,
175+
"extern gpuc.lookup",
176+
Ptr{Cvoid},
177+
Core.svec(#=meta=# Any, #=mi=# Any, #=f=# Any, primal_argtypes...), # Must use Any for MethodInstance or ftype
178+
0,
179+
QuoteNode(:llvmcall),
180+
deferred_info.meta,
181+
case.invoke,
182+
primal_args...
183+
)
184+
ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid}))
185+
186+
# 2. Call to magic `__autodiff`
187+
expr = Expr(:foreigncall,
188+
"extern __autodiff",
189+
adjoint_rt,
190+
Core.svec(Ptr{Cvoid}, Any, adjoint_argtypes...),
191+
0,
192+
QuoteNode(:llvmcall),
193+
ptr,
194+
adjoint_args...
195+
)
196+
ret = insert_node!(ir, idx, NewInstruction(expr, adjoint_rt))
197+
198+
# Finally replace placeholder return
199+
ir[Core.SSAValue(1)][:inst] = Core.ReturnNode(ret)
200+
ir[Core.SSAValue(1)][:type] = Ptr{Cvoid}
201+
202+
ir = Core.Compiler.compact!(ir)
203+
204+
# which mi to use here?
205+
# push inlining todos
206+
# TODO: Effects
207+
# aviatesk mentioned using inlining_policy instead...
208+
itodo = Core.Compiler.InliningTodo(case.invoke, ir, Core.Compiler.Effects())
209+
@assert itodo.linear_inline_eligible
210+
push!(todo, (stmt_idx=>itodo))
211+
212+
return nothing
213+
end
214+
215+
function mock_enzyme!(@nospecialize(job), intrinsic, mod::LLVM.Module)
216+
changed = false
217+
218+
for use in LLVM.uses(intrinsic)
219+
call = LLVM.user(use)
220+
LLVM.@dispose builder=LLVM.IRBuilder() begin
221+
LLVM.position!(builder, call)
222+
ops = LLVM.operands(call)
223+
target = ops[1]
224+
if target isa LLVM.ConstantExpr && (LLVM.opcode(target) == LLVM.API.LLVMPtrToInt ||
225+
LLVM.opcode(target) == LLVM.API.LLVMBitCast)
226+
target = first(LLVM.operands(target))
227+
end
228+
funcT = LLVM.called_type(call)
229+
funcT = LLVM.FunctionType(LLVM.return_type(funcT), LLVM.parameters(funcT)[3:end])
230+
direct_call = LLVM.call!(builder, funcT, target, ops[3:end - 1]) # why is the -1 necessary
231+
232+
LLVM.replace_uses!(call, direct_call)
233+
end
234+
if isempty(LLVM.uses(call))
235+
LLVM.erase!(call)
236+
changed = true
237+
else
238+
# the validator will detect this
239+
end
240+
end
241+
242+
return changed
243+
end
244+
245+
GPUCompiler.register_plugin!("__autodiff", mock_enzyme!)
246+
247+
end #module

test/ptx_tests.jl

+24
Original file line numberDiff line numberDiff line change
@@ -504,4 +504,28 @@ end
504504
ir = sprint(io->PTX.code_llvm(io, kernel_inline, Tuple{Ptr{Int64}, Int64}, meta=Plugin.NeverInlineMeta()))
505505
@test occursin("call fastcc i64 @julia_inline", ir)
506506
end
507+
508+
@testset "Mock Enzyme" begin
509+
function f(x)
510+
x^2
511+
end
512+
513+
function kernel(a, x)
514+
y = Plugin.autodiff(f, x)
515+
unsafe_store!(a, y)
516+
nothing
517+
end
518+
519+
# This tests deferred_codegen with kernel=true
520+
@show PTX.code_typed(kernel, Tuple{Ptr{Float64}, Float64})
521+
522+
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, optimize=false))
523+
@test occursin("call double @__autodiff", ir)
524+
@test !occursin("call fastcc double @julia_f", ir)
525+
526+
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, optimize=true))
527+
@test !occursin("call double @__autodiff", ir)
528+
@test occursin("call fastcc double @julia_f", ir)
529+
end
530+
507531
end #testitem

0 commit comments

Comments
 (0)