Skip to content

Commit

Permalink
inference: enable CodeInfo method_for_inference_limit_heuristics supp…
Browse files Browse the repository at this point in the history
…ort (#26822)

allow unbounded inference recursion, as long as the user-provided Method in
method_for_inference_limit_heuristics does not match the target frame
  • Loading branch information
vtjnash authored and jrevels committed Apr 18, 2018
1 parent 8e7034e commit 69e559d
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 64 deletions.
5 changes: 4 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
cyclei = 0
infstate = sv
edgecycle = false
method2 = method_for_inference_heuristics(method, sig, sparams, sv.params.world) # Union{Method, Nothing}
while !(infstate === nothing)
infstate = infstate::InferenceState
if method === infstate.linfo.def
Expand All @@ -197,7 +198,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl
edgecycle = true
break
end
if topmost === nothing
inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
if topmost === nothing && method2 === inf_method2
# inspect the parent of this edge,
# to see if they are the same Method as sv
# in which case we'll need to ensure it is convergent
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
method = linfo.def::Method
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ]
tree.signature_for_inference_heuristics = nothing
tree.method_for_inference_limit_heuristics = nothing
tree.slotnames = Any[ COMPILER_TEMP_SYM for i = 1:method.nargs ]
tree.slotflags = UInt8[ 0 for i = 1:method.nargs ]
tree.slottypes = nothing
Expand Down
32 changes: 10 additions & 22 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,33 +155,21 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV
return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any, UInt), method, atypes, sparams, world)
end

# TODO: Use these functions instead of directly manipulating
# the "actual" method for appropriate places in inference (see #24676)
function method_for_inference_heuristics(cinfo, default)
if isa(cinfo, CodeInfo)
# appropriate format for `sig` is svec(ftype, argtypes, world)
sig = cinfo.signature_for_inference_heuristics
if isa(sig, SimpleVector) && length(sig) == 3
methods = _methods(sig[1], sig[2], -1, sig[3])
if length(methods) == 1
_, _, m = methods[]
if isa(m, Method)
return m
end
end
end
end
return default
end

function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world)
# This function is used for computing alternate limit heuristics
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt)
if isdefined(method, :generator) && method.generator.expand_early
method_instance = code_for_method(method, sig, sparams, world, false)
if isa(method_instance, MethodInstance)
return method_for_inference_heuristics(get_staged(method_instance), method)
cinfo = get_staged(method_instance)
if isa(cinfo, CodeInfo)
method2 = cinfo.method_for_inference_limit_heuristics
if method2 isa Method
return method2
end
end
end
end
return method
return nothing
end

function exprtype(@nospecialize(x), src, mod::Module)
Expand Down
3 changes: 2 additions & 1 deletion src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2300,7 +2300,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)

size_t nf = jl_datatype_nfields(jl_code_info_type);
for (i = 0; i < nf - 5; i++) {
jl_serialize_value_(&s, jl_get_nth_field((jl_value_t*)code, i), 1);
int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field
jl_serialize_value_(&s, jl_get_nth_field((jl_value_t*)code, i), copy);
}

ios_putc('\0', s.s);
Expand Down
2 changes: 1 addition & 1 deletion src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2026,7 +2026,7 @@ void jl_init_types(void)
jl_perm_symsvec(12,
"code",
"codelocs",
"signature_for_inference_heuristics",
"method_for_inference_limit_heuristics",
"slottypes",
"ssavaluetypes",
"linetable",
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ typedef struct _jl_llvm_functions_t {
typedef struct _jl_code_info_t {
jl_array_t *code; // Any array of statements
jl_value_t *codelocs; // Int array of indicies into the line table
jl_value_t *signature_for_inference_heuristics; // optional method used during inference
jl_value_t *method_for_inference_limit_heuristics; // optional method used during inference
jl_value_t *slottypes; // types of variable slots (or `nothing`)
jl_value_t *ssavaluetypes; // types of ssa values (or count of them)
jl_value_t *linetable; // Table of locations
Expand Down
4 changes: 2 additions & 2 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast)
jl_array_del_end(meta, na - ins);
}
}
li->signature_for_inference_heuristics = jl_nothing;
li->method_for_inference_limit_heuristics = jl_nothing;
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1);
jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0);
size_t nslots = jl_array_len(vis);
Expand Down Expand Up @@ -303,7 +303,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
(jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t),
jl_code_info_type);
src->code = NULL;
src->signature_for_inference_heuristics = NULL;
src->method_for_inference_limit_heuristics = NULL;
src->slotnames = NULL;
src->slotflags = NULL;
src->slottypes = NULL;
Expand Down
2 changes: 1 addition & 1 deletion src/toplevel.c
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr)
jl_gc_wb(src, src->slotflags);
src->ssavaluetypes = jl_box_long(0);
jl_gc_wb(src, src->ssavaluetypes);
src->signature_for_inference_heuristics = jl_nothing;
src->method_for_inference_limit_heuristics = jl_nothing;
src->codelocs = jl_nothing;
src->linetable = jl_nothing;

Expand Down
112 changes: 79 additions & 33 deletions test/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1310,73 +1310,119 @@ function _generated_stub(gen::Symbol, args::Vector{Any}, params::Vector{Any}, li
return Expr(:meta, :generated, stub)
end

f24852_kernel(x, y) = x * y

function f24852_kernel_cinfo(x, y)
sig, spvals, method = Base._methods_by_ftype(Tuple{typeof(f24852_kernel),x,y}, -1, typemax(UInt))[1]
f24852_kernel1(x, y::Tuple) = x * y[1][1][1]
f24852_kernel2(x, y::Tuple) = f24852_kernel1(x, (y,))
f24852_kernel3(x, y::Tuple) = f24852_kernel2(x, (y,))
f24852_kernel(x, y::Number) = f24852_kernel3(x, (y,))

function f24852_kernel_cinfo(fsig::Type)
world = typemax(UInt) # FIXME
sig, spvals, method = Base._methods_by_ftype(fsig, -1, world)[1]
isdefined(method, :source) || return (nothing, :(f(x, y)))
code_info = Base.uncompressed_ast(method)
body = Expr(:block, code_info.code...)
Base.Core.Compiler.substitute!(body, 0, Any[], sig, Any[spvals...], 0, :propagate)
Base.Core.Compiler.substitute!(body, 0, Any[], sig, Any[spvals...], 1, :propagate)
if startswith(String(method.name), "f24852")
for a in body.args
if a isa Expr && a.head == :(=)
a = a.args[2]
end
if a isa Expr && length(a.args) === 3 && a.head === :call
pushfirst!(a.args, Core.SlotNumber(1))
end
end
end
pushfirst!(code_info.slotnames, Symbol("#self#"))
pushfirst!(code_info.slotflags, 0x00)
return method, code_info
end

function f24852_gen_cinfo_uninflated(X, Y, f, x, y)
_, code_info = f24852_kernel_cinfo(x, y)
function f24852_gen_cinfo_uninflated(X, Y, _, f, x, y)
_, code_info = f24852_kernel_cinfo(Tuple{f, x, y})
return code_info
end

function f24852_gen_cinfo_inflated(X, Y, f, x, y)
method, code_info = f24852_kernel_cinfo(x, y)
code_info.signature_for_inference_heuristics = Core.Compiler.svec(f, (x, y), typemax(UInt))
function f24852_gen_cinfo_inflated(X, Y, _, f, x, y)
method, code_info = f24852_kernel_cinfo(Tuple{f, x, y})
code_info.method_for_inference_limit_heuristics = method
return code_info
end

function f24852_gen_expr(X, Y, f, x, y)
return :(f24852_kernel(x::$X, y::$Y))
function f24852_gen_expr(X, Y, _, f, x, y) # deparse f(x::X, y::Y) where {X, Y}
if f === typeof(f24852_kernel)
f2 = :f24852_kernel3
elseif f === typeof(f24852_kernel3)
f2 = :f24852_kernel2
elseif f === typeof(f24852_kernel2)
f2 = :f24852_kernel1
elseif f === typeof(f24852_kernel1)
return :((x::$X) * (y::$Y)[1][1][1])
else
return :(error(repr(f)))
end
return :(f24852_late_expr($f2, x::$X, (y::$Y,)))
end

@eval begin
function f24852_late_expr(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:f24852_late_expr, :x, :y],
function f24852_late_expr(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
$(Expr(:meta, :generated_only))
#= no body =#
end
function f24852_late_inflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_late_inflated, :x, :y],
function f24852_late_inflated(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
$(Expr(:meta, :generated_only))
#= no body =#
end
function f24852_late_uninflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_late_uninflated, :x, :y],
function f24852_late_uninflated(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
$(Expr(:meta, :generated_only))
#= no body =#
end
end

@eval begin
function f24852_early_expr(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:f24852_early_expr, :x, :y],
function f24852_early_expr(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
$(Expr(:meta, :generated_only))
#= no body =#
end
function f24852_early_inflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_early_inflated, :x, :y],
function f24852_early_inflated(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
$(Expr(:meta, :generated_only))
#= no body =#
end
function f24852_early_uninflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_early_uninflated, :x, :y],
function f24852_early_uninflated(f, x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:self, :f, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
$(Expr(:meta, :generated_only))
#= no body =#
end
end

x, y = rand(), rand()
result = f24852_kernel(x, y)

@test result === f24852_late_expr(x, y)
@test result === f24852_late_uninflated(x, y)
@test result === f24852_late_inflated(x, y)

@test result === f24852_early_expr(x, y)
@test result === f24852_early_uninflated(x, y)
@test result === f24852_early_inflated(x, y)

# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics`
@test result === f24852_late_expr(f24852_kernel, x, y)
@test Base.return_types(f24852_late_expr, typeof((f24852_kernel, x, y))) == Any[Any]
@test result === f24852_late_uninflated(f24852_kernel, x, y)
@test Base.return_types(f24852_late_uninflated, typeof((f24852_kernel, x, y))) == Any[Any]
@test result === f24852_late_uninflated(f24852_kernel, x, y)
@test Base.return_types(f24852_late_uninflated, typeof((f24852_kernel, x, y))) == Any[Any]

@test result === f24852_early_expr(f24852_kernel, x, y)
@test Base.return_types(f24852_early_expr, typeof((f24852_kernel, x, y))) == Any[Any]
@test result === f24852_early_uninflated(f24852_kernel, x, y)
@test Base.return_types(f24852_early_uninflated, typeof((f24852_kernel, x, y))) == Any[Any]
@test result === @inferred f24852_early_inflated(f24852_kernel, x, y)
@test Base.return_types(f24852_early_inflated, typeof((f24852_kernel, x, y))) == Any[Float64]

# TODO: test that `expand_early = true` + inflated `method_for_inference_limit_heuristics`
# can be used to tighten up some inference result.

# Test that Conditional doesn't get widened to Bool too quickly
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/ssair.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ let code = Any[
))

Compiler.run_passes(ci, 1, Compiler.LineInfoNode[Compiler.NullLineInfo])
end
end

0 comments on commit 69e559d

Please sign in to comment.