diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 7a82a101d..5fa90ba95 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -53,15 +53,18 @@ bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, // (forward index) of both fragments are equal. This is required when // validating loop layout against buffer fragment, as code generation // needs to correctly derive buffer physical indices from loop layout. + bool large_physical_is_fully_replicated = large_frag->IsCompletedReplicated(); + if (large_physical_is_fully_replicated) { + return true; // fully replicated fragments are always compatible + } + if (check_forward_index) { auto small_physical = small_frag->Forward(small_frag_indices); auto large_physical = large_frag->Forward(large_frag_indices); - // Dimension mismatch means they are not equal. if (small_physical.size() != large_physical.size()) { return false; } - // Check each physical index component for equality. for (size_t i = 0; i < small_physical.size(); i++) { auto diff = analyzer_.Simplify(small_physical[i] - large_physical[i]); @@ -649,7 +652,7 @@ bool ParallelOpNode::ValidateCandidateAgainstFragments( // check_forward_index=true: when validating loop layout against buffer // fragment, we need to ensure physical indices match for correct code gen. if (!ProveFragmentContains(candidate, fragment, vars, indice_map_[buffer], - analyzer_)) { + analyzer_, /*check_forward_index=*/true)) { return false; } } @@ -815,8 +818,7 @@ ParallelOpNode::ChooseBestCandidate(const Fragment &candidate_from_buffer, auto contains = [&](const Fragment &big, const Fragment &small) { // contains(A, B) means: for any loop index, the threads that access // B's elements are a subset of those that access A's elements. - return ProveFragmentContains(small, big, vars, vars, analyzer_, - /*check_forward_index=*/true); + return ProveFragmentContains(small, big, vars, vars, analyzer_); }; bool buf_ok = ValidateCandidateAgainstFragments(candidate_from_buffer, T);