Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from b487ec to 23bce0
6 changes: 6 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ Layout makeQuarterBankSwizzleLayout(int stride, int continuous,
namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block
constexpr const char *kLayoutMap = "layout_map";
// ForAttr, Containing the parallel loop layout for a parallel for loop
constexpr const char *kParallelLoopLayout = "parallel_loop_layout";
// ForAttr, Containing the predicate for a parallel for loop
constexpr const char *kParallelLoopPredicate = "parallel_loop_predicate";
// ForAttr, Width (in elements) for coalesced memory access
constexpr const char *kCoalescedWidth = "coalesced_width";
} // namespace attr

} // namespace tl
Expand Down
41 changes: 18 additions & 23 deletions src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../layout/layout.h"
#include "../target/utils.h"
#include "../transform/atomicadd_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
Expand All @@ -23,23 +24,24 @@ namespace tl {
using namespace tir;

/**
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
* @brief Construct an AtomicAdd operator from call arguments and annotations.
*
* Builds the internal AtomicAddNode, extracts the source and destination
* regions and their backing Buffers from the first two region-style expressions
* in `args` (BufferLoad/BufferRegion), and stores them along with their
* ranges. If a third argument is provided, it is interpreted as an integer
* immediate and stored as the node's coalesced width.
* ranges. Annotations are copied directly from the Call node.
*
* @param args Call-style PrimExprs where:
* - args[0] is the source region call,
* - args[1] is the destination region call,
* - args[2] (optional) is an IntImm specifying coalesced width.
* - args[1] is the destination region call.
* @param annotations Map containing optional keys:
* - "use_tma": whether to use TMA for memory operations
* - "memory_order": memory order for atomic operations
* Notes:
* - The constructor checks that args[0] and args[1] are region-compatible.
* - The constructed node is stored in this->data_.
*/
AtomicAdd::AtomicAdd(Array<PrimExpr> args) {
AtomicAdd::AtomicAdd(Array<PrimExpr> args, Map<String, ObjectRef> annotations) {
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
Expand All @@ -50,16 +52,8 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args) {
}
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
node->use_tma = Downcast<IntImm>(args[2]);
}
node->memory_order = IntImm(0);
if (args.size() >= 4) {
node->memory_order = Downcast<IntImm>(args[3]);
}
if (args.size() >= 5) {
node->coalesced_width = Downcast<IntImm>(args[4]);
}
// Copy annotations from the Call node
node->annotations = annotations;
data_ = std::move(node);
}

Expand Down Expand Up @@ -284,21 +278,22 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {

new_args.push_back(dst_ptr);
new_args.push_back(src_value);
new_args.push_back(memory_order);
new_args.push_back(GetMemoryOrder());

Call atomicadd_call =
tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args);

Stmt body = tvm::tir::Evaluate(atomicadd_call);

for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()) {
annotations.Set("coalesced_width", coalesced_width);
Map<String, ObjectRef> loop_annotations;
if (annotations.count(attr::kCoalescedWidth)) {
loop_annotations.Set(attr::kCoalescedWidth,
annotations.Get(attr::kCoalescedWidth).value());
}

body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, std::nullopt, annotations);
ForKind::kParallel, body, std::nullopt, loop_annotations);
}
return Downcast<For>(body);
}
Expand Down Expand Up @@ -377,7 +372,7 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
*/
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
if (use_tma->value != 0) {
if (GetUseTMA()) {
Array<PrimExpr> src_indices, dst_indices;
PrimExpr src_size, dst_size;
std::tie(src_indices, src_size) = ReturnIndicesAndSize(0);
Expand Down Expand Up @@ -487,7 +482,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int sm = GetArchInt(target);
auto plan = planner.Plan(loop, sm);
int vec = std::max(plan.vector_size, 1);
if (auto cw = loop->annotations.Get("coalesced_width")) {
if (auto cw = loop->annotations.Get(attr::kCoalescedWidth)) {
if (const auto *imm = cw->as<IntImmNode>()) {
int expected = imm->value;
ICHECK_GT(expected, 0);
Expand Down
37 changes: 29 additions & 8 deletions src/op/atomic_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ class AtomicAddNode : public TileOperatorNode {
public:
Buffer src, dst; ///< Source and destination buffers
Array<Range> src_range,
dst_range; ///< Access ranges for source and destination
IntImm use_tma; ///< Whether to use TMA for memory operations
IntImm coalesced_width; ///< Width for memory coalescing optimization
IntImm memory_order; ///< Memory order for atomic operations
dst_range; ///< Access ranges for source and destination
Map<String, ObjectRef> annotations; ///< Annotations for the atomic operation
// Supported annotation keys:
// - "use_tma": IntImm, whether to use TMA for memory operations
// - "coalesced_width": IntImm, width for memory coalescing optimization
// - "memory_order": IntImm, memory order for atomic operations

mutable ParallelOp par_op_; ///< Associated parallel operation
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode,
Expand All @@ -41,9 +43,26 @@ class AtomicAddNode : public TileOperatorNode {
.def_ro("dst", &AtomicAddNode::dst)
.def_ro("src_range", &AtomicAddNode::src_range)
.def_ro("dst_range", &AtomicAddNode::dst_range)
.def_ro("use_tma", &AtomicAddNode::use_tma)
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width)
.def_ro("memory_order", &AtomicAddNode::memory_order);
.def_ro("annotations", &AtomicAddNode::annotations);
}

// Helper methods to get annotation values
bool GetUseTMA() const {
if (auto val = annotations.Get("use_tma")) {
if (auto int_val = val->as<IntImmNode>()) {
return int_val->value != 0;
}
}
return false;
}

int GetMemoryOrder() const {
if (auto val = annotations.Get("memory_order")) {
if (auto int_val = val->as<IntImmNode>()) {
return int_val->value;
}
}
return 0; // default: relaxed
}

protected:
Expand All @@ -65,7 +84,9 @@ class AtomicAdd : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator,
AtomicAddNode);
TVM_DLL AtomicAdd(Array<PrimExpr> args);
TVM_DLL
AtomicAdd(Array<PrimExpr> args,
Map<String, ObjectRef> annotations = Map<String, ObjectRef>());
static const Op &Get();
};

Expand Down
61 changes: 24 additions & 37 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
return Array<T>{array.rbegin(), array.rend()};
}

// Constructs a Copy operator node from call arguments.
// Constructs a Copy operator node from call arguments and annotations.
// args[0]: source region, args[1]: destination region
// Optional: args[2] coalesced_width, args[3] disable_tma, args[4]
// eviction_policy
Copy::Copy(Array<PrimExpr> args) {
// annotations: Map containing coalesced_width, disable_tma, eviction_policy,
// etc.
Copy::Copy(Array<PrimExpr> args, Map<String, ObjectRef> annotations) {
ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>();
Array<Range> rgs[2];
Buffer bf[2];
Expand All @@ -114,18 +114,8 @@ Copy::Copy(Array<PrimExpr> args) {
}
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3) {
auto coalesced_width = Downcast<IntImm>(args[2]);
if (coalesced_width->value > 0) {
node->coalesced_width = coalesced_width;
}
}
if (args.size() >= 4) {
node->disable_tma = Downcast<Bool>(args[3]);
}
if (args.size() >= 5) {
node->eviction_policy = args[4].as<IntImmNode>()->value;
}
// Copy annotations from the Call node
node->annotations = annotations;
data_ = std::move(node);
}

Expand Down Expand Up @@ -323,12 +313,13 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return For(Var("i"), 0, 1, ForKind::kSerial, body);
}
for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()) {
annotations.Set("coalesced_width", coalesced_width);
Map<String, ObjectRef> loop_annotations;
if (annotations.count(attr::kCoalescedWidth)) {
loop_annotations.Set(attr::kCoalescedWidth,
annotations.Get(attr::kCoalescedWidth).value());
}
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
ForKind::kParallel, body, std::nullopt, annotations);
ForKind::kParallel, body, std::nullopt, loop_annotations);
}
return Downcast<For>(body);
}
Expand Down Expand Up @@ -361,7 +352,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<Bool>(kDisableTMALower, Bool(false)).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
auto copy_inst = GetCopyInst(target, disable_tma_lower || GetDisableTMA(),
T.layout_map, T.analyzer, T.buffer_oob);

