Skip to content

Commit 748149e

Browse files
committed
Only merge vars occur in the local union decision.
If we always merge the whole env, then the output bounds would be widen than input if different Union decision touch different vars. Also add missing `occurs_inv/cov`'s merge (by max).
1 parent 6deb98f commit 748149e

File tree

2 files changed

+115
-26
lines changed

2 files changed

+115
-26
lines changed

Diff for: src/subtype.c

+108-26
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ typedef struct jl_varbinding_t {
6565
jl_value_t *lb;
6666
jl_value_t *ub;
6767
int8_t right; // whether this variable came from the right side of `A <: B`
68+
int8_t occurs; // occurs in any position
6869
int8_t occurs_inv; // occurs in invariant position
6970
int8_t occurs_cov; // # of occurrences in covariant position
7071
int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete
@@ -161,7 +162,7 @@ static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
161162
typedef struct {
162163
int8_t *buf;
163164
int rdepth;
164-
int8_t _space[16];
165+
int8_t _space[24];
165166
} jl_savedenv_t;
166167

167168
static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se)
@@ -174,9 +175,9 @@ static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se)
174175
}
175176
if (root)
176177
*root = (jl_value_t*)jl_alloc_svec(len * 3);
177-
se->buf = (int8_t*)(len > 8 ? malloc_s(len * 2) : &se->_space);
178+
se->buf = (int8_t*)(len > 8 ? malloc_s(len * 3) : &se->_space);
178179
#ifdef __clang_gcanalyzer__
179-
memset(se->buf, 0, len * 2);
180+
memset(se->buf, 0, len * 3);
180181
#endif
181182
int i=0, j=0; v = e->vars;
182183
while (v != NULL) {
@@ -185,6 +186,7 @@ static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se)
185186
jl_svecset(*root, i++, v->ub);
186187
jl_svecset(*root, i++, (jl_value_t*)v->innervars);
187188
}
189+
se->buf[j++] = v->occurs;
188190
se->buf[j++] = v->occurs_inv;
189191
se->buf[j++] = v->occurs_cov;
190192
v = v->prev;
@@ -207,6 +209,7 @@ static void restore_env(jl_stenv_t *e, jl_value_t *root, jl_savedenv_t *se) JL_N
207209
if (root) v->lb = jl_svecref(root, i++);
208210
if (root) v->ub = jl_svecref(root, i++);
209211
if (root) v->innervars = (jl_array_t*)jl_svecref(root, i++);
212+
v->occurs = se->buf[j++];
210213
v->occurs_inv = se->buf[j++];
211214
v->occurs_cov = se->buf[j++];
212215
v = v->prev;
@@ -227,6 +230,15 @@ static int current_env_length(jl_stenv_t *e)
227230
return len;
228231
}
229232

233+
static void clean_occurs(jl_stenv_t *e)
234+
{
235+
jl_varbinding_t *v = e->vars;
236+
while (v) {
237+
v->occurs = 0;
238+
v = v->prev;
239+
}
240+
}
241+
230242
// type utilities
231243

