diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index ab6389eaa..6f9fa5d2f 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -20,7 +20,7 @@ block_DK = 64 block_DV = 32 threads = 128 -num_stages = 1 +num_stages = 0 def test_example_wy_fast_compilation(): diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 79e78add9..9df8a5a1a 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -26,6 +26,77 @@ struct LetWrapper { PrimExpr value; }; +/*! + * \brief Collector to find all buffers used in a statement. + * + * This is used to collect buffers that are actually used in the pipeline loop + * body, so that we can properly multi-version them for software pipelining. + */ +class BufferUsageCollector : public StmtExprVisitor { +public: + BufferUsageCollector( + const Map &buffer_data_to_buffer, + const std::unordered_set + &allocated_buffers) + : buffer_data_to_buffer_(buffer_data_to_buffer), + allocated_buffers_(allocated_buffers) {} + + Array Collect(const Stmt &stmt) { + this->VisitStmt(stmt); + Array result; + for (const auto &buffer : used_buffers_) { + result.push_back(buffer); + } + return result; + } + +private: + void VisitStmt_(const BufferStoreNode *op) final { + AddBuffer(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode *op) final { + AddBuffer(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const CallNode *op) final { + // Handle tvm_access_ptr which also accesses buffers + if (op->op.same_as(builtin::tvm_access_ptr())) { + if (op->args.size() > 1) { + if (const auto *var = op->args[1].as()) { + auto it = buffer_data_to_buffer_.find(GetRef(var)); + if (it != buffer_data_to_buffer_.end()) { + AddBuffer((*it).second); + } + } + } + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BlockNode *op) final { + // Also collect buffers allocated in nested blocks within the pipeline body + for (const auto &buffer : op->alloc_buffers) { + used_buffers_.insert(buffer); + } + StmtExprVisitor::VisitStmt_(op); + } + + void AddBuffer(const Buffer &buffer) { + // Only add buffers that are allocated (not function input/output buffers) + if (allocated_buffers_.count(buffer)) { + used_buffers_.insert(buffer); + } + } + + const Map &buffer_data_to_buffer_; + const std::unordered_set + &allocated_buffers_; + std::unordered_set used_buffers_; +}; + /*! * \brief Create a block and infer the access region with the given body. * @@ -219,13 +290,28 @@ class PipelineBodyRewriter : public StmtExprMutator { */ class PipelineRewriter : public StmtExprMutator { public: + /*! + * \brief Constructor of PipelineRewriter. + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \param pipeline_allocs All buffers that need multi-versioning in the + * pipeline. This includes buffers allocated in the pipeline block and + * buffers allocated in outer blocks that are used in the pipeline. + * \param local_allocs Buffers that are allocated in the pipeline block + * itself. These buffers will be re-allocated in the rewritten block. + * Buffers in pipeline_allocs but not in local_allocs are allocated in outer + * blocks and should not be re-allocated. + * \param pipeline_loop The original loop to be software pipelined. + * \param pipeline_info The pipeline annotation information. + * \param loop_var_let_wrappers Let wrappers that depend on the loop var. + */ PipelineRewriter(Map buffer_data_to_buffer, const Array &pipeline_allocs, - const For &pipeline_loop, const PipelineInfo &pipeline_info, + const Array &local_allocs, const For &pipeline_loop, + const PipelineInfo &pipeline_info, const std::vector &loop_var_let_wrappers) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), - pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop), - pipeline_info_(pipeline_info), + pipeline_allocs_(pipeline_allocs), local_allocs_(local_allocs), + pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info), loop_var_let_wrappers_(loop_var_let_wrappers) {} Stmt BuildPipeline() { @@ -234,7 +320,12 @@ class PipelineRewriter : public StmtExprMutator { std::unordered_map infos = GetBufferAccessInfo(); for (const Buffer &buffer : pipeline_allocs_) { - int num_versions = ComputeBufferVersions(buffer, infos.at(buffer)); + auto it = infos.find(buffer); + if (it == infos.end()) { + // Buffer is not accessed in the pipeline blocks, skip it + continue; + } + int num_versions = ComputeBufferVersions(buffer, it->second); if (num_versions > 1) { buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); } @@ -302,8 +393,12 @@ class PipelineRewriter : public StmtExprMutator { // Step 3: Make a new block that contains new buffer allocations after // pipeline rewriting. + // Only include buffers that are locally allocated in the pipeline block. + // Buffers from outer blocks will be handled separately. Array alloc_buffers; - for (const auto &alloc : pipeline_allocs_) { + std::unordered_set local_allocs_set( + local_allocs_.begin(), local_allocs_.end()); + for (const auto &alloc : local_allocs_) { alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); buffer_data_to_buffer_.erase(alloc->data); } @@ -312,6 +407,12 @@ class PipelineRewriter : public StmtExprMutator { return BlockRealize({}, Bool(true), block); } + /*! + * \brief Get the buffer remapping created during pipeline rewriting. + * This is used to update alloc_buffers in outer blocks. + */ + const Map &GetBufferRemap() const { return buffer_remap_; } + private: /*! * \brief Analyze accesses to the buffers in the software pipeline. @@ -804,6 +905,7 @@ class PipelineRewriter : public StmtExprMutator { arith::Analyzer analyzer_; Map buffer_data_to_buffer_; Array pipeline_allocs_; + Array local_allocs_; For pipeline_loop_; PipelineInfo pipeline_info_; int max_stage_ = -1; @@ -923,14 +1025,17 @@ class PipelineInjector : private StmtExprMutator { Stmt pipeline_body_root{nullptr}; bool pipeline_body_from_block = false; Array pipeline_allocs; + Array + block_local_allocs; // buffers allocated in the pipeline block itself if (const auto *realize = for_node->body.as()) { const auto &block = realize->block; for (const auto &buffer : block->alloc_buffers) { ICHECK(buffer->IsInstance()); buffer_data_to_buffer_.Set(buffer->data, buffer); + allocated_buffers_.insert(buffer); } pipeline_body_root = block->body; - pipeline_allocs = block->alloc_buffers; + block_local_allocs = block->alloc_buffers; pipeline_body_from_block = true; } else { pipeline_body_root = for_node->body; @@ -1021,13 +1126,49 @@ class PipelineInjector : private StmtExprMutator { ICHECK(nested_pipeline_block->match_buffers .empty()); // match_buffer should have been lowered for (const auto &buffer : nested_pipeline_block->alloc_buffers) { - pipeline_allocs.push_back(buffer); buffer_data_to_buffer_.Set(buffer->data, buffer); + allocated_buffers_.insert(buffer); } } f_add_child(child); } + // Collect all buffers that are actually used in the pipeline loop body. + // This includes buffers allocated in outer blocks (like logits_smem) that + // are used inside the pipeline loop. + BufferUsageCollector collector(buffer_data_to_buffer_, allocated_buffers_); + pipeline_allocs = collector.Collect(SeqStmt(pipeline_body_seq->seq)); + + // Build a set of local allocs (buffers allocated in the pipeline block + // itself) for efficient lookup + std::unordered_set local_allocs_set; + for (const auto &buffer : block_local_allocs) { + local_allocs_set.insert(buffer); + } + for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { + const Stmt &child = pipeline_body_seq->seq[i]; + const auto *nested_block_realize = child.as(); + if (nested_block_realize && is_one(nested_block_realize->predicate) && + nested_block_realize->block->body->IsInstance()) { + for (const auto &buffer : nested_block_realize->block->alloc_buffers) { + local_allocs_set.insert(buffer); + } + } + } + + // Check if any external buffer (from outer blocks) is already used in + // another pipeline. This would cause conflicts in multi-versioning. + for (const auto &buffer : pipeline_allocs) { + // Only check external buffers (not locally allocated in this pipeline) + if (local_allocs_set.count(buffer) == 0) { + CHECK(buffers_used_in_pipeline_.count(buffer) == 0) + << "Buffer '" << buffer->name + << "' is used in multiple software pipeline loops. " + << "This is not supported because multi-versioning would conflict."; + buffers_used_in_pipeline_.insert(buffer); + } + } + auto pipeline_stages = Downcast>( op->annotations.at(tir::attr::software_pipeline_stage)); auto pipeline_orders = Downcast>( @@ -1067,10 +1208,32 @@ class PipelineInjector : private StmtExprMutator { ValidatePipelineBody(pipeline_info, original_order); // Step 4: Rewrite the pipeline body. - Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, - tvm::ffi::GetRef(op), pipeline_info, - loop_var_let_wrappers) - .BuildPipeline(); + // local_allocs contains buffers allocated in the pipeline block itself. + // pipeline_allocs contains all buffers that need multi-versioning, + // including buffers from outer blocks. + Array local_allocs = block_local_allocs; + // Add nested block allocs to local_allocs + for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { + const Stmt &child = pipeline_body_seq->seq[i]; + const auto *nested_block_realize = child.as(); + if (nested_block_realize && is_one(nested_block_realize->predicate) && + nested_block_realize->block->body->IsInstance()) { + const Block &nested_pipeline_block = nested_block_realize->block; + for (const auto &buffer : nested_pipeline_block->alloc_buffers) { + local_allocs.push_back(buffer); + } + } + } + + PipelineRewriter rewriter(buffer_data_to_buffer_, pipeline_allocs, + local_allocs, tvm::ffi::GetRef(op), + pipeline_info, loop_var_let_wrappers); + Stmt pipeline = rewriter.BuildPipeline(); + + // Store the buffer remapping for updating outer block alloc_buffers + for (const auto &kv : rewriter.GetBufferRemap()) { + pending_buffer_remap_.Set(kv.first, kv.second); + } auto apply_wrappers = [&](Stmt stmt) { for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) { stmt = (*it)(stmt); @@ -1097,6 +1260,7 @@ class PipelineInjector : private StmtExprMutator { const auto &block = realize->block; for (const auto &buffer : block->alloc_buffers) { buffer_data_to_buffer_.erase(buffer->data); + allocated_buffers_.erase(buffer); } } return pipeline; @@ -1105,18 +1269,35 @@ class PipelineInjector : private StmtExprMutator { Stmt VisitStmt_(const BlockNode *op) final { for (const auto &buffer : op->alloc_buffers) { buffer_data_to_buffer_.Set(buffer->data, buffer); + allocated_buffers_.insert(buffer); } Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + // Update alloc_buffers with any pending buffer remaps from pipeline + // rewriting. This handles buffers allocated in this block but + // multi-versioned during pipeline rewriting of inner loops. + Array new_alloc_buffers; + for (const auto &buffer : block->alloc_buffers) { + if (auto remapped = pending_buffer_remap_.Get(buffer)) { + new_alloc_buffers.push_back(remapped.value()); + // Remove from pending after applying + pending_buffer_remap_.erase(buffer); + } else { + new_alloc_buffers.push_back(buffer); + } + } + Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); BlockNode *n = block.CopyOnWrite(); n->reads = access[0]; n->writes = access[1]; + n->alloc_buffers = std::move(new_alloc_buffers); for (const auto &buffer : op->alloc_buffers) { buffer_data_to_buffer_.erase(buffer->data); + allocated_buffers_.erase(buffer); } return block; } @@ -1141,6 +1322,12 @@ class PipelineInjector : private StmtExprMutator { } Map buffer_data_to_buffer_; + std::unordered_set allocated_buffers_; + Map pending_buffer_remap_; + // Buffers from outer blocks that have been used in a pipeline loop. + // Used to detect if the same buffer is used in multiple pipeline loops. + std::unordered_set + buffers_used_in_pipeline_; Optional global_symbol_; }; } // namespace software_pipeline