@@ -37,7 +37,7 @@ struct NeverInlineMeta <: InlineStateMeta end
37
37
import GPUCompiler: abstract_call_known, GPUInterpreter
38
38
import Core. Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
39
39
StmtInfo, AbsIntState, EFFECTS_TOTAL,
40
- MethodResultPure
40
+ MethodResultPure, CallInfo, IRCode
41
41
42
42
function abstract_call_known (meta:: InlineStateMeta , interp:: GPUInterpreter , @nospecialize (f),
43
43
arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
@@ -70,5 +70,178 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
70
70
return nothing
71
71
end
72
72
73
+ struct MockEnzymeMeta end
73
74
74
- end
75
+ # Having to define this function is annoying
76
+ # introduce `abstract type InferenceMeta`
77
+ function inlining_handler (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (atype), callinfo)
78
+ return nothing
79
+ end
80
+
81
+ function autodiff end
82
+
83
+ import GPUCompiler: DeferredCallInfo
84
+ struct AutodiffCallInfo <: CallInfo
85
+ rt
86
+ info:: DeferredCallInfo
87
+ end
88
+
89
+ function abstract_call_known (meta:: Nothing , interp:: GPUInterpreter , f:: typeof (autodiff),
90
+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
91
+ (; fargs, argtypes) = arginfo
92
+
93
+ @assert f === autodiff
94
+ if length (argtypes) <= 1
95
+ @static if VERSION < v " 1.11.0-"
96
+ return CallMeta (Union{}, Effects (), NoCallInfo ())
97
+ else
98
+ return CallMeta (Union{}, Union{}, Effects (), NoCallInfo ())
99
+ end
100
+ end
101
+
102
+ other_fargs = fargs === nothing ? nothing : fargs[2 : end ]
103
+ other_arginfo = ArgInfo (other_fargs, argtypes[2 : end ])
104
+ # TODO : Ought we not change absint to use MockEnzymeMeta(), otherwise we fill the cache for nothing.
105
+ call = Core. Compiler. abstract_call (interp, other_arginfo, si, sv, max_methods)
106
+ callinfo = DeferredCallInfo (MockEnzymeMeta (), call. rt, call. info)
107
+
108
+ # Real Enzyme must compute `rt` and `exct` according to enzyme semantics
109
+ # and likely perform a unwrapping of fargs...
110
+ rt = call. rt
111
+
112
+ # TODO : Edges? Effects?
113
+ @static if VERSION < v " 1.11.0-"
114
+ # Can't use call.effects since otherwise this call might be just replaced with rt
115
+ return CallMeta (rt, Effects (), AutodiffCallInfo (rt, callinfo))
116
+ else
117
+ return CallMeta (rt, call. exct, Effects (), AutodiffCallInfo (rt, callinfo))
118
+ end
119
+ end
120
+
121
+ function abstract_call_known (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (f),
122
+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
123
+ return nothing
124
+ end
125
+
126
+ import Core. Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature
127
+
128
+ # We really need a Compiler stdlib
129
+ Base. getindex (ir:: IRCode , i) = Core. Compiler. getindex (ir, i)
130
+ Base. setindex! (inst:: Instruction , val, i) = Core. Compiler. setindex! (inst, val, i)
131
+
132
+ const FlagType = VERSION >= v " 1.11.0-" ? UInt32 : UInt8
133
+ function Core. Compiler. handle_call! (todo:: Vector{Pair{Int,Any}} , ir:: IRCode , stmt_idx:: Int ,
134
+ stmt:: Expr , info:: AutodiffCallInfo , flag:: FlagType ,
135
+ sig:: Signature , state:: InliningState )
136
+
137
+ # Goal:
138
+ # The IR we want to inline here is:
139
+ # unpack the args ..
140
+ # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
141
+ # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
142
+
143
+ # 0. Obtain primal mi from DeferredCallInfo
144
+ # TODO : remove this code duplication
145
+ deferred_info = info. info
146
+ minfo = deferred_info. info
147
+ results = minfo. results
148
+ if length (results. matches) != 1
149
+ return nothing
150
+ end
151
+ match = only (results. matches)
152
+
153
+ # lookup the target mi with correct edge tracking
154
+ # TODO : Effects?
155
+ case = Core. Compiler. compileable_specialization (
156
+ match, Core. Compiler. Effects (), Core. Compiler. InliningEdgeTracker (state), info)
157
+ @assert case isa Core. Compiler. InvokeCase
158
+ @assert stmt. head === :call
159
+
160
+ # Now create the IR we want to inline
161
+ ir = Core. Compiler. IRCode () # contains a placeholder
162
+ args = [Core. Compiler. Argument (i) for i in 2 : length (stmt. args)] # f, args...
163
+ idx = 0
164
+
165
+ # 0. Enzyme proper: Desugar args
166
+ primal_args = args
167
+ primal_argtypes = match. spec_types. parameters[2 : end ]
168
+
169
+ adjoint_rt = info. rt
170
+ adjoint_args = args # TODO
171
+ adjoint_argtypes = primal_argtypes
172
+
173
+ # 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
174
+ expr = Expr (:foreigncall ,
175
+ " extern gpuc.lookup" ,
176
+ Ptr{Cvoid},
177
+ Core. svec (#= meta=# Any, #= mi=# Any, #= f=# Any, primal_argtypes... ), # Must use Any for MethodInstance or ftype
178
+ 0 ,
179
+ QuoteNode (:llvmcall ),
180
+ deferred_info. meta,
181
+ case. invoke,
182
+ primal_args...
183
+ )
184
+ ptr = insert_node! (ir, (idx += 1 ), NewInstruction (expr, Ptr{Cvoid}))
185
+
186
+ # 2. Call to magic `__autodiff`
187
+ expr = Expr (:foreigncall ,
188
+ " extern __autodiff" ,
189
+ adjoint_rt,
190
+ Core. svec (Ptr{Cvoid}, Any, adjoint_argtypes... ),
191
+ 0 ,
192
+ QuoteNode (:llvmcall ),
193
+ ptr,
194
+ adjoint_args...
195
+ )
196
+ ret = insert_node! (ir, idx, NewInstruction (expr, adjoint_rt))
197
+
198
+ # Finally replace placeholder return
199
+ ir[Core. SSAValue (1 )][:inst ] = Core. ReturnNode (ret)
200
+ ir[Core. SSAValue (1 )][:type ] = Ptr{Cvoid}
201
+
202
+ ir = Core. Compiler. compact! (ir)
203
+
204
+ # which mi to use here?
205
+ # push inlining todos
206
+ # TODO : Effects
207
+ # aviatesk mentioned using inlining_policy instead...
208
+ itodo = Core. Compiler. InliningTodo (case. invoke, ir, Core. Compiler. Effects ())
209
+ @assert itodo. linear_inline_eligible
210
+ push! (todo, (stmt_idx=> itodo))
211
+
212
+ return nothing
213
+ end
214
+
215
+ function mock_enzyme! (@nospecialize (job), intrinsic, mod:: LLVM.Module )
216
+ changed = false
217
+
218
+ for use in LLVM. uses (intrinsic)
219
+ call = LLVM. user (use)
220
+ LLVM. @dispose builder= LLVM. IRBuilder () begin
221
+ LLVM. position! (builder, call)
222
+ ops = LLVM. operands (call)
223
+ target = ops[1 ]
224
+ if target isa LLVM. ConstantExpr && (LLVM. opcode (target) == LLVM. API. LLVMPtrToInt ||
225
+ LLVM. opcode (target) == LLVM. API. LLVMBitCast)
226
+ target = first (LLVM. operands (target))
227
+ end
228
+ funcT = LLVM. called_type (call)
229
+ funcT = LLVM. FunctionType (LLVM. return_type (funcT), LLVM. parameters (funcT)[3 : end ])
230
+ direct_call = LLVM. call! (builder, funcT, target, ops[3 : end - 1 ]) # why is the -1 necessary
231
+
232
+ LLVM. replace_uses! (call, direct_call)
233
+ end
234
+ if isempty (LLVM. uses (call))
235
+ LLVM. erase! (call)
236
+ changed = true
237
+ else
238
+ # the validator will detect this
239
+ end
240
+ end
241
+
242
+ return changed
243
+ end
244
+
245
+ GPUCompiler. register_plugin! (" __autodiff" , mock_enzyme!)
246
+
247
+ end # module
0 commit comments