Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: experiments in diagonal constraints in jl_args_morespecific (ambiguity/method sorting) #16276

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions base/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,18 @@

## native julia error handling ##

error(s::AbstractString) = throw(Main.Base.ErrorException(s))
error(s...) = throw(Main.Base.ErrorException(Main.Base.string(s...)))
if isdefined(Main, :Base)
error(s::AbstractString) = throw(Main.Base.ErrorException(s))
error(s...) = throw(Main.Base.ErrorException(Main.Base.string(s...)))
else
error(s::AbstractString) = (print(s); throw(InterruptException()))
function error(ss...)
print("Error with ")
print(length(ss))
print(" arguments")
throw(InterruptException())
end
end

rethrow() = ccall(:jl_rethrow, Bottom, ())
rethrow(e) = ccall(:jl_rethrow_other, Bottom, (Any,), e)
Expand Down
2 changes: 1 addition & 1 deletion base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ end
find_vars(e) = find_vars(e, [])
function find_vars(e, lst)
if isa(e,Symbol)
if current_module()===Main && isdefined(e)
if current_module()===Main && isdefined(Main,:Base) && isdefined(e)
# Main runs on process 1, so send globals from there, excluding
# things defined in Base.
if !isdefined(Base,e) || eval(Base,e)!==eval(current_module(),e)
Expand Down
2 changes: 1 addition & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ current_module() = ccall(:jl_get_current_module, Ref{Module}, ())

function fullname(m::Module)
m === Main && return ()
m === Base && return (:Base,) # issue #10653
isdefined(Main, :Base) && m === Base && return (:Base,) # issue #10653
mn = module_name(m)
mp = module_parent(m)
if mp === m
Expand Down
212 changes: 203 additions & 9 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -1448,14 +1448,15 @@ static int solve_tvar_constraints(cenv_t *env, cenv_t *soln, jl_value_t **tvs, i
return 0;
}

