-
Notifications
You must be signed in to change notification settings - Fork 332
[TileOp] Implement WGMMA for T.gemm_v2 #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
ec26c23
2ff5cbf
0166a90
ce83ace
72e900d
22131e7
6632a70
eac5433
51fcf15
ce9f545
70699a9
2dbaccc
d2db013
ce9e2b6
fef8d2a
bd9bd37
ff3e04d
c6ab014
cc9e32f
5244c19
3b5c075
caa2e51
3858d81
4cdd131
8783aad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -166,13 +166,14 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, | |
| Layout makeGemmLayoutLinear(int stride, int continuous); | ||
| Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size); | ||
| Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, | ||
| int element_size, int kfactor); | ||
| int element_size, bool k_inner = true); | ||
| Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, | ||
| int continuity, int element_size, int kfactor); | ||
| int continuity, int element_size, | ||
| bool k_inner = true); | ||
| Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, | ||
| int element_size, int kfactor); | ||
| int element_size, bool k_inner = true); | ||
| Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, | ||
| int kfactor); | ||
| int kPack); | ||
|
Comment on lines
+169
to
+176
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainAudit callers after k_inner bool API change APIs now take Also applies to: 171-173, 174-176, 185-186 🏁 Script executed: #!/bin/bash
# Find suspicious non-bool last-arg usages of Gemm AB layout constructors.
rg -nPU -C2 '\bmakeGemmABLayout(?:Hopper|Sm100)?\s*\('
# Heuristic: show calls where 5th arg looks like a ternary yielding 1/2
rg -nPU '\bmakeGemmABLayout(?:Hopper|Sm100)?\s*\((?:[^,]*,){4}\s*[^),]*\?\s*[12]\s*:\s*[12]\s*\)' -n -C2Length of output: 8230 Fix non-boolean arguments to makeGemmABLayout calls*
|
||
|
|
||
| Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, | ||
| const int warp_m, const int warp_n, | ||
|
|
@@ -181,7 +182,7 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, | |
| const int block_k, const int warp_m, | ||
| const int warp_n); | ||
| Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, | ||
| int kfactor); | ||
| bool k_inner = true); | ||
|
|
||
| Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous, | ||
| int elementsize, int crosswise); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,7 +109,7 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { | |
| * @param vmap Mapping from access pointer vars to Buffer objects used to | ||
| * resolve the Buffer corresponding to each pointer argument. | ||
| * | ||
| * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor | ||
| * @note If `kPack` is provided it must be 1; otherwise the constructor | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| * fails with an ICHECK (runtime assertion). No other validation is | ||
| * performed here. | ||
| */ | ||
|
|
@@ -670,7 +670,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| int dim_A = A->shape.size(); | ||
| results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]), | ||
| *as_const_int(A->shape[dim_A - 1]), | ||
| true, trans_A ? 1 : 2)); | ||
| true, !trans_A)); | ||
| } else if (A.scope() == "local.fragment") { | ||
| ICHECK(trans_A == false); | ||
| auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n); | ||
|
|
@@ -683,7 +683,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| int dim_B = B->shape.size(); | ||
| results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]), | ||
| *as_const_int(B->shape[dim_B - 1]), | ||
| false, trans_B ? 2 : 1)); | ||
| false, trans_B)); | ||
| } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) || | ||
| TargetIsSM120(T.target) || | ||
| (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) { | ||
|
|
@@ -700,7 +700,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); | ||
| results.Set(A, | ||
| makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
| A->dtype.bits(), trans_A ? 1 : 2)); | ||
| A->dtype.bits(), !trans_A)); | ||
| } else if (A.scope() == "local.fragment") { | ||
| auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, | ||
| A->dtype.bits(), trans_A); | ||
|
|
@@ -714,7 +714,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); | ||
| results.Set(B, | ||
| makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
| B->dtype.bits(), trans_B ? 2 : 1)); | ||
| B->dtype.bits(), trans_B)); | ||
| } else if (B.scope() == "local.fragment") { | ||
| auto fragment = | ||
| makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); | ||
|
|
@@ -741,9 +741,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| auto ABLayout = | ||
| gemm_inst == GemmInst::kWGMMA | ||
| ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, | ||
| A->dtype.bits(), trans_A ? 1 : 2) | ||
| A->dtype.bits(), !trans_A) | ||
| : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
| A->dtype.bits(), trans_A ? 1 : 2); | ||
| A->dtype.bits(), !trans_A); | ||
| results.Set(A, ABLayout); | ||
| } else { | ||
| auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, | ||
|
|
@@ -756,12 +756,13 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, | |
| const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); | ||
| const int64_t continuity = | ||
| trans_B ? mat_continuous : mat_continuous / warp_n; | ||
|
|
||
| auto ABLayout = | ||
| gemm_inst == GemmInst::kWGMMA | ||
| ? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, | ||
| B->dtype.bits(), trans_B ? 2 : 1) | ||
| B->dtype.bits(), trans_B) | ||
| : makeGemmABLayout(mat_stride, mat_continuous, mat_continuous, | ||
| B->dtype.bits(), trans_B ? 2 : 1); | ||
| B->dtype.bits(), trans_B); | ||
| results.Set(B, ABLayout); | ||
| } else { | ||
| auto fragment = | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing import for
torch.The code references
torch.bfloat16on line 37, but thetorchmodule is not imported in this file. This will cause aNameErrorat runtime when the test is executed.Apply this diff to add the missing import:
# ruff: noqa +import torch import tilelang.testing📝 Committable suggestion
🤖 Prompt for AI Agents