Skip to content

Commit e682465

Browse files
authored
[Enhancement] Add strict layout map for improved buffer layout inference (#594)
- Introduced a `strict_layout_map` to enhance layout inference by ensuring that buffers with strict layout requirements are properly accounted for during the inference process. - Updated the inference logic to check for the presence of buffers in the `strict_layout_map` before applying layout changes, improving the accuracy of layout assignments. - Refactored the layout inference steps to include the copying of layouts into the new strict map, ensuring a clear separation of layout handling based on inference levels.
1 parent f8fb5c4 commit e682465

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/transform/layout_inference.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
228228

229229
// Copy the annotated layout map to local variable
230230
Map<Buffer, Layout> layout_map = annotated_layout_map_;
231+
Map<Buffer, Layout> strict_layout_map;
231232
int num_infer = infer_list_.size();
232233

233234
// Prepare BFS queue for iterative inference
@@ -245,6 +246,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
245246
}
246247
q.push(i);
247248
}
249+
248250
auto run_infer_step = [&](int cur_infer_id, InferLevel level,
249251
bool update_queue) {
250252
// Range check for cur_infer_id
@@ -290,7 +292,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
290292
if (layout_map.count(buffer)) {
291293
// If replicate size of this buffer is greater than the old one
292294
if (buffer.scope() == "local.fragment" &&
293-
level != InferLevel::kStrict) {
295+
level != InferLevel::kStrict &&
296+
!strict_layout_map.count(buffer)) {
294297
const FragmentNode *dst_layout = layout.as<Fragment>().get();
295298
const FragmentNode *src_layout =
296299
layout_map[buffer].as<Fragment>().get();
@@ -358,6 +361,10 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
358361
run_infer_step(i, InferLevel::kStrict, false);
359362
}
360363

364+
for (const auto &[buffer, layout] : layout_map) {
365+
strict_layout_map.Set(buffer, layout);
366+
}
367+
361368
// step 2: infer common layout with BFS
362369
finish_infer_queue();
363370

@@ -366,7 +373,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
366373
run_infer_step(i, InferLevel::kFree, true);
367374
finish_infer_queue();
368375
}
369-
370376
// Check that all local.fragment buffers have inferred layouts
371377
for (const auto &[buffer, _] : use_list_) {
372378
if (buffer.scope() == "local.fragment") {

0 commit comments

Comments
 (0)