Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,19 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, Map<String, ObjectRef> annotations) {
<< "AtomicAdd expects at least 2 arguments (src, dst), got "
<< args.size();
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto region = NormalizeToBufferRegion(args[i]);
rgs[i] = region->region;
bf[i] = region->buffer;

if (IsBufferLikeExpr(args[0])) {
auto region = NormalizeToBufferRegion(args[0]);
node->src = region->buffer;
node->src_range = region->region;
} else {
node->src_value = args[0];
}
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);

auto region = NormalizeToBufferRegion(args[1]);
node->dst = region->buffer;
node->dst_range = region->region;

// Copy annotations from the Call node
node->annotations = annotations;
data_ = std::move(node);
Expand Down Expand Up @@ -144,45 +148,49 @@ AtomicAddNode::ReturnIndicesAndSize(int src_dst) const {
*/
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.empty();
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
}
ICHECK(!loop_vars.empty()) << "MakeIterVars in AtomicOp should not return "
"empty vars (at least 1 var)";

for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);

ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;

ICHECK(loop_vars.size() <= dst_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
<< ", dst_range.size() = " << dst_range.size() << ", dst = " << dst->name;

Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);

Array<PrimExpr> new_args;

// Optional bounds predicates for src and dst
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);

// Load source value and cast to dst dtype if needed
PrimExpr src_value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype)
src_value = Cast(dst->dtype, src_value);
// Src arg to be passed to the Call atomic operation
PrimExpr src_value_arg;

// If src is a Buffer
if (!src_value.defined()) {
ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size()
<< ", src = " << src->name << ", dst = " << dst->name;

Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
// Load source value
src_value_arg = BufferLoad(src, src_indices);
} else {
src_value_arg = src_value;
}
// Cast to dst dtype if needed
if (src_value_arg->dtype != dst->dtype)
src_value_arg = Cast(dst->dtype, src_value_arg);

// Build a pointer to destination element using tvm_access_ptr
PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(dst, dst_indices)});

new_args.push_back(dst_ptr);
new_args.push_back(src_value);
new_args.push_back(src_value_arg);
new_args.push_back(GetMemoryOrder());

// erase use_tma from annotations
Expand Down
8 changes: 7 additions & 1 deletion src/op/atomic_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class AtomicAddNode : public AtomicOpBaseNode {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AtomicAddNode>()
.def_ro("src", &AtomicAddNode::src)
.def_ro("src_value", &AtomicAddNode::src_value)
.def_ro("dst", &AtomicAddNode::dst)
.def_ro("src_range", &AtomicAddNode::src_range)
.def_ro("dst_range", &AtomicAddNode::dst_range)
Expand All @@ -47,7 +48,12 @@ class AtomicAddNode : public AtomicOpBaseNode {
bool GetUseTMA() const {
if (auto val = annotations.Get("use_tma")) {
if (auto int_val = val->as<IntImmNode>()) {
return int_val->value != 0;
if (int_val->value != 0) {
ICHECK(!src_value.defined())
<< "TMA is not supported when using TiledAtomicAdd with PrimExpr "
"as value.";
return true;
}
}
}
return false;
Expand Down
107 changes: 65 additions & 42 deletions src/op/atomic_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@ AtomicMax::AtomicMax(Array<PrimExpr> args, Map<String, ObjectRef> annotations) {
<< "AtomicMax expects at least 2 arguments (src, dst), got "
<< args.size();
ObjectPtr<AtomicMaxNode> node = tvm::ffi::make_object<AtomicMaxNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto region = NormalizeToBufferRegion(args[i]);
rgs[i] = region->region;
bf[i] = region->buffer;

if (IsBufferLikeExpr(args[0])) {
auto region = NormalizeToBufferRegion(args[0]);
node->src = region->buffer;
node->src_range = region->region;
} else {
node->src_value = args[0];
}
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);

auto region = NormalizeToBufferRegion(args[1]);
node->dst = region->buffer;
node->dst_range = region->region;

node->annotations = annotations;
data_ = std::move(node);
}
Expand All @@ -63,15 +67,19 @@ AtomicMin::AtomicMin(Array<PrimExpr> args, Map<String, ObjectRef> annotations) {
<< "AtomicMin expects at least 2 arguments (src, dst), got "
<< args.size();
ObjectPtr<AtomicMinNode> node = tvm::ffi::make_object<AtomicMinNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto region = NormalizeToBufferRegion(args[i]);
rgs[i] = region->region;
bf[i] = region->buffer;

if (IsBufferLikeExpr(args[0])) {
auto region = NormalizeToBufferRegion(args[0]);
node->src = region->buffer;
node->src_range = region->region;
} else {
node->src_value = args[0];
}
std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]);
std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]);

auto region = NormalizeToBufferRegion(args[1]);
node->dst = region->buffer;
node->dst_range = region->region;

