Skip to content

Commit

Permalink
add a new Expr head for directly calling a LambdaInfo
Browse files Browse the repository at this point in the history
Expr(:invoke, LambdaInfo, call-args...)

is a more primitive form of Expr(:call)
for which the dispatch logic has been pre-determined

this is not used by lowering, but is used by inference
to allowing moving of this logic out of codegen
  • Loading branch information
vtjnash committed Jun 2, 2016
1 parent bdcd426 commit 9e119fa
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 56 deletions.
9 changes: 8 additions & 1 deletion base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ function abstract_call(f::ANY, fargs, argtypes::Vector{Any}, vtypes::VarTable, s
return abstract_call_gf_by_type(f, atype, sv)
end

function abstract_eval_call(e, vtypes::VarTable, sv::InferenceState)
function abstract_eval_call(e::Expr, vtypes::VarTable, sv::InferenceState)
argtypes = Any[abstract_eval(a, vtypes, sv) for a in e.args]
#print("call ", e.args[1], argtypes, "\n\n")
for x in argtypes
Expand Down Expand Up @@ -2432,6 +2432,13 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
return NF
end
=#

# convert call to invoke
cache_linfo = ccall(:jl_get_spec_lambda, Any, (Any,), atype) # TODO: merge tfunc and spec arrays so this lookup unnecessary
if cache_linfo !== nothing
e.head = :invoke
unshift!(e.args, cache_linfo)
end
return NF
end

Expand Down
47 changes: 44 additions & 3 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,15 @@ end

function show(io::IO, l::LambdaInfo)
if isdefined(l, :def)
println(io, "LambdaInfo for ", l.def.name)
if (l === l.def.lambda_template)
print(io, "LambdaInfo template for ")
show(io, l.def)
println(io)
else
print(io, "LambdaInfo for ")
show_lambda_types(io, l.specTypes.parameters)
println(io)
end
else
println(io, "Toplevel LambdaInfo thunk")
end
Expand Down Expand Up @@ -947,15 +955,44 @@ function show_unquoted(io::IO, ex::Expr, indent::Int, prec::Int)
show(io, ex.head)
for arg in args
print(io, ", ")
show(io, arg)
if isa(arg, LambdaInfo) && isdefined(arg, :specTypes)
show_lambda_types(io, arg.specTypes.parameters)
else
show(io, arg)
end
end
print(io, "))")
end
show_type && show_expr_type(io, ex.typ, emphstate)
nothing
end

function show_lambda_types(io::IO, sig::SimpleVector)
# print a method signature tuple
ft = sig[1]
if ft <: Function && isempty(ft.parameters) &&
isdefined(ft.name.module, ft.name.mt.name) &&
ft == typeof(getfield(ft.name.module, ft.name.mt.name))
print(io, ft.name.mt.name)
elseif isa(ft, DataType) && is(ft.name, Type.name) && isleaftype(ft)
f = ft.parameters[1]
print(io, f)
else
print(io, "(::", ft, ")")
end
first = true
print(io, '(')
for i = 2:length(sig) # fixme (iter): `eachindex` with offset?
first || print(io, ", ")
first = false
print(io, "::", sig[i])
end
print(io, ')')
nothing
end

function ismodulecall(ex::Expr)
ex.head == :call && (ex.args[1] == GlobalRef(Base,:getfield) ||
return ex.head == :call && (ex.args[1] == GlobalRef(Base,:getfield) ||
ex.args[1] == GlobalRef(Core,:getfield)) &&
isa(ex.args[2], Symbol) &&
isdefined(current_module(), ex.args[2]) &&
Expand Down Expand Up @@ -989,6 +1026,7 @@ function show(io::IO, tv::TypeVar)
print(io, "<:")
show(io, tv.ub)
end
nothing
end

function dump(io::IO, x::SimpleVector, n::Int, indent)
Expand All @@ -1008,6 +1046,7 @@ function dump(io::IO, x::SimpleVector, n::Int, indent)
end
end
end
nothing
end

function dump(io::IO, x::ANY, n::Int, indent)
Expand Down Expand Up @@ -1069,6 +1108,7 @@ function dump(io::IO, x::Array, n::Int, indent)
end
end
end
nothing
end
dump(io::IO, x::Symbol, n::Int, indent) = print(io, typeof(x), " ", x)

Expand All @@ -1090,6 +1130,7 @@ function dump(io::IO, x::DataType, n::Int, indent)
end
end
end
nothing
end

# dumptype is for displaying abstract type hierarchies like Jameson
Expand Down
28 changes: 2 additions & 26 deletions base/stacktraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,32 +162,8 @@ function show_spec_linfo(io::IO, frame::StackFrame)
end
else
linfo = get(frame.linfo)
params =
if isdefined(linfo, :specTypes)
linfo.specTypes.parameters
else
nothing
end
if params !== nothing
ft = params[1]
if ft <: Function && isempty(ft.parameters) &&
isdefined(ft.name.module, ft.name.mt.name) &&
ft == typeof(getfield(ft.name.module, ft.name.mt.name))
print(io, ft.name.mt.name)
elseif isa(ft, DataType) && is(ft.name, Type.name) && isleaftype(ft)
f = ft.parameters[1]
print(io, f)
else
print(io, "(::", ft, ")")
end
first = true
print(io, '(')
for i = 2:length(params) # fixme (iter): `eachindex` with offset?
first || print(io, ", ")
first = false
print(io, "::", params[i])
end
print(io, ')')
if isdefined(linfo, :specTypes)
Base.show_lambda_types(io, linfo.specTypes.parameters)
else
print(io, linfo.name)
end
Expand Down
3 changes: 2 additions & 1 deletion src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ jl_value_t *jl_memory_exception;
jl_value_t *jl_readonlymemory_exception;
union jl_typemap_t jl_cfunction_list;

jl_sym_t *call_sym; jl_sym_t *dots_sym;
jl_sym_t *call_sym; jl_sym_t *invoke_sym;
jl_sym_t *dots_sym;
jl_sym_t *module_sym; jl_sym_t *slot_sym;
jl_sym_t *empty_sym;
jl_sym_t *export_sym; jl_sym_t *import_sym;
Expand Down
32 changes: 30 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,7 @@ void *jl_get_llvmf(jl_tupletype_t *tt, bool getwrapper, bool getdeclarations)
JL_GC_POP();
return NULL;
}
jl_compile_linfo(linfo); // make sure to compile this normally first, since `emit_function` doesn't handle recursive compilation correctly

if (!getdeclarations) {
// emit this function into a new module
Expand Down Expand Up @@ -1172,7 +1173,6 @@ void *jl_get_llvmf(jl_tupletype_t *tt, bool getwrapper, bool getdeclarations)
return specf;
}
}
jl_compile_linfo(linfo);
Function *llvmf;
if (!getwrapper && linfo->functionObjectsDecls.specFunctionObject != NULL) {
llvmf = (Function*)linfo->functionObjectsDecls.specFunctionObject;
Expand Down Expand Up @@ -2705,12 +2705,33 @@ static jl_cgval_t emit_call_function_object(jl_lambda_info_t *li, const jl_cgval
expr_type(callexpr, ctx), ctx);
}

