Skip to content

Commit d295d7c

Browse files
committed
Widen diagonal var during Type unwrapping in instanceof_tfunc (#52228)
close #52168 close #27031
1 parent 0b15b44 commit d295d7c

File tree

5 files changed

+245
-4
lines changed

5 files changed

+245
-4
lines changed

base/compiler/tfuncs.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -95,25 +95,31 @@ add_tfunc(throw, 1, 1, @nospecs((𝕃::AbstractLattice, x)->Bottom), 0)
9595
# if isexact is false, the actual runtime type may (will) be a subtype of t
9696
# if isconcrete is true, the actual runtime type is definitely concrete (unreachable if not valid as a typeof)
9797
# if istype is true, the actual runtime value will definitely be a type (e.g. this is false for Union{Type{Int}, Int})
98-
function instanceof_tfunc(@nospecialize(t))
98+
function instanceof_tfunc(@nospecialize(t), @nospecialize(troot) = t)
9999
if isa(t, Const)
100100
if isa(t.val, Type) && valid_as_lattice(t.val)
101101
return t.val, true, isconcretetype(t.val), true
102102
end
103103
return Bottom, true, false, false # runtime throws on non-Type
104104
end
105105
t = widenconst(t)
106+
troot = widenconst(troot)
106107
if t === Bottom
107108
return Bottom, true, true, false # runtime unreachable
108109
elseif t === typeof(Bottom) || !hasintersect(t, Type)
109110
return Bottom, true, false, false # literal Bottom or non-Type
110111
elseif isType(t)
111112
tp = t.parameters[1]
112113
valid_as_lattice(tp) || return Bottom, true, false, false # runtime unreachable / throws on non-Type
114+
if troot isa UnionAll
115+
# Free `TypeVar`s inside `Type` has violated the "diagonal" rule.
116+
# Widen them before `UnionAll` rewraping to relax concrete constraint.
117+
tp = widen_diagonal(tp, troot)
118+
end
113119
return tp, !has_free_typevars(tp), isconcretetype(tp), true
114120
elseif isa(t, UnionAll)
115121
t′ = unwrap_unionall(t)
116-
t′′, isexact, isconcrete, istype = instanceof_tfunc(t′)
122+
t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, rewrap_unionall(t, troot))
117123
tr = rewrap_unionall(t′′, t)
118124
if t′′ isa DataType && t′′.name !== Tuple.name && !has_free_typevars(tr)
119125
# a real instance must be within the declared bounds of the type,
@@ -128,8 +134,8 @@ function instanceof_tfunc(@nospecialize(t))
128134
end
129135
return tr, isexact, isconcrete, istype
130136
elseif isa(t, Union)
131-
ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a)
132-
tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b)
137+
ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, troot)
138+
tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, troot)
133139
isconcrete = isconcrete_a && isconcrete_b
134140
istype = istype_a && istype_b
135141
# most users already handle the Union case, so here we assume that

base/essentials.jl

+5
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,11 @@ function rename_unionall(@nospecialize(u))
411411
return UnionAll(nv, body{nv})
412412
end
413413

414+
# remove concrete constraint on diagonal TypeVar if it comes from troot
415+
function widen_diagonal(@nospecialize(t), troot::UnionAll)
416+
body = ccall(:jl_widen_diagonal, Any, (Any, Any), t, troot)
417+
end
418+
414419
function isvarargtype(@nospecialize(t))
415420
return isa(t, Core.TypeofVararg)
416421
end

src/subtype.c

+206
Original file line numberDiff line numberDiff line change
@@ -4304,6 +4304,212 @@ int jl_subtype_matching(jl_value_t *a, jl_value_t *b, jl_svec_t **penv)
43044304
return sub;
43054305
}
43064306

