Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[compat]
ExprTools = "0.1"
InteractiveUtils = "1"
LLVM = "8, 9"
LLVM = "9"
Libdl = "1"
Logging = "1"
PrecompileTools = "1"
Expand Down
103 changes: 103 additions & 0 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
entry::LLVM.Function)
entry_fn = LLVM.name(entry)

# get rid of unreachable control flow (JuliaLang/Metal.jl#370)
if job.config.target.macos < v"15"
for f in functions(mod)
replace_unreachable!(job, f)
end
end

# add kernel metadata
if job.config.kernel
entry = add_address_spaces!(job, mod, entry)
Expand All @@ -142,6 +149,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L

add_module_metadata!(job, mod)

# JuliaLang/Metal.jl#113
hide_noreturn!(mod)
end

Expand Down Expand Up @@ -1075,3 +1083,98 @@ function annotate_air_intrinsics!(@nospecialize(job::CompilerJob), mod::LLVM.Mod

return changed
end

# replace unreachable control flow with branches to the exit block
#
# before macOS 15, code generated by Julia 1.11 causes compilation failures in the back-end.
# the reduced example contains unreachable control flow executed divergently, so this is a
# similar issue as encountered with NVIDIA, albeit causing crashes instead of miscompiles.
#
# the proposed solution is to avoid (divergent) unreachable control flow, instead replacing
# it by branches to the exit block. since `unreachable` doesn't lower to anything that
# aborts the kernel anyway (can we fix this?), this transformation should be safe.
function replace_unreachable!(@nospecialize(job::CompilerJob), f::LLVM.Function)
# find unreachable instructions and exit blocks
unreachables = Instruction[]
exit_blocks = BasicBlock[]
for bb in blocks(f), inst in instructions(bb)
if isa(inst, LLVM.UnreachableInst)
push!(unreachables, inst)
end
if isa(inst, LLVM.RetInst)
push!(exit_blocks, bb)
end
end
isempty(unreachables) && return false

# if we don't have an exit block, we can't do much. we could insert a return, but that
# would probably keep the problematic control flow just as it is.
isempty(exit_blocks) && return false

@dispose builder=IRBuilder() begin
# if we have multiple exit blocks, take the last one, which is hopefully the least
# divergent (assuming divergent control flow is the root of the problem here).
exit_block = last(exit_blocks)
ret = terminator(exit_block)

# create a return block with only the return instruction, so that we only have to
# care about any values returned, and not about any other SSA value in the block.
if first(instructions(exit_block)) == ret
# we can reuse the exit block if it only contains the return
return_block = exit_block
else
# split the exit block right before the ret
return_block = BasicBlock(f, "ret")
move_after(return_block, exit_block)

# emit a branch
position!(builder, ret)
br!(builder, return_block)

# move the return
delete!(exit_block, ret)
position!(builder, return_block)
insert!(builder, ret)
end

# when returning a value, add a phi node to the return block, so that we can later
# add incoming undef values when branching from `unreachable` blocks
if !isempty(operands(ret))
position!(builder, ret)
# XXX: support aggregate returns?
val = only(operands(ret))
phi = phi!(builder, value_type(val))
for pred in predecessors(return_block)
push!(incoming(phi), (val, pred))
end
operands(ret)[1] = phi
end

# replace the unreachable with a branch to the return block
for unreachable in unreachables
bb = LLVM.parent(unreachable)

# remove preceding traps to avoid reconstructing unreachable control flow
prev = previnst(unreachable)
if isa(prev, LLVM.CallInst) && name(called_operand(prev)) == "llvm.trap"
unsafe_delete!(bb, prev)
end

# replace the unreachable with a branch to the return block
position!(builder, unreachable)
br!(builder, return_block)
unsafe_delete!(bb, unreachable)

# patch up any phi nodes in the return block
for inst in instructions(return_block)
if isa(inst, LLVM.PHIInst)
undef = UndefValue(value_type(inst))
vals = incoming(inst)
push!(vals, (undef, bb))
end
end
end
end

return true
end