diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index e7f9c6093..4ff3b8194 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -82,6 +82,7 @@ def postprocess_kernel( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, }) def bwd( B, @@ -159,9 +160,8 @@ def sparse_mla_bwd_kernel( acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) acc_dkv = T.alloc_fragment([BS, D], accum_dtype) acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) - acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) - acc_dkv_tail_shared = T.view( - KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype) + acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype) + acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype) max_kv_i = s_i diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 2ada9fd08..c3f99f307 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -297,13 +297,17 @@ std::pair LayoutNode::InverseWithLevel() const { } Layout LayoutNode::Reshape(const Array &shape, - arith::Analyzer *analyzer) const { + arith::Analyzer *analyzer, + const PrimExpr rescale_num, + const PrimExpr rescale_den) const { + // Fast path: if shape is the same, return the original layout if (StructuralEqual()(InputShape(), shape)) { return ffi::GetRef(this); } - // Step 1. Prove the product of InputShape is equal to the product of shape + // Step 1. Prove the product relation holds under rescale: + // prod(InputShape) * rescale_num == prod(shape) * rescale_den PrimExpr input_shape_product = Integer(1); for (const auto &dim : InputShape()) { input_shape_product *= dim; @@ -317,8 +321,10 @@ Layout LayoutNode::Reshape(const Array &shape, // potential null dereference paths flagged by static analysis. arith::Analyzer fallback_analyzer; arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer; - ICHECK(az->CanProveEqual(input_shape_product, shape_product)) - << "InputShape() = " << InputShape() << " shape = " << shape; + ICHECK(az->CanProveEqual(input_shape_product * rescale_num, + shape_product * rescale_den)) + << "InputShape() = " << InputShape() << " shape = " << shape + << ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den; // Step 2. Create new forward indices by reshaping // For each dimension in the new shape, we create a placeholder variable @@ -339,13 +345,17 @@ Layout LayoutNode::Reshape(const Array &shape, } flat_index = flat_index + new_vars[i] * stride; } + // Convert new flat index (in units of new elements) to the old flat index + // (in units of old elements) using the rational rescale factor. + // old_flat = floor((flat_index * rescale_den) / rescale_num) + PrimExpr old_flat_index = floordiv(flat_index * rescale_den, rescale_num); // Step 4. Convert flat index back to original shape indices // For original shape [s0, s1, ..., sm]: // i0 = flat_index // (s1 * s2 * ... * sm) // i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm) // ... Array original_indices; - PrimExpr remaining = flat_index; + PrimExpr remaining = old_flat_index; for (size_t i = 0; i < InputShape().size(); ++i) { PrimExpr stride = Integer(1); for (size_t j = i + 1; j < InputShape().size(); ++j) { @@ -373,7 +383,10 @@ Layout LayoutNode::Reshape(const Array &shape, } Layout FragmentNode::Reshape(const Array &shape, - arith::Analyzer *analyzer) const { + arith::Analyzer *analyzer, + const PrimExpr rescale_num, + const PrimExpr rescale_den) const { + // Fast path: identical input shape, return self if (StructuralEqual()(InputShape(), shape)) { return ffi::GetRef(this); @@ -390,8 +403,9 @@ Layout FragmentNode::Reshape(const Array &shape, // Use provided analyzer if present, otherwise a local fallback. arith::Analyzer fallback_analyzer; arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer; - ICHECK(az->CanProveEqual(input_prod, shape_prod)) + ICHECK(az->CanProveEqual(input_prod * rescale_num, shape_prod * rescale_den)) << "InputShape() = " << InputShape() << " shape = " << shape + << ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den << " input fragment layout is = " << DebugOutput(); // 2) Build flat index from new-shape indices @@ -414,9 +428,12 @@ Layout FragmentNode::Reshape(const Array &shape, stride = stride * shape[j]; flat = flat + new_vars[i] * stride; } + // Convert to old flat index units using the rational rescale factor. + // old_flat = floor((flat * rescale_den) / rescale_num) + PrimExpr old_flat = floordiv(flat * rescale_den, rescale_num); // 3) Recover original indices from flat index Array orig_indices; - PrimExpr remain = flat; + PrimExpr remain = old_flat; for (size_t i = 0; i < InputShape().size(); ++i) { PrimExpr stride = Integer(1); for (size_t j = i + 1; j < InputShape().size(); ++j) @@ -536,6 +553,52 @@ bool FragmentNode::IsCompletedReplicated() const { ReplicationPlaceholder()); } +arith::IterMapResult FragmentNode::DetectInjective() const { + // lei:To perform injective check, we need to reverse the layout + // and use surjective check, now we use bijective check for convenience + // can be relaxed in future + arith::Analyzer analyzer; + // Build a flat indices array: [forward_thread_, forward_index_[...]] + Array indices; + indices.push_back(forward_thread_); + for (const auto &e : forward_index_) { + indices.push_back(e); + } + + // Mirror Layout::InverseWithLevel(): if any participating shape is + // symbolic, relax to NoCheck and rely on runtime guards elsewhere. + auto collect_symbolic = [&](const Array &shape) { + Array symbolic_dims; + for (const auto &dim : shape) { + if (!as_const_int(dim)) { + symbolic_dims.push_back(dim); + } + } + return symbolic_dims; + }; + + Array symbolic_dims = collect_symbolic(InputShape()); + Array output_shape = OutputShape(); + symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(), + output_shape.end()); + // Also consider replicate size for fragments + if (!as_const_int(ReplicateExtent())) { + symbolic_dims.push_back(ReplicateExtent()); + } + symbolic_dims = collect_symbolic(symbolic_dims); + + bool is_static_shape = symbolic_dims.empty(); + auto level = is_static_shape ? arith::IterMapLevel::Bijective + : arith::IterMapLevel::NoCheck; + if (!is_static_shape) { + DLOG(WARNING) + << "Fragment::DetectInjective on symbolic layout, falling back to " + << "NoCheck; symbolic dims: " << symbolic_dims; + } + + return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer); +} + PrimExpr FragmentNode::ThreadExtent() const { Array ret(OutputDim(), 1); arith::Analyzer analyzer; diff --git a/src/layout/layout.h b/src/layout/layout.h index afa504187..369df4f2e 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -6,6 +6,7 @@ #ifndef TVM_TL_LAYOUT_LAYOUT_H_ #define TVM_TL_LAYOUT_LAYOUT_H_ +#include #include #include #include @@ -18,6 +19,25 @@ namespace tl { using namespace tir; +// Common layout-related exceptions +class LayoutConflictException : public std::exception { +public: + const char *what() const noexcept override { return msg_.c_str(); } + explicit LayoutConflictException(const std::string &msg) : msg_(msg) {} + +private: + std::string msg_; +}; + +class LoopLayoutInjectiveException : public std::exception { +public: + const char *what() const noexcept override { return msg_.c_str(); } + explicit LoopLayoutInjectiveException(const std::string &msg) : msg_(msg) {} + +private: + std::string msg_; +}; + class Layout; class Fragment; @@ -42,8 +62,18 @@ class LayoutNode : public Object { virtual Layout Inverse() const; + // Reshape the layout to a new logical shape. When aliasing buffers of + // different dtypes, the element count may change while the underlying + // byte-size stays equal. Use rescale_num/rescale_den to represent the + // ratio between the old element size and the new element size in bytes. + // Specifically, define factor = rescale_num / rescale_den where: + // new_num_elems = old_num_elems * factor + // For example, f32->i8 (4B -> 1B) uses rescale_num=4, rescale_den=1. + // i8->f32 (1B -> 4B) uses rescale_num=1, rescale_den=4. virtual Layout Reshape(const Array &shape, - arith::Analyzer *analyzer) const; + arith::Analyzer *analyzer, + const PrimExpr rescale_num = Integer(1), + const PrimExpr rescale_den = Integer(1)) const; virtual std::pair InverseWithLevel() const; @@ -86,7 +116,9 @@ class FragmentNode : public LayoutNode { Layout Inverse() const final; - Layout Reshape(const Array &shape, arith::Analyzer *analyzer) const; + Layout Reshape(const Array &shape, arith::Analyzer *analyzer, + const PrimExpr rescale_num = Integer(1), + const PrimExpr rescale_den = Integer(1)) const; std::pair InverseWithLevel() const final; @@ -116,6 +148,8 @@ class FragmentNode : public LayoutNode { bool IsCompletedReplicated() const; + arith::IterMapResult DetectInjective() const; + static void RegisterReflection(); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode); diff --git a/src/op/copy.cc b/src/op/copy.cc index 7bef87d64..72e73e162 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -551,7 +551,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, // This must be a global/shared layout, so we can skip the parallel op // layout inference (parallel layout inference only annotate the loop layout // and the register layout). - bool is_load = copy_inst == CopyInst::kBulkLoad; + bool is_load = + copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D; Buffer global_tensor = is_load ? src : dst; Buffer shared_tensor = is_load ? dst : src; // check shared layout is non-swizzle @@ -561,6 +562,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, Layout linear_layout = ComputeLinearLayout(shared_tensor); return Map({{shared_tensor, linear_layout}}); } + return {}; } // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 94572098d..7f755b475 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -214,6 +214,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { if (loop_layout_.defined()) return {}; + if (level == InferLevel::kStrict) { LayoutMap results; // Deduce buffers that should be complicated replicated. @@ -562,6 +563,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, } else { return {}; } + // check loop_layout_ is injective + auto injective_res = loop_layout_->DetectInjective(); + if (!injective_res->errors.empty()) { + std::ostringstream oss; + oss << "Loop layout is not injective: " << loop_layout_->DebugOutput() + << '\n' + << " errors: " << injective_res->errors << '\n' + << " loop AST: " << root_; + throw LoopLayoutInjectiveException(oss.str()); + } PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); diff --git a/src/op/parallel.h b/src/op/parallel.h index 8ebd7366e..4ff5484b8 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -24,15 +24,6 @@ namespace tl { using namespace tir; -class LayoutConflictException : public std::exception { -public: - const char *what() const noexcept override { return msg_.c_str(); } - LayoutConflictException(const std::string &msg) : msg_(msg) {} - -private: - std::string msg_; -}; - bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, Array small_frag_indices, Array large_frag_indices, diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index f5ccc42b4..e505bc6ea 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -72,7 +73,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue, LayoutMap &layout_map, const LayoutMap &strict_layout_map, - std::queue &q, std::vector &in_queue) { + std::deque &q, std::vector &in_queue) { auto num_infer = infer_list_.size(); // Range check for cur_infer_id @@ -112,9 +113,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, cur_analyzer, buffer_oob}, level); + // Process the returned updates for (const auto &[buffer, layout] : updates) { - // Basic validity checks ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer."; ICHECK(layout.defined()) << "InferLayout returned an undefined layout."; @@ -140,8 +141,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } } Layout target_layout = - shapes_equal ? src_layout - : src_layout->Reshape(sib->shape, &analyzer_); + shapes_equal + ? src_layout + : src_layout->Reshape(sib->shape, &analyzer_, + Integer(src_buffer->dtype.bytes()), + Integer(sib->dtype.bytes())); if (layout_map.count(sib)) { ICHECK(target_layout->IsEqual(layout_map[sib].get())) << "Get different layout for alias buffer " << sib @@ -152,10 +156,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { layout_map.Set(sib, target_layout); if (update_queue && use_list_.count(sib)) { for (int idx : use_list_[sib]) { - if (!in_queue[idx] && idx != cur_infer_id) { - in_queue[idx] = true; - q.push(idx); - } + EnqueueWithPriority(idx, q, in_queue, cur_infer_id, layout_map); } } } @@ -233,22 +234,20 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { << "Index in use_list_ for buffer " << buffer << " out of range: " << idx << " >= " << num_infer << "."; - if (!in_queue[idx] && idx != cur_infer_id) { - in_queue[idx] = true; - q.push(idx); - } + EnqueueWithPriority(idx, q, in_queue, cur_infer_id, layout_map); } } } }; void FinishInferQueue(InferLevel level, LayoutMap &layout_map, - const LayoutMap &strict_layout_map, std::queue &q, + const LayoutMap &strict_layout_map, std::deque &q, std::vector &in_queue) { auto num_infer = infer_list_.size(); + while (!q.empty()) { int cur_infer_id = q.front(); - q.pop(); + q.pop_front(); // Range check again, just to be safe ICHECK_GE(cur_infer_id, 0); ICHECK_LT(cur_infer_id, num_infer); @@ -289,7 +288,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { int num_infer = infer_list_.size(); // Prepare BFS queue for iterative inference - std::queue q; + std::deque q; std::vector in_queue(num_infer, true); for (int i = 0; i < num_infer; i++) { // Check that each infer_list_ entry is valid @@ -301,7 +300,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { if (!thread_var_vec_[i].defined() && skip_thread_partition_) { thread_var_vec_[i] = thread_var_; } - q.push(i); + q.push_back(i); } // step 1: infer strict layout @@ -352,10 +351,12 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } } - Layout reshaped = - shapes_equal - ? rep_layout.value() - : rep_layout.value()->Reshape(buf->shape, &analyzer_); + Layout reshaped = shapes_equal + ? rep_layout.value() + : rep_layout.value()->Reshape( + buf->shape, &analyzer_, + Integer(rep.value()->dtype.bytes()), + Integer(buf->dtype.bytes())); layout_map.Set(buf, reshaped); } } @@ -431,6 +432,38 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { return buffer_map; } + // Return true if all buffers that this op (idx) touches already have + // inferred layouts in layout_map. Used to prioritize enqueue order. + bool ShouldPrioritize(int idx, const LayoutMap &layout_map) const { + auto it = op_touched_buffers_.find(idx); + if (it == op_touched_buffers_.end() || it->second.empty()) + return false; + for (const auto &buf : it->second) { + if (!layout_map.count(buf)) + return false; + } + return true; + } + + // Enqueue idx to q with priority if all its buffers already + // have layouts. Also guards against duplicates and self-enqueue. + void EnqueueWithPriority(int idx, std::deque &q, + std::vector &in_queue, int cur_infer_id, + const LayoutMap &layout_map) const { + if (idx == cur_infer_id) + return; + if (idx < 0 || idx >= static_cast(in_queue.size())) + return; + if (in_queue[idx]) + return; + in_queue[idx] = true; + if (ShouldPrioritize(idx, layout_map)) { + q.push_front(idx); + } else { + q.push_back(idx); + } + } + void VisitExpr_(const CallNode *op) final { IRVisitorWithAnalyzer::VisitExpr_(op); // Do not analysis the call node to the global function. @@ -536,11 +569,28 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } void addToUseList(const Buffer &buffer) { + // buffer scope must be local.fragment + if (buffer.scope() != "local.fragment") { + return; + } int infer_idx = infer_list_.size(); if (use_list_.find(buffer) == use_list_.end()) { use_list_[buffer] = {}; } use_list_[buffer].push_back(infer_idx); + + // Track which buffers this op (infer_idx) touches for prioritization. + // Avoid duplicates. + auto &vec = op_touched_buffers_[infer_idx]; + bool exists = false; + for (const auto &b : vec) { + if (b.same_as(buffer)) { + exists = true; + break; + } + } + if (!exists) + vec.push_back(buffer); } void VisitStmt_(const ForNode *op) final { @@ -549,6 +599,71 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { for (const auto &[buffer, _] : infer->GetIndiceMap()) { addToUseList(buffer); } + + PostOrderVisit(op->body, [this](const ObjectRef &node) { + if (auto *buffer_load = node.as()) { + if (buffer_load->buffer.defined() && + buffer_load->buffer->data.defined()) { + if (buffer_data_to_buffers_.count(buffer_load->buffer->data)) { + // Check if this buffer is already in the list + auto buffers = buffer_data_to_buffers_[buffer_load->buffer->data]; + bool found = false; + for (const auto &buf : buffers) { + if (buf.same_as(buffer_load->buffer)) { + found = true; + break; + } + } + if (!found) { + buffers.push_back(buffer_load->buffer); + buffer_data_to_buffers_.Set(buffer_load->buffer->data, buffers); + DLOG(INFO) << "[LayoutInference] BufferStore: added buffer " + << buffer_load->buffer + << " buffer.get() = " << buffer_load->buffer.get() + << " data = " << buffer_load->buffer->data.get(); + } + } else { + buffer_data_to_buffers_.Set(buffer_load->buffer->data, + {buffer_load->buffer}); + DLOG(INFO) << "[LayoutInference] BufferStore: new buffer " + << buffer_load->buffer + << " buffer.get() = " << buffer_load->buffer.get() + << " data = " << buffer_load->buffer->data.get(); + } + } + } else if (auto *buffer_store = node.as()) { + if (buffer_store->buffer.defined() && + buffer_store->buffer->data.defined()) { + if (buffer_data_to_buffers_.count(buffer_store->buffer->data)) { + auto buffers = + buffer_data_to_buffers_[buffer_store->buffer->data]; + bool found = false; + for (const auto &buf : buffers) { + if (buf.same_as(buffer_store->buffer)) { + found = true; + break; + } + } + if (!found) { + buffers.push_back(buffer_store->buffer); + buffer_data_to_buffers_.Set(buffer_store->buffer->data, + buffers); + DLOG(INFO) << "[LayoutInference] BufferStore: added buffer " + << buffer_store->buffer + << " buffer.get() = " << buffer_store->buffer.get() + << " data = " << buffer_store->buffer->data.get(); + } + } else { + buffer_data_to_buffers_.Set(buffer_store->buffer->data, + {buffer_store->buffer}); + DLOG(INFO) << "[LayoutInference] BufferStore: new buffer " + << buffer_store->buffer + << " buffer.get() = " << buffer_store->buffer.get() + << " data = " << buffer_store->buffer->data.get(); + } + } + } + }); infer_list_stmt_.push_back(tvm::ffi::GetRef(op)); infer_list_.push_back(std::move(infer)); thread_var_vec_.push_back(thread_var_); @@ -615,7 +730,11 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { if (shapes_equal) { annotated_layout_map_.Set(buffer, layout); } else { - auto reshaped_layout = layout->Reshape(buffer->shape, &analyzer_); + // Use the first buffer sharing this var as the base for dtype ratio + int base_bytes = buffers[0]->dtype.bytes(); + auto reshaped_layout = + layout->Reshape(buffer->shape, &analyzer_, Integer(base_bytes), + Integer(buffer->dtype.bytes())); annotated_layout_map_.Set(buffer, reshaped_layout); } } @@ -699,6 +818,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { std::vector infer_list_; std::unordered_map, ObjectPtrHash, ObjectPtrEqual> use_list_; + // Per-op list of buffers it touches (fragment scope), used for prioritization + std::unordered_map> op_touched_buffers_; // This is a workaround for cpu backend, // we need to define a thread_var for the serial loop. IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), @@ -765,6 +886,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } } } + std::unordered_map> components; for (int i = 0; i < infer_list_.size(); i++) { int root = uf.Find(i); @@ -781,7 +903,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // For each component, try each op as root, and determine the least // replicated one - std::queue q; + std::deque q; std::vector in_queue(infer_list_.size(), false); for (auto &&[root, members] : components) { @@ -795,7 +917,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // Try each member as the root of inference for this component for (int attempt_infer_root : members) { DLOG(INFO) << "----------------------- try root " << attempt_infer_root - << '\n'; + << " members " << members.size() << '\n'; // Backup the current infer_list_ state auto back_infer_list = BackupInferList(); // Copy the current layout_map for temporary use @@ -826,6 +948,10 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { do_update = false; DLOG(INFO) << "attempt failed due to NormalizeIterException " << e.what() << '\n'; + } catch (const LoopLayoutInjectiveException &e) { + do_update = false; + DLOG(INFO) << "attempt failed due to LoopLayoutInjectiveException " + << e.what() << '\n'; } if (do_update) { diff --git a/testing/python/analysis/test_tilelang_fragment_loop_checker.py b/testing/python/analysis/test_tilelang_fragment_loop_checker.py index 9073aebcd..df88573f8 100644 --- a/testing/python/analysis/test_tilelang_fragment_loop_checker.py +++ b/testing/python/analysis/test_tilelang_fragment_loop_checker.py @@ -1,4 +1,5 @@ import tilelang +import tilelang.testing import tilelang.language as T import pytest