Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Nov 2, 2025

This pull request introduces a new version of the GEMM (matrix multiplication) kernel, adds a usage example and benchmarking script, and refactors how GEMM functions are exposed in the tilelang package. It also updates the CUDA code generation logic to treat the scale_out argument as a runtime value rather than a compile-time boolean, improving flexibility. The most important changes are grouped below.

GEMM Kernel Improvements and Refactoring

  • Added a new kernel implementation gemm_v2 and refactored the original kernel to gemm_v1. The default gemm now points to gemm_v1, allowing users to select between versions. (tilelang/language/gemm.py, tilelang/language/__init__.py) [1] [2] [3]
  • Updated the TVM submodule, likely to pick up upstream changes needed for new kernel support. (3rdparty/tvm)

Usage Example and Benchmarking

  • Added a new script latency.py that demonstrates how to use the new kernel, validates correctness against PyTorch, and profiles latency. It supports toggling between gemm_v1 and gemm_v2 via a command-line flag. (maint/gemm_v2/latency.py)

Code Generation and Runtime Flexibility

  • Changed the handling of the scale_out argument in CUDA code generation to treat it as a runtime value rather than a compile-time boolean, both in assembly generation and rule replacement. This enables more flexible kernel invocation. (src/target/codegen_cuda.cc) [1] [2] [3] [4]

Logging and Debugging

  • Removed a debug log statement from the GEMM instruction selection logic, reducing log noise during kernel selection. (src/op/gemm.cc)

Summary by CodeRabbit

  • New Features

    • GEMM now exposes v1 and v2 variants while maintaining backward compatibility.
    • Added a CUDA GEMM kernel variant with optional ReLU and latency profiling.
  • Tests

    • Added comprehensive GEMM correctness and latency tests across sizes, dtypes, transpositions.
    • A float32 test comment added; one transform test removed.
  • Bug Fixes

    • Removed a noisy runtime debug log.

@github-actions
Copy link

github-actions bot commented Nov 2, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 2, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

Adds comprehensive GEMM testing and profiling modules, introduces gemm_v1 alongside gemm_v2 (with alias gemm→gemm_v1), updates CUDA codegen type handling and MMA dispatch templates, extends MMA TF32 specializations, and removes a debug log and one outdated test.

Changes

Cohort / File(s) Summary
GEMM language API
tilelang/language/gemm.py, tilelang/language/__init__.py
Added gemm_v1 and gemm_v2 implementations; exposed gemm_v1 in public API and aliased gemm = gemm_v1.
GEMM correctness tests
maint/gemm_v2/correctness_evaluation.py
New pytest-based correctness evaluation module with matmul variants (matmul, matmul_rs, matmul_sr, matmul_rr), compile-and-check utilities, parameterized test cases across shapes/dtypes/transposes, and Torch reference comparisons.
GEMM latency and example
maint/gemm_v2/latency.py
New module building a tiled GEMM+ReLU CUDA kernel, optional gemm_v2/gemm_v1 path toggle, host-side validation against PyTorch, and simple latency profiling.
CUDA codegen adjustments
src/target/codegen_cuda.cc
Derive AType/BType strings from normalized dtype enums (TensorFloat32 mapping), propagate derived types into MMA/WGMMA template replacements, and switch scale_out replacements to string expressions.
MMA dispatch extensions
src/tl_templates/cuda/instruction/mma.h
Added two MmaDispatcher specializations for TF32 inputs on SM80 (M=16,N=8,K=4 and M=16,N=8,K=8) enabling FP32 math on Tensor Cores.
Op cleanup
src/op/gemm.cc
Removed a runtime debug logging statement in GemmNode::GetGemmInst.
Tests adjusted
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py, testing/python/transform/test_tilelang_transform_inject_fence_proxy.py
Inserted a TODO comment in tilelibrary_gemm test file; removed test_wgmma_after_descriptor() from inject_fence_proxy tests.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test Runner
    participant API as tilelang.language
    participant GEMMv1 as gemm_v1
    participant GEMMv2 as gemm_v2
    participant Codegen as CUDA Codegen
    participant Kernel as CUDA Kernel
    participant Ref as PyTorch Ref

    Test->>API: call run_gemm / run_gemm_* (params)
    API->>GEMMv2: build program (or GEMMv1 if used)
    GEMMv2->>Codegen: emit CUDA (resolve AType/BType, scale_out)
    Codegen-->>GEMMv2: CUDA source / binary
    GEMMv2->>Kernel: launch kernel
    Kernel-->>Test: outputs
    Test->>Ref: compute reference (torch)
    Test->>Test: compare outputs vs Ref (assert close)
    Note over Codegen: MMA dispatch may choose TF32 paths\n(templates use derived AType/BType)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • Pay attention to:
    • maint/gemm_v2/correctness_evaluation.py (many variants, dtype/transposition logic).
    • src/target/codegen_cuda.cc (template replacer changes, AType/BType normalization, scale_out stringification).
    • MMA TF32 additions in src/tl_templates/cuda/instruction/mma.h for correctness on SM80.

