Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/interpreter/forward_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function build_frule(
else
# Derive forward-pass IR, and shove in a `MistyClosure`.
dual_ir, captures, info = generate_dual_ir(interp, sig_or_mi; debug_mode)
dual_oc = optimized_misty_closure(
dual_oc = misty_closure(
info.dual_ret_type, dual_ir, captures...; do_compile=true
)
raw_rule = DerivedFRule{sig,typeof(dual_oc),info.isva,info.nargs}(dual_oc)
Expand Down Expand Up @@ -109,6 +109,9 @@ function generate_dual_ir(

# Grab code associated to the primal.
primal_ir, _ = lookup_ir(interp, sig_or_mi)
@static if VERSION > v"1.12-"
primal_ir = set_valid_world!(primal_ir, interp.world)
end
nargs = length(primal_ir.argtypes)

# Normalise the IR.
Expand Down
33 changes: 32 additions & 1 deletion src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function __infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance)
propagate_inbounds = true
spec_info = CC.SpecInfo(nargs, isva, propagate_inbounds, nothing)
min_world = world = get_inference_world(interp)
max_world = Base.get_world_counter()
max_world = Base.get_world_counter() # TODO: why do we allow other worlds?
irsv = CC.IRInterpretationState(
interp, spec_info, ir, mi, ir.argtypes, world, min_world, max_world
)
Expand Down Expand Up @@ -285,6 +285,37 @@ function lookup_ir(::CC.AbstractInterpreter, mc::MistyClosure; optimize_until=no
return mc.ir[], return_type(mc.oc)
end

@static if VERSION > v"1.12-"
"""
set_valid_world!(ir::IRCode, world::UInt)::IRCode

(1.12+ only)
Create a shallow copy of the given IR code, with its `valid_worlds` field updated
to a single valid world. This allows the compiler to perform more inlining.

In particular, if the IR comes from say a function `f` which makes a call to another
function `g` which only got defined after `f`, then at the min_world when `f` was
defined, `g` was not available yet. If we restrict the IR to a world where `g` is
available then `g` can be inlined.

Will error if `world` is not in the existing `valid_worlds` of `ir`.
"""
function set_valid_world!(ir::IRCode, world::UInt)
if world ∉ ir.valid_worlds
error("World $world is not valid for this IRCode: $(ir.valid_worlds).")
end
return CC.IRCode(
ir.stmts,
ir.cfg,
ir.debuginfo,
ir.argtypes,
ir.meta,
ir.sptypes,
CC.WorldRange(world, world),
)
end
end

return_type(::Core.OpaqueClosure{A,B}) where {A,B} = B

"""
Expand Down
7 changes: 5 additions & 2 deletions src/interpreter/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1134,10 +1134,10 @@ function build_rrule(
else
# Derive forwards- and reverse-pass IR, and shove in `MistyClosure`s.
dri = generate_ir(interp, sig_or_mi; debug_mode)
fwd_oc = optimized_misty_closure(
fwd_oc = misty_closure(
dri.fwd_ret_type, dri.fwd_ir, dri.shared_data...
)
rvs_oc = optimized_misty_closure(
rvs_oc = misty_closure(
dri.rvs_ret_type, dri.rvs_ir, dri.shared_data...
)

Expand Down Expand Up @@ -1181,6 +1181,9 @@ function generate_ir(

# Grab code associated to the primal.
ir, _ = lookup_ir(interp, sig_or_mi)
@static if VERSION > v"1.12-"
ir = set_valid_world!(ir, interp.world)
end
Treturn = compute_ir_rettype(ir)
fwd_ret_type = forwards_ret_type(ir)
rvs_ret_type = pullback_ret_type(ir)
Expand Down
61 changes: 0 additions & 61 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,55 +354,6 @@ function opaque_closure(
)::Core.OpaqueClosure{sig,ret_type}
end

function optimized_opaque_closure(rtype, ir::IRCode, env...; kwargs...)
oc = opaque_closure(rtype, ir, env...; kwargs...)
world = UInt(oc.world)
set_world_bounds_for_optimization!(oc)
optimized_oc = optimize_opaque_closure(oc, rtype, env...; kwargs...)
return optimized_oc
end

function optimize_opaque_closure(oc::Core.OpaqueClosure, rtype, env...; kwargs...)
method = oc.source
ci = method.specializations.cache
world = UInt(oc.world)
ir = reinfer_and_inline(ci, world)
ir === nothing && return oc # nothing to optimize
return opaque_closure(rtype, ir, env...; kwargs...)
end

# Allows optimization to make assumptions about binding access,
# enabling inlining and other optimizations.
function set_world_bounds_for_optimization!(oc::Core.OpaqueClosure)
ci = oc.source.specializations.cache
ci.inferred === nothing && return nothing
ci.inferred.min_world = oc.world
ci.inferred.max_world = oc.world
end

function reinfer_and_inline(ci::Core.CodeInstance, world::UInt)
interp = CC.NativeInterpreter(world)
mi = get_mi(ci)
argtypes = collect(Any, mi.specTypes.parameters)
irsv = CC.IRInterpretationState(interp, ci, mi, argtypes, world)
irsv === nothing && return nothing
for stmt in irsv.ir.stmts
inst = stmt[:inst]
if isexpr(inst, :loopinfo) ||
isexpr(inst, :pop_exception) ||
isa(inst, CC.GotoIfNot) ||
isa(inst, CC.GotoNode) ||
isexpr(inst, :copyast)
continue
end
stmt[:flag] |= CC.IR_FLAG_REFINED
end
CC.ir_abstract_constant_propagation(interp, irsv)
state = CC.InliningState(interp)
ir = CC.ssa_inlining_pass!(irsv.ir, state, CC.propagate_inbounds(irsv))
return ir
end

"""
misty_closure(
ret_type::Type,
Expand All @@ -425,18 +376,6 @@ function misty_closure(
return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir))
end

function optimized_misty_closure(
ret_type::Type,
ir::IRCode,
@nospecialize env...;
isva::Bool=false,
do_compile::Bool=true,
)
return MistyClosure(
optimized_opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)
)
end

@static if VERSION > v"1.12-"
compute_ir_rettype(ir) = CC.compute_ir_rettype(ir)
compute_oc_signature(ir, nargs, isva) = CC.compute_oc_signature(ir, nargs, isva)
Expand Down
4 changes: 0 additions & 4 deletions test/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ end
@test length(stmt(ir.stmts)) == length(src.code)
oc = Mooncake.opaque_closure(rettype, ir)
@test oc(args...) == f(args...)
oc = Mooncake.optimized_opaque_closure(rettype, ir)
@test oc(args...) == f(args...)
end
@testset "infer_ir!" begin

Expand All @@ -44,8 +42,6 @@ end
ir.argtypes[1] = Tuple
oc = Mooncake.opaque_closure(Float64, ir)
@test oc(5.0) == cos(sin(5.0))
oc = Mooncake.optimized_opaque_closure(Float64, ir)
@test oc(5.0) == cos(sin(5.0))
end
@testset "lookup_ir" begin
tt = Tuple{typeof(sin),Float64}
Expand Down
5 changes: 0 additions & 5 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,9 @@
oc2 = Mooncake.opaque_closure(Any, ir)
@test oc2 isa Core.OpaqueClosure{Tuple{Float64},Any}
@test oc2(5.0) == sin(5.0)
oc2 = Mooncake.optimized_opaque_closure(Any, ir)
@test oc2 isa Core.OpaqueClosure{Tuple{Float64},Any}
@test oc2(5.0) == sin(5.0)

# Check that we can get a MistyClosure also.
mc = Mooncake.misty_closure(Float64, ir)
@test mc(5.0) == sin(5.0)
mc = Mooncake.optimized_misty_closure(Float64, ir)
@test mc(5.0) == sin(5.0)
end
end
Loading