[Enhancement][AMD] Add preshuffle fp8 gemm example on amd.#1605
[Enhancement][AMD] Add preshuffle fp8 gemm example on amd.#1605LeiWang1999 merged 3 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds a new example module that implements an autotuned TileLang FP8 GEMM using AMD MFMA with optional B preshuffling, plus utilities to build, validate, benchmark, and test the kernel against a PyTorch reference. Changes
Sequence Diagram(s)Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Fix all issues with AI Agents 🤖
In @examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py:
- Around line 193-194: The tensor dtype conversion is wrong: replace calls using
getattr(torch, in_dtype) when constructing tensors A and B with the tilelang
dtype's as_torch() (i.e., use in_dtype.as_torch()) and similarly replace any
getattr(torch, out_dtype) occurrences (used later when creating other tensors)
with out_dtype.as_torch(); locate uses of the variables in_dtype and out_dtype
around the tensor constructions (e.g., where A and B are created and the later
tensor creations referenced in the review) and call .as_torch() to obtain a
valid torch.dtype for torch.rand()/tensor constructors.
🧹 Nitpick comments (4)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (4)
78-80: Remove unused variablelocal_size_c.The variable is assigned but never used in the kernel implementation.
Proposed fix
local_size_a = mfma_emitter.local_size_a local_size_b = mfma_emitter.local_size_b - local_size_c = mfma_emitter.local_size_out
92-102:B_shared_shapeis computed but never used.Unlike
A_shared, there's noB_sharedbuffer allocation - B is loaded directly from global memory vialdmatrix_bwithpid_m/pid_n. If this is intentional for the preshuffle design, consider removing the dead code. If shared memory for B is planned, this appears incomplete.Proposed fix (remove unused computation)
- if b_preshuffle: - B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y, - pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y, - pack_size_k, micro_size_y) - B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, - pack_size_k) if b_transposed else (block_K // pack_size_k, - block_N // micro_size_y, pack_size_k, - micro_size_y) - else: - B_shape = (N, K) if b_transposed else (K, N) - B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) + if b_preshuffle: + B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y, + pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y, + pack_size_k, micro_size_y) + else: + B_shape = (N, K) if b_transposed else (K, N)
210-226: Simplify reference computation with conditional transpose.The four branches can be consolidated. Also, same
getattr(torch, out_dtype)issue applies here whereout_dtypeisT.float32.Proposed simplification
- if a_transposed and b_transposed: - # Get Reference Result - ref_c = torch.matmul(A.T.half(), - B.T.half()).to(getattr(torch, out_dtype)) - elif a_transposed and not b_transposed: - # Get Reference Result - ref_c = torch.matmul(A.T.half(), - B.half()).to(getattr(torch, out_dtype)) - elif not a_transposed and b_transposed: - # Get Reference Result - ref_c = torch.matmul(A.half(), - B.T.half()).to(getattr(torch, out_dtype)) - else: - # Get Reference Result - ref_c = torch.matmul(A.half(), B.half()).to(getattr(torch, out_dtype)) + A_ref = A.T.half() if a_transposed else A.half() + B_ref = B.T.half() if b_transposed else B.half() + ref_c = torch.matmul(A_ref, B_ref).to(torch.float32) # out_dtype is T.float32
229-238: Test coverage is limited but acceptable for an example.The test only exercises
b_transposed=True, b_preshuffle=True. Consider adding at least one case withb_preshuffle=Falseto validate the non-preshuffle path, especially sinceB_shared_shapeis computed for it.Suggested additional test case
def test_assert_tl_matmul(): assert_tl_matmul_correctness( 512, 512, 512, k_pack=1, b_transposed=True, b_preshuffle=True) assert_tl_matmul_correctness( 512, 512, 512, k_pack=2, b_transposed=True, b_preshuffle=True) + # Test non-preshuffle path + assert_tl_matmul_correctness( + 512, 512, 512, k_pack=1, b_transposed=True, b_preshuffle=False)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧬 Code graph analysis (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (6)
tilelang/layout/swizzle.py (1)
make_swizzled_layout(62-71)tilelang/intrinsics/mfma_macro_generator.py (2)
MatrixCorePreshuffleIntrinEmitter(670-866)mfma(358-395)tilelang/transform/simplify.py (1)
simplify_prim_func(53-58)tilelang/testing/__init__.py (1)
set_random_seed(32-37)tilelang/autotuner/tuner.py (1)
autotune(691-786)tilelang/language/v2/dtypes.py (2)
float8_e4m3fnuz(348-348)float16(299-299)
🪛 Ruff (0.14.10)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
80-80: Local variable local_size_c is assigned to but never used
Remove assignment to unused variable local_size_c
(F841)
102-102: Local variable B_shared_shape is assigned to but never used
Remove assignment to unused variable B_shared_shape
(F841)
🔇 Additional comments (4)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (4)
1-12: LGTM!Imports are appropriate and the random seed initialization ensures reproducibility for the example.
14-29: LGTM!Configuration generator is clean and produces a reasonable set of 54 tuning configurations.
104-153: LGTM!The kernel implementation correctly implements tiled GEMM with pipelining, using shared memory for A and direct global memory access for preshuffled B. The mfma_emitter abstraction handles the MFMA intrinsics cleanly.
156-172: LGTM!The weight shuffling logic correctly reshapes and permutes the tensor for the preshuffle layout, with proper divisibility checks.
| A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) | ||
| B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py | head -250Repository: tile-ai/tilelang
Length of output: 9312
🏁 Script executed:
rg "float8_e4m3fnuz" --type py -B 5 -A 5 | head -100Repository: tile-ai/tilelang
Length of output: 8697
🏁 Script executed:
fd "language.py" --type f | head -5Repository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
fd "language" --type f | grep -E "^tilelang.*\.py$" | head -10Repository: tile-ai/tilelang
Length of output: 86
🏁 Script executed:
rg "class.*float8_e4m3fnuz|float8_e4m3fnuz\s*=" --type py -B 2 -A 2 | head -50Repository: tile-ai/tilelang
Length of output: 2305
🏁 Script executed:
cat tilelang/language/v2/dtypes.py | head -100Repository: tile-ai/tilelang
Length of output: 2447
🏁 Script executed:
cat tilelang/language/v2/dtypes.py | sed -n '100,250p'Repository: tile-ai/tilelang
Length of output: 5383
🏁 Script executed:
python3 << 'EOF'
import sys
sys.path.insert(0, '/tmp/tilelang')
# Check if TVM DataType has __str__ method and what it returns
from tvm import DataType
test_dtype = DataType("float8_e4m3fnuz")
print(f"Type of test_dtype: {type(test_dtype)}")
print(f"String representation: {str(test_dtype)}")
print(f"Has as_torch: {hasattr(test_dtype, 'as_torch')}")
# Try getattr with the dtype object
try:
result = getattr(__import__('torch'), test_dtype)
print(f"getattr(torch, dtype_object) succeeded: {result}")
except TypeError as e:
print(f"getattr(torch, dtype_object) failed with TypeError: {e}")
except AttributeError as e:
print(f"getattr(torch, dtype_object) failed with AttributeError: {e}")
EOFRepository: tile-ai/tilelang
Length of output: 146
🏁 Script executed:
rg "\.as_torch\(\)" --type py -B 2 -A 2 | head -50Repository: tile-ai/tilelang
Length of output: 3171
🏁 Script executed:
sed -n '175,233p' examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.pyRepository: tile-ai/tilelang
Length of output: 2024
🏁 Script executed:
rg "def __str__|def __repr__" tilelang/language/v2/dtypes.py -B 2 -A 5Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
rg "getattr\(torch," examples/ --type py -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 28083
🏁 Script executed:
cat examples/gemm_fp8/example_tilelang_gemm_fp8.py | sed -n '160,230p'Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
wc -l examples/gemm_fp8/example_tilelang_gemm_fp8.pyRepository: tile-ai/tilelang
Length of output: 109
🏁 Script executed:
cat examples/gemm_fp8/example_tilelang_gemm_fp8.pyRepository: tile-ai/tilelang
Length of output: 2406
🏁 Script executed:
sed -n '176,232p' examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py | cat -nRepository: tile-ai/tilelang
Length of output: 2421
Use .as_torch() to convert tilelang dtypes to torch dtypes.
Lines 199-200 pass in_dtype (a tilelang dtype object) to getattr(torch, in_dtype), which expects a string attribute name. This will fail at runtime. The correct approach is to use the .as_torch() method:
Diff
- A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
- B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
+ A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())
+ B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())Also apply the same fix on lines 221, 224, 227, 230 where out_dtype is used with getattr(torch, out_dtype):
- ref_c = torch.matmul(...).to(getattr(torch, out_dtype))
+ ref_c = torch.matmul(...).to(out_dtype.as_torch())🤖 Prompt for AI Agents
In @examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py around lines
193-194, The tensor dtype conversion is wrong: replace calls using
getattr(torch, in_dtype) when constructing tensors A and B with the tilelang
dtype's as_torch() (i.e., use in_dtype.as_torch()) and similarly replace any
getattr(torch, out_dtype) occurrences (used later when creating other tensors)
with out_dtype.as_torch(); locate uses of the variables in_dtype and out_dtype
around the tensor constructions (e.g., where A and B are created and the later
tensor creations referenced in the review) and call .as_torch() to obtain a
valid torch.dtype for torch.rand()/tensor constructors.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (2)
199-200: Use.as_torch()to convert tilelang dtypes to torch dtypes.
getattr(torch, in_dtype)will fail becausein_dtypeis a tilelang dtype object, not a string. Use the.as_torch()method instead.Suggested fix
- A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) - B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) + A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch()) + B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())
219-231: Same dtype conversion issue in reference computation.All
getattr(torch, out_dtype)calls need to useout_dtype.as_torch().Suggested fix
if a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.T.half(), B.T.half()).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.T.half(), B.T.half()).to(out_dtype.as_torch()) elif a_transposed and not b_transposed: # Get Reference Result - ref_c = torch.matmul(A.T.half(), B.half()).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.T.half(), B.half()).to(out_dtype.as_torch()) elif not a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.half(), B.T.half()).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.half(), B.T.half()).to(out_dtype.as_torch()) else: # Get Reference Result - ref_c = torch.matmul(A.half(), B.half()).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.half(), B.half()).to(out_dtype.as_torch())
🧹 Nitpick comments (3)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (3)
80-82: Remove unused variablelocal_size_c.The variable is assigned but never referenced. This was flagged by static analysis (Ruff F841).
Suggested fix
local_size_a = mfma_emitter.local_size_a local_size_b = mfma_emitter.local_size_b - local_size_c = mfma_emitter.local_size_out
100-107: Remove unusedB_shared_shapecomputation.This variable is computed but never used. The kernel loads B directly from global memory via
ldmatrix_brather than through shared memory. Flagged by static analysis (Ruff F841).Suggested fix
if b_preshuffle: B_shape = ( (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y) ) - B_shared_shape = ( - (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k) - if b_transposed - else (block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y) - ) else: B_shape = (N, K) if b_transposed else (K, N) - B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
235-237: Consider expanding test coverage for transpose combinations.Currently only tests
a_transposed=False, b_transposed=True. The kernel supports all four transpose combinations, but only one is validated. Consider adding tests for other configurations if they are intended to be supported.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧬 Code graph analysis (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (3)
tilelang/layout/swizzle.py (1)
make_swizzled_layout(62-71)tilelang/intrinsics/mfma_macro_generator.py (2)
MatrixCorePreshuffleIntrinEmitter(670-866)mfma(358-395)tilelang/testing/__init__.py (1)
set_random_seed(32-37)
🪛 Ruff (0.14.10)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
82-82: Local variable local_size_c is assigned to but never used
Remove assignment to unused variable local_size_c
(F841)
107-107: Local variable B_shared_shape is assigned to but never used
Remove assignment to unused variable B_shared_shape
(F841)
🔇 Additional comments (4)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (4)
1-11: LGTM!Imports are appropriate and the random seed initialization ensures reproducible autotuning as expected for example/test code.
14-31: LGTM!The config generation is clean and produces a reasonable search space for autotuning.
109-154: LGTM!The kernel body correctly implements the preshuffled MFMA GEMM pattern: A goes through shared memory with swizzled layout, B loads directly from global memory via the preshuffle path, and the pipeline structure is appropriate.
157-173: LGTM!The
shuffle_weightfunction correctly implements the preshuffle transformation required for MFMA access patterns. The reshape-permute-contiguous pattern is standard for this kind of layout transformation.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (1)
188-189: Use.as_torch()to convert TileLang dtypes to torch dtypes.Lines 188-189 pass
in_dtype(a TileLang dtype object) togetattr(torch, in_dtype), which expects a string attribute name. This will fail at runtime.The same issue occurs on lines 203, 206, 209, and 212 with
out_dtype.🔎 Proposed fix
- A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) - B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) + A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch()) + B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())Also fix lines 203, 206, 209, 212:
if a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.T.half(), B.T.half()).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.T.half(), B.T.half()).to(out_dtype.as_torch()) elif a_transposed and not b_transposed: # Get Reference Result - ref_c = torch.matmul(A.T.half(), B.half()).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.T.half(), B.half()).to(out_dtype.as_torch()) elif not a_transposed and b_transposed: # Get Reference Result - ref_c = torch.matmul(A.half(), B.T.half()).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.half(), B.T.half()).to(out_dtype.as_torch()) else: # Get Reference Result - ref_c = torch.matmul(A.half(), B.half()).to(getattr(torch, out_dtype)) + ref_c = torch.matmul(A.half(), B.half()).to(out_dtype.as_torch())
🧹 Nitpick comments (2)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (2)
54-54: Consider makingb_preshufflea function parameter.Currently
b_preshuffleis hardcoded toTrue. Making it a parameter would improve flexibility for benchmarking and comparing preshuffled vs. non-preshuffled performance.🔎 Suggested refactor
@tilelang.jit(out_idx=[-1]) def tl_matmul( M, N, K, block_M, block_N, block_K, num_stages, k_pack=2, num_threads=256, in_dtype=T.float8_e4m3fnuz, out_dtype=T.float32, accum_dtype=T.float32, a_transposed=False, b_transposed=True, + b_preshuffle=True, ): - b_preshuffle = True warp_size = 64
217-222: Consider adding tests for different configurations.The current test only exercises one configuration (512×512×512 with k_pack=2). Adding tests for different matrix sizes, k_pack values, and transpose combinations would improve coverage.
For example:
def test_assert_tl_matmul(): # Test basic case assert_tl_matmul_correctness(512, 512, 512, k_pack=2) # Test different k_pack assert_tl_matmul_correctness(256, 256, 256, k_pack=1) # Test transpose combinations assert_tl_matmul_correctness(512, 512, 512, k_pack=2, a_transposed=True, b_transposed=False)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧬 Code graph analysis (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (1)
tilelang/intrinsics/mfma_macro_generator.py (3)
MatrixCorePreshuffleIntrinEmitter(670-866)make_mfma_store_layout(573-640)mfma(358-395)
🔇 Additional comments (7)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (7)
1-11: LGTM!Imports are appropriate and the random seed initialization ensures reproducible test results.
14-31: LGTM!Configuration generation logic is correct and provides a reasonable search space for autotuning.
34-53: LGTM!Function signature and decorators are appropriate for an autotuned TileLang kernel.
90-97: LGTM!Shape calculations correctly account for transposition and preshuffle layout requirements.
99-144: LGTM!Kernel implementation follows the expected pattern for tiled GEMM with preshuffle:
- Proper memory hierarchy (shared → local → fragment)
- Correct pipelining with configurable stages
- Appropriate layout annotations for MFMA
147-163: LGTM!The shuffle logic correctly transforms the weight tensor into the preshuffle layout expected by the kernel.
201-212: Transpose handling logic is correct.The reference computation correctly handles all four transpose combinations, matching the kernel semantics.
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.