Skip to content

Commit befe30e

Browse files
committed
[language] add transpose E
1 parent 5d5cf85 commit befe30e

File tree

6 files changed

+29
-15
lines changed

6 files changed

+29
-15
lines changed

src/op/gemm_sp_py.cc

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,24 @@ GemmSPPy::GemmSPPy(Array<PrimExpr> args, BufferMap vmap) {
6161
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
6262
node->trans_A = args[4].as<Bool>().value();
6363
node->trans_B = args[5].as<Bool>().value();
64-
node->M = args[6].as<IntImm>().value()->value;
65-
node->N = args[7].as<IntImm>().value()->value;
66-
node->K = args[8].as<IntImm>().value()->value;
67-
node->policy = GemmWarpPolicy(args[9].as<IntImm>().value()->value);
68-
node->clear_accum = args[10].as<PrimExpr>().value();
69-
node->stride_A = args[11].as<IntImm>().value()->value;
70-
node->stride_B = args[12].as<IntImm>().value()->value;
71-
node->offset_A = args[13].as<IntImm>().value()->value;
72-
node->offset_B = args[14].as<IntImm>().value()->value;
73-
if (args.size() > 15) {
74-
node->kPack = args[15].as<IntImm>().value()->value;
64+
node->trans_E = args[6].as<Bool>().value();
65+
node->M = args[7].as<IntImm>().value()->value;
66+
node->N = args[8].as<IntImm>().value()->value;
67+
node->K = args[9].as<IntImm>().value()->value;
68+
node->policy = GemmWarpPolicy(args[10].as<IntImm>().value()->value);
69+
node->clear_accum = args[11].as<PrimExpr>().value();
70+
node->stride_A = args[12].as<IntImm>().value()->value;
71+
node->stride_B = args[13].as<IntImm>().value()->value;
72+
node->offset_A = args[14].as<IntImm>().value()->value;
73+
node->offset_B = args[15].as<IntImm>().value()->value;
74+
if (args.size() > 16) {
75+
node->kPack = args[16].as<IntImm>().value()->value;
7576
if (node->kPack != 1 && node->kPack != 2) {
7677
ICHECK(false) << "kPack must be 1 or 2";
7778
}
7879
}
79-
if (args.size() > 16) {
80-
node->wg_wait = args[16].as<IntImm>().value()->value;
80+
if (args.size() > 17) {
81+
node->wg_wait = args[17].as<IntImm>().value()->value;
8182
}
8283
data_ = std::move(node);
8384
}

src/op/gemm_sp_py.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class GemmSPPyNode : public TileOperatorNode {
2424
tir::Buffer A, E, B, C;
2525
// pointer to the A, E, B, C
2626
PrimExpr Aptr, Eptr, Bptr, Cptr;
27-
bool trans_A, trans_B;
27+
bool trans_A, trans_B, trans_E;
2828
int M, N, K;
2929
int stride_A, stride_B;
3030
int offset_A, offset_B;
@@ -51,6 +51,7 @@ class GemmSPPyNode : public TileOperatorNode {
5151
.def_ro("Cptr", &GemmSPPyNode::Cptr)
5252
.def_ro("trans_A", &GemmSPPyNode::trans_A)
5353
.def_ro("trans_B", &GemmSPPyNode::trans_B)
54+
.def_ro("trans_E", &GemmSPPyNode::trans_E)
5455
.def_ro("M", &GemmSPPyNode::M)
5556
.def_ro("N", &GemmSPPyNode::N)
5657
.def_ro("K", &GemmSPPyNode::K)

tilelang/intrinsics/mma_sp_macro_generator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def __init__(
139139
accum_dtype: str = "float16",
140140
a_transposed: bool = False,
141141
b_transposed: bool = False,
142+
e_transposed: bool = False,
142143
block_row_warps: int = 2,
143144
block_col_warps: int = 2,
144145
warp_row_tiles: int = 8,
@@ -155,6 +156,7 @@ def __init__(
155156
self.accum_dtype = accum_dtype
156157
self.a_transposed = a_transposed
157158
self.b_transposed = b_transposed
159+
self.e_transposed = e_transposed
158160
# Hint Information
159161
self.block_row_warps = block_row_warps
160162
self.block_col_warps = block_col_warps
@@ -362,6 +364,7 @@ def ldmatrix_e(self, E_local_buf: Buffer, E_shared_buf: Buffer, ki: PrimExpr, rk
362364
local_size_e = self.local_size_e
363365
a_dtype = self.a_dtype
364366
e_dtype = self.e_dtype
367+
trans = self.e_transposed
365368
# ldmatrix cannot be used for int8 + trans case.
366369
# include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h
367370
ldmatrix_available = False # TODO: use ldmatrix when possible
@@ -413,7 +416,7 @@ def _warp_ldmatrix_e(
413416
rk * warp_k + ki * micro_size_k) // self.e_factor
414417
for j in T.serial(local_size_e):
415418
mi, mk = mma_load_layout(tx, j)
416-
E_local_buf[i * local_size_e + j] = E_shared_buf[wi + mi, wk + mk]
419+
E_local_buf[i * local_size_e + j] = E_shared_buf[wk + mk, wi + mi] if trans else E_shared_buf[wi + mi, wk + mk]
417420

418421
return _warp_ldmatrix_e(E_local_buf, E_shared_buf, ki, thread_binding, rk)
419422

tilelang/language/experimental/gemm_sp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def gemm_sp(
3131
C (Union[tir.Buffer, tir.Var]): Output matrix for results
3232
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
3333
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
34+
transpose_E (bool, optional): Whether to transpose matrix E. Defaults to False.
3435
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
3536
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
3637
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
@@ -95,6 +96,7 @@ def gemm_sp_v2(
9596
C: tir.Buffer | tir.Var,
9697
transpose_A: bool = False,
9798
transpose_B: bool = False,
99+
transpose_E: bool = False,
98100
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
99101
clear_accum: bool = False,
100102
k_pack: int = 1,
@@ -293,6 +295,7 @@ def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
293295
Cptr,
294296
transpose_A,
295297
transpose_B,
298+
transpose_E,
296299
M,
297300
N,
298301
K,

tilelang/tileop/gemm_sp/gemm_sp_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def trans_A(self) -> bool:
4949
def trans_B(self) -> bool:
5050
return self.gemm_sp_node.trans_B
5151

52+
@property
53+
def trans_E(self) -> bool:
54+
return self.gemm_sp_node.trans_E
55+
5256
@property
5357
def e_dtype(self) -> str:
5458
return self.E.dtype

tilelang/tileop/gemm_sp/gemm_sp_mma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def infer_layout(self, target: Target, thread_nums: int):
2323
accum_dtype=self.accum_dtype,
2424
a_transposed=self.trans_A,
2525
b_transposed=self.trans_B,
26+
e_transposed=self.trans_E,
2627
block_row_warps=m_warp,
2728
block_col_warps=n_warp,
2829
warp_row_tiles=warp_row_tiles,
@@ -69,6 +70,7 @@ def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
6970
accum_dtype=self.accum_dtype,
7071
a_transposed=self.trans_A,
7172
b_transposed=self.trans_B,
73+
e_transposed=self.trans_E,
7274
block_row_warps=m_warp,
7375
block_col_warps=n_warp,
7476
warp_row_tiles=warp_row_tiles,

0 commit comments

Comments
 (0)