Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,12 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
data_ = std::move(n);
}

Fragment Fragment::FullyReplicated(Array<PrimExpr> shape,
PrimExpr thread_extent) {
return Fragment(shape, {}, ReplicationPlaceholder(), thread_extent,
std::nullopt);
}

// which means the forward_thread is rep_var -> lambda i, rep: rep
bool FragmentNode::IsCompletedReplicated() const {
arith::Analyzer analyzer;
Expand Down
14 changes: 14 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,20 @@ class Fragment : public Layout {
PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var);

/*!
* \brief Create a fully replicated fragment layout.
*
* A fully replicated fragment means all threads hold identical copies of the
* entire buffer. This is useful for index buffers or masks that need to be
* accessed uniformly across all threads.
*
* \param shape The shape of the buffer.
* \param thread_extent The number of threads.
* \return A Fragment where each thread has a complete copy of all elements.
*/
TVM_DLL static Fragment FullyReplicated(Array<PrimExpr> shape,
PrimExpr thread_extent);

TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
};

Expand Down
61 changes: 56 additions & 5 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,34 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D;
Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src;

Map<Buffer, Layout> result_map;

// Collect fragment buffers from indices and mark them as fully replicated
// For Bulk Load/Store, fragment buffers used as indices should be
// replicated across all threads
PrimExpr thread_extent = T.thread_bounds->extent;
for (const auto &range : src_range) {
CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
}
for (const auto &range : dst_range) {
CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
}

// check shared layout is non-swizzle
// skip layout inference if shared layout is already annotated
if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) {
// create a new layout map for tma linear layout
Layout linear_layout = ComputeLinearLayout(shared_tensor);
return Map<Buffer, Layout>({{shared_tensor, linear_layout}});
result_map.Set(shared_tensor, linear_layout);
}
return {};
return result_map;
}
// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
Expand All @@ -571,7 +591,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
arith::Analyzer analyzer;
par_op_ = ParallelOp((MakeSIMTLoop(&analyzer)));
}
return par_op_->InferLayout(T, level);
auto layout_map = par_op_->InferLayout(T, level);
return layout_map;
}
/**
* @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA)
Expand Down Expand Up @@ -940,8 +961,13 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
par_op->InferLayout({T.target,
T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
level);
}
auto loop_layout = par_op->GetLoopLayout();
Expand Down Expand Up @@ -2034,6 +2060,31 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
return args;
}

void CopyNode::CollectFragmentLayouts(const PrimExpr &expr,
const Map<Var, PrimExpr> &let_var_to_expr,
const LayoutMap &existing_layouts,
PrimExpr thread_extent,
Range thread_bounds,
Map<Buffer, Layout> &result_map) const {
PostOrderVisit(expr, [&](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.scope() == "local.fragment" &&
!existing_layouts.count(bl->buffer) &&
!result_map.count(bl->buffer)) {
auto f = Fragment::FullyReplicated(bl->buffer->shape, thread_extent);
result_map.Set(bl->buffer, f->BindThreadRange(thread_bounds));
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr.count(var)) {
CollectFragmentLayouts(let_var_to_expr[var], let_var_to_expr,
existing_layouts, thread_extent, thread_bounds,
result_map);
}
}
});
}

// Register the Copy operation with TVM's TIR system
// This makes the copy operation available for use in TVM programs
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
Expand Down
22 changes: 22 additions & 0 deletions src/op/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,28 @@ class CopyNode : public TileOperatorNode {
* @return Reference to the singleton TVM Op representing this operator.
*/
TileOperator Clone() const;

private:
/*!
* \brief Collect fragment buffers from expression and create fully replicated
* layouts.
*
* Recursively searches the expression for BufferLoad nodes with
* "local.fragment" scope, following let bindings. For each found fragment
* buffer, creates a fully replicated layout and adds it to result_map.
*
* \param expr Expression to search.
* \param let_var_to_expr Map from let variables to their bound expressions.
* \param existing_layouts Existing layout map to check for already-inferred
* layouts. \param thread_extent Number of threads for replication. \param
* thread_bounds Thread bounds for binding the layout. \param result_map
* Output map to store collected fragment layouts.
*/
void CollectFragmentLayouts(const PrimExpr &expr,
const Map<Var, PrimExpr> &let_var_to_expr,
const LayoutMap &existing_layouts,
PrimExpr thread_extent, Range thread_bounds,
Map<Buffer, Layout> &result_map) const;
};