// Handle tensor memory (tmem) layout inference
Expand Down Expand Up @@ -736,7 +727,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<Bool>(kDisableTMALower, Bool(false)).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
auto copy_inst = GetCopyInst(target, disable_tma_lower || GetDisableTMA(),
T.layout_map, analyzer);
if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) {
auto tmem_copy = LowerTmemCopy(T, analyzer);
Expand Down Expand Up @@ -783,6 +774,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
<< "` may cause conflicted write.";
}
vectorized_thread_loop = VectorizeLoop(transformed_loop);
return vectorized_thread_loop;
} else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
Expand All @@ -797,17 +789,11 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
level);
}
auto loop_layout = par_op->GetLoopLayout();
auto thread_var = T.thread_var;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
}

if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
// Use LowerParallelLoop to handle partitioning, vectorization, and
// predicate
return LowerParallelLoop(par_op->GetRoot(), loop_layout, T.thread_var,
analyzer, par_op->GetPredicate(T.thread_var));
}
return vectorized_thread_loop;
}

// Lowers copy to LDSM/STSM (warp-level 8x8 matrix) instructions.
Expand Down Expand Up @@ -1452,7 +1438,7 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
int need_reduce = 0;
if (!is_load)
args.push_back(need_reduce);
args.push_back(this->eviction_policy);
args.push_back(GetEvictionPolicy());
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args)));
} else {
Expand All @@ -1464,7 +1450,7 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
int need_reduce = 0;
if (!is_load)
args.push_back(need_reduce);
args.push_back(this->eviction_policy);
args.push_back(GetEvictionPolicy());
tma_copy = Evaluate(Call(DataType::Handle(), op, args));
}
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
Expand Down Expand Up @@ -1536,13 +1522,13 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer,
tma_copy = Evaluate(
Call(DataType::Handle(), tma_load(),
{shared_addr, global_addr, 0,
elements * shared_tensor->dtype.bytes(), this->eviction_policy}));
elements * shared_tensor->dtype.bytes(), GetEvictionPolicy()}));
} else {
int need_reduce = 0;
tma_copy = Evaluate(
Call(DataType::Handle(), tma_store(),
{global_addr, shared_addr, elements * shared_tensor->dtype.bytes(),
need_reduce, this->eviction_policy}));
need_reduce, GetEvictionPolicy()}));
}
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
Expand Down Expand Up @@ -1575,7 +1561,8 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
// Constructs a Conv2DIm2ColOp node from call arguments.
// args: src, dst, nhw_step, c_step, kernel, stride, dilation, padding,
// eviction_policy
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args) {
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args,
Map<String, ObjectRef> annotations) {
ObjectPtr<Conv2DIm2ColOpNode> node =
tvm::ffi::make_object<Conv2DIm2ColOpNode>();
node->srcRegion_ = NormalizeToBufferRegion(args[0]);
Expand Down
Loading
Loading