@@ -36,7 +36,7 @@ struct NeverInlineMeta <: InlineStateMeta end
36
36
import GPUCompiler: abstract_call_known, GPUInterpreter
37
37
import Core. Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
38
38
StmtInfo, AbsIntState, EFFECTS_TOTAL,
39
- MethodResultPure
39
+ MethodResultPure, CallInfo, IRCode
40
40
41
41
function abstract_call_known (meta:: InlineStateMeta , interp:: GPUInterpreter , @nospecialize (f),
42
42
arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
@@ -69,5 +69,143 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
69
69
return nothing
70
70
end
71
71
72
+ struct MockEnzymeMeta end
72
73
73
- end
74
+ # Having to define this function is annoying
75
+ # introduce `abstract type InferenceMeta`
76
+ function inlining_handler (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (atype), callinfo)
77
+ return nothing
78
+ end
79
+
80
+ function autodiff end
81
+
82
+ import GPUCompiler: DeferredCallInfo
83
+ struct AutodiffCallInfo <: CallInfo
84
+ rt
85
+ info:: DeferredCallInfo
86
+ end
87
+
88
+ function abstract_call_known (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (f),
89
+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
90
+ (; fargs, argtypes) = arginfo
91
+
92
+ if f === autodiff
93
+ if length (argtypes) <= 1
94
+ @static if VERSION < v " 1.11.0-"
95
+ return CallMeta (Union{}, Effects (), NoCallInfo ())
96
+ else
97
+ return CallMeta (Union{}, Union{}, Effects (), NoCallInfo ())
98
+ end
99
+ end
100
+
101
+ other_fargs = fargs === nothing ? nothing : fargs[2 : end ]
102
+ other_arginfo = ArgInfo (other_fargs, argtypes[2 : end ])
103
+ call = Core. Compiler. abstract_call (interp, other_arginfo, si, sv, max_methods)
104
+ callinfo = DeferredCallInfo (MockEnzymeMeta (), call. rt, call. info)
105
+
106
+ # Real Enzyme must compute `rt` and `exct` according to enzyme semantics
107
+ # and likely perform a unwrapping of fargs...
108
+ rt = call. rt
109
+
110
+ # TODO : Edges? Effects?
111
+ @static if VERSION < v " 1.11.0-"
112
+ # Can't use call.effects since otherwise this call might be just replaced with rt
113
+ return CallMeta (rt, Effects (), AutodiffCallInfo (rt, callinfo))
114
+ else
115
+ return CallMeta (rt, call. exct, Effects (), AutodiffCallInfo (rt, callinfo))
116
+ end
117
+ end
118
+
119
+ return nothing
120
+ end
121
+
122
+ import Core. Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature
123
+
124
+ # We really need a Compiler stdlib
125
+ Base. getindex (ir:: IRCode , i) = Core. Compiler. getindex (ir, i)
126
+ Base. setindex! (inst:: Instruction , val, i) = Core. Compiler. setindex! (inst, val, i)
127
+
128
+ const FlagType = VERSION >= v " 1.11.0-" ? UInt32 : UInt8
129
+ function Core. Compiler. handle_call! (todo:: Vector{Pair{Int,Any}} , ir:: IRCode , stmt_idx:: Int ,
130
+ stmt:: Expr , info:: AutodiffCallInfo , flag:: FlagType ,
131
+ sig:: Signature , state:: InliningState )
132
+
133
+ # Goal:
134
+ # The IR we want to inline here is:
135
+ # unpack the args ..
136
+ # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
137
+ # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
138
+
139
+ # 0. Obtain primal mi from DeferredCallInfo
140
+ # TODO : remove this code duplication
141
+ deferred_info = info. info
142
+ minfo = deferred_info. info
143
+ results = minfo. results
144
+ if length (results. matches) != 1
145
+ return nothing
146
+ end
147
+ match = only (results. matches)
148
+
149
+ # lookup the target mi with correct edge tracking
150
+ # TODO : Effects?
151
+ case = Core. Compiler. compileable_specialization (
152
+ match, Core. Compiler. Effects (), Core. Compiler. InliningEdgeTracker (state), info)
153
+ @assert case isa Core. Compiler. InvokeCase
154
+ @assert stmt. head === :call
155
+
156
+ # Now create the IR we want to inline
157
+ ir = Core. Compiler. IRCode () # contains a placeholder
158
+ args = [Core. Compiler. Argument (i) for i in 3 : length (stmt. args)]
159
+ idx = 0
160
+
161
+ # 0. Enzyme proper: Desugar args
162
+ primal_args = args
163
+ primal_argtypes = match. spec_types. parameters[2 : end ]
164
+
165
+ adjoint_rt = info. rt
166
+ adjoint_args = args # TODO
167
+ adjoint_argtypes = primal_argtypes
168
+
169
+ # 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
170
+ expr = Expr (:foreigncall ,
171
+ " extern gpuc.lookup" ,
172
+ Ptr{Cvoid},
173
+ Core. svec (Any, Any, Any, primal_argtypes... ), # Must use Any for MethodInstance or ftype
174
+ 0 ,
175
+ QuoteNode (:llvmcall ),
176
+ deferred_info. meta,
177
+ case. invoke,
178
+ primal_args...
179
+ )
180
+ ptr = insert_node! (ir, (idx += 1 ), NewInstruction (expr, Ptr{Cvoid}))
181
+
182
+ # 2. Call to magic `__autodiff`
183
+ expr = Expr (:foreigncall ,
184
+ " extern __autodiff" ,
185
+ adjoint_rt,
186
+ Core. svec (Any, Ptr{Cvoid}, adjoint_argtypes... ),
187
+ 0 ,
188
+ QuoteNode (:llvmcall ),
189
+ ptr,
190
+ adjoint_args...
191
+ )
192
+ ret = insert_node! (ir, idx, NewInstruction (expr, adjoint_rt))
193
+
194
+ # Finally replace placeholder return
195
+ ir[Core. SSAValue (1 )][:inst ] = Core. ReturnNode (ret)
196
+ ir[Core. SSAValue (1 )][:type ] = Ptr{Cvoid}
197
+
198
+ ir = Core. Compiler. compact! (ir)
199
+
200
+ # which mi to use here?
201
+ # push inlining todos
202
+ # TODO : Effects
203
+ # aviatesk mentioned using inlining_policy instead...
204
+ itodo = Core. Compiler. InliningTodo (case. invoke, ir, Core. Compiler. Effects ())
205
+ @assert itodo. linear_inline_eligible
206
+ push! (todo, (stmt_idx=> itodo))
207
+
208
+ return nothing
209
+ end
210
+
211
+ end # module
0 commit comments