232244
// quickly test that two types are identical
@@ -590,6 +602,8 @@ static int subtype_left_var(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int par
590602
// of determining whether the variable is concrete.
591603
static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param) JL_NOTSAFEPOINT
592604
{
605+
if (vb != NULL)
606+
vb->occurs = 1;
593607
if (vb != NULL && param) {
594608
// saturate counters at 2; we don't need values bigger than that
595609
if (param == 2 && (vb->right ? e->Rinvdepth : e->invdepth) > vb->depth0) {
@@ -782,7 +796,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
782796
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
783797
{
784798
u = unalias_unionall(u, e);
785-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0,
799+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
786800
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
787801
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
788802
e->vars = &vb;
@@ -2741,7 +2755,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
27412755
{
27422756
jl_value_t *res=NULL, *save=NULL;
27432757
jl_savedenv_t se;
2744-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0,
2758+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
27452759
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
27462760
JL_GC_PUSH5(&res, &vb.lb, &vb.ub, &save, &vb.innervars);
27472761
save_env(e, &save, &se);
@@ -2754,13 +2768,13 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
27542768
else if (res != jl_bottom_type) {
27552769
if (vb.concrete || vb.occurs_inv>1 || vb.intvalued > 1 || u->var->lb != jl_bottom_type || (vb.occurs_inv && vb.occurs_cov)) {
27562770
restore_env(e, NULL, &se);
2757-
vb.occurs_cov = vb.occurs_inv = 0;
2771+
vb.occurs = vb.occurs_cov = vb.occurs_inv = 0;
27582772
vb.constraintkind = vb.concrete ? 1 : 2;
27592773
res = intersect_unionall_(t, u, e, R, param, &vb);
27602774
}
27612775
else if (vb.occurs_cov && !var_occurs_invariant(u->body, u->var, 0)) {
27622776
restore_env(e, save, &se);
2763-
vb.occurs_cov = vb.occurs_inv = 0;
2777+
vb.occurs = vb.occurs_cov = vb.occurs_inv = 0;
27642778
vb.constraintkind = 1;
27652779
res = intersect_unionall_(t, u, e, R, param, &vb);
27662780
}
@@ -3271,36 +3285,97 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
32713285

32723286
static int merge_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se, int count)
32733287
{
3274-
if (!count) {
3275-
save_env(e, root, se);
3276-
return 1;
3288+
if (count == 0) {
3289+
int len = current_env_length(e);
3290+
*root = (jl_value_t*)jl_alloc_svec(len * 3);
3291+
se->buf = (int8_t*)(len > 8 ? malloc_s(len * 3) : &se->_space);
3292+
memset(se->buf, 0, len * 3);
32773293
}
32783294
int n = 0;
32793295
jl_varbinding_t *v = e->vars;
32803296
jl_value_t *b1 = NULL, *b2 = NULL;
32813297
JL_GC_PUSH2(&b1, &b2); // clang-sagc does not understand that *root is rooted already
3298+
v = e->vars;
32823299
while (v != NULL) {
3283-
b1 = jl_svecref(*root, n);
3284-
b2 = v->lb;
3285-
jl_svecset(*root, n, simple_meet(b1, b2));
3286-
b1 = jl_svecref(*root, n+1);
3287-
b2 = v->ub;
3288-
jl_svecset(*root, n+1, simple_join(b1, b2));
3289-
b1 = jl_svecref(*root, n+2);
3290-
b2 = (jl_value_t*)v->innervars;
3291-
if (b2 && b1 != b2) {
3292-
if (b1)
3293-
jl_array_ptr_1d_append((jl_array_t*)b2, (jl_array_t*)b1);
3294-
else
3295-
jl_svecset(*root, n+2, b2);
3300+
if (v->occurs) {
3301+
// only merge lb/ub/innervars if this var occurs.
3302+
b1 = jl_svecref(*root, n);
3303+
b2 = v->lb;
3304+
jl_svecset(*root, n, b1 ? simple_meet(b1, b2) : b2);
3305+
b1 = jl_svecref(*root, n+1);
3306+
b2 = v->ub;
3307+
jl_svecset(*root, n+1, b1 ? simple_join(b1, b2) : b2);
3308+
b1 = jl_svecref(*root, n+2);
3309+
b2 = (jl_value_t*)v->innervars;
3310+
if (b2 && b1 != b2) {
3311+
if (b1)
3312+
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
3313+
else
3314+
jl_svecset(*root, n+2, b2);
3315+
}
3316+
// record the meeted vars.
3317+
se->buf[n] = 1;
32963318
}
3319+
// always merge occurs_inv/cov by max (never decrease)
3320+
if (v->occurs_inv > se->buf[n+1])
3321+
se->buf[n+1] = v->occurs_inv;
3322+
if (v->occurs_cov > se->buf[n+2])
3323+
se->buf[n+2] = v->occurs_cov;
32973324
n = n + 3;
32983325
v = v->prev;
32993326
}
33003327
JL_GC_POP();
33013328
return count + 1;
33023329
}
33033330

3331+
// merge untouched vars' info.
3332+
static void final_merge_env(jl_value_t **merged, jl_savedenv_t *me, jl_value_t **saved, jl_savedenv_t *se)
3333+
{
3334+
int l = jl_svec_len(*merged);
3335+
assert(l == jl_svec_len(*saved) && l%3 == 0);
3336+
jl_value_t *b1 = NULL, *b2 = NULL;
3337+
JL_GC_PUSH2(&b1, &b2);
3338+
for (int n = 0; n < l; n = n + 3) {
3339+
if (jl_svecref(*merged, n) == NULL)
3340+
jl_svecset(*merged, n, jl_svecref(*saved, n));
3341+
if (jl_svecref(*merged, n+1) == NULL)
3342+
jl_svecset(*merged, n+1, jl_svecref(*saved, n+1));
3343+
b1 = jl_svecref(*merged, n+2);
3344+
b2 = jl_svecref(*saved , n+2);
3345+
if (b2 && b1 != b2) {
3346+
if (b1)
3347+
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
3348+
else
3349+
jl_svecset(*merged, n+2, b2);
3350+
}
3351+
me->buf[n] |= se->buf[n];
3352+
}
3353+
JL_GC_POP();
3354+
}
3355+
3356+
static void expand_local_env(jl_stenv_t *e, jl_value_t *res)
3357+
{
3358+
jl_varbinding_t *v = e->vars;
3359+
// Here we pull in some typevar missed in fastpath.
3360+
while (v != NULL) {
3361+
v->occurs = v->occurs || jl_has_typevar(res, v->var);
3362+
assert(v->occurs == 0 || v->occurs == 1);
3363+
v = v->prev;
3364+
}
3365+
v = e->vars;
3366+
while (v != NULL) {
3367+
if (v->occurs == 1) {
3368+
jl_varbinding_t *v2 = e->vars;
3369+
while (v2 != NULL) {
3370+
if (v2 != v && v2->occurs == 0)
3371+
v2->occurs = -(jl_has_typevar(v->lb, v2->var) || jl_has_typevar(v->ub, v2->var));
3372+
v2 = v2->prev;
3373+
}
3374+
}
3375+
v = v->prev;
3376+
}
3377+
}
3378+
33043379
static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
33053380
{
33063381
e->Runions.depth = 0;
@@ -3313,10 +3388,13 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
33133388
jl_savedenv_t se, me;
33143389
save_env(e, saved, &se);
33153390
int lastset = 0, niter = 0, total_iter = 0;
3391+
clean_occurs(e);
33163392
jl_value_t *ii = intersect(x, y, e, 0);
33173393
is[0] = ii; // root
3318-
if (is[0] != jl_bottom_type)
3394+
if (is[0] != jl_bottom_type) {
3395+
expand_local_env(e, is[0]);
33193396
niter = merge_env(e, merged, &me, niter);
3397+
}
33203398
restore_env(e, *saved, &se);
33213399
while (e->Runions.more) {
33223400
if (e->emptiness_only && ii != jl_bottom_type)
@@ -3330,9 +3408,12 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
33303408
lastset = set;
33313409

33323410
is[0] = ii;
3411+
clean_occurs(e);
33333412
is[1] = intersect(x, y, e, 0);
3334-
if (is[1] != jl_bottom_type)
3413+
if (is[1] != jl_bottom_type) {
3414+
expand_local_env(e, is[1]);
33353415
niter = merge_env(e, merged, &me, niter);
3416+
}
33363417
restore_env(e, *saved, &se);
33373418
if (is[0] == jl_bottom_type)
33383419
ii = is[1];
@@ -3348,7 +3429,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
33483429
break;
33493430
}
33503431
}
3351-
if (niter){
3432+
if (niter) {
3433+
final_merge_env(merged, &me, saved, &se);
33523434
restore_env(e, *merged, &me);
33533435
free_env(&me);
33543436
}

Diff for: test/subtype.jl

+7
Original file line numberDiff line numberDiff line change
@@ -2319,6 +2319,13 @@ let S = Tuple{T2, V2} where {T2, N2, V2<:(Array{S2, N2} where {S2 <: T2})},
23192319
@testintersect(S, T, !Union{})
23202320
end
23212321

2322+
# A simple case which has a small local union.
2323+
# make sure the env is not widened too much when we intersect(Int8, Int8).
2324+
struct T48006{A1,A2,A3} end
2325+
@testintersect(Tuple{T48006{Float64, Int, S1}, Int} where {F1<:Real, S1<:Union{Int8, Val{F1}}},
2326+
Tuple{T48006{F2, I, S2}, I} where {F2<:Real, I<:Int, S2<:Union{Int8, Val{F2}}},
2327+
Tuple{T48006{Float64, Int, S1}, Int} where S1<:Union{Val{Float64}, Int8})
2328+
23222329
@testset "known subtype/intersect issue" begin
23232330
#issue 45874
23242331
# Causes a hang due to jl_critical_error calling back into malloc...

0 commit comments

Comments
 (0)