Skip to content

Commit 2e119b4

Browse files
committed
calculate mma store dst index using inverse affine map
1 parent 9489434 commit 2e119b4

File tree

4 files changed

+40
-7
lines changed

4 files changed

+40
-7
lines changed

src/target/source/codegen_c.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,13 @@ void CodeGenC::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*)
460460
void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*)
461461
PrintBinaryExpr(op, "/", os, this);
462462
}
463+
void CodeGenC::VisitExpr_(const FloorDivNode* op, std::ostream& os) { // NOLINT(*)
464+
PrintBinaryExpr(op, "/", os, this);
465+
}
466+
void CodeGenC::VisitExpr_(const FloorModNode* op, std::ostream& os) { // NOLINT(*)
467+
PrintBinaryExpr(op, "%", os, this);
468+
}
469+
463470
void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*)
464471
if (op->dtype.is_int() || op->dtype.is_uint()) {
465472
PrintBinaryExpr(op, "%", os, this);

src/target/source/codegen_c.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
134134
void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*)
135135
void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*)
136136
void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*)
137+
void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*)
138+
void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*)
137139
void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*)
138140
void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*)
139141
void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*)

src/target/source/codegen_cuda.cc

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
#include "literal/cuda_half_t.h"
3636
#include "ptx.h"
37+
#include "tvm/arith/iter_affine_map.h"
3738

3839
namespace tvm {
3940
namespace codegen {
@@ -838,21 +839,38 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
838839
std::string dst = this->PrintExpr(op->args[2]);
839840
std::string src = this->PrintExpr(op->args[3]);
840841
std::string src_offset = this->PrintExpr(op->args[4]);
841-
std::string stride = this->PrintExpr(op->args[5]);
842842

843843
if (m == 16 && n == 8) {
844+
std::string stride = this->PrintExpr(op->args[5]);
844845
os << "for (int i = 0; i < 4; ++i) {\n";
845846
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
846847
<< " + (threadIdx.x % 4) * 2 + i % 2]"
847848
<< " = " << src << "[" << src_offset << " + i];\n";
848849
os << "}\n";
849850
} else if (m == 16 && n == 16) {
850-
os << "for (int outer = 0; outer < 2; ++outer) {\n";
851-
os << "for (int i = 0; i < 4; ++i) {\n";
852-
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
853-
<< " + outer * 8 + (threadIdx.x % 4) * 2 + i % 2]"
854-
<< " = " << src << "[" << src_offset << " + outer * 4 + i];\n";
855-
os << "}\n";
851+
const auto* index_map =
852+
runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout");
853+
ICHECK(index_map);
854+
855+
Var var_i("i");
856+
Var var_j("j");
857+
Array<PrimExpr> forward_map = (*index_map)(var_i, var_j);
858+
859+
arith::Analyzer ana;
860+
auto iter_map = arith::DetectIterMap(
861+
forward_map, {{var_i, Range(0, 16)}, {var_j, Range(0, 16)}}, true, true, &ana, true);
862+
863+
Var thread_id("threadIdx.x");
864+
Var local_id("local_id");
865+
auto inverse_map = arith::InverseAffineIterMap(iter_map, {thread_id, local_id});
866+
PrimExpr stride = op->args[5];
867+
auto dst_idx = inverse_map[var_i] * stride + inverse_map[var_j];
868+
869+
var_idmap_[thread_id.get()] = "threadIdx.x";
870+
var_idmap_[local_id.get()] = "local_id";
871+
os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
872+
os << dst << "[" + this->PrintExpr(dst_idx) + "]"
873+
<< " = " << src << "[" << src_offset << " + local_id];\n";
856874
os << "}\n";
857875
}
858876
} else if (op->op.same_as(builtin::mma_fill())) {

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
1313
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
1414

1515

16+
@tvm._ffi.register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
17+
def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
18+
thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
19+
return tvm.runtime.convert([thread_id, local_id])
20+
21+
1622
@T.prim_func
1723
def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
1824
A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")

0 commit comments

Comments
 (0)