Skip to content

[Enhancement][AMD] Add preshuffle fp8 gemm example on amd.#1605

Merged
LeiWang1999 merged 3 commits intotile-ai:mainfrom
Gongen-Ali:main
Jan 5, 2026
Merged

[Enhancement][AMD] Add preshuffle fp8 gemm example on amd.#1605
LeiWang1999 merged 3 commits intotile-ai:mainfrom
Gongen-Ali:main

Conversation

@Gongen-Ali
Copy link
Collaborator

@Gongen-Ali Gongen-Ali commented Jan 5, 2026

Summary by CodeRabbit

  • New Features

    • Added an autotuned FP8 tiled GEMM example for AMD with configurable tiling/staging and optional B preshuffling to improve performance.
    • Includes utilities to reorder weights for preshuffle and explicit FP8 input / FP32 accumulate-output support.
  • Tests

    • Adds end-to-end correctness checks and performance benchmarking against reference matmul across autotune configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

github-actions bot commented Jan 5, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 5, 2026

📝 Walkthrough

Walkthrough

Adds a new example module that implements an autotuned TileLang FP8 GEMM using AMD MFMA with optional B preshuffling, plus utilities to build, validate, benchmark, and test the kernel against a PyTorch reference.

Changes

Cohort / File(s) Change Summary
FP8 GEMM with Preshuffling Example
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
New file adding tl_matmul (autotuned, JIT TileLang MFMA GEMM with block_M/block_N/block_K/num_stages and k_pack, supports A/B transpose and B preshuffle), shuffle_weight (B preshuffle/reorder), assert_tl_matmul_correctness (build/validate/benchmark against PyTorch), and test_assert_tl_matmul (smoke test).

Sequence Diagram(s)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I hopped through tiles where MFMA sings,
I shuffled B and tuned tiny things,
FP8 bytes in a moonlit sprawl,
Kernels compile — I clap my paws,
Hooray — matrices leap and fall! 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the main change: adding a new preshuffle FP8 GEMM example for AMD hardware, which aligns with the new file introducing an autotuned TileLang-based FP8 GEMM with B preshuffling.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Gongen-Ali Gongen-Ali changed the title []Add preshuffle gemm example on amd. [Enhancement][AMD] Add preshuffle fp8 gemm example on amd. Jan 5, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Fix all issues with AI Agents 🤖
In @examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py:
- Around line 193-194: The tensor dtype conversion is wrong: replace calls using
getattr(torch, in_dtype) when constructing tensors A and B with the tilelang
dtype's as_torch() (i.e., use in_dtype.as_torch()) and similarly replace any
getattr(torch, out_dtype) occurrences (used later when creating other tensors)
with out_dtype.as_torch(); locate uses of the variables in_dtype and out_dtype
around the tensor constructions (e.g., where A and B are created and the later
tensor creations referenced in the review) and call .as_torch() to obtain a
valid torch.dtype for torch.rand()/tensor constructors.
🧹 Nitpick comments (4)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (4)

78-80: Remove unused variable local_size_c.

The variable is assigned but never used in the kernel implementation.

Proposed fix
     local_size_a = mfma_emitter.local_size_a
     local_size_b = mfma_emitter.local_size_b
-    local_size_c = mfma_emitter.local_size_out

92-102: B_shared_shape is computed but never used.

Unlike A_shared, there's no B_shared buffer allocation - B is loaded directly from global memory via ldmatrix_b with pid_m/pid_n. If this is intentional for the preshuffle design, consider removing the dead code. If shared memory for B is planned, this appears incomplete.

Proposed fix (remove unused computation)
-    if b_preshuffle:
-        B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y,
-                   pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y,
-                                                      pack_size_k, micro_size_y)
-        B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y,
-                          pack_size_k) if b_transposed else (block_K // pack_size_k,
-                                                             block_N // micro_size_y, pack_size_k,
-                                                             micro_size_y)
-    else:
-        B_shape = (N, K) if b_transposed else (K, N)
-        B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
+    if b_preshuffle:
+        B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y,
+                   pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y,
+                                                      pack_size_k, micro_size_y)
+    else:
+        B_shape = (N, K) if b_transposed else (K, N)

