From 69e559d496c69c7909ba87631cb552295e77255f Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Wed, 18 Apr 2018 18:02:54 -0400 Subject: [PATCH] inference: enable CodeInfo method_for_inference_limit_heuristics support (#26822) allow unbounded inference recursion, as long as the user-provided Method in method_for_inference_limit_heuristics does not match the target frame --- base/compiler/abstractinterpretation.jl | 5 +- base/compiler/typeinfer.jl | 2 +- base/compiler/utilities.jl | 32 +++---- src/dump.c | 3 +- src/jltypes.c | 2 +- src/julia.h | 2 +- src/method.c | 4 +- src/toplevel.c | 2 +- test/compiler/compiler.jl | 112 +++++++++++++++++------- test/compiler/ssair.jl | 2 +- 10 files changed, 102 insertions(+), 64 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index f3a8f98370c70..68299eb5385cb 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -187,6 +187,7 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl cyclei = 0 infstate = sv edgecycle = false + method2 = method_for_inference_heuristics(method, sig, sparams, sv.params.world) # Union{Method, Nothing} while !(infstate === nothing) infstate = infstate::InferenceState if method === infstate.linfo.def @@ -197,7 +198,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl edgecycle = true break end - if topmost === nothing + inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match + inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing} + if topmost === nothing && method2 === inf_method2 # inspect the parent of this edge, # to see if they are the same Method as sv # in which case we'll need to ensure it is convergent diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 67aed9648aeb8..8aea9e5d753ea 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -168,7 +168,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool, method = linfo.def::Method tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ()) tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ] - tree.signature_for_inference_heuristics = nothing + tree.method_for_inference_limit_heuristics = nothing tree.slotnames = Any[ COMPILER_TEMP_SYM for i = 1:method.nargs ] tree.slotflags = UInt8[ 0 for i = 1:method.nargs ] tree.slottypes = nothing diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index 7955241182ca9..94a7d087a3bad 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -155,33 +155,21 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any, UInt), method, atypes, sparams, world) end -# TODO: Use these functions instead of directly manipulating -# the "actual" method for appropriate places in inference (see #24676) -function method_for_inference_heuristics(cinfo, default) - if isa(cinfo, CodeInfo) - # appropriate format for `sig` is svec(ftype, argtypes, world) - sig = cinfo.signature_for_inference_heuristics - if isa(sig, SimpleVector) && length(sig) == 3 - methods = _methods(sig[1], sig[2], -1, sig[3]) - if length(methods) == 1 - _, _, m = methods[] - if isa(m, Method) - return m - end - end - end - end - return default -end - -function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world) +# This function is used for computing alternate limit heuristics +function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt) if isdefined(method, :generator) && method.generator.expand_early method_instance = code_for_method(method, sig, sparams, world, false) if isa(method_instance, MethodInstance) - return method_for_inference_heuristics(get_staged(method_instance), method) + cinfo = get_staged(method_instance) + if isa(cinfo, CodeInfo) + method2 = cinfo.method_for_inference_limit_heuristics + if method2 isa Method + return method2 + end + end end end - return method + return nothing end function exprtype(@nospecialize(x), src, mod::Module) diff --git a/src/dump.c b/src/dump.c index a68163e6fb6fe..54135fc445cac 100644 --- a/src/dump.c +++ b/src/dump.c @@ -2300,7 +2300,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code) size_t nf = jl_datatype_nfields(jl_code_info_type); for (i = 0; i < nf - 5; i++) { - jl_serialize_value_(&s, jl_get_nth_field((jl_value_t*)code, i), 1); + int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field + jl_serialize_value_(&s, jl_get_nth_field((jl_value_t*)code, i), copy); } ios_putc('\0', s.s); diff --git a/src/jltypes.c b/src/jltypes.c index 5d72efd4bf60b..2518943aa9126 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2026,7 +2026,7 @@ void jl_init_types(void) jl_perm_symsvec(12, "code", "codelocs", - "signature_for_inference_heuristics", + "method_for_inference_limit_heuristics", "slottypes", "ssavaluetypes", "linetable", diff --git a/src/julia.h b/src/julia.h index 17c8d2075dbb9..022a13926ed12 100644 --- a/src/julia.h +++ b/src/julia.h @@ -235,7 +235,7 @@ typedef struct _jl_llvm_functions_t { typedef struct _jl_code_info_t { jl_array_t *code; // Any array of statements jl_value_t *codelocs; // Int array of indicies into the line table - jl_value_t *signature_for_inference_heuristics; // optional method used during inference + jl_value_t *method_for_inference_limit_heuristics; // optional method used during inference jl_value_t *slottypes; // types of variable slots (or `nothing`) jl_value_t *ssavaluetypes; // types of ssa values (or count of them) jl_value_t *linetable; // Table of locations diff --git a/src/method.c b/src/method.c index 45b9678250578..0debde199e7a2 100644 --- a/src/method.c +++ b/src/method.c @@ -232,7 +232,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast) jl_array_del_end(meta, na - ins); } } - li->signature_for_inference_heuristics = jl_nothing; + li->method_for_inference_limit_heuristics = jl_nothing; jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1); jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0); size_t nslots = jl_array_len(vis); @@ -303,7 +303,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void) (jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t), jl_code_info_type); src->code = NULL; - src->signature_for_inference_heuristics = NULL; + src->method_for_inference_limit_heuristics = NULL; src->slotnames = NULL; src->slotflags = NULL; src->slottypes = NULL; diff --git a/src/toplevel.c b/src/toplevel.c index e8f368b9c26d4..20c736371ea6a 100644 --- a/src/toplevel.c +++ b/src/toplevel.c @@ -615,7 +615,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr) jl_gc_wb(src, src->slotflags); src->ssavaluetypes = jl_box_long(0); jl_gc_wb(src, src->ssavaluetypes); - src->signature_for_inference_heuristics = jl_nothing; + src->method_for_inference_limit_heuristics = jl_nothing; src->codelocs = jl_nothing; src->linetable = jl_nothing; diff --git a/test/compiler/compiler.jl b/test/compiler/compiler.jl index f93325681d53a..2d576b07339ef 100644 --- a/test/compiler/compiler.jl +++ b/test/compiler/compiler.jl @@ -1310,73 +1310,119 @@ function _generated_stub(gen::Symbol, args::Vector{Any}, params::Vector{Any}, li return Expr(:meta, :generated, stub) end -f24852_kernel(x, y) = x * y - -function f24852_kernel_cinfo(x, y) - sig, spvals, method = Base._methods_by_ftype(Tuple{typeof(f24852_kernel),x,y}, -1, typemax(UInt))[1] +f24852_kernel1(x, y::Tuple) = x * y[1][1][1] +f24852_kernel2(x, y::Tuple) = f24852_kernel1(x, (y,)) +f24852_kernel3(x, y::Tuple) = f24852_kernel2(x, (y,)) +f24852_kernel(x, y::Number) = f24852_kernel3(x, (y,)) + +function f24852_kernel_cinfo(fsig::Type) + world = typemax(UInt) # FIXME + sig, spvals, method = Base._methods_by_ftype(fsig, -1, world)[1] + isdefined(method, :source) || return (nothing, :(f(x, y))) code_info = Base.uncompressed_ast(method) body = Expr(:block, code_info.code...) - Base.Core.Compiler.substitute!(body, 0, Any[], sig, Any[spvals...], 0, :propagate) + Base.Core.Compiler.substitute!(body, 0, Any[], sig, Any[spvals...], 1, :propagate) + if startswith(String(method.name), "f24852") + for a in body.args + if a isa Expr && a.head == :(=) + a = a.args[2] + end + if a isa Expr && length(a.args) === 3 && a.head === :call + pushfirst!(a.args, Core.SlotNumber(1)) + end + end + end + pushfirst!(code_info.slotnames, Symbol("#self#")) + pushfirst!(code_info.slotflags, 0x00) return method, code_info end -function f24852_gen_cinfo_uninflated(X, Y, f, x, y) - _, code_info = f24852_kernel_cinfo(x, y) +function f24852_gen_cinfo_uninflated(X, Y, _, f, x, y) + _, code_info = f24852_kernel_cinfo(Tuple{f, x, y}) return code_info end -function f24852_gen_cinfo_inflated(X, Y, f, x, y) - method, code_info = f24852_kernel_cinfo(x, y) - code_info.signature_for_inference_heuristics = Core.Compiler.svec(f, (x, y), typemax(UInt)) +function f24852_gen_cinfo_inflated(X, Y, _, f, x, y) + method, code_info = f24852_kernel_cinfo(Tuple{f, x, y}) + code_info.method_for_inference_limit_heuristics = method return code_info end -function f24852_gen_expr(X, Y, f, x, y) - return :(f24852_kernel(x::$X, y::$Y)) +function f24852_gen_expr(X, Y, _, f, x, y) # deparse f(x::X, y::Y) where {X, Y} + if f === typeof(f24852_kernel) + f2 = :f24852_kernel3 + elseif f === typeof(f24852_kernel3) + f2 = :f24852_kernel2 + elseif f === typeof(f24852_kernel2) + f2 = :f24852_kernel1 + elseif f === typeof(f24852_kernel1) + return :((x::$X) * (y::$Y)[1][1][1]) + else + return :(error(repr(f))) + end + return :(f24852_late_expr($f2, x::$X, (y::$Y,))) end @eval begin - function f24852_late_expr(x::X, y::Y) where {X, Y} - $(_generated_stub(:f24852_gen_expr, Any[:f24852_late_expr, :x, :y], + function f24852_late_expr(f, x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_expr, Any[:self, :f, :x, :y], Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false)) + $(Expr(:meta, :generated_only)) + #= no body =# end - function f24852_late_inflated(x::X, y::Y) where {X, Y} - $(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_late_inflated, :x, :y], + function f24852_late_inflated(f, x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_cinfo_inflated, Any[:self, :f, :x, :y], Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false)) + $(Expr(:meta, :generated_only)) + #= no body =# end - function f24852_late_uninflated(x::X, y::Y) where {X, Y} - $(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_late_uninflated, :x, :y], + function f24852_late_uninflated(f, x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:self, :f, :x, :y], Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false)) + $(Expr(:meta, :generated_only)) + #= no body =# end end @eval begin - function f24852_early_expr(x::X, y::Y) where {X, Y} - $(_generated_stub(:f24852_gen_expr, Any[:f24852_early_expr, :x, :y], + function f24852_early_expr(f, x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_expr, Any[:self, :f, :x, :y], Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true)) + $(Expr(:meta, :generated_only)) + #= no body =# end - function f24852_early_inflated(x::X, y::Y) where {X, Y} - $(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_early_inflated, :x, :y], + function f24852_early_inflated(f, x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_cinfo_inflated, Any[:self, :f, :x, :y], Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true)) + $(Expr(:meta, :generated_only)) + #= no body =# end - function f24852_early_uninflated(x::X, y::Y) where {X, Y} - $(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_early_uninflated, :x, :y], + function f24852_early_uninflated(f, x::X, y::Y) where {X, Y} + $(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:self, :f, :x, :y], Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true)) + $(Expr(:meta, :generated_only)) + #= no body =# end end x, y = rand(), rand() result = f24852_kernel(x, y) -@test result === f24852_late_expr(x, y) -@test result === f24852_late_uninflated(x, y) -@test result === f24852_late_inflated(x, y) - -@test result === f24852_early_expr(x, y) -@test result === f24852_early_uninflated(x, y) -@test result === f24852_early_inflated(x, y) - -# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics` +@test result === f24852_late_expr(f24852_kernel, x, y) +@test Base.return_types(f24852_late_expr, typeof((f24852_kernel, x, y))) == Any[Any] +@test result === f24852_late_uninflated(f24852_kernel, x, y) +@test Base.return_types(f24852_late_uninflated, typeof((f24852_kernel, x, y))) == Any[Any] +@test result === f24852_late_uninflated(f24852_kernel, x, y) +@test Base.return_types(f24852_late_uninflated, typeof((f24852_kernel, x, y))) == Any[Any] + +@test result === f24852_early_expr(f24852_kernel, x, y) +@test Base.return_types(f24852_early_expr, typeof((f24852_kernel, x, y))) == Any[Any] +@test result === f24852_early_uninflated(f24852_kernel, x, y) +@test Base.return_types(f24852_early_uninflated, typeof((f24852_kernel, x, y))) == Any[Any] +@test result === @inferred f24852_early_inflated(f24852_kernel, x, y) +@test Base.return_types(f24852_early_inflated, typeof((f24852_kernel, x, y))) == Any[Float64] + +# TODO: test that `expand_early = true` + inflated `method_for_inference_limit_heuristics` # can be used to tighten up some inference result. # Test that Conditional doesn't get widened to Bool too quickly diff --git a/test/compiler/ssair.jl b/test/compiler/ssair.jl index a1e429dcd3149..695219a5c3090 100644 --- a/test/compiler/ssair.jl +++ b/test/compiler/ssair.jl @@ -22,4 +22,4 @@ let code = Any[ )) Compiler.run_passes(ci, 1, Compiler.LineInfoNode[Compiler.NullLineInfo]) -end \ No newline at end of file +end