Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 7, 2025

This change will help us integrate additional backends and reduce compilation time.
Note: This pull request is experimental and currently targets MMA only.

TODOs for this PR

  • Support more data types
    • fp16 ss
      • nn
      • nt
      • tn
      • tt
    • fp16 rs
      • nn
      • nt
      • tn
      • tt
    • fp16 sr
      • nn
      • nt
      • tn
      • tt
    • fp16 rr
      • nn
      • nt
      • tn
      • tt
    • bf16
    • Enabling software pipeline
    • Enabling auto warp specialization (cc this part @chengyupku)
    • n=8 support
    • int8
    • fp8
    • tf32
  • Wrap ldmatrix for improved readability

Summary by CodeRabbit

  • New Features

    • Experimental gemm_v2 API, GemmPy operator, GemmMMA/GemmBase implementations, and GemmWarpPolicy compute_warp_partition API.
    • New PTX codegen backend and helper APIs for MMA, ldmatrix and cp.async emission.
  • Improvements/Refactor

    • Reworked TensorCore layouts (A/B variants), ldmatrix emission, thread-binding control, safer kernel-launch guards, buffer-shape remapping fixes, and many public exports/re-exports.
  • Tests

    • Expanded GEMM tests to cover ss/sr/rs/rr variants and more dtypes.
  • Chores

    • Updated third-party submodule and build/CMake entries.

…bility

- Removed deprecated prime factorization functions from `gemm.cc` and `gemm_sp.cc`.
- Introduced a new `GemmWarpPolicy` class to manage warp policy attributes and methods, improving encapsulation.
- Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities.
- Enhanced `GetArchInt` function in `utils.cc` for better readability and type safety.
- Added new `gemm_v2` function in `gemm.py` for improved GEMM operation with additional parameters and checks.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 7, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a new TL GEMM stack (GemmPy/GemmMMA) with FFI bindings and GemmWarpPolicy compute wrapper, introduces gemm_v2 and GemmPy scriptable node, refactors MMA layouts/emitter and PTX codegen (ptx.cc/.h + codegen_cuda changes), renames FrontendLegalize→LetInline, and applies pipeline/lowering and test updates.

Changes

Cohort / File(s) Summary
Submodule update
3rdparty/tvm
Update submodule pointer.
GEMM C++ glue & FFI
src/op/gemm.cc, src/op/gemm_sp.cc
Register tl.GemmWarpPolicy, add FFI reflection and tl.GemmWarpPolicyComputeWarpPartition; remove private toPrimeFactors helper.
New TL GEMM op (C++)
src/op/gemm_py.h, src/op/gemm_py.cc
Add GemmPy TileOperator node/ref, arg parsing, kernel selection (MMA/WGMMA/MFMA), Lower/InferLayout, TL op registration and reflection/FFI hooks.
CUDA codegen & PTX API
src/target/codegen_cuda.cc, src/target/ptx.h, src/target/ptx.cc, CMakeLists.txt
Add PTX codegen module/header; introduce PTX MMA/cp.async/barrier helpers; switch ldmatrix emission to tl::ptx_ldmatrix_xN(_trans) call path; include ptx.cc in build.
Target utilities & FFI
src/target/utils.cc, tilelang/utils/target.py
Safer GetArchInt parsing; add TVM_FFI_STATIC_INIT_BLOCK exposing target capability predicates and warp-size; Python wrappers to new _ffi_api.
TileOp GEMM Python runtime/FFI
tilelang/tileop/gemm/__init__.py, tilelang/tileop/__init__.py
Add GemmPy scriptable node, tl.gemm_py.lower/infer_layout FFI adapters; re-export GemmPy from tileop.
GEMM runtime classes (Python)
tilelang/tileop/gemm/gemm_base.py, tilelang/tileop/gemm/gemm_mma.py
Add GemmBase wrapper and GemmMMA backend implementing infer_layout and lower for SS/SR/RS/RR variants using TensorCoreIntrinEmitter.
Language API & gemm_v2
tilelang/ir.py, tilelang/language/gemm.py, tilelang/language/__init__.py
Extend GemmWarpPolicy fields and add compute_warp_partition FFI call; add and re-export gemm_v2 which resolves pointers/strides/offsets and emits a TL GEMM intrinsic.
Intrinsics & emitter refactor (Python)
tilelang/intrinsics/mma_layout.py, tilelang/intrinsics/mma_macro_generator.py, tilelang/intrinsics/utils.py
Replace SR/RS layout scheme with A/B variants and trans wrappers; add 32x4/32x8/32x16 layouts and load helpers; TensorCoreIntrinEmitter gains optional thread_var and get_thread_binding; rework load-layout logic.
Pass rename & simplify integration
src/transform/frontend_legalize.cc, tilelang/transform/__init__.py, tilelang/transform/simplify.py, tilelang/engine/phase.py, testing/python/transform/test_tilelang_transform_let_inline.py
Rename FrontendLegalizerLetInliner and FrontendLegalize()LetInline(); add LetInline export and inline_let option in simplify flow; update engine and tests.
Pipeline / lowering fixes
src/transform/inject_pipeline.cc, src/transform/lower_tile_op.cc
Wrap async-stage bodies in async_scope, set block read/write regions after rewriting, and separate original vs remapped buffer shapes in indexing/remapping to fix offsets; minor log removals.
Tests & GEMM Python tests
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
Expand GEMM tests to SS/SR/RS/RR variants, add kernels and run/test wrappers, exercise swizzled layouts and gemm_v2 usage; rename helpers accordingly.
Miscellaneous Python & small C++ tweaks
tilelang/__init__.py, tilelang/layout/swizzle.py, tilelang/language/kernel.py, tilelang/profiler/__init__.py, tilelang/utils/tensor.py, src/op/copy.cc, .clang-tidy
Re-export tileop; doc comments; KernelLaunchFrame guards; float8 cast handling in profiler; random tensor fill via float32 then cast; minor whitespace; disable clang-tidy performance-enum-size.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant PyUser as Python (tilelang.language.gemm_v2)
  participant TL as TL Runtime (Python)
  participant FFI as TVM FFI
  participant Cpp as C++ TL Op (GemmPy)
  participant Impl as Python Impl (tileop.gemm.GemmMMA)
  participant CG as CUDA Codegen

  PyUser->>TL: gemm_v2(A,B,C,transpose_*,policy,...)
  TL->>TL: Resolve shapes, strides, pointers, offsets
  TL->>FFI: tl.gemm_py.lower(gemm_py, target, thread_bounds, thread_var)
  FFI->>Cpp: GemmPyNode::Lower(...)
  Cpp->>Impl: dispatch to GemmMMA.lower via FFI/Python
  Impl->>Impl: prepare emitter, layouts, fragments
  Impl-->>FFI: return PrimFunc
  FFI-->>TL: PrimFunc
  TL-->>CG: Lower to CUDA, emit PTX
  CG->>CG: Emit ldmatrix via tl::ptx_ldmatrix_xN(_trans)
  CG-->>PyUser: Compiled kernel
Loading
sequenceDiagram
  autonumber
  participant Py as Python (tilelang.ir.GemmWarpPolicy)
  participant FFI as TVM FFI
  participant Cpp as C++ GemmWarpPolicy

  Py->>FFI: GemmWarpPolicyComputeWarpPartition(policy,M,N,block_size,target,is_wgmma)
  FFI->>Cpp: policy->ComputeWarpPartition(...)
  Cpp-->>FFI: sets policy.m_warp, policy.n_warp
  FFI-->>Py: returns m_warp, n_warp
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Poem

I nibble bytes and hop through loops,
I stitch A, B in memory groups.
LetInline I tweak, layouts I fluff,
PTX shimmers — kernels call my bluff.
A rabbit grin — the warps compile! 🐇✨

Pre-merge checks (2 passed, 1 warning)

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title clearly identifies the primary change by stating that the PR introduces an experimental Python-defined T.gemm_v2 operator for the TileOp component, directly reflecting the core objective of the changeset.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

github-actions bot commented Sep 7, 2025

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces an experimental, Python-defined General Matrix Multiplication (GEMM) operation, T.gemm_v2, designed to improve backend integration and reduce compilation overhead. It establishes a robust Python-C++ FFI for handling the operation's parameters, layout inference, and lowering, while also enhancing existing target utility functions and refactoring related code for better maintainability and extensibility.

Highlights

  • Experimental Python-defined GEMM (T.gemm_v2): Introduced a new experimental T.gemm_v2 function, defined in Python, to facilitate integration with additional backends and potentially reduce compilation time. This initial version is focused on MMA (Matrix Multiply-Accumulate) operations.
  • New C++ Backend for Python-defined Ops: Added new C++ files (src/op/gemm_py.cc, src/op/gemm_py.h) to support the Python-defined gemm_v2 operation. This includes GemmPyNode for handling GEMM parameters, WGMMA checks for Hopper architecture, and FFI registrations for layout inference and lowering functions.
  • Enhanced Python-C++ FFI for Target Utilities: Expanded the FFI (Foreign Function Interface) to expose various target-related utility functions (e.g., TargetIsCuda, TargetGetWarpSize) from C++ to Python, enabling more flexible and dynamic target-specific optimizations from the Python frontend.
  • Refactoring and Code Cleanup: Removed the toPrimeFactors utility function from src/op/gemm.cc and src/op/gemm_sp.cc, simplifying the codebase. Improved the GetArchInt parsing logic in src/target/utils.cc for robustness.
  • Updated MMA Macro Generator: The TensorCoreIntrinEmitter in tilelang/intrinsics/mma_macro_generator.py now includes a thread_var parameter and a get_thread_binding method, allowing ldmatrix and stmatrix macros to retrieve thread bindings more directly, reducing reliance on T.KernelLaunchFrame.Current().
  • Improved GemmWarpPolicy Integration: The GemmWarpPolicy object in tilelang/ir.py has been updated to expose its internal policy_type, m_warp, and n_warp attributes, and includes a compute_warp_partition method that calls into the C++ backend via FFI.
  • New tileop Python Package: A new Python package tilelang/tileop has been introduced, specifically for Python-defined tile operations. It currently contains the GemmPy class, which defines the Python-side logic for infer_layout and lower for the new gemm_v2 operation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an experimental, Python-defined gemm_v2 operator to improve backend integration and reduce compilation times. The changes are extensive, adding new C++ and Python files for the operator, and exposing more C++ utilities to Python via FFI. The overall structure is sound. My review focuses on correctness and maintainability, and I've identified a few critical bugs and opportunities for refactoring to improve code quality.

equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_B) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a copy-paste error in SEqualReduce. offset_A is being compared with other->offset_B instead of other->offset_A. This will cause incorrect behavior in structural equality checks, which can lead to subtle bugs in TIR passes.

Suggested change
equal(offset_A, other->offset_B) &&
equal(offset_A, other->offset_A) &&

Comment on lines 70 to 71
block_row_warps=self.M,
block_col_warps=self.N,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The TensorCoreIntrinEmitter is initialized with block_row_warps=self.M and block_col_warps=self.N. However, self.M and self.N are matrix dimensions, while block_row_warps and block_col_warps are expected to be warp counts. The computed m_warp and n_warp should be used instead. This is a critical bug that will lead to incorrect behavior.

Suggested change
block_row_warps=self.M,
block_col_warps=self.N,
block_row_warps=m_warp,
block_col_warps=n_warp,

Comment on lines 99 to 100
block_row_warps=self.M,
block_col_warps=self.N,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The TensorCoreIntrinEmitter is initialized with block_row_warps=self.M and block_col_warps=self.N. However, self.M and self.N are matrix dimensions, while block_row_warps and block_col_warps are expected to be warp counts. The computed m_warp and n_warp should be used instead. This is a critical bug that will lead to incorrect behavior.

Suggested change
block_row_warps=self.M,
block_col_warps=self.N,
block_row_warps=m_warp,
block_col_warps=n_warp,

Comment on lines +207 to +218
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function GetArchInt is duplicated from src/target/utils.cc. Since gemm_py.cc includes ../target/utils.h, this local static version is not needed. It also appears to be dead code as it's not called within this file. Please remove this duplicated function to improve maintainability.

Comment on lines 57 to 84
def infer_layout(self, target: Target, thread_nums: int):
if target_is_cuda(target):
# TODO(lei): Support more cuda architectures, now mma only
# Now only implement ssr layout
m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
warp_row_tiles = m_warp * 16
warp_col_tiles = n_warp * 16
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=self.in_dtype,
b_dtype=self.in_dtype,
accum_dtype=self.accum_dtype,
a_transposed=self.trans_A,
b_transposed=self.trans_B,
block_row_warps=self.M,
block_col_warps=self.N,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=self.chunk,
)
layout_map = {
self.A: make_swizzled_layout(self.A),
self.B: make_swizzled_layout(self.B),
self.C: mma_emitter.make_mma_store_layout(self.C),
}
return layout_map
else:
raise ValueError(f"Unsupported target: {target}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating warp partitions and initializing TensorCoreIntrinEmitter is duplicated in infer_layout and lower. This could be refactored into a private helper method to improve maintainability and avoid potential inconsistencies between the two methods.

Comment on lines +186 to +359
def gemm_v2(
A: Union[tir.Buffer, tir.Var],
B: Union[tir.Buffer, tir.Var],
C: Union[tir.Buffer, tir.Var],
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A (Union[tir.Buffer, tir.Var]): First input matrix
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""

def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
return arg

A = legalize_arguments(A)
B = legalize_arguments(B)
C = legalize_arguments(C)

def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
region = object.region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")

def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
for s in reversed(object.shape):
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferRegion):
buffer, _ = object.buffer, object.region
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")

A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
C_shape = retrieve_shape(C)

A_stride = retrieve_stride(A)
B_stride = retrieve_stride(B)

assert len(C_shape) == 2, "current only support C as a 2D tensor"
assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor"
assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor"
if len(A_shape) > 2:
for i in range(len(A_shape) - 2):
assert A_shape[i] == 1, \
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if len(B_shape) > 2:
for i in range(len(B_shape) - 2):
assert B_shape[i] == 1, \
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"

M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"

stride_a = A_stride[-2]
stride_b = B_stride[-2]

def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
buffer, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
# not offset the last two dimension
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")

def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
elif isinstance(object, tir.BufferRegion):
_, region = object.buffer, object.region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")

A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]

Aptr = retrieve_ptr(A, "r")
Bptr = retrieve_ptr(B, "r")
Cptr = retrieve_ptr(C, "rw")
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_py"),
Aptr,
Bptr,
Cptr,
transpose_A,
transpose_B,
M,
N,
K,
policy,
clear_accum,
stride_a,
stride_b,
offset_a,
offset_b,
k_pack,
wg_wait,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new gemm_v2 function duplicates a significant amount of code from the existing gemm function, including the helper functions legalize_arguments, retrieve_shape, retrieve_stride, retrieve_ptr, and retrieve_offset. To improve maintainability and reduce redundancy, this common logic should be refactored into a shared helper function. The gemm and gemm_v2 functions would then call this helper and only differ in the final intrinsic they call (tl.gemm vs tl.gemm_py).