210-226: Simplify reference computation with conditional transpose.

The four branches can be consolidated. Also, same getattr(torch, out_dtype) issue applies here where out_dtype is T.float32.

Proposed simplification
-    if a_transposed and b_transposed:
-        # Get Reference Result
-        ref_c = torch.matmul(A.T.half(),
-                             B.T.half()).to(getattr(torch, out_dtype))
-    elif a_transposed and not b_transposed:
-        # Get Reference Result
-        ref_c = torch.matmul(A.T.half(),
-                             B.half()).to(getattr(torch, out_dtype))
-    elif not a_transposed and b_transposed:
-        # Get Reference Result
-        ref_c = torch.matmul(A.half(),
-                             B.T.half()).to(getattr(torch, out_dtype))
-    else:
-        # Get Reference Result
-        ref_c = torch.matmul(A.half(), B.half()).to(getattr(torch, out_dtype))
+    A_ref = A.T.half() if a_transposed else A.half()
+    B_ref = B.T.half() if b_transposed else B.half()
+    ref_c = torch.matmul(A_ref, B_ref).to(torch.float32)  # out_dtype is T.float32

229-238: Test coverage is limited but acceptable for an example.

The test only exercises b_transposed=True, b_preshuffle=True. Consider adding at least one case with b_preshuffle=False to validate the non-preshuffle path, especially since B_shared_shape is computed for it.

Suggested additional test case
 def test_assert_tl_matmul():
     assert_tl_matmul_correctness(
         512, 512, 512, k_pack=1, b_transposed=True, b_preshuffle=True)
     assert_tl_matmul_correctness(
         512, 512, 512, k_pack=2, b_transposed=True, b_preshuffle=True)
+    # Test non-preshuffle path
+    assert_tl_matmul_correctness(
+        512, 512, 512, k_pack=1, b_transposed=True, b_preshuffle=False)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 32aec8a and dacab78.

