Skip to content

Commit

Permalink
Merge pull request #15295 from JuliaLang/jb/spenv
Browse files Browse the repository at this point in the history
put only values in static parameter environments
  • Loading branch information
JeffBezanson committed Mar 1, 2016
2 parents eb59d63 + 54c3ffe commit dceac08
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 101 deletions.
48 changes: 15 additions & 33 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type VarInfo
label_counter::Int # index of the current highest label for this function
fedbackvars::ObjectIdDict
mod::Module
linfo::LambdaInfo
end

function VarInfo(linfo::LambdaInfo, ast=linfo.ast)
Expand All @@ -42,18 +43,8 @@ function VarInfo(linfo::LambdaInfo, ast=linfo.ast)
end
gensym_types = Any[ NF for i = 1:(ngs::Int) ]
nl = label_counter(body)+1
if length(linfo.sparam_vals) > 0
n = length(linfo.sparam_syms)
sp = Array(Any, n*2)
for i = 1:n
sp[i*2-1] = linfo.sparam_syms[i]
sp[i*2 ] = linfo.sparam_vals[i]
end
sp = svec(sp...)
else
sp = svec()
end
VarInfo(sp, vars, gensym_types, vinflist, nl, ObjectIdDict(), linfo.module)
sp = linfo.sparam_vals
VarInfo(sp, vars, gensym_types, vinflist, nl, ObjectIdDict(), linfo.module, linfo)
end

type VarState
Expand All @@ -79,8 +70,8 @@ end
inference_stack = EmptyCallStack()