…nd functionality

- Updated `gemm_py.h` to include the correct header for GEMM operations.
- Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity.
- Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose.
- Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity.
- Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite.
- Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process.
- Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting.
- Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope.
- Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts.
- Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling.
- Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.
- Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations.
- Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage.
- Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations.
- Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.
- Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations.
- Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes.
- Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability.
- Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow.

These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/layout/gemm_layouts.cc (1)

200-206: Add missing precondition checks for block_m and block_n divisibility.

This function uses block_m / warp_m and block_n / warp_n as repeat/replicate counts (Lines 208 and 215), but unlike makeGemmFragmentA/C, there are no ICHECKs ensuring exact divisibility. If not multiples, layouts silently truncate and produce incorrect tiling.

Apply this diff near the existing ICHECKs:

   // transposed
   ICHECK(warp_n % 8 == 0);
   ICHECK(block_k % 16 == 0);
+  ICHECK(block_m % warp_m == 0) << "block_m=" << block_m << ", warp_m=" << warp_m;
+  ICHECK(block_n % warp_n == 0) << "block_n=" << block_n << ", warp_n=" << warp_n;
src/transform/inject_pipeline.cc (1)

53-69: Recompute reads/writes also for reused BlockRealize blocks

When body is a BlockRealize with predicate true, the early return skips recomputing access regions. If the block’s body was rewritten earlier, reads/writes can be stale. Compute and set them in both paths.

 Block MakeBlock(const Stmt &body,
                 const Map<Var, Buffer> &buffer_data_to_buffer) {
   if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) {
     if (is_one(block_realize->predicate)) {
-      // no need to create a new block
-      return block_realize->block;
+      // reuse block but refresh access regions
+      Block block = block_realize->block;
+      Array<Array<BufferRegion>> access =
+          GetBlockReadWriteRegion(block, buffer_data_to_buffer);
+      BlockNode* n = block.CopyOnWrite();
+      n->reads = access[0];
+      n->writes = access[1];
+      return block;
     }
   }
   Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
               /*body*/ body);
   Array<Array<BufferRegion>> access =
       GetBlockReadWriteRegion(block, buffer_data_to_buffer);
   BlockNode *n = block.CopyOnWrite();
   n->reads = access[0];
   n->writes = access[1];
   return block;
 }
♻️ Duplicate comments (4)
src/op/gemm_py.cc (1)

208-219: Remove duplicated GetArchInt (already provided by utils.h).

This local static is dead code here and duplicates src/target/utils.cc. Please delete.

Apply:

-/**
- * @brief Parse and return the numeric GPU architecture from a Target's "arch"
- * attribute.
- * ...
- */
-static int GetArchInt(Target target) {
-  int arch_int = 0;
-  auto s = target->GetAttr<String>("arch");
-  ICHECK(s.defined());
-  std::string arch = s.value();
-  if (arch.rfind("sm_", 0) == 0) {
-    arch_int = std::stoi(arch.substr(3));
-  } else {
-    arch_int = 0;
-  }
-  return arch_int;
-}
src/op/gemm_py.h (1)

71-71: Fix SEqualReduce comparison for offset_A (copy-paste bug).

offset_A is compared against other->offset_B. This breaks structural equality and downstream memoization/passes.

-           equal(offset_A, other->offset_B) &&
+           equal(offset_A, other->offset_A) &&
tilelang/language/gemm.py (1)

185-359: Deduplicate gemm and gemm_v2 by extracting shared helpers

This block duplicates almost all of gemm(). Factor the common logic into a private helper returning the prepared arguments; gemm() and gemm_v2 then differ only by the intrinsic op string.

Here’s a minimal sketch:

+from dataclasses import dataclass
+
+@dataclass
+class _GemmArgs:
+    Aptr: tir.PrimExpr
+    Bptr: tir.PrimExpr
+    Cptr: tir.PrimExpr
+    M: int
+    N: int
+    K: int
+    stride_a: int
+    stride_b: int
+    offset_a: int
+    offset_b: int
+
+def _prepare_gemm_args(A, B, C, transpose_A, transpose_B) -> _GemmArgs:
+    # move legalize_arguments/retrieve_shape/retrieve_stride/retrieve_ptr/retrieve_offset here
+    # and return _GemmArgs(...)
+    ...
+
 def gemm(...):
-    # duplicated helpers here
-    ...
-    return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), ...)
+    args = _prepare_gemm_args(A, B, C, transpose_A, transpose_B)
+    return tir.call_intrin(
+        "handle",
+        tir.op.Op.get("tl.gemm"),
+        args.Aptr, args.Bptr, args.Cptr,
+        transpose_A, transpose_B,
+        args.M, args.N, args.K, policy, clear_accum,
+        args.stride_a, args.stride_b, args.offset_a, args.offset_b, k_pack, wg_wait,
+    )
 
 def gemm_v2(...):
-    # duplicated helpers here
-    ...
-    return tir.call_intrin("handle", tir.op.Op.get("tl.gemm_py"), ...)
+    args = _prepare_gemm_args(A, B, C, transpose_A, transpose_B)
+    return tir.call_intrin(
+        "handle",
+        tir.op.Op.get("tl.gemm_py"),
+        args.Aptr, args.Bptr, args.Cptr,
+        transpose_A, transpose_B,
+        args.M, args.N, args.K, policy, clear_accum,
+        args.stride_a, args.stride_b, args.offset_a, args.offset_b, k_pack, wg_wait,
+    )

Bonus: rename helper parameters from “object” to “obj” to avoid shadowing the built-in.

tilelang/tileop/gemm/__init__.py (1)

56-75: Critical past bug fixed: using m_warp/n_warp (not M/N) for block warps.

Emitter now receives block_row_warps=m_warp and block_col_warps=n_warp. This resolves the earlier misconfiguration.

Also applies to: 105-125

🧹 Nitpick comments (30)
src/layout/gemm_layouts.cc (1)

325-331: Make xor8x8 parameter type consistent with xor2x2/xor4x4.

Minor: j is passed by value here while other helpers use const PrimExpr&. Align for consistency and to avoid an extra copy.

-PrimExpr xor8x8(const PrimExpr &i, const PrimExpr j) {
+PrimExpr xor8x8(const PrimExpr &i, const PrimExpr &j) {
tilelang/layout/swizzle.py (1)

8-9: Clarify comment: helper doesn’t toggle swizzling

The function always builds a swizzled layout; it doesn’t decide enable/disable based on TMA. Reword to state that call sites/schedules decide when to use this layout.

src/target/utils.cc (1)

21-26: Stricter arch parsing looks good; consider broader formats

Current check enforces "sm_" prefix. If upstream targets ever provide "sm90"/"compute_90", this will ICHECK. If that’s possible in your pipeline, normalize accepted forms before stoi.

src/target/codegen_cuda.cc (1)

1334-1344: Remove commented-out legacy asm path

The commented code for need_cast_smem_ptr_to_int_/PrintLoadMatrixAssembly adds noise. Clean up.

-      // need_cast_smem_ptr_to_int_ = true;
-      // this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
-      //                                         local_elem_offset, smem_ptr,
-      //                                         smem_elem_offset);
-      std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
+      std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
       if (trans == 1)
         func_name += "_trans";
-      // this->stream << func_name << "(" << local_ptr "" << ", " << smem_ptr << ");\n";
       this->PrintIndent();
       this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset<< ", " << local_ptr << " + " << local_elem_offset << ");\n";
src/transform/lower_tile_op.cc (2)

306-312: Fix incorrect variable naming that swapped old and new shapes

The variable names are misleading and incorrect. The original buffer shape should be named old_shape and the remapped buffer shape should be new_shape, but line 306 assigns the original load->buffer->shape to old_shape and line 326 assigns the remapped new_buffer->shape to new_shape.

However, there's a logical issue here: you're using old_shape to compute the element offset from the original buffer (lines 317-320), then using it again to decompose that offset back into multi-dimensional indices (lines 341-345). This is correct. But the forward mapping at line 348 expects indices in the original buffer's coordinate system, not the new one.

The current implementation appears correct despite the confusing naming. Consider renaming for clarity:

-      Array<PrimExpr> old_shape = load->buffer->shape;
+      Array<PrimExpr> original_shape = load->buffer->shape;

-      CHECK_EQ(indices.size(), old_shape.size())
+      CHECK_EQ(indices.size(), original_shape.size())
          << "Indices size and shape size must match for general N-dimensional "
             "buffer "
          << "but got indices size: " << indices.size()
-          << " and shape size: " << old_shape.size();
+          << " and shape size: " << original_shape.size();

358-361: Simplify new_indices computation by using the remapped shape directly

The decomposition of new_offset into new_indices can be done more directly since you already have new_shape available.

       Array<PrimExpr> new_indices;
+      PrimExpr remaining = new_offset;
       for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) {
-        new_indices.insert(new_indices.begin(),
-                           floormod(new_offset, new_shape[i]));
-        new_offset = floordiv(new_offset, new_shape[i]);
+        new_indices.insert(new_indices.begin(),
+                           floormod(remaining, new_shape[i]));
+        remaining = floordiv(remaining, new_shape[i]);
       }
tilelang/utils/target.py (2)

8-15: Nit: fix constant name typo.

Rename AVALIABLE_TARGETS → AVAILABLE_TARGETS for clarity, and update references.

Apply:

-AVALIABLE_TARGETS = {
+AVAILABLE_TARGETS = {
...
-            target, Target) or target in AVALIABLE_TARGETS, f"Target {target} is not supported"
+            target, Target) or target in AVAILABLE_TARGETS, f"Target {target} is not supported"

82-84: Avoid re-wrapping an existing Target.

If return_var is already a Target, re-constructing it is unnecessary.

Apply:

-    if return_object:
-        return Target(return_var)
+    if return_object:
+        return return_var if isinstance(return_var, Target) else Target(return_var)
tilelang/intrinsics/mma_layout.py (1)

66-69: Clear public aliases.

Aliases expose SR/RS A/B variants succinctly. Consider adding one-line docstrings for discoverability.

src/op/gemm_py.cc (4)

95-110: Duplicate kernel-selection logic with GemmNode.

GetGemmInst mirrors src/op/gemm.cc. Prefer a shared helper to avoid divergence.


142-192: Duplicate WGMMA gate with GemmNode::CheckWGMMA.

Keep a single source of truth (e.g., extract to a common helper) to prevent subtle mismatches.


221-226: Drop unused pre-computation in Lower.

block_size/gemm_inst/warp partition are unused; remove to reduce confusion.

Apply:

-Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
-  auto block_size = *as_const_int(T.thread_bounds->extent);
-  GemmInst gemm_inst = GetGemmInst(block_size, T.target);
-  auto [warp_m, warp_n] = policy->ComputeWarpPartition(
-      M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
+Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {

15-15: Remove unused include.

tvm/ffi/string.h isn’t used.

Apply:

-#include "tvm/ffi/string.h"
src/op/gemm_py.h (2)

39-61: Ensure reflection is registered exactly once.

RegisterReflection() defines bindings but isn't obviously invoked. Please confirm it's called at init (e.g., from gemm_py.cc) to avoid missing reflection in Python/printing.

Example (in src/op/gemm_py.cc):

// At file scope
namespace {
struct GemmPyNodeReflectionRegister {
  GemmPyNodeReflectionRegister() { GemmPyNode::RegisterReflection(); }
} _gemm_py_node_reflection_register;
}  // namespace

17-17: Avoid using namespace in a header.

Headers should not inject wide namespaces. It's redundant here since members are already qualified.

-using namespace tir;
tilelang/__init__.py (1)

110-110: Remove unused noqa directive (RUF100).

# noqa: F401 is unnecessary here and triggers Ruff’s RUF100.

-from . import tileop  # noqa: F401
+from . import tileop
tilelang/engine/phase.py (1)

69-78: Docstring is now stale; mention LetInline instead of “frontend legalization”

Update the pipeline bullets to reflect the inlining step to avoid confusion for users reading this docstring.

Apply this diff to refresh the wording:

-    - Legalizes frontend Tile IR into TVM-compatible constructs.
+    - Inlines let expressions/statements (LetInline) to normalize IR.
tilelang/tileop/__init__.py (1)

1-1: Remove redundant noqa

Ruff reports RUF100; the F401 rule isn’t enabled, so the directive is unnecessary here.

-from .gemm import GemmPy  # noqa: F401
+from .gemm import GemmPy
tilelang/language/__init__.py (1)

46-46: Nit: unnecessary noqa

Same RUF100 note as elsewhere; consider dropping the directive or declaring explicit all if you want to keep re-export semantics without noqa.

-from .gemm import GemmWarpPolicy, gemm, gemm_v2  # noqa: F401
+from .gemm import GemmWarpPolicy, gemm, gemm_v2
tilelang/transform/__init__.py (1)

5-5: Publicly exposing LetInline looks good; minor noqa nit

The re-export is correct and aligns with the new pipeline; consider removing the redundant noqa to satisfy RUF100.

-from .simplify import Simplify, simplify_prim_func, LetInline  # noqa: F401
+from .simplify import Simplify, simplify_prim_func, LetInline
tilelang/ir.py (1)

35-39: Add return type for compute_warp_partition for clarity

Returning a Tuple[int, int] helps callers and type checkers; no behavior change.

-from tilelang import _ffi_api
+from tilelang import _ffi_api
+from typing import Tuple
@@
-    def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
-                               is_wgmma: bool):
+    def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
+                               is_wgmma: bool) -> Tuple[int, int]:
tilelang/transform/simplify.py (2)

30-48: Return-type contract of _Simplify/apply_simplify is misleading

Both branches return a PrimFunc even when input is an IRModule. Either (a) document that PrimFunc is always returned, or (b) return IRModule when given IRModule. I suggest tightening types to PrimFunc to avoid surprising callers.

-from typing import Union, Callable
+from typing import Union, Callable
@@
-def _Simplify(stmt: Union[PrimFunc, IRModule],
-              inline_let: bool = False) -> Union[PrimFunc, IRModule]:
+def _Simplify(stmt: Union[PrimFunc, IRModule],
+              inline_let: bool = False) -> PrimFunc:
@@
-    elif isinstance(stmt, IRModule):
+    elif isinstance(stmt, IRModule):
@@
-        assert len(mod.functions) == 1, "Simplify should return a single function"
-        return list(mod.functions.values()).pop()
+        assert len(mod.functions) == 1, "Simplify should return a single function"
+        return list(mod.functions.values()).pop()

If you prefer (b), I can provide a patch that preserves the input type instead.


62-66: Expose the return type in apply_simplify docstring/signature

Keep it consistent with _Simplify if you adopt the suggestion above.

-def apply_simplify(stmt: Union[PrimFunc, IRModule],
-                   inline_let: bool = False) -> Union[PrimFunc, IRModule]:
-    """Apply Simplify pass to a PrimFunc or IRModule."""
+def apply_simplify(stmt: Union[PrimFunc, IRModule],
+                   inline_let: bool = False) -> PrimFunc:
+    """Apply Simplify pass and return a single PrimFunc."""
tilelang/language/gemm.py (4)

247-248: Prefer TypeError for invalid type arguments (matches Ruff hints)

These branches validate types; TypeError is more appropriate than ValueError.

-        else:
-            raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
+        else:
+            raise TypeError(f"Unsupported argument type: {type(object)} for buffer {object}")
@@
-        else:
-            raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
+        else:
+            raise TypeError(f"Unsupported argument type: {type(object)} for buffer {object}")
@@
-        else:
-            raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
+        else:
+            raise TypeError(f"Unsupported argument type: {type(object)} for buffer {object}")
@@
-        else:
-            raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
+        else:
+            raise TypeError(f"Unsupported argument type: {type(object)} for buffer {object}")

Also applies to: 266-267, 315-316, 328-329


185-196: Minor naming and readability nits

  • “experimental currently, for fast compilation” is vague; consider adding a brief note on what’s missing (e.g., supported dtypes/backends) and stability expectations.
  • Consider adding a deprecation/transition note if gemm_v2 is intended to replace gemm later.

220-232: Tiny style: avoid param name “object”

Using “object” as a parameter shadows the built-in. Rename to “obj” in the extracted helpers when you deduplicate.


1-359: Cross-file: potential SEqualReduce bug in C++ GemmPyNode (offset fields)

Unrelated to this Python file, but worth fixing before it bites equality/caching: in src/op/gemm_py.h, SEqualReduce compares offset_A to other->offset_B twice. That likely should be offset_A vs offset_A, and offset_B vs offset_B.

I can open a follow-up PR; suggested patch:

-           equal(offset_A, other->offset_B) &&
-           equal(offset_B, other->offset_B) &&
+           equal(offset_A, other->offset_A) &&
+           equal(offset_B, other->offset_B) &&
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (1)

470-505: Main guard: keep tests hermetic for CI.

Directly running run_gemm_ss(...) in __main__ can hide CI issues and cause flakiness on machines without GPUs. Prefer leaving only tilelang.testing.main() here.

Apply:

-    # tilelang.testing.main()
-    tilelang.disable_cache()
+    tilelang.testing.main()
+    # tilelang.disable_cache()
@@
-    run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
+    # run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
tilelang/intrinsics/mma_macro_generator.py (1)

429-444: make_mma_load_layout: flexible SR/RS construction is good; minor robustness nits.

  • Good: clean separation of A vs B and axis-order, and use of IndexMap.inverse/forward.
  • Nit: the ValueError messages are long; Ruff TRY003 flags these—consider shorter messages or custom exception types.

Apply (representative):

-            raise ValueError(f"Unsupported matrix {matrix}")
+            raise ValueError("Unsupported matrix type")

Also applies to: 487-516

tilelang/tileop/gemm/__init__.py (1)

138-170: Lowering bodies: local buffer sizes and KI loop are consistent; one naming nit.

  • Local sizes use warp_rows * local_size_a and warp_cols * local_size_b—consistent with ldmatrix/mma strides.
  • Nit: two functions named _gemm_rsr (rs and rr cases). Different scopes, but renaming the rr one to _gemm_rrr would improve grepability.
-                def _gemm_rsr() -> None:
+                def _gemm_rrr() -> None:
@@
-                return _Simplify(_gemm_rsr, inline_let=True)
+                return _Simplify(_gemm_rrr, inline_let=True)

Also applies to: 171-198, 199-224, 225-242

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bcfc834 and e36740d.

📒 Files selected for processing (27)
  • 3rdparty/tvm (1 hunks)
  • src/layout/gemm_layouts.cc (1 hunks)
  • src/op/gemm.cc (1 hunks)
  • src/op/gemm_py.cc (1 hunks)
  • src/op/gemm_py.h (1 hunks)
  • src/op/gemm_sp.cc (0 hunks)
  • src/target/codegen_cuda.cc (3 hunks)
  • src/target/utils.cc (2 hunks)
  • src/transform/frontend_legalize.cc (2 hunks)
  • src/transform/inject_pipeline.cc (2 hunks)
  • src/transform/lower_tile_op.cc (2 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6 hunks)
  • testing/python/transform/test_tilelang_transform_let_inline.py (1 hunks)
  • tilelang/__init__.py (1 hunks)
  • tilelang/engine/phase.py (1 hunks)
  • tilelang/intrinsics/mma_layout.py (1 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (16 hunks)
  • tilelang/ir.py (2 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/gemm.py (1 hunks)
  • tilelang/language/kernel.py (1 hunks)
  • tilelang/layout/swizzle.py (1 hunks)
  • tilelang/tileop/__init__.py (1 hunks)
  • tilelang/tileop/gemm/__init__.py (1 hunks)
  • tilelang/transform/__init__.py (1 hunks)
  • tilelang/transform/simplify.py (3 hunks)
  • tilelang/utils/target.py (2 hunks)
💤 Files with no reviewable changes (1)
  • src/op/gemm_sp.cc
🧰 Additional context used
🧬 Code graph analysis (19)
tilelang/engine/phase.py (2)
src/transform/frontend_legalize.cc (2)
  • LetInline (85-90)
  • LetInline (85-85)
tilelang/transform/simplify.py (1)
  • LetInline (8-16)
testing/python/transform/test_tilelang_transform_let_inline.py (2)
src/transform/frontend_legalize.cc (2)
  • LetInline (85-90)
  • LetInline (85-85)
tilelang/transform/simplify.py (1)
  • LetInline (8-16)
tilelang/language/__init__.py (1)
tilelang/language/gemm.py (2)
  • gemm (9-182)
  • gemm_v2 (186-359)
tilelang/ir.py (3)
src/op/gemm_py.h (1)
  • tvm (13-124)
src/op/gemm_sp.h (1)
  • tvm (13-95)
tilelang/language/ast/ir.py (1)
  • target (1682-1713)
tilelang/tileop/__init__.py (4)
tilelang/language/gemm.py (1)
  • gemm (9-182)
src/op/gemm_py.cc (1)
  • GemmPy (50-80)
tilelang/tileop/gemm/__init__.py (1)
  • GemmPy (31-270)
src/op/gemm_py.h (1)
  • GemmPy (116-121)
tilelang/transform/__init__.py (2)
tilelang/transform/simplify.py (1)
  • LetInline (8-16)
src/transform/frontend_legalize.cc (2)
  • LetInline (85-90)
  • LetInline (85-85)
src/transform/frontend_legalize.cc (3)
src/transform/lower_tile_op.cc (2)
  • f (210-240)
  • f (210-210)
src/transform/legalize_vectorized_loop.cc (2)
  • f (49-58)
  • f (49-49)
tilelang/transform/simplify.py (1)
  • LetInline (8-16)
tilelang/language/kernel.py (2)
tilelang/language/frame.py (1)
  • Current (153-162)
tilelang/intrinsics/mma_macro_generator.py (1)
  • get_thread_binding (117-123)
src/transform/inject_pipeline.cc (1)
tilelang/language/ast/ir.py (1)
  • block (342-358)
tilelang/transform/simplify.py (2)
src/transform/frontend_legalize.cc (2)
  • LetInline (85-90)
  • LetInline (85-85)
src/transform/simplify.cc (2)
  • Simplify (481-489)
  • Simplify (481-481)
src/op/gemm_py.cc (3)
tilelang/tileop/gemm/__init__.py (1)
  • GemmPy (31-270)
src/op/gemm.cc (10)
  • Clone (89-92)
  • Clone (89-89)
  • GetGemmInst (94-108)
  • GetGemmInst (94-94)
  • CheckWGMMA (307-357)
  • CheckWGMMA (307-307)
  • Lower (399-435)
  • Lower (399-399)
  • InferLayout (456-604)
  • InferLayout (456-457)
src/target/utils.cc (8)
  • TargetGetWarpSize (114-119)
  • TargetGetWarpSize (114-114)
  • TargetIsHopper (49-54)
  • TargetIsHopper (49-49)
  • TargetIsCDNA (63-72)
  • TargetIsCDNA (63-63)
  • TargetIsCuda (11-13)
  • TargetIsCuda (11-11)
src/op/gemm.cc (1)
src/op/gemm.h (2)
  • RegisterReflection (34-40)
  • RegisterReflection (117-139)
src/transform/lower_tile_op.cc (3)
src/transform/atomicadd_vectorize.cc (2)
  • indices (83-131)
  • indices (83-83)
src/transform/loop_vectorize.cc (2)
  • indices (113-168)
  • indices (113-113)
src/transform/loop_vectorize_dynamic.cc (2)
  • indices (141-194)
  • indices (141-141)
tilelang/utils/target.py (1)
src/target/utils.cc (26)
  • TargetIsCuda (11-13)
  • TargetIsCuda (11-11)
  • TargetIsRocm (14-16)
  • TargetIsRocm (14-14)
  • TargetIsVolta (28-33)
  • TargetIsVolta (28-28)
  • TargetIsTuring (35-40)
  • TargetIsTuring (35-35)
  • TargetIsAmpere (42-47)
  • TargetIsAmpere (42-42)
  • TargetIsHopper (49-54)
  • TargetIsHopper (49-49)
  • TargetIsSM120 (56-61)
  • TargetIsSM120 (56-56)
  • TargetIsCDNA (63-72)
  • TargetIsCDNA (63-63)
  • TargetHasAsyncCopy (74-92)
  • TargetHasAsyncCopy (74-74)
  • TargetHasLdmatrix (93-98)
  • TargetHasLdmatrix (93-93)
  • TargetHasStmatrix (100-105)
  • TargetHasStmatrix (100-100)
  • TargetHasBulkCopy (107-112)
  • TargetHasBulkCopy (107-107)
  • TargetGetWarpSize (114-119)
  • TargetGetWarpSize (114-114)
src/op/gemm_py.h (2)
src/op/operator.h (2)
  • TileOperatorNode (55-95)
  • TileOperator (62-94)
src/op/gemm_py.cc (11)
  • CheckWGMMA (142-192)
  • CheckWGMMA (142-142)
  • Lower (221-252)
  • Lower (221-221)
  • InferLayout (254-269)
  • InferLayout (254-255)
  • Clone (90-93)
  • Clone (90-90)
  • GetGemmInst (95-110)
  • GetGemmInst (95-96)
  • GemmPy (50-80)
tilelang/tileop/gemm/__init__.py (8)
src/op/gemm_py.h (2)
  • tvm (13-124)
  • GemmPy (116-121)
tilelang/utils/target.py (1)
  • target_is_cuda (87-88)
tilelang/intrinsics/mma_macro_generator.py (10)
  • make_mma_store_layout (518-587)
  • make_mma_load_layout (374-516)
  • ldmatrix_a (165-209)
  • ldmatrix_a (675-779)
  • ldmatrix_b (211-257)
  • ldmatrix_b (781-893)
  • mma (259-317)
  • mma (895-944)
  • mma (949-1047)
  • mma (1052-1151)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-16)
tilelang/ir.py (2)
  • GemmWarpPolicy (30-39)
  • compute_warp_partition (35-39)
tilelang/transform/simplify.py (1)
  • _Simplify (30-49)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
src/op/gemm_py.cc (1)
  • GemmPy (50-80)
tilelang/intrinsics/mma_macro_generator.py (4)
tilelang/language/kernel.py (4)
  • get_thread_binding (151-156)
  • get_thread_binding (261-265)
  • KernelLaunchFrame (72-206)
  • Current (115-121)
tilelang/intrinsics/mma_layout.py (2)
  • shared_16x32_to_mma_32x16_layout (72-74)
  • shared_32x16_to_mma_32x16_layout (77-79)
tilelang/intrinsics/utils.py (2)
  • shared_16x32_to_mma_32x16_layout (58-60)
  • shared_32x16_to_mma_32x16_layout (63-65)
tilelang/layout/fragment.py (2)
  • replicate (146-160)
  • repeat (123-144)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/gemm.py (1)
  • gemm_v2 (186-359)
tilelang/language/__init__.py (1)
  • annotate_layout (102-140)
tilelang/layout/fragment.py (1)
  • make_swizzled_layout (210-216)
tilelang/profiler/__init__.py (1)
  • assert_allclose (76-137)
tilelang/env.py (1)
  • disable_cache (232-233)
tilelang/language/gemm.py (4)
tilelang/ir.py (1)
  • GemmWarpPolicy (30-39)
tilelang/primitives/gemm/base.py (1)
  • GemmWarpPolicy (8-188)
tilelang/language/frame.py (2)
  • has_let_value (189-198)
  • get_let_value (201-210)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
🪛 Ruff (0.12.2)
tilelang/language/__init__.py

46-46: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/__init__.py

110-110: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/tileop/__init__.py

1-1: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/transform/__init__.py

5-5: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/tileop/gemm/__init__.py

101-101: Avoid specifying long messages outside the exception class

(TRY003)


103-103: Avoid specifying long messages outside the exception class

(TRY003)


243-243: Avoid specifying long messages outside the exception class

(TRY003)


245-245: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_macro_generator.py

427-427: Avoid specifying long messages outside the exception class

(TRY003)


442-442: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Avoid specifying long messages outside the exception class

(TRY003)


514-514: Avoid specifying long messages outside the exception class

(TRY003)

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py

246-246: Redefinition of unused threads from line 1

(F811)


362-362: Redefinition of unused threads from line 1

(F811)

tilelang/language/gemm.py

247-247: Prefer TypeError exception for invalid type

(TRY004)


247-247: Avoid specifying long messages outside the exception class

(TRY003)


266-266: Prefer TypeError exception for invalid type

(TRY004)


266-266: Avoid specifying long messages outside the exception class

(TRY003)


315-315: Prefer TypeError exception for invalid type

(TRY004)


315-315: Avoid specifying long messages outside the exception class

(TRY003)


328-328: Prefer TypeError exception for invalid type

(TRY004)


328-328: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (33)
src/layout/gemm_layouts.cc (2)

208-208: Formatting-only change; no semantic impact. LGTM.

Chaining Replicate(...)->Repeat(...) on one line preserves the original call order and arguments. No behavior change.


215-215: Formatting-only change; no semantic impact. LGTM.

Same call chain and arguments as before; just collapsed for brevity.

3rdparty/tvm (1)

1-1: Submodule bump: verify CI, reproducibility, and downstream API sync.

  • CI workflows already invoke git submodule update --init --recursive for 3rdparty/tvm.
  • Record and pin commit eddefbd65acb7b1ea51dd18068b4049754c4fa7a in your third-party docs/changelog and update license headers.
  • Confirm this TVM commit defines and exports (via TVM_REGISTER_GLOBAL and function implementations) all FFI symbols used by GemmPy and gemm_v2—namely tl.GemmWarpPolicyComputeWarpPartition, TargetIsCuda, TargetHasLdmatrix, and TargetGetWarpSize.

Consider referencing a stable annotated TVM tag alongside the commit hash for easier future audits.

src/transform/inject_pipeline.cc (2)

678-684: LGTM: async scope annotation

Marking async producer blocks with tir::attr::async_scope is appropriate and consistent with later commit/wait handling.


954-958: LGTM: refresh block access regions post-mutation

Recomputing reads/writes here prevents stale region metadata after rewriting.

src/target/utils.cc (1)

121-150: FFI exposure set LGTM

The reflection registrations match the helpers and align with Python wrappers.

src/target/codegen_cuda.cc (2)

1263-1263: LGTM: add indentation before emitting MMA asm

Improves formatting; harmless to codegen.

Also applies to: 1299-1299


1129-1142: The script will show exactly how many and which args tl::ptx_ldmatrix() pushes in copy.cc. This determines whether print_extern_call_stmt(func_name, 2) is correct or needs adjustment.

src/transform/frontend_legalize.cc (2)

37-44: LGTM! Clean refactoring from FrontendLegalizer to LetInliner

The renaming from FrontendLegalizer to LetInliner better reflects the actual functionality of this pass, which inlines let bindings. The implementation remains unchanged, preserving the existing behavior.


85-94: Consistent naming updates throughout the pass registration

The pass registration and FFI reflection are properly updated to use the new LetInline naming, maintaining consistency with the class rename.

src/op/gemm.cc (2)

611-612: LGTM! Proper TVM operator registration for GemmWarpPolicy

The registration correctly sets up the tl.GemmWarpPolicy operator with an appropriate TScript printer name for better debugging and visualization.


614-625: Well-structured FFI bindings for GEMM reflection and warp partition computation

The FFI initialization block properly:

  1. Registers reflection for both GemmNode and GemmWarpPolicyNode classes
  2. Exposes the GemmWarpPolicyComputeWarpPartition function to Python
  3. Correctly delegates to the ComputeWarpPartition method on the policy object

This enables Python code to query warp partitioning decisions, which is essential for the experimental T.gemm_v2 operator mentioned in the PR objectives.

tilelang/utils/target.py (1)

87-136: LGTM: thin wrappers cleanly mirror C++ FFI.

The wrappers are minimal and keep Python in sync with backend feature queries.

tilelang/intrinsics/mma_layout.py (2)

50-57: A-layout split and transposed helper look correct.

Thread id and local index mapping are consistent; transposed variant is a simple swap. Looks good.


58-65: B-layout split mirrors A-layout as expected.

Mapping matches the intended orientation for B; transposed helper is correct.

src/op/gemm_py.cc (1)

271-276: Confirm TL op arity matches Python frontend
set_num_inputs(5) must align with how gemm_py is called in Python; please verify that the Python binding passes exactly five arguments.

testing/python/transform/test_tilelang_transform_let_inline.py (1)

10-10: Switch to LetInline looks correct.

The pass rename aligns with the new API and keeps the test intent intact.

tilelang/engine/phase.py (1)

88-89: Switching to LetInline at the start of LowerAndLegalize is sensible

Inlining lets before InjectAssumes/Simplify is a good move for later analyses and matches the new public pass. No functional concerns from me.

tilelang/transform/simplify.py (1)

8-17: LetInline wrapper: LGTM

The pass exposure matches the C++ binding and keeps naming consistent across layers.

tilelang/language/gemm.py (1)

340-359: Registration of tl.gemm_py confirmed: the op is registered in src/op/gemm_py.cc via

TIR_REGISTER_TL_OP(GemmPy, gemm_py)

and has associated FFI hooks for .lower and .infer_layout.

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (3)

35-37: Shared allocs: consistent explicit scopes look good.

Explicit scope="shared" and fragment allocs are consistent across variants. No issues spotted.

Also applies to: 149-151, 263-266, 380-384


48-49: Correctly switched to T.gemm_v2.

Calls into T.gemm_v2 are wired with expected buffer kinds (shared/fragment) and transpose flags. Good.

Also applies to: 167-168, 281-282, 401-402


154-156: Layout annotations are coherent with rs/sr/rr paths.

Swizzling the shared operand and keeping the fragment operand un-swizzled matches the gemm_v2 layout inference paths. Looks good.

Also applies to: 268-270, 386-389

tilelang/intrinsics/mma_macro_generator.py (6)

53-79: Thread binding override via thread_var is a solid addition.

Gracefully falls back to the current Kernel frame when thread_var is absent, enabling external control in lowering. Nice.

Also applies to: 117-124


179-206: ldmatrix_a: index/transpose coupling—double-check semantics.

You compute A_shared_buf_elem as [wk, wi] when a_transposed else [wi, wk], and pass trans=self.a_transposed to ptx_ldmatrix. This ties address orientation and the hardware transpose flag to the same boolean. Depending on your MMA variant (row.col), A often uses non-transposed LD for non-transposed A, but the required ldmatrix “trans” bit can be architecture- and layout-specific. Please confirm on a small kernel that both TN and NN paths load correct lanes for m16n8k16.

I can help craft a minimal kernel to validate lane mapping if needed.


224-255: ldmatrix_b: inverted trans bit matches row.col expectation.

For B you use trans = not b_transposed and flip the index order accordingly. This aligns with typical row.col usage. LGTM pending A verification above.


333-371: stmatrix: correct handling of 2D vs 4D destinations.

Thread/lane mapping and local_id calculation align with the store IndexMap. Looks consistent.


399-404: New sr A/B transform funcs: ensure availability and dtype coverage.

The imports reference shared_16x16_to_mma_32x8_layout_sr_a/_sr_b for 16-bit and 16x32/32x16 for 8-bit. Verify these functions exist across all targeted CUDA arches and cover FP16/BF16/FP8. Missing symbols will break import at compile-time.

If any dtype isn’t ready, gate it behind a clear error in gemm_v2 or policy selection.

Also applies to: 421-426


280-316: mma: saturate flag explicitly false—OK.

Explicit boolean improves readability. No functional concerns.

tilelang/tileop/gemm/__init__.py (4)

17-21: FFI shims return thread extent correctly.

Bridging to infer_layout/lower using thread_bounds.extent looks correct.

Also applies to: 23-28


76-101: Infer-layout mappings align with operand kinds.

  • ss: swizzle A/B shared; C uses MMA store layout.
  • sr/rs: fragment side gets MMA load layout; shared side swizzled.
  • rr: both A/B load layouts; C store layout.
    This matches test variants and emitter expectations.

247-259: in_dtype/accum_dtype/chunk accessors look correct.

  • in_dtype guard asserts A/B dtype equality.
  • chunk respects trans_A by selecting K-positioned axis.

260-271: is_gemm_ predicates correctly classify shared vs fragment.*

Accurate use of is_shared/is_fragment. Good.

Comment on lines +50 to +80
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();

node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
node->trans_A = args[3].as<Bool>().value();
node->trans_B = args[4].as<Bool>().value();
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<Bool>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bounds-check args and validate BufferMap lookups.

Constructor indexes args[0..13] unconditionally and uses vmap[...] without presence checks. This risks OOB access or null buffers.

Apply:

 GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
   ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
+  // Required: Aptr, Bptr, Cptr, trans_A, trans_B, M, N, K, policy, clear_accum,
+  //           stride_A, stride_B, offset_A, offset_B
+  ICHECK_GE(args.size(), 14) << "gemm_py expects at least 14 positional args";
+
+  // Defaults for optional fields
+  node->kPack = 1;
+  node->wg_wait = 0;

   node->Aptr = args[0];
   node->Bptr = args[1];
   node->Cptr = args[2];
-  node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
-  node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
-  node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
-  node->trans_A = args[3].as<Bool>().value();
-  node->trans_B = args[4].as<Bool>().value();
+  auto Avar = GetVarFromAccessPtr(node->Aptr);
+  auto Bvar = GetVarFromAccessPtr(node->Bptr);
+  auto Cvar = GetVarFromAccessPtr(node->Cptr);
+  ICHECK(vmap.count(Avar)) << "Aptr not found in BufferMap";
+  ICHECK(vmap.count(Bvar)) << "Bptr not found in BufferMap";
+  ICHECK(vmap.count(Cvar)) << "Cptr not found in BufferMap";
+  node->A = vmap.at(Avar);
+  node->B = vmap.at(Bvar);
+  node->C = vmap.at(Cvar);
+  node->trans_A = args[3].as<IntImm>().value()->value != 0;
+  node->trans_B = args[4].as<IntImm>().value()->value != 0;
   node->M = args[5].as<IntImm>().value()->value;
   node->N = args[6].as<IntImm>().value()->value;
   node->K = args[7].as<IntImm>().value()->value;
   node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
-  node->clear_accum = args[9].as<Bool>().value();
+  node->clear_accum = args[9].as<IntImm>().value()->value != 0;
   node->stride_A = args[10].as<IntImm>().value()->value;
   node->stride_B = args[11].as<IntImm>().value()->value;
   node->offset_A = args[12].as<IntImm>().value()->value;
   node->offset_B = args[13].as<IntImm>().value()->value;
   if (args.size() > 14) {
     node->kPack = args[14].as<IntImm>().value()->value;
     if (node->kPack != 1 && node->kPack != 2) {
       ICHECK(false) << "kPack must be 1 or 2";
     }
   }
   if (args.size() > 15) {
     node->wg_wait = args[15].as<IntImm>().value()->value;
   }
   data_ = std::move(node);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
node->trans_A = args[3].as<Bool>().value();
node->trans_B = args[4].as<Bool>().value();
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<Bool>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
// Required: Aptr, Bptr, Cptr, trans_A, trans_B, M, N, K, policy, clear_accum,
// stride_A, stride_B, offset_A, offset_B
ICHECK_GE(args.size(), 14) << "gemm_py expects at least 14 positional args";
// Defaults for optional fields
node->kPack = 1;
node->wg_wait = 0;
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
auto Avar = GetVarFromAccessPtr(node->Aptr);
auto Bvar = GetVarFromAccessPtr(node->Bptr);
auto Cvar = GetVarFromAccessPtr(node->Cptr);
ICHECK(vmap.count(Avar)) << "Aptr not found in BufferMap";
ICHECK(vmap.count(Bvar)) << "Bptr not found in BufferMap";
ICHECK(vmap.count(Cvar)) << "Cptr not found in BufferMap";
node->A = vmap.at(Avar);
node->B = vmap.at(Bvar);
node->C = vmap.at(Cvar);
node->trans_A = args[3].as<IntImm>().value()->value != 0;
node->trans_B = args[4].as<IntImm>().value()->value != 0;
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<IntImm>().value()->value != 0;
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
🤖 Prompt for AI Agents
In src/op/gemm_py.cc around lines 50 to 80, the constructor assumes args has at
least 14 entries and that vmap contains entries for GetVarFromAccessPtr results,
which can cause OOB or missing-buffer crashes; add explicit checks before each
indexed access (ICHECK or if-checks verifying args.size() >= required_index+1)
and validate that .as<...>() returns non-null (or use ICHECK(node) semantics)
before dereferencing value(); for vmap lookups, compute the key once, check
vmap.count(key) (or vmap.find(key) != vmap.end()) and handle missing buffers
with a clear ICHECK(false) or throw with a descriptive message; keep existing
optional-field handling but gate accesses to args[14] and args[15] behind
args.size() checks; ensure all messages indicate which arg or buffer is missing
to aid debugging.

Comment on lines 110 to 116
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
# GEMM tests for float16
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32,
2) # f16f16f16_nn
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

num_stages=0 will likely break T.Pipelined; use >= 1.

Passing 0 to T.Pipelined(..., num_stages=...) can lead to invalid scheduling or no prefetch. Align with other tests using 2 and set to at least 1.

Apply:

-    run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
-    run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
-    run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
-    run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
+    run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
+    run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 1)
+    run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 1)
+    run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 1)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
# GEMM tests for float16
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32,
2) # f16f16f16_nn
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
# GEMM tests for float16
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 1)
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 1)
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 1)
🤖 Prompt for AI Agents
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py around lines 110 to
116: the GEMM calls pass num_stages=0 which can break T.Pipelined; change the
final argument from 0 to 2 (or at least 1) for each run_gemm_ss invocation so
Pipelined gets a valid number of stages (prefer 2 to match other tests).

Comment on lines +287 to +291
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix error message to reference gemm_v2

The assertion message still says “T.gemm”; this will mislead users debugging shape issues.

-    assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
+    assert K == K_B, f"T.gemm_v2 K shape check failed: K_A = {K}, K_B = {K_B}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
M, N = C_shape
K = A_shape[-2] if transpose_A else A_shape[-1]
K_B = B_shape[-1] if transpose_B else B_shape[-2]
assert K == K_B, f"T.gemm_v2 K shape check failed: K_A = {K}, K_B = {K_B}"
🤖 Prompt for AI Agents
In tilelang/language/gemm.py around lines 287 to 291, the assertion message
refers to "T.gemm" which is incorrect for this function; update the assertion
string to reference "T.gemm_v2" (or the correct function name) so the message
reads something like "T.gemm_v2 K shape check failed: K_A = {K}, K_B = {K_B}"
and retain the existing shape variables in the f-string for clarity.

Comment on lines +264 to 265
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
return KernelLaunchFrame.Current().get_thread_binding(dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace asserts with hard checks to avoid -O stripping

Using assert for runtime preconditions can be stripped with Python -O, risking None deref. Raise an explicit exception and avoid double calling Current().

Apply this diff to each wrapper:

 def get_thread_binding(dim: int = 0) -> Var:
   """Returns the thread binding for the given dimension.
   """
-  assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
-  return KernelLaunchFrame.Current().get_thread_binding(dim)
+  current = KernelLaunchFrame.Current()
+  if current is None:
+      raise RuntimeError("KernelLaunchFrame is not initialized")
+  return current.get_thread_binding(dim)

Repeat the same pattern for:

  • get_thread_bindings()
  • get_block_binding()
  • get_block_bindings()
  • get_thread_extent()
  • get_thread_extents()
  • get_block_extent()
  • get_block_extents()

Optionally add a small helper to DRY:

def _current_kernel_frame() -> KernelLaunchFrame:
    current = KernelLaunchFrame.Current()
    if current is None:
        raise RuntimeError("KernelLaunchFrame is not initialized")
    return current

Then call _current_kernel_frame() in the wrappers.

Also applies to: 271-272, 278-279, 285-286, 292-293, 299-300, 306-307, 313-314

🤖 Prompt for AI Agents
In tilelang/language/kernel.py around lines 264-265 (and similarly at 271-272,
278-279, 285-286, 292-293, 299-300, 306-307, 313-314), replace the assert-based
runtime checks with explicit checks that raise a RuntimeError when
KernelLaunchFrame.Current() is None and avoid calling Current() twice; implement
(optionally) a small helper like _current_kernel_frame() that calls
KernelLaunchFrame.Current(), raises RuntimeError("KernelLaunchFrame is not
initialized") when None, and returns the current frame, then have each wrapper
call that helper and return the appropriate get_* result (e.g.,
get_thread_binding, get_thread_bindings, get_block_binding, etc.).

@@ -1,5 +1,6 @@
from typing import Literal, Union
from tilelang import tvm as tvm
from tilelang import _ffi_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Guard against missing FFI bindings (verify symbols).

Import assumes the C++ FFI registrations exist. Please verify the symbols are present to avoid ImportError at runtime.

Run this to confirm all required FFI globals are registered:


🏁 Script executed:

#!/bin/bash
# Expect matches for each symbol below.
rg -nP 'TVM_REGISTER_GLOBAL\("tl\.(Target(Is|Has|Get).*)"\)' -g 'src/**' -S
# Quick check of Python-facing wrappers reference:
rg -nP '\bTarget(Is|Has|Get)\w+\(' src/target/utils.cc src/target/utils.h

Length of output: 3351


Add missing FFI global registrations for target utilities.
No TVM_REGISTER_GLOBAL("tl.TargetIs…"), "tl.TargetHas…" or "tl.TargetGet…" entries were found, so _ffi_api won’t expose TargetIsCuda, TargetHasAsyncCopy, TargetGetWarpSize, etc. Add corresponding registrations in src/target/utils.cc, for example:

TVM_REGISTER_GLOBAL("tl.TargetIsCuda")
  .set_body_typed(TargetIsCuda);
// …and similarly for TargetIsRocm, TargetHasAsyncCopy, TargetGetWarpSize, etc.
🤖 Prompt for AI Agents
In tilelang/utils/target.py around line 3, the module imports _ffi_api but the
corresponding C++ FFI globals are not registered, so functions like
TargetIsCuda, TargetIsRocm, TargetHasAsyncCopy, TargetGetWarpSize are not
exposed; open src/target/utils.cc and add TVM_REGISTER_GLOBAL entries mapping
those symbols to their typed C++ implementations (e.g.
TVM_REGISTER_GLOBAL("tl.TargetIsCuda").set_body_typed(TargetIsCuda); and
likewise for TargetIsRocm, TargetHasAsyncCopy, TargetGetWarpSize and any other
missing TargetIs*/TargetHas*/TargetGet* utilities), then rebuild to ensure
_ffi_api exposes the functions to tilelang/utils/target.py.

@oraluben oraluben mentioned this pull request Sep 10, 2025
- Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations.
- Improved block realization handling in `gemm_py.cc` for better assignment of global symbols.
- Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity.
- Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations.

These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tilelang/profiler/__init__.py (1)

129-145: Guard and refactor all direct torch.float8_ references*
Wrap every use of torch.float8_e* across the codebase—e.g. in tilelang/utils/tensor.py, tilelang/contrib/dlpack.py, tests under testing/python/** and examples under examples/**—with a version-agnostic check (e.g. hasattr(torch, …) or your new is_float8() helper) or fallback via getattr(torch, dtype_name) to avoid AttributeErrors on older PyTorch.

src/transform/inject_pipeline.cc (1)

676-684: Use Integer(1) and implement async_scope lowering

  • Change the literal 1 to Integer(1) in the AttrStmt for type safety.
  • No references to tir::attr::async_scope were found in lowering or codegen (the attribute is effectively a no-op). Implement its consumption downstream or remove the marker.
♻️ Duplicate comments (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (1)

504-504: Use "float32" instead of "float" for dtype consistency.

Line 504 uses "float" as the accumulation dtype, but based on the search results and the TensorCoreIntrinEmitter's dtype_abbrv mapping in the first file, this should be "float32" to avoid KeyError during lowering.

🧹 Nitpick comments (8)
tilelang/profiler/__init__.py (1)

119-137: Minor: avoid redefining helper inside the loop

Define is_float8 once at module scope to reduce per-iteration overhead and make it reusable.

tilelang/tileop/gemm/gemm_base.py (2)

13-18: Add lightweight contract checks early.

Assert basic dtype consistency up front to fail-fast.

 class GemmBase(object):
     gemm_node: "GemmPy"
 
     def infer_layout(self, target: Target, thread_nums: int):
         raise NotImplementedError("infer_layout is not implemented")
     
     def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
         raise NotImplementedError("lower is not implemented")
+
+    def __post_init__(self):
+        # Basic sanity: inputs must agree in dtype; C holds accum dtype.
+        assert self.A.dtype == self.B.dtype, "A and B must have the same dtype"
+        # If node carries explicit dtypes, keep them consistent with buffers.
+        if getattr(self.gemm_node, "in_dtype", None):
+            assert self.gemm_node.in_dtype == self.A.dtype, "in_dtype mismatch"
+        if getattr(self.gemm_node, "accum_dtype", None):
+            assert self.gemm_node.accum_dtype == self.C.dtype, "accum_dtype mismatch"

101-116: Tighten type hints for TIR scalars/exprs.

stride_*/offset_* are typically tir.PrimExpr, not plain int.

-    def stride_A(self) -> int:
+    def stride_A(self) -> tir.PrimExpr:
         return self.gemm_node.stride_A
@@
-    def stride_B(self) -> int:
+    def stride_B(self) -> tir.PrimExpr:
         return self.gemm_node.stride_B
@@
-    def offset_A(self) -> int:
+    def offset_A(self) -> tir.PrimExpr:
         return self.gemm_node.offset_A
@@
-    def offset_B(self) -> int:
+    def offset_B(self) -> tir.PrimExpr:
         return self.gemm_node.offset_B
tilelang/intrinsics/mma_layout.py (2)

106-116: Naming consistency (“ldmatrix_” vs “mma_load_”).

Elsewhere, loaders are ldmatrix_*. Consider renaming mma_load_{a,b}_32x16_to_shared_16x32_layout to ldmatrix_32x16_to_shared_16x32_layout_{a,b} or add aliases to avoid confusion.


50-53: Minor comment/style polish.

Nit: rephrase “if wanna trans” to “for transposed layout, apply map_indices” to keep code comments professional.

Also applies to: 60-68

tilelang/intrinsics/utils.py (1)

24-26: Broaden dtype annotation to include bf16 (supported in this PR).

Implementation handles any 16-bit dtype; the type hint restricts to "float16" only.

-    dtype: Literal["float16", "int8"] = "float16",
+    dtype: Literal["float16", "bfloat16", "int8"] = "float16",
tilelang/tileop/gemm/__init__.py (1)

52-65: Consider extracting the error message to reduce redundancy

Both methods raise the same error message. While the static analysis hint about TRY003 isn't critical, you could improve maintainability by defining the error message once.

+UNSUPPORTED_TARGET_MSG = "Unsupported target: {}"
+
 def infer_layout(self, target: Target, thread_nums: int):
     if target_is_cuda(target):
         # TODO(lei): Support more cuda architectures, now mma only
         return GemmMMA(self).infer_layout(target, thread_nums)
     else:
-        raise ValueError(f"Unsupported target: {target}")
+        raise ValueError(UNSUPPORTED_TARGET_MSG.format(target))

 def lower(self, target: Target, thread_nums: int, thread_var: tir.Var):
     if target_is_cuda(target):
         # TODO(lei): Support more cuda architectures, now mma only
         # Now only implement ssr layout
         return GemmMMA(self).lower(target, thread_nums, thread_var)
     else:
-        raise ValueError(f"Unsupported target: {target}")
+        raise ValueError(UNSUPPORTED_TARGET_MSG.format(target))
tilelang/tileop/gemm/gemm_mma.py (1)

197-207: Remove redundant is_gemm_ methods from gemm_mma.py*
These four methods duplicate identical implementations in GemmBase; delete lines 197–207 to eliminate code duplication.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e36740d and a3f2564.

📒 Files selected for processing (14)
  • src/op/copy.cc (0 hunks)
  • src/op/gemm.cc (2 hunks)
  • src/op/gemm_py.cc (1 hunks)
  • src/target/codegen_cuda.cc (3 hunks)
  • src/transform/inject_pipeline.cc (2 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (7 hunks)
  • tilelang/intrinsics/mma_layout.py (2 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (18 hunks)
  • tilelang/intrinsics/utils.py (2 hunks)
  • tilelang/profiler/__init__.py (1 hunks)
  • tilelang/tileop/gemm/__init__.py (1 hunks)
  • tilelang/tileop/gemm/gemm_base.py (1 hunks)
  • tilelang/tileop/gemm/gemm_mma.py (1 hunks)
  • tilelang/utils/tensor.py (1 hunks)
💤 Files with no reviewable changes (1)
  • src/op/copy.cc
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/target/codegen_cuda.cc
  • src/op/gemm_py.cc
🧰 Additional context used
🧬 Code graph analysis (8)
tilelang/profiler/__init__.py (2)
tilelang/engine/param.py (1)
  • is_float8 (81-91)
tilelang/utils/tensor.py (1)
  • torch_assert_close (220-312)
tilelang/intrinsics/utils.py (1)
tilelang/intrinsics/mma_layout.py (1)
  • ldmatrix_32x16_to_shared_16x32_layout_b (36-39)
tilelang/tileop/gemm/gemm_mma.py (5)
tilelang/tileop/gemm/gemm_base.py (21)
  • GemmBase (10-131)
  • infer_layout (13-14)
  • policy (130-131)
  • M (33-34)
  • N (37-38)
  • in_dtype (53-54)
  • in_dtype (65-67)
  • accum_dtype (57-58)
  • accum_dtype (70-71)
  • trans_A (45-46)
  • trans_B (49-50)
  • chunk (61-62)
  • chunk (74-75)
  • is_gemm_ss (20-21)
  • A (78-79)
  • B (82-83)
  • C (86-87)
  • is_gemm_sr (23-24)
  • is_gemm_rs (26-27)
  • is_gemm_rr (29-30)
  • lower (16-17)
tilelang/intrinsics/mma_macro_generator.py (11)
  • TensorCoreIntrinEmitter (24-630)
  • make_mma_store_layout (561-630)
  • make_mma_load_layout (421-559)
  • ldmatrix_a (190-241)
  • ldmatrix_a (718-822)
  • ldmatrix_b (243-302)
  • ldmatrix_b (824-936)
  • mma (304-364)
  • mma (938-987)
  • mma (992-1090)
  • mma (1095-1194)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/transform/simplify.py (1)
  • _Simplify (30-49)
tilelang/tileop/gemm/__init__.py (2)
  • infer_layout (52-57)
  • lower (59-65)
tilelang/tileop/gemm/gemm_base.py (3)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/tileop/gemm/__init__.py (2)
  • infer_layout (52-57)
  • lower (59-65)
tilelang/tileop/gemm/gemm_mma.py (6)
  • infer_layout (15-57)
  • lower (59-194)
  • is_gemm_ss (197-198)
  • is_gemm_sr (200-201)
  • is_gemm_rs (203-204)
  • is_gemm_rr (206-207)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/gemm.py (1)
  • gemm_v2 (186-359)
tilelang/jit/kernel.py (3)
  • get_kernel_source (378-389)
  • get_profiler (360-376)
  • out_idx (446-447)
tilelang/language/__init__.py (1)
  • annotate_layout (102-140)
tilelang/layout/fragment.py (1)
  • make_swizzled_layout (210-216)
tilelang/profiler/__init__.py (1)
  • assert_allclose (76-145)
tilelang/tileop/gemm/__init__.py (5)
src/op/gemm_py.h (2)
  • tvm (13-124)
  • GemmPy (116-121)
tilelang/utils/target.py (1)
  • target_is_cuda (87-88)
tilelang/ir.py (1)
  • GemmWarpPolicy (30-39)
tilelang/tileop/gemm/gemm_mma.py (3)
  • GemmMMA (13-207)
  • infer_layout (15-57)
  • lower (59-194)
src/op/gemm_py.cc (1)
  • GemmPy (50-80)
src/op/gemm.cc (1)
src/op/gemm.h (2)
  • RegisterReflection (34-40)
  • RegisterReflection (117-139)
tilelang/intrinsics/mma_macro_generator.py (4)
tilelang/intrinsics/utils.py (2)
  • mma_store_index_map (68-69)
  • get_ldmatrix_offset (19-50)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/intrinsics/mma_layout.py (3)
  • mma_load_a_32x16_to_shared_16x32_layout (106-109)
  • mma_load_b_32x16_to_shared_16x32_layout (112-115)
  • transform_func (176-179)
tilelang/language/kernel.py (5)
  • threads (195-199)
  • get_thread_binding (151-156)
  • get_thread_binding (261-265)
  • KernelLaunchFrame (72-206)
  • Current (115-121)
🪛 Ruff (0.12.2)
tilelang/tileop/gemm/gemm_mma.py

57-57: Avoid specifying long messages outside the exception class

(TRY003)


194-194: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm/gemm_base.py

11-11: Undefined name GemmPy

(F821)


65-65: Redefinition of unused in_dtype from line 53

(F811)


70-70: Redefinition of unused accum_dtype from line 57

(F811)


74-74: Redefinition of unused chunk from line 61

(F811)

tilelang/tileop/gemm/__init__.py

57-57: Avoid specifying long messages outside the exception class

(TRY003)


65-65: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_macro_generator.py

468-468: Avoid specifying long messages outside the exception class

(TRY003)


485-485: Avoid specifying long messages outside the exception class

(TRY003)


543-543: Avoid specifying long messages outside the exception class

(TRY003)


557-557: Avoid specifying long messages outside the exception class

(TRY003)

🪛 GitHub Actions: CI
tilelang/tileop/gemm/gemm_mma.py

[error] 150-150: codespell: Misspelled word 'parrent' found at line 150; should be 'parent'.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task
🔇 Additional comments (32)
src/transform/inject_pipeline.cc (1)

954-959: Good: recomputing block read/write regions to match the rewritten body.

This keeps region metadata in sync post-rewrite and reduces chances of stale reads/writes leaking into later passes.

tilelang/intrinsics/mma_layout.py (1)

76-88: A/B 32x16 mappings: request quick sanity checks.

Formulas look consistent with 32x16 microtiles; please confirm symmetry of A/B and transposed wrappers with a small table-driven check (thread_id∈[0,31], local_id∈[0,15]) to ensure ranges and uniqueness per warp.

You can validate purely with Python (no GPU) by checking for duplicate (row,col) pairs and expected ranges.

Also applies to: 90-92

tilelang/intrinsics/utils.py (2)

3-9: Import rename LGTM; keep both A/B helpers aligned.

Switching to ldmatrix_32x16_to_shared_16x32_layout_b matches the new A/B split.


39-43: Verified no callers of old B helper
Only its definition remains in tilelang/intrinsics/mma_layout.py; no compatibility alias needed.

src/op/gemm.cc (2)

616-631: LGTM! Clean FFI registration implementation

The FFI registration block is well-structured. The static initialization ensures proper registration of the GemmWarpPolicy operations at module load time, and the lambda wrapper cleanly exposes the ComputeWarpPartition functionality.


254-260: No action needed: existing guards prevent divide-by-zero and enforce divisibility
The if (m_per_warp < 1 || n_per_warp < 1) check precludes any zero or negative divisor, and the initial ICHECK constraints combined with the post-search splitting logic guarantee that both M % (m × kMPerWarp) and N % (n × kNPerWarp) are zero.

Likely an incorrect or invalid review comment.

tilelang/tileop/gemm/__init__.py (1)

13-16: Return value is unused but consistent with the design

The gemm_py.infer_layout() returns a dict mapping buffers to layouts, but the wrapper function doesn't explicitly return this value. This appears intentional as the TVM FFI may handle the return implicitly.

tilelang/intrinsics/mma_macro_generator.py (16)

5-5: LGTM: Var import correctly added for thread_var typing.

The import of Var from TVM is necessary to support the new optional thread_var parameter typing.


12-19: LGTM: MMA layout imports updated correctly.

The layout imports have been updated to support the new matrix-specific layout functions (_sr_a and _sr_b variants) and the mma_load functions, which are needed for the enhanced A/B matrix handling.


30-32: LGTM: Dynamic n_dim support added.

The change from a constant N_DIM = 16 to a dynamic n_dim = 16 with improved documentation is correct. This allows for flexible n_dim sizes (16 or 8) based on the warp configuration.


62-62: New parameter added: thread_var for custom thread binding.

The optional thread_var parameter allows overriding the default thread binding retrieval mechanism.


77-79: LGTM: Micro size initialization updated for dynamic n_dim.

The calls to _initialize_micro_size and _initialize_local_size now properly pass the dynamic n_dim value.


85-85: LGTM: Thread variable stored correctly.

The thread_var parameter is properly stored as an instance variable.


115-134: LGTM: Enhanced micro size initialization with dynamic n_dim.

The refactored _initialize_micro_size method now properly handles dynamic n_dim calculation based on warp_col_tiles, supporting both 16 and 8 dimensions. The validation logic ensures proper tile sizes and divisibility.


142-149: LGTM: Thread binding abstraction added.

The new get_thread_binding method provides a clean abstraction that allows using either a custom thread_var or the default kernel frame thread binding.


203-204: LGTM: Improved ldmatrix availability check.

The ldmatrix availability logic now correctly excludes int8 + transposed A matrix cases, which are not supported by the ldmatrix instruction.


206-206: LGTM: Thread binding usage centralized.

The use of self.get_thread_binding() provides consistent thread binding access throughout the method.


218-240: LGTM: Improved ldmatrix_a with proper transposition handling.

The enhanced implementation now properly handles transposed matrices with the correct indexing (A_shared_buf[wk, wi] vs A_shared_buf[wi, wk]) and passes the transposition flag correctly to the ldmatrix instruction.


256-300: LGTM: Enhanced ldmatrix_b with proper transposition and replicate handling.

The implementation now correctly:

  • Uses centralized thread binding
  • Calculates replicate_b based on n_dim == 16
  • Handles transposed B matrix indexing properly
  • Sets the correct trans flag for the ldmatrix instruction

319-319: LGTM: MMA instruction enhanced with replicate support.

The MMA implementation now properly handles replication for 16-dimensional cases with the additional MMA instruction call when replicate_b is true. The saturate parameter is correctly documented.

Also applies to: 345-362


376-380: LGTM: Dynamic n_dim usage in stmatrix.

The stmatrix method now correctly uses the dynamic n_dim value instead of a hardcoded constant, and maintains consistent thread binding usage.


447-558: Enhanced make_mma_load_layout with matrix-specific handling.

The method has been significantly refactored to properly handle A/B matrices with different axis orders and transposition states. The new implementation:

  • Uses matrix-specific flags (matrix_is_a, matrix_is_b) for clearer logic
  • Employs dedicated transform functions for A and B matrices
  • Handles sr/rs axis order conditions properly
  • Constructs layouts with appropriate warp/block/replicate dimensions

The logic appears sound for handling the various GEMM configurations (ss/sr/rs/rr).


729-729: LGTM: Thread binding consistency in ladder transform classes.

The ladder transform emitter classes now also use the centralized get_thread_binding() method for consistency.

Also applies to: 836-836

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (9)

34-35: LGTM: Shared memory allocation with explicit scope.

The shared memory allocations now use explicit scope="shared" parameter, which is consistent with the new memory management approach.


47-49: LGTM: Switch to gemm_v2 API.

The transition from the old T.gemm to T.gemm_v2 aligns with the PR objective of introducing the new GEMM operator stack.


54-68: LGTM: Function renamed and kernel source printing added.

The renaming from run_gemm to run_gemm_ss provides better clarity for the shared-shared configuration. Adding print(kernel.get_kernel_source()) aids in debugging and verification.

Also applies to: 93-94


110-128: LGTM: Expanded test coverage with corrected num_stages.

The test suite now includes:

  • Fixed num_stages=2 instead of 0 (avoiding potential pipeline issues)
  • Enhanced data type coverage (int8, float8_e5m2)
  • n8 dimension testing

This significantly improves test coverage for various GEMM configurations.


160-167: LGTM: RS configuration with swizzled layout and fragment handling.

The matmul_rs function properly implements the reduction-spatial configuration with:

  • Swizzled layout annotation for A_shared
  • Fragment allocation and copy for A matrix
  • Correct gemm_v2 usage

Also applies to: 177-178


242-255: LGTM: Comprehensive RS test coverage.

The RS tests now include the full spectrum of configurations (transpose combinations, data types, and dimension variants) consistent with the SS tests.


258-384: LGTM: New SR configuration implementation and tests.

The matmul_sr function and corresponding tests properly implement the spatial-reduction configuration with B matrix fragmentation and swizzled layout. The test coverage mirrors the comprehensive approach used in other configurations.


386-516: LGTM: Complete RR configuration with dual fragmentation.

The matmul_rr function correctly implements the reduction-reduction configuration with both A and B matrix fragmentation and swizzled layouts. The test suite provides comprehensive coverage of all transpose combinations and data types.


520-521: LGTM: Test runner updated for direct execution.

The main execution path now directly calls a specific GEMM test case, which is useful for focused testing and debugging.

Comment on lines +70 to +73
shared_16x16_to_mma_32x8_layout_sr_a = shared_16x16_to_mma_a_32x8_layout
shared_16x16_to_mma_32x8_layout_sr_b = shared_16x16_to_mma_b_32x8_layout
shared_16x16_to_mma_32x8_layout_rs_a = shared_16x16_to_mma_a_32x8_layout_trans
shared_16x16_to_mma_32x8_layout_rs_b = shared_16x16_to_mma_b_32x8_layout_trans
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Back-compat: ensure no remaining uses of removed SR/RS aliases.

You introduced per-matrix aliases (_sr_a/_sr_b/_rs_a/_rs_b). If older sites still import shared_16x16_to_mma_32x8_layout_sr/..._rs, they will break.

Run to find potential breakages:

If matches exist, consider re-adding deprecated shims that forward to _sr_a/_rs_b with clear naming, or update call sites in this PR.


🏁 Script executed:

#!/bin/bash
# Old aliases that may be referenced elsewhere
rg -nP '\bshared_16x16_to_mma_32x8_layout_(sr|rs)\b|shared_16x16_to_mma_32x8_layout_trans\b|shared_16x16_to_mma_32x8_layout\b' -g '!**/site-packages/**' -C2

Length of output: 148


🏁 Script executed:

#!/bin/bash
# Search across all Python files, including those ignored by default, for any uses of the old aliases or base name
rg --no-ignore -nP '\bshared_16x16_to_mma_32x8_layout_(?:sr|rs)\b|\bshared_16x16_to_mma_32x8_layout_trans\b|\bshared_16x16_to_mma_32x8_layout\b' -g '*.py' -C2

Length of output: 1383


Restore deprecated SR/RS aliases or update example usage

• examples/plot_layout/fragment_mma_load_a.py (lines 35–36, 50–51) still imports and invokes the removed
shared_16x16_to_mma_32x8_layout_sr/rs aliases—update these sites to the new per-matrix functions (e.g. _sr_a/_sr_b/_rs_a/_rs_b), or re-add fallback shims in tilelang/intrinsics/mma_layout.py mapping the old names to the new ones.

🤖 Prompt for AI Agents
In tilelang/intrinsics/mma_layout.py around lines 70 to 73, example code still
imports the removed shared_16x16_to_mma_32x8_layout_sr/rs aliases; restore
backward compatibility by re-adding alias shims that map the old names to the
new per-matrix functions (e.g. set shared_16x16_to_mma_32x8_layout_sr_a ->
shared_16x16_to_mma_a_32x8_layout_sr_a, shared_16x16_to_mma_32x8_layout_sr_b ->
shared_16x16_to_mma_b_32x8_layout_sr_b, shared_16x16_to_mma_32x8_layout_rs_a ->
shared_16x16_to_mma_a_32x8_layout_rs_a, shared_16x16_to_mma_32x8_layout_rs_b ->
shared_16x16_to_mma_b_32x8_layout_rs_b) so existing examples keep working;
alternatively update the example files to import and call the new per-matrix
names instead of the old aliases.

Comment on lines 52 to 63
@property
def in_dtype(self) -> str:
return self.gemm_node.in_dtype

@property
def accum_dtype(self) -> str:
return self.gemm_node.accum_dtype

@property
def chunk(self) -> int:
return self.gemm_node.chunk

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Resolve duplicate properties (F811) and pick a single source of truth.

in_dtype, accum_dtype, and chunk are defined twice; the latter overrides the former and confuses readers and linters. Recommend keeping the node-backed versions and removing the buffer-derived duplicates (you already assert consistency in __post_init__).

     @property
     def in_dtype(self) -> str:
-        return self.gemm_node.in_dtype
+        return self.gemm_node.in_dtype
@@
     @property
     def accum_dtype(self) -> str:
-        return self.gemm_node.accum_dtype
+        return self.gemm_node.accum_dtype
@@
     @property
     def chunk(self) -> int:
-        return self.gemm_node.chunk
+        return self.gemm_node.chunk
@@
-    @property
-    def in_dtype(self) -> str:
-        assert self.A.dtype == self.B.dtype, "A and B must have the same dtype"
-        return self.A.dtype
-
-    @property
-    def accum_dtype(self) -> str:
-        return self.C.dtype
-
-    @property
-    def chunk(self) -> int:
-        return self.A.shape[-2] if self.trans_A else self.A.shape[-1]
+    # If needed later, expose K extent explicitly to avoid overloading "chunk":
+    # @property
+    # def k_extent(self) -> int:
+    #     return self.A.shape[-2] if self.trans_A else self.A.shape[-1]

Also applies to: 64-76

🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_base.py around lines 52-63 (and similarly for
64-76), duplicate property definitions for in_dtype, accum_dtype, and chunk
exist and shadow the previous definitions; remove the buffer-derived duplicate
property methods and keep the node-backed properties (self.gemm_node.*) as the
single source of truth, relying on the existing __post_init__ consistency checks
to ensure they match, and delete the redundant getters so linters and readers
see only one definition per property.

Comment on lines +179 to +192
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""

for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)

# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect function name for RR variant

The function is named _gemm_rsr but should be _gemm_rr since this is the RR (fragment-fragment) variant.

 @T.prim_func
-def _gemm_rsr() -> None:
+def _gemm_rr() -> None:
     """
     The inner macro that loads data from shared buffers A_shared and
     B_shared into local fragments, then issues Tensor Core mma ops,
     accumulating into C_local.
     """

     for ki in T.serial(0, (block_K // micro_size_k)):
         # Perform Matrix Multiplication
         mma_emitter.mma(A_local, B_local, C_local, ki)

 # Simplify to optimize the index computing
 # Must inline let statements to simplify the analysis
-return _Simplify(_gemm_rsr, inline_let=True)
+return _Simplify(_gemm_rr, inline_let=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _gemm_rsr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
@T.prim_func
def _gemm_rr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rr, inline_let=True)
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mma.py around lines 179 to 192, the inner macro is
incorrectly named `_gemm_rsr` but this is the RR (fragment-fragment) variant;
rename the function to `_gemm_rr` and update the return to call
`_Simplify(_gemm_rr, inline_let=True)` (and adjust any internal references or
docstrings mentioning `_gemm_rsr`) so the name matches the RR variant.

Comment on lines +180 to +184
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Misleading docstring for RR variant

The docstring mentions loading from "shared buffers" but in the RR variant, both A and B are already in local fragments and no loading from shared memory occurs.

 """
-The inner macro that loads data from shared buffers A_shared and
-B_shared into local fragments, then issues Tensor Core mma ops,
-accumulating into C_local.
+The inner macro that performs Tensor Core mma ops using
+local fragments A_local and B_local, accumulating into C_local.
 """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
"""
The inner macro that performs Tensor Core mma ops using
local fragments A_local and B_local, accumulating into C_local.
"""
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mma.py around lines 180 to 184, the docstring for
the RR variant incorrectly states that data is loaded from "shared buffers
A_shared and B_shared"; update the docstring to reflect that in the RR variant
both A and B are already in local fragments (no shared-memory loads), and
describe that the macro issues Tensor Core MMA ops using those local fragments
and accumulates into C_local; keep the description concise and specific to the
RR behavior.

Comment on lines +116 to +117
return torch.empty(
*shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Block non-floating dtypes for Uniform/Normal to avoid silent truncation (ints) and all-True tensors (bool).

Casting float32 samples to integer/boolean dtypes silently changes semantics (e.g., Uniform [-1, 1) → mostly zeros for int32; almost all True for bool). This is likely a regression vs. previously erroring fast paths. Gate these branches to floating/FP8 only (int8 remains handled earlier).

Apply:

@@
-        elif supply_type == TensorSupplyType.Uniform:
-            return torch.empty(
-                *shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype)
+        elif supply_type == TensorSupplyType.Uniform:
+            # Restrict to floating/FP8 to avoid silent truncation for ints/bools.
+            if dtype not in {torch.float16, torch.float32, torch.float64, torch.bfloat16} and not param.is_float8():
+                raise TypeError(f"Uniform supply requires floating or float8 dtype, got {dtype}. "
+                                "Use TensorSupplyType.Integer instead.")
+            return torch.empty(
+                *shape, device=device, dtype=torch.float32
+            ).uniform_(-1.0, 1.0).to(dtype)
@@
-        elif supply_type == TensorSupplyType.Normal:
-            return torch.empty(
-                *shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype)
+        elif supply_type == TensorSupplyType.Normal:
+            if dtype not in {torch.float16, torch.float32, torch.float64, torch.bfloat16} and not param.is_float8():
+                raise TypeError(f"Normal supply requires floating or float8 dtype, got {dtype}. "
+                                "Use TensorSupplyType.Integer instead.")
+            return torch.empty(
+                *shape, device=device, dtype=torch.float32
+            ).normal_(-1.0, 1.0).to(dtype)

Optional: extract a small helper to de-duplicate the float32→cast pattern for Uniform/Normal.

Also applies to: 119-120

🤖 Prompt for AI Agents
In tilelang/utils/tensor.py around lines 116-120, the Uniform/Normal branches
currently sample in float32 then cast to the requested dtype, which silently
corrupts non-floating dtypes (ints/bool); change these branches to first
validate that the requested dtype is a floating type (or supported FP8) and
raise a clear error if not; keep int8 handling where already done; optionally
factor out a small helper that samples in float32 and casts to dtype to avoid
duplication but only call it after the dtype check so integer/bool dtypes are
rejected before sampling.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (5)
tilelang/tileop/gemm/gemm_mma.py (2)

183-197: Rename RR macro for clarity and consistency.

This is the RR path but the function is named _gemm_rsr. Rename to _gemm_rr and update the return.

-            def _gemm_rsr() -> None:
+            def _gemm_rr() -> None:
@@
-            return _Simplify(_gemm_rsr, inline_let=True)
+            return _Simplify(_gemm_rr, inline_let=True)

126-133: Fix misleading docstrings per variant (SR/RS/RR).

Docstrings mention shared loads where fragments are already provided. Tighten wording to match behavior.

@@
-                """
-                The inner macro that loads data from shared buffers A_shared and
-                B_shared into local fragments, then issues Tensor Core mma ops,
-                accumulating into C_local.
-                """
+                """
+                Load A from shared into fragment; B is already a fragment.
+                Issue Tensor Core MMA and accumulate into C_local.
+                """
@@
-                """
-                The inner macro that loads data from shared buffers A_shared and
-                B_shared into local fragments, then issues Tensor Core mma ops,
-                accumulating into C_local.
-                """
+                """
+                Load B from shared into fragment; A is already a fragment.
+                Issue Tensor Core MMA and accumulate into C_local.
+                """
@@
-                """
-                The inner macro that loads data from shared buffers A_shared and
-                B_shared into local fragments, then issues Tensor Core mma ops,
-                accumulating into C_local.
-                """
+                """
+                Use local fragments A_local and B_local.
+                Issue Tensor Core MMA and accumulate into C_local.
+                """

Also applies to: 156-161, 184-189

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (2)

117-117: num_stages=0 will break T.Pipelined - use at least 1

Passing 0 to T.Pipelined(..., num_stages=...) can lead to invalid scheduling or no prefetch. Software pipelining requires at least 1 stage.

Apply this diff to fix the issue:

-    run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
+    run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)
-    run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
+    run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)
-    run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
+    run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)

Also applies to: 251-251, 385-385


521-521: Use "float32" instead of "float" for accum dtype

TensorCoreIntrinEmitter.dtype_abbrv expects "float32", not "float". Using "float" will cause a KeyError during lowering.

Apply this diff:

-    run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2)
+    run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float32", 128, 256, 32, 2)
tilelang/intrinsics/mma_layout.py (1)

94-97: Address remaining usage of removed SR/RS aliases in examples.

Based on the past review comment, example files still reference the old shared_16x16_to_mma_32x8_layout_sr/shared_16x16_to_mma_32x8_layout_rs aliases that were removed. The current aliases use _sr_a/_sr_b/_rs_a/_rs_b suffixes.

Consider adding backward compatibility aliases or updating the example files:

# Add these aliases for backward compatibility
shared_16x16_to_mma_32x8_layout_sr = shared_16x16_to_mma_32x8_layout_sr_a  # or _sr_b
shared_16x16_to_mma_32x8_layout_rs = shared_16x16_to_mma_32x8_layout_rs_a  # or _rs_b

Or verify that all example files have been updated to use the new per-matrix naming scheme.

🧹 Nitpick comments (13)
.clang-tidy (1)

4-4: Avoid enabling and disabling the same checks; it’s confusing

These checks appear in both keep and disable lists, and end up disabled by the latter. Consider removing them from the retained block for clarity:

  • cppcoreguidelines-pro-type-static-cast-downcast
  • cppcoreguidelines-pro-type-member-init
  • cppcoreguidelines-narrowing-conversions

Suggested cleanup:

-  cppcoreguidelines-pro-type-static-cast-downcast,
-  cppcoreguidelines-pro-type-member-init,
-  cppcoreguidelines-narrowing-conversions,

Also applies to: 43-43, 5-5, 41-41, 9-9, 39-39

src/target/codegen_cuda.cc (2)

22-22: Avoid broad using-directive; prefer a narrow alias.

To reduce namespace pollution inside tvm::codegen, consider an alias instead of using namespace.

Example:

-using namespace tvm::tl::codegen;
+namespace tl_codegen = tvm::tl::codegen;

Then qualify calls, e.g., tl_codegen::PrintMMAAssembly(...).


1335-1341: Verify tl::ptx_ldmatrix_xN argument order and unify both call sites.

Here you emit tl::ptx_ldmatrix_xN(_trans)(smem + off, local + off), while the earlier tl::ptx_ldmatrix branch prints args starting at index 2 via print_extern_call_stmt (potentially a different order). Please confirm both wrappers take the same parameter order and adjust if necessary to avoid silent layout bugs. Consider centralizing this emission in one helper to keep the ordering consistent.

tilelang/tileop/gemm/gemm_base.py (1)

1-8: Avoid runtime import cycles; gate GemmWarpPolicy under TYPE_CHECKING.

Drop the eager import and keep the annotation as a forward-ref to prevent potential import loops and silence type-checker complaints.

@@
-from tilelang.ir import GemmWarpPolicy
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+    from tilelang.ir import GemmWarpPolicy
@@
-    def policy(self) -> GemmWarpPolicy:
+    def policy(self) -> "GemmWarpPolicy":
         return self.gemm_node.policy

Also applies to: 117-119

tilelang/tileop/gemm/gemm_mma.py (2)

202-212: De-duplicate is_gemm_ helpers; inherit from GemmBase.*

These override identical logic from GemmBase without changes. Remove to keep a single source of truth.

-    def is_gemm_ss(self) -> bool:
-        return is_shared(self.A) and is_shared(self.B)
-
-    def is_gemm_sr(self) -> bool:
-        return is_shared(self.A) and is_fragment(self.B)
-
-    def is_gemm_rs(self) -> bool:
-        return is_fragment(self.A) and is_shared(self.B)
-
-    def is_gemm_rr(self) -> bool:
-        return is_fragment(self.A) and is_fragment(self.B)
+    # Inherit is_gemm_* from GemmBase

84-89: Guard against non-divisible K blocking.

If chunk (K tile) isn’t divisible by micro_size_k, the loop truncates. Add a precondition.

         block_K = mma_emitter.chunk
         micro_size_k = mma_emitter.micro_size_k
+        assert block_K % micro_size_k == 0, "block_K must be divisible by micro_size_k"

If this can be violated by some policy, consider raising ValueError with guidance.

tilelang/intrinsics/mma_macro_generator.py (3)

151-158: Annotate thread-binding getter for clarity.

Return type is a Var; helps readers and tools.

-    def get_thread_binding(self):
+    def get_thread_binding(self) -> Var:
         if self.thread_var is None:
             current_frame = T.KernelLaunchFrame.Current()
             assert current_frame is not None, "Must be called in a T.Kernel Frame"
             return current_frame.get_thread_binding()
         else:
             return self.thread_var

456-475: Docstring: this is a load-layout, not “store”.

Small wording fix to avoid confusion with make_mma_store_layout.

-        Create a layout function for storing MMA results into a fragment buffer.
-        This layout is used in conjunction with `inverse_mma_store_layout` to
-        map fragment indices to threads and local indices.
+        Create a layout function for loading data into an MMA fragment buffer.
+        This layout maps fragment indices to threads and local indices for MMA loads.

124-143: Tighten assert messages or use ValueError for consistency.

Messages say “greater than” but the check is “>=”. Either update text or switch to explicit ValueError for ruff TRY003 friendliness.

-        assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}"
+        assert warp_row_tiles >= 16, f"warp_row_tiles must be >= 16, got {warp_row_tiles}"
-        assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
+        assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}"
-        assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
+        assert warp_col_tiles >= 8, f"warp_col_tiles must be >= 8, got {warp_col_tiles}"
-        assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
+        assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
src/target/ptx.h (2)

44-45: Inconsistent parameter naming conventions in function documentation

The documentation uses inconsistent naming - sometimes with underscores (a_offset, a_ptr) and sometimes without (smem_ptr). Consider standardizing the parameter names.

Apply this naming convention consistently across all function signatures and documentation. For example:

  • Use a_elem_offset instead of mixing a_offset in docs and a_elem_offset in signature
  • Keep parameter names consistent between the documentation comments and actual function signatures

Also applies to: 46-47, 60-61, 74-76, 88-91, 103-105


150-154: Fix comment formatting for mbarrier documentation

Multi-line parameter documentation is improperly formatted with \param split across lines.

Apply this diff to fix the documentation:

-/*!
- * \brief Print ptx barrier arrival with expect tx operation using
- * mbarrier.arrive.expect_tx \param barrier: The name of the barrier in shared
- * memory. \param byte_count: Increases the tx count of the mbarrier object to
- * track completion of addtional async transactions.
- */
+/*!
+ * \brief Print ptx barrier arrival with expect tx operation using mbarrier.arrive.expect_tx
+ * \param barrier The name of the barrier in shared memory.
+ * \param byte_count Increases the tx count of the mbarrier object to
+ * track completion of additional async transactions.
+ */
src/target/ptx.cc (2)

83-133: Consider using a map for more efficient dtype lookups

The current implementation uses multiple if-else statements for string to DataType conversion. For better performance and maintainability, consider using an unordered_map.

Consider refactoring to use a static map:

inline DataType DTypeFromString(const std::string& str) {
  static const std::unordered_map<std::string, DataType> dtype_map = {
    {"int4", DataType::kInt4}, {".s4", DataType::kInt4},
    {"uint4", DataType::kUInt4}, {".u4", DataType::kUInt4},
    {"int8", DataType::kInt8}, {".s8", DataType::kInt8},
    // ... rest of the mappings
  };
  
  auto it = dtype_map.find(str);
  if (it != dtype_map.end()) {
    return it->second;
  }
  LOG(FATAL) << "Unrecognized PTX data type " << str;
}

452-464: String replacement implementation is correct but could be more efficient

The current implementation correctly handles overlapping patterns by updating the search position. However, for better performance with many rules, consider pre-compiling patterns or using a more sophisticated string replacement algorithm.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a3f2564 and aa62efb.

📒 Files selected for processing (11)
  • .clang-tidy (1 hunks)
  • CMakeLists.txt (1 hunks)
  • src/target/codegen_cuda.cc (4 hunks)
  • src/target/ptx.cc (1 hunks)
  • src/target/ptx.h (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (7 hunks)
  • tilelang/intrinsics/mma_layout.py (3 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (18 hunks)
  • tilelang/intrinsics/utils.py (3 hunks)
  • tilelang/tileop/gemm/gemm_base.py (1 hunks)
  • tilelang/tileop/gemm/gemm_mma.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
tilelang/tileop/gemm/gemm_mma.py (5)
tilelang/tileop/gemm/gemm_base.py (18)
  • GemmBase (11-119)
  • infer_layout (14-15)
  • policy (118-119)
  • M (33-34)
  • N (37-38)
  • in_dtype (53-55)
  • accum_dtype (58-59)
  • trans_A (45-46)
  • trans_B (49-50)
  • chunk (62-63)
  • is_gemm_ss (20-21)
  • A (66-67)
  • B (70-71)
  • C (74-75)
  • is_gemm_sr (23-24)
  • is_gemm_rs (26-27)
  • is_gemm_rr (29-30)
  • lower (17-18)
tilelang/intrinsics/mma_macro_generator.py (10)
  • make_mma_store_layout (596-665)
  • make_mma_load_layout (452-594)
  • ldmatrix_a (199-261)
  • ldmatrix_a (753-857)
  • ldmatrix_b (263-333)
  • ldmatrix_b (859-971)
  • mma (335-395)
  • mma (973-1022)
  • mma (1027-1125)
  • mma (1130-1229)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/transform/simplify.py (1)
  • _Simplify (30-49)
tilelang/tileop/gemm/__init__.py (2)
  • infer_layout (52-57)
  • lower (59-65)
src/target/ptx.h (1)
src/target/ptx.cc (20)
  • PrintMMAAssembly (569-632)
  • PrintMMAAssembly (570-579)
  • PrintLoadMatrixAssembly (657-693)
  • PrintLoadMatrixAssembly (657-662)
  • PrintCpAsyncAssembly (695-722)
  • PrintCpAsyncAssembly (695-699)
  • PrintPredicatedCpAsyncAssembly (724-780)
  • PrintPredicatedCpAsyncAssembly (724-727)
  • PrintCpAsyncBulkAsm (782-809)
  • PrintCpAsyncBulkAsm (782-787)
  • PrintCpAsyncBarrierAsm (811-826)
  • PrintCpAsyncBarrierAsm (811-811)
  • PrintInitBarrierThreadCountAsm (828-846)
  • PrintInitBarrierThreadCountAsm (828-829)
  • PrintArriveBarrierAsm (848-863)
  • PrintArriveBarrierAsm (848-848)
  • PrintArriveBarrierExpectTxAsm (865-883)
  • PrintArriveBarrierExpectTxAsm (865-866)
  • PrintWaitBarrierAsm (885-901)
  • PrintWaitBarrierAsm (885-885)
tilelang/tileop/gemm/gemm_base.py (4)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/carver/roller/node.py (1)
  • Node (93-175)
tilelang/tileop/gemm/gemm_mma.py (6)
  • infer_layout (15-58)
  • lower (60-200)
  • is_gemm_ss (202-203)
  • is_gemm_sr (205-206)
  • is_gemm_rs (208-209)
  • is_gemm_rr (211-212)
tilelang/tileop/gemm/__init__.py (2)
  • infer_layout (52-57)
  • lower (59-65)
src/target/codegen_cuda.cc (1)
src/target/ptx.h (1)
  • codegen (33-164)
tilelang/intrinsics/mma_macro_generator.py (5)
tilelang/intrinsics/utils.py (2)
  • mma_store_index_map (81-82)
  • get_ldmatrix_offset (21-63)
tilelang/utils/language.py (1)
  • is_fragment (68-78)
tilelang/intrinsics/mma_layout.py (5)
  • mma_load_a_32x4_to_shared_16x8_layout (130-133)
  • mma_load_b_32x4_to_shared_16x8_layout (136-139)
  • mma_load_a_32x16_to_shared_16x32_layout (142-145)
  • mma_load_b_32x16_to_shared_16x32_layout (148-151)
  • transform_func (212-215)
tilelang/language/kernel.py (5)
  • threads (195-199)
  • get_thread_binding (151-156)
  • get_thread_binding (261-265)
  • KernelLaunchFrame (72-206)
  • Current (115-121)
tilelang/layout/fragment.py (2)
  • replicate (146-160)
  • repeat (123-144)
tilelang/intrinsics/utils.py (1)
tilelang/intrinsics/mma_layout.py (7)
  • ldmatrix_32x4_to_shared_16x8_layout_a (6-9)
  • ldmatrix_32x4_to_shared_16x8_layout_b (12-15)
  • ldmatrix_32x8_to_shared_16x16_layout (18-21)
  • ldmatrix_trans_32x8_to_shared_16x16_layout (24-27)
  • ldmatrix_32x16_to_shared_16x32_layout_a (30-33)
  • ldmatrix_32x16_to_shared_16x32_layout_b (36-39)
  • transform_func (212-215)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/gemm.py (1)
  • gemm_v2 (186-359)
tilelang/jit/kernel.py (2)
  • get_profiler (360-376)
  • out_idx (446-447)
tilelang/utils/tensor.py (1)
  • TensorSupplyType (11-18)
tilelang/language/__init__.py (1)
  • annotate_layout (102-140)
tilelang/layout/fragment.py (1)
  • make_swizzled_layout (210-216)
🪛 Ruff (0.12.2)
tilelang/tileop/gemm/gemm_mma.py

57-58: Avoid specifying long messages outside the exception class

(TRY003)


199-200: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_macro_generator.py

122-122: Avoid specifying long messages outside the exception class

(TRY003)


224-224: Avoid specifying long messages outside the exception class

(TRY003)


290-290: Avoid specifying long messages outside the exception class

(TRY003)


503-503: Avoid specifying long messages outside the exception class

(TRY003)


520-520: Avoid specifying long messages outside the exception class

(TRY003)


578-578: Avoid specifying long messages outside the exception class

(TRY003)


592-592: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/utils.py

41-41: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-nvidia
  • GitHub Check: bot-task
🔇 Additional comments (18)
.clang-tidy (1)

43-45: Document rationale for disabling performance-enum-size and confirm clang-tidy support

Add a brief comment explaining the ABI risk above the disabled check, then ensure CI installs clang-tidy and includes “performance-enum-size”:

   -performance-unnecessary-value-param,
+  # Rationale: changing enum underlying types can break ABI; disable until enum usage around T.gemm_v2 is audited.
   -performance-enum-size,

Example CI validation:

command -v clang-tidy >/dev/null && clang-tidy --version \
  && clang-tidy -list-checks -checks='*' | grep -q 'performance-enum-size'
CMakeLists.txt (1)

135-135: PTX source inclusion under USE_CUDA looks correct.

Adding src/target/ptx.cc to TILE_LANG_CUDA_SRCS is properly gated by USE_CUDA and aligns with the new helpers used by codegen_cuda.cc.

src/target/codegen_cuda.cc (2)

17-17: Include path switch to local ptx.h is appropriate.

The relative include matches src/target/ptx.h and avoids stale paths.


1263-1265: Indentation before emitting asm improves readability.

The added PrintIndent() before MMA/PTX emission keeps generated code aligned.

Also applies to: 1299-1305

tilelang/intrinsics/utils.py (1)

52-59: Int8 layout switch to 32x16→16x32 A/B functions: looks right; please validate all transpose modes.

The A-non-transposed/B-transposed restriction matches ldmatrix constraints for int8. Please confirm coverage with tests across rs/sr/rr variants to avoid corner-case misindexing.

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (2)

47-47: Good adoption of T.gemm_v2 for experimental testing

The transition from T.gemm to T.gemm_v2 aligns well with the PR's experimental nature. The commented-out original line provides a convenient fallback for debugging.


170-172: Consistent use of swizzled layouts for fragment optimization

Good implementation of swizzled memory layouts for the fragment variants (rs, sr, rr). This optimization improves memory access patterns for tensor cores.

Also applies to: 304-306, 440-443

src/target/ptx.cc (3)

594-631: Well-structured PTX assembly generation

The MMA assembly generation is well-organized with clear separation of concerns: validation, template generation, and pattern replacement. The use of the Replacer class provides good maintainability.


748-766: No action needed: C++17 support already enabled
CMakeLists.txt sets CMAKE_CXX_STANDARD to 17, so structured bindings are supported.


377-383: Correct layout validation logic – The condition properly enforces row-major A and column-major B for all cases except the supported m8n8k4 fp16 configuration, where custom layouts are allowed.

tilelang/intrinsics/mma_layout.py (8)

6-9: LGTM! Clean implementation of ldmatrix layout for matrix A.

The function correctly maps thread_id and local_id to row/col coordinates for a 32x4 to 16x8 layout transformation specific to matrix A.


12-15: LGTM! Proper matrix B layout implementation.

The layout calculation correctly handles the different indexing pattern needed for matrix B in the 32x4 to 16x8 transformation.


18-21: LGTM! Consistent 32x8 to 16x16 layout mapping.

The function maintains consistency with the existing pattern and correctly handles the wider local_id range (8 instead of 4).


24-27: LGTM! Proper transposed layout implementation.

The transposed variant correctly swaps the indexing pattern while maintaining the same coordinate calculation logic.


50-73: LGTM! Well-structured matrix A/B layout separation.

The refactor to separate matrix A and B layouts improves clarity and maintainability. The transpose variants and SR/RS aliases provide good backward compatibility.


76-97: LGTM! Consistent 16x16 to 32x8 layout implementation.

The matrix A/B separation follows the same pattern established for the 32x4 layouts, with proper transpose variants and SR/RS aliases.


100-121: LGTM! Proper 32x16 layout implementation.

The 16x32 to 32x16 layout functions correctly handle the larger dimension while maintaining consistency with the established patterns.


130-151: LGTM! MMA load helper functions are well-implemented.

The new load helper functions provide clear separation between matrix A and B loading patterns, with correct coordinate calculations for each matrix type.

Comment on lines +502 to +507
const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_a) /
frag_attr_a.size / threads / (sparse ? 2 : 1),
num_operands_b =
(k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads,
num_operands_c =
(m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential integer overflow in operand calculations

The calculations for num_operands_* involve multiplication and division of potentially large values without overflow checks.

Add overflow checks or use safer integer types:

+  // Check for potential overflow
+  const uint64_t bits_a = static_cast<uint64_t>(m) * k * ptx::DTypeBits(dtype_a);
+  const uint64_t bits_b = static_cast<uint64_t>(k) * n * ptx::DTypeBits(dtype_b);
+  const uint64_t bits_c = static_cast<uint64_t>(m) * n * ptx::DTypeBits(dtype_c);
+  
+  CHECK(bits_a / frag_attr_a.size / threads <= INT_MAX) << "Operand count overflow";
+  
   const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_a) /
                              frag_attr_a.size / threads / (sparse ? 2 : 1),

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/target/ptx.cc around lines 502-507, the computation of num_operands_a/b/c
can overflow because several potentially large multiplications are done in int
before division; change the intermediate arithmetic to a wider unsigned integer
type (e.g., uint64_t or use __int128 for intermediate products), perform checks
for multiplication overflow (use builtin helpers like __builtin_mul_overflow or
check operands against max/value limits) before dividing, and cast back to the
target type only after safe division; if overflow is detected, handle it
deterministically (return error, clamp, or log) rather than producing a wrapped
value.

* \brief Code generation with inlined PTX code.
*/
#ifndef TVM_TL_TARGET_SOURCE_PTX_H_
#define TVM_TL_TARGET_SOURCE_PTX_H_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect header guard macro name

The header guard uses TVM_TL_TARGET_SOURCE_PTX_H_ but the file is located at src/target/ptx.h. The macro should match the actual path.

Apply this diff to fix the header guard:

-#ifndef TVM_TL_TARGET_SOURCE_PTX_H_
-#define TVM_TL_TARGET_SOURCE_PTX_H_
+#ifndef TVM_TL_TARGET_PTX_H_
+#define TVM_TL_TARGET_PTX_H_

Also update line 167:

-#endif // TVM_TL_TARGET_SOURCE_PTX_H_
+#endif // TVM_TL_TARGET_PTX_H_
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#define TVM_TL_TARGET_SOURCE_PTX_H_
#ifndef TVM_TL_TARGET_PTX_H_
#define TVM_TL_TARGET_PTX_H_
// … rest of ptx.h contents …
#endif // TVM_TL_TARGET_PTX_H_
🤖 Prompt for AI Agents
In src/target/ptx.h around line 25, the header guard macro is incorrect
(currently TVM_TL_TARGET_SOURCE_PTX_H_) and must match the file path; replace it
with a macro appropriate for src/target/ptx.h (e.g., SRC_TARGET_PTX_H_) and
ensure the closing #endif at line 167 is updated to reference the new macro name
so the guard pairs correctly.

Comment on lines +4 to +5
ldmatrix_32x4_to_shared_16x8_layout_a,
ldmatrix_32x4_to_shared_16x8_layout_b,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unreachable tf32/32-bit branch and its now-unused imports.

dtype is annotated as Literal["float16","int8"]; DataType(dtype).bits can’t be 32, so this block is dead code. The imports for ldmatrix_32x4_to_shared_16x8_layout_{a,b} become unused. Either wire this up with a real tf32 path (and broaden the type hints/tests) or drop it for now to avoid confusion.

Apply:

-from .mma_layout import (
-    ldmatrix_32x4_to_shared_16x8_layout_a,
-    ldmatrix_32x4_to_shared_16x8_layout_b,
+from .mma_layout import (
     ldmatrix_32x8_to_shared_16x16_layout,
     ldmatrix_trans_32x8_to_shared_16x16_layout,
     ldmatrix_32x16_to_shared_16x32_layout_a,
     ldmatrix_32x16_to_shared_16x32_layout_b,
     mma_store_32x8_to_shared_16x16_layout,
 )
@@
-    if dtype_bits == 32:
-        if matrix == "B" and transposed:
-            transform_func = ldmatrix_32x4_to_shared_16x8_layout_b
-            new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
-            return new_row_idx * stride + new_col_idx
-        elif matrix == "A" and not transposed:
-            transform_func = ldmatrix_32x4_to_shared_16x8_layout_a
-            new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
-            return new_row_idx * stride + new_col_idx
-        else:
-            raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8")
+    # tf32 not yet supported (see PR TODO); 32-bit branch intentionally omitted.

Also applies to: 31-42

🤖 Prompt for AI Agents
In tilelang/intrinsics/utils.py around lines 4-5 and 31-42 remove the dead
tf32/32-bit branch and the now-unused imports
ldmatrix_32x4_to_shared_16x8_layout_a and ldmatrix_32x4_to_shared_16x8_layout_b:
update the dtype handling to only consider the Literal["float16","int8"] cases
(or if you intend to support tf32/32-bit, broaden the dtype type hints/tests and
add a real code path), delete the unused import lines, and run tests/linters to
ensure no remaining references to those symbols remain.

- Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`.
- This change simplifies the allocation process and aligns with the updated GEMM function signatures.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (2)

117-117: num_stages must be >= 1 for T.Pipelined.

Using 0 is invalid/fragile and has been flagged previously. Set to 1 (or 2) to enable prefetch.

-    run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
+    run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)

-    run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
+    run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)

-    run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
+    run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)
#!/bin/bash
# Find any remaining uses of num_stages=0 in tests.
rg -nP '\bnum_stages\s*=\s*0\b' -g 'testing/**'

# Sanity-check run_gemm_* callsites passing a literal 0 as the stage arg.
rg -nP 'run_gemm_(ss|sr|rs|rr)\([^)]*,\s*0\s*(,|\))' -C2 -g 'testing/**'

Also applies to: 251-251, 385-385


521-521: Use "float32" (not "float") for accum dtype to match emitter mapping.

The emitter expects "float32"; "float" can KeyError during lowering. This was raised before; adjust here as well.

-    run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2)
+    run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float32", 128, 256, 32, 2)
#!/bin/bash
# Ensure no remaining accum dtype "float" with bfloat16 inputs.
rg -nP 'run_gemm_\(ss|sr|rs|rr\)\([^)]*"bfloat16"[^)]*"bfloat16"[^)]*"float"'
🧹 Nitpick comments (3)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (3)

47-49: Remove the commented legacy GEMM call for clarity.

Keep only the v2 call to avoid confusion.

-                # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)

128-128: Misleading comment: these are float32 tests, not TF32.

TF32 support is listed as TODO; rename the comment to reflect fp32.

-    # tfloat32 test
+    # float32 tests (TF32 not enabled yet)

54-68: Reduce duplication across run_gemm_ wrappers.*

Factor a small helper to compile, get_profiler, and assert against ref_program; pass the matmul builder as a parameter. Keeps variants focused on their distinct tiling only.

Also applies to: 189-203, 461-475

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aa62efb and b5f327c.

📒 Files selected for processing (1)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6)
tilelang/language/gemm.py (1)
  • gemm_v2 (186-359)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/jit/kernel.py (2)
  • get_profiler (360-376)
  • out_idx (446-447)
tilelang/utils/tensor.py (1)
  • TensorSupplyType (11-18)
tilelang/language/__init__.py (1)
  • annotate_layout (102-140)
tilelang/layout/fragment.py (1)
  • make_swizzled_layout (210-216)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check
🔇 Additional comments (3)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (3)

170-173: Swizzled A_shared layout usage looks correct.

Annotating A_shared with a swizzled layout matches the rs variant intent.


304-306: Swizzled B_shared layout usage looks correct.

Matches the sr variant and improves MMA load patterns.


441-443: Swizzling both A_shared and B_shared is appropriate for rr.

This aligns with the rr data path and should help bank conflict avoidance.

RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
…#793)

* Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability

- Removed deprecated prime factorization functions from `gemm.cc` and `gemm_sp.cc`.
- Introduced a new `GemmWarpPolicy` class to manage warp policy attributes and methods, improving encapsulation.
- Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities.
- Enhanced `GetArchInt` function in `utils.cc` for better readability and type safety.
- Added new `gemm_v2` function in `gemm.py` for improved GEMM operation with additional parameters and checks.

* Refactor GEMM and frontend legalize operations for improved clarity and functionality

- Updated `gemm_py.h` to include the correct header for GEMM operations.
- Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity.
- Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose.
- Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity.
- Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite.
- Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process.

* Enhance CUDA code generation and testing for GEMM operations

- Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting.
- Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope.
- Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts.
- Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling.
- Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.

* Refactor GEMM layout and testing for improved clarity and functionality

- Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations.
- Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage.
- Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations.
- Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts.

These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.

* Refactor GEMM layout and Python integration for improved functionality

- Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations.
- Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes.
- Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability.
- Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow.

These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.

* Refactor GEMM layout and testing for improved clarity and functionality

- Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations.
- Improved block realization handling in `gemm_py.cc` for better assignment of global symbols.
- Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity.
- Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations.

These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.

* tfloat32 support.

* lint fix

* lint fix

* Refactor shared memory allocation in GEMM tests

- Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`.
- This change simplifies the allocation process and aligns with the updated GEMM function signatures.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant