Skip to content

Commit d2db013

Browse files
committed
Comprehensively support WGMMA GEMM SS
1 parent 2dbaccc commit d2db013

File tree

7 files changed

+83
-29
lines changed

7 files changed

+83
-29
lines changed

src/layout/gemm_layouts.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ Layout makeQuarterBankSwizzleLayout(int stride, int continuous,
385385
Var i = InputPlaceholder(0);
386386
Var j = InputPlaceholder(1);
387387
int vector_size = 128 / element_size;
388+
LOG(INFO) << "makeQuarterBankSwizzleLayout: " << stride << ", " << continuous << ", " << element_size;
388389
ICHECK(stride % 8 == 0) << "stride=" << stride;
389390
ICHECK(continuous % (vector_size * 2) == 0)
390391
<< "continuous=" << continuous << ", vector_size=" << vector_size;
@@ -740,6 +741,7 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
740741

741742
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
742743
int continuity, int element_size, bool k_inner) {
744+
LOG(INFO) << "makeGemmABLayoutHopper: " << mat_stride << ", " << mat_continuous << ", " << continuity << ", " << element_size << ", " << k_inner;
743745
if (element_size == 64) {
744746
if (!k_inner && continuity % 16 == 0) // float64 KxN
745747
return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
@@ -749,6 +751,12 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
749751
element_size);
750752
}
751753
int vector_size = 128 / element_size;
754+
LOG(INFO) << "makeGemmABLayoutHopper: mat_continuous: " << mat_continuous << ", mat_stride: " << mat_stride << ", element_size: " << element_size;
755+
LOG(INFO) << "vector_size: " << vector_size;
756+
LOG(INFO) << "mat_continuous % (vector_size * 8): " << mat_continuous % (vector_size * 8);
757+
LOG(INFO) << "mat_continuous % (vector_size * 4): " << mat_continuous % (vector_size * 4);
758+
LOG(INFO) << "mat_continuous % (vector_size * 2): " << mat_continuous % (vector_size * 2);
759+
LOG(INFO) << "mat_continuous % vector_size: " << mat_continuous % vector_size;
752760
if (mat_continuous % (vector_size * 8) == 0)
753761
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
754762
else if (mat_continuous % (vector_size * 4) == 0)

src/layout/layout.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void LayoutNode::RegisterReflection() {
8282
}
8383

8484
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
85-
for (const auto &[var, dom] : LayoutNode::getVarMap()) {
85+
for (const auto &[var, dom] : getVarMap()) {
8686
analyzer->Bind(var, dom);
8787
}
8888
}
@@ -547,6 +547,10 @@ TVM_FFI_STATIC_INIT_BLOCK({
547547
[](int stride, int continuous, int element_size) {
548548
return makeQuarterBankSwizzleLayout(stride, continuous,
549549
element_size);
550+
})
551+
.def("tl.make_linear_layout",
552+
[](int stride, int continuous) {
553+
return makeGemmLayoutLinear(stride, continuous);
550554
});
551555
});
552556

src/op/gemm.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,8 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
756756
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
757757
const int64_t continuity =
758758
trans_B ? mat_continuous : mat_continuous / warp_n;
759+
760+
LOG(INFO) << "gemm_inst: " << (int)gemm_inst << ", trans_B: " << trans_B;
759761
auto ABLayout =
760762
gemm_inst == GemmInst::kWGMMA
761763
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,

