Skip to content

Commit 7c4dc5c

Browse files
committed
Mock Enzyme plugin
1 parent 052a118 commit 7c4dc5c

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed

test/plugin_testsetup.jl

+84
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,89 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
6969
return nothing
7070
end
7171

72+
struct MockEnzymeMeta end
7273

74+
function autodiff end
75+
76+
import GPUCompiler: DeferredCallInfo
77+
struct AutodiffCallInfo <: CallInfo
78+
info::DeferredCallInfo
79+
end
80+
81+
function abstract_call_known(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(f),
82+
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
83+
(; fargs, argtypes) = arginfo
84+
85+
if f === autodiff
86+
if length(argtypes) >= 1
87+
@static if VERSION < v"1.11.0-"
88+
return CallMeta(Union{}, Effects(), NoCallInfo())
89+
else
90+
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
91+
end
92+
end
93+
94+
other_fargs = fargs === nothing ? nothing : fargs[2:end]
95+
other_arginfo = ArgInfo(other_fargs, argtypes[2:end])
96+
call = Core.Compiler.abstract_call(interp, other_arginfo, si, sv, max_methods)
97+
callinfo = DeferredCallInfo(MockEnzymeMeta(), call.rt, call.info)
98+
99+
# Real Enzyme must compute `rt` and `exct` according to enzyme semantics
100+
# and likely perform a unwrapping of fargs...
101+
rt = Nothing
102+
103+
# TODO: Edges? Effects?
104+
@static if VERSION < v"1.11.0-"
105+
return CallMeta(rt, call.effects, AutodiffCallInfo(callinfo))
106+
else
107+
return CallMeta(rt, call.exct, call.effects, AutodiffCallInfo(callinfo))
108+
end
109+
end
110+
111+
return nothing
112+
end
113+
114+
const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
115+
function CC.handle_call!(todo::Vector{Pair{Int,Any}}, ir::CC.IRCode, idx::CC.Int,
116+
stmt::Expr, info::AutodiffCallInfo, flag::FlagType,
117+
sig::CC.Signature, state::CC.InliningState)
118+
# Goal:
119+
# The IR we want to inline here is:
120+
# unpack the args ...
121+
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
122+
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
123+
124+
push!(todo, idx=>(info::AutoDiffTodo))
125+
126+
# # 1. Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
127+
# deferred_info = info.info
128+
# # TODO: This is code duplication is unfortunate...
129+
# minfo = deferred_info.info
130+
# results = minfo.results
131+
# if length(results.matches) != 1
132+
# return nothing
133+
# end
134+
# match = only(results.matches)
135+
136+
# # lookup the target mi with correct edge tracking
137+
# case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state),
138+
# info)
139+
# @assert case isa CC.InvokeCase
140+
# @assert stmt.head === :call
141+
142+
# stmt = Expr(:foreigncall,
143+
# "extern gpuc.lookup",
144+
# Ptr{Cvoid},
145+
# Core.svec(Any, Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
146+
# 0,
147+
# QuoteNode(:llvmcall),
148+
# info.meta,
149+
# case.invoke,
150+
# stmt.args[3:end]...
151+
# )
152+
153+
# # 2. Form call to `__autodiff`
154+
# # TODO!
155+
156+
return nothing
73157
end

0 commit comments

Comments
 (0)