diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 48e6cdf6e..cece1e6f9 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -14,6 +14,7 @@ #include "../target/utils.h" #include "region.h" #include "tcgen5_meta.h" +#include "utils.h" namespace tvm { namespace tl { @@ -48,92 +49,9 @@ using namespace tir; * fails with an ICHECK (runtime assertion). No other validation is * performed here. */ -// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region) -// to BufferRegion -static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, - const BufferMap &vmap) { - // Case 1: Already a BufferRegion - if (arg->IsInstance()) { - return Downcast(arg); - } - - // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else - // extent=1) - if (const auto *load = arg.as()) { - Array ranges; - for (const PrimExpr &index : load->indices) { - if (const auto *ramp = index.as()) { - ICHECK(ramp->stride.as()) << "Ramp stride must be IntImm"; - ICHECK_EQ(ramp->stride.as()->value, 1) - << "Only stride-1 Ramp is supported in GEMM region conversion"; - ICHECK(ramp->lanes.as()) - << "Scalable vector lanes not supported in GEMM region conversion"; - ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); - } else { - ranges.push_back(Range::FromMinExtent(index, 1)); - } - } - return BufferRegion(load->buffer, ranges); - } +// NormalizeToBufferRegion moved to src/op/utils.{h,cc} - // Case 3: Call nodes - if (const auto *call = arg.as()) { - // tl.region(...) — reconstruct via RegionOp - if (call->op.same_as(RegionOp::Get())) { - RegionOp region(call->args, vmap); - return BufferRegion(region->GetBuffer(), region->GetRanges()); - } - // builtin.tvm_access_ptr(...) — map var to Buffer and take full region - if (call->op.same_as(builtin::tvm_access_ptr())) { - Var var = Downcast(call->args[1]); - Buffer buf = vmap[var]; - Array ranges; - for (PrimExpr extent : buf->shape) { - ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); - } - return BufferRegion(buf, ranges); - } - } - - LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg; - throw; // Unreachable, keeps compiler happy -} - -// Build a tvm_access_ptr(handle) to the start of the 2D tile within a -// BufferRegion. Offset is computed from all but the last two dimensions; extent -// is the product of the last two extents. rw_mask: 1=read, 2=write, -// 3=readwrite. -static PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, - int rw_mask) { - Buffer buf = region->buffer; - int ndim = static_cast(buf->shape.size()); - ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims"; - - // Compute row-major strides - std::vector strides(ndim); - PrimExpr one = make_const(buf->shape[0].dtype(), 1); - PrimExpr cur = one; - for (int i = ndim - 1; i >= 0; --i) { - strides[i] = cur; - cur = cur * buf->shape[i]; - } - - // Offset: sum_{i in [0..ndim-3]} min_i * stride_i - PrimExpr offset = make_const(buf->shape[0].dtype(), 0); - for (int i = 0; i < ndim - 2; ++i) { - offset = offset + region->region[i]->min * strides[i]; - } - - // Extent: last two extents product (elements) - PrimExpr extent = - region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; - - // ptype and return handle - PrimExpr ptype = tir::TypeAnnotation(buf->dtype); - Array acc_args{ptype, buf->data, offset, extent, - IntImm(DataType::Int(32), rw_mask)}; - return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); -} +// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} Gemm::Gemm(Array args, BufferMap vmap) { ObjectPtr node = tvm::ffi::make_object(); @@ -535,9 +453,12 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst); // Build access pointers from regions locally - PrimExpr Aptr = MakeAccessPtrFromRegion(aRegion_, /*r*/ 1); - PrimExpr Bptr = MakeAccessPtrFromRegion(bRegion_, /*r*/ 1); - PrimExpr Cptr = MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3); + PrimExpr Aptr = + MakeAccessPtrFromRegion(aRegion_, /*r*/ 1, /*require_2d*/ true); + PrimExpr Bptr = + MakeAccessPtrFromRegion(bRegion_, /*r*/ 1, /*require_2d*/ true); + PrimExpr Cptr = + MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3, /*require_2d*/ true); std::stringstream ss; std::string op_name; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index ac506ee09..a6ddef64f 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -14,98 +14,16 @@ #include "../target/utils.h" #include "region.h" #include "tcgen5_meta.h" +#include "utils.h" namespace tvm { namespace tl { using namespace tir; -// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region) -// to BufferRegion -static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, - const BufferMap &vmap) { - // Case 1: Already a BufferRegion - if (arg->IsInstance()) { - return Downcast(arg); - } - - // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else - // extent=1) - if (const auto *load = arg.as()) { - Array ranges; - for (const PrimExpr &index : load->indices) { - if (const auto *ramp = index.as()) { - ICHECK(ramp->stride.as()) << "Ramp stride must be IntImm"; - ICHECK_EQ(ramp->stride.as()->value, 1) - << "Only stride-1 Ramp is supported in GEMM region conversion"; - ICHECK(ramp->lanes.as()) - << "Scalable vector lanes not supported in GEMM region conversion"; - ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); - } else { - ranges.push_back(Range::FromMinExtent(index, 1)); - } - } - return BufferRegion(load->buffer, ranges); - } - - // Case 3: Call nodes - if (const auto *call = arg.as()) { - // tl.region(...) — reconstruct via RegionOp - if (call->op.same_as(RegionOp::Get())) { - RegionOp region(call->args, vmap); - return BufferRegion(region->GetBuffer(), region->GetRanges()); - } - // builtin.tvm_access_ptr(...) — map var to Buffer and take full region - if (call->op.same_as(builtin::tvm_access_ptr())) { - Var var = Downcast(call->args[1]); - Buffer buf = vmap.at(var); - Array ranges; - for (PrimExpr extent : buf->shape) { - ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); - } - return BufferRegion(buf, ranges); - } - } +// NormalizeToBufferRegion moved to src/op/utils.{h,cc} - LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg; - throw; // Unreachable, keeps compiler happy -} - -// Build a tvm_access_ptr(handle) to the start of the 2D tile within a -// BufferRegion. Offset is computed from all but the last two dimensions; extent -// is the product of the last two extents. rw_mask: 1=read, 2=write, -// 3=readwrite. -static PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, - int rw_mask) { - Buffer buf = region->buffer; - int ndim = static_cast(buf->shape.size()); - ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims"; - - // Compute row-major strides - std::vector strides(ndim); - PrimExpr one = make_const(buf->shape[0].dtype(), 1); - PrimExpr cur = one; - for (int i = ndim - 1; i >= 0; --i) { - strides[i] = cur; - cur = cur * buf->shape[i]; - } - - // Offset: sum_{i in [0..ndim-3]} min_i * stride_i - PrimExpr offset = make_const(buf->shape[0].dtype(), 0); - for (int i = 0; i < ndim - 2; ++i) { - offset = offset + region->region[i]->min * strides[i]; - } - - // Extent: last two extents product (elements) - PrimExpr extent = - region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; - - // ptype and return handle - PrimExpr ptype = tir::TypeAnnotation(buf->dtype); - Array acc_args{ptype, buf->data, offset, extent, - IntImm(DataType::Int(32), rw_mask)}; - return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); -} +// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} /** * @brief Construct a Gemm operator from serialized TL arguments and a buffer diff --git a/src/op/reduce.cc b/src/op/reduce.cc index b6dbe8651..c326f5ac0 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -17,105 +17,16 @@ #include "region.h" #include "tir/transforms/ir_utils.h" #include "tvm/tir/stmt.h" +#include "utils.h" namespace tvm { namespace tl { using namespace tir; -// Normalize an argument (BufferRegion/BufferLoad/tl.region) -// to BufferRegion so Reduce can uniformly consume regions. -static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, - const BufferMap &vmap) { - // Case 1: Already a BufferRegion - if (arg->IsInstance()) { - return Downcast(arg); - } - - // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else - // extent=1) - if (const auto *load = arg.as()) { - Array ranges; - for (const PrimExpr &index : load->indices) { - if (const auto *ramp = index.as()) { - ICHECK(ramp->stride.as()) << "Ramp stride must be IntImm"; - ICHECK_EQ(ramp->stride.as()->value, 1) - << "Only stride-1 Ramp is supported in region conversion"; - ICHECK(ramp->lanes.as()) - << "Scalable vector lanes not supported in region conversion"; - ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); - } else { - ranges.push_back(Range::FromMinExtent(index, 1)); - } - } - return BufferRegion(load->buffer, ranges); - } - - // Case 3: Call nodes (only tl.region) - if (const auto *call = arg.as()) { - // tl.region(...) — reconstruct via RegionOp - if (call->op.same_as(RegionOp::Get())) { - RegionOp region(call->args, vmap); - return BufferRegion(region->GetBuffer(), region->GetRanges()); - } - // builtin.tvm_access_ptr(...) — map var to Buffer and take full region - if (call->op.same_as(builtin::tvm_access_ptr())) { - Var var = Downcast(call->args[1]); - Buffer buf = vmap[var]; - Array ranges; - for (PrimExpr extent : buf->shape) { - ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); - } - return BufferRegion(buf, ranges); - } - } - - LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg; - throw; // Unreachable -} - -// Build a tvm_access_ptr(handle) to the start of the 2D tile within a -// BufferRegion. Offset is computed from all but the last two dimensions; extent -// is the product of the last two extents. rw_mask: 1=read, 2=write, -// 3=readwrite. -static PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, - int rw_mask) { - Buffer buf = region->buffer; - int ndim = static_cast(buf->shape.size()); - ICHECK(ndim == 1 || ndim == 2) << "Cumsum expects buffers with 1 or 2 dims"; - - PrimExpr offset, extent; - if (ndim == 1) { - // Simple 1D region: offset and extent come from the single axis. - auto axis = region->region[0]; - offset = axis->min; - extent = axis->extent; - } else { - // Compute row-major strides for ndim >= 2 - std::vector strides(ndim); - PrimExpr one = make_const(buf->shape[0].dtype(), 1); - PrimExpr cur = one; - for (int i = ndim - 1; i >= 0; --i) { - strides[i] = cur; - cur = cur * buf->shape[i]; - } - // Offset: sum_{i in [0..ndim-3]} min_i * stride_i - offset = make_const(buf->shape[0].dtype(), 0); - for (int i = 0; i < ndim - 2; ++i) { - offset = offset + region->region[i]->min * strides[i]; - } +// NormalizeToBufferRegion moved to src/op/utils.{h,cc} - // Extent: last two extents product (elements) - extent = - region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; - } - - // ptype and return handle - PrimExpr ptype = tir::TypeAnnotation(buf->dtype); - Array acc_args{ptype, buf->data, offset, extent, - IntImm(DataType::Int(32), rw_mask)}; - return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); -} +// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} ReduceOp::ReduceOp(Array args, BufferMap vmap) { ObjectPtr node = tvm::ffi::make_object(); diff --git a/src/op/utils.cc b/src/op/utils.cc new file mode 100644 index 000000000..59960b570 --- /dev/null +++ b/src/op/utils.cc @@ -0,0 +1,105 @@ +/*! + * \file tl/op/utils.cc + * \brief Common utilities implementation for TL ops. + */ + +#include "utils.h" + +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, + const BufferMap &vmap) { + // Case 1: Already a BufferRegion + if (arg->IsInstance()) { + return Downcast(arg); + } + + // Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else + // extent=1) + if (const auto *load = arg.as()) { + Array ranges; + for (const PrimExpr &index : load->indices) { + if (const auto *ramp = index.as()) { + ICHECK(ramp->stride.as()) << "Ramp stride must be IntImm"; + ICHECK_EQ(ramp->stride.as()->value, 1) + << "Only stride-1 Ramp is supported in region conversion"; + ICHECK(ramp->lanes.as()) + << "Scalable vector lanes not supported in region conversion"; + ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); + } else { + ranges.push_back(Range::FromMinExtent(index, 1)); + } + } + return BufferRegion(load->buffer, ranges); + } + + // Case 3: Call nodes + if (const auto *call = arg.as()) { + // tl.region(...) — reconstruct via RegionOp + if (call->op.same_as(RegionOp::Get())) { + RegionOp region(call->args, vmap); + return BufferRegion(region->GetBuffer(), region->GetRanges()); + } + // builtin.tvm_access_ptr(...) — map var to Buffer and take full region + if (call->op.same_as(builtin::tvm_access_ptr())) { + Var var = Downcast(call->args[1]); + Buffer buf = vmap.at(var); + Array ranges; + for (PrimExpr extent : buf->shape) { + ranges.push_back(Range(IntImm(extent->dtype, 0), extent)); + } + return BufferRegion(buf, ranges); + } + } + + LOG(FATAL) << "Unsupported argument for BufferRegion: " << arg; + throw; // Unreachable +} + +PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, + bool require_2d) { + Buffer buf = region->buffer; + int ndim = static_cast(buf->shape.size()); + if (require_2d) { + ICHECK(ndim >= 2) << "Expect buffers with at least 2 dims"; + } + + PrimExpr offset, extent; + if (ndim == 1) { + // 1D: straightforward + auto axis = region->region[0]; + offset = axis->min; + extent = axis->extent; + } else { + // Compute row-major strides + std::vector strides(ndim); + PrimExpr one = make_const(buf->shape[0].dtype(), 1); + PrimExpr cur = one; + for (int i = ndim - 1; i >= 0; --i) { + strides[i] = cur; + cur = cur * buf->shape[i]; + } + // Offset: sum_{i in [0..ndim-3]} min_i * stride_i + offset = make_const(buf->shape[0].dtype(), 0); + for (int i = 0; i < ndim - 2; ++i) { + offset = offset + region->region[i]->min * strides[i]; + } + // Extent: last two extents product (elements) + extent = + region->region[ndim - 2]->extent * region->region[ndim - 1]->extent; + } + + // ptype and return handle + PrimExpr ptype = tir::TypeAnnotation(buf->dtype); + Array acc_args{ptype, buf->data, offset, extent, + IntImm(DataType::Int(32), rw_mask)}; + return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); +} + +} // namespace tl +} // namespace tvm diff --git a/src/op/utils.h b/src/op/utils.h new file mode 100644 index 000000000..9e7880acd --- /dev/null +++ b/src/op/utils.h @@ -0,0 +1,35 @@ +/*! + * \file tl/op/utils.h + * \brief Common utilities for TL ops. + */ + +#ifndef TVM_TL_OP_UTILS_H_ +#define TVM_TL_OP_UTILS_H_ + +#include "./operator.h" +#include "region.h" +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +// Normalize an argument (BufferRegion/BufferLoad/tl.region/tvm_access_ptr) +// to BufferRegion so ops can uniformly consume regions. +TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg, + const BufferMap &vmap); + +// Build a tvm_access_ptr(handle) from a BufferRegion. +// - If `require_2d` is true, checks buffer ndim >= 2. +// - For 1D regions (when allowed), offset=min, extent=extent. +// - For ndim >= 2, offset sums all but last two dims using row-major strides, +// extent is product of the last two extents. +TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, + int rw_mask, bool require_2d = false); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_UTILS_H_