📒 Files selected for processing (1)
  • examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧬 Code graph analysis (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (6)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (62-71)
tilelang/intrinsics/mfma_macro_generator.py (2)
  • MatrixCorePreshuffleIntrinEmitter (670-866)
  • mfma (358-395)
tilelang/transform/simplify.py (1)
  • simplify_prim_func (53-58)
tilelang/testing/__init__.py (1)
  • set_random_seed (32-37)
tilelang/autotuner/tuner.py (1)
  • autotune (691-786)
tilelang/language/v2/dtypes.py (2)
  • float8_e4m3fnuz (348-348)
  • float16 (299-299)
🪛 Ruff (0.14.10)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py

80-80: Local variable local_size_c is assigned to but never used

Remove assignment to unused variable local_size_c

(F841)


102-102: Local variable B_shared_shape is assigned to but never used

Remove assignment to unused variable B_shared_shape

(F841)

🔇 Additional comments (4)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (4)

1-12: LGTM!

Imports are appropriate and the random seed initialization ensures reproducibility for the example.


14-29: LGTM!

Configuration generator is clean and produces a reasonable set of 54 tuning configurations.


104-153: LGTM!

The kernel implementation correctly implements tiled GEMM with pipelining, using shared memory for A and direct global memory access for preshuffled B. The mfma_emitter abstraction handles the MFMA intrinsics cleanly.


156-172: LGTM!

The weight shuffling logic correctly reshapes and permutes the tensor for the preshuffle layout, with proper divisibility checks.

Comment on lines +193 to +194
A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py | head -250

Repository: tile-ai/tilelang

Length of output: 9312


🏁 Script executed:

rg "float8_e4m3fnuz" --type py -B 5 -A 5 | head -100

Repository: tile-ai/tilelang

Length of output: 8697


🏁 Script executed:

fd "language.py" --type f | head -5

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

fd "language" --type f | grep -E "^tilelang.*\.py$" | head -10

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

rg "class.*float8_e4m3fnuz|float8_e4m3fnuz\s*=" --type py -B 2 -A 2 | head -50

Repository: tile-ai/tilelang

Length of output: 2305


🏁 Script executed:

cat tilelang/language/v2/dtypes.py | head -100

Repository: tile-ai/tilelang

Length of output: 2447


🏁 Script executed:

cat tilelang/language/v2/dtypes.py | sed -n '100,250p'

Repository: tile-ai/tilelang

Length of output: 5383


🏁 Script executed:

python3 << 'EOF'
import sys
sys.path.insert(0, '/tmp/tilelang')

# Check if TVM DataType has __str__ method and what it returns
from tvm import DataType

test_dtype = DataType("float8_e4m3fnuz")
print(f"Type of test_dtype: {type(test_dtype)}")
print(f"String representation: {str(test_dtype)}")
print(f"Has as_torch: {hasattr(test_dtype, 'as_torch')}")

# Try getattr with the dtype object
try:
    result = getattr(__import__('torch'), test_dtype)
    print(f"getattr(torch, dtype_object) succeeded: {result}")
except TypeError as e:
    print(f"getattr(torch, dtype_object) failed with TypeError: {e}")
except AttributeError as e:
    print(f"getattr(torch, dtype_object) failed with AttributeError: {e}")
EOF

Repository: tile-ai/tilelang

Length of output: 146


🏁 Script executed:

rg "\.as_torch\(\)" --type py -B 2 -A 2 | head -50

Repository: tile-ai/tilelang

Length of output: 3171


🏁 Script executed:

sed -n '175,233p' examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py

Repository: tile-ai/tilelang

Length of output: 2024


🏁 Script executed:

rg "def __str__|def __repr__" tilelang/language/v2/dtypes.py -B 2 -A 5

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

rg "getattr\(torch," examples/ --type py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 28083


🏁 Script executed:

cat examples/gemm_fp8/example_tilelang_gemm_fp8.py | sed -n '160,230p'

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

wc -l examples/gemm_fp8/example_tilelang_gemm_fp8.py

Repository: tile-ai/tilelang

Length of output: 109


🏁 Script executed:

cat examples/gemm_fp8/example_tilelang_gemm_fp8.py

Repository: tile-ai/tilelang

Length of output: 2406


🏁 Script executed:

sed -n '176,232p' examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py | cat -n

Repository: tile-ai/tilelang

Length of output: 2421


Use .as_torch() to convert tilelang dtypes to torch dtypes.

Lines 199-200 pass in_dtype (a tilelang dtype object) to getattr(torch, in_dtype), which expects a string attribute name. This will fail at runtime. The correct approach is to use the .as_torch() method:

Diff
-    A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
-    B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
+    A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())
+    B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())

Also apply the same fix on lines 221, 224, 227, 230 where out_dtype is used with getattr(torch, out_dtype):

-        ref_c = torch.matmul(...).to(getattr(torch, out_dtype))
+        ref_c = torch.matmul(...).to(out_dtype.as_torch())
🤖 Prompt for AI Agents
In @examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py around lines
193-194, The tensor dtype conversion is wrong: replace calls using
getattr(torch, in_dtype) when constructing tensors A and B with the tilelang
dtype's as_torch() (i.e., use in_dtype.as_torch()) and similarly replace any
getattr(torch, out_dtype) occurrences (used later when creating other tensors)
with out_dtype.as_torch(); locate uses of the variables in_dtype and out_dtype
around the tensor constructions (e.g., where A and B are created and the later
tensor creations referenced in the review) and call .as_torch() to obtain a
valid torch.dtype for torch.rand()/tensor constructors.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (2)

199-200: Use .as_torch() to convert tilelang dtypes to torch dtypes.

getattr(torch, in_dtype) will fail because in_dtype is a tilelang dtype object, not a string. Use the .as_torch() method instead.

Suggested fix
-    A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
-    B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
+    A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())
+    B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())

