Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/gdn/test_example_gdn_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
block_DK = 64
block_DV = 32
threads = 128
num_stages = 1
num_stages = 0


def test_example_wy_fast_compilation():
Expand Down
209 changes: 198 additions & 11 deletions src/transform/inject_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,77 @@ struct LetWrapper {
PrimExpr value;
};

/*!
* \brief Collector to find all buffers used in a statement.
*
* This is used to collect buffers that are actually used in the pipeline loop
* body, so that we can properly multi-version them for software pipelining.
*/
class BufferUsageCollector : public StmtExprVisitor {
public:
BufferUsageCollector(
const Map<Var, Buffer> &buffer_data_to_buffer,
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
&allocated_buffers)
: buffer_data_to_buffer_(buffer_data_to_buffer),
allocated_buffers_(allocated_buffers) {}

Array<Buffer> Collect(const Stmt &stmt) {
this->VisitStmt(stmt);
Array<Buffer> result;
for (const auto &buffer : used_buffers_) {
result.push_back(buffer);
}
return result;
}

private:
void VisitStmt_(const BufferStoreNode *op) final {
AddBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const BufferLoadNode *op) final {
AddBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const CallNode *op) final {
// Handle tvm_access_ptr which also accesses buffers
if (op->op.same_as(builtin::tvm_access_ptr())) {
if (op->args.size() > 1) {
if (const auto *var = op->args[1].as<VarNode>()) {
auto it = buffer_data_to_buffer_.find(GetRef<Var>(var));
if (it != buffer_data_to_buffer_.end()) {
AddBuffer((*it).second);
}
}
}
}
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BlockNode *op) final {
// Also collect buffers allocated in nested blocks within the pipeline body
for (const auto &buffer : op->alloc_buffers) {
used_buffers_.insert(buffer);
}
StmtExprVisitor::VisitStmt_(op);
}

void AddBuffer(const Buffer &buffer) {
// Only add buffers that are allocated (not function input/output buffers)
if (allocated_buffers_.count(buffer)) {
used_buffers_.insert(buffer);
}
}

const Map<Var, Buffer> &buffer_data_to_buffer_;
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
&allocated_buffers_;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> used_buffers_;
};

/*!
* \brief Create a block and infer the access region with the given body.
*
Expand Down Expand Up @@ -219,13 +290,28 @@ class PipelineBodyRewriter : public StmtExprMutator {
*/
class PipelineRewriter : public StmtExprMutator {
public:
/*!
* \brief Constructor of PipelineRewriter.
* \param buffer_data_to_buffer The map from buffer data to buffer.
* \param pipeline_allocs All buffers that need multi-versioning in the
* pipeline. This includes buffers allocated in the pipeline block and
* buffers allocated in outer blocks that are used in the pipeline.
* \param local_allocs Buffers that are allocated in the pipeline block
* itself. These buffers will be re-allocated in the rewritten block.
* Buffers in pipeline_allocs but not in local_allocs are allocated in outer
* blocks and should not be re-allocated.
* \param pipeline_loop The original loop to be software pipelined.
* \param pipeline_info The pipeline annotation information.
* \param loop_var_let_wrappers Let wrappers that depend on the loop var.
*/
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const Array<Buffer> &pipeline_allocs,
const For &pipeline_loop, const PipelineInfo &pipeline_info,
const Array<Buffer> &local_allocs, const For &pipeline_loop,
const PipelineInfo &pipeline_info,
const std::vector<LetWrapper> &loop_var_let_wrappers)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info),
pipeline_allocs_(pipeline_allocs), local_allocs_(local_allocs),
pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info),
loop_var_let_wrappers_(loop_var_let_wrappers) {}

