Skip to content

Commit 4ef6019

Browse files
committed
Make GPUInterpreter extensible
Currently Enzyme uses it's own AbstractInterpreter, in particular to handle inlining blocking of functions with custom rules and to handle nested autodiff operations. - [ ] Create a version of Enzyme with this - [ ] Support a version of `gpuc.deferred(meta)`
1 parent 78d166b commit 4ef6019

14 files changed

+389
-76
lines changed

src/driver.jl

+15-10
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242
## deferred compilation
4343

4444
"""
45-
var"gpuc.deferred"(f, args...)::Ptr{Cvoid}
45+
var"gpuc.deferred"(meta, f, args...)::Ptr{Cvoid}
4646
4747
As if we were to call `f(args...)` but instead we are
4848
putting down a marker and return a function pointer to later
@@ -154,10 +154,11 @@ const __llvm_initialized = Ref(false)
154154

155155
@timeit_debug to "IR generation" begin
156156
ir, compiled = irgen(job)
157+
edge = Edge(inference_metadata(job), job.source)
157158
if job.config.entry_abi === :specfunc
158-
entry_fn = compiled[job.source].specfunc
159+
entry_fn = compiled[edge].specfunc
159160
else
160-
entry_fn = compiled[job.source].func
161+
entry_fn = compiled[edge].func
161162
end
162163
entry = functions(ir)[entry_fn]
163164
end
@@ -198,24 +199,28 @@ const __llvm_initialized = Ref(false)
198199
return val
199200
end
200201

201-
worklist = Dict{Any, Vector{LLVM.CallInst}}()
202+
worklist = Dict{Edge, Vector{LLVM.CallInst}}()
202203
for use in uses(dyn_marker)
203204
# decode the call
204205
call = user(use)::LLVM.CallInst
205-
dyn_mi_inst = find_base_object(operands(call)[1])
206+
dyn_meta_inst = find_base_object(operands(call)[1])
207+
@compiler_assert isa(dyn_meta_inst, LLVM.ConstantInt) job
208+
dyn_mi_inst = find_base_object(operands(call)[2])
206209
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
210+
dyn_meta = Base.unsafe_pointer_to_objref(
211+
convert(Ptr{Cvoid}, convert(Int, dyn_meta_inst)))
207212
dyn_mi = Base.unsafe_pointer_to_objref(
208-
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
209-
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
213+
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))::MethodInstance
214+
push!(get!(worklist, Edge(dyn_meta, dyn_mi), LLVM.CallInst[]), call)
210215
end
211216

212-
for dyn_mi in keys(worklist)
213-
dyn_fn_name = compiled[dyn_mi].specfunc
217+
for dyn_edge in keys(worklist)
218+
dyn_fn_name = compiled[dyn_edge].specfunc
214219
dyn_fn = functions(ir)[dyn_fn_name]
215220

216221
# insert a pointer to the function everywhere the entry is used
217222
T_ptr = convert(LLVMType, Ptr{Cvoid})
218-
for call in worklist[dyn_mi]
223+
for call in worklist[dyn_edge]
219224
@dispose builder=IRBuilder() begin
220225
position!(builder, call)
221226
fptr = if LLVM.version() >= v"17"

src/interface.jl

+12-6
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ Several keyword arguments can be used to customize the compilation process:
8989
struct CompilerConfig{T,P}
9090
target::T
9191
params::P
92+
meta
9293

9394
kernel::Bool
9495
name::Union{Nothing,String}
@@ -98,6 +99,7 @@ struct CompilerConfig{T,P}
9899

99100
function CompilerConfig(target::AbstractCompilerTarget,
100101
params::AbstractCompilerParams;
102+
meta = nothing,
101103
kernel=true,
102104
name=nothing,
103105
entry_abi=:specfunc,
@@ -106,16 +108,16 @@ struct CompilerConfig{T,P}
106108
if entry_abi (:specfunc, :func)
107109
error("Unknown entry_abi=$entry_abi")
108110
end
109-
new{typeof(target), typeof(params)}(target, params, kernel, name, entry_abi,
111+
new{typeof(target), typeof(params)}(target, params, meta, kernel, name, entry_abi,
110112
always_inline, opt_level)
111113
end
112114
end
113115

114116
# copy constructor
115-
CompilerConfig(cfg::CompilerConfig; target=cfg.target, params=cfg.params,
117+
CompilerConfig(cfg::CompilerConfig; target=cfg.target, params=cfg.params, meta=cfg.meta,
116118
kernel=cfg.kernel, name=cfg.name, entry_abi=cfg.entry_abi,
117119
always_inline=cfg.always_inline, opt_level=cfg.opt_level) =
118-
CompilerConfig(target, params; kernel, entry_abi, name, always_inline, opt_level)
120+
CompilerConfig(target, params; meta, kernel, entry_abi, name, always_inline, opt_level)
119121

120122
function Base.show(io::IO, @nospecialize(cfg::CompilerConfig{T})) where {T}
121123
print(io, "CompilerConfig for ", T)
@@ -124,6 +126,7 @@ end
124126
function Base.hash(cfg::CompilerConfig, h::UInt)
125127
h = hash(cfg.target, h)
126128
h = hash(cfg.params, h)
129+
h = hash(cfg.meta, h)::UInt
127130

128131
h = hash(cfg.kernel, h)
129132
h = hash(cfg.name, h)
@@ -178,15 +181,17 @@ runtime_module(@nospecialize(job::CompilerJob)) = error("Not implemented")
178181
# check if a function is an intrinsic that can assumed to be always available
179182
isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
180183

184+
inference_metadata(@nospecialize(job::CompilerJob)) = job.config.meta
185+
181186
# provide a specific interpreter to use.
182187
if VERSION >= v"1.11.0-DEV.1552"
183188
get_interpreter(@nospecialize(job::CompilerJob)) =
184-
GPUInterpreter(job.world; method_table=method_table(job),
189+
GPUInterpreter(job.world; meta=inference_metadata(job), method_table=method_table(job),
185190
token=ci_cache_token(job), inf_params=inference_params(job),
186191
opt_params=optimization_params(job))
187192
else
188193
get_interpreter(@nospecialize(job::CompilerJob)) =
189-
GPUInterpreter(job.world; method_table=method_table(job),
194+
GPUInterpreter(job.world; meta=inference_metadata(job), method_table=method_table(job),
190195
code_cache=ci_cache(job), inf_params=inference_params(job),
191196
opt_params=optimization_params(job))
192197
end
@@ -227,10 +232,11 @@ struct GPUCompilerCacheToken
227232
target_type::Type
228233
always_inline::Bool
229234
method_table::Core.MethodTable
235+
metadata
230236
end
231237

232238
ci_cache_token(@nospecialize(job::CompilerJob)) =
233-
GPUCompilerCacheToken(typeof(job.config.target), job.config.always_inline, method_table(job))
239+
GPUCompilerCacheToken(typeof(job.config.target), job.config.always_inline, method_table(job), inference_metadata(job))
234240

235241
# the codeinstance cache to use -- should only be used for the constructor
236242
if VERSION >= v"1.11.0-DEV.1552"

src/irgen.jl

+14-13
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
function irgen(@nospecialize(job::CompilerJob))
44
mod, compiled = @timeit_debug to "emission" compile_method_instance(job)
5+
edge = Edge(inference_metadata(job), job.source)
56
if job.config.entry_abi === :specfunc
6-
entry_fn = compiled[job.source].specfunc
7+
entry_fn = compiled[edge].specfunc
78
else
8-
entry_fn = compiled[job.source].func
9+
entry_fn = compiled[edge].func
910
end
1011
@assert entry_fn !== nothing
1112
entry = functions(mod)[entry_fn]
@@ -70,25 +71,25 @@ function irgen(@nospecialize(job::CompilerJob))
7071
entry = deprecation_marker
7172
end
7273
if job.config.entry_abi === :specfunc
73-
func = compiled[job.source].func
74+
func = compiled[edge].func
7475
specfunc = LLVM.name(entry)
7576
else
7677
func = LLVM.name(entry)
77-
specfunc = compiled[job.source].specfunc
78+
specfunc = compiled[edge].specfunc
7879
end
7980

80-
compiled[job.source] =
81-
(; compiled[job.source].ci, func, specfunc)
81+
compiled[edge] =
82+
(; compiled[edge].ci, func, specfunc)
8283

8384
# Earlier we sanitize global names, this invalidates the
8485
# func, specfunc names safed in compiled. Update the names now,
8586
# such that when when use the compiled mappings to lookup the
8687
# llvm function for a methodinstance (deferred codegen) we have
8788
# valid targets.
88-
for mi in keys(compiled)
89-
mi == job.source && continue
90-
ci, func, specfunc = compiled[mi]
91-
compiled[mi] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc))
89+
for other in keys(compiled)
90+
other == edge && continue
91+
ci, func, specfunc = compiled[other]
92+
compiled[other] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc))
9293
end
9394

9495
# TODO: Should we rewrite gpuc.lookup here?
@@ -111,11 +112,11 @@ function irgen(@nospecialize(job::CompilerJob))
111112
# internalize all functions, but keep exported global variables.
112113
linkage!(entry, LLVM.API.LLVMExternalLinkage)
113114
preserved_gvs = String[LLVM.name(entry)]
114-
for mi in keys(compiled)
115+
for other in keys(compiled)
115116
# delay internalizing of deferred calls since
116117
# gpuc.lookup is not yet rewriten.
117-
mi == job.source && continue
118-
_, _, specfunc = compiled[mi]
118+
other == edge && continue
119+
_, _, specfunc = compiled[other]
119120
push!(preserved_gvs, specfunc) # this could be deleted if we rewrite gpuc.lookup earlier
120121
end
121122
for gvar in globals(mod)

0 commit comments

Comments
 (0)