Skip to content

Commit

Permalink
fixes for cycles and stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash committed Oct 30, 2024
1 parent c67759e commit 1b9af2e
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 181 deletions.
4 changes: 1 addition & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1401,9 +1401,7 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
result, match, argtypes, info, flag, state; allow_typevars=true)
end
if !fully_covered
atype = argtypes_to_type(sig.argtypes)
# We will emit an inline MethodError so we need a backedge to the MethodTable
add_uncovered_edges!(state.edges, info, atype)
# We will emit an inline MethodError in this case, but that info already came inference, so we must already have the uncovered edge for it
end
elseif !isempty(cases)
# if we've not seen all candidates, union split is valid only for dispatch tuples
Expand Down
21 changes: 0 additions & 21 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ end
nsplit_impl(info::MethodMatchInfo) = 1
getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results)
getresult_impl(::MethodMatchInfo, ::Int) = nothing
function add_uncovered_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, @nospecialize(atype))
fully_covering(info) || push!(edges, info.mt, atype)
nothing
end

"""
info::UnionSplitInfo <: CallInfo
Expand All @@ -154,22 +150,6 @@ _add_edges_impl(edges::Vector{Any}, info::UnionSplitInfo, mi_edge::Bool=false) =
nsplit_impl(info::UnionSplitInfo) = length(info.split)
getsplit_impl(info::UnionSplitInfo, idx::Int) = getsplit(info.split[idx], 1)
getresult_impl(::UnionSplitInfo, ::Int) = nothing
function add_uncovered_edges_impl(edges::Vector{Any}, info::UnionSplitInfo, @nospecialize(atype))
all(fully_covering, info.split) && return nothing
# add mt backedges with removing duplications
for mt in uncovered_method_tables(info)
push!(edges, mt, atype)
end
end
function uncovered_method_tables(info::UnionSplitInfo)
mts = MethodTable[]
for mminfo in info.split
fully_covering(mminfo) && continue
any(mt′::MethodTable->mt′===mminfo.mt, mts) && continue
push!(mts, mminfo.mt)
end
return mts
end

abstract type ConstResult end

Expand Down Expand Up @@ -215,7 +195,6 @@ add_edges_impl(edges::Vector{Any}, info::ConstCallInfo) = add_edges!(edges, info
nsplit_impl(info::ConstCallInfo) = nsplit(info.call)
getsplit_impl(info::ConstCallInfo, idx::Int) = getsplit(info.call, idx)
getresult_impl(info::ConstCallInfo, idx::Int) = info.results[idx]
add_uncovered_edges_impl(edges::Vector{Any}, info::ConstCallInfo, @nospecialize(atype)) = add_uncovered_edges!(edges, info.call, atype)

"""
info::MethodResultPure <: CallInfo
Expand Down
245 changes: 89 additions & 156 deletions src/staticdata_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -833,28 +833,33 @@ static size_t verify_call(jl_value_t *sig, jl_svec_t *expecteds, size_t i, size_
return max_valid;
}

// Test all edges relevant to a method
static size_t jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, arraylist_t *tovisit, htable_t *visited)
// Test all edges relevant to a method:
//// Visit the entire call graph, starting from edges[idx] to determine if that method is valid
//// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
//// and slightly modified with an early termination option once the computation reaches its minimum
static int jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, size_t *maxworld, arraylist_t *stack, htable_t *visiting)
{
size_t max_valid = jl_atomic_load_relaxed(&codeinst->max_world);
if (max_valid != WORLD_AGE_REVALIDATION_SENTINEL)
return max_valid;
max_valid = ~(size_t)0;
size_t max_valid2 = jl_atomic_load_relaxed(&codeinst->max_world);
if (max_valid2 != WORLD_AGE_REVALIDATION_SENTINEL) {
*maxworld = max_valid2;
return 0;
}
jl_method_instance_t *caller = codeinst->def;
assert(jl_is_method_instance(caller) && jl_is_method(caller->def.method));
void **bp = ptrhash_bp(visited, caller);
void **bp = ptrhash_bp(visiting, codeinst);
if (*bp != HT_NOTFOUND)
//return 0; // handle cycles by giving up
return ~(size_t)0; // handle cycles by giving up
*bp = (void*)caller;
return (char*)*bp - (char*)HT_NOTFOUND; // cycle idx
arraylist_push(stack, (void*)codeinst);
size_t depth = stack->len;
*bp = (char*)HT_NOTFOUND + depth;
JL_TIMING(VERIFY_IMAGE, VERIFY_Methods);
jl_value_t *loctag = NULL;
jl_value_t *sig = NULL;
jl_value_t *matches = NULL;
jl_array_t *maxvalids2 = NULL;
JL_GC_PUSH4(&loctag, &maxvalids2, &matches, &sig);
JL_GC_PUSH3(&loctag, &matches, &sig);
jl_svec_t *callees = jl_atomic_load_relaxed(&codeinst->edges);
assert(jl_is_svec((jl_value_t*)callees));
// verify current edges
for (size_t j = 0; j < jl_svec_len(callees); ) {
jl_value_t *edge = jl_svecref(callees, j);
size_t max_valid2;
Expand Down Expand Up @@ -894,8 +899,8 @@ static size_t jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, ar
max_valid2 = verify_invokesig(edge, meth, minworld);
j += 2;
}
if (max_valid2 < max_valid)
max_valid = max_valid2;
if (*maxworld > max_valid2)
*maxworld = max_valid2;
if (max_valid2 != ~(size_t)0 && _jl_debug_method_invalidation) {
jl_array_ptr_1d_push(_jl_debug_method_invalidation, edge);
loctag = jl_cstr_to_string("insert_backedges_callee");
Expand All @@ -905,167 +910,89 @@ static size_t jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, ar
}
//jl_static_show((JL_STREAM*)ios_stderr, (jl_value_t*)edge);
//ios_puts(max_valid2 == ~(size_t)0 ? "valid\n" : "INVALID\n", ios_stderr);
if (max_valid == 0 && !_jl_debug_method_invalidation)
if (max_valid2 == 0 && !_jl_debug_method_invalidation)
break;
}
if (max_valid == ~(size_t)0 || _jl_debug_method_invalidation) {
JL_GC_POP();
// verify recursive edges (if valid, or debugging)
size_t cycle = depth;
jl_code_instance_t *cause = codeinst;
if (*maxworld == ~(size_t)0 || _jl_debug_method_invalidation) {
for (size_t j = 0; j < jl_svec_len(callees); j++) {
jl_value_t *edge = jl_svecref(callees, j);
if (!jl_is_code_instance(edge))
continue;
jl_code_instance_t *callee = (jl_code_instance_t*)edge;
size_t max_valid2 = jl_verify_method(callee, minworld, tovisit, visited);
if (max_valid2 < max_valid)
max_valid = max_valid2;
if (max_valid2 != ~(size_t)0 && _jl_debug_method_invalidation) {
jl_array_ptr_1d_push(_jl_debug_method_invalidation, edge);
loctag = jl_cstr_to_string("insert_backedges_callee");
jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag);
jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)codeinst);
loctag = jl_cstr_to_string("recursive"); // TODO?
jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag);
size_t max_valid2 = ~(size_t)0;
size_t child_cycle = jl_verify_method(callee, minworld, &max_valid2, stack, visiting);
if (*maxworld > max_valid2) {
cause = callee;
*maxworld = max_valid2;
}
//jl_static_show((JL_STREAM*)ios_stderr, (jl_value_t*)callee->def);
//ios_puts(max_valid2 == ~(size_t)0 ? "valid\n" : "INVALID\n", ios_stderr);
if (max_valid == 0 && !_jl_debug_method_invalidation)
if (max_valid2 == 0) {
// found what we were looking for, so terminate early
break;
}
else if (child_cycle && child_cycle < cycle) {
// record the cycle will resolve at depth "cycle"
cycle = child_cycle;
}
}
}
if (max_valid != ~(size_t)0 && _jl_debug_method_invalidation) {
jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)caller);
loctag = jl_cstr_to_string("verify_methods");
jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag);
jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)codeinst);
if (*maxworld != 0 && cycle != depth)
return cycle;
// If we are the top of the current cycle, now mark all other parts of
// our cycle with what we found.
// Or if we found a failed edge, also mark all of the other parts of the
// cycle as also having a failed edge.
while (stack->len >= depth) {
jl_code_instance_t *child = (jl_code_instance_t*)arraylist_pop(stack);
if (*maxworld != jl_atomic_load_relaxed(&child->max_world))
jl_atomic_store_relaxed(&child->max_world, *maxworld);
void **bp = ptrhash_bp(visiting, codeinst);
assert(*bp == (char*)HT_NOTFOUND + stack->len + 1);
*bp = HT_NOTFOUND;
if (_jl_debug_method_invalidation && *maxworld != ~(size_t)0) {
jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)child);
loctag = jl_cstr_to_string("verify_methods");
JL_GC_PUSH1(&loctag);
jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag);
jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)cause);
JL_GC_POP();
}
}
jl_atomic_store_relaxed(&codeinst->max_world, max_valid);
//jl_static_show((JL_STREAM*)ios_stderr, (jl_value_t*)caller);
//ios_puts(max_valid == ~(size_t)0 ? "valid\n\n" : "INVALID\n\n", ios_stderr);
JL_GC_POP();
return max_valid;
return 0;
}

