-
Notifications
You must be signed in to change notification settings - Fork 332
[TileOp] Implement WGMMA for T.gemm_v2 #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
ec26c23
2ff5cbf
0166a90
ce83ace
72e900d
22131e7
6632a70
eac5433
51fcf15
ce9f545
70699a9
2dbaccc
d2db013
ce9e2b6
fef8d2a
bd9bd37
ff3e04d
c6ab014
cc9e32f
5244c19
3b5c075
caa2e51
3858d81
4cdd131
8783aad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,7 +42,7 @@ using namespace tir; | |
| * @param vmap Mapping from access pointer vars to Buffer objects used to | ||
| * resolve the Buffer corresponding to each pointer argument. | ||
| * | ||
| * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor | ||
| * @note If `kPack` is provided it must be 1; otherwise the constructor | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| * fails with an ICHECK (runtime assertion). No other validation is | ||
| * performed here. | ||
| */ | ||
|
|
@@ -478,7 +478,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| int dim_A = A->shape.size(); | ||
| results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]), | ||
| *as_const_int(A->shape[dim_A - 1]), | ||
| true, trans_A ? 1 : 2)); | ||
| true, !trans_A)); | ||
| } else if (A.scope() == "local.fragment") { | ||
| ICHECK(trans_A == false); | ||
| auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n); | ||
|
|
@@ -491,7 +491,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| int dim_B = B->shape.size(); | ||
| results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]), | ||
| *as_const_int(B->shape[dim_B - 1]), | ||
| false, trans_B ? 2 : 1)); | ||
| false, trans_B)); | ||
| } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) || | ||
| TargetIsSM120(T.target)) { | ||
| auto fragment = | ||
|
|
@@ -504,7 +504,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); | ||
| results.Set(A, | ||
| makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
| A->dtype.bits(), trans_A ? 1 : 2)); | ||
| A->dtype.bits(), !trans_A)); | ||
| } else if (A.scope() == "local.fragment") { | ||
| auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, | ||
| A->dtype.bits(), trans_A); | ||
|
|
@@ -518,7 +518,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); | ||
| results.Set(B, | ||
| makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
| B->dtype.bits(), trans_B ? 2 : 1)); | ||
| B->dtype.bits(), trans_B)); | ||
| } else if (B.scope() == "local.fragment") { | ||
| auto fragment = | ||
| makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); | ||
|
|
@@ -542,9 +542,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| auto ABLayout = | ||
| gemm_inst == GemmInst::kWGMMA | ||
| ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, | ||
| A->dtype.bits(), trans_A ? 1 : 2) | ||
| A->dtype.bits(), !trans_A) | ||
| : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
| A->dtype.bits(), trans_A ? 1 : 2); | ||
| A->dtype.bits(), !trans_A); | ||
| results.Set(A, ABLayout); | ||
| } else { | ||
| auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, | ||
|
|
@@ -560,9 +560,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| auto ABLayout = | ||
| gemm_inst == GemmInst::kWGMMA | ||
| ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, | ||
| B->dtype.bits(), trans_B ? 2 : 1) | ||
| B->dtype.bits(), trans_B) | ||
| : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
| B->dtype.bits(), trans_B ? 2 : 1); | ||
| B->dtype.bits(), trans_B); | ||
| results.Set(B, ABLayout); | ||
| } else { | ||
| auto fragment = | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,6 +106,8 @@ GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size, | |
| return GemmInst::kMMA; | ||
| } else { | ||
| ICHECK(0) << "Unsupported target for gemm: " << target->str(); | ||
| return GemmInst::kMMA; // This line will never be reached due to ICHECK, but | ||
| // satisfies compiler | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -225,8 +227,9 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | |
| M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); | ||
|
|
||
| if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { | ||
| auto prim_func = Downcast<PrimFunc>( | ||
| (*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds, T.thread_var)); | ||
| auto prim_func = | ||
| Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target, | ||
| T.thread_bounds, T.thread_var)); | ||
| ICHECK(prim_func->attrs.defined()); | ||
|
Comment on lines
+231
to
233
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion 🧩 Analysis chainFFI lower signature now includes layout_map; ensure all providers updated. The C++ call passes T.layout_map. Verify the Python FFI target “tl.gemm_py.lower” accepts (gemm_py, layout_map, target, thread_bounds, thread_var) everywhere to avoid runtime errors. 🏁 Script executed: #!/bin/bash
# Confirm the new signature at Python side
rg -nP 'def\s+lower\((self|gemm_py)[^)]*layout_map' -C2
# Find the FFI registration or dispatcher for "tl.gemm_py.lower"
rg -n 'gemm_py\.lower' -S
# Check base class signatures to avoid LSP mismatches
rg -nP 'class\s+GemmBase\b.*?\n' -n tilelang/tileop/gemm -C2
rg -nP 'def\s+lower\(' tilelang/tileop/gemm -SLength of output: 2364 Add layout_map to GemmBase.lower Change tilelang/tileop/gemm/gemm_base.py:18 from |
||
| auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol"); | ||
| ICHECK(global_symbol.defined()); | ||
|
|
@@ -249,6 +252,8 @@ Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { | |
| /*name_hint=*/global_symbol.value(), prim_func->body)); | ||
| } else { | ||
| LOG(FATAL) << "No lower function found for gemm_py"; | ||
| return Stmt(); // This line will never be reached due to LOG(FATAL), but | ||
| // satisfies compiler | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -275,5 +280,14 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py) | |
| Integer(CallEffectKind::kOpaque)); | ||
|
|
||
| TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); }); | ||
|
|
||
| TVM_FFI_STATIC_INIT_BLOCK({ | ||
| namespace refl = tvm::ffi::reflection; | ||
| refl::GlobalDef().def("tl.GemmPyGemmInst", | ||
| [](GemmPy gemm_py, int block_size, Target target) { | ||
| return gemm_py->GetGemmInst(block_size, target); | ||
| }); | ||
| }); | ||
|
|
||
| } // namespace tl | ||
| } // namespace tvm | ||
Uh oh!
There was an error while loading. Please reload this page.