Skip to content

Commit

Permalink
Fix #41096 and #43082, make sure env is restored when typeintersect…
Browse files Browse the repository at this point in the history
… tries a new Union decision (#46350)

* `intersect_all` should always `restore_env`. let `merge_env` track valid `env` change.

* Add test.

Co-authored-by: Jeff Bezanson <[email protected]>
  • Loading branch information
N5N3 and JeffBezanson authored Aug 18, 2022
1 parent f7144f7 commit 9aabb4c
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 29 deletions.
73 changes: 49 additions & 24 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,9 @@ static void restore_env(jl_stenv_t *e, jl_value_t *root, jl_savedenv_t *se) JL_N
jl_varbinding_t *v = e->vars;
int i = 0, j = 0;
while (v != NULL) {
if (root) v->lb = jl_svecref(root, i);
i++;
if (root) v->ub = jl_svecref(root, i);
i++;
if (root) v->innervars = (jl_array_t*)jl_svecref(root, i);
i++;
if (root) v->lb = jl_svecref(root, i++);
if (root) v->ub = jl_svecref(root, i++);
if (root) v->innervars = (jl_array_t*)jl_svecref(root, i++);
v->occurs_inv = se->buf[j++];
v->occurs_cov = se->buf[j++];
v = v->prev;
Expand Down Expand Up @@ -2323,6 +2320,11 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
JL_GC_POP();
return jl_bottom_type;
}
if (jl_is_uniontype(ub) && !jl_is_uniontype(a)) {
bb->ub = ub;
bb->lb = jl_bottom_type;
ub = (jl_value_t*)b;
}
}
if (ub != (jl_value_t*)b) {
if (jl_has_free_typevars(ub)) {
Expand Down Expand Up @@ -3166,26 +3168,50 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
return jl_bottom_type;
}

static int merge_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se, int count)
{
if (!count) {
save_env(e, root, se);
return 1;
}
int n = 0;
jl_varbinding_t *v = e->vars;
jl_value_t *ub = NULL, *vub = NULL;
JL_GC_PUSH2(&ub, &vub);
while (v != NULL) {
if (v->ub != v->var->ub || v->lb != v->var->lb) {
jl_value_t *lb = jl_svecref(*root, n);
if (v->lb != lb)
jl_svecset(*root, n, lb ? jl_bottom_type : v->lb);
ub = jl_svecref(*root, n+1);
vub = v->ub;
if (vub != ub)
jl_svecset(*root, n+1, ub ? simple_join(ub, vub) : vub);
}
n = n + 3;
v = v->prev;
}
JL_GC_POP();
return count + 1;
}

static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
e->Runions.depth = 0;
e->Runions.more = 0;
e->Runions.used = 0;
jl_value_t **is;
JL_GC_PUSHARGS(is, 3);
JL_GC_PUSHARGS(is, 4);
jl_value_t **saved = &is[2];
jl_savedenv_t se;
jl_value_t **merged = &is[3];
jl_savedenv_t se, me;
save_env(e, saved, &se);
int lastset = 0, niter = 0, total_iter = 0;
jl_value_t *ii = intersect(x, y, e, 0);
is[0] = ii; // root
if (ii == jl_bottom_type) {
restore_env(e, *saved, &se);
}
else {
free_env(&se);
save_env(e, saved, &se);
}
if (is[0] != jl_bottom_type)
niter = merge_env(e, merged, &me, niter);
restore_env(e, *saved, &se);
while (e->Runions.more) {
if (e->emptiness_only && ii != jl_bottom_type)
break;
Expand All @@ -3199,28 +3225,27 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)

is[0] = ii;
is[1] = intersect(x, y, e, 0);
if (is[1] == jl_bottom_type) {
restore_env(e, *saved, &se);
}
else {
free_env(&se);
save_env(e, saved, &se);
}
if (is[1] != jl_bottom_type)
niter = merge_env(e, merged, &me, niter);
restore_env(e, *saved, &se);
if (is[0] == jl_bottom_type)
ii = is[1];
else if (is[1] == jl_bottom_type)
ii = is[0];
else {
// TODO: the repeated subtype checks in here can get expensive
ii = jl_type_union(is, 2);
niter++;
}
total_iter++;
if (niter > 3 || total_iter > 400000) {
if (niter > 4 || total_iter > 400000) {
ii = y;
break;
}
}
if (niter){
restore_env(e, *merged, &me);
free_env(&me);
}
free_env(&se);
JL_GC_POP();
return ii;
Expand Down
34 changes: 29 additions & 5 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1928,12 +1928,24 @@ let A = Tuple{Ref{T}, Vararg{T}} where T,
B = Tuple{Ref{U}, Union{Ref{S}, Ref{U}, Int}, Union{Ref{S}, S}} where S where U,
C = Tuple{Ref{U}, Union{Ref{S}, Ref{U}, Ref{W}}, Union{Ref{S}, W, V}} where V<:AbstractArray where W where S where U
I = typeintersect(A, B)
Ts = (Tuple{Ref{Int}, Int, Int}, Tuple{Ref{Ref{Int}}, Ref{Int}, Ref{Int}})
@test I != Union{}
@test I <: A
@test I <: B
# avoid stack overflow
@test_broken I <: B
for T in Ts
if T <: A && T <: B
@test T <: I
end
end
J = typeintersect(A, C)
@test_broken J != Union{}
@test J != Union{}
@test J <: A
@test_broken J <: C
for T in Ts
if T <: A && T <: C
@test T <: J
end
end
end

let A = Tuple{Dict{I,T}, I, T} where T where I,
Expand Down Expand Up @@ -1964,8 +1976,9 @@ let A = Tuple{Any, Type{Ref{_A}} where _A},
B = Tuple{Type{T}, Type{<:Union{Ref{T}, T}}} where T,
I = typeintersect(A, B)
@test I != Union{}
# TODO: this intersection result is still too narrow
@test_broken Tuple{Type{Ref{Integer}}, Type{Ref{Integer}}} <: I
@test Tuple{Type{Ref{Integer}}, Type{Ref{Integer}}} <: I
# TODO: this intersection result seems too wide (I == B) ?
@test_broken !<:(Tuple{Type{Int}, Type{Int}}, I)
end

@testintersect(Tuple{Type{T}, T} where T<:(Tuple{Vararg{_A, _B}} where _B where _A),
Expand Down Expand Up @@ -1996,3 +2009,14 @@ let T = TypeVar(:T, Real),
@test !(UnionAll(T, UnionAll(V, UnionAll(T, Type{Pair{T, V}}))) <: UnionAll(T, UnionAll(V, Type{Pair{T, V}})))
@test !(UnionAll(T, UnionAll(V, UnionAll(T, S))) <: UnionAll(T, UnionAll(V, S)))
end

# issue #41096
let C = Val{Val{B}} where {B}
@testintersect(Val{<:Union{Missing, Val{false}, Val{true}}}, C, Val{<:Union{Val{true}, Val{false}}})
@testintersect(Val{<:Union{Nothing, Val{true}, Val{false}}}, C, Val{<:Union{Val{true}, Val{false}}})
@testintersect(Val{<:Union{Nothing, Val{false}}}, C, Val{Val{false}})
end

#issue #43082
struct X43082{A, I, B<:Union{Ref{I},I}}; end
@testintersect(Tuple{X43082{T}, Int} where T, Tuple{X43082{Int}, Any}, Tuple{X43082{Int}, Int})

0 comments on commit 9aabb4c

Please sign in to comment.