node->annotations = annotations;
data_ = std::move(node);
}
Expand All @@ -93,14 +101,23 @@ const Op &AtomicMinNode::GetElemOp() const { return atomic_min_elem_op(); }
Array<IterVar> AtomicOpBaseNode::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent))
// Make IterVars according to dst, not src
// Since src may be a scalar Expr
for (size_t i = 0; i < dst_range.size(); i++) {
if (is_one(dst_range[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
Var var = Var(std::string{char('i' + idx)}, dst_range[i]->extent->dtype);
idx++;
loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
{Range(0, dst_range[i]->extent), var, IterVarType::kDataPar});
}

// If is scalar, create a dummy loop var
if (loop_vars.empty()) {
Var var = Var("i");
loop_vars.push_back({Range(0, 1), var, IterVarType::kDataPar});
}

return loop_vars;
}

Expand All @@ -117,9 +134,11 @@ Array<PrimExpr> AtomicOpBaseNode::MakeIndices(const Array<IterVar> &ivs,
idx++;
}
}
ICHECK(idx == ivs.size())
<< "idx = " << idx << ", ivs.size() = " << ivs.size()
<< "src name = " << src->name << ", dst name = " << dst->name;

// Special case: scalar range, when there is one var and one range(0, 1)
ICHECK(idx == ivs.size() || (idx == 0 && ivs.size() == 1))
<< "Unmatched indices: idx = " << idx << ", ivs.size() = " << ivs.size()
<< ", dst name = " << dst->name;
return indices;
}

Expand Down Expand Up @@ -156,41 +175,45 @@ PrimExpr AtomicOpBaseNode::MakePredicate(arith::Analyzer *analyzer,

For AtomicOpBaseNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.empty();
if (is_scalar) {
return For(Var("i"), 0, 1, ForKind::kSerial,
BufferStore(dst, BufferLoad(src, {0}), {0}));
}
ICHECK(!loop_vars.empty()) << "MakeIterVars in AtomicOp should not return "
"empty vars (at least 1 var)";

for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);

ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;

ICHECK(loop_vars.size() <= dst_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
<< ", dst_range.size() = " << dst_range.size() << ", dst = " << dst->name;

Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);

Array<PrimExpr> new_args;

// Load source value and cast to dst dtype if needed
PrimExpr src_value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype)
src_value = Cast(dst->dtype, src_value);
// Src arg to be passed to the Call atomic operation
PrimExpr src_value_arg;

// If src is a Buffer
if (!src_value.defined()) {
ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size()
<< ", src = " << src->name << ", dst = " << dst->name;

Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
// Load source value
src_value_arg = BufferLoad(src, src_indices);
} else {
src_value_arg = src_value;
}
// Cast to dst dtype if needed
if (src_value_arg->dtype != dst->dtype)
src_value_arg = Cast(dst->dtype, src_value_arg);

// Build a pointer to destination element using tvm_access_ptr
PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(dst, dst_indices)});

new_args.push_back(dst_ptr);
new_args.push_back(src_value);
new_args.push_back(src_value_arg);
new_args.push_back(GetMemoryOrder());

// Use the appropriate elem_op based on the derived type (via virtual call)
Expand Down
5 changes: 4 additions & 1 deletion src/op/atomic_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ using namespace tir;
*/
class AtomicOpBaseNode : public TileOperatorNode {
public:
Buffer src, dst; ///< Source and destination buffers
PrimExpr src_value; ///< Source values, for cases src is not a buffer
Buffer src, dst; ///< Source and destination buffers
Array<Range> src_range,
dst_range; ///< Access ranges for source and destination
Map<String, ObjectRef> annotations; ///< Annotations for the atomic operation
Expand Down Expand Up @@ -81,6 +82,7 @@ class AtomicMaxNode : public AtomicOpBaseNode {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AtomicMaxNode>()
.def_ro("src", &AtomicMaxNode::src)
.def_ro("src_value", &AtomicMaxNode::src_value)
.def_ro("dst", &AtomicMaxNode::dst)
.def_ro("src_range", &AtomicMaxNode::src_range)
.def_ro("dst_range", &AtomicMaxNode::dst_range)
Expand Down Expand Up @@ -113,6 +115,7 @@ class AtomicMinNode : public AtomicOpBaseNode {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AtomicMinNode>()
.def_ro("src", &AtomicMinNode::src)
.def_ro("src_value", &AtomicMinNode::src_value)
.def_ro("dst", &AtomicMinNode::dst)
.def_ro("src_range", &AtomicMinNode::src_range)
.def_ro("dst_range", &AtomicMinNode::dst_range)
Expand Down
11 changes: 11 additions & 0 deletions src/op/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/

#include "utils.h"
#include "tvm/tir/expr.h"

#include <tvm/tir/builtin.h>

Expand All @@ -12,6 +13,16 @@ namespace tl {

using namespace tir;

bool IsBufferLikeExpr(const PrimExpr &expr) {
if (expr.as<BufferLoadNode>() || expr.as<BufferRegionNode>()) {
return true;
}
if (const auto *call = expr.as<CallNode>()) {
return (call->op.same_as(RegionOp::Get()));
}
return false;
}

BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
Expand Down
5 changes: 5 additions & 0 deletions src/op/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "../target/stubs/cuda.h"
#include "./operator.h"
#include "region.h"
#include "tvm/runtime/base.h"
#include <tvm/tir/buffer.h>
#include <tvm/tir/op.h>

Expand All @@ -25,6 +26,10 @@ template <typename T> Array<T> ReverseArray(Array<T> array) {
return Array<T>{array.rbegin(), array.rend()};
}

// Check if an PrimExpr is a buffer-like (BufferRegion/BufferLoad/tl.region)
// expression.
TVM_DLL bool IsBufferLikeExpr(const PrimExpr &expr);

// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so ops can uniformly consume regions.
// Note: tvm_access_ptr is no longer supported here.
Expand Down
Loading
Loading