static size_t jl_verify_method_graph(jl_code_instance_t *codeinst, size_t minworld, arraylist_t *tovisit, htable_t *visited)
static size_t jl_verify_method_graph(jl_code_instance_t *codeinst, size_t minworld, arraylist_t *stack, htable_t *visiting)
{
assert(tovisit->len == 0);
size_t max_valid = jl_verify_method(codeinst, minworld, tovisit, visited);
htable_reset(visited, 0);
return max_valid;
assert(stack->len == 0);
for (size_t i = 0, hsz = visiting->size; i < hsz; i++)
assert(visiting->table[i] == HT_NOTFOUND);
size_t maxworld = ~(size_t)0;
int child_cycle = jl_verify_method(codeinst, minworld, &maxworld, stack, visiting);
assert(child_cycle == 0); (void)child_cycle;
assert(stack->len == 0);
for (size_t i = 0, hsz = visiting->size / 2; i < hsz; i++) {
assert(visiting->table[2 * i + 1] == HT_NOTFOUND);
visiting->table[2 * i] = HT_NOTFOUND;
}
return maxworld;
}

//// Visit the entire call graph, starting from edges[idx] to determine if that method is valid
//// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
//// and slightly modified with an early termination option once the computation reaches its minimum
//static int jl_verify_graph_edge(size_t *maxvalids2_data, htable_t *idxs, jl_array_t *edges, size_t idx, arraylist_t *visited, arraylist_t *stack) jwn
//{
// assert(idx < visited->len);
// if (maxvalids2_data[idx] == 0) {
// visited->items[idx] = (void*)1;
// return 0;
// }
// size_t cycle = (size_t)visited->items[idx];
// if (cycle != 0)
// return cycle - 1; // depth remaining
// jl_value_t *cause = NULL;
// arraylist_push(stack, (void*)idx);
// size_t depth = stack->len;
// visited->items[idx] = (void*)(1 + depth);
// cycle = depth;
// jl_svec_t *callees = (jl_svec_t*)jl_array_ptr_ref(edges, 2 * idx + 1);
// assert(jl_is_svec((jl_value_t*)callees));
// for (size_t i = 0; i < jl_svec_len(callees); i++) {
// jl_value_t *edge = jl_svecref(callees, i);
// if (!jl_is_method_instance(edge))
// continue;
// void *verify_edge = ptrhash_get(idxs, edge);
// if (verify_edge == HT_NOTFOUND)
// continue;
// size_t childidx = (char*)verify_edge - (char*)HT_NOTFOUND - 1;
// int child_cycle = jl_verify_graph_edge(maxvalids2_data, idxs, edges, childidx, visited, stack);
// size_t child_max_valid = maxvalids2_data[childidx];
// if (child_max_valid < maxvalids2_data[idx]) {
// maxvalids2_data[idx] = child_max_valid;
// cause = jl_array_ptr_ref(edges, childidx * 2);
// }
// if (child_max_valid == 0) {
// // found what we were looking for, so terminate early
// break;
// }
// else if (child_cycle && child_cycle < cycle) {
// // record the cycle will resolve at depth "cycle"
// cycle = child_cycle;
// }
// }
// size_t max_valid = maxvalids2_data[idx];
// if (max_valid != 0 && cycle != depth)
// return cycle;
// // If we are the top of the current cycle, now mark all other parts of
// // our cycle with what we found.
// // Or if we found a failed edge, also mark all of the other parts of the
// // cycle as also having an failed edge.
// while (stack->len >= depth) {
// size_t childidx = (size_t)arraylist_pop(stack);
// assert(visited->items[childidx] == (void*)(2 + stack->len));
// if (idx != childidx) {
// if (max_valid < maxvalids2_data[childidx])
// maxvalids2_data[childidx] = max_valid;
// }
// visited->items[childidx] = (void*)1;
// if (_jl_debug_method_invalidation && max_valid != ~(size_t)0) {
// jl_method_instance_t *mi = (jl_method_instance_t*)jl_array_ptr_ref(edges, childidx * 2);
// jl_value_t *loctag = NULL;
// JL_GC_PUSH1(&loctag);
// jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)mi);
// loctag = jl_cstr_to_string("verify_methods");
// jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag);
// jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)cause);
// JL_GC_POP();
// }
// }
// return 0;
//}

