From a624d445c02cc2928f8b6a34108aa78fdbb13121 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Thu, 23 Nov 2023 15:22:28 +0800 Subject: [PATCH] Widen diagonal var during `Type` unwrapping in `instanceof_tfunc` (#52228) close #52168 close #27031 --- base/compiler/tfuncs.jl | 14 ++- base/essentials.jl | 5 + src/subtype.c | 206 +++++++++++++++++++++++++++++++++++++ test/compiler/inference.jl | 13 +++ test/core.jl | 11 ++ 5 files changed, 245 insertions(+), 4 deletions(-) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 97d8bf90c060d..c30f4a1a237e1 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -95,7 +95,7 @@ add_tfunc(throw, 1, 1, @nospecs((𝕃::AbstractLattice, x)->Bottom), 0) # if isexact is false, the actual runtime type may (will) be a subtype of t # if isconcrete is true, the actual runtime type is definitely concrete (unreachable if not valid as a typeof) # if istype is true, the actual runtime value will definitely be a type (e.g. this is false for Union{Type{Int}, Int}) -function instanceof_tfunc(@nospecialize(t), astag::Bool=false) +function instanceof_tfunc(@nospecialize(t), astag::Bool=false, @nospecialize(troot) = t) if isa(t, Const) if isa(t.val, Type) && valid_as_lattice(t.val, astag) return t.val, true, isconcretetype(t.val), true @@ -103,6 +103,7 @@ function instanceof_tfunc(@nospecialize(t), astag::Bool=false) return Bottom, true, false, false # runtime throws on non-Type end t = widenconst(t) + troot = widenconst(troot) if t === Bottom return Bottom, true, true, false # runtime unreachable elseif t === typeof(Bottom) || !hasintersect(t, Type) @@ -110,10 +111,15 @@ function instanceof_tfunc(@nospecialize(t), astag::Bool=false) elseif isType(t) tp = t.parameters[1] valid_as_lattice(tp, astag) || return Bottom, true, false, false # runtime unreachable / throws on non-Type + if troot isa UnionAll + # Free `TypeVar`s inside `Type` has violated the "diagonal" rule. + # Widen them before `UnionAll` rewraping to relax concrete constraint. + tp = widen_diagonal(tp, troot) + end return tp, !has_free_typevars(tp), isconcretetype(tp), true elseif isa(t, UnionAll) t′ = unwrap_unionall(t) - t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag) + t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag, rewrap_unionall(t, troot)) tr = rewrap_unionall(t′′, t) if t′′ isa DataType && t′′.name !== Tuple.name && !has_free_typevars(tr) # a real instance must be within the declared bounds of the type, @@ -128,8 +134,8 @@ function instanceof_tfunc(@nospecialize(t), astag::Bool=false) end return tr, isexact, isconcrete, istype elseif isa(t, Union) - ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag) - tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag) + ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag, troot) + tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag, troot) isconcrete = isconcrete_a && isconcrete_b istype = istype_a && istype_b # most users already handle the Union case, so here we assume that diff --git a/base/essentials.jl b/base/essentials.jl index 106826d140c57..a9f3bfc40f622 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -459,6 +459,11 @@ function rename_unionall(@nospecialize(u)) return UnionAll(nv, body{nv}) end +# remove concrete constraint on diagonal TypeVar if it comes from troot +function widen_diagonal(@nospecialize(t), troot::UnionAll) + body = ccall(:jl_widen_diagonal, Any, (Any, Any), t, troot) +end + function isvarargtype(@nospecialize(t)) return isa(t, Core.TypeofVararg) end diff --git a/src/subtype.c b/src/subtype.c index f80e31c58c46b..ce2572be2952d 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -4304,6 +4304,212 @@ int jl_subtype_matching(jl_value_t *a, jl_value_t *b, jl_svec_t **penv) return sub; } +// type utils +static void check_diagonal(jl_value_t *t, jl_varbinding_t *troot, int param) +{ + if (jl_is_uniontype(t)) { + int i, len = 0; + jl_varbinding_t *v; + for (v = troot; v != NULL; v = v->prev) + len++; + int8_t *occurs = (int8_t *)alloca(len); + for (v = troot, i = 0; v != NULL; v = v->prev, i++) + occurs[i] = v->occurs_inv | (v->occurs_cov << 2); + check_diagonal(((jl_uniontype_t *)t)->a, troot, param); + for (v = troot, i = 0; v != NULL; v = v->prev, i++) { + int8_t occurs_inv = occurs[i] & 3; + int8_t occurs_cov = occurs[i] >> 2; + occurs[i] = v->occurs_inv | (v->occurs_cov << 2); + v->occurs_inv = occurs_inv; + v->occurs_cov = occurs_cov; + } + check_diagonal(((jl_uniontype_t *)t)->b, troot, param); + for (v = troot, i = 0; v != NULL; v = v->prev, i++) { + if (v->occurs_inv < (occurs[i] & 3)) + v->occurs_inv = occurs[i] & 3; + if (v->occurs_cov < (occurs[i] >> 2)) + v->occurs_cov = occurs[i] >> 2; + } + } + else if (jl_is_unionall(t)) { + assert(troot != NULL); + jl_varbinding_t *v1 = troot, *v2 = troot->prev; + while (v2 != NULL) { + if (v2->var == ((jl_unionall_t *)t)->var) { + v1->prev = v2->prev; + break; + } + v1 = v2; + v2 = v2->prev; + } + check_diagonal(((jl_unionall_t *)t)->body, troot, param); + v1->prev = v2; + } + else if (jl_is_datatype(t)) { + int nparam = jl_is_tuple_type(t) ? 1 : 2; + if (nparam < param) nparam = param; + for (size_t i = 0; i < jl_nparams(t); i++) { + check_diagonal(jl_tparam(t, i), troot, nparam); + } + } + else if (jl_is_vararg(t)) { + jl_value_t *T = jl_unwrap_vararg(t); + jl_value_t *N = jl_unwrap_vararg_num(t); + int n = (N && jl_is_long(N)) ? jl_unbox_long(N) : 2; + if (T && n > 0) check_diagonal(T, troot, param); + if (T && n > 1) check_diagonal(T, troot, param); + if (N) check_diagonal(N, troot, 2); + } + else if (jl_is_typevar(t)) { + jl_varbinding_t *v = troot; + for (; v != NULL; v = v->prev) { + if (v->var == (jl_tvar_t *)t) { + if (param == 1 && v->occurs_cov < 2) v->occurs_cov++; + if (param == 2 && v->occurs_inv < 2) v->occurs_inv++; + break; + } + } + if (v == NULL) + check_diagonal(((jl_tvar_t *)t)->ub, troot, 0); + } +} + +static jl_value_t *insert_nondiagonal(jl_value_t *type, jl_varbinding_t *troot, int widen2ub) +{ + if (jl_is_typevar(type)) { + int concretekind = widen2ub > 1 ? 0 : 1; + jl_varbinding_t *v = troot; + for (; v != NULL; v = v->prev) { + if (v->occurs_inv == 0 && + v->occurs_cov > concretekind && + v->var == (jl_tvar_t *)type) + break; + } + if (v != NULL) { + if (widen2ub) { + type = insert_nondiagonal(((jl_tvar_t *)type)->ub, troot, 2); + } + else { + // we must replace each covariant occurrence of newvar with a different newvar2<:newvar (diagonal rule) + if (v->innervars == NULL) + v->innervars = jl_alloc_array_1d(jl_array_any_type, 0); + jl_value_t *newvar = NULL, *lb = v->var->lb, *ub = (jl_value_t *)v->var; + jl_array_t *innervars = v->innervars; + JL_GC_PUSH4(&newvar, &lb, &ub, &innervars); + newvar = (jl_value_t *)jl_new_typevar(v->var->name, lb, ub); + jl_array_ptr_1d_push(innervars, newvar); + JL_GC_POP(); + type = newvar; + } + } + } + else if (jl_is_unionall(type)) { + jl_value_t *body = ((jl_unionall_t*)type)->body; + jl_tvar_t *var = ((jl_unionall_t*)type)->var; + jl_varbinding_t *v = troot; + for (; v != NULL; v = v->prev) { + if (v->var == var) + break; + } + if (v) v->var = NULL; // Temporarily remove `type->var` from binding list. + jl_value_t *newbody = insert_nondiagonal(body, troot, widen2ub); + if (v) v->var = var; // And restore it after inner insertation. + jl_value_t *newvar = NULL; + JL_GC_PUSH2(&newbody, &newvar); + if (body == newbody || jl_has_typevar(newbody, var)) { + if (body != newbody) + newbody = jl_new_struct(jl_unionall_type, var, newbody); + // n.b. we do not widen lb, since that would be the wrong direction + newvar = insert_nondiagonal(var->ub, troot, widen2ub); + if (newvar != var->ub) { + newvar = (jl_value_t*)jl_new_typevar(var->name, var->lb, newvar); + newbody = jl_apply_type1(newbody, newvar); + newbody = jl_type_unionall((jl_tvar_t*)newvar, newbody); + } + } + type = newbody; + JL_GC_POP(); + } + else if (jl_is_uniontype(type)) { + jl_value_t *a = ((jl_uniontype_t*)type)->a; + jl_value_t *b = ((jl_uniontype_t*)type)->b; + jl_value_t *newa = NULL; + jl_value_t *newb = NULL; + JL_GC_PUSH2(&newa, &newb); + newa = insert_nondiagonal(a, troot, widen2ub); + newb = insert_nondiagonal(b, troot, widen2ub); + if (newa != a || newb != b) + type = simple_union(newa, newb); + JL_GC_POP(); + } + else if (jl_is_vararg(type)) { + // As for Vararg we'd better widen it's var to ub as otherwise they are still diagonal + jl_value_t *t = jl_unwrap_vararg(type); + jl_value_t *n = jl_unwrap_vararg_num(type); + if (widen2ub == 0) + widen2ub = !(n && jl_is_long(n)) || jl_unbox_long(n) > 1; + jl_value_t *newt; + JL_GC_PUSH2(&newt, &n); + newt = insert_nondiagonal(t, troot, widen2ub); + if (t != newt) + type = (jl_value_t *)jl_wrap_vararg(newt, n, 0); + JL_GC_POP(); + } + else if (jl_is_datatype(type)) { + if (jl_is_tuple_type(type)) { + jl_svec_t *newparams = NULL; + jl_value_t *newelt = NULL; + JL_GC_PUSH2(&newparams, &newelt); + for (size_t i = 0; i < jl_nparams(type); i++) { + jl_value_t *elt = jl_tparam(type, i); + newelt = insert_nondiagonal(elt, troot, widen2ub); + if (elt != newelt) { + if (!newparams) + newparams = jl_svec_copy(((jl_datatype_t*)type)->parameters); + jl_svecset(newparams, i, newelt); + } + } + if (newparams) + type = (jl_value_t*)jl_apply_tuple_type(newparams, 1); + JL_GC_POP(); + } + } + return type; +} + +static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) { + check_diagonal(t, troot, 0); + int any_concrete = 0; + for (jl_varbinding_t *v = troot; v != NULL; v = v->prev) + any_concrete |= v->occurs_cov > 1 && v->occurs_inv == 0; + if (!any_concrete) + return t; // no diagonal + return insert_nondiagonal(t, troot, 0); +} + +static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot) +{ + jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot }; + jl_value_t *nt; + JL_GC_PUSH2(&vb.innervars, &nt); + if (jl_is_unionall(u->body)) + nt = widen_diagonal(t, (jl_unionall_t *)u->body, &vb); + else + nt = _widen_diagonal(t, &vb); + if (vb.innervars != NULL) { + for (size_t i = 0; i < jl_array_nrows(vb.innervars); i++) { + jl_tvar_t *var = (jl_tvar_t*)jl_array_ptr_ref(vb.innervars, i); + nt = jl_type_unionall(var, nt); + } + } + JL_GC_POP(); + return nt; +} + +JL_DLLEXPORT jl_value_t *jl_widen_diagonal(jl_value_t *t, jl_unionall_t *ua) +{ + return widen_diagonal(t, ua, NULL); +} // specificity comparison diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 30a2763cb605b..108a62db4b940 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -5593,3 +5593,16 @@ end |> only === Float64 @test Base.infer_exception_type(c::Bool -> c ? 1 : 2) == Union{} @test Base.infer_exception_type(c::Missing -> c ? 1 : 2) == TypeError @test Base.infer_exception_type(c::Any -> c ? 1 : 2) == TypeError + +# Issue #52168 +f52168(x, t::Type) = x::NTuple{2, Base.inferencebarrier(t)::Type} +@test f52168((1, 2.), Any) === (1, 2.) + +# Issue #27031 +let x = 1, _Any = Any + @noinline bar27031(tt::Tuple{T,T}, ::Type{Val{T}}) where {T} = notsame27031(tt) + @noinline notsame27031(tt::Tuple{T, T}) where {T} = error() + @noinline notsame27031(tt::Tuple{T, S}) where {T, S} = "OK" + foo27031() = bar27031((x, 1.0), Val{_Any}) + @test foo27031() == "OK" +end diff --git a/test/core.jl b/test/core.jl index c85868c496d91..7cd4be5427eb4 100644 --- a/test/core.jl +++ b/test/core.jl @@ -8059,3 +8059,14 @@ check_globalref_lowering() = @insert_global let src = code_lowered(check_globalref_lowering)[1] @test length(src.code) == 2 end + +# Test correctness of widen_diagonal +let widen_diagonal(x::UnionAll) = Base.rewrap_unionall(Base.widen_diagonal(Base.unwrap_unionall(x), x), x), + check_widen_diagonal(x, y) = !<:(x, y) && x <: widen_diagonal(y) + @test Tuple{Int,Float64} <: widen_diagonal(NTuple) + @test Tuple{Int,Float64} <: widen_diagonal(Tuple{T,T} where {T}) + @test Tuple{Real,Int,Float64} <: widen_diagonal(Tuple{S,Vararg{T}} where {S, T<:S}) + @test Tuple{Int,Int,Float64,Float64} <: widen_diagonal(Tuple{S,S,Vararg{T}} where {S, T<:S}) + @test Union{Tuple{T}, Tuple{T,Int}} where {T} === widen_diagonal(Union{Tuple{T}, Tuple{T,Int}} where {T}) + @test Tuple === widen_diagonal(Union{Tuple{Vararg{S}}, Tuple{Vararg{T}}} where {S, T}) +end