4307+
// type utils
4308+
static void check_diagonal(jl_value_t *t, jl_varbinding_t *troot, int param)
4309+
{
4310+
if (jl_is_uniontype(t)) {
4311+
int i, len = 0;
4312+
jl_varbinding_t *v;
4313+
for (v = troot; v != NULL; v = v->prev)
4314+
len++;
4315+
int8_t *occurs = (int8_t *)alloca(len);
4316+
for (v = troot, i = 0; v != NULL; v = v->prev, i++)
4317+
occurs[i] = v->occurs_inv | (v->occurs_cov << 2);
4318+
check_diagonal(((jl_uniontype_t *)t)->a, troot, param);
4319+
for (v = troot, i = 0; v != NULL; v = v->prev, i++) {
4320+
int8_t occurs_inv = occurs[i] & 3;
4321+
int8_t occurs_cov = occurs[i] >> 2;
4322+
occurs[i] = v->occurs_inv | (v->occurs_cov << 2);
4323+
v->occurs_inv = occurs_inv;
4324+
v->occurs_cov = occurs_cov;
4325+
}
4326+
check_diagonal(((jl_uniontype_t *)t)->b, troot, param);
4327+
for (v = troot, i = 0; v != NULL; v = v->prev, i++) {
4328+
if (v->occurs_inv < (occurs[i] & 3))
4329+
v->occurs_inv = occurs[i] & 3;
4330+
if (v->occurs_cov < (occurs[i] >> 2))
4331+
v->occurs_cov = occurs[i] >> 2;
4332+
}
4333+
}
4334+
else if (jl_is_unionall(t)) {
4335+
assert(troot != NULL);
4336+
jl_varbinding_t *v1 = troot, *v2 = troot->prev;
4337+
while (v2 != NULL) {
4338+
if (v2->var == ((jl_unionall_t *)t)->var) {
4339+
v1->prev = v2->prev;
4340+
break;
4341+
}
4342+
v1 = v2;
4343+
v2 = v2->prev;
4344+
}
4345+
check_diagonal(((jl_unionall_t *)t)->body, troot, param);
4346+
v1->prev = v2;
4347+
}
4348+
else if (jl_is_datatype(t)) {
4349+
int nparam = jl_is_tuple_type(t) ? 1 : 2;
4350+
if (nparam < param) nparam = param;
4351+
for (size_t i = 0; i < jl_nparams(t); i++) {
4352+
check_diagonal(jl_tparam(t, i), troot, nparam);
4353+
}
4354+
}
4355+
else if (jl_is_vararg(t)) {
4356+
jl_value_t *T = jl_unwrap_vararg(t);
4357+
jl_value_t *N = jl_unwrap_vararg_num(t);
4358+
int n = (N && jl_is_long(N)) ? jl_unbox_long(N) : 2;
4359+
if (T && n > 0) check_diagonal(T, troot, param);
4360+
if (T && n > 1) check_diagonal(T, troot, param);
4361+
if (N) check_diagonal(N, troot, 2);
4362+
}
4363+
else if (jl_is_typevar(t)) {
4364+
jl_varbinding_t *v = troot;
4365+
for (; v != NULL; v = v->prev) {
4366+
if (v->var == (jl_tvar_t *)t) {
4367+
if (param == 1 && v->occurs_cov < 2) v->occurs_cov++;
4368+
if (param == 2 && v->occurs_inv < 2) v->occurs_inv++;
4369+
break;
4370+
}
4371+
}
4372+
if (v == NULL)
4373+
check_diagonal(((jl_tvar_t *)t)->ub, troot, 0);
4374+
}
4375+
}
4376+
4377+
static jl_value_t *insert_nondiagonal(jl_value_t *type, jl_varbinding_t *troot, int widen2ub)
4378+
{
4379+
if (jl_is_typevar(type)) {
4380+
int concretekind = widen2ub > 1 ? 0 : 1;
4381+
jl_varbinding_t *v = troot;
4382+
for (; v != NULL; v = v->prev) {
4383+
if (v->occurs_inv == 0 &&
4384+
v->occurs_cov > concretekind &&
4385+
v->var == (jl_tvar_t *)type)
4386+
break;
4387+
}
4388+
if (v != NULL) {
4389+
if (widen2ub) {
4390+
type = insert_nondiagonal(((jl_tvar_t *)type)->ub, troot, 2);
4391+
}
4392+
else {
4393+
// we must replace each covariant occurrence of newvar with a different newvar2<:newvar (diagonal rule)
4394+
if (v->innervars == NULL)
4395+
v->innervars = jl_alloc_array_1d(jl_array_any_type, 0);
4396+
jl_value_t *newvar = NULL, *lb = v->var->lb, *ub = (jl_value_t *)v->var;
4397+
jl_array_t *innervars = v->innervars;
4398+
JL_GC_PUSH4(&newvar, &lb, &ub, &innervars);
4399+
newvar = (jl_value_t *)jl_new_typevar(v->var->name, lb, ub);
4400+
jl_array_ptr_1d_push(innervars, newvar);
4401+
JL_GC_POP();
4402+
type = newvar;
4403+
}
4404+
}
4405+
}
4406+
else if (jl_is_unionall(type)) {
4407+
jl_value_t *body = ((jl_unionall_t*)type)->body;
4408+
jl_tvar_t *var = ((jl_unionall_t*)type)->var;
4409+
jl_varbinding_t *v = troot;
4410+
for (; v != NULL; v = v->prev) {
4411+
if (v->var == var)
4412+
break;
4413+
}
4414+
if (v) v->var = NULL; // Temporarily remove `type->var` from binding list.
4415+
jl_value_t *newbody = insert_nondiagonal(body, troot, widen2ub);
4416+
if (v) v->var = var; // And restore it after inner insertation.
4417+
jl_value_t *newvar = NULL;
4418+
JL_GC_PUSH2(&newbody, &newvar);
4419+
if (body == newbody || jl_has_typevar(newbody, var)) {
4420+
if (body != newbody)
4421+
newbody = jl_new_struct(jl_unionall_type, var, newbody);
4422+
// n.b. we do not widen lb, since that would be the wrong direction
4423+
newvar = insert_nondiagonal(var->ub, troot, widen2ub);
4424+
if (newvar != var->ub) {
4425+
newvar = (jl_value_t*)jl_new_typevar(var->name, var->lb, newvar);
4426+
newbody = jl_apply_type1(newbody, newvar);
4427+
newbody = jl_type_unionall((jl_tvar_t*)newvar, newbody);
4428+
}
4429+
}
4430+
type = newbody;
4431+
JL_GC_POP();
4432+
}
4433+
else if (jl_is_uniontype(type)) {
4434+
jl_value_t *a = ((jl_uniontype_t*)type)->a;
4435+
jl_value_t *b = ((jl_uniontype_t*)type)->b;
4436+
jl_value_t *newa = NULL;
4437+
jl_value_t *newb = NULL;
4438+
JL_GC_PUSH2(&newa, &newb);
4439+
newa = insert_nondiagonal(a, troot, widen2ub);
4440+
newb = insert_nondiagonal(b, troot, widen2ub);
4441+
if (newa != a || newb != b)
4442+
type = simple_union(newa, newb);
4443+
JL_GC_POP();
4444+
}
4445+
else if (jl_is_vararg(type)) {
4446+
// As for Vararg we'd better widen it's var to ub as otherwise they are still diagonal
4447+
jl_value_t *t = jl_unwrap_vararg(type);
4448+
jl_value_t *n = jl_unwrap_vararg_num(type);
4449+
if (widen2ub == 0)
4450+
widen2ub = !(n && jl_is_long(n)) || jl_unbox_long(n) > 1;
4451+
jl_value_t *newt;
4452+
JL_GC_PUSH2(&newt, &n);
4453+
newt = insert_nondiagonal(t, troot, widen2ub);
4454+
if (t != newt)
4455+
type = (jl_value_t *)jl_wrap_vararg(newt, n, 0);
4456+
JL_GC_POP();
4457+
}
4458+
else if (jl_is_datatype(type)) {
4459+
if (jl_is_tuple_type(type)) {
4460+
jl_svec_t *newparams = NULL;
4461+
jl_value_t *newelt = NULL;
4462+
JL_GC_PUSH2(&newparams, &newelt);
4463+
for (size_t i = 0; i < jl_nparams(type); i++) {
4464+
jl_value_t *elt = jl_tparam(type, i);
4465+
newelt = insert_nondiagonal(elt, troot, widen2ub);
4466+
if (elt != newelt) {
4467+
if (!newparams)
4468+
newparams = jl_svec_copy(((jl_datatype_t*)type)->parameters);
4469+
jl_svecset(newparams, i, newelt);
4470+
}
4471+
}
4472+
if (newparams)
4473+
type = (jl_value_t*)jl_apply_tuple_type(newparams);
4474+
JL_GC_POP();
4475+
}
4476+
}
4477+
return type;
4478+
}
4479+
4480+
static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {
4481+
check_diagonal(t, troot, 0);
4482+
int any_concrete = 0;
4483+
for (jl_varbinding_t *v = troot; v != NULL; v = v->prev)
4484+
any_concrete |= v->occurs_cov > 1 && v->occurs_inv == 0;
4485+
if (!any_concrete)
4486+
return t; // no diagonal
4487+
return insert_nondiagonal(t, troot, 0);
4488+
}
4489+
4490+
static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
4491+
{
4492+
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
4493+
jl_value_t *nt;
4494+
JL_GC_PUSH2(&vb.innervars, &nt);
4495+
if (jl_is_unionall(u->body))
4496+
nt = widen_diagonal(t, (jl_unionall_t *)u->body, &vb);
4497+
else
4498+
nt = _widen_diagonal(t, &vb);
4499+
if (vb.innervars != NULL) {
4500+
for (size_t i = 0; i < jl_array_nrows(vb.innervars); i++) {
4501+
jl_tvar_t *var = (jl_tvar_t*)jl_array_ptr_ref(vb.innervars, i);
4502+
nt = jl_type_unionall(var, nt);
4503+
}
4504+
}
4505+
JL_GC_POP();
4506+
return nt;
4507+
}
4508+
4509+
JL_DLLEXPORT jl_value_t *jl_widen_diagonal(jl_value_t *t, jl_unionall_t *ua)
4510+
{
4511+
return widen_diagonal(t, ua, NULL);
4512+
}
43074513