// Restore backedges to external targets
// `edges` = [caller1, ...], the list of worklist-owned code instances internally
// `ext_ci_list` = [caller1, ...], the list of worklist-owned code instances externally
static void jl_insert_backedges(jl_array_t *edges, jl_array_t *ext_ci_list, size_t minworld)
{
// determine which CodeInstance objects are still valid in our image
//ios_puts("===\n", ios_stderr);
//for (size_t i = 0; i < jl_array_nrows(edges); i++) {
// jl_code_instance_t *caller = (jl_code_instance_t*)jl_array_ptr_ref(edges, i);
// jl_svec_t *targets = jl_atomic_load_relaxed(&caller->edges);
// ios_puts(" => ", ios_stderr);
// jl_(caller);
// for (size_t j = 0; j < jl_svec_len(targets); j++) {
// jl_value_t *edge = jl_svecref(targets, j);
// if (jl_is_mtable(edge)) { j++; continue; }
// jl_(edge);
// }
//}
//ios_puts("===\n", ios_stderr);

size_t n_ext_cis = ext_ci_list ? jl_array_nrows(ext_ci_list) : 0;

// next jwn
for (size_t i = 0; i < n_ext_cis; i++) {
jl_code_instance_t *ci = (jl_code_instance_t*)jl_array_ptr_ref(ext_ci_list, i);
if (jl_atomic_load_relaxed(&ci->max_world) != WORLD_AGE_REVALIDATION_SENTINEL) {
assert(jl_atomic_load_relaxed(&ci->min_world) == 1);
assert(jl_atomic_load_relaxed(&ci->max_world) == ~(size_t)0);
jl_method_instance_t *caller = ci->def;
if (jl_atomic_load_relaxed(&ci->inferred) && jl_rettype_inferred(ci->owner, caller, minworld, ~(size_t)0) == jl_nothing) {
jl_mi_cache_insert(caller, ci);
}
//jl_static_show((JL_STREAM*)ios_stderr, (jl_value_t*)caller);
//ios_puts("free\n", ios_stderr);
}
}

arraylist_t tovisit;
arraylist_new(&tovisit, 0);
htable_t visited;
htable_new(&visited, 0);
// next enable any applicable new codes
// to enable any applicable new codes
arraylist_t stack;
arraylist_new(&stack, 0);
htable_t visiting;
htable_new(&visiting, 0);
for (size_t external = 0; external < (ext_ci_list ? 2 : 1); external++) {
if (external)
edges = ext_ci_list;
Expand All @@ -1074,10 +1001,16 @@ static void jl_insert_backedges(jl_array_t *edges, jl_array_t *ext_ci_list, size
jl_code_instance_t *codeinst = (jl_code_instance_t*)jl_array_ptr_ref(edges, i);
jl_svec_t *callees = jl_atomic_load_relaxed(&codeinst->edges);
jl_method_instance_t *caller = codeinst->def;
if (jl_atomic_load_relaxed(&codeinst->min_world) != minworld)
continue;
size_t maxvalid = jl_verify_method_graph(codeinst, minworld, &tovisit, &visited);
assert(jl_atomic_load_relaxed(&codeinst->min_world) == minworld);
if (jl_atomic_load_relaxed(&codeinst->min_world) != minworld) {
if (external && jl_atomic_load_relaxed(&codeinst->max_world) != WORLD_AGE_REVALIDATION_SENTINEL) {
assert(jl_atomic_load_relaxed(&codeinst->min_world) == 1);
assert(jl_atomic_load_relaxed(&codeinst->max_world) == ~(size_t)0);
}
else {
continue;
}
}
size_t maxvalid = jl_verify_method_graph(codeinst, minworld, &stack, &visiting);
assert(jl_atomic_load_relaxed(&codeinst->max_world) == maxvalid);
if (maxvalid == ~(size_t)0) {
// if this callee is still valid, add all the backedges
Expand Down Expand Up @@ -1134,8 +1067,8 @@ static void jl_insert_backedges(jl_array_t *edges, jl_array_t *ext_ci_list, size
}
}

htable_free(&visited);
arraylist_free(&tovisit);
htable_free(&visiting);
arraylist_free(&stack);
}

static jl_value_t *read_verify_mod_list(ios_t *s, jl_array_t *depmods)
Expand Down
2 changes: 1 addition & 1 deletion test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ precompile_test_harness("code caching") do dir
idxs = findall(==("verify_methods"), invalidations)
idxsbits = filter(idxs) do i
mi = invalidations[i-1]
mi.def == m
mi.def.def === m
end
idx = only(idxsbits)
tagbad = invalidations[idx+1]
Expand Down

0 comments on commit 1b9af2e

Please sign in to comment.