Possibly related PRs

Suggested reviewers

  • tzj-fxz

Poem

🐰 Two GEMM paths hop in tidy rows,
v1 and v2 where tensor-buffer flows.
TF32 sparks on SM80's night,
Tests check results till numbers are right.
A little rabbit claps—kernels compile bright!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 6.12% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[Language] Add Correctness and performance check scripts for V2" accurately describes real additions present in the changeset. The PR does introduce two new utility files: maint/gemm_v2/correctness_evaluation.py (a comprehensive pytest-based correctness evaluation module) and maint/gemm_v2/latency.py (a performance and latency profiling script). Both files represent significant additions for V2 validation and benchmarking. However, the title focuses on the testing infrastructure utilities and does not mention the core language-level refactoring that introduces gemm_v1 and gemm_v2 as distinct API variants with gemm aliased to gemm_v1, which represents the primary structural change to the tilelang.language module.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f09ef4d and 548d5a0.

📒 Files selected for processing (4)
  • src/target/codegen_cuda.cc (7 hunks)
  • src/tl_templates/cuda/instruction/mma.h (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (1 hunks)
  • testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (0 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tilelang/language/__init__.py (1)

54-54: Drop the redundant noqa suppressor

gemm, gemm_v1, and gemm_v2 are all re-exported, so Ruff now flags # noqa: F401 as unused. Please remove the directive to keep lint clean.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aef0a6b and f09ef4d.

📒 Files selected for processing (7)
  • 3rdparty/tvm (1 hunks)
  • maint/gemm_v2/correctness_evaluation.py (1 hunks)
  • maint/gemm_v2/latency.py (1 hunks)
  • src/op/gemm.cc (0 hunks)
  • src/target/codegen_cuda.cc (4 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/gemm.py (2 hunks)
💤 Files with no reviewable changes (1)
  • src/op/gemm.cc
🧰 Additional context used
🧬 Code graph analysis (4)
tilelang/language/gemm.py (2)
tilelang/primitives/gemm/__init__.py (1)
  • gemm (10-46)
examples/gemm/example_gemm.py (1)
  • gemm (9-25)
maint/gemm_v2/latency.py (9)
tilelang/jit/__init__.py (1)
  • jit (233-306)
tilelang/language/allocate.py (2)
  • alloc_shared (27-42)
  • alloc_fragment (59-70)
tilelang/language/fill.py (1)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (11-87)
tilelang/language/gemm.py (2)
  • gemm_v2 (215-434)
  • gemm_v1 (10-211)
tilelang/language/parallel.py (1)
  • Parallel (9-29)
tilelang/jit/kernel.py (1)
  • get_profiler (367-383)
tilelang/utils/tensor.py (1)
  • TensorSupplyType (11-18)
maint/gemm_v2/correctness_evaluation.py (11)
tilelang/language/kernel.py (2)
  • threads (215-219)
  • num_threads (222-226)
tilelang/language/allocate.py (2)
  • alloc_shared (27-42)
  • alloc_fragment (59-70)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (11-87)
tilelang/language/gemm.py (1)
  • gemm_v2 (215-434)
tilelang/jit/__init__.py (1)
  • compile (30-79)
tilelang/jit/kernel.py (2)
  • out_idx (453-454)
  • get_profiler (367-383)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-144)
tilelang/profiler/__init__.py (1)
  • assert_allclose (77-146)
tilelang/env.py (1)
  • disable_cache (271-272)
tilelang/testing/__init__.py (1)
  • set_random_seed (30-35)
tilelang/language/__init__.py (2)
tilelang/primitives/gemm/__init__.py (1)
  • gemm (10-46)
tilelang/language/gemm.py (2)
  • gemm_v1 (10-211)
  • gemm_v2 (215-434)
🪛 Ruff (0.14.2)
tilelang/language/__init__.py

54-54: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (3)
src/target/codegen_cuda.cc (2)

1841-1841: LGTM: scale_out now handled as runtime expression

The change from boolean extraction to expression string is correct. This allows scale_out to be a runtime-evaluated expression rather than a compile-time constant, which provides more flexibility.

Note: Unlike scale_in_a and scale_in_b (which are still booleans converted to "1"/"-1" strings at lines 1866-1867), scale_out is now directly substituted as an arbitrary expression into the template. Ensure the underlying WGMMA instruction template can handle expression strings for the scale_out parameter.

Also applies to: 1873-1873


1903-1903: LGTM: Consistent scale_out handling across WGMMA variants

The changes mirror the pattern in ptx_wgmma_ss (lines 1841, 1873), ensuring consistent handling of scale_out as a runtime expression across both ptx_wgmma_ss and ptx_wgmma_rs handlers.

Also applies to: 1942-1942

3rdparty/tvm (1)

1-1: Manually verify the TVM submodule commit references.

The sandbox environment cannot access the upstream TVM repository to verify these commit hashes. Confirm that both commit hashes exist in the official TVM repository and review the changes between them to ensure they align with the GEMM v2 implementation goals and introduce no unexpected breaking changes.

Comment on lines +77 to +88
import torch

if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix the dtype reinterpretation code path

This block tries to reinterpret tensor bits with a NumPy-style view(dtype), but PyTorch’s Tensor.view only accepts shape arguments. As soon as we exercise the float32 path, this raises TypeError: view(): argument 'size' must be tuple of ints, but found type torch.dtype. Please switch to an explicit dtype conversion before the bias adjustment:

-        if in_dtype == "float32":
-            A = (A.view(torch.int32) - 0x1000).view(torch.float32)
-            B = (B.view(torch.int32) - 0x1000).view(torch.float32)
+        if in_dtype == "float32":
+            A = (A.to(torch.int32) - 0x1000).to(torch.float32)
+            B = (B.to(torch.int32) - 0x1000).to(torch.float32)

That keeps the intended shift while avoiding the runtime failure.

🤖 Prompt for AI Agents
In maint/gemm_v2/correctness_evaluation.py around lines 77 to 88, the code
incorrectly uses Tensor.view(torch.int32) to reinterpret dtypes (PyTorch view
expects shape tuples), causing a TypeError; change the float32 reinterpretation
path to explicitly cast the tensor to an integer dtype (e.g., A =
A.to(torch.int32)), perform the subtraction (A = A - 0x1000), then cast back to
float (A = A.to(torch.float32)); do the same for B, then proceed with matmul and
conversion to out_dtype as before.

@LeiWang1999 LeiWang1999 merged commit d99853b into tile-ai:main Nov 2, 2025
1 check was pending
tzj-fxz pushed a commit to tzj-fxz/tilelang that referenced this pull request Nov 3, 2025
LeiWang1999 added a commit that referenced this pull request Nov 5, 2025
* [Test] Add cp async to avoid register spill

* [BugFix] GQA fwd and bwd
- Fix the undefined behavior of -inf in acc_s
- Fix the causal loop range in varlen scenario

* [TMA] Move on to TMA and locate the register spill issue

* [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT

* [Debug] The SIMT copy in producer occupies too many registers

* [BugFix] Use 3D lse and delta to avoid illegal instruction

* [Perf] Relaxed order for dQ and SIMT store for dKdV

* [Feat] For atomic add version

* [Lint]

* [Bugfix] Enable code lowering with producer‑copy‑only program (#1168)

* bugfix

* lint fix

* Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns.

* Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic.

* Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions.

* [Bugfix] Support 16bits shfl_sync (#1169)

* Add type-safe warp shuffle helpers for 16-bit float types in common.h

- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`.
- Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations.
- Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.

* lint fix

* [Testing] Move TMA 1D and test for its functionality (#1167)

* [Testing] Move TMA 1D and test for its functionality

* [Lint]

* [Refactor]: Change the params in pytest to avoid oom error during ci (#1170)

* [Refactor]: Change the params in pytest to avoid oom error during ci

* format

* fix

* Update test_example_cast.py

* Update parameters in test_example_cast

* Update test_example_flash_attention.py

* update

* format

* fix

* fix

* format

* [Bugfix] Fix tvm import path for editable build (#1172)

* [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (#986)

* remove debug print

* pipeline fix

* use the correct buffer access scope

* rs support

* warp warpgroup_fence_operand

* fix

* fp8 dtype ptx enhance

* mma fix

* TCGEN05 Interface

* tcgen05 support

* rebase

* update

* Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors.

* lint fix

* Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module.

* wgmma fix

---------

Co-authored-by: Zhiwen Mo <[email protected]>

* [Language] Add Correctness and performance check scripts for V2 (#1174)

* fix

* lint fix

* fix

* lint fix

* fix

* upd

* [Bugfix] Legalize Datatype for mma intrinisc codegen  (#1179)

* fix

* lint fix

* Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands.

* [Perf] Add layout and use_tma to boost performance

* [Lint]

* [Note]

---------

Co-authored-by: Lei Wang <[email protected]>
Co-authored-by: Yuqi Dong <[email protected]>
Co-authored-by: Zhiwen Mo <[email protected]>
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* [Test] Add cp async to avoid register spill

* [BugFix] GQA fwd and bwd
- Fix the undefined behavior of -inf in acc_s
- Fix the causal loop range in varlen scenario

* [TMA] Move on to TMA and locate the register spill issue

* [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT

* [Debug] The SIMT copy in producer occupies too many registers

* [BugFix] Use 3D lse and delta to avoid illegal instruction

* [Perf] Relaxed order for dQ and SIMT store for dKdV

* [Feat] For atomic add version

* [Lint]

* [Bugfix] Enable code lowering with producer‑copy‑only program (tile-ai#1168)

* bugfix

* lint fix

* Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns.

* Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic.

* Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions.

* [Bugfix] Support 16bits shfl_sync (tile-ai#1169)

* Add type-safe warp shuffle helpers for 16-bit float types in common.h

- Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`.
- Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations.
- Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability.

* lint fix

* [Testing] Move TMA 1D and test for its functionality (tile-ai#1167)

* [Testing] Move TMA 1D and test for its functionality

* [Lint]

* [Refactor]: Change the params in pytest to avoid oom error during ci (tile-ai#1170)

* [Refactor]: Change the params in pytest to avoid oom error during ci

* format

* fix

* Update test_example_cast.py

* Update parameters in test_example_cast

* Update test_example_flash_attention.py

* update

* format

* fix

* fix

* format

* [Bugfix] Fix tvm import path for editable build (tile-ai#1172)

* [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (tile-ai#986)

* remove debug print

* pipeline fix

* use the correct buffer access scope

* rs support

* warp warpgroup_fence_operand

* fix

* fp8 dtype ptx enhance

* mma fix

* TCGEN05 Interface

* tcgen05 support

* rebase

* update

* Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors.

* lint fix

* Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module.

* wgmma fix

---------

Co-authored-by: Zhiwen Mo <[email protected]>

* [Language] Add Correctness and performance check scripts for V2 (tile-ai#1174)

* fix

* lint fix

* fix

* lint fix

* fix

* upd

* [Bugfix] Legalize Datatype for mma intrinisc codegen  (tile-ai#1179)

* fix

* lint fix

* Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands.

* [Perf] Add layout and use_tma to boost performance

* [Lint]

* [Note]

---------

Co-authored-by: Lei Wang <[email protected]>
Co-authored-by: Yuqi Dong <[email protected]>
Co-authored-by: Zhiwen Mo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant