@@ -69,5 +69,89 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
69
69
return nothing
70
70
end
71
71
72
+ struct MockEnzymeMeta end
72
73
74
+ function autodiff end
75
+
76
+ import GPUCompiler: DeferredCallInfo
77
+ struct AutodiffCallInfo <: CallInfo
78
+ info:: DeferredCallInfo
79
+ end
80
+
81
+ function abstract_call_known (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (f),
82
+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
83
+ (; fargs, argtypes) = arginfo
84
+
85
+ if f === autodiff
86
+ if length (argtypes) >= 1
87
+ @static if VERSION < v " 1.11.0-"
88
+ return CallMeta (Union{}, Effects (), NoCallInfo ())
89
+ else
90
+ return CallMeta (Union{}, Union{}, Effects (), NoCallInfo ())
91
+ end
92
+ end
93
+
94
+ other_fargs = fargs === nothing ? nothing : fargs[2 : end ]
95
+ other_arginfo = ArgInfo (other_fargs, argtypes[2 : end ])
96
+ call = Core. Compiler. abstract_call (interp, other_arginfo, si, sv, max_methods)
97
+ callinfo = DeferredCallInfo (MockEnzymeMeta (), call. rt, call. info)
98
+
99
+ # Real Enzyme must compute `rt` and `exct` according to enzyme semantics
100
+ # and likely perform a unwrapping of fargs...
101
+ rt = Nothing
102
+
103
+ # TODO : Edges? Effects?
104
+ @static if VERSION < v " 1.11.0-"
105
+ return CallMeta (rt, call. effects, AutodiffCallInfo (callinfo))
106
+ else
107
+ return CallMeta (rt, call. exct, call. effects, AutodiffCallInfo (callinfo))
108
+ end
109
+ end
110
+
111
+ return nothing
112
+ end
113
+
114
+ const FlagType = VERSION >= v " 1.11.0-" ? UInt32 : UInt8
115
+ function CC. handle_call! (todo:: Vector{Pair{Int,Any}} , ir:: CC.IRCode , idx:: CC.Int ,
116
+ stmt:: Expr , info:: AutodiffCallInfo , flag:: FlagType ,
117
+ sig:: CC.Signature , state:: CC.InliningState )
118
+ # Goal:
119
+ # The IR we want to inline here is:
120
+ # unpack the args ...
121
+ # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
122
+ # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
123
+
124
+ push! (todo, idx=> (info:: AutoDiffTodo ))
125
+
126
+ # # 1. Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
127
+ # deferred_info = info.info
128
+ # # TODO : This is code duplication is unfortunate...
129
+ # minfo = deferred_info.info
130
+ # results = minfo.results
131
+ # if length(results.matches) != 1
132
+ # return nothing
133
+ # end
134
+ # match = only(results.matches)
135
+
136
+ # # lookup the target mi with correct edge tracking
137
+ # case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state),
138
+ # info)
139
+ # @assert case isa CC.InvokeCase
140
+ # @assert stmt.head === :call
141
+
142
+ # stmt = Expr(:foreigncall,
143
+ # "extern gpuc.lookup",
144
+ # Ptr{Cvoid},
145
+ # Core.svec(Any, Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
146
+ # 0,
147
+ # QuoteNode(:llvmcall),
148
+ # info.meta,
149
+ # case.invoke,
150
+ # stmt.args[3:end]...
151
+ # )
152
+
153
+ # # 2. Form call to `__autodiff`
154
+ # # TODO !
155
+
156
+ return nothing
73
157
end
0 commit comments