43084514
// specificity comparison
43094515

test/compiler/inference.jl

+13
Original file line numberDiff line numberDiff line change
@@ -5129,3 +5129,16 @@ let TV = TypeVar(:T)
51295129
some = Some{Any}((TV, t))
51305130
@test abstract_call_unionall_vararg(some) isa UnionAll
51315131
end
5132+
5133+
# Issue #52168
5134+
f52168(x, t::Type) = x::NTuple{2, Base.inferencebarrier(t)::Type}
5135+
@test f52168((1, 2.), Any) === (1, 2.)
5136+
5137+
# Issue #27031
5138+
let x = 1, _Any = Any
5139+
@noinline bar27031(tt::Tuple{T,T}, ::Type{Val{T}}) where {T} = notsame27031(tt)
5140+
@noinline notsame27031(tt::Tuple{T, T}) where {T} = error()
5141+
@noinline notsame27031(tt::Tuple{T, S}) where {T, S} = "OK"
5142+
foo27031() = bar27031((x, 1.0), Val{_Any})
5143+
@test foo27031() == "OK"
5144+
end

test/core.jl

+11
Original file line numberDiff line numberDiff line change
@@ -8057,3 +8057,14 @@ end
80578057
# `SimpleVector`-operations should be concrete-eval eligible
80588058
@test Core.Compiler.is_foldable(Base.infer_effects(length, (Core.SimpleVector,)))
80598059
@test Core.Compiler.is_foldable(Base.infer_effects(getindex, (Core.SimpleVector,Int)))
8060+
8061+
# Test correctness of widen_diagonal
8062+
let widen_diagonal(x::UnionAll) = Base.rewrap_unionall(Base.widen_diagonal(Base.unwrap_unionall(x), x), x),
8063+
check_widen_diagonal(x, y) = !<:(x, y) && x <: widen_diagonal(y)
8064+
@test Tuple{Int,Float64} <: widen_diagonal(NTuple)
8065+
@test Tuple{Int,Float64} <: widen_diagonal(Tuple{T,T} where {T})
8066+
@test Tuple{Real,Int,Float64} <: widen_diagonal(Tuple{S,Vararg{T}} where {S, T<:S})
8067+
@test Tuple{Int,Int,Float64,Float64} <: widen_diagonal(Tuple{S,S,Vararg{T}} where {S, T<:S})
8068+
@test Union{Tuple{T}, Tuple{T,Int}} where {T} === widen_diagonal(Union{Tuple{T}, Tuple{T,Int}} where {T})
8069+
@test Tuple === widen_diagonal(Union{Tuple{Vararg{S}}, Tuple{Vararg{T}}} where {S, T})
8070+
end

0 commit comments

Comments
 (0)