jl_value_t *jl_type_intersection_matching(jl_value_t *a, jl_value_t *b,
jl_svec_t **penv, jl_svec_t *tvars)
jl_value_t *jl_type_intersection_matching_(jl_value_t *a, jl_value_t *b,
jl_svec_t **penv, jl_svec_t **tvars,
jl_svec_t *(*tvarsel)(jl_svec_t*,cenv_t*,cenv_t*))
{
jl_value_t **rts;
JL_GC_PUSHARGS(rts, 2 + 2*MAX_CENV_SIZE);
cenv_t eqc; eqc.n = 0; eqc.data = &rts[2];
cenv_t env; env.n = 0; env.data = &rts[2+MAX_CENV_SIZE];
eqc.tvars = tvars; env.tvars = tvars;
eqc.tvars = *tvars; env.tvars = *tvars;
jl_value_t **pti = &rts[0];
jl_value_t **extraroot = &rts[1];

Expand All @@ -1472,7 +1473,7 @@ jl_value_t *jl_type_intersection_matching(jl_value_t *a, jl_value_t *b,
*pti = (jl_value_t*)jl_bottom_type;
}
if (*pti == (jl_value_t*)jl_bottom_type ||
!(env.n > 0 || eqc.n > 0 || tvars != jl_emptysvec)) {
!(env.n > 0 || eqc.n > 0 || *tvars != jl_emptysvec)) {
JL_GC_POP();
return *pti;
}
Expand Down Expand Up @@ -1505,14 +1506,15 @@ jl_value_t *jl_type_intersection_matching(jl_value_t *a, jl_value_t *b,

jl_value_t **tvs;
int tvarslen;
if (jl_is_typevar(tvars)) {
tvs = (jl_value_t**)&tvars;
if (jl_is_typevar(*tvars)) {
tvs = (jl_value_t**)tvars;
tvarslen = 1;
}
else {
assert(jl_is_svec(tvars));
tvs = jl_svec_data(tvars);
tvarslen = jl_svec_len(tvars);
assert(jl_is_svec(*tvars));
*tvars = tvarsel(*tvars, &env, &eqc);
tvs = jl_svec_data(*tvars);
tvarslen = jl_svec_len(*tvars);
}

if (!solve_tvar_constraints(&env, &eqc, tvs, tvarslen)) {
Expand Down Expand Up @@ -1613,6 +1615,57 @@ jl_value_t *jl_type_intersection_matching(jl_value_t *a, jl_value_t *b,
return *pti;
}

// Use the tvars specified by the caller
jl_svec_t *tvarsel_caller(jl_svec_t *tvars, cenv_t *env, cenv_t *eqc)
{
return tvars;
}

// append all unique bound tvars
static void append_unique(jl_array_t *unique, cenv_t *env)
{
for (size_t i = 0; i < env->n; i+=2) {
jl_value_t *tvar = env->data[i];
if (!jl_is_typevar(tvar)) continue;
jl_tvar_t *tv = (jl_tvar_t*) tvar;
if (!tv->bound) continue;
int isunique = 1;
for (size_t j = 0; j < jl_array_len(unique); j++) {
if (jl_cellref(unique, j) == tvar) {
isunique = 0;
break;
}
}
if (isunique) {
jl_cell_1d_push(unique, tvar);
}
}
}

// Extract all unique bound tvars from both env and eqc
jl_svec_t *tvarsel_all(jl_svec_t *tvars, cenv_t *env, cenv_t *eqc)
{
jl_array_t *unique = NULL;
jl_svec_t *ret = NULL;
JL_GC_PUSH2(&unique, &ret);
unique = jl_alloc_cell_1d(0);
append_unique(unique, env);
append_unique(unique, eqc);
int n = jl_array_len(unique);
ret = jl_alloc_svec_uninit(n);
for (size_t i = 0; i < n; i++) {
jl_svecset(ret, i, jl_cellref(unique, i));
}
JL_GC_POP();
return ret;
}

jl_value_t *jl_type_intersection_matching(jl_value_t *a, jl_value_t *b,
jl_svec_t **penv, jl_svec_t *tvars)
{
return jl_type_intersection_matching_(a, b, penv, &tvars, tvarsel_caller);
}

// --- type instantiation and cache ---

static int extensionally_same_type(jl_value_t *a, jl_value_t *b)
Expand Down Expand Up @@ -2632,6 +2685,21 @@ static int type_eqv_with_ANY(jl_value_t *a, jl_value_t *b)

static int jl_type_morespecific_(jl_value_t *a, jl_value_t *b, int invariant);

jl_datatype_t *jl_fix_vararg_bound(jl_datatype_t *tt, int nfix)
{
assert(jl_is_va_tuple(tt));
assert(nfix >= 0);
jl_svec_t *tp = tt->parameters;
size_t ntp = jl_svec_len(tp);
jl_value_t *env[2] = {NULL, NULL};
JL_GC_PUSH2(env[0], env[1]);
env[0] = jl_tparam1(jl_tparam(tt, ntp-1));
env[1] = jl_box_long(nfix);
jl_datatype_t *ret = (jl_datatype_t*)jl_instantiate_type_with((jl_value_t*)tt, env, 2);
JL_GC_POP();
return ret;
}

static int jl_tuple_morespecific(jl_datatype_t *cdt, jl_datatype_t *pdt, int invariant)
{
size_t clenr = jl_nparams(cdt);
Expand Down Expand Up @@ -2843,6 +2911,132 @@ JL_DLLEXPORT int jl_type_morespecific(jl_value_t *a, jl_value_t *b)
return jl_type_morespecific_(a, b, 0);
}

int jl_args_morespecific_(jl_value_t *a, jl_value_t *b)
{
int msp = jl_type_morespecific(a,b);
int btv = jl_has_typevars(b);
if (btv) {
if (jl_type_match_morespecific(a,b) == (jl_value_t*)jl_false) {
if (jl_has_typevars(a))
return 0;
return msp;
}
if (jl_has_typevars(a)) {
type_match_invariance_mask = 0;
//int result = jl_type_match_morespecific(b,a) == (jl_value_t*)jl_false);
// this rule seems to work better:
int result = jl_type_match(b,a) == (jl_value_t*)jl_false;
type_match_invariance_mask = 1;
if (result)
return 1;
}
int nmsp = jl_type_morespecific(b,a);
if (nmsp == msp)
return 0;
}
if (jl_has_typevars((jl_value_t*)a)) {
int nmsp = jl_type_morespecific(b,a);
if (nmsp && msp)
return 1;
if (!btv && jl_types_equal(a,b))
return 1;
if (jl_type_match_morespecific(b,a) != (jl_value_t*)jl_false)
return 0;
}
return msp;
}

// Called when a is a bound-vararg and b has known length. Sets the
// vararg length in a to match b, as long as this makes some earlier
// argument more specific.
int jl_args_morespecific_fix1(jl_value_t *a, jl_value_t *b, int nfix, int swap)
{
assert(nfix >= 0);
jl_datatype_t *tta = (jl_datatype_t*)a;
jl_datatype_t *newtta = jl_fix_vararg_bound(tta, nfix);
int changed = 0;
for (size_t i = 0; i < jl_nparams(tta)-1; i++) {
if (jl_tparam(tta, i) != jl_tparam(newtta, i)) {
changed = 1;
break;
}
}
if (changed) {
JL_GC_PUSH1(&newtta);
int ret;
if (swap)
ret = jl_args_morespecific_(b, (jl_value_t*)newtta);
else
ret = jl_args_morespecific_((jl_value_t*)newtta, b);
JL_GC_POP();
return ret;
}
if (swap)
return jl_args_morespecific_(b, a);
return jl_args_morespecific_(a, b);
}

JL_DLLEXPORT int jl_args_morespecific(jl_value_t *a, jl_value_t *b)
{
jl_value_t *atv = NULL;
jl_value_t *btv = NULL;
jl_svec_t *penv = jl_emptysvec;
jl_svec_t *tvars = jl_emptysvec;
jl_svec_t *inst = NULL;
JL_GC_PUSH5(&atv, &btv, &penv, &tvars, &inst);
if (jl_has_typevars(a) || jl_has_typevars(b)) {
// Nail down any type parameters in terms of the intersection
// of the signatures.
jl_value_t *ti = jl_type_intersection_matching_(a, b, &penv, &tvars, tvarsel_all);
if (ti == jl_bottom_type) {
JL_GC_POP();
return 0;
}
int n = jl_svec_len(tvars);
inst = jl_alloc_svec_uninit(2*n);
for (size_t i = 0; i < n; i++) {
jl_svecset(inst, 2*i, jl_svecref(tvars, i));
jl_svecset(inst, 2*i+1, jl_svecref(penv, i));
}
/*
jl_(a);
jl_(b);
jl_(ti);
jl_(inst);
*/
atv = jl_instantiate_type_with(a, jl_svec_data(inst), n);
btv = jl_instantiate_type_with(b, jl_svec_data(inst), n);
if (atv != btv) {
a = atv;
b = btv;
}
}
int ret = -1;
if (jl_is_tuple_type(a) && jl_is_tuple_type(b)) {
jl_datatype_t *tta = (jl_datatype_t*)a;
jl_datatype_t *ttb = (jl_datatype_t*)b;
size_t alenf, blenf;
jl_vararg_kind_t akind, bkind;
jl_tuple_lenkind_t alenkind, blenkind;
alenf = tuple_vararg_params(tta->parameters, NULL, &akind, &alenkind);
blenf = tuple_vararg_params(ttb->parameters, NULL, &bkind, &blenkind);
// When one is JL_VARARG_BOUND and the other has fixed length,
// allow the argument length to fix the tvar
if (akind == JL_VARARG_BOUND && blenkind == JL_TUPLE_FIXED && blenf >= alenf) {
ret = jl_args_morespecific_fix1(a, b, blenf-alenf+1, 0);
}
else if (bkind == JL_VARARG_BOUND && alenkind == JL_TUPLE_FIXED && alenf >= blenf) {
ret = jl_args_morespecific_fix1(b, a, alenf-blenf+1, 1);
}
if (ret != -1) {
JL_GC_POP();
return ret;
}
}
ret = jl_args_morespecific_(a, b);
JL_GC_POP();
return ret;
}

// ----------------------------------------------------------------------------

Expand Down
35 changes: 0 additions & 35 deletions src/typemap.c
Original file line number Diff line number Diff line change
Expand Up @@ -884,41 +884,6 @@ jl_typemap_entry_t *jl_typemap_insert(union jl_typemap_t *cache, jl_value_t *par
return newrec;
}

JL_DLLEXPORT int jl_args_morespecific(jl_value_t *a, jl_value_t *b)
{
int msp = jl_type_morespecific(a,b);
int btv = jl_has_typevars(b);
if (btv) {
if (jl_type_match_morespecific(a,b) == (jl_value_t*)jl_false) {
if (jl_has_typevars(a))
return 0;
return msp;
}
if (jl_has_typevars(a)) {
type_match_invariance_mask = 0;
//int result = jl_type_match_morespecific(b,a) == (jl_value_t*)jl_false);
// this rule seems to work better:
int result = jl_type_match(b,a) == (jl_value_t*)jl_false;
type_match_invariance_mask = 1;
if (result)
return 1;
}
int nmsp = jl_type_morespecific(b,a);
if (nmsp == msp)
return 0;
}
if (jl_has_typevars((jl_value_t*)a)) {
int nmsp = jl_type_morespecific(b,a);
if (nmsp && msp)
return 1;
if (!btv && jl_types_equal(a,b))
return 1;
if (jl_type_match_morespecific(b,a) != (jl_value_t*)jl_false)
return 0;
}
return msp;
}

static int has_unions(jl_tupletype_t *type)
{
int i;
Expand Down
12 changes: 12 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,18 @@ let T = TypeVar(:T, Tuple{Vararg{RangeIndex}}, true)
@test args_morespecific(t2, t1)
end

let T = TypeVar(:T, Any, true), N = TypeVar(:N, Any, true)
a = Tuple{Array{T,N}, Vararg{Int,N}}
b = Tuple{Array,Int}
@test args_morespecific(a, b)
@test !args_morespecific(b, a)
a = Tuple{Array, Vararg{Int,N}}
@test !args_morespecific(a, b)
@test args_morespecific(b, a)
end

# with bound varargs

# issue #11840
f11840(::Type) = "Type"
f11840(::DataType) = "DataType"
Expand Down