Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ec26c23
[Feature] Introduce WGMMA support and enhance GEMM layout handling
LeiWang1999 Sep 14, 2025
2ff5cbf
[Refactor] Clean up code formatting and enhance layout function reada…
LeiWang1999 Sep 14, 2025
0166a90
[Feature] Add descriptor initialization and offset manipulation for W…
LeiWang1999 Sep 15, 2025
ce83ace
[Refactor] Improve code formatting and readability in various files
LeiWang1999 Sep 15, 2025
72e900d
[Update] Update subproject commit and refactor layout function call
LeiWang1999 Sep 15, 2025
22131e7
support more data types
LeiWang1999 Sep 16, 2025
6632a70
gemm_rs support
LeiWang1999 Sep 16, 2025
eac5433
lint fix
LeiWang1999 Sep 16, 2025
51fcf15
wgmma wrapper
LeiWang1999 Sep 17, 2025
ce9f545
Remove debug logging for wgmma assembly code and refactor swizzle byt…
LeiWang1999 Sep 18, 2025
70699a9
Merge branch 'main' of https://github.com/tile-ai/tilelang into v2_wg…
LeiWang1999 Oct 7, 2025
2dbaccc
Refactor GEMM layout functions to replace 'kfactor' with 'k_inner' fo…
LeiWang1999 Oct 7, 2025
d2db013
Comprehensively support WGMMA GEMM SS
LeiWang1999 Oct 8, 2025
ce9e2b6
remove debug print
LeiWang1999 Oct 8, 2025
fef8d2a
lint fix
LeiWang1999 Oct 8, 2025
bd9bd37
remove debug print
LeiWang1999 Oct 9, 2025
ff3e04d
reduce bwd test shape
LeiWang1999 Oct 9, 2025
c6ab014
lint fix
LeiWang1999 Oct 9, 2025
cc9e32f
clear cache for pytest
LeiWang1999 Oct 9, 2025
5244c19
lint fix
LeiWang1999 Oct 9, 2025
3b5c075
Update sparse MLA examples to support SKV adjustment and correctness …
LeiWang1999 Oct 7, 2025
caa2e51
test fix
LeiWang1999 Oct 9, 2025
3858d81
adjust test case
LeiWang1999 Oct 9, 2025
4cdd131
test fix
LeiWang1999 Oct 9, 2025
8783aad
skip some test currently
LeiWang1999 Oct 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
ICHECK(block_m % warp_m == 0);
// ICHECK(block_n == warp_n);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;