static jl_cgval_t emit_invoke(jl_expr_t *ex, jl_codectx_t *ctx)
{
jl_value_t **args = (jl_value_t**)jl_array_data(ex->args);
size_t arglen = jl_array_dim0(ex->args);
size_t nargs = arglen - 2;
assert(arglen >= 2);
jl_lambda_info_t *li = (jl_lambda_info_t*)args[0];
assert(jl_is_lambda_info(li) && !li->inInference);

jl_compile_linfo(li);
assert(li->functionObjectsDecls.functionObject != NULL);
Value *theFptr = (Value*)li->functionObjectsDecls.functionObject;
jl_cgval_t fval = emit_expr(args[1], ctx);
jl_cgval_t result = emit_call_function_object(li, fval, theFptr, &args[1], nargs, (jl_value_t*)ex, ctx);
if (result.typ == jl_bottom_type) {
CreateTrap(builder);
}
return result;
}

static jl_cgval_t emit_call(jl_expr_t *ex, jl_codectx_t *ctx)
{
jl_value_t *expr = (jl_value_t*)ex;
jl_value_t **args = (jl_value_t**)jl_array_data(ex->args);
size_t arglen = jl_array_dim0(ex->args);
size_t nargs = arglen - 1;
assert(arglen >= 1);
Value *theFptr = NULL;
jl_cgval_t result;
jl_value_t *aty = NULL;
Expand Down Expand Up @@ -2757,7 +2778,8 @@ static jl_cgval_t emit_call(jl_expr_t *ex, jl_codectx_t *ctx)
jl_sprint((jl_value_t*)aty));
}*/
jl_lambda_info_t *li = jl_get_specialization1((jl_tupletype_t*)aty);
if (li != NULL) {
if (li != NULL && !li->inInference) {
jl_compile_linfo(li);
assert(li->functionObjectsDecls.functionObject != NULL);
theFptr = (Value*)li->functionObjectsDecls.functionObject;
jl_cgval_t fval;
Expand Down Expand Up @@ -3204,6 +3226,9 @@ static jl_cgval_t emit_expr(jl_value_t *expr, jl_codectx_t *ctx)
builder.CreateCondBr(isfalse, ifnot, ifso);
builder.SetInsertPoint(ifso);
}
else if (head == invoke_sym) {
return emit_invoke(ex, ctx);
}
else if (head == call_sym) {
if (ctx->linfo->def) { // don't bother codegen constant-folding for toplevel
jl_value_t *c = static_eval(expr, ctx, true, true);
Expand Down Expand Up @@ -3530,7 +3555,10 @@ static Function *gen_cfun_wrapper(jl_function_t *ff, jl_value_t *jlrettype, jl_t
const char *name = "cfunction";
jl_lambda_info_t *lam = jl_get_specialization1((jl_tupletype_t*)sigt);
jl_value_t *astrt = (jl_value_t*)jl_any_type;
if (lam && lam->inInference)
lam = NULL;
if (lam != NULL) {
jl_compile_linfo(lam);
name = jl_symbol_name(lam->def->name);
astrt = lam->rettype;
if (astrt != (jl_value_t*)jl_bottom_type &&
Expand Down
2 changes: 1 addition & 1 deletion src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2408,7 +2408,7 @@ void jl_init_serializer(void)
// everything above here represents a class of object rather than only a literal

jl_emptysvec, jl_emptytuple, jl_false, jl_true, jl_nothing, jl_any_type,
call_sym, goto_ifnot_sym, return_sym, body_sym, line_sym,
call_sym, invoke_sym, goto_ifnot_sym, return_sym, body_sym, line_sym,
lambda_sym, jl_symbol("tuple"), assign_sym,

// empirical list of very common symbols
Expand Down
42 changes: 22 additions & 20 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ jl_lambda_info_t *jl_get_specialization1(jl_tupletype_t *types)

// make sure exactly 1 method matches (issue #7302).
int i;
for(i=0; i < jl_nparams(types); i++) {
for (i = 0; i < jl_nparams(types); i++) {
jl_value_t *ti = jl_tparam(types, i);
// if one argument type is DataType, multiple Type{} definitions
// might match. also be conservative with tuples rather than trying
Expand All @@ -1083,31 +1083,29 @@ jl_lambda_info_t *jl_get_specialization1(jl_tupletype_t *types)
JL_TRY {
sf = jl_method_lookup_by_type(mt, types, 1, 1);
} JL_CATCH {
goto not_found;
sf = NULL;
}
if (sf != NULL) {
jl_method_t *m = sf->def;
if (jl_has_call_ambiguities(types, m)) {
goto not_found;
}
}
if (sf == NULL || sf->code == NULL || sf->inInference)
goto not_found;
if (sf->functionObjectsDecls.functionObject == NULL) {
if (sf->fptr != NULL)
goto not_found;
jl_compile_linfo(sf);
if (sf == NULL || sf->code == NULL ||
jl_has_call_ambiguities(types, sf->def)) {
sf = NULL;
}
JL_GC_POP();
return sf;
not_found:
JL_GC_POP();
return NULL;
}

JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
{
return jl_get_specialization1(types) != NULL;
jl_lambda_info_t *li = jl_get_specialization1(types);
if (li == NULL)
return 0;
jl_compile_linfo(li);
return 1;
}

JL_DLLEXPORT jl_value_t *jl_get_spec_lambda(jl_tupletype_t *types)
{
jl_value_t *li = (jl_value_t*)jl_get_specialization1(types);
return li ? li : jl_nothing;
}

int jl_has_call_ambiguities(jl_tupletype_t *types, jl_method_t *m)
Expand Down Expand Up @@ -1170,8 +1168,10 @@ static int _compile_all_tvar_union(jl_tupletype_t *methsig, jl_svec_t *tvars)
// usually can create a specialized version of the function,
// if the signature is already a leaftype
jl_lambda_info_t *spec = jl_get_specialization1(methsig);
if (spec)
if (spec) {
jl_compile_linfo(spec);
return 1;
}
}
return 0;
}
Expand Down Expand Up @@ -1202,7 +1202,9 @@ static int _compile_all_tvar_union(jl_tupletype_t *methsig, jl_svec_t *tvars)
goto getnext; // signature wouldn't be callable / is invalid -- skip it
}
if (jl_is_leaf_type(sig)) {
if (jl_get_specialization1((jl_tupletype_t*)sig)) {
jl_lambda_info_t *spec = jl_get_specialization1((jl_tupletype_t*)sig);
if (spec) {
jl_compile_linfo(spec);
if (!jl_has_typevars((jl_value_t*)sig)) goto getnext; // success
}
}
Expand Down
18 changes: 18 additions & 0 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ static jl_value_t *do_call(jl_value_t **args, size_t nargs, jl_value_t **locals,
return result;
}

static jl_value_t *do_invoke(jl_value_t **args, size_t nargs, jl_value_t **locals,
jl_lambda_info_t *lam)
{
jl_value_t **argv;
JL_GC_PUSHARGS(argv, nargs - 1);
size_t i;
for (i = 1; i < nargs; i++)
argv[i - 1] = eval(args[i], locals, lam);
jl_lambda_info_t *meth = (jl_lambda_info_t*)args[0];
assert(jl_is_lambda_info(meth) && !meth->inInference);
jl_value_t *result = jl_call_method_internal(meth, argv, nargs - 1);
JL_GC_POP();
return result;
}

jl_value_t *jl_eval_global_var(jl_module_t *m, jl_sym_t *e)
{
jl_value_t *v = jl_get_global(m, e);
Expand Down Expand Up @@ -173,6 +188,9 @@ static jl_value_t *eval(jl_value_t *e, jl_value_t **locals, jl_lambda_info_t *la
if (ex->head == call_sym) {
return do_call(args, nargs, locals, lam);
}
else if (ex->head == invoke_sym) {
return do_invoke(args, nargs, locals, lam);
}
else if (ex->head == assign_sym) {
jl_value_t *sym = args[0];
jl_value_t *rhs = eval(args[1], locals, lam);
Expand Down
1 change: 1 addition & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3732,6 +3732,7 @@ void jl_init_types(void)

empty_sym = jl_symbol("");
call_sym = jl_symbol("call");
invoke_sym = jl_symbol("invoke");
quote_sym = jl_symbol("quote");
inert_sym = jl_symbol("inert");
top_sym = jl_symbol("top");
Expand Down
3 changes: 2 additions & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,8 @@ extern JL_DLLEXPORT jl_value_t *jl_false;
extern JL_DLLEXPORT jl_value_t *jl_nothing;

// some important symbols
extern jl_sym_t *call_sym; extern jl_sym_t *empty_sym;
extern jl_sym_t *call_sym; extern jl_sym_t *invoke_sym;
extern jl_sym_t *empty_sym;
extern jl_sym_t *dots_sym; extern jl_sym_t *vararg_sym;
extern jl_sym_t *quote_sym; extern jl_sym_t *newvar_sym;
extern jl_sym_t *top_sym; extern jl_sym_t *dot_sym;
Expand Down
1 change: 1 addition & 0 deletions src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ JL_DLLEXPORT jl_value_t *jl_threading_run(jl_svec_t *args)
JL_GC_PUSH1(&argtypes);
argtypes = arg_type_tuple(jl_svec_data(args), jl_svec_len(args));
jl_lambda_info_t *li = jl_get_specialization1(argtypes);
jl_compile_linfo(li);
jl_generate_fptr(li);

threadwork.command = TI_THREADWORK_RUN;
Expand Down
Loading

0 comments on commit 9e119fa

Please sign in to comment.