@@ -105,13 +105,16 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
105105 " required for layout inference." ;
106106
107107 // Run InferLayout
108+ DLOG (INFO) << " [RunInferStep] working on " << cur_infer_id << ' \n ' ;
108109 auto updates =
109110 next->InferLayout (LayoutInferArgs{target_, thread_bounds, layout_map,
110111 &analyzer_, buffer_oob},
111112 level);
112-
113113 // Process the returned updates
114114 for (const auto &[buffer, layout] : updates) {
115+ DLOG (INFO) << " consider update " << buffer << " as "
116+ << layout->DebugOutput () << ' \n ' ;
117+
115118 // Basic validity checks
116119 ICHECK (buffer.defined ()) << " InferLayout returned an undefined buffer." ;
117120 ICHECK (layout.defined ()) << " InferLayout returned an undefined layout." ;
@@ -140,6 +143,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
140143 if (ProveFragmentContains (src_layout, dst_layout, indices, indices,
141144 inner_analyzer)) {
142145 layout_map.Set (buffer, layout);
146+ DLOG (INFO) << " layout broadcast from "
147+ << src_layout->DebugOutput () << " , accepted" << ' \n ' ;
143148 continue ;
144149 }
145150 }
@@ -151,6 +156,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
151156 } else {
152157 // Otherwise, update map
153158 layout_map.Set (buffer, layout);
159+ DLOG (INFO) << " new layout accepted" << ' \n ' ;
154160 if (!update_queue)
155161 continue ;
156162
@@ -210,6 +216,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
210216 << " Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
211217 " length." ;
212218
219+ DLOG (INFO) << " [InferLayout] all participating operators:" << ' \n ' ;
220+ for (int i = 0 ; i < infer_list_stmt_.size (); ++i) {
221+ DLOG (INFO) << " op " << i << " :" << infer_list_stmt_[i] << ' \n ' ;
222+ }
223+
213224 // If needed, you can also check that annotated_layout_map_ is not empty, or
214225 // anything else relevant to your setup.
215226
@@ -470,6 +481,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
470481
471482 void InferInFreeMode (LayoutMap &layout_map,
472483 const LayoutMap &strict_layout_map) {
484+
485+ DLOG (INFO) << " Enforced layout maps:" << ' \n ' ;
486+ for (auto &&[k, v] : layout_map) {
487+ DLOG (INFO) << " " << k << " : " << v->DebugOutput () << ' \n ' ;
488+ }
489+ DLOG (INFO) << ' \n ' ;
490+
473491 // Group operators into connected components
474492 UnionFind<int > uf;
475493 for (int i = 0 ; i < infer_list_.size (); i++) {
@@ -505,52 +523,53 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
505523 std::vector<bool > in_queue (infer_list_.size (), false );
506524
507525 for (auto &&[root, members] : components) {
526+ DLOG (INFO) << " ======================= processing component " << root
527+ << ' \n ' ;
508528 decltype (infer_list_) best_infer_list;
509529 LayoutMap best_layout_map;
510530 int64_t min_reg_num = INT64_MAX;
531+ int min_reg_num_infer_root = -1 ;
511532
533+ // Try each member as the root of inference for this component
512534 for (int attempt_infer_root : members) {
513- // backup infer_list_ in class member
535+ DLOG (INFO) << " ----------------------- try root " << attempt_infer_root
536+ << ' \n ' ;
537+ // Backup the current infer_list_ state
514538 auto back_infer_list = BackupInferList ();
515- // create temporarily used layout_map, new handle so that it copies on
516- // write
539+ // Copy the current layout_map for temporary use
517540 LayoutMap tmp_layout_map = layout_map;
518- // infer from attempt_infer_root in free mode
519541 bool do_update = true ;
520542 try {
543+ // Run inference starting from attempt_infer_root
521544 RunInferStep (attempt_infer_root, InferLevel::kFree , true ,
522545 tmp_layout_map, strict_layout_map, q, in_queue);
523546 FinishInferQueue (InferLevel::kFree , tmp_layout_map, strict_layout_map,
524547 q, in_queue);
525- // Silly workaround: we have no clue if single root will iterate over
526- // the entire component, since the InferLayout implementations have
527- // complicated conditioning inside and we know nothing about it.
528- // This would constantly result in incomplete layouts for buffers in
529- // this component. Instead of trying all combinations of root
530- // selection order, we simply go through all other loops in order
531- // after the first search from attempt_infer_root.
548+
549+ // After the first search, run inference for all other members in
550+ // order
532551 for (int other_infer_root : members) {
533552 if (other_infer_root != attempt_infer_root) {
534553 RunInferStep (other_infer_root, InferLevel::kFree , true ,
535554 tmp_layout_map, strict_layout_map, q, in_queue);
536- // must also be kFree here to avoid conflicts.
537555 FinishInferQueue (InferLevel::kFree , tmp_layout_map,
538556 strict_layout_map, q, in_queue);
539557 }
540558 }
541- } catch (LayoutConflictException e) {
542- // such an order fails, try others
559+ } catch (const LayoutConflictException &e) {
543560 do_update = false ;
544- } catch (NormalizeIterException e) {
545- // such an order encounters iterators that is not normalizable, try
546- // others e.g. i * 576 % 2048
561+ DLOG (INFO) << " attempt failed due to LayoutConflictException "
562+ << e. what () << ' \n ' ;
563+ } catch ( const NormalizeIterException &e) {
547564 do_update = false ;
565+ DLOG (INFO) << " attempt failed due to NormalizeIterException "
566+ << e.what () << ' \n ' ;
548567 }
549568
550569 if (do_update) {
551- // compute total register number
570+ // Compute the total register number for this layout
552571 int64_t reg_num = 0 ;
553- for (auto & &[buffer, layout] : tmp_layout_map) {
572+ for (const auto &[buffer, layout] : tmp_layout_map) {
554573 if (auto frag = layout.as <Fragment>()) {
555574 int64_t frag_reg_num = 1 ;
556575 for (auto i : frag.value ()->OutputShape ()) {
@@ -561,21 +580,24 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
561580 reg_num += frag_reg_num;
562581 }
563582 }
564- // if it's any better, update the best_* storage
583+ // Update the best plan if this one uses fewer registers
565584 if (reg_num < min_reg_num) {
566- best_infer_list = std::move (infer_list_);
585+ best_infer_list =
586+ BackupInferList (); // Use backup to avoid moving out infer_list_
567587 best_layout_map = tmp_layout_map;
568588 min_reg_num = reg_num;
589+ min_reg_num_infer_root = attempt_infer_root;
569590 }
570591 }
571- // recover stateful infer_list_, head on next
592+ // Restore infer_list_ state for the next attempt
572593 infer_list_ = std::move (back_infer_list);
573594 }
574- if (min_reg_num < INT64_MAX) {
575- // now apply the best plan for this component
576- infer_list_ = std::move (best_infer_list);
577- layout_map = best_layout_map;
578- }
595+ ICHECK (min_reg_num < INT64_MAX) << " no available layout found" << ' \n ' ;
596+ // Apply the best plan for this component
597+ infer_list_ = std::move (best_infer_list);
598+ layout_map = best_layout_map;
599+ DLOG (INFO) << " [InferInFreeMode] Final selection is attempt_infer_root = "
600+ << min_reg_num_infer_root << ' \n ' ;
579601 }
580602 }
581603};
@@ -682,20 +704,25 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
682704 // Here, A_local is a register-local buffer held independently by each
683705 // thread, so explicit thread binding is not required.
684706 //
685- // We use PostOrderVisit to detect whether the buffer store targets a
686- // "local" buffer , which indicates register usage and justifies skipping
707+ // We use PostOrderVisit to detect whether the loop only manuplates
708+ // "local" buffers , which indicates register usage and justifies skipping
687709 // thread binding.
688- bool is_register_store = false ;
710+ bool local_register_only = true ;
689711 PostOrderVisit (root, [&](const ObjectRef &obj) {
690712 if (const auto *store = obj.as <BufferStoreNode>()) {
691- if (store->buffer .scope () == " local" ) {
692- is_register_store = true ;
713+ if (store->buffer .scope () != " local" ) {
714+ local_register_only = false ;
715+ }
716+ } else if (const auto *load = obj.as <BufferLoadNode>()) {
717+ if (load->buffer .scope () != " local" ) {
718+ local_register_only = false ;
693719 }
694720 }
695721 });
696722
697723 auto loop_layout = result_.for_map [root];
698- bool parallel_loop = !is_register_store && !skip_thread_partition_;
724+ // FIXME: tell in-Parallel and out-of-Parallel `local`s apart
725+ bool parallel_loop = !skip_thread_partition_ && !local_register_only;
699726
700727 if (parallel_loop) {
701728 for_node =
0 commit comments