Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,25 @@ function prepare_llvm(interp, mod::LLVM.Module, job, meta)

fixup_1p12_sret!(llvmfn)
end

# We explicitly save the type of alloca's before they get lowered
for f in functions(mod)
for bb in blocks(f), inst in instructions(bb)
if !isa(inst, LLVM.CallInst)
continue
end
fn = LLVM.called_operand(inst)
if !isa(fn, LLVM.Function)
continue
end
if LLVM.name(fn) == "julia.gc_alloc_obj"
legal, RT, _ = abs_typeof(inst)
if legal
metadata(inst)["enzymejl_gc_alloc_rt"] = MDNode(LLVM.Metadata[MDString(string(convert(UInt, unsafe_to_pointer(RT))))])
end
end
end
end
end

include("compiler/optimize.jl")
Expand Down
18 changes: 15 additions & 3 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ end

EnzymeAttributorPass() = NewPMModulePass("enzyme_attributor", enzyme_attributor_pass!)
ReinsertGCMarkerPass() = NewPMFunctionPass("reinsert_gcmarker", reinsert_gcmarker_pass!)
RestoreAllocaType() = NewPMFunctionPass("restore_alloca_type", restore_alloca_type!)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vchuravy can you check what obvious thing I'm doing wrong here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register!(pb, RestoreAllocaType())

SafeAtomicToRegularStorePass() = NewPMFunctionPass("safe_atomic_to_regular_store", safe_atomic_to_regular_store!)
Addr13NoAliasPass() = NewPMModulePass("addr13_noalias", addr13NoAlias)

function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
register!(pb, Addr13NoAliasPass())
register!(pb, RestoreAllocaType())
add!(pb, NewPMAAManager()) do aam
add!(aam, ScopedNoAliasAA())
add!(aam, TypeBasedAA())
Expand All @@ -50,6 +52,7 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
add!(mpm, AlwaysInlinerPass())
add!(mpm, NewPMFunctionPassManager()) do fpm
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())
end
end
run!(pb, mod, tm)
Expand All @@ -76,6 +79,7 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
function middle_optimize!(second_stage=false)
@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
register!(pb, RestoreAllocaType())
add!(pb, NewPMAAManager()) do aam
add!(aam, ScopedNoAliasAA())
add!(aam, TypeBasedAA())
Expand All @@ -98,6 +102,7 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
add!(fpm, ReassociatePass())
add!(fpm, EarlyCSEPass())
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())

add!(fpm, NewPMLoopPassManager(use_memory_ssa=true)) do lpm
add!(lpm, LoopIdiomRecognizePass())
Expand All @@ -117,6 +122,7 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
# todo peeling=false?
add!(fpm, LoopUnrollPass(opt_level=2, partial=false)) # what opt level?
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())
add!(fpm, SROAPass())
add!(fpm, GVNPass())

Expand All @@ -131,6 +137,7 @@ function optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)
add!(fpm, JumpThreadingPass())
add!(fpm, DSEPass())
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())
add!(fpm, SimplifyCFGPass())


Expand Down Expand Up @@ -222,7 +229,8 @@ function addOptimizationPasses!(mpm::LLVM.NewPMPassManager)
# merging the `alloca` for the unboxed data and the `alloca` created by the `alloc_opt`
# pass.

add!(fpm, AllocOptPass())
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())
# consider AggressiveInstCombinePass at optlevel > 2

add!(fpm, InstCombinePass())
Expand All @@ -240,6 +248,7 @@ function addOptimizationPasses!(mpm::LLVM.NewPMPassManager)
# Load forwarding above can expose allocations that aren't actually used
# remove those before optimizing loops.
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())

add!(fpm, NewPMLoopPassManager(use_memory_ssa=true)) do lpm
add!(lpm, LoopRotatePass())
Expand All @@ -261,6 +270,7 @@ function addOptimizationPasses!(mpm::LLVM.NewPMPassManager)

# Run our own SROA on heap objects before LLVM's
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())
# Re-run SROA after loop-unrolling (useful for small loops that operate,
# over the structure of an aggregate)
add!(fpm, SROAPass())
Expand All @@ -282,7 +292,8 @@ function addOptimizationPasses!(mpm::LLVM.NewPMPassManager)

# More dead allocation (store) deletion before loop optimization
# consider removing this:
add!(fpm, AllocOptPass())
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())