auto warp_layout = makeGemmFragment8x8()->Repeat({2, warp_n / 8}, false,
false); // 16 x N (1 warp)
auto block_layout = warp_layout->Repeat({block_m / warp_m, block_n / warp_n},
Expand Down Expand Up @@ -385,6 +385,7 @@ Layout makeQuarterBankSwizzleLayout(int stride, int continuous,
Var i = InputPlaceholder(0);
Var j = InputPlaceholder(1);
int vector_size = 128 / element_size;
LOG(INFO) << "makeQuarterBankSwizzleLayout: " << stride << ", " << continuous << ", " << element_size;
ICHECK(stride % 8 == 0) << "stride=" << stride;
ICHECK(continuous % (vector_size * 2) == 0)
<< "continuous=" << continuous << ", vector_size=" << vector_size;
Expand Down Expand Up @@ -576,8 +577,8 @@ Layout MakeGemmVoltaBLayoutCongruous(int stride, int continuous) {
}

Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor) {
if (kfactor == 2)
bool k_inner) {
if (k_inner)
return MakeGemmVoltaABLayoutCrosswise(stride, continuous);
if (is_a && continuous % 64 == 0)
return MakeGemmVoltaALayoutCongruous(stride, continuous);
Expand Down Expand Up @@ -705,29 +706,29 @@ Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous,
* select specific swizzling strategies. It might be the same as mat_continuous
* or different based on tiling or hardware details.
* \param element_size The size of each element in the matrix, in bits (e.g., 8,
* 16, 32, 64). \param kfactor An integer factor that influences layout
* 16, 32, 64). \param k_inner Whether the K dimension is in the inner loop.
* selection, particularly for fp64 and int8 types. It often relates to how the
* K dimension of the GEMM (M x K * K x N) is handled or tiled.
* - For fp64 (element_size == 64):
* - kfactor == 1 often implies K is in the "outer" loop (e.g.,
* KxN matrix).
* - kfactor == 2 often implies K is in the "inner" loop (e.g.,
* NxK matrix).
* - k_inner == false often implies K is in the "outer" loop
* (e.g., KxN matrix).
* - k_inner == true often implies K is in the "inner" loop
* (e.g., NxK matrix).
* - For int8 (element_size == 8):
* - kfactor == 1 uses a padded layout.
* - k_inner == false uses a padded layout.
* \return A Layout object representing the chosen memory layout.
*/
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor) {
int element_size, bool k_inner) {
if (element_size == 64) {
if (kfactor == 1 && continuity % 16 == 0) // float64 KxN
if (!k_inner && continuity % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
if (kfactor == 2 && continuity % 16 == 0) // float64 NxK
if (k_inner && continuity % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
if (!k_inner && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
Expand All @@ -739,16 +740,23 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
}

Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor) {
int continuity, int element_size, bool k_inner) {
LOG(INFO) << "makeGemmABLayoutHopper: " << mat_stride << ", " << mat_continuous << ", " << continuity << ", " << element_size << ", " << k_inner;
if (element_size == 64) {
if (kfactor == 1 && continuity % 16 == 0) // float64 KxN
if (!k_inner && continuity % 16 == 0) // float64 KxN
return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
if (kfactor == 2 && continuity % 16 == 0) // float64 NxK
if (k_inner && continuity % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
}
int vector_size = 128 / element_size;
LOG(INFO) << "makeGemmABLayoutHopper: mat_continuous: " << mat_continuous << ", mat_stride: " << mat_stride << ", element_size: " << element_size;
LOG(INFO) << "vector_size: " << vector_size;
LOG(INFO) << "mat_continuous % (vector_size * 8): " << mat_continuous % (vector_size * 8);
LOG(INFO) << "mat_continuous % (vector_size * 4): " << mat_continuous % (vector_size * 4);
LOG(INFO) << "mat_continuous % (vector_size * 2): " << mat_continuous % (vector_size * 2);
LOG(INFO) << "mat_continuous % vector_size: " << mat_continuous % vector_size;
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
Expand All @@ -761,11 +769,11 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
else
ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
<< ", element_size=" << element_size << ", k_inner=" << k_inner;
}

Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor) {
int element_size, bool k_inner) {
if (element_size == 64) {
ICHECK(0) << "float64 on sm100 is not supported now";
}
Expand All @@ -782,7 +790,7 @@ Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
else
ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
<< ", element_size=" << element_size << ", k_inner=" << k_inner;
__builtin_unreachable(); // to prevent compiler warning
}

Expand Down
43 changes: 41 additions & 2 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,11 @@ TVM_FFI_STATIC_INIT_BLOCK({
[](Layout layout) { return layout->GetForwardIndex(); })
.def("tl.Layout_forward_vars",
[](Layout layout) { return layout->GetForwardVars(); })
.def("tl.Layout_is_equal",
[](Layout layout, Layout other) {
const LayoutNode *other_node = other.as<LayoutNode>();
return layout->IsEqual(other_node);
})
.def_packed("tl.Fragment",
[](PackedArgs args, Any *rv) {
*rv = Fragment(
Expand All @@ -492,6 +497,11 @@ TVM_FFI_STATIC_INIT_BLOCK({
/*forward_thread=*/args[2].cast<PrimExpr>(),
/*thread_replicate=*/args[3].cast<IterVar>());
})
.def("tl.Fragment_is_equal",
[](Fragment fragment, Fragment other) {
const FragmentNode *other_node = other.as<FragmentNode>();
return fragment->IsEqual(other_node);
})
.def("tl.Fragment_thread_size",
[](Fragment fragment) { return fragment->ThreadExtent(); })
.def("tl.Fragment_thread",
Expand All @@ -509,9 +519,38 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("tl.Fragment_condense_rep_var",
[](Fragment fragment) { return fragment->CondenseReplicateVar(); })
.def("tl.make_swizzled_layout",
[](int stride, int continuous, int element_size, bool k_inner,
bool allow_pad = true) {
if (allow_pad) {
return makeGemmABLayout(stride, continuous, continuous,
element_size, k_inner);
} else {
return makeGemmABLayoutHopper(stride, continuous, continuous,
element_size, k_inner);
}
})
.def("tl.make_wgmma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_full_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeGemmABLayout(stride, continuous, continuous,
element_size, 0);
return makeFullBankSwizzleLayout(stride, continuous, element_size);
})
.def("tl.make_half_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeHalfBankSwizzleLayout(stride, continuous, element_size);
})
.def("tl.make_quarter_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeQuarterBankSwizzleLayout(stride, continuous,
element_size);
})
.def("tl.make_linear_layout",
[](int stride, int continuous) {
return makeGemmLayoutLinear(stride, continuous);
});
});

