[Refactor] Reorganize ParallelOp code structure and move ProveFragmentContains to layout utils#1779
Conversation
…d Validation - Added the ProveFragmentContains function to check if the threads accessing elements of a smaller fragment are a subset of those accessing a larger fragment. - This function ensures valid access when transitioning from a smaller to a larger fragment layout. - Updated layout.cc and utils.cc to incorporate this new functionality, enhancing the layout validation process. - Removed the previous implementation of ProveFragmentContains from parallel.cc to streamline the codebase.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughMoves fragment-containment verification into layout utilities (adds ProveFragmentContains), binds thread ranges onto de-replicated and fully-replicated Fragments, removes local ProveFragmentContains from ParallelOp, and refactors ParallelOpNode to use IsBufferCompletelyReplicated, add cross-thread/store tracking, and update validation APIs and inference flow. Changes
Sequence Diagram(s)sequenceDiagram
participant Parallel as ParallelOpNode
participant Layout as LayoutUtils::ProveFragmentContains
participant Analyzer as Analyzer
participant Buffer as Buffer/Fragment
Parallel->>Buffer: collect fragments, accesses, indices
Parallel->>Parallel: IsBufferCompletelyReplicated(buffer, layout_map)
alt buffer fully replicated
Parallel-->>Parallel: skip containment proof, bind thread range
else not fully replicated
Parallel->>Layout: ProveFragmentContains(small_frag, large_frag, small_idx, large_idx, Analyzer, check_forward)
Layout-->>Parallel: true / false
Parallel->>Parallel: ValidateCandidateAgainstFragments(candidate, T, throw_on_error, check_forward_index, source_buffer)
Parallel->>Parallel: build replication guards, bind thread ranges
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/op/parallel.cc`:
- Around line 476-487: There is a stray semicolon after the call to
DeReplicate() and several info logs should be debug-only: remove the extraneous
semicolon following loop_layout_->DeReplicate() and ensure the result is
assigned to dereplicated_layout as intended, and replace LOG(INFO) calls inside
ComputeLoopLayoutFromBuffer with DLOG(INFO) for consistent debug logging; check
the code paths around DeReplicate(), loop_layout_, and
ValidateCandidateAgainstFragments to confirm behavior is unchanged after
removing the semicolon and update the LOG calls to DLOG to match the file's
debug logging convention.
🧹 Nitpick comments (3)
src/layout/utils.cc (1)
401-434: Analyzer binding has potential side effects on caller.The
analyzer.Bind(rep_small, ..., true)call modifies the analyzer's state. Since the analyzer is passed by reference, this binding persists after the function returns, which could affect subsequent analysis in the caller.Consider whether this is intentional. If not, you may want to either:
- Document this side effect clearly
- Use a local analyzer copy if isolation is needed
src/op/parallel.cc (2)
418-425: Consider usingDLOG(INFO)for debug logging.These
LOG(INFO)statements will emit output in production builds. Other debug logging in this file usesDLOG(INFO)(e.g., lines 434, 437-438, 621-622). For consistency and to avoid noisy logs in production, consider changing these toDLOG(INFO).♻️ Proposed fix
if (read_source_buffer.defined() && allow_layout_propgate) { candidate_from_buffer = ComputeLoopLayoutFromBuffer(read_source_buffer, T); - LOG(INFO) << "read_source_buffer: " << read_source_buffer; - LOG(INFO) << "candidate_from_buffer: " << candidate_from_buffer; + DLOG(INFO) << "read_source_buffer: " << read_source_buffer; + DLOG(INFO) << "candidate_from_buffer: " << candidate_from_buffer; } // try to infer loop layout with two mechanisms and choose the best one { candidate_from_plan = ComputePlanCandidate(T); - LOG(INFO) << "candidate_from_plan: " << candidate_from_plan; + DLOG(INFO) << "candidate_from_plan: " << candidate_from_plan; }
629-656: UseDLOG(INFO)instead ofLOG(INFO)for debug logging.These logging statements should use
DLOG(INFO)for consistency with the rest of the file and to avoid verbose output in production.♻️ Proposed fix
Var rep("_rep"); auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); - LOG(INFO) << "rep: " << rep; - LOG(INFO) << "src_layout: " << src_layout->DebugOutput(); - LOG(INFO) << "src_layout->DeReplicate(): " - << src_layout->DeReplicate()->DebugOutput(); - LOG(INFO) << "Create rep_iter: " << rep_iter - << " from rep_extent: " << src_layout->ReplicateExtent(); + DLOG(INFO) << "rep: " << rep; + DLOG(INFO) << "src_layout: " << src_layout->DebugOutput(); + DLOG(INFO) << "src_layout->DeReplicate(): " + << src_layout->DeReplicate()->DebugOutput(); + DLOG(INFO) << "Create rep_iter: " << rep_iter + << " from rep_extent: " << src_layout->ReplicateExtent(); PrimExpr loop_var_to_thread = src_layout->ForwardThread(GetAccessInfo(buffer).indices, rep); loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); - LOG(INFO) << "loop_var_to_thread after simplify: " << loop_var_to_thread; + DLOG(INFO) << "loop_var_to_thread after simplify: " << loop_var_to_thread; ... try { result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) ->BindThreadRange(T.thread_bounds); - LOG(INFO) << "result: " << result; - LOG(INFO) << "result->DeReplicate(): " << result->DeReplicate(); + DLOG(INFO) << "result: " << result; + DLOG(INFO) << "result->DeReplicate(): " << result->DeReplicate(); } catch (const tvm::runtime::Error &err) {
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/op/parallel.cc`:
- Around line 208-222: The ParallelOpNode::IsBufferCompletelyReplicated function
unsafely calls layout_map[buffer].as<Fragment>().value() and inconsistently
treats index errors (returns false for non-IntImm but LOG(FATAL) for non-zero
IntImm); fix by first checking the as<Fragment>() cast result (e.g., auto
frag_opt = layout_map[buffer].as<Fragment>(); if (!frag_opt) return false)
before accessing value, and make index handling consistent by returning false
(not LOG(FATAL)) when any index is non-IntImm or an IntImm with value != 0; use
GetAccessInfo(buffer).indices to perform these checks and then call
frag_opt->IsCompletedReplicated() if all checks pass.
🧹 Nitpick comments (3)
src/op/parallel.cc (3)
120-135: Potential duplicate buffer entries in tracking vectors.The
PostOrderVisitmay encounter multiple stores to the same buffer, adding duplicates tostore_shared_global_buffers_andstore_fragment_buffers_. If downstream code iterates these vectors assuming unique entries, this could cause redundant processing or incorrect behavior.Consider using a set or checking for existence before insertion:
♻️ Proposed fix to avoid duplicates
PostOrderVisit(root_, [&](const ObjectRef &obj) { if (const auto *store = obj.as<BufferStoreNode>()) { auto buffer = store->buffer; if (IsSharedBuffer(buffer) || IsGlobalBuffer(buffer)) { has_cross_thread_access_ = true; - store_shared_global_buffers_.emplace_back(buffer); + if (std::find(store_shared_global_buffers_.begin(), + store_shared_global_buffers_.end(), buffer) == + store_shared_global_buffers_.end()) { + store_shared_global_buffers_.emplace_back(buffer); + } } else if (IsFragmentBuffer(buffer)) { - store_fragment_buffers_.emplace_back(buffer); + if (std::find(store_fragment_buffers_.begin(), + store_fragment_buffers_.end(), buffer) == + store_fragment_buffers_.end()) { + store_fragment_buffers_.emplace_back(buffer); + } } } else if (const auto *load = obj.as<BufferLoadNode>()) {
477-484: Consider skipping redundant re-validation after successful DeReplicate.When
ValidateCandidateAgainstFragmentssucceeds fordereplicated_layout(line 477-480),loop_layout_is updated todereplicated_layout. The subsequent call on line 482-484 then validates the same layout again withthrow_on_error=true, which is redundant since it was just validated.♻️ Proposed optimization
Fragment dereplicated_layout = loop_layout_->DeReplicate(); if (ValidateCandidateAgainstFragments( dereplicated_layout, T, /*throw_on_error=*/false, /*check_forward_index=*/false, source_buffer)) { loop_layout_ = dereplicated_layout; + } else { + // DeReplicate failed validation; validate original with throwing enabled + ValidateCandidateAgainstFragments(loop_layout_, T, /*throw_on_error=*/true, + /*check_forward_index=*/false, + source_buffer); } - ValidateCandidateAgainstFragments(loop_layout_, T, /*throw_on_error=*/true, - /*check_forward_index=*/false, - source_buffer);
626-630: Reconsider the leading underscore in variable name.The variable was renamed from
repto_rep, but it's actively used on line 630. Leading underscores typically indicate unused variables or have reserved meanings in C++ (identifiers starting with underscore followed by uppercase, or double underscores, are reserved). While_repis technically allowed, it may be misleading since the variable is not unused.♻️ Suggested alternatives
- Var rep("_rep"); + Var rep("rep");Or if distinguishing from other
repvariables is needed:- Var rep("_rep"); + Var rep("loop_rep");
| bool ParallelOpNode::IsBufferCompletelyReplicated( | ||
| const Buffer &buffer, const LayoutMap &layout_map) const { | ||
| if (!IsFragmentBuffer(buffer)) | ||
| return false; | ||
| auto frag = layout_map[buffer].as<Fragment>().value(); | ||
| // buffer indices should be IntImm | ||
| for (const auto &index : GetAccessInfo(buffer).indices) { | ||
| if (!index.as<IntImmNode>()) { | ||
| return false; | ||
| } else if (index.as<IntImmNode>()->value != 0) { | ||
| LOG(FATAL) << "buffer " << buffer << " is not completed replicated"; | ||
| } | ||
| } | ||
| return frag->IsCompletedReplicated(); | ||
| } |
There was a problem hiding this comment.
Unsafe .value() access and inconsistent error handling.
Two concerns:
-
Line 212:
T.layout_map[buffer].as<Fragment>().value()calls.value()without checking if the cast succeeded. If the layout is not aFragment, this will crash. -
Lines 217-218: The method returns
falsefor non-IntImmindices but callsLOG(FATAL)for non-zeroIntImmvalues. This inconsistency is confusing—consider either returningfalsefor both or documenting why non-zero indices are a fatal error.
🛡️ Proposed fix for safe access
bool ParallelOpNode::IsBufferCompletelyReplicated(
const Buffer &buffer, const LayoutMap &layout_map) const {
if (!IsFragmentBuffer(buffer))
return false;
- auto frag = layout_map[buffer].as<Fragment>().value();
+ auto frag_opt = layout_map[buffer].as<Fragment>();
+ if (!frag_opt.has_value())
+ return false;
+ auto frag = frag_opt.value();
// buffer indices should be IntImm
for (const auto &index : GetAccessInfo(buffer).indices) {
if (!index.as<IntImmNode>()) {
return false;
} else if (index.as<IntImmNode>()->value != 0) {
- LOG(FATAL) << "buffer " << buffer << " is not completed replicated";
+ return false; // Non-zero index means not completely replicated
}
}
return frag->IsCompletedReplicated();
}🤖 Prompt for AI Agents
In `@src/op/parallel.cc` around lines 208 - 222, The
ParallelOpNode::IsBufferCompletelyReplicated function unsafely calls
layout_map[buffer].as<Fragment>().value() and inconsistently treats index errors
(returns false for non-IntImm but LOG(FATAL) for non-zero IntImm); fix by first
checking the as<Fragment>() cast result (e.g., auto frag_opt =
layout_map[buffer].as<Fragment>(); if (!frag_opt) return false) before accessing
value, and make index handling consistent by returning false (not LOG(FATAL))
when any index is non-IntImm or an IntImm with value != 0; use
GetAccessInfo(buffer).indices to perform these checks and then call
frag_opt->IsCompletedReplicated() if all checks pass.
- Removed the initial DeReplicate attempt from InferLayout to streamline layout inference. - Added DeReplicate logic to ComputeLoopLayoutFromBuffer to reduce replication when validating layout candidates. - Updated test cases to disable caching and ensure proper functionality of loop layout kernels.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@testing/python/layout/test_tilelang_annotate_loop_layout.py`:
- Around line 110-112: The test runner was replaced with a direct call to
test_loop_layout_identity(), skipping other tests; restore the standard runner
by re-enabling tilelang.testing.main() (or calling the test framework's main
entry) in the if __name__ == "__main__": block and remove the explicit
test_loop_layout_identity() invocation so that test_loop_layout_fragment_vec4,
test_copy_loop_layout_annotated_replicate_vec4, and
test_annotate_replicate_loop_layout_vec4 are executed via the test runner.
- Around line 43-47: Remove the debug print and either remove or document the
cache disable: delete the print(code) statement that prints kernel source (from
kernel.get_kernel_source()) and if tilelang.disable_cache() was only used for
temporary debugging, remove that call; if disabling cache is required for this
test, add a brief inline comment next to tilelang.disable_cache() explaining why
caching must be disabled for this test (e.g., to ensure fresh kernel
compilation). Ensure the assertion comparing the generated code remains
unchanged.
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
|
@regression-perf |
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
- Removed caching disablement and print statements from the loop layout identity test for cleaner output. - Updated the main execution block to directly call the testing framework, enhancing test execution flow.
Summary
ProveFragmentContainsfunction fromparallel.cctolayout/utils.ccfor better code organizationIfBufferRemapLoopGeneratorhelper class into anonymous namespacebuffer_is_completed_replicatedlambda to private methodIsBufferCompletelyReplicatedDeReplicateoptimization attempt before layout validation to reduce replication when possibleValidateCandidateAgainstFragmentsto support both bool return and exception throwing modesChanges
src/layout/utils.h: AddProveFragmentContainsfunction declarationsrc/layout/utils.cc: AddProveFragmentContainsfunction implementationsrc/op/parallel.h:ProveFragmentContainsdeclaration (moved to utils)IsBufferCompletelyReplicatedprivate methodValidateCandidateAgainstFragmentssignature with optional parameterssrc/op/parallel.cc:ProveFragmentContainsimplementationIfBufferRemapLoopGeneratorto anonymous namespaceIsBufferCompletelyReplicatedmethod implementationInferLayoutDeReplicateattempt before validationValidateCandidateAgainstFragmentsSummary by CodeRabbit
Refactor
Improvements
New Features
Tests