Skip to content

Commit

Permalink
Simplify using constant expressions.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Nov 25, 2024
1 parent e2e8153 commit 7353e4c
Showing 1 changed file with 65 additions and 62 deletions.
127 changes: 65 additions & 62 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
# add kernel metadata
if job.config.kernel
entry = add_parameter_address_spaces!(job, mod, entry)
add_global_address_spaces!(job, mod)
entry = LLVM.functions(mod)[entry_fn]
entry = add_global_address_spaces!(job, mod, entry)

add_argument_metadata!(job, mod, entry)

Expand Down Expand Up @@ -228,7 +227,8 @@ end
# NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
# be executed after optimization (where Julia's address spaces are stripped). If we ever
# want to execute it earlier, adapt remapType to rewrite all pointer types.
function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
f::LLVM.Function)
ft = function_type(f)

# find the byref parameters
Expand Down Expand Up @@ -281,7 +281,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
# keep on using the original IR that assumed pointers without address spaces
new_args = LLVM.Value[]
@dispose builder=IRBuilder() begin
entry = BasicBlock(new_f, "parameter_conversion")
entry = BasicBlock(new_f, "conversion")
position!(builder, entry)

# perform argument conversions
Expand Down Expand Up @@ -334,76 +334,79 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
return new_f
end

# add addrspace 2 to global constants
function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
# update address spaces of constant global objects
#
# global constant objects need to reside in address space 2, so we clone each function
# that uses global objects and rewrite the globals used by it
function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
entry::LLVM.Function)
# determine global variables we need to update
global_map = Dict{LLVM.Value, LLVM.Value}()
for gv in globals(mod)
if isconstant(gv) && addrspace(value_type(gv)) == 0
gv_ty = global_value_type(gv)
gv_name = LLVM.name(gv)

new_gv = GlobalVariable(mod, gv_ty, "", 2)

alignment!(new_gv, alignment(gv))
unnamed_addr!(new_gv, unnamed_addr(gv))
initializer!(new_gv, initializer(gv))
constant!(new_gv, true)
linkage!(new_gv, linkage(gv))
visibility!(new_gv, visibility(gv))

funcs = Set{LLVM.Function}()
for use in uses(gv)
inst = user(use)
bb = LLVM.parent(inst)
f = LLVM.parent(bb)

push!(funcs, f)
end
isconstant(gv) || continue
addrspace(value_type(gv)) == 0 || continue

for f in funcs
ft = function_type(f)
new_f = LLVM.Function(mod, "h", ft)
linkage!(new_f, linkage(f))
gv_ty = global_value_type(gv)
gv_name = LLVM.name(gv)

for (param, new_param) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_param, LLVM.name(param))
end
LLVM.name!(gv, gv_name * ".old")
new_gv = GlobalVariable(mod, gv_ty, gv_name, 2)

@dispose builder=IRBuilder() begin
entry = BasicBlock(new_f, "gv_conversion")
position!(builder, entry)
alignment!(new_gv, alignment(gv))
unnamed_addr!(new_gv, unnamed_addr(gv))
initializer!(new_gv, initializer(gv))
constant!(new_gv, true)
linkage!(new_gv, linkage(gv))
visibility!(new_gv, visibility(gv))

ptr = alloca!(builder, gv_ty, gv_name * ".local")
val = load!(builder, gv_ty, new_gv, gv_name * ".val")
store!(builder, val, ptr)
global_map[gv] = new_gv
end
isempty(global_map) && return entry

# map the arguments
value_map = Dict{LLVM.Value, LLVM.Value}(
param => new_param for (param, new_param) in zip(parameters(f), parameters(new_f))
)
# determine which functions we need to update
function_worklist = Set{LLVM.Function}()
for gv in keys(global_map), use in uses(gv)
inst = user(use)
bb = LLVM.parent(inst)
f = LLVM.parent(bb)

value_map[gv] = ptr
value_map[f] = new_f
clone_into!(new_f, f; value_map,
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)

br!(builder, blocks(new_f)[2])
end
push!(function_worklist, f)
end

f_name = LLVM.name(f)
replace_uses!(f, new_f)
replace_metadata_uses!(f, new_f)
erase!(f)
LLVM.name!(new_f, f_name)
end
# update functions that use the global
if !isempty(function_worklist)
# we can't map the global variable directly, as the type change won't be applied
# recursively. so instead map a constant expression converting the value of the
# global into one with the correct address space.
value_map = Dict{LLVM.Value,LLVM.Value}()
for (gv, new_gv) in global_map
ptr = const_addrspacecast(new_gv, value_type(gv))
@assert ptr isa LLVM.ConstantExpr
value_map[gv] = ptr
end

entry_fn = LLVM.name(entry)
for fun in function_worklist
fn = LLVM.name(fun)

new_fun = clone(fun; value_map)
replace_uses!(fun, new_fun)
replace_metadata_uses!(fun, new_fun)
erase!(fun)

@assert isempty(uses(gv))
replace_metadata_uses!(gv, new_gv)
erase!(gv)
LLVM.name!(new_gv, gv_name)
LLVM.name!(new_fun, fn)
end
entry = LLVM.functions(mod)[entry_fn]
end

return
# delete old globals
for (gv, new_gv) in global_map
@assert isempty(uses(gv))
replace_metadata_uses!(gv, new_gv)
erase!(gv)
end

return entry
end


Expand Down

0 comments on commit 7353e4c

Please sign in to comment.