class Copy : public TileOperator {
Expand Down
18 changes: 14 additions & 4 deletions src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,13 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
par_op->InferLayout({T.target,
T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
Expand All @@ -176,8 +181,13 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
dst.scope() == "global") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
false, T.buffer_remap},
par_op->InferLayout({T.target,
T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
Expand Down
6 changes: 6 additions & 0 deletions src/op/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
// Map from LetStmt variable to its bound expression, for resolving
// fragment buffer accesses through let bindings
Map<Var, PrimExpr> let_var_to_expr;
};

struct LayoutInferArgs {
Expand All @@ -48,6 +51,9 @@ struct LayoutInferArgs {
arith::Analyzer *analyzer;
bool buffer_oob = false;
Map<Buffer, Buffer> buffer_remap;
// Map from LetStmt variable to its bound expression, for resolving
// fragment buffer accesses through let bindings
Map<Var, PrimExpr> let_var_to_expr;
};

class TileOperator;
Expand Down
33 changes: 33 additions & 0 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,34 @@ TileOperator ParallelOpNode::Clone() const {
return ParallelOp(op);
}

void ParallelOpNode::ExpandLetBindings(
const Map<Var, PrimExpr> &let_var_to_expr) {
if (let_var_to_expr.empty())
return;

// Helper function to recursively find BufferLoads through let bindings
std::function<void(const PrimExpr &)> expand = [&](const PrimExpr &expr) {
PostOrderVisit(expr, [&](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.scope() == "local.fragment" &&
!indice_map_.count(bl->buffer)) {
indice_map_.Set(bl->buffer, bl->indices);
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr.count(var)) {
expand(let_var_to_expr[var]);
}
}
});
};

// Scan all let bindings
for (const auto &[var, expr] : let_var_to_expr) {
expand(expr);
}
}

Stmt ParallelOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
return root_;
Expand Down Expand Up @@ -215,6 +243,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
if (loop_layout_.defined())
return {};

// Expand let bindings to find fragment buffer accesses
if (!T.let_var_to_expr.empty()) {
const_cast<ParallelOpNode *>(this)->ExpandLetBindings(T.let_var_to_expr);
}

if (level == InferLevel::kStrict) {
LayoutMap results;
// Deduce buffers that should be complicated replicated.
Expand Down
4 changes: 4 additions & 0 deletions src/op/parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ class ParallelOpNode : public TileOperatorNode {
void AddPredicate(const PrimExpr &expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
}
// Expand let bindings to find fragment buffer accesses and add them to
// indice_map_. This handles cases like: a = block_mask_f[i]; T.copy(A[a, 0],
// ...)
void ExpandLetBindings(const Map<Var, PrimExpr> &let_var_to_expr);

// Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor;
Expand Down
42 changes: 38 additions & 4 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
"required for layout inference.";

// Run InferLayout
auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
cur_analyzer, buffer_oob},
level);
auto updates = next->InferLayout(LayoutInferArgs{target_,
thread_bounds,
layout_map,
cur_analyzer,
buffer_oob,
{},
let_var_to_expr_},
level);

// Process the returned updates
for (const auto &[buffer, layout] : updates) {
Expand Down Expand Up @@ -479,6 +483,10 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
} else if (auto buffer = getBufferFromRegion(arg)) {
addToUseList(buffer.value());
}
// Check if the argument uses any LetStmt variables that reference
// fragment buffers. If so, add those buffers to the use list.
// This handles cases like: a = block_mask_f[i]; T.copy(A[a, 0], ...)
CollectFragmentBuffersFromExpr(arg);
}
// Compute thread_var_ and thread_bounds_
thread_var_vec_.push_back(thread_var_);
Expand Down Expand Up @@ -754,6 +762,30 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
IRVisitorWithAnalyzer::VisitStmt_(op);
}

void VisitStmt_(const LetStmtNode *op) final {
// Record Let variable to its bound expression.
// This enables tracking fragment buffer accesses through let bindings.
let_var_to_expr_.Set(op->var, op->value);
IRVisitorWithAnalyzer::VisitStmt_(op);
}

// Helper: recursively collect fragment buffers from an expression,
// following let bindings chain.
void CollectFragmentBuffersFromExpr(const PrimExpr &expr) {
PostOrderVisit(expr, [this](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.defined() && bl->buffer.scope() == "local.fragment") {
addToUseList(bl->buffer);
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr_.count(var)) {
CollectFragmentBuffersFromExpr(let_var_to_expr_[var]);
}
}
});
}

void VisitExpr_(const BufferLoadNode *op) final {
// Collect buffer from BufferLoad
if (op->buffer.defined() && op->buffer->data.defined()) {
Expand Down Expand Up @@ -815,6 +847,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
}

Map<Var, Array<Buffer>> buffer_data_to_buffers_;
// Map from LetStmt variable to its bound expression
Map<Var, PrimExpr> let_var_to_expr_;
std::vector<ObjectRef> infer_list_stmt_;
std::vector<TileOperator> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
Expand Down
3 changes: 1 addition & 2 deletions src/transform/layout_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
const auto &buffer = opt_buffer.value();
Fragment f;
if (info->rep == ReducerRepType::ALL) {
f = Fragment(buffer->shape, {}, ReplicationPlaceholder(),
thread_extent, std::nullopt);
f = Fragment::FullyReplicated(buffer->shape, thread_extent);
} else if (info->rep == ReducerRepType::NONE) {
PrimExpr flatten_idx = InputPlaceholder(0);
for (int i = 1; i < buffer->shape.size(); ++i)
Expand Down
14 changes: 10 additions & 4 deletions src/transform/lower_tile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,10 +638,16 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
thread_bounds = Range::FromMinExtent(0, 1);
}

auto lowered =
tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var,
callback, layout_map_, buffer_remap_},
analyzer_);
// Convert let_bindings_ to Map<Var, PrimExpr> for LowerArgs
Map<Var, PrimExpr> let_var_to_expr;
for (const auto &[var, expr] : let_bindings_) {
let_var_to_expr.Set(var, expr);
}

auto lowered = tile_op->Lower(
LowerArgs{target_, thread_bounds, thread_var_->var, callback,
layout_map_, buffer_remap_, let_var_to_expr},
analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered);
}

Expand Down
Loading
Loading