function is_static_parameter(sv::VarInfo, s::Symbol)
sp = sv.sp
for i=1:2:length(sp)
sp = sv.linfo.sparam_syms
for i=1:length(sp)
if is(sp[i],s)
return true
end
Expand Down Expand Up @@ -443,13 +434,13 @@ const apply_type_tfunc = function (A::ANY, args...)
push!(tparams, val)
continue
elseif isa(inference_stack,CallStack) && isa(A[i],Symbol)
sp = inference_stack.sv.sp
sp = inference_stack.sv.linfo.sparam_syms
s = A[i]
found = false
for j=1:2:length(sp)
for j=1:length(sp)
if is(sp[j],s)
# static parameter
val = sp[j+1]
val = inference_stack.sv.sp[j]
if valid_tparam(val)
push!(tparams, val)
found = true
Expand Down Expand Up @@ -1148,11 +1139,11 @@ end
function abstract_eval_symbol(s::Symbol, vtypes::ObjectIdDict, sv::VarInfo)
t = get(vtypes,s,NF)
if is(t,NF)
sp = sv.sp
for i=1:2:length(sp)
sp = sv.linfo.sparam_syms
for i=1:length(sp)
if is(sp[i],s)
# static parameter
val = sp[i+1]
val = sv.sp[i]
if isa(val,TypeVar)
# static param bound to typevar
if Any <: val.ub
Expand Down Expand Up @@ -1554,18 +1545,10 @@ function typeinf_uncached(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector,

if length(linfo.sparam_vals) > 0
# handled by VarInfo constructor
elseif isempty(sparams) && !isempty(linfo.sparam_syms)
sv.sp = svec(Any[ TypeVar(sym, Any, true) for sym in linfo.sparam_syms ]...)
else
sp = Any[]
for i = 1:2:length(sparams)
push!(sp, sparams[i].name)
push!(sp, sparams[i+1])
end
for i = 1:length(linfo.sparam_syms)
sym = linfo.sparam_syms[i]
push!(sp, sym)
push!(sp, TypeVar(sym, Any, true))
end
sv.sp = svec(sp...)
sv.sp = sparams
end

args = f_argnames(ast)
Expand Down Expand Up @@ -2323,8 +2306,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::VarInfo,
else
spvals = Any[]
for i = 1:length(spnames)
methsp[2 * i - 1].name === spnames[i] || error("sp env in the wrong order")
si = methsp[2 * i]
si = methsp[i]
if isa(si, TypeVar)
return NF
end
Expand Down
111 changes: 46 additions & 65 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,13 @@ static jl_value_t *ml_matches(jl_methlist_t *ml, jl_value_t *type,
jl_sym_t *name, int lim);

static jl_lambda_info_t *cache_method(jl_methtable_t *mt, jl_tupletype_t *type,
jl_lambda_info_t *method, jl_tupletype_t *decl,
jl_svec_t *sparams, int8_t isstaged)
jl_lambda_info_t *method, jl_methlist_t *m,
jl_svec_t *sparams)
{
JL_LOCK(codegen); // Might GC
size_t i;
jl_tupletype_t *decl = m->sig;
int8_t isstaged = m->isstaged;
int need_guard_entries = 0;
int cache_as_orig = 0;
jl_value_t *temp=NULL;
Expand Down Expand Up @@ -719,17 +721,24 @@ static jl_lambda_info_t *cache_method(jl_methtable_t *mt, jl_tupletype_t *type,
}
else {
jl_value_t *lastdeclt = jl_tparam(decl,jl_nparams(decl)-1);
if (jl_svec_len(sparams) > 0) {
lastdeclt = (jl_value_t*)
jl_instantiate_type_with((jl_value_t*)lastdeclt,
jl_svec_data(sparams),
jl_svec_len(sparams)/2);
int nsp = jl_svec_len(sparams);
if (nsp > 0) {
temp2 = (jl_value_t*)jl_alloc_svec_uninit(2*nsp);
for(j=0; j < nsp; j++) {
if (j==0 && jl_is_typevar(m->tvars))
jl_svecset(temp2, 0, m->tvars);
else
jl_svecset(temp2, j*2, jl_svecref(m->tvars, j));
jl_svecset(temp2, j*2+1, jl_svecref(sparams,j));
}
lastdeclt = (jl_value_t*)jl_instantiate_type_with((jl_value_t*)lastdeclt,
jl_svec_data(temp2), nsp);
}
jl_svecset(limited, i, lastdeclt);
}
type = jl_apply_tuple_type(limited);
temp2 = (jl_value_t*)type;
// now there is a problem: the computed signature is more
// now there is a problem: the widened signature is more
// general than just the given arguments, so it might conflict
// with another definition that doesn't have cache instances yet.
// to fix this, we insert guard cache entries for all intersections
Expand All @@ -749,7 +758,7 @@ static jl_lambda_info_t *cache_method(jl_methtable_t *mt, jl_tupletype_t *type,
for(i=0; i < jl_array_len(temp); i++) {
jl_value_t *m = jl_cellref(temp, i);
jl_value_t *env = jl_svecref(m,1);
for(int k=1; k < jl_svec_len(env); k+=2) {
for(int k=0; k < jl_svec_len(env); k++) {
if (jl_is_typevar(jl_svecref(env,k))) {
unmatched_tvars = 1; break;
}
Expand Down Expand Up @@ -798,11 +807,8 @@ static jl_lambda_info_t *cache_method(jl_methtable_t *mt, jl_tupletype_t *type,
JL_UNLOCK(codegen);
return newmeth;
}
jl_svec_t *sparam_vals = jl_svec_len(sparams) == 0 ? jl_emptysvec : jl_alloc_svec_uninit(jl_svec_len(sparams)/2);
for (int i = 0; i < jl_svec_len(sparam_vals); i++) {
jl_svecset(sparam_vals, i, jl_svecref(sparams, i * 2 + 1));
}
newmeth = jl_add_static_parameters(method, sparam_vals, type);

newmeth = jl_add_static_parameters(method, sparams, type);

if (cache_as_orig)
(void)jl_method_cache_insert(mt, origtype, newmeth);
Expand Down Expand Up @@ -837,49 +843,24 @@ static jl_value_t *lookup_match(jl_value_t *a, jl_value_t *b, jl_svec_t **penv,
return ti;
JL_GC_PUSH1(&ti);
assert(jl_is_svec(*penv));
jl_value_t **ee = (jl_value_t**)alloca(sizeof(void*) * jl_svec_len(*penv));
int n=0;
// only keep vars in tvars list
jl_value_t **tvs;
int tvarslen;
if (jl_is_typevar(tvars)) {
tvs = (jl_value_t**)&tvars;
tvarslen = 1;
}
else {
tvs = jl_svec_data(tvars);
tvarslen = jl_svec_len(tvars);
}
int l = jl_svec_len(*penv);
for(int i=0; i < l; i+=2) {
jl_value_t *v = jl_svecref(*penv,i);
jl_value_t *val = jl_svecref(*penv,i+1);
for(int j=0; j < tvarslen; j++) {
if (v == tvs[j]) {
ee[n++] = v;
ee[n++] = val;
/*
since "a" is a concrete type, we assume that
(a∩b != Union{}) => a<:b. However if a static parameter is
forced to equal Union{}, then part of "b" might become Union{},
and therefore a subtype of "a". For example
(Type{Union{}},Int) ∩ (Type{T},T)
issue #5254
*/
if (val == (jl_value_t*)jl_bottom_type) {
if (!jl_subtype(a, ti, 0)) {
JL_GC_POP();
return (jl_value_t*)jl_bottom_type;
}
}
for(int i=0; i < l; i++) {
jl_value_t *val = jl_svecref(*penv,i);
/*
since "a" is a concrete type, we assume that
(a∩b != Union{}) => a<:b. However if a static parameter is
forced to equal Union{}, then part of "b" might become Union{},
and therefore a subtype of "a". For example
(Type{Union{}},Int) ∩ (Type{T},T)
issue #5254
*/
if (val == (jl_value_t*)jl_bottom_type) {
if (!jl_subtype(a, ti, 0)) {
JL_GC_POP();
return (jl_value_t*)jl_bottom_type;
}
}
}
if (n != l) {
jl_svec_t *en = jl_alloc_svec_uninit(n);
memcpy(jl_svec_data(en), ee, n*sizeof(void*));
*penv = en;
}
JL_GC_POP();
return ti;
}
Expand All @@ -903,13 +884,9 @@ JL_DLLEXPORT jl_lambda_info_t *jl_instantiate_staged(jl_lambda_info_t *generator
{
jl_expr_t *ex = NULL;
jl_value_t *linenum = NULL;
jl_svec_t *sparam_vals = NULL;
jl_svec_t *sparam_vals = env;
JL_GC_PUSH3(&ex, &linenum, &sparam_vals);

sparam_vals = jl_svec_len(env) == 0 ? jl_emptysvec : jl_alloc_svec_uninit(jl_svec_len(env)/2);
for (int i = 0; i < jl_svec_len(sparam_vals); i++) {
jl_svecset(sparam_vals, i, jl_svecref(env, i * 2 + 1));
}
assert(generator->sparam_vals == jl_emptysvec);
assert(jl_svec_len(generator->sparam_syms) == jl_svec_len(sparam_vals));
assert(generator->unspecialized == NULL && generator->specTypes == jl_anytuple_type);
Expand Down Expand Up @@ -974,7 +951,7 @@ static jl_lambda_info_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_datatype_t *
if (ti != (jl_value_t*)jl_bottom_type) {
// parametric methods only match if all typevars are matched by
// non-typevars.
for(i=1; i < jl_svec_len(env); i+=2) {
for(i=0; i < jl_svec_len(env); i++) {
if (jl_is_typevar(jl_svecref(env,i))) {
if (inexact) {
// "inexact" means the given type is compile-time,
Expand Down Expand Up @@ -1010,7 +987,7 @@ static jl_lambda_info_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_datatype_t *
// make sure the argument is rooted in `cache_method`
// in case another thread changed it.
newsig = m->sig;
jl_lambda_info_t *res = cache_method(mt, tt, func, m->sig, jl_emptysvec, m->isstaged);
jl_lambda_info_t *res = cache_method(mt, tt, func, m, jl_emptysvec);
JL_GC_POP();
return res;
}
Expand Down Expand Up @@ -1039,7 +1016,7 @@ static jl_lambda_info_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_datatype_t *
if (!cache)
nf = func;
else
nf = cache_method(mt, tt, func, newsig, env, m->isstaged);
nf = cache_method(mt, tt, func, m, env);
JL_GC_POP();
return nf;
}
Expand Down Expand Up @@ -1978,8 +1955,7 @@ jl_value_t *jl_gf_invoke(jl_tupletype_t *types0, jl_value_t **args, size_t nargs
if (m->isstaged)
func = jl_instantiate_staged(func, tt, tpenv);

newsig = m->sig;
mfunc = cache_method(m->invokes, tt, func, newsig, tpenv, m->isstaged);
mfunc = cache_method(m->invokes, tt, func, m, tpenv);
}

JL_GC_POP();
Expand Down Expand Up @@ -2169,11 +2145,16 @@ static jl_value_t *ml_matches(jl_methlist_t *ml, jl_value_t *type,
*/
int matched_all_typevars = 1;
size_t l = jl_svec_len(env);
for(i=1; i < l; i+=2) {
for(i=0; i < l; i++) {
jl_value_t *tv;
if (jl_is_typevar(ml->tvars))
tv = (jl_value_t*)ml->tvars;
else
tv = jl_svecref(ml->tvars, i);
if (jl_is_typevar(jl_svecref(env,i)) &&
// if tvar is at the top level it will definitely be matched.
// see issue #5575
!tvar_exists_at_top_level(jl_svecref(env,i-1), ml->sig, 1)) {
!tvar_exists_at_top_level(tv, ml->sig, 1)) {
matched_all_typevars = 0;
break;
}
Expand Down
5 changes: 2 additions & 3 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -1584,13 +1584,12 @@ jl_value_t *jl_type_intersection_matching(jl_value_t *a, jl_value_t *b,
}

// return environment in same order as tvars
*penv = jl_alloc_svec_uninit(tvarslen*2);
*penv = jl_alloc_svec_uninit(tvarslen);
for(int tk=0; tk < tvarslen; tk++) {
jl_tvar_t *tv = (jl_tvar_t*)tvs[tk];
for(e=0; e < eqc.n; e+=2) {
if (eqc.data[e] == (jl_value_t*)tv) {
jl_svecset(*penv, tk*2, tv);
jl_svecset(*penv, tk*2+1, eqc.data[e+1]);
jl_svecset(*penv, tk, eqc.data[e+1]);
}
}
}
Expand Down

0 comments on commit dceac08

Please sign in to comment.