Expand Down
10 changes: 5 additions & 5 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
int element_size, bool k_inner = true);
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor);
int continuity, int element_size, bool k_inner = true);
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
int element_size, bool k_inner = true);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kfactor);
int kPack);

Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
Expand All @@ -181,7 +181,7 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor);
bool k_inner = true);

Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous,
int elementsize, int crosswise);
Expand Down
20 changes: 20 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_wgmma_ss)
.set_num_inputs(15)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
.set_num_inputs(15)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down Expand Up @@ -239,5 +249,15 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(initialize_descriptor)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm
44 changes: 42 additions & 2 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,21 +216,43 @@ TVM_DLL const Op &mbarrier_wait_parity();
*/
TVM_DLL const Op &mbarrier_expect_tx();

/*!
* \brief tvm intrinsic for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_ss(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
* scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_ss();

/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, bool
* scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_rs();

/*!
* \brief tvm intrinsics for initializing tensor memory
*
* ptx_init_tensor_memory(tmem_buffer, num_cols)
*
*/
const Op &ptx_init_tensor_memory();
TVM_DLL const Op &ptx_init_tensor_memory();

/*!
* \brief tvm intrinsics for deallocating tensor memory
*
* tmem_deallocate(tmem_buffer)
*
*/
const Op &ptx_deallocate_tensor_memory();
TVM_DLL const Op &ptx_deallocate_tensor_memory();

/*!
* \brief tvm intrinsics for ldmatrix
Expand Down Expand Up @@ -398,6 +420,24 @@ TVM_DLL const Op &tl_gemm_sp();
*/
TVM_DLL const Op &tl_shuffle_elect();

/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* wgmma/utcmma.
*
* This op is used to represent a descriptor initialization operation in
* tilelang.
*/
TVM_DLL const Op &initialize_descriptor();

/*!
* \brief tilelang intrinsic for setting the start address of a descriptor
* buffer for wgmma/utcmma.
*
* This op is used to represent a descriptor start address setting operation in
* tilelang.
*/
TVM_DLL const Op &increase_descriptor_offset();

} // namespace tl
} // namespace tvm

Expand Down
20 changes: 11 additions & 9 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* @note If `kPack` is provided it must be 1; otherwise the constructor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Gemm Constructor Documentation Mismatch

The Gemm constructor's documentation for kPack is incomplete and inconsistent with its validation logic. The comment states kPack must be 1, but the underlying validation permits 1 or 2.

Fix in Cursor Fix in Web

* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
Expand Down Expand Up @@ -670,7 +670,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
int dim_A = A->shape.size();
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
*as_const_int(A->shape[dim_A - 1]),
true, trans_A ? 1 : 2));
true, !trans_A));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
Expand All @@ -683,7 +683,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
int dim_B = B->shape.size();
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
*as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1));
false, trans_B));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
TargetIsSM120(T.target) ||
(TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
Expand All @@ -700,7 +700,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2));
A->dtype.bits(), !trans_A));
} else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
Expand All @@ -714,7 +714,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
results.Set(B,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1));
B->dtype.bits(), trans_B));
} else if (B.scope() == "local.fragment") {
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
Expand All @@ -741,9 +741,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
auto ABLayout =
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2)
A->dtype.bits(), !trans_A)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2);
A->dtype.bits(), !trans_A);
results.Set(A, ABLayout);
} else {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
Expand All @@ -756,12 +756,14 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;

LOG(INFO) << "gemm_inst: " << (int)gemm_inst << ", trans_B: " << trans_B;
auto ABLayout =
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1)
B->dtype.bits(), trans_B)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1);
B->dtype.bits(), trans_B);
results.Set(B, ABLayout);
} else {
auto fragment =
Expand Down
Loading
Loading