# see if all of the constant folding has exposed more loops
# to simplification and deletion
Expand Down Expand Up @@ -435,12 +446,13 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
registerEnzymeAndPassPipeline!(pb)
register!(pb, ReinsertGCMarkerPass())
register!(pb, SafeAtomicToRegularStorePass())
register!(pb, RestoreAllocaType())
add!(pb, NewPMAAManager()) do aam
add!(aam, ScopedNoAliasAA())
add!(aam, TypeBasedAA())
add!(aam, BasicAA())
end
add!(pb, NewPMModulePassManager()) do mpm
add!(pb, NewPMModulePassManager()) do mpm
addOptimizationPasses!(mpm)
if machine
# TODO enable validate_return_roots
Expand Down
1 change: 1 addition & 0 deletions src/llvm/attributes.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const nofreefns = Set{String}((
"jl_genericmemory_copyto",
"utf8proc_toupper",
"ClientGetAddressableDevices",
"ClientNumAddressableDevices",
Expand Down
52 changes: 52 additions & 0 deletions src/llvm/transforms.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,51 @@
function restore_alloca_type!(f::LLVM.Function)
replaceAndErase = Tuple{LLVM.AllocaInst,Type, LLVMType}[]
dl = datalayout(LLVM.parent(f))

for bb in blocks(f), inst in instructions(bb)
if isa(inst, LLVM.AllocaInst)
if haskey(metadata(inst), "enzymejl_allocrt") || haskey(metadata(inst), "enzymejl_gc_alloc_rt")
mds = operands(metadata(arg)["enzymejl_allocart"])[1]::MDString
mds = Base.convert(String, mds)
ptr = reinterpret(Ptr{Cvoid}, parse(UInt, mds))
RT = Base.unsafe_pointer_to_objref(ptr)
lrt = convert(LLVMType, RT)
at = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(inst))
if at == lrt
continue
end
cnt = operands(inst)[1]
if !isa(cnt, LLVM.ConstantInt) || convert(UInt, cnt) != 1
continue
end
if LLVM.sizeof(dl, at) == LLVM.sizeof(dl, lrt) && CountTrackedPointers(at).count == 0
push!(replaceAndErase, (inst, RT, lrt))
end
end
end
end

for (al, RT, lrt) in replaceAndErase
if CountTrackedPointers(lrt).count != 0
lrt2 = strip_tracked_pointers(lrt)
@assert LLVM.sizeof(dl, lrt2) == LLVM.sizeof(dl, lrt)
lrt = lrt2
end
b = IRBuilder()
position!(b, al)
pname = LLVM.name(al)
LLVM.name!(al, "")
al2 = alloca!(b, lrt, pname)
cst = al2
if value_type(cst) != value_type(al)
cst = bitcast!(b, cst, value_type(al))
end
LLVM.replace_uses!(al, cst)
LLVM.API.LLVMInstructionEraseFromParent(al)
metadata(inst)["enzymejl_allocart"] = MDNode(LLVM.Metadata[MDString(string(convert(UInt, unsafe_to_pointer(RT))))])
end
return length(replaceAndErase) != 0
end

# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA
function rewrite_ccalls!(mod::LLVM.Module)
Expand Down Expand Up @@ -2643,11 +2691,13 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup
propagate_returned!(mod)
LLVM.@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
register!(pb, RestoreAllocaType())
add!(pb, NewPMModulePassManager()) do mpm
add!(mpm, NewPMFunctionPassManager()) do fpm
add!(fpm, InstCombinePass())
add!(fpm, JLInstSimplifyPass())
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())
add!(fpm, SROAPass())
add!(fpm, EarlyCSEPass())
end
Expand All @@ -2670,11 +2720,13 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup
LLVM.@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
register!(pb, EnzymeAttributorPass())
register!(pb, RestoreAllocaType())
add!(pb, NewPMModulePassManager()) do mpm
add!(mpm, NewPMFunctionPassManager()) do fpm
add!(fpm, InstCombinePass())
add!(fpm, JLInstSimplifyPass())
add!(fpm, AllocOptPass())
add!(fpm, RestoreAllocaType())
add!(fpm, SROAPass())
end
if RunAttributor[]
Expand Down
31 changes: 31 additions & 0 deletions src/typeutils/lltypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,37 @@ nfields(Type::LLVM.VectorType) = size(Type)
nfields(Type::LLVM.ArrayType) = length(Type)
nfields(Type::LLVM.PointerType) = 1

function strip_tracked_pointers(@nospecialize(T::LLVM.LLVMType))
if !any_jltypes(T)
return T
end
if isa(T, LLVM.PointerType)
@assert isSpecialPtr(T)
if LLVM.is_opaque(T)
return LLVM.PointerType()
else
return LLVM.PointerType(eltype(T))
end
end
if isa(T, LLVM.ArrayType)
return LLVM.ArrayType(eltype(T), length(T))
end

if isa(T, LLVM.VectorType)
return LLVM.VectorType(eltype(T), length(T))
end

if isa(T, LLVM.StructType)
subtypes = LLVM.LLVMTypes[]
for (i, t) in enumerate(LLVM.elements(ty))
push!(subtypes, strip_tracked_pointers(t))
end
return LLVM.StructType(subtypes; packed=LLVM.ispacked(T))
end

throw(AssertionError("Unknown composite type"))
end

function store_nonjl_types!(B::LLVM.IRBuilder, @nospecialize(startval::LLVM.Value), @nospecialize(p::LLVM.Value))
T_jlvalue = LLVM.StructType(LLVMType[])
T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
Expand Down
Loading