tilelang/intrinsics/wgmma_macro_generator.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
make_full_bank_swizzled_layout,
1111
make_half_bank_swizzled_layout,
1212
make_quarter_bank_swizzled_layout,
13+
make_linear_layout,
1314
)
1415
from tvm.runtime import convert
1516
from tilelang.intrinsics.mma_layout import (shared_16x8_to_mma_32x4_layout_sr_a,
@@ -131,13 +132,20 @@ def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
131132
self.micro_size_k = k_dim
132133

133134
def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode:
135+
# same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper
136+
mat_stride = int(buffer.shape[-2])
137+
mat_continuous = int(buffer.shape[-1])
138+
element_size = DataType(buffer.dtype).bits
139+
print(f"_determinate_swizzle_mode mat_stride: {mat_stride}, mat_continuous: {mat_continuous}, element_size: {element_size}")
134140
if layout is None:
135141
return SwizzleMode.NONE
136-
elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)):
142+
elif layout.is_equal(make_linear_layout(mat_stride, mat_continuous)):
143+
return SwizzleMode.NONE
144+
elif layout.is_equal(make_quarter_bank_swizzled_layout(mat_stride, mat_continuous, element_size)):
137145
return SwizzleMode.SWIZZLE_32B
138-
elif layout.is_equal(make_half_bank_swizzled_layout(buffer)):
146+
elif layout.is_equal(make_half_bank_swizzled_layout(mat_stride, mat_continuous, element_size)):
139147
return SwizzleMode.SWIZZLE_64B
140-
elif layout.is_equal(make_full_bank_swizzled_layout(buffer)):
148+
elif layout.is_equal(make_full_bank_swizzled_layout(mat_stride, mat_continuous, element_size)):
141149
return SwizzleMode.SWIZZLE_128B
142150
else:
143151
raise ValueError(f"Unsupported swizzle mode: {layout}")
@@ -173,7 +181,11 @@ def wgmma(self,
173181
a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout)
174182
b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout)
175183

176-
elems_in_bytes = DataType(self.a_dtype).bits // 8
184+
elems_in_bits = DataType(self.a_dtype).bits
185+
elems_in_bytes = elems_in_bits // 8
186+
187+
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
188+
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none() else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
177189

178190
# by default, we utilize non-swizzle layout offset
179191
a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim *
@@ -186,52 +198,59 @@ def wgmma(self,
186198
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
187199
if a_is_k_major:
188200
a_leading_byte_offset = 16
201+
a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size()
189202
else:
190203
# MN Major
191204
# LBO represents the distance between two atoms along the M dimension
192205
# SBO represents the distance between two atoms along the K dimension
193-
a_leading_byte_offset = a_swizzle_mode.swizzle_atom_size()
194-
a_stride_byte_offset = 8 * 64 * elems_in_bytes
206+
a_m_axis_atoms = m_dim // a_swizzle_atom_elems
207+
if a_m_axis_atoms <= 1:
208+
a_leading_byte_offset = 0
209+
else:
210+
a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
211+
212+
if a_m_axis_atoms <= 1:
213+
a_stride_byte_offset = 8 * elems_in_bytes * m_dim
214+
else:
215+
a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems
195216

196217
b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim *
197218
elems_in_bytes)
198-
b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 *
199-
elems_in_bytes)
219+
b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (
220+
0 if n_dim == 8 else (8 * 8 * elems_in_bytes)
221+
)
200222
if not b_swizzle_mode.is_none():
201223
# swizzle mode doesn't require LBO/SBO to be 1
202224
# https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset
203225
if b_is_k_major:
204226
b_leading_byte_offset = 16
227+
b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size()
205228
else:
206229
# MN Major, K * N
207230
# LBO represents the distance between two atoms along the N dimension
208231
# SBO represents the distance between two atoms along the K dimension
209-
b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
232+
b_n_axis_atoms = n_dim // b_swizzle_atom_elems
210233
if b_n_axis_atoms <= 1:
211234
b_leading_byte_offset = 0
212235
else:
213236
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
214-
215237
if b_n_axis_atoms <= 1:
216238
b_stride_byte_offset = 8 * elems_in_bytes * n_dim
217239
else:
218-
b_stride_byte_offset = 8 * elems_in_bytes * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
219-
220-
240+
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
221241
print(f"a_leading_byte_offset: {a_leading_byte_offset >> 4}")
222242
print(f"a_stride_byte_offset: {a_stride_byte_offset >> 4}")
243+
print(f"b_leading_byte_offset: {b_leading_byte_offset >> 4}")
244+
print(f"b_stride_byte_offset: {b_stride_byte_offset >> 4}")
223245

224246
print(f"b_swizzle_atom_size: {b_swizzle_mode.swizzle_atom_size()}")
225247
print(f"b_swizzle_byte_size: {b_swizzle_mode.swizzle_byte_size()}")
226-
print(f"m_dim: {m_dim}")
227-
print(f"n_dim: {n_dim}")
228-
print(f"k_dim: {k_dim}")
229-
print(f"micro_size_k: {micro_size_k}")
230-
print(f"a_leading_byte_offset: {a_leading_byte_offset}")
231-
print(f"a_stride_byte_offset: {a_stride_byte_offset}")
232-
print(f"b_leading_byte_offset: {b_leading_byte_offset}")
233-
print(f"b_stride_byte_offset: {b_stride_byte_offset}")
234-
# exit()
248+
249+
# for example, if [n, k] where k is 128, we should split it into 2 atoms
250+
# where max specially handles the case when n_dim is 8.
251+
ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1)
252+
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
253+
235254
@T.macro
236255
def _warp_mma(A_buf, B_buf, C_local_buf):
237256
desc_a = T.alloc_descriptor()
@@ -242,10 +261,8 @@ def _warp_mma(A_buf, B_buf, C_local_buf):
242261
int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4))
243262
for ki in T.serial(0, (k_dim // micro_size_k)):
244263
for i in T.serial(m_dim // 64):
245-
k_dim_offset = ki * micro_size_k
246-
A_offset = i * 64 * A_buf.shape[
247-
-1] + k_dim_offset if a_is_k_major else ki * micro_size_k * 64 + i * 64 * k_dim
248-
B_offset = k_dim_offset if b_is_k_major else k_dim_offset * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
264+
A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + (ki // ak_atom_size) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
265+
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (ki % bk_atom_size) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k
249266
C_offset = i * warp_cols * local_size_out # 4 warps as an unit
250267
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major,
251268
b_is_k_major, a_dtype_abbrv, b_dtype_abbrv,
@@ -300,7 +317,7 @@ def wgmma_rs(self,
300317
if b_n_axis_atoms <= 1:
301318
b_leading_byte_offset = 0
302319
else:
303-
b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim
320+
b_leading_byte_offset = 8 * b_swizzle_mode.swizzle_atom_size() * (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
304321

305322
if b_n_axis_atoms <= 1:
306323
b_stride_byte_offset = 8 * elems_in_bytes * n_dim

tilelang/layout/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@
99
make_full_bank_swizzled_layout, # noqa: F401
1010
make_half_bank_swizzled_layout, # noqa: F401
1111
make_quarter_bank_swizzled_layout, # noqa: F401
12+
make_linear_layout, # noqa: F401
1213
)
1314
from .gemm_sp import make_metadata_layout # noqa: F401

tilelang/layout/swizzle.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer,
2525
assert len(buffer.shape) == 2
2626
if continuity is None:
2727
continuity = int(buffer.shape[1])
28+
print(f"make_wgmma_swizzled_layout: {buffer.shape[0]}, {buffer.shape[1]}, {continuity}, {tvm.DataType(buffer.dtype).bits}, {k_major}")
2829
return _ffi_api.make_wgmma_swizzled_layout(
2930
int(buffer.shape[0]),
3031
int(buffer.shape[1]),
@@ -107,3 +108,23 @@ def make_quarter_bank_swizzled_layout(*args):
107108
continuous,
108109
element_size,
109110
)
111+
112+
def make_linear_layout(*args):
113+
"""
114+
Args:
115+
args: buffer or (stride, continuous)
116+
Examples:
117+
make_linear_layout(buffer)
118+
make_linear_layout(stride, continuous)
119+
"""
120+
if len(args) == 1:
121+
buffer = args[0]
122+
stride, continuous = int(buffer.shape[0]), int(buffer.shape[1])
123+
elif len(args) == 2:
124+
stride, continuous = args
125+
else:
126+
raise ValueError(f"Invalid arguments: {args}")
127+
return _ffi_api.make_linear_layout(
128+
stride,
129+
continuous,
130+
)

tilelang/tileop/gemm/gemm_wgmma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def infer_layout(self, target: Target, thread_nums: int):
3434

3535
if self.is_gemm_ss():
3636
a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp
37-
b_continuity = self.N if b_is_k_major else 4 * self.K // n_warp
37+
b_continuity = self.K if b_is_k_major else self.N // n_warp
38+
3839
return {
3940
# WGMMA does not support padding
4041
self.A:

0 commit comments

Comments
 (0)