Skip to content

Commit

Permalink
typeintersect: fix bounds merging during inner intersect_all (#55299)
Browse files Browse the repository at this point in the history
This PR reverts the optimization from
748149e (part of #48167), while
keeping the fix for merging occurs_inv/occurs_cov, as that optimzation
makes no sense especially when typevar occurs both inside and outside
the inner intersection.

Close #55206
  • Loading branch information
N5N3 authored Jul 30, 2024
1 parent f979ee9 commit fb6b790
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 141 deletions.
190 changes: 56 additions & 134 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ typedef struct jl_varbinding_t {
jl_value_t *lb;
jl_value_t *ub;
int8_t right; // whether this variable came from the right side of `A <: B`
int8_t occurs; // occurs in any position
int8_t occurs_inv; // occurs in invariant position
int8_t occurs_cov; // # of occurrences in covariant position
int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete
Expand Down Expand Up @@ -179,7 +178,7 @@ static int current_env_length(jl_stenv_t *e)
typedef struct {
int8_t *buf;
int rdepth;
int8_t _space[32]; // == 8 * 4
int8_t _space[24]; // == 8 * 3
jl_gcframe_t gcframe;
jl_value_t *roots[24]; // == 8 * 3
} jl_savedenv_t;
Expand Down Expand Up @@ -208,7 +207,6 @@ static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
roots[i++] = v->ub;
roots[i++] = (jl_value_t*)v->innervars;
}
se->buf[j++] = v->occurs;
se->buf[j++] = v->occurs_inv;
se->buf[j++] = v->occurs_cov;
se->buf[j++] = v->max_offset;
Expand Down Expand Up @@ -243,7 +241,7 @@ static void alloc_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
ct->gcstack = &se->gcframe;
}
}
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 4) : se->_space);
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 3) : se->_space);
#ifdef __clang_gcanalyzer__
memset(se->buf, 0, len * 3);
#endif
Expand Down Expand Up @@ -290,7 +288,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
v->ub = roots[i++];
v->innervars = (jl_array_t*)roots[i++];
}
v->occurs = se->buf[j++];
v->occurs_inv = se->buf[j++];
v->occurs_cov = se->buf[j++];
v->max_offset = se->buf[j++];
Expand All @@ -302,15 +299,6 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
memset(&e->envout[e->envidx], 0, (e->envsz - e->envidx)*sizeof(void*));
}

static void clean_occurs(jl_stenv_t *e)
{
jl_varbinding_t *v = e->vars;
while (v) {
v->occurs = 0;
v = v->prev;
}
}

#define flip_offset(e) ((e)->Loffset *= -1)

// type utilities
Expand Down Expand Up @@ -599,6 +587,8 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)

static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);

#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)

static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
Expand Down Expand Up @@ -679,8 +669,6 @@ static int subtype_left_var(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int par
// of determining whether the variable is concrete.
static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param) JL_NOTSAFEPOINT
{
if (vb != NULL)
vb->occurs = 1;
if (vb != NULL && param) {
// saturate counters at 2; we don't need values bigger than that
if (param == 2 && e->invdepth > vb->depth0) {
Expand Down Expand Up @@ -915,7 +903,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0,
e->invdepth, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
Expand Down Expand Up @@ -3312,7 +3300,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res = NULL;
jl_savedenv_t se;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0,
e->invdepth, NULL, e->vars };
JL_GC_PUSH4(&res, &vb.lb, &vb.ub, &vb.innervars);
save_env(e, &se, 1);
Expand Down Expand Up @@ -3341,7 +3329,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
vb.ub = vb.var->ub;
}
restore_env(e, &se, vb.constraintkind == 1 ? 1 : 0);
vb.occurs = vb.occurs_cov = vb.occurs_inv = 0;
vb.occurs_cov = vb.occurs_inv = 0;
res = intersect_unionall_(t, u, e, R, param, &vb);
}
}
Expand Down Expand Up @@ -4042,79 +4030,12 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
return jl_bottom_type;
}