Stmt BuildPipeline() {
Expand All @@ -234,7 +320,12 @@ class PipelineRewriter : public StmtExprMutator {
std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
infos = GetBufferAccessInfo();
for (const Buffer &buffer : pipeline_allocs_) {
int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
auto it = infos.find(buffer);
if (it == infos.end()) {
// Buffer is not accessed in the pipeline blocks, skip it
continue;
}
int num_versions = ComputeBufferVersions(buffer, it->second);
if (num_versions > 1) {
buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
}
Expand Down Expand Up @@ -302,8 +393,12 @@ class PipelineRewriter : public StmtExprMutator {

// Step 3: Make a new block that contains new buffer allocations after
// pipeline rewriting.
// Only include buffers that are locally allocated in the pipeline block.
// Buffers from outer blocks will be handled separately.
Array<Buffer> alloc_buffers;
for (const auto &alloc : pipeline_allocs_) {
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> local_allocs_set(
local_allocs_.begin(), local_allocs_.end());
for (const auto &alloc : local_allocs_) {
alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
buffer_data_to_buffer_.erase(alloc->data);
}
Expand All @@ -312,6 +407,12 @@ class PipelineRewriter : public StmtExprMutator {
return BlockRealize({}, Bool(true), block);
}

/*!
* \brief Get the buffer remapping created during pipeline rewriting.
* This is used to update alloc_buffers in outer blocks.
*/
const Map<Buffer, Buffer> &GetBufferRemap() const { return buffer_remap_; }

private:
/*!
* \brief Analyze accesses to the buffers in the software pipeline.
Expand Down Expand Up @@ -804,6 +905,7 @@ class PipelineRewriter : public StmtExprMutator {
arith::Analyzer analyzer_;
Map<Var, Buffer> buffer_data_to_buffer_;
Array<Buffer> pipeline_allocs_;
Array<Buffer> local_allocs_;
For pipeline_loop_;
PipelineInfo pipeline_info_;
int max_stage_ = -1;
Expand Down Expand Up @@ -923,14 +1025,17 @@ class PipelineInjector : private StmtExprMutator {
Stmt pipeline_body_root{nullptr};
bool pipeline_body_from_block = false;
Array<Buffer> pipeline_allocs;
Array<Buffer>
block_local_allocs; // buffers allocated in the pipeline block itself
if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
for (const auto &buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
allocated_buffers_.insert(buffer);
}
pipeline_body_root = block->body;
pipeline_allocs = block->alloc_buffers;
block_local_allocs = block->alloc_buffers;
pipeline_body_from_block = true;
} else {
pipeline_body_root = for_node->body;
Expand Down Expand Up @@ -1021,13 +1126,49 @@ class PipelineInjector : private StmtExprMutator {
ICHECK(nested_pipeline_block->match_buffers
.empty()); // match_buffer should have been lowered
for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
pipeline_allocs.push_back(buffer);
buffer_data_to_buffer_.Set(buffer->data, buffer);
allocated_buffers_.insert(buffer);
}
}
f_add_child(child);
}

// Collect all buffers that are actually used in the pipeline loop body.
// This includes buffers allocated in outer blocks (like logits_smem) that
// are used inside the pipeline loop.
BufferUsageCollector collector(buffer_data_to_buffer_, allocated_buffers_);
pipeline_allocs = collector.Collect(SeqStmt(pipeline_body_seq->seq));

// Build a set of local allocs (buffers allocated in the pipeline block
// itself) for efficient lookup
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> local_allocs_set;
for (const auto &buffer : block_local_allocs) {
local_allocs_set.insert(buffer);
}
for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
const Stmt &child = pipeline_body_seq->seq[i];
const auto *nested_block_realize = child.as<BlockRealizeNode>();
if (nested_block_realize && is_one(nested_block_realize->predicate) &&
nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
for (const auto &buffer : nested_block_realize->block->alloc_buffers) {
local_allocs_set.insert(buffer);
}
}
}

// Check if any external buffer (from outer blocks) is already used in
// another pipeline. This would cause conflicts in multi-versioning.
for (const auto &buffer : pipeline_allocs) {
// Only check external buffers (not locally allocated in this pipeline)
if (local_allocs_set.count(buffer) == 0) {
CHECK(buffers_used_in_pipeline_.count(buffer) == 0)
<< "Buffer '" << buffer->name
<< "' is used in multiple software pipeline loops. "
<< "This is not supported because multi-versioning would conflict.";
buffers_used_in_pipeline_.insert(buffer);
}
}

auto pipeline_stages = Downcast<Array<Integer>>(
op->annotations.at(tir::attr::software_pipeline_stage));
auto pipeline_orders = Downcast<Array<Integer>>(
Expand Down Expand Up @@ -1067,10 +1208,32 @@ class PipelineInjector : private StmtExprMutator {
ValidatePipelineBody(pipeline_info, original_order);

// Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
tvm::ffi::GetRef<For>(op), pipeline_info,
loop_var_let_wrappers)
.BuildPipeline();
// local_allocs contains buffers allocated in the pipeline block itself.
// pipeline_allocs contains all buffers that need multi-versioning,
// including buffers from outer blocks.
Array<Buffer> local_allocs = block_local_allocs;
// Add nested block allocs to local_allocs
for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
const Stmt &child = pipeline_body_seq->seq[i];
const auto *nested_block_realize = child.as<BlockRealizeNode>();
if (nested_block_realize && is_one(nested_block_realize->predicate) &&
nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
const Block &nested_pipeline_block = nested_block_realize->block;
for (const auto &buffer : nested_pipeline_block->alloc_buffers) {
local_allocs.push_back(buffer);
}
}
}

PipelineRewriter rewriter(buffer_data_to_buffer_, pipeline_allocs,
local_allocs, tvm::ffi::GetRef<For>(op),
pipeline_info, loop_var_let_wrappers);
Stmt pipeline = rewriter.BuildPipeline();

// Store the buffer remapping for updating outer block alloc_buffers
for (const auto &kv : rewriter.GetBufferRemap()) {
pending_buffer_remap_.Set(kv.first, kv.second);
}
auto apply_wrappers = [&](Stmt stmt) {
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
stmt = (*it)(stmt);
Expand All @@ -1097,6 +1260,7 @@ class PipelineInjector : private StmtExprMutator {
const auto &block = realize->block;
for (const auto &buffer : block->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
allocated_buffers_.erase(buffer);
}
}
return pipeline;
Expand All @@ -1105,18 +1269,35 @@ class PipelineInjector : private StmtExprMutator {
Stmt VisitStmt_(const BlockNode *op) final {
for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
allocated_buffers_.insert(buffer);
}

Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

// Update alloc_buffers with any pending buffer remaps from pipeline
// rewriting. This handles buffers allocated in this block but
// multi-versioned during pipeline rewriting of inner loops.
Array<Buffer> new_alloc_buffers;
for (const auto &buffer : block->alloc_buffers) {
if (auto remapped = pending_buffer_remap_.Get(buffer)) {
new_alloc_buffers.push_back(remapped.value());
// Remove from pending after applying
pending_buffer_remap_.erase(buffer);
} else {
new_alloc_buffers.push_back(buffer);
}
}

Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
BlockNode *n = block.CopyOnWrite();
n->reads = access[0];
n->writes = access[1];
n->alloc_buffers = std::move(new_alloc_buffers);

for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
allocated_buffers_.erase(buffer);
}
return block;
}
Expand All @@ -1141,6 +1322,12 @@ class PipelineInjector : private StmtExprMutator {
}

Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocated_buffers_;
Map<Buffer, Buffer> pending_buffer_remap_;
// Buffers from outer blocks that have been used in a pipeline loop.
// Used to detect if the same buffer is used in multiple pipeline loops.
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
buffers_used_in_pipeline_;
Optional<String> global_symbol_;
};
} // namespace software_pipeline
Expand Down
Loading