219-231: Same dtype conversion issue in reference computation.

All getattr(torch, out_dtype) calls need to use out_dtype.as_torch().

Suggested fix
     if a_transposed and b_transposed:
         # Get Reference Result
-        ref_c = torch.matmul(A.T.half(), B.T.half()).to(getattr(torch, out_dtype))
+        ref_c = torch.matmul(A.T.half(), B.T.half()).to(out_dtype.as_torch())
     elif a_transposed and not b_transposed:
         # Get Reference Result
-        ref_c = torch.matmul(A.T.half(), B.half()).to(getattr(torch, out_dtype))
+        ref_c = torch.matmul(A.T.half(), B.half()).to(out_dtype.as_torch())
     elif not a_transposed and b_transposed:
         # Get Reference Result
-        ref_c = torch.matmul(A.half(), B.T.half()).to(getattr(torch, out_dtype))
+        ref_c = torch.matmul(A.half(), B.T.half()).to(out_dtype.as_torch())
     else:
         # Get Reference Result
-        ref_c = torch.matmul(A.half(), B.half()).to(getattr(torch, out_dtype))
+        ref_c = torch.matmul(A.half(), B.half()).to(out_dtype.as_torch())
🧹 Nitpick comments (3)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (3)

80-82: Remove unused variable local_size_c.

The variable is assigned but never referenced. This was flagged by static analysis (Ruff F841).

Suggested fix
     local_size_a = mfma_emitter.local_size_a
     local_size_b = mfma_emitter.local_size_b
-    local_size_c = mfma_emitter.local_size_out

100-107: Remove unused B_shared_shape computation.

This variable is computed but never used. The kernel loads B directly from global memory via ldmatrix_b rather than through shared memory. Flagged by static analysis (Ruff F841).

Suggested fix
     if b_preshuffle:
         B_shape = (
             (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k)
             if b_transposed
             else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y)
         )
-        B_shared_shape = (
-            (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, pack_size_k)
-            if b_transposed
-            else (block_K // pack_size_k, block_N // micro_size_y, pack_size_k, micro_size_y)
-        )
     else:
         B_shape = (N, K) if b_transposed else (K, N)
-        B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)

235-237: Consider expanding test coverage for transpose combinations.

Currently only tests a_transposed=False, b_transposed=True. The kernel supports all four transpose combinations, but only one is validated. Consider adding tests for other configurations if they are intended to be supported.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dacab78 and 38d0c75.

📒 Files selected for processing (1)
  • examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧬 Code graph analysis (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (3)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (62-71)
tilelang/intrinsics/mfma_macro_generator.py (2)
  • MatrixCorePreshuffleIntrinEmitter (670-866)
  • mfma (358-395)
tilelang/testing/__init__.py (1)
  • set_random_seed (32-37)
🪛 Ruff (0.14.10)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py

82-82: Local variable local_size_c is assigned to but never used

Remove assignment to unused variable local_size_c

(F841)


107-107: Local variable B_shared_shape is assigned to but never used

Remove assignment to unused variable B_shared_shape

(F841)

🔇 Additional comments (4)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (4)

1-11: LGTM!

Imports are appropriate and the random seed initialization ensures reproducible autotuning as expected for example/test code.


14-31: LGTM!

The config generation is clean and produces a reasonable search space for autotuning.


109-154: LGTM!

The kernel body correctly implements the preshuffled MFMA GEMM pattern: A goes through shared memory with swizzled layout, B loads directly from global memory via the preshuffle path, and the pipeline structure is appropriate.


157-173: LGTM!

The shuffle_weight function correctly implements the preshuffle transformation required for MFMA access patterns. The reshape-permute-contiguous pattern is standard for this kind of layout transformation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (1)

188-189: Use .as_torch() to convert TileLang dtypes to torch dtypes.

Lines 188-189 pass in_dtype (a TileLang dtype object) to getattr(torch, in_dtype), which expects a string attribute name. This will fail at runtime.

The same issue occurs on lines 203, 206, 209, and 212 with out_dtype.

🔎 Proposed fix
-    A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
-    B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype))
+    A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())
+    B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(in_dtype.as_torch())