static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count)
static int merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se, int count)
{
if (count == 0)
alloc_env(e, se, 1);
jl_value_t **roots = NULL;
int nroots = 0;
if (se->gcframe.nroots == JL_GC_ENCODE_PUSHARGS(1)) {
jl_svec_t *sv = (jl_svec_t*)se->roots[0];
assert(jl_is_svec(sv));
roots = jl_svec_data(sv);
nroots = jl_svec_len(sv);
}
else {
roots = se->roots;
nroots = se->gcframe.nroots >> 2;
}
int m = 0, n = 0;
jl_varbinding_t *v = e->vars;
while (v != NULL) {
if (count == 0) {
// need to initialize this
se->buf[m] = 0;
se->buf[m+1] = 0;
se->buf[m+2] = 0;
se->buf[m+3] = v->max_offset;
}
jl_value_t *b1, *b2;
if (v->occurs) {
// only merge lb/ub if this var occurs.
b1 = roots[n];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = v->lb;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
roots[n] = b1 ? simple_meet(b1, b2, 0) : b2;
b1 = roots[n+1];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = v->ub;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
roots[n+1] = b1 ? simple_join(b1, b2) : b2;
// record the meeted vars.
se->buf[m] = 1;
}
// `innervars` might be re-sorted inside `finish_unionall`.
// We'd better always merge it.
b1 = roots[n+2];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = (jl_value_t*)v->innervars;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
if (b2 && b1 != b2) {
if (b1)
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
else
roots[n+2] = b2;
}
// always merge occurs_inv/cov by max (never decrease)
if (v->occurs_inv > se->buf[m+1])
se->buf[m+1] = v->occurs_inv;
if (v->occurs_cov > se->buf[m+2])
se->buf[m+2] = v->occurs_cov;
// always merge max_offset by min
if (!v->intersected && v->max_offset < se->buf[m+3])
se->buf[m+3] = v->max_offset;
m = m + 4;
n = n + 3;
v = v->prev;
if (count == 0) {
save_env(e, me, 1);
return 1;
}
assert(n == nroots); (void)nroots;
return count + 1;
}

// merge untouched vars' info.
static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
{
jl_value_t **merged = NULL;
jl_value_t **saved = NULL;
int nroots = 0;
Expand All @@ -4136,47 +4057,49 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
}
assert(nroots == current_env_length(e) * 3);
assert(nroots % 3 == 0);
for (int n = 0, m = 0; n < nroots; n += 3, m += 4) {
if (merged[n] == NULL)
merged[n] = saved[n];
if (merged[n+1] == NULL)
merged[n+1] = saved[n+1];
jl_value_t *b1, *b2;
int m = 0, n = 0;
jl_varbinding_t *v = e->vars;
while (v != NULL) {
jl_value_t *b0, *b1, *b2;
// merge `lb`
b0 = saved[n];
b1 = merged[n];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = v->lb;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
merged[n] = (b1 == b0 || b2 == b0) ? b0 : simple_meet(b1, b2, 0);
// merge `ub`
b0 = saved[n+1];
b1 = merged[n+1];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = v->ub;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
merged[n+1] = (b1 == b0 || b2 == b0) ? b0 : simple_join(b1, b2);
// merge `innervars`
b1 = merged[n+2];
JL_GC_PROMISE_ROOTED(b1); // clang-sagc doesn't know this came from our GC frame
b2 = saved[n+2];
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know this came from our GC frame
b2 = (jl_value_t*)v->innervars;
JL_GC_PROMISE_ROOTED(b2); // clang-sagc doesn't know the fields of this are stack GC roots
if (b2 && b1 != b2) {
if (b1)
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
else
merged[n+2] = b2;
}
me->buf[m] |= se->buf[m];
}
}

static void expand_local_env(jl_stenv_t *e, jl_value_t *res)
{
jl_varbinding_t *v = e->vars;
// Here we pull in some typevar missed in fastpath.
while (v != NULL) {
v->occurs = v->occurs || jl_has_typevar(res, v->var);
assert(v->occurs == 0 || v->occurs == 1);
v = v->prev;
}
v = e->vars;
while (v != NULL) {
if (v->occurs == 1) {
jl_varbinding_t *v2 = e->vars;
while (v2 != NULL) {
if (v2 != v && v2->occurs == 0)
v2->occurs = -(jl_has_typevar(v->lb, v2->var) || jl_has_typevar(v->ub, v2->var));
v2 = v2->prev;
}
}
// merge occurs_inv/cov by max (never decrease)
if (v->occurs_inv > me->buf[m])
me->buf[m] = v->occurs_inv;
if (v->occurs_cov > me->buf[m+1])
me->buf[m+1] = v->occurs_cov;
// merge max_offset by min
if (!v->intersected && v->max_offset < me->buf[m+2])
me->buf[m+2] = v->max_offset;
m = m + 3;
n = n + 3;
v = v->prev;
}
assert(n == nroots); (void)nroots;
return count + 1;
}

static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
Expand All @@ -4189,25 +4112,19 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
jl_savedenv_t se, me;
save_env(e, &se, 1);
int niter = 0, total_iter = 0;
clean_occurs(e);
is[0] = intersect(x, y, e, 0); // root
if (is[0] != jl_bottom_type) {
expand_local_env(e, is[0]);
niter = merge_env(e, &me, niter);
}
if (is[0] != jl_bottom_type)
niter = merge_env(e, &me, &se, niter);
restore_env(e, &se, 1);
while (next_union_state(e, 1)) {
if (e->emptiness_only && is[0] != jl_bottom_type)
break;
e->Runions.depth = 0;
e->Runions.more = 0;

clean_occurs(e);
is[1] = intersect(x, y, e, 0);
if (is[1] != jl_bottom_type) {
expand_local_env(e, is[1]);
niter = merge_env(e, &me, niter);
}
if (is[1] != jl_bottom_type)
niter = merge_env(e, &me, &se, niter);
restore_env(e, &se, 1);
if (is[0] == jl_bottom_type)
is[0] = is[1];
Expand All @@ -4216,13 +4133,18 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
is[0] = jl_type_union(is, 2);
}
total_iter++;
if (niter > 4 || total_iter > 400000) {
if (has_next_union_state(e, 1) && (niter > 4 || total_iter > 400000)) {
is[0] = y;
// we give up precise intersection here, just restore the saved env
restore_env(e, &se, 1);
if (niter > 0) {
free_env(&me);
niter = 0;
}
break;
}
}
if (niter) {
final_merge_env(e, &me, &se);
restore_env(e, &me, 1);
free_env(&me);
}
Expand Down Expand Up @@ -4707,7 +4629,7 @@ static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {

static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
{
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_value_t *nt;
JL_GC_PUSH2(&vb.innervars, &nt);
if (jl_is_unionall(u->body))
Expand Down
43 changes: 36 additions & 7 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2380,12 +2380,41 @@ let S = Tuple{T2, V2} where {T2, N2, V2<:(Array{S2, N2} where {S2 <: T2})},
@testintersect(S, T, !Union{})
end

# A simple case which has a small local union.
# make sure the env is not widened too much when we intersect(Int8, Int8).
struct T48006{A1,A2,A3} end
@testintersect(Tuple{T48006{Float64, Int, S1}, Int} where {F1<:Real, S1<:Union{Int8, Val{F1}}},
Tuple{T48006{F2, I, S2}, I} where {F2<:Real, I<:Int, S2<:Union{Int8, Val{F2}}},
Tuple{T48006{Float64, Int, S1}, Int} where S1<:Union{Val{Float64}, Int8})
let S = Dict{Int, S1} where {F1, S1<:Union{Int8, Val{F1}}},
T = Dict{F2, S2} where {F2, S2<:Union{Int8, Val{F2}}}
@test_broken typeintersect(S, T) == Dict{Int, S} where S<:Union{Val{Int}, Int8}
@test typeintersect(T, S) == Dict{Int, S} where S<:Union{Val{Int}, Int8}
end

# Ensure inner `intersect_all` never under-esitimate.
let S = Tuple{F1, Dict{Int, S1}} where {F1, S1<:Union{Int8, Val{F1}}},
T = Tuple{Any, Dict{F2, S2}} where {F2, S2<:Union{Int8, Val{F2}}}
@test Tuple{Nothing, Dict{Int, Int8}} <: S
@test Tuple{Nothing, Dict{Int, Int8}} <: T
@test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(S, T)
@test Tuple{Nothing, Dict{Int, Int8}} <: typeintersect(T, S)
end

let S = Tuple{F1, Val{S1}} where {F1, S1<:Dict{F1}}
T = Tuple{Any, Val{S2}} where {F2, S2<:Union{map(T->Dict{T}, Base.BitInteger_types)...}}
ST = typeintersect(S, T)
TS = typeintersect(S, T)
for U in Base.BitInteger_types
@test Tuple{U, Val{Dict{U,Nothing}}} <: S
@test Tuple{U, Val{Dict{U,Nothing}}} <: T
@test Tuple{U, Val{Dict{U,Nothing}}} <: ST
@test Tuple{U, Val{Dict{U,Nothing}}} <: TS
end
end

#issue 55206
struct T55206{A,B<:Complex{A},C<:Union{Dict{Nothing},Dict{A}}} end
@testintersect(T55206, T55206{<:Any,<:Any,<:Dict{Nothing}}, T55206{A,<:Complex{A},<:Dict{Nothing}} where {A})
@testintersect(
Tuple{Dict{Int8, Int16}, Val{S1}} where {F1, S1<:AbstractSet{F1}},
Tuple{Dict{T1, T2}, Val{S2}} where {T1, T2, S2<:Union{Set{T1},Set{T2}}},
Tuple{Dict{Int8, Int16}, Val{S1}} where {S1<:Union{Set{Int8},Set{Int16}}}
)

f48167(::Type{Val{L2}}, ::Type{Union{Val{L1}, Set{R}}}) where {L1, R, L2<:L1} = 1
f48167(::Type{Val{L1}}, ::Type{Union{Val{L2}, Set{R}}}) where {L1, R, L2<:L1} = 2
Expand Down Expand Up @@ -2554,7 +2583,7 @@ end
let T = Tuple{Union{Type{T}, Type{S}}, Union{Val{T}, Val{S}}, Union{Val{T}, S}} where T<:Val{A} where A where S<:Val,
S = Tuple{Type{T}, T, Val{T}} where T<:(Val{S} where S<:Val)
# optimal = Union{}?
@test typeintersect(T, S) == Tuple{Type{A}, Union{Val{A}, Val{S} where S<:Union{Val, A}, Val{x} where x<:Val, Val{x} where x<:Union{Val, A}}, Val{A}} where A<:(Val{S} where S<:Val)
@test typeintersect(T, S) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {S<:Val, T<:Val}
@test typeintersect(S, T) == Tuple{Type{T}, Union{Val{T}, Val{S}}, Val{T}} where {T<:Val, S<:(Union{Val{A}, Val} where A)}
end

Expand Down

0 comments on commit fb6b790

Please sign in to comment.