Also fix lines 203, 206, 209, 212:

     if a_transposed and b_transposed:
         # Get Reference Result
-        ref_c = torch.matmul(A.T.half(), B.T.half()).to(getattr(torch, out_dtype))
+        ref_c = torch.matmul(A.T.half(), B.T.half()).to(out_dtype.as_torch())
     elif a_transposed and not b_transposed:
         # Get Reference Result
-        ref_c = torch.matmul(A.T.half(), B.half()).to(getattr(torch, out_dtype))
+        ref_c = torch.matmul(A.T.half(), B.half()).to(out_dtype.as_torch())
     elif not a_transposed and b_transposed:
         # Get Reference Result
-        ref_c = torch.matmul(A.half(), B.T.half()).to(getattr(torch, out_dtype))
+        ref_c = torch.matmul(A.half(), B.T.half()).to(out_dtype.as_torch())
     else:
         # Get Reference Result
-        ref_c = torch.matmul(A.half(), B.half()).to(getattr(torch, out_dtype))
+        ref_c = torch.matmul(A.half(), B.half()).to(out_dtype.as_torch())
🧹 Nitpick comments (2)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (2)

54-54: Consider making b_preshuffle a function parameter.

Currently b_preshuffle is hardcoded to True. Making it a parameter would improve flexibility for benchmarking and comparing preshuffled vs. non-preshuffled performance.

🔎 Suggested refactor
 @tilelang.jit(out_idx=[-1])
 def tl_matmul(
     M,
     N,
     K,
     block_M,
     block_N,
     block_K,
     num_stages,
     k_pack=2,
     num_threads=256,
     in_dtype=T.float8_e4m3fnuz,
     out_dtype=T.float32,
     accum_dtype=T.float32,
     a_transposed=False,
     b_transposed=True,
+    b_preshuffle=True,
 ):
-    b_preshuffle = True
     warp_size = 64

217-222: Consider adding tests for different configurations.

The current test only exercises one configuration (512×512×512 with k_pack=2). Adding tests for different matrix sizes, k_pack values, and transpose combinations would improve coverage.

For example:

def test_assert_tl_matmul():
    # Test basic case
    assert_tl_matmul_correctness(512, 512, 512, k_pack=2)
    
    # Test different k_pack
    assert_tl_matmul_correctness(256, 256, 256, k_pack=1)
    
    # Test transpose combinations
    assert_tl_matmul_correctness(512, 512, 512, k_pack=2, a_transposed=True, b_transposed=False)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 38d0c75 and e4ba345.

📒 Files selected for processing (1)
  • examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py
🧬 Code graph analysis (1)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (1)
tilelang/intrinsics/mfma_macro_generator.py (3)
  • MatrixCorePreshuffleIntrinEmitter (670-866)
  • make_mfma_store_layout (573-640)
  • mfma (358-395)
🔇 Additional comments (7)
examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py (7)

1-11: LGTM!

Imports are appropriate and the random seed initialization ensures reproducible test results.


14-31: LGTM!

Configuration generation logic is correct and provides a reasonable search space for autotuning.


34-53: LGTM!

Function signature and decorators are appropriate for an autotuned TileLang kernel.


90-97: LGTM!

Shape calculations correctly account for transposition and preshuffle layout requirements.


99-144: LGTM!

Kernel implementation follows the expected pattern for tiled GEMM with preshuffle:

  • Proper memory hierarchy (shared → local → fragment)
  • Correct pipelining with configurable stages
  • Appropriate layout annotations for MFMA

147-163: LGTM!

The shuffle logic correctly transforms the weight tensor into the preshuffle layout expected by the kernel.


201-212: Transpose handling logic is correct.

The reference computation correctly handles all four transpose combinations, matching the kernel semantics.

@LeiWang1999 LeiWang1999 merged commit 7198aa5 into tile-ai:main Jan 5, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants