-
Notifications
You must be signed in to change notification settings - Fork 331
[TileOp] Introduce a experimental python defined T.gemm_v2
#793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…bility - Removed deprecated prime factorization functions from `gemm.cc` and `gemm_sp.cc`. - Introduced a new `GemmWarpPolicy` class to manage warp policy attributes and methods, improving encapsulation. - Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities. - Enhanced `GetArchInt` function in `utils.cc` for better readability and type safety. - Added new `gemm_v2` function in `gemm.py` for improved GEMM operation with additional parameters and checks.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a new TL GEMM stack (GemmPy/GemmMMA) with FFI bindings and GemmWarpPolicy compute wrapper, introduces gemm_v2 and GemmPy scriptable node, refactors MMA layouts/emitter and PTX codegen (ptx.cc/.h + codegen_cuda changes), renames FrontendLegalize→LetInline, and applies pipeline/lowering and test updates. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant PyUser as Python (tilelang.language.gemm_v2)
participant TL as TL Runtime (Python)
participant FFI as TVM FFI
participant Cpp as C++ TL Op (GemmPy)
participant Impl as Python Impl (tileop.gemm.GemmMMA)
participant CG as CUDA Codegen
PyUser->>TL: gemm_v2(A,B,C,transpose_*,policy,...)
TL->>TL: Resolve shapes, strides, pointers, offsets
TL->>FFI: tl.gemm_py.lower(gemm_py, target, thread_bounds, thread_var)
FFI->>Cpp: GemmPyNode::Lower(...)
Cpp->>Impl: dispatch to GemmMMA.lower via FFI/Python
Impl->>Impl: prepare emitter, layouts, fragments
Impl-->>FFI: return PrimFunc
FFI-->>TL: PrimFunc
TL-->>CG: Lower to CUDA, emit PTX
CG->>CG: Emit ldmatrix via tl::ptx_ldmatrix_xN(_trans)
CG-->>PyUser: Compiled kernel
sequenceDiagram
autonumber
participant Py as Python (tilelang.ir.GemmWarpPolicy)
participant FFI as TVM FFI
participant Cpp as C++ GemmWarpPolicy
Py->>FFI: GemmWarpPolicyComputeWarpPartition(policy,M,N,block_size,target,is_wgmma)
FFI->>Cpp: policy->ComputeWarpPartition(...)
Cpp-->>FFI: sets policy.m_warp, policy.n_warp
FFI-->>Py: returns m_warp, n_warp
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Poem
Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces an experimental, Python-defined General Matrix Multiplication (GEMM) operation, T.gemm_v2, designed to improve backend integration and reduce compilation overhead. It establishes a robust Python-C++ FFI for handling the operation's parameters, layout inference, and lowering, while also enhancing existing target utility functions and refactoring related code for better maintainability and extensibility.
Highlights
- Experimental Python-defined GEMM (
T.gemm_v2): Introduced a new experimentalT.gemm_v2function, defined in Python, to facilitate integration with additional backends and potentially reduce compilation time. This initial version is focused on MMA (Matrix Multiply-Accumulate) operations. - New C++ Backend for Python-defined Ops: Added new C++ files (
src/op/gemm_py.cc,src/op/gemm_py.h) to support the Python-definedgemm_v2operation. This includesGemmPyNodefor handling GEMM parameters, WGMMA checks for Hopper architecture, and FFI registrations for layout inference and lowering functions. - Enhanced Python-C++ FFI for Target Utilities: Expanded the FFI (Foreign Function Interface) to expose various target-related utility functions (e.g.,
TargetIsCuda,TargetGetWarpSize) from C++ to Python, enabling more flexible and dynamic target-specific optimizations from the Python frontend. - Refactoring and Code Cleanup: Removed the
toPrimeFactorsutility function fromsrc/op/gemm.ccandsrc/op/gemm_sp.cc, simplifying the codebase. Improved theGetArchIntparsing logic insrc/target/utils.ccfor robustness. - Updated MMA Macro Generator: The
TensorCoreIntrinEmitterintilelang/intrinsics/mma_macro_generator.pynow includes athread_varparameter and aget_thread_bindingmethod, allowingldmatrixandstmatrixmacros to retrieve thread bindings more directly, reducing reliance onT.KernelLaunchFrame.Current(). - Improved
GemmWarpPolicyIntegration: TheGemmWarpPolicyobject intilelang/ir.pyhas been updated to expose its internalpolicy_type,m_warp, andn_warpattributes, and includes acompute_warp_partitionmethod that calls into the C++ backend via FFI. - New
tileopPython Package: A new Python packagetilelang/tileophas been introduced, specifically for Python-defined tile operations. It currently contains theGemmPyclass, which defines the Python-side logic forinfer_layoutandlowerfor the newgemm_v2operation.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an experimental, Python-defined gemm_v2 operator to improve backend integration and reduce compilation times. The changes are extensive, adding new C++ and Python files for the operator, and exposing more C++ utilities to Python via FFI. The overall structure is sound. My review focuses on correctness and maintainability, and I've identified a few critical bugs and opportunities for refactoring to improve code quality.
| equal(N, other->N) && equal(K, other->K) && | ||
| equal(stride_A, other->stride_A) && | ||
| equal(stride_B, other->stride_B) && | ||
| equal(offset_A, other->offset_B) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a copy-paste error in SEqualReduce. offset_A is being compared with other->offset_B instead of other->offset_A. This will cause incorrect behavior in structural equality checks, which can lead to subtle bugs in TIR passes.
| equal(offset_A, other->offset_B) && | |
| equal(offset_A, other->offset_A) && |
tilelang/tileop/gemm/__init__.py
Outdated
| block_row_warps=self.M, | ||
| block_col_warps=self.N, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TensorCoreIntrinEmitter is initialized with block_row_warps=self.M and block_col_warps=self.N. However, self.M and self.N are matrix dimensions, while block_row_warps and block_col_warps are expected to be warp counts. The computed m_warp and n_warp should be used instead. This is a critical bug that will lead to incorrect behavior.
| block_row_warps=self.M, | |
| block_col_warps=self.N, | |
| block_row_warps=m_warp, | |
| block_col_warps=n_warp, |
tilelang/tileop/gemm/__init__.py
Outdated
| block_row_warps=self.M, | ||
| block_col_warps=self.N, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TensorCoreIntrinEmitter is initialized with block_row_warps=self.M and block_col_warps=self.N. However, self.M and self.N are matrix dimensions, while block_row_warps and block_col_warps are expected to be warp counts. The computed m_warp and n_warp should be used instead. This is a critical bug that will lead to incorrect behavior.
| block_row_warps=self.M, | |
| block_col_warps=self.N, | |
| block_row_warps=m_warp, | |
| block_col_warps=n_warp, |
| static int GetArchInt(Target target) { | ||
| int arch_int = 0; | ||
| auto s = target->GetAttr<String>("arch"); | ||
| ICHECK(s.defined()); | ||
| std::string arch = s.value(); | ||
| if (arch.rfind("sm_", 0) == 0) { | ||
| arch_int = std::stoi(arch.substr(3)); | ||
| } else { | ||
| arch_int = 0; | ||
| } | ||
| return arch_int; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tilelang/tileop/gemm/__init__.py
Outdated
| def infer_layout(self, target: Target, thread_nums: int): | ||
| if target_is_cuda(target): | ||
| # TODO(lei): Support more cuda architectures, now mma only | ||
| # Now only implement ssr layout | ||
| m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False) | ||
| warp_row_tiles = m_warp * 16 | ||
| warp_col_tiles = n_warp * 16 | ||
| mma_emitter = TensorCoreIntrinEmitter( | ||
| a_dtype=self.in_dtype, | ||
| b_dtype=self.in_dtype, | ||
| accum_dtype=self.accum_dtype, | ||
| a_transposed=self.trans_A, | ||
| b_transposed=self.trans_B, | ||
| block_row_warps=self.M, | ||
| block_col_warps=self.N, | ||
| warp_row_tiles=warp_row_tiles, | ||
| warp_col_tiles=warp_col_tiles, | ||
| chunk=self.chunk, | ||
| ) | ||
| layout_map = { | ||
| self.A: make_swizzled_layout(self.A), | ||
| self.B: make_swizzled_layout(self.B), | ||
| self.C: mma_emitter.make_mma_store_layout(self.C), | ||
| } | ||
| return layout_map | ||
| else: | ||
| raise ValueError(f"Unsupported target: {target}") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def gemm_v2( | ||
| A: Union[tir.Buffer, tir.Var], | ||
| B: Union[tir.Buffer, tir.Var], | ||
| C: Union[tir.Buffer, tir.Var], | ||
| transpose_A: bool = False, | ||
| transpose_B: bool = False, | ||
| policy: GemmWarpPolicy = GemmWarpPolicy.Square, | ||
| clear_accum: bool = False, | ||
| k_pack: int = 1, | ||
| wg_wait: int = 0, | ||
| ): | ||
| """Perform a General Matrix Multiplication (GEMM) operation. | ||
| This function computes C = A @ B where A and B can optionally be transposed. | ||
| The operation supports various warp policies and accumulation modes. | ||
| Args: | ||
| A (Union[tir.Buffer, tir.Var]): First input matrix | ||
| B (Union[tir.Buffer, tir.Var]): Second input matrix | ||
| C (Union[tir.Buffer, tir.Var]): Output matrix for results | ||
| transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. | ||
| transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. | ||
| policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. | ||
| clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. | ||
| k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. | ||
| wg_wait (int, optional): Warp group wait count. Defaults to 0. | ||
| Returns: | ||
| tir.Call: A handle to the GEMM operation | ||
| Raises: | ||
| AssertionError: If the K dimensions of matrices A and B don't match | ||
| """ | ||
|
|
||
| def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): | ||
| """Convert let-bound variables to their corresponding buffers. | ||
| Args: | ||
| arg (Union[tir.Buffer, tir.Var]): Input argument to legalize | ||
| Returns: | ||
| Union[tir.Buffer, tir.Var]: The legalized argument | ||
| """ | ||
| if isinstance(arg, tir.Var) and T.has_let_value(arg): | ||
| return T.get_let_value(arg).buffer | ||
| return arg | ||
|
|
||
| A = legalize_arguments(A) | ||
| B = legalize_arguments(B) | ||
| C = legalize_arguments(C) | ||
|
|
||
| def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: | ||
| if isinstance(object, tir.Buffer): | ||
| return object.shape | ||
| elif isinstance(object, tir.BufferRegion): | ||
| region = object.region | ||
| shape = [] | ||
| for r in region: | ||
| shape.append(r.extent) | ||
| return shape | ||
| else: | ||
| raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") | ||
|
|
||
| def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: | ||
| if isinstance(object, tir.Buffer): | ||
| strides = [] | ||
| stride = 1 | ||
| for s in reversed(object.shape): | ||
| strides.insert(0, stride) | ||
| stride *= s | ||
| return strides | ||
| elif isinstance(object, tir.BufferRegion): | ||
| buffer, _ = object.buffer, object.region | ||
| strides = [] | ||
| stride = 1 | ||
| for s in reversed(buffer.shape): | ||
| strides.insert(0, stride) | ||
| stride *= s | ||
| return strides | ||
| else: | ||
| raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") | ||
|
|
||
| A_shape = retrieve_shape(A) | ||
| B_shape = retrieve_shape(B) | ||
| C_shape = retrieve_shape(C) | ||
|
|
||
| A_stride = retrieve_stride(A) | ||
| B_stride = retrieve_stride(B) | ||
|
|
||
| assert len(C_shape) == 2, "current only support C as a 2D tensor" | ||
| assert len(A_shape) >= 2, "current only support A as a 2D or higher-order tensor" | ||
| assert len(B_shape) >= 2, "current only support B as a 2D or higher-order tensor" | ||
| if len(A_shape) > 2: | ||
| for i in range(len(A_shape) - 2): | ||
| assert A_shape[i] == 1, \ | ||
| "current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" | ||
| if len(B_shape) > 2: | ||
| for i in range(len(B_shape) - 2): | ||
| assert B_shape[i] == 1, \ | ||
| "current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions" | ||
|
|
||
| M, N = C_shape | ||
| K = A_shape[-2] if transpose_A else A_shape[-1] | ||
| K_B = B_shape[-1] if transpose_B else B_shape[-2] | ||
| assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}" | ||
|
|
||
| stride_a = A_stride[-2] | ||
| stride_b = B_stride[-2] | ||
|
|
||
| def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion], | ||
| access_type: str = "r") -> tir.PrimExpr: | ||
| if isinstance(object, tir.Buffer): | ||
| return object.access_ptr(access_type) | ||
| elif isinstance(object, tir.BufferRegion): | ||
| buffer, region = object.buffer, object.region | ||
| indices = [] | ||
| for r in region: | ||
| indices.append(r.min) | ||
| strides = [] | ||
| stride = 1 | ||
| for s in reversed(buffer.shape): | ||
| strides.insert(0, stride) | ||
| stride *= s | ||
| offset = 0 | ||
| # not offset the last two dimension | ||
| for i in range(len(indices) - 2): | ||
| offset += indices[i] * strides[i] | ||
| return buffer.access_ptr(access_mask=access_type, offset=offset) | ||
| else: | ||
| raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") | ||
|
|
||
| def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr: | ||
| """Retrieve the offset of the buffer or buffer region.""" | ||
| if isinstance(object, tir.Buffer): | ||
| return [0] * len(object.shape) | ||
| elif isinstance(object, tir.BufferRegion): | ||
| _, region = object.buffer, object.region | ||
| indices = [] | ||
| for r in region: | ||
| indices.append(r.min) | ||
| return indices | ||
| else: | ||
| raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") | ||
|
|
||
| A_offset = retrieve_offset(A) | ||
| B_offset = retrieve_offset(B) | ||
| assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" | ||
| assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0" | ||
| offset_a = A_offset[-1] | ||
| offset_b = B_offset[-1] | ||
|
|
||
| Aptr = retrieve_ptr(A, "r") | ||
| Bptr = retrieve_ptr(B, "r") | ||
| Cptr = retrieve_ptr(C, "rw") | ||
| return tir.call_intrin( | ||
| "handle", | ||
| tir.op.Op.get("tl.gemm_py"), | ||
| Aptr, | ||
| Bptr, | ||
| Cptr, | ||
| transpose_A, | ||
| transpose_B, | ||
| M, | ||
| N, | ||
| K, | ||
| policy, | ||
| clear_accum, | ||
| stride_a, | ||
| stride_b, | ||
| offset_a, | ||
| offset_b, | ||
| k_pack, | ||
| wg_wait, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new gemm_v2 function duplicates a significant amount of code from the existing gemm function, including the helper functions legalize_arguments, retrieve_shape, retrieve_stride, retrieve_ptr, and retrieve_offset. To improve maintainability and reduce redundancy, this common logic should be refactored into a shared helper function. The gemm and gemm_v2 functions would then call this helper and only differ in the final intrinsic they call (tl.gemm vs tl.gemm_py).
…nd functionality - Updated `gemm_py.h` to include the correct header for GEMM operations. - Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity. - Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose. - Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity. - Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite. - Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process.
- Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting. - Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope. - Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts. - Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling. - Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.
- Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations. - Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage. - Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations. - Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework.
- Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations. - Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes. - Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability. - Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow. These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/layout/gemm_layouts.cc (1)
200-206: Add missing precondition checks for block_m and block_n divisibility.This function uses block_m / warp_m and block_n / warp_n as repeat/replicate counts (Lines 208 and 215), but unlike makeGemmFragmentA/C, there are no ICHECKs ensuring exact divisibility. If not multiples, layouts silently truncate and produce incorrect tiling.
Apply this diff near the existing ICHECKs:
// transposed ICHECK(warp_n % 8 == 0); ICHECK(block_k % 16 == 0); + ICHECK(block_m % warp_m == 0) << "block_m=" << block_m << ", warp_m=" << warp_m; + ICHECK(block_n % warp_n == 0) << "block_n=" << block_n << ", warp_n=" << warp_n;src/transform/inject_pipeline.cc (1)
53-69: Recompute reads/writes also for reused BlockRealize blocksWhen body is a BlockRealize with predicate true, the early return skips recomputing access regions. If the block’s body was rewritten earlier, reads/writes can be stale. Compute and set them in both paths.
Block MakeBlock(const Stmt &body, const Map<Var, Buffer> &buffer_data_to_buffer) { if (const BlockRealizeNode *block_realize = body.as<BlockRealizeNode>()) { if (is_one(block_realize->predicate)) { - // no need to create a new block - return block_realize->block; + // reuse block but refresh access regions + Block block = block_realize->block; + Array<Array<BufferRegion>> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer); + BlockNode* n = block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + return block; } } Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body); Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); BlockNode *n = block.CopyOnWrite(); n->reads = access[0]; n->writes = access[1]; return block; }
♻️ Duplicate comments (4)
src/op/gemm_py.cc (1)
208-219: Remove duplicated GetArchInt (already provided by utils.h).This local static is dead code here and duplicates src/target/utils.cc. Please delete.
Apply:
-/** - * @brief Parse and return the numeric GPU architecture from a Target's "arch" - * attribute. - * ... - */ -static int GetArchInt(Target target) { - int arch_int = 0; - auto s = target->GetAttr<String>("arch"); - ICHECK(s.defined()); - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) { - arch_int = std::stoi(arch.substr(3)); - } else { - arch_int = 0; - } - return arch_int; -}src/op/gemm_py.h (1)
71-71: Fix SEqualReduce comparison for offset_A (copy-paste bug).
offset_Ais compared againstother->offset_B. This breaks structural equality and downstream memoization/passes.- equal(offset_A, other->offset_B) && + equal(offset_A, other->offset_A) &&tilelang/language/gemm.py (1)
185-359: Deduplicate gemm and gemm_v2 by extracting shared helpersThis block duplicates almost all of gemm(). Factor the common logic into a private helper returning the prepared arguments; gemm() and gemm_v2 then differ only by the intrinsic op string.
Here’s a minimal sketch:
+from dataclasses import dataclass + +@dataclass +class _GemmArgs: + Aptr: tir.PrimExpr + Bptr: tir.PrimExpr + Cptr: tir.PrimExpr + M: int + N: int + K: int + stride_a: int + stride_b: int + offset_a: int + offset_b: int + +def _prepare_gemm_args(A, B, C, transpose_A, transpose_B) -> _GemmArgs: + # move legalize_arguments/retrieve_shape/retrieve_stride/retrieve_ptr/retrieve_offset here + # and return _GemmArgs(...) + ... + def gemm(...): - # duplicated helpers here - ... - return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), ...) + args = _prepare_gemm_args(A, B, C, transpose_A, transpose_B) + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.gemm"), + args.Aptr, args.Bptr, args.Cptr, + transpose_A, transpose_B, + args.M, args.N, args.K, policy, clear_accum, + args.stride_a, args.stride_b, args.offset_a, args.offset_b, k_pack, wg_wait, + ) def gemm_v2(...): - # duplicated helpers here - ... - return tir.call_intrin("handle", tir.op.Op.get("tl.gemm_py"), ...) + args = _prepare_gemm_args(A, B, C, transpose_A, transpose_B) + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.gemm_py"), + args.Aptr, args.Bptr, args.Cptr, + transpose_A, transpose_B, + args.M, args.N, args.K, policy, clear_accum, + args.stride_a, args.stride_b, args.offset_a, args.offset_b, k_pack, wg_wait, + )Bonus: rename helper parameters from “object” to “obj” to avoid shadowing the built-in.
tilelang/tileop/gemm/__init__.py (1)
56-75: Critical past bug fixed: using m_warp/n_warp (not M/N) for block warps.Emitter now receives
block_row_warps=m_warpandblock_col_warps=n_warp. This resolves the earlier misconfiguration.Also applies to: 105-125
🧹 Nitpick comments (30)
src/layout/gemm_layouts.cc (1)
325-331: Make xor8x8 parameter type consistent with xor2x2/xor4x4.Minor: j is passed by value here while other helpers use const PrimExpr&. Align for consistency and to avoid an extra copy.
-PrimExpr xor8x8(const PrimExpr &i, const PrimExpr j) { +PrimExpr xor8x8(const PrimExpr &i, const PrimExpr &j) {tilelang/layout/swizzle.py (1)
8-9: Clarify comment: helper doesn’t toggle swizzlingThe function always builds a swizzled layout; it doesn’t decide enable/disable based on TMA. Reword to state that call sites/schedules decide when to use this layout.
src/target/utils.cc (1)
21-26: Stricter arch parsing looks good; consider broader formatsCurrent check enforces "sm_" prefix. If upstream targets ever provide "sm90"/"compute_90", this will ICHECK. If that’s possible in your pipeline, normalize accepted forms before stoi.
src/target/codegen_cuda.cc (1)
1334-1344: Remove commented-out legacy asm pathThe commented code for need_cast_smem_ptr_to_int_/PrintLoadMatrixAssembly adds noise. Clean up.
- // need_cast_smem_ptr_to_int_ = true; - // this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, - // local_elem_offset, smem_ptr, - // smem_elem_offset); - std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); + std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); if (trans == 1) func_name += "_trans"; - // this->stream << func_name << "(" << local_ptr "" << ", " << smem_ptr << ");\n"; this->PrintIndent(); this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset<< ", " << local_ptr << " + " << local_elem_offset << ");\n";src/transform/lower_tile_op.cc (2)
306-312: Fix incorrect variable naming that swapped old and new shapesThe variable names are misleading and incorrect. The original buffer shape should be named
old_shapeand the remapped buffer shape should benew_shape, but line 306 assigns the originalload->buffer->shapetoold_shapeand line 326 assigns the remappednew_buffer->shapetonew_shape.However, there's a logical issue here: you're using
old_shapeto compute the element offset from the original buffer (lines 317-320), then using it again to decompose that offset back into multi-dimensional indices (lines 341-345). This is correct. But the forward mapping at line 348 expects indices in the original buffer's coordinate system, not the new one.The current implementation appears correct despite the confusing naming. Consider renaming for clarity:
- Array<PrimExpr> old_shape = load->buffer->shape; + Array<PrimExpr> original_shape = load->buffer->shape; - CHECK_EQ(indices.size(), old_shape.size()) + CHECK_EQ(indices.size(), original_shape.size()) << "Indices size and shape size must match for general N-dimensional " "buffer " << "but got indices size: " << indices.size() - << " and shape size: " << old_shape.size(); + << " and shape size: " << original_shape.size();
358-361: Simplify new_indices computation by using the remapped shape directlyThe decomposition of
new_offsetintonew_indicescan be done more directly since you already havenew_shapeavailable.Array<PrimExpr> new_indices; + PrimExpr remaining = new_offset; for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) { - new_indices.insert(new_indices.begin(), - floormod(new_offset, new_shape[i])); - new_offset = floordiv(new_offset, new_shape[i]); + new_indices.insert(new_indices.begin(), + floormod(remaining, new_shape[i])); + remaining = floordiv(remaining, new_shape[i]); }tilelang/utils/target.py (2)
8-15: Nit: fix constant name typo.Rename AVALIABLE_TARGETS → AVAILABLE_TARGETS for clarity, and update references.
Apply:
-AVALIABLE_TARGETS = { +AVAILABLE_TARGETS = { ... - target, Target) or target in AVALIABLE_TARGETS, f"Target {target} is not supported" + target, Target) or target in AVAILABLE_TARGETS, f"Target {target} is not supported"
82-84: Avoid re-wrapping an existing Target.If return_var is already a Target, re-constructing it is unnecessary.
Apply:
- if return_object: - return Target(return_var) + if return_object: + return return_var if isinstance(return_var, Target) else Target(return_var)tilelang/intrinsics/mma_layout.py (1)
66-69: Clear public aliases.Aliases expose SR/RS A/B variants succinctly. Consider adding one-line docstrings for discoverability.
src/op/gemm_py.cc (4)
95-110: Duplicate kernel-selection logic with GemmNode.GetGemmInst mirrors src/op/gemm.cc. Prefer a shared helper to avoid divergence.
142-192: Duplicate WGMMA gate with GemmNode::CheckWGMMA.Keep a single source of truth (e.g., extract to a common helper) to prevent subtle mismatches.
221-226: Drop unused pre-computation in Lower.block_size/gemm_inst/warp partition are unused; remove to reduce confusion.
Apply:
-Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - auto block_size = *as_const_int(T.thread_bounds->extent); - GemmInst gemm_inst = GetGemmInst(block_size, T.target); - auto [warp_m, warp_n] = policy->ComputeWarpPartition( - M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); +Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
15-15: Remove unused include.tvm/ffi/string.h isn’t used.
Apply:
-#include "tvm/ffi/string.h"src/op/gemm_py.h (2)
39-61: Ensure reflection is registered exactly once.
RegisterReflection()defines bindings but isn't obviously invoked. Please confirm it's called at init (e.g., from gemm_py.cc) to avoid missing reflection in Python/printing.Example (in src/op/gemm_py.cc):
// At file scope namespace { struct GemmPyNodeReflectionRegister { GemmPyNodeReflectionRegister() { GemmPyNode::RegisterReflection(); } } _gemm_py_node_reflection_register; } // namespace
17-17: Avoidusing namespacein a header.Headers should not inject wide namespaces. It's redundant here since members are already qualified.
-using namespace tir;tilelang/__init__.py (1)
110-110: Remove unusednoqadirective (RUF100).
# noqa: F401is unnecessary here and triggers Ruff’s RUF100.-from . import tileop # noqa: F401 +from . import tileoptilelang/engine/phase.py (1)
69-78: Docstring is now stale; mention LetInline instead of “frontend legalization”Update the pipeline bullets to reflect the inlining step to avoid confusion for users reading this docstring.
Apply this diff to refresh the wording:
- - Legalizes frontend Tile IR into TVM-compatible constructs. + - Inlines let expressions/statements (LetInline) to normalize IR.tilelang/tileop/__init__.py (1)
1-1: Remove redundant noqaRuff reports RUF100; the F401 rule isn’t enabled, so the directive is unnecessary here.
-from .gemm import GemmPy # noqa: F401 +from .gemm import GemmPytilelang/language/__init__.py (1)
46-46: Nit: unnecessary noqaSame RUF100 note as elsewhere; consider dropping the directive or declaring explicit all if you want to keep re-export semantics without noqa.
-from .gemm import GemmWarpPolicy, gemm, gemm_v2 # noqa: F401 +from .gemm import GemmWarpPolicy, gemm, gemm_v2tilelang/transform/__init__.py (1)
5-5: Publicly exposing LetInline looks good; minor noqa nitThe re-export is correct and aligns with the new pipeline; consider removing the redundant noqa to satisfy RUF100.
-from .simplify import Simplify, simplify_prim_func, LetInline # noqa: F401 +from .simplify import Simplify, simplify_prim_func, LetInlinetilelang/ir.py (1)
35-39: Add return type for compute_warp_partition for clarityReturning a Tuple[int, int] helps callers and type checkers; no behavior change.
-from tilelang import _ffi_api +from tilelang import _ffi_api +from typing import Tuple @@ - def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, - is_wgmma: bool): + def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target, + is_wgmma: bool) -> Tuple[int, int]:tilelang/transform/simplify.py (2)
30-48: Return-type contract of _Simplify/apply_simplify is misleadingBoth branches return a PrimFunc even when input is an IRModule. Either (a) document that PrimFunc is always returned, or (b) return IRModule when given IRModule. I suggest tightening types to PrimFunc to avoid surprising callers.
-from typing import Union, Callable +from typing import Union, Callable @@ -def _Simplify(stmt: Union[PrimFunc, IRModule], - inline_let: bool = False) -> Union[PrimFunc, IRModule]: +def _Simplify(stmt: Union[PrimFunc, IRModule], + inline_let: bool = False) -> PrimFunc: @@ - elif isinstance(stmt, IRModule): + elif isinstance(stmt, IRModule): @@ - assert len(mod.functions) == 1, "Simplify should return a single function" - return list(mod.functions.values()).pop() + assert len(mod.functions) == 1, "Simplify should return a single function" + return list(mod.functions.values()).pop()If you prefer (b), I can provide a patch that preserves the input type instead.
62-66: Expose the return type in apply_simplify docstring/signatureKeep it consistent with _Simplify if you adopt the suggestion above.
-def apply_simplify(stmt: Union[PrimFunc, IRModule], - inline_let: bool = False) -> Union[PrimFunc, IRModule]: - """Apply Simplify pass to a PrimFunc or IRModule.""" +def apply_simplify(stmt: Union[PrimFunc, IRModule], + inline_let: bool = False) -> PrimFunc: + """Apply Simplify pass and return a single PrimFunc."""tilelang/language/gemm.py (4)
247-248: Prefer TypeError for invalid type arguments (matches Ruff hints)These branches validate types; TypeError is more appropriate than ValueError.
- else: - raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") + else: + raise TypeError(f"Unsupported argument type: {type(object)} for buffer {object}") @@ - else: - raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") + else: + raise TypeError(f"Unsupported argument type: {type(object)} for buffer {object}") @@ - else: - raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") + else: + raise TypeError(f"Unsupported argument type: {type(object)} for buffer {object}") @@ - else: - raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}") + else: + raise TypeError(f"Unsupported argument type: {type(object)} for buffer {object}")Also applies to: 266-267, 315-316, 328-329
185-196: Minor naming and readability nits
- “experimental currently, for fast compilation” is vague; consider adding a brief note on what’s missing (e.g., supported dtypes/backends) and stability expectations.
- Consider adding a deprecation/transition note if gemm_v2 is intended to replace gemm later.
220-232: Tiny style: avoid param name “object”Using “object” as a parameter shadows the built-in. Rename to “obj” in the extracted helpers when you deduplicate.
1-359: Cross-file: potential SEqualReduce bug in C++ GemmPyNode (offset fields)Unrelated to this Python file, but worth fixing before it bites equality/caching: in src/op/gemm_py.h, SEqualReduce compares offset_A to other->offset_B twice. That likely should be offset_A vs offset_A, and offset_B vs offset_B.
I can open a follow-up PR; suggested patch:
- equal(offset_A, other->offset_B) && - equal(offset_B, other->offset_B) && + equal(offset_A, other->offset_A) && + equal(offset_B, other->offset_B) &&testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (1)
470-505: Main guard: keep tests hermetic for CI.Directly running
run_gemm_ss(...)in__main__can hide CI issues and cause flakiness on machines without GPUs. Prefer leaving onlytilelang.testing.main()here.Apply:
- # tilelang.testing.main() - tilelang.disable_cache() + tilelang.testing.main() + # tilelang.disable_cache() @@ - run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1) + # run_gemm_ss(128, 128, 128, False, True, "float16", "float16", "float16", 128, 128, 32, 1)tilelang/intrinsics/mma_macro_generator.py (1)
429-444: make_mma_load_layout: flexible SR/RS construction is good; minor robustness nits.
- Good: clean separation of A vs B and axis-order, and use of IndexMap.inverse/forward.
- Nit: the ValueError messages are long; Ruff TRY003 flags these—consider shorter messages or custom exception types.
Apply (representative):
- raise ValueError(f"Unsupported matrix {matrix}") + raise ValueError("Unsupported matrix type")Also applies to: 487-516
tilelang/tileop/gemm/__init__.py (1)
138-170: Lowering bodies: local buffer sizes and KI loop are consistent; one naming nit.
- Local sizes use
warp_rows * local_size_aandwarp_cols * local_size_b—consistent with ldmatrix/mma strides.- Nit: two functions named
_gemm_rsr(rs and rr cases). Different scopes, but renaming the rr one to_gemm_rrrwould improve grepability.- def _gemm_rsr() -> None: + def _gemm_rrr() -> None: @@ - return _Simplify(_gemm_rsr, inline_let=True) + return _Simplify(_gemm_rrr, inline_let=True)Also applies to: 171-198, 199-224, 225-242
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (27)
3rdparty/tvm(1 hunks)src/layout/gemm_layouts.cc(1 hunks)src/op/gemm.cc(1 hunks)src/op/gemm_py.cc(1 hunks)src/op/gemm_py.h(1 hunks)src/op/gemm_sp.cc(0 hunks)src/target/codegen_cuda.cc(3 hunks)src/target/utils.cc(2 hunks)src/transform/frontend_legalize.cc(2 hunks)src/transform/inject_pipeline.cc(2 hunks)src/transform/lower_tile_op.cc(2 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py(6 hunks)testing/python/transform/test_tilelang_transform_let_inline.py(1 hunks)tilelang/__init__.py(1 hunks)tilelang/engine/phase.py(1 hunks)tilelang/intrinsics/mma_layout.py(1 hunks)tilelang/intrinsics/mma_macro_generator.py(16 hunks)tilelang/ir.py(2 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/gemm.py(1 hunks)tilelang/language/kernel.py(1 hunks)tilelang/layout/swizzle.py(1 hunks)tilelang/tileop/__init__.py(1 hunks)tilelang/tileop/gemm/__init__.py(1 hunks)tilelang/transform/__init__.py(1 hunks)tilelang/transform/simplify.py(3 hunks)tilelang/utils/target.py(2 hunks)
💤 Files with no reviewable changes (1)
- src/op/gemm_sp.cc
🧰 Additional context used
🧬 Code graph analysis (19)
tilelang/engine/phase.py (2)
src/transform/frontend_legalize.cc (2)
LetInline(85-90)LetInline(85-85)tilelang/transform/simplify.py (1)
LetInline(8-16)
testing/python/transform/test_tilelang_transform_let_inline.py (2)
src/transform/frontend_legalize.cc (2)
LetInline(85-90)LetInline(85-85)tilelang/transform/simplify.py (1)
LetInline(8-16)
tilelang/language/__init__.py (1)
tilelang/language/gemm.py (2)
gemm(9-182)gemm_v2(186-359)
tilelang/ir.py (3)
src/op/gemm_py.h (1)
tvm(13-124)src/op/gemm_sp.h (1)
tvm(13-95)tilelang/language/ast/ir.py (1)
target(1682-1713)
tilelang/tileop/__init__.py (4)
tilelang/language/gemm.py (1)
gemm(9-182)src/op/gemm_py.cc (1)
GemmPy(50-80)tilelang/tileop/gemm/__init__.py (1)
GemmPy(31-270)src/op/gemm_py.h (1)
GemmPy(116-121)
tilelang/transform/__init__.py (2)
tilelang/transform/simplify.py (1)
LetInline(8-16)src/transform/frontend_legalize.cc (2)
LetInline(85-90)LetInline(85-85)
src/transform/frontend_legalize.cc (3)
src/transform/lower_tile_op.cc (2)
f(210-240)f(210-210)src/transform/legalize_vectorized_loop.cc (2)
f(49-58)f(49-49)tilelang/transform/simplify.py (1)
LetInline(8-16)
tilelang/language/kernel.py (2)
tilelang/language/frame.py (1)
Current(153-162)tilelang/intrinsics/mma_macro_generator.py (1)
get_thread_binding(117-123)
src/transform/inject_pipeline.cc (1)
tilelang/language/ast/ir.py (1)
block(342-358)
tilelang/transform/simplify.py (2)
src/transform/frontend_legalize.cc (2)
LetInline(85-90)LetInline(85-85)src/transform/simplify.cc (2)
Simplify(481-489)Simplify(481-481)
src/op/gemm_py.cc (3)
tilelang/tileop/gemm/__init__.py (1)
GemmPy(31-270)src/op/gemm.cc (10)
Clone(89-92)Clone(89-89)GetGemmInst(94-108)GetGemmInst(94-94)CheckWGMMA(307-357)CheckWGMMA(307-307)Lower(399-435)Lower(399-399)InferLayout(456-604)InferLayout(456-457)src/target/utils.cc (8)
TargetGetWarpSize(114-119)TargetGetWarpSize(114-114)TargetIsHopper(49-54)TargetIsHopper(49-49)TargetIsCDNA(63-72)TargetIsCDNA(63-63)TargetIsCuda(11-13)TargetIsCuda(11-11)
src/op/gemm.cc (1)
src/op/gemm.h (2)
RegisterReflection(34-40)RegisterReflection(117-139)
src/transform/lower_tile_op.cc (3)
src/transform/atomicadd_vectorize.cc (2)
indices(83-131)indices(83-83)src/transform/loop_vectorize.cc (2)
indices(113-168)indices(113-113)src/transform/loop_vectorize_dynamic.cc (2)
indices(141-194)indices(141-141)
tilelang/utils/target.py (1)
src/target/utils.cc (26)
TargetIsCuda(11-13)TargetIsCuda(11-11)TargetIsRocm(14-16)TargetIsRocm(14-14)TargetIsVolta(28-33)TargetIsVolta(28-28)TargetIsTuring(35-40)TargetIsTuring(35-35)TargetIsAmpere(42-47)TargetIsAmpere(42-42)TargetIsHopper(49-54)TargetIsHopper(49-49)TargetIsSM120(56-61)TargetIsSM120(56-56)TargetIsCDNA(63-72)TargetIsCDNA(63-63)TargetHasAsyncCopy(74-92)TargetHasAsyncCopy(74-74)TargetHasLdmatrix(93-98)TargetHasLdmatrix(93-93)TargetHasStmatrix(100-105)TargetHasStmatrix(100-100)TargetHasBulkCopy(107-112)TargetHasBulkCopy(107-107)TargetGetWarpSize(114-119)TargetGetWarpSize(114-114)
src/op/gemm_py.h (2)
src/op/operator.h (2)
TileOperatorNode(55-95)TileOperator(62-94)src/op/gemm_py.cc (11)
CheckWGMMA(142-192)CheckWGMMA(142-142)Lower(221-252)Lower(221-221)InferLayout(254-269)InferLayout(254-255)Clone(90-93)Clone(90-90)GetGemmInst(95-110)GetGemmInst(95-96)GemmPy(50-80)
tilelang/tileop/gemm/__init__.py (8)
src/op/gemm_py.h (2)
tvm(13-124)GemmPy(116-121)tilelang/utils/target.py (1)
target_is_cuda(87-88)tilelang/intrinsics/mma_macro_generator.py (10)
make_mma_store_layout(518-587)make_mma_load_layout(374-516)ldmatrix_a(165-209)ldmatrix_a(675-779)ldmatrix_b(211-257)ldmatrix_b(781-893)mma(259-317)mma(895-944)mma(949-1047)mma(1052-1151)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-16)tilelang/ir.py (2)
GemmWarpPolicy(30-39)compute_warp_partition(35-39)tilelang/transform/simplify.py (1)
_Simplify(30-49)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)src/op/gemm_py.cc (1)
GemmPy(50-80)
tilelang/intrinsics/mma_macro_generator.py (4)
tilelang/language/kernel.py (4)
get_thread_binding(151-156)get_thread_binding(261-265)KernelLaunchFrame(72-206)Current(115-121)tilelang/intrinsics/mma_layout.py (2)
shared_16x32_to_mma_32x16_layout(72-74)shared_32x16_to_mma_32x16_layout(77-79)tilelang/intrinsics/utils.py (2)
shared_16x32_to_mma_32x16_layout(58-60)shared_32x16_to_mma_32x16_layout(63-65)tilelang/layout/fragment.py (2)
replicate(146-160)repeat(123-144)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6)
tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/gemm.py (1)
gemm_v2(186-359)tilelang/language/__init__.py (1)
annotate_layout(102-140)tilelang/layout/fragment.py (1)
make_swizzled_layout(210-216)tilelang/profiler/__init__.py (1)
assert_allclose(76-137)tilelang/env.py (1)
disable_cache(232-233)
tilelang/language/gemm.py (4)
tilelang/ir.py (1)
GemmWarpPolicy(30-39)tilelang/primitives/gemm/base.py (1)
GemmWarpPolicy(8-188)tilelang/language/frame.py (2)
has_let_value(189-198)get_let_value(201-210)tilelang/language/tir/op.py (1)
call_intrin(119-144)
🪛 Ruff (0.12.2)
tilelang/language/__init__.py
46-46: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/__init__.py
110-110: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/tileop/__init__.py
1-1: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/transform/__init__.py
5-5: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/tileop/gemm/__init__.py
101-101: Avoid specifying long messages outside the exception class
(TRY003)
103-103: Avoid specifying long messages outside the exception class
(TRY003)
243-243: Avoid specifying long messages outside the exception class
(TRY003)
245-245: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_macro_generator.py
427-427: Avoid specifying long messages outside the exception class
(TRY003)
442-442: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Avoid specifying long messages outside the exception class
(TRY003)
514-514: Avoid specifying long messages outside the exception class
(TRY003)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py
246-246: Redefinition of unused threads from line 1
(F811)
362-362: Redefinition of unused threads from line 1
(F811)
tilelang/language/gemm.py
247-247: Prefer TypeError exception for invalid type
(TRY004)
247-247: Avoid specifying long messages outside the exception class
(TRY003)
266-266: Prefer TypeError exception for invalid type
(TRY004)
266-266: Avoid specifying long messages outside the exception class
(TRY003)
315-315: Prefer TypeError exception for invalid type
(TRY004)
315-315: Avoid specifying long messages outside the exception class
(TRY003)
328-328: Prefer TypeError exception for invalid type
(TRY004)
328-328: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (33)
src/layout/gemm_layouts.cc (2)
208-208: Formatting-only change; no semantic impact. LGTM.Chaining Replicate(...)->Repeat(...) on one line preserves the original call order and arguments. No behavior change.
215-215: Formatting-only change; no semantic impact. LGTM.Same call chain and arguments as before; just collapsed for brevity.
3rdparty/tvm (1)
1-1: Submodule bump: verify CI, reproducibility, and downstream API sync.
- CI workflows already invoke
git submodule update --init --recursivefor 3rdparty/tvm.- Record and pin commit eddefbd65acb7b1ea51dd18068b4049754c4fa7a in your third-party docs/changelog and update license headers.
- Confirm this TVM commit defines and exports (via
TVM_REGISTER_GLOBALand function implementations) all FFI symbols used by GemmPy and gemm_v2—namelytl.GemmWarpPolicyComputeWarpPartition,TargetIsCuda,TargetHasLdmatrix, andTargetGetWarpSize.Consider referencing a stable annotated TVM tag alongside the commit hash for easier future audits.
src/transform/inject_pipeline.cc (2)
678-684: LGTM: async scope annotationMarking async producer blocks with tir::attr::async_scope is appropriate and consistent with later commit/wait handling.
954-958: LGTM: refresh block access regions post-mutationRecomputing reads/writes here prevents stale region metadata after rewriting.
src/target/utils.cc (1)
121-150: FFI exposure set LGTMThe reflection registrations match the helpers and align with Python wrappers.
src/target/codegen_cuda.cc (2)
1263-1263: LGTM: add indentation before emitting MMA asmImproves formatting; harmless to codegen.
Also applies to: 1299-1299
1129-1142: The script will show exactly how many and which args tl::ptx_ldmatrix() pushes in copy.cc. This determines whetherprint_extern_call_stmt(func_name, 2)is correct or needs adjustment.src/transform/frontend_legalize.cc (2)
37-44: LGTM! Clean refactoring from FrontendLegalizer to LetInlinerThe renaming from
FrontendLegalizertoLetInlinerbetter reflects the actual functionality of this pass, which inlines let bindings. The implementation remains unchanged, preserving the existing behavior.
85-94: Consistent naming updates throughout the pass registrationThe pass registration and FFI reflection are properly updated to use the new
LetInlinenaming, maintaining consistency with the class rename.src/op/gemm.cc (2)
611-612: LGTM! Proper TVM operator registration for GemmWarpPolicyThe registration correctly sets up the
tl.GemmWarpPolicyoperator with an appropriate TScript printer name for better debugging and visualization.
614-625: Well-structured FFI bindings for GEMM reflection and warp partition computationThe FFI initialization block properly:
- Registers reflection for both
GemmNodeandGemmWarpPolicyNodeclasses- Exposes the
GemmWarpPolicyComputeWarpPartitionfunction to Python- Correctly delegates to the
ComputeWarpPartitionmethod on the policy objectThis enables Python code to query warp partitioning decisions, which is essential for the experimental
T.gemm_v2operator mentioned in the PR objectives.tilelang/utils/target.py (1)
87-136: LGTM: thin wrappers cleanly mirror C++ FFI.The wrappers are minimal and keep Python in sync with backend feature queries.
tilelang/intrinsics/mma_layout.py (2)
50-57: A-layout split and transposed helper look correct.Thread id and local index mapping are consistent; transposed variant is a simple swap. Looks good.
58-65: B-layout split mirrors A-layout as expected.Mapping matches the intended orientation for B; transposed helper is correct.
src/op/gemm_py.cc (1)
271-276: Confirm TL op arity matches Python frontend
set_num_inputs(5)must align with howgemm_pyis called in Python; please verify that the Python binding passes exactly five arguments.testing/python/transform/test_tilelang_transform_let_inline.py (1)
10-10: Switch to LetInline looks correct.The pass rename aligns with the new API and keeps the test intent intact.
tilelang/engine/phase.py (1)
88-89: Switching to LetInline at the start of LowerAndLegalize is sensibleInlining lets before InjectAssumes/Simplify is a good move for later analyses and matches the new public pass. No functional concerns from me.
tilelang/transform/simplify.py (1)
8-17: LetInline wrapper: LGTMThe pass exposure matches the C++ binding and keeps naming consistent across layers.
tilelang/language/gemm.py (1)
340-359: Registration oftl.gemm_pyconfirmed: the op is registered insrc/op/gemm_py.ccviaTIR_REGISTER_TL_OP(GemmPy, gemm_py)and has associated FFI hooks for
.lowerand.infer_layout.testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (3)
35-37: Shared allocs: consistent explicit scopes look good.Explicit
scope="shared"and fragment allocs are consistent across variants. No issues spotted.Also applies to: 149-151, 263-266, 380-384
48-49: Correctly switched to T.gemm_v2.Calls into
T.gemm_v2are wired with expected buffer kinds (shared/fragment) and transpose flags. Good.Also applies to: 167-168, 281-282, 401-402
154-156: Layout annotations are coherent with rs/sr/rr paths.Swizzling the shared operand and keeping the fragment operand un-swizzled matches the
gemm_v2layout inference paths. Looks good.Also applies to: 268-270, 386-389
tilelang/intrinsics/mma_macro_generator.py (6)
53-79: Thread binding override via thread_var is a solid addition.Gracefully falls back to the current Kernel frame when
thread_varis absent, enabling external control in lowering. Nice.Also applies to: 117-124
179-206: ldmatrix_a: index/transpose coupling—double-check semantics.You compute
A_shared_buf_elemas[wk, wi]whena_transposedelse[wi, wk], and passtrans=self.a_transposedtoptx_ldmatrix. This ties address orientation and the hardware transpose flag to the same boolean. Depending on your MMA variant (row.col), A often uses non-transposed LD for non-transposed A, but the required ldmatrix “trans” bit can be architecture- and layout-specific. Please confirm on a small kernel that both TN and NN paths load correct lanes for m16n8k16.I can help craft a minimal kernel to validate lane mapping if needed.
224-255: ldmatrix_b: inverted trans bit matches row.col expectation.For B you use
trans = not b_transposedand flip the index order accordingly. This aligns with typical row.col usage. LGTM pending A verification above.
333-371: stmatrix: correct handling of 2D vs 4D destinations.Thread/lane mapping and local_id calculation align with the store IndexMap. Looks consistent.
399-404: New sr A/B transform funcs: ensure availability and dtype coverage.The imports reference
shared_16x16_to_mma_32x8_layout_sr_a/_sr_bfor 16-bit and 16x32/32x16 for 8-bit. Verify these functions exist across all targeted CUDA arches and cover FP16/BF16/FP8. Missing symbols will break import at compile-time.If any dtype isn’t ready, gate it behind a clear error in
gemm_v2or policy selection.Also applies to: 421-426
280-316: mma: saturate flag explicitly false—OK.Explicit boolean improves readability. No functional concerns.
tilelang/tileop/gemm/__init__.py (4)
17-21: FFI shims return thread extent correctly.Bridging to
infer_layout/lowerusingthread_bounds.extentlooks correct.Also applies to: 23-28
76-101: Infer-layout mappings align with operand kinds.
- ss: swizzle A/B shared; C uses MMA store layout.
- sr/rs: fragment side gets MMA load layout; shared side swizzled.
- rr: both A/B load layouts; C store layout.
This matches test variants and emitter expectations.
247-259: in_dtype/accum_dtype/chunk accessors look correct.
in_dtypeguard asserts A/B dtype equality.chunkrespectstrans_Aby selecting K-positioned axis.
260-271: is_gemm_ predicates correctly classify shared vs fragment.*Accurate use of
is_shared/is_fragment. Good.
| GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { | ||
| ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>(); | ||
|
|
||
| node->Aptr = args[0]; | ||
| node->Bptr = args[1]; | ||
| node->Cptr = args[2]; | ||
| node->A = vmap[GetVarFromAccessPtr(node->Aptr)]; | ||
| node->B = vmap[GetVarFromAccessPtr(node->Bptr)]; | ||
| node->C = vmap[GetVarFromAccessPtr(node->Cptr)]; | ||
| node->trans_A = args[3].as<Bool>().value(); | ||
| node->trans_B = args[4].as<Bool>().value(); | ||
| node->M = args[5].as<IntImm>().value()->value; | ||
| node->N = args[6].as<IntImm>().value()->value; | ||
| node->K = args[7].as<IntImm>().value()->value; | ||
| node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value); | ||
| node->clear_accum = args[9].as<Bool>().value(); | ||
| node->stride_A = args[10].as<IntImm>().value()->value; | ||
| node->stride_B = args[11].as<IntImm>().value()->value; | ||
| node->offset_A = args[12].as<IntImm>().value()->value; | ||
| node->offset_B = args[13].as<IntImm>().value()->value; | ||
| if (args.size() > 14) { | ||
| node->kPack = args[14].as<IntImm>().value()->value; | ||
| if (node->kPack != 1 && node->kPack != 2) { | ||
| ICHECK(false) << "kPack must be 1 or 2"; | ||
| } | ||
| } | ||
| if (args.size() > 15) { | ||
| node->wg_wait = args[15].as<IntImm>().value()->value; | ||
| } | ||
| data_ = std::move(node); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bounds-check args and validate BufferMap lookups.
Constructor indexes args[0..13] unconditionally and uses vmap[...] without presence checks. This risks OOB access or null buffers.
Apply:
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
+ // Required: Aptr, Bptr, Cptr, trans_A, trans_B, M, N, K, policy, clear_accum,
+ // stride_A, stride_B, offset_A, offset_B
+ ICHECK_GE(args.size(), 14) << "gemm_py expects at least 14 positional args";
+
+ // Defaults for optional fields
+ node->kPack = 1;
+ node->wg_wait = 0;
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
- node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
- node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
- node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
- node->trans_A = args[3].as<Bool>().value();
- node->trans_B = args[4].as<Bool>().value();
+ auto Avar = GetVarFromAccessPtr(node->Aptr);
+ auto Bvar = GetVarFromAccessPtr(node->Bptr);
+ auto Cvar = GetVarFromAccessPtr(node->Cptr);
+ ICHECK(vmap.count(Avar)) << "Aptr not found in BufferMap";
+ ICHECK(vmap.count(Bvar)) << "Bptr not found in BufferMap";
+ ICHECK(vmap.count(Cvar)) << "Cptr not found in BufferMap";
+ node->A = vmap.at(Avar);
+ node->B = vmap.at(Bvar);
+ node->C = vmap.at(Cvar);
+ node->trans_A = args[3].as<IntImm>().value()->value != 0;
+ node->trans_B = args[4].as<IntImm>().value()->value != 0;
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
- node->clear_accum = args[9].as<Bool>().value();
+ node->clear_accum = args[9].as<IntImm>().value()->value != 0;
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
data_ = std::move(node);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { | |
| ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>(); | |
| node->Aptr = args[0]; | |
| node->Bptr = args[1]; | |
| node->Cptr = args[2]; | |
| node->A = vmap[GetVarFromAccessPtr(node->Aptr)]; | |
| node->B = vmap[GetVarFromAccessPtr(node->Bptr)]; | |
| node->C = vmap[GetVarFromAccessPtr(node->Cptr)]; | |
| node->trans_A = args[3].as<Bool>().value(); | |
| node->trans_B = args[4].as<Bool>().value(); | |
| node->M = args[5].as<IntImm>().value()->value; | |
| node->N = args[6].as<IntImm>().value()->value; | |
| node->K = args[7].as<IntImm>().value()->value; | |
| node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value); | |
| node->clear_accum = args[9].as<Bool>().value(); | |
| node->stride_A = args[10].as<IntImm>().value()->value; | |
| node->stride_B = args[11].as<IntImm>().value()->value; | |
| node->offset_A = args[12].as<IntImm>().value()->value; | |
| node->offset_B = args[13].as<IntImm>().value()->value; | |
| if (args.size() > 14) { | |
| node->kPack = args[14].as<IntImm>().value()->value; | |
| if (node->kPack != 1 && node->kPack != 2) { | |
| ICHECK(false) << "kPack must be 1 or 2"; | |
| } | |
| } | |
| if (args.size() > 15) { | |
| node->wg_wait = args[15].as<IntImm>().value()->value; | |
| } | |
| data_ = std::move(node); | |
| } | |
| GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { | |
| ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>(); | |
| // Required: Aptr, Bptr, Cptr, trans_A, trans_B, M, N, K, policy, clear_accum, | |
| // stride_A, stride_B, offset_A, offset_B | |
| ICHECK_GE(args.size(), 14) << "gemm_py expects at least 14 positional args"; | |
| // Defaults for optional fields | |
| node->kPack = 1; | |
| node->wg_wait = 0; | |
| node->Aptr = args[0]; | |
| node->Bptr = args[1]; | |
| node->Cptr = args[2]; | |
| auto Avar = GetVarFromAccessPtr(node->Aptr); | |
| auto Bvar = GetVarFromAccessPtr(node->Bptr); | |
| auto Cvar = GetVarFromAccessPtr(node->Cptr); | |
| ICHECK(vmap.count(Avar)) << "Aptr not found in BufferMap"; | |
| ICHECK(vmap.count(Bvar)) << "Bptr not found in BufferMap"; | |
| ICHECK(vmap.count(Cvar)) << "Cptr not found in BufferMap"; | |
| node->A = vmap.at(Avar); | |
| node->B = vmap.at(Bvar); | |
| node->C = vmap.at(Cvar); | |
| node->trans_A = args[3].as<IntImm>().value()->value != 0; | |
| node->trans_B = args[4].as<IntImm>().value()->value != 0; | |
| node->M = args[5].as<IntImm>().value()->value; | |
| node->N = args[6].as<IntImm>().value()->value; | |
| node->K = args[7].as<IntImm>().value()->value; | |
| node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value); | |
| node->clear_accum = args[9].as<IntImm>().value()->value != 0; | |
| node->stride_A = args[10].as<IntImm>().value()->value; | |
| node->stride_B = args[11].as<IntImm>().value()->value; | |
| node->offset_A = args[12].as<IntImm>().value()->value; | |
| node->offset_B = args[13].as<IntImm>().value()->value; | |
| if (args.size() > 14) { | |
| node->kPack = args[14].as<IntImm>().value()->value; | |
| if (node->kPack != 1 && node->kPack != 2) { | |
| ICHECK(false) << "kPack must be 1 or 2"; | |
| } | |
| } | |
| if (args.size() > 15) { | |
| node->wg_wait = args[15].as<IntImm>().value()->value; | |
| } | |
| data_ = std::move(node); | |
| } |
🤖 Prompt for AI Agents
In src/op/gemm_py.cc around lines 50 to 80, the constructor assumes args has at
least 14 entries and that vmap contains entries for GetVarFromAccessPtr results,
which can cause OOB or missing-buffer crashes; add explicit checks before each
indexed access (ICHECK or if-checks verifying args.size() >= required_index+1)
and validate that .as<...>() returns non-null (or use ICHECK(node) semantics)
before dereferencing value(); for vmap lookups, compute the key once, check
vmap.count(key) (or vmap.find(key) != vmap.end()) and handle missing buffers
with a clear ICHECK(false) or throw with a descriptive message; keep existing
optional-field handling but gate accesses to args[14] and args[15] behind
args.size() checks; ensure all messages indicate which arg or buffer is missing
to aid debugging.
| # More test case can be found in kernel/test_tilelang_kernel_gemm.py | ||
| # GEMM tests for float16 | ||
| run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, | ||
| 2) # f16f16f16_nn | ||
| run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0) | ||
| run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) | ||
| run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0) | ||
| run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_stages=0 will likely break T.Pipelined; use >= 1.
Passing 0 to T.Pipelined(..., num_stages=...) can lead to invalid scheduling or no prefetch. Align with other tests using 2 and set to at least 1.
Apply:
- run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0)
- run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0)
- run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0)
- run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0)
+ run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 1)
+ run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 1)
+ run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 1)
+ run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 1)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # More test case can be found in kernel/test_tilelang_kernel_gemm.py | |
| # GEMM tests for float16 | |
| run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, | |
| 2) # f16f16f16_nn | |
| run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 0) | |
| run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 0) | |
| run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 0) | |
| run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 0) | |
| # More test case can be found in kernel/test_tilelang_kernel_gemm.py | |
| # GEMM tests for float16 | |
| run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 1) | |
| run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 1) | |
| run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 1) | |
| run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 1) |
🤖 Prompt for AI Agents
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py around lines 110 to
116: the GEMM calls pass num_stages=0 which can break T.Pipelined; change the
final argument from 0 to 2 (or at least 1) for each run_gemm_ss invocation so
Pipelined gets a valid number of stages (prefer 2 to match other tests).
| M, N = C_shape | ||
| K = A_shape[-2] if transpose_A else A_shape[-1] | ||
| K_B = B_shape[-1] if transpose_B else B_shape[-2] | ||
| assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix error message to reference gemm_v2
The assertion message still says “T.gemm”; this will mislead users debugging shape issues.
- assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}"
+ assert K == K_B, f"T.gemm_v2 K shape check failed: K_A = {K}, K_B = {K_B}"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| M, N = C_shape | |
| K = A_shape[-2] if transpose_A else A_shape[-1] | |
| K_B = B_shape[-1] if transpose_B else B_shape[-2] | |
| assert K == K_B, f"T.gemm K shape check failed: K_A = {K}, K_B = {K_B}" | |
| M, N = C_shape | |
| K = A_shape[-2] if transpose_A else A_shape[-1] | |
| K_B = B_shape[-1] if transpose_B else B_shape[-2] | |
| assert K == K_B, f"T.gemm_v2 K shape check failed: K_A = {K}, K_B = {K_B}" |
🤖 Prompt for AI Agents
In tilelang/language/gemm.py around lines 287 to 291, the assertion message
refers to "T.gemm" which is incorrect for this function; update the assertion
string to reference "T.gemm_v2" (or the correct function name) so the message
reads something like "T.gemm_v2 K shape check failed: K_A = {K}, K_B = {K_B}"
and retain the existing shape variables in the f-string for clarity.
| assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" | ||
| return KernelLaunchFrame.Current().get_thread_binding(dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Replace asserts with hard checks to avoid -O stripping
Using assert for runtime preconditions can be stripped with Python -O, risking None deref. Raise an explicit exception and avoid double calling Current().
Apply this diff to each wrapper:
def get_thread_binding(dim: int = 0) -> Var:
"""Returns the thread binding for the given dimension.
"""
- assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
- return KernelLaunchFrame.Current().get_thread_binding(dim)
+ current = KernelLaunchFrame.Current()
+ if current is None:
+ raise RuntimeError("KernelLaunchFrame is not initialized")
+ return current.get_thread_binding(dim)Repeat the same pattern for:
- get_thread_bindings()
- get_block_binding()
- get_block_bindings()
- get_thread_extent()
- get_thread_extents()
- get_block_extent()
- get_block_extents()
Optionally add a small helper to DRY:
def _current_kernel_frame() -> KernelLaunchFrame:
current = KernelLaunchFrame.Current()
if current is None:
raise RuntimeError("KernelLaunchFrame is not initialized")
return currentThen call _current_kernel_frame() in the wrappers.
Also applies to: 271-272, 278-279, 285-286, 292-293, 299-300, 306-307, 313-314
🤖 Prompt for AI Agents
In tilelang/language/kernel.py around lines 264-265 (and similarly at 271-272,
278-279, 285-286, 292-293, 299-300, 306-307, 313-314), replace the assert-based
runtime checks with explicit checks that raise a RuntimeError when
KernelLaunchFrame.Current() is None and avoid calling Current() twice; implement
(optionally) a small helper like _current_kernel_frame() that calls
KernelLaunchFrame.Current(), raises RuntimeError("KernelLaunchFrame is not
initialized") when None, and returns the current frame, then have each wrapper
call that helper and return the appropriate get_* result (e.g.,
get_thread_binding, get_thread_bindings, get_block_binding, etc.).
| @@ -1,5 +1,6 @@ | |||
| from typing import Literal, Union | |||
| from tilelang import tvm as tvm | |||
| from tilelang import _ffi_api | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Guard against missing FFI bindings (verify symbols).
Import assumes the C++ FFI registrations exist. Please verify the symbols are present to avoid ImportError at runtime.
Run this to confirm all required FFI globals are registered:
🏁 Script executed:
#!/bin/bash
# Expect matches for each symbol below.
rg -nP 'TVM_REGISTER_GLOBAL\("tl\.(Target(Is|Has|Get).*)"\)' -g 'src/**' -S
# Quick check of Python-facing wrappers reference:
rg -nP '\bTarget(Is|Has|Get)\w+\(' src/target/utils.cc src/target/utils.hLength of output: 3351
Add missing FFI global registrations for target utilities.
No TVM_REGISTER_GLOBAL("tl.TargetIs…"), "tl.TargetHas…" or "tl.TargetGet…" entries were found, so _ffi_api won’t expose TargetIsCuda, TargetHasAsyncCopy, TargetGetWarpSize, etc. Add corresponding registrations in src/target/utils.cc, for example:
TVM_REGISTER_GLOBAL("tl.TargetIsCuda")
.set_body_typed(TargetIsCuda);
// …and similarly for TargetIsRocm, TargetHasAsyncCopy, TargetGetWarpSize, etc.🤖 Prompt for AI Agents
In tilelang/utils/target.py around line 3, the module imports _ffi_api but the
corresponding C++ FFI globals are not registered, so functions like
TargetIsCuda, TargetIsRocm, TargetHasAsyncCopy, TargetGetWarpSize are not
exposed; open src/target/utils.cc and add TVM_REGISTER_GLOBAL entries mapping
those symbols to their typed C++ implementations (e.g.
TVM_REGISTER_GLOBAL("tl.TargetIsCuda").set_body_typed(TargetIsCuda); and
likewise for TargetIsRocm, TargetHasAsyncCopy, TargetGetWarpSize and any other
missing TargetIs*/TargetHas*/TargetGet* utilities), then rebuild to ensure
_ffi_api exposes the functions to tilelang/utils/target.py.
- Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations. - Improved block realization handling in `gemm_py.cc` for better assignment of global symbols. - Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity. - Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations. These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/profiler/__init__.py (1)
129-145: Guard and refactor all direct torch.float8_ references*
Wrap every use of torch.float8_e* across the codebase—e.g. in tilelang/utils/tensor.py, tilelang/contrib/dlpack.py, tests under testing/python/** and examples under examples/**—with a version-agnostic check (e.g. hasattr(torch, …) or your new is_float8() helper) or fallback via getattr(torch, dtype_name) to avoid AttributeErrors on older PyTorch.src/transform/inject_pipeline.cc (1)
676-684: Use Integer(1) and implement async_scope lowering
- Change the literal
1toInteger(1)in the AttrStmt for type safety.- No references to
tir::attr::async_scopewere found in lowering or codegen (the attribute is effectively a no-op). Implement its consumption downstream or remove the marker.
♻️ Duplicate comments (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (1)
504-504: Use "float32" instead of "float" for dtype consistency.Line 504 uses
"float"as the accumulation dtype, but based on the search results and the TensorCoreIntrinEmitter's dtype_abbrv mapping in the first file, this should be"float32"to avoid KeyError during lowering.
🧹 Nitpick comments (8)
tilelang/profiler/__init__.py (1)
119-137: Minor: avoid redefining helper inside the loopDefine
is_float8once at module scope to reduce per-iteration overhead and make it reusable.tilelang/tileop/gemm/gemm_base.py (2)
13-18: Add lightweight contract checks early.Assert basic dtype consistency up front to fail-fast.
class GemmBase(object): gemm_node: "GemmPy" def infer_layout(self, target: Target, thread_nums: int): raise NotImplementedError("infer_layout is not implemented") def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): raise NotImplementedError("lower is not implemented") + + def __post_init__(self): + # Basic sanity: inputs must agree in dtype; C holds accum dtype. + assert self.A.dtype == self.B.dtype, "A and B must have the same dtype" + # If node carries explicit dtypes, keep them consistent with buffers. + if getattr(self.gemm_node, "in_dtype", None): + assert self.gemm_node.in_dtype == self.A.dtype, "in_dtype mismatch" + if getattr(self.gemm_node, "accum_dtype", None): + assert self.gemm_node.accum_dtype == self.C.dtype, "accum_dtype mismatch"
101-116: Tighten type hints for TIR scalars/exprs.
stride_*/offset_*are typicallytir.PrimExpr, not plainint.- def stride_A(self) -> int: + def stride_A(self) -> tir.PrimExpr: return self.gemm_node.stride_A @@ - def stride_B(self) -> int: + def stride_B(self) -> tir.PrimExpr: return self.gemm_node.stride_B @@ - def offset_A(self) -> int: + def offset_A(self) -> tir.PrimExpr: return self.gemm_node.offset_A @@ - def offset_B(self) -> int: + def offset_B(self) -> tir.PrimExpr: return self.gemm_node.offset_Btilelang/intrinsics/mma_layout.py (2)
106-116: Naming consistency (“ldmatrix_” vs “mma_load_”).Elsewhere, loaders are
ldmatrix_*. Consider renamingmma_load_{a,b}_32x16_to_shared_16x32_layouttoldmatrix_32x16_to_shared_16x32_layout_{a,b}or add aliases to avoid confusion.
50-53: Minor comment/style polish.Nit: rephrase “if wanna trans” to “for transposed layout, apply map_indices” to keep code comments professional.
Also applies to: 60-68
tilelang/intrinsics/utils.py (1)
24-26: Broaden dtype annotation to include bf16 (supported in this PR).Implementation handles any 16-bit dtype; the type hint restricts to "float16" only.
- dtype: Literal["float16", "int8"] = "float16", + dtype: Literal["float16", "bfloat16", "int8"] = "float16",tilelang/tileop/gemm/__init__.py (1)
52-65: Consider extracting the error message to reduce redundancyBoth methods raise the same error message. While the static analysis hint about TRY003 isn't critical, you could improve maintainability by defining the error message once.
+UNSUPPORTED_TARGET_MSG = "Unsupported target: {}" + def infer_layout(self, target: Target, thread_nums: int): if target_is_cuda(target): # TODO(lei): Support more cuda architectures, now mma only return GemmMMA(self).infer_layout(target, thread_nums) else: - raise ValueError(f"Unsupported target: {target}") + raise ValueError(UNSUPPORTED_TARGET_MSG.format(target)) def lower(self, target: Target, thread_nums: int, thread_var: tir.Var): if target_is_cuda(target): # TODO(lei): Support more cuda architectures, now mma only # Now only implement ssr layout return GemmMMA(self).lower(target, thread_nums, thread_var) else: - raise ValueError(f"Unsupported target: {target}") + raise ValueError(UNSUPPORTED_TARGET_MSG.format(target))tilelang/tileop/gemm/gemm_mma.py (1)
197-207: Remove redundant is_gemm_ methods from gemm_mma.py*
These four methods duplicate identical implementations in GemmBase; delete lines 197–207 to eliminate code duplication.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
src/op/copy.cc(0 hunks)src/op/gemm.cc(2 hunks)src/op/gemm_py.cc(1 hunks)src/target/codegen_cuda.cc(3 hunks)src/transform/inject_pipeline.cc(2 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py(7 hunks)tilelang/intrinsics/mma_layout.py(2 hunks)tilelang/intrinsics/mma_macro_generator.py(18 hunks)tilelang/intrinsics/utils.py(2 hunks)tilelang/profiler/__init__.py(1 hunks)tilelang/tileop/gemm/__init__.py(1 hunks)tilelang/tileop/gemm/gemm_base.py(1 hunks)tilelang/tileop/gemm/gemm_mma.py(1 hunks)tilelang/utils/tensor.py(1 hunks)
💤 Files with no reviewable changes (1)
- src/op/copy.cc
🚧 Files skipped from review as they are similar to previous changes (2)
- src/target/codegen_cuda.cc
- src/op/gemm_py.cc
🧰 Additional context used
🧬 Code graph analysis (8)
tilelang/profiler/__init__.py (2)
tilelang/engine/param.py (1)
is_float8(81-91)tilelang/utils/tensor.py (1)
torch_assert_close(220-312)
tilelang/intrinsics/utils.py (1)
tilelang/intrinsics/mma_layout.py (1)
ldmatrix_32x16_to_shared_16x32_layout_b(36-39)
tilelang/tileop/gemm/gemm_mma.py (5)
tilelang/tileop/gemm/gemm_base.py (21)
GemmBase(10-131)infer_layout(13-14)policy(130-131)M(33-34)N(37-38)in_dtype(53-54)in_dtype(65-67)accum_dtype(57-58)accum_dtype(70-71)trans_A(45-46)trans_B(49-50)chunk(61-62)chunk(74-75)is_gemm_ss(20-21)A(78-79)B(82-83)C(86-87)is_gemm_sr(23-24)is_gemm_rs(26-27)is_gemm_rr(29-30)lower(16-17)tilelang/intrinsics/mma_macro_generator.py (11)
TensorCoreIntrinEmitter(24-630)make_mma_store_layout(561-630)make_mma_load_layout(421-559)ldmatrix_a(190-241)ldmatrix_a(718-822)ldmatrix_b(243-302)ldmatrix_b(824-936)mma(304-364)mma(938-987)mma(992-1090)mma(1095-1194)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/transform/simplify.py (1)
_Simplify(30-49)tilelang/tileop/gemm/__init__.py (2)
infer_layout(52-57)lower(59-65)
tilelang/tileop/gemm/gemm_base.py (3)
tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/tileop/gemm/__init__.py (2)
infer_layout(52-57)lower(59-65)tilelang/tileop/gemm/gemm_mma.py (6)
infer_layout(15-57)lower(59-194)is_gemm_ss(197-198)is_gemm_sr(200-201)is_gemm_rs(203-204)is_gemm_rr(206-207)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6)
tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/gemm.py (1)
gemm_v2(186-359)tilelang/jit/kernel.py (3)
get_kernel_source(378-389)get_profiler(360-376)out_idx(446-447)tilelang/language/__init__.py (1)
annotate_layout(102-140)tilelang/layout/fragment.py (1)
make_swizzled_layout(210-216)tilelang/profiler/__init__.py (1)
assert_allclose(76-145)
tilelang/tileop/gemm/__init__.py (5)
src/op/gemm_py.h (2)
tvm(13-124)GemmPy(116-121)tilelang/utils/target.py (1)
target_is_cuda(87-88)tilelang/ir.py (1)
GemmWarpPolicy(30-39)tilelang/tileop/gemm/gemm_mma.py (3)
GemmMMA(13-207)infer_layout(15-57)lower(59-194)src/op/gemm_py.cc (1)
GemmPy(50-80)
src/op/gemm.cc (1)
src/op/gemm.h (2)
RegisterReflection(34-40)RegisterReflection(117-139)
tilelang/intrinsics/mma_macro_generator.py (4)
tilelang/intrinsics/utils.py (2)
mma_store_index_map(68-69)get_ldmatrix_offset(19-50)tilelang/utils/language.py (1)
is_fragment(68-78)tilelang/intrinsics/mma_layout.py (3)
mma_load_a_32x16_to_shared_16x32_layout(106-109)mma_load_b_32x16_to_shared_16x32_layout(112-115)transform_func(176-179)tilelang/language/kernel.py (5)
threads(195-199)get_thread_binding(151-156)get_thread_binding(261-265)KernelLaunchFrame(72-206)Current(115-121)
🪛 Ruff (0.12.2)
tilelang/tileop/gemm/gemm_mma.py
57-57: Avoid specifying long messages outside the exception class
(TRY003)
194-194: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm/gemm_base.py
11-11: Undefined name GemmPy
(F821)
65-65: Redefinition of unused in_dtype from line 53
(F811)
70-70: Redefinition of unused accum_dtype from line 57
(F811)
74-74: Redefinition of unused chunk from line 61
(F811)
tilelang/tileop/gemm/__init__.py
57-57: Avoid specifying long messages outside the exception class
(TRY003)
65-65: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_macro_generator.py
468-468: Avoid specifying long messages outside the exception class
(TRY003)
485-485: Avoid specifying long messages outside the exception class
(TRY003)
543-543: Avoid specifying long messages outside the exception class
(TRY003)
557-557: Avoid specifying long messages outside the exception class
(TRY003)
🪛 GitHub Actions: CI
tilelang/tileop/gemm/gemm_mma.py
[error] 150-150: codespell: Misspelled word 'parrent' found at line 150; should be 'parent'.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (32)
src/transform/inject_pipeline.cc (1)
954-959: Good: recomputing block read/write regions to match the rewritten body.This keeps region metadata in sync post-rewrite and reduces chances of stale reads/writes leaking into later passes.
tilelang/intrinsics/mma_layout.py (1)
76-88: A/B 32x16 mappings: request quick sanity checks.Formulas look consistent with 32x16 microtiles; please confirm symmetry of A/B and transposed wrappers with a small table-driven check (thread_id∈[0,31], local_id∈[0,15]) to ensure ranges and uniqueness per warp.
You can validate purely with Python (no GPU) by checking for duplicate (row,col) pairs and expected ranges.
Also applies to: 90-92
tilelang/intrinsics/utils.py (2)
3-9: Import rename LGTM; keep both A/B helpers aligned.Switching to
ldmatrix_32x16_to_shared_16x32_layout_bmatches the new A/B split.
39-43: Verified no callers of old B helper
Only its definition remains intilelang/intrinsics/mma_layout.py; no compatibility alias needed.src/op/gemm.cc (2)
616-631: LGTM! Clean FFI registration implementationThe FFI registration block is well-structured. The static initialization ensures proper registration of the GemmWarpPolicy operations at module load time, and the lambda wrapper cleanly exposes the ComputeWarpPartition functionality.
254-260: No action needed: existing guards prevent divide-by-zero and enforce divisibility
Theif (m_per_warp < 1 || n_per_warp < 1)check precludes any zero or negative divisor, and the initialICHECKconstraints combined with the post-search splitting logic guarantee that both M % (m × kMPerWarp) and N % (n × kNPerWarp) are zero.Likely an incorrect or invalid review comment.
tilelang/tileop/gemm/__init__.py (1)
13-16: Return value is unused but consistent with the designThe
gemm_py.infer_layout()returns a dict mapping buffers to layouts, but the wrapper function doesn't explicitly return this value. This appears intentional as the TVM FFI may handle the return implicitly.tilelang/intrinsics/mma_macro_generator.py (16)
5-5: LGTM: Var import correctly added for thread_var typing.The import of
Varfrom TVM is necessary to support the new optionalthread_varparameter typing.
12-19: LGTM: MMA layout imports updated correctly.The layout imports have been updated to support the new matrix-specific layout functions (
_sr_aand_sr_bvariants) and the mma_load functions, which are needed for the enhanced A/B matrix handling.
30-32: LGTM: Dynamic n_dim support added.The change from a constant
N_DIM = 16to a dynamicn_dim = 16with improved documentation is correct. This allows for flexible n_dim sizes (16 or 8) based on the warp configuration.
62-62: New parameter added: thread_var for custom thread binding.The optional
thread_varparameter allows overriding the default thread binding retrieval mechanism.
77-79: LGTM: Micro size initialization updated for dynamic n_dim.The calls to
_initialize_micro_sizeand_initialize_local_sizenow properly pass the dynamicn_dimvalue.
85-85: LGTM: Thread variable stored correctly.The
thread_varparameter is properly stored as an instance variable.
115-134: LGTM: Enhanced micro size initialization with dynamic n_dim.The refactored
_initialize_micro_sizemethod now properly handles dynamic n_dim calculation based on warp_col_tiles, supporting both 16 and 8 dimensions. The validation logic ensures proper tile sizes and divisibility.
142-149: LGTM: Thread binding abstraction added.The new
get_thread_bindingmethod provides a clean abstraction that allows using either a customthread_varor the default kernel frame thread binding.
203-204: LGTM: Improved ldmatrix availability check.The ldmatrix availability logic now correctly excludes int8 + transposed A matrix cases, which are not supported by the ldmatrix instruction.
206-206: LGTM: Thread binding usage centralized.The use of
self.get_thread_binding()provides consistent thread binding access throughout the method.
218-240: LGTM: Improved ldmatrix_a with proper transposition handling.The enhanced implementation now properly handles transposed matrices with the correct indexing (
A_shared_buf[wk, wi]vsA_shared_buf[wi, wk]) and passes the transposition flag correctly to the ldmatrix instruction.
256-300: LGTM: Enhanced ldmatrix_b with proper transposition and replicate handling.The implementation now correctly:
- Uses centralized thread binding
- Calculates
replicate_bbased onn_dim == 16- Handles transposed B matrix indexing properly
- Sets the correct
transflag for the ldmatrix instruction
319-319: LGTM: MMA instruction enhanced with replicate support.The MMA implementation now properly handles replication for 16-dimensional cases with the additional MMA instruction call when
replicate_bis true. The saturate parameter is correctly documented.Also applies to: 345-362
376-380: LGTM: Dynamic n_dim usage in stmatrix.The stmatrix method now correctly uses the dynamic
n_dimvalue instead of a hardcoded constant, and maintains consistent thread binding usage.
447-558: Enhanced make_mma_load_layout with matrix-specific handling.The method has been significantly refactored to properly handle A/B matrices with different axis orders and transposition states. The new implementation:
- Uses matrix-specific flags (
matrix_is_a,matrix_is_b) for clearer logic- Employs dedicated transform functions for A and B matrices
- Handles sr/rs axis order conditions properly
- Constructs layouts with appropriate warp/block/replicate dimensions
The logic appears sound for handling the various GEMM configurations (ss/sr/rs/rr).
729-729: LGTM: Thread binding consistency in ladder transform classes.The ladder transform emitter classes now also use the centralized
get_thread_binding()method for consistency.Also applies to: 836-836
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (9)
34-35: LGTM: Shared memory allocation with explicit scope.The shared memory allocations now use explicit
scope="shared"parameter, which is consistent with the new memory management approach.
47-49: LGTM: Switch to gemm_v2 API.The transition from the old
T.gemmtoT.gemm_v2aligns with the PR objective of introducing the new GEMM operator stack.
54-68: LGTM: Function renamed and kernel source printing added.The renaming from
run_gemmtorun_gemm_ssprovides better clarity for the shared-shared configuration. Addingprint(kernel.get_kernel_source())aids in debugging and verification.Also applies to: 93-94
110-128: LGTM: Expanded test coverage with corrected num_stages.The test suite now includes:
- Fixed
num_stages=2instead of0(avoiding potential pipeline issues)- Enhanced data type coverage (int8, float8_e5m2)
- n8 dimension testing
This significantly improves test coverage for various GEMM configurations.
160-167: LGTM: RS configuration with swizzled layout and fragment handling.The
matmul_rsfunction properly implements the reduction-spatial configuration with:
- Swizzled layout annotation for A_shared
- Fragment allocation and copy for A matrix
- Correct gemm_v2 usage
Also applies to: 177-178
242-255: LGTM: Comprehensive RS test coverage.The RS tests now include the full spectrum of configurations (transpose combinations, data types, and dimension variants) consistent with the SS tests.
258-384: LGTM: New SR configuration implementation and tests.The
matmul_srfunction and corresponding tests properly implement the spatial-reduction configuration with B matrix fragmentation and swizzled layout. The test coverage mirrors the comprehensive approach used in other configurations.
386-516: LGTM: Complete RR configuration with dual fragmentation.The
matmul_rrfunction correctly implements the reduction-reduction configuration with both A and B matrix fragmentation and swizzled layouts. The test suite provides comprehensive coverage of all transpose combinations and data types.
520-521: LGTM: Test runner updated for direct execution.The main execution path now directly calls a specific GEMM test case, which is useful for focused testing and debugging.
| shared_16x16_to_mma_32x8_layout_sr_a = shared_16x16_to_mma_a_32x8_layout | ||
| shared_16x16_to_mma_32x8_layout_sr_b = shared_16x16_to_mma_b_32x8_layout | ||
| shared_16x16_to_mma_32x8_layout_rs_a = shared_16x16_to_mma_a_32x8_layout_trans | ||
| shared_16x16_to_mma_32x8_layout_rs_b = shared_16x16_to_mma_b_32x8_layout_trans |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Back-compat: ensure no remaining uses of removed SR/RS aliases.
You introduced per-matrix aliases (_sr_a/_sr_b/_rs_a/_rs_b). If older sites still import shared_16x16_to_mma_32x8_layout_sr/..._rs, they will break.
Run to find potential breakages:
If matches exist, consider re-adding deprecated shims that forward to _sr_a/_rs_b with clear naming, or update call sites in this PR.
🏁 Script executed:
#!/bin/bash
# Old aliases that may be referenced elsewhere
rg -nP '\bshared_16x16_to_mma_32x8_layout_(sr|rs)\b|shared_16x16_to_mma_32x8_layout_trans\b|shared_16x16_to_mma_32x8_layout\b' -g '!**/site-packages/**' -C2Length of output: 148
🏁 Script executed:
#!/bin/bash
# Search across all Python files, including those ignored by default, for any uses of the old aliases or base name
rg --no-ignore -nP '\bshared_16x16_to_mma_32x8_layout_(?:sr|rs)\b|\bshared_16x16_to_mma_32x8_layout_trans\b|\bshared_16x16_to_mma_32x8_layout\b' -g '*.py' -C2Length of output: 1383
Restore deprecated SR/RS aliases or update example usage
• examples/plot_layout/fragment_mma_load_a.py (lines 35–36, 50–51) still imports and invokes the removed
shared_16x16_to_mma_32x8_layout_sr/rs aliases—update these sites to the new per-matrix functions (e.g. _sr_a/_sr_b/_rs_a/_rs_b), or re-add fallback shims in tilelang/intrinsics/mma_layout.py mapping the old names to the new ones.
🤖 Prompt for AI Agents
In tilelang/intrinsics/mma_layout.py around lines 70 to 73, example code still
imports the removed shared_16x16_to_mma_32x8_layout_sr/rs aliases; restore
backward compatibility by re-adding alias shims that map the old names to the
new per-matrix functions (e.g. set shared_16x16_to_mma_32x8_layout_sr_a ->
shared_16x16_to_mma_a_32x8_layout_sr_a, shared_16x16_to_mma_32x8_layout_sr_b ->
shared_16x16_to_mma_b_32x8_layout_sr_b, shared_16x16_to_mma_32x8_layout_rs_a ->
shared_16x16_to_mma_a_32x8_layout_rs_a, shared_16x16_to_mma_32x8_layout_rs_b ->
shared_16x16_to_mma_b_32x8_layout_rs_b) so existing examples keep working;
alternatively update the example files to import and call the new per-matrix
names instead of the old aliases.
tilelang/tileop/gemm/gemm_base.py
Outdated
| @property | ||
| def in_dtype(self) -> str: | ||
| return self.gemm_node.in_dtype | ||
|
|
||
| @property | ||
| def accum_dtype(self) -> str: | ||
| return self.gemm_node.accum_dtype | ||
|
|
||
| @property | ||
| def chunk(self) -> int: | ||
| return self.gemm_node.chunk | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolve duplicate properties (F811) and pick a single source of truth.
in_dtype, accum_dtype, and chunk are defined twice; the latter overrides the former and confuses readers and linters. Recommend keeping the node-backed versions and removing the buffer-derived duplicates (you already assert consistency in __post_init__).
@property
def in_dtype(self) -> str:
- return self.gemm_node.in_dtype
+ return self.gemm_node.in_dtype
@@
@property
def accum_dtype(self) -> str:
- return self.gemm_node.accum_dtype
+ return self.gemm_node.accum_dtype
@@
@property
def chunk(self) -> int:
- return self.gemm_node.chunk
+ return self.gemm_node.chunk
@@
- @property
- def in_dtype(self) -> str:
- assert self.A.dtype == self.B.dtype, "A and B must have the same dtype"
- return self.A.dtype
-
- @property
- def accum_dtype(self) -> str:
- return self.C.dtype
-
- @property
- def chunk(self) -> int:
- return self.A.shape[-2] if self.trans_A else self.A.shape[-1]
+ # If needed later, expose K extent explicitly to avoid overloading "chunk":
+ # @property
+ # def k_extent(self) -> int:
+ # return self.A.shape[-2] if self.trans_A else self.A.shape[-1]Also applies to: 64-76
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_base.py around lines 52-63 (and similarly for
64-76), duplicate property definitions for in_dtype, accum_dtype, and chunk
exist and shadow the previous definitions; remove the buffer-derived duplicate
property methods and keep the node-backed properties (self.gemm_node.*) as the
single source of truth, relying on the existing __post_init__ consistency checks
to ensure they match, and delete the redundant getters so linters and readers
see only one definition per property.
| def _gemm_rsr() -> None: | ||
| """ | ||
| The inner macro that loads data from shared buffers A_shared and | ||
| B_shared into local fragments, then issues Tensor Core mma ops, | ||
| accumulating into C_local. | ||
| """ | ||
|
|
||
| for ki in T.serial(0, (block_K // micro_size_k)): | ||
| # Perform Matrix Multiplication | ||
| mma_emitter.mma(A_local, B_local, C_local, ki) | ||
|
|
||
| # Simplify to optimize the index computing | ||
| # Must inline let statements to simplify the analysis | ||
| return _Simplify(_gemm_rsr, inline_let=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect function name for RR variant
The function is named _gemm_rsr but should be _gemm_rr since this is the RR (fragment-fragment) variant.
@T.prim_func
-def _gemm_rsr() -> None:
+def _gemm_rr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
-return _Simplify(_gemm_rsr, inline_let=True)
+return _Simplify(_gemm_rr, inline_let=True)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _gemm_rsr() -> None: | |
| """ | |
| The inner macro that loads data from shared buffers A_shared and | |
| B_shared into local fragments, then issues Tensor Core mma ops, | |
| accumulating into C_local. | |
| """ | |
| for ki in T.serial(0, (block_K // micro_size_k)): | |
| # Perform Matrix Multiplication | |
| mma_emitter.mma(A_local, B_local, C_local, ki) | |
| # Simplify to optimize the index computing | |
| # Must inline let statements to simplify the analysis | |
| return _Simplify(_gemm_rsr, inline_let=True) | |
| @T.prim_func | |
| def _gemm_rr() -> None: | |
| """ | |
| The inner macro that loads data from shared buffers A_shared and | |
| B_shared into local fragments, then issues Tensor Core mma ops, | |
| accumulating into C_local. | |
| """ | |
| for ki in T.serial(0, (block_K // micro_size_k)): | |
| # Perform Matrix Multiplication | |
| mma_emitter.mma(A_local, B_local, C_local, ki) | |
| # Simplify to optimize the index computing | |
| # Must inline let statements to simplify the analysis | |
| return _Simplify(_gemm_rr, inline_let=True) |
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mma.py around lines 179 to 192, the inner macro is
incorrectly named `_gemm_rsr` but this is the RR (fragment-fragment) variant;
rename the function to `_gemm_rr` and update the return to call
`_Simplify(_gemm_rr, inline_let=True)` (and adjust any internal references or
docstrings mentioning `_gemm_rsr`) so the name matches the RR variant.
| """ | ||
| The inner macro that loads data from shared buffers A_shared and | ||
| B_shared into local fragments, then issues Tensor Core mma ops, | ||
| accumulating into C_local. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Misleading docstring for RR variant
The docstring mentions loading from "shared buffers" but in the RR variant, both A and B are already in local fragments and no loading from shared memory occurs.
"""
-The inner macro that loads data from shared buffers A_shared and
-B_shared into local fragments, then issues Tensor Core mma ops,
-accumulating into C_local.
+The inner macro that performs Tensor Core mma ops using
+local fragments A_local and B_local, accumulating into C_local.
"""📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """ | |
| The inner macro that loads data from shared buffers A_shared and | |
| B_shared into local fragments, then issues Tensor Core mma ops, | |
| accumulating into C_local. | |
| """ | |
| """ | |
| The inner macro that performs Tensor Core mma ops using | |
| local fragments A_local and B_local, accumulating into C_local. | |
| """ |
🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_mma.py around lines 180 to 184, the docstring for
the RR variant incorrectly states that data is loaded from "shared buffers
A_shared and B_shared"; update the docstring to reflect that in the RR variant
both A and B are already in local fragments (no shared-memory loads), and
describe that the macro issues Tensor Core MMA ops using those local fragments
and accumulates into C_local; keep the description concise and specific to the
RR behavior.
| return torch.empty( | ||
| *shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Block non-floating dtypes for Uniform/Normal to avoid silent truncation (ints) and all-True tensors (bool).
Casting float32 samples to integer/boolean dtypes silently changes semantics (e.g., Uniform [-1, 1) → mostly zeros for int32; almost all True for bool). This is likely a regression vs. previously erroring fast paths. Gate these branches to floating/FP8 only (int8 remains handled earlier).
Apply:
@@
- elif supply_type == TensorSupplyType.Uniform:
- return torch.empty(
- *shape, device=device, dtype=torch.float32).uniform_(-1.0, 1.0).to(dtype)
+ elif supply_type == TensorSupplyType.Uniform:
+ # Restrict to floating/FP8 to avoid silent truncation for ints/bools.
+ if dtype not in {torch.float16, torch.float32, torch.float64, torch.bfloat16} and not param.is_float8():
+ raise TypeError(f"Uniform supply requires floating or float8 dtype, got {dtype}. "
+ "Use TensorSupplyType.Integer instead.")
+ return torch.empty(
+ *shape, device=device, dtype=torch.float32
+ ).uniform_(-1.0, 1.0).to(dtype)
@@
- elif supply_type == TensorSupplyType.Normal:
- return torch.empty(
- *shape, device=device, dtype=torch.float32).normal_(-1.0, 1.0).to(dtype)
+ elif supply_type == TensorSupplyType.Normal:
+ if dtype not in {torch.float16, torch.float32, torch.float64, torch.bfloat16} and not param.is_float8():
+ raise TypeError(f"Normal supply requires floating or float8 dtype, got {dtype}. "
+ "Use TensorSupplyType.Integer instead.")
+ return torch.empty(
+ *shape, device=device, dtype=torch.float32
+ ).normal_(-1.0, 1.0).to(dtype)Optional: extract a small helper to de-duplicate the float32→cast pattern for Uniform/Normal.
Also applies to: 119-120
🤖 Prompt for AI Agents
In tilelang/utils/tensor.py around lines 116-120, the Uniform/Normal branches
currently sample in float32 then cast to the requested dtype, which silently
corrupts non-floating dtypes (ints/bool); change these branches to first
validate that the requested dtype is a floating type (or supported FP8) and
raise a clear error if not; keep int8 handling where already done; optionally
factor out a small helper that samples in float32 and casts to dtype to avoid
duplication but only call it after the dtype check so integer/bool dtypes are
rejected before sampling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (5)
tilelang/tileop/gemm/gemm_mma.py (2)
183-197: Rename RR macro for clarity and consistency.This is the RR path but the function is named
_gemm_rsr. Rename to_gemm_rrand update the return.- def _gemm_rsr() -> None: + def _gemm_rr() -> None: @@ - return _Simplify(_gemm_rsr, inline_let=True) + return _Simplify(_gemm_rr, inline_let=True)
126-133: Fix misleading docstrings per variant (SR/RS/RR).Docstrings mention shared loads where fragments are already provided. Tighten wording to match behavior.
@@ - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ + """ + Load A from shared into fragment; B is already a fragment. + Issue Tensor Core MMA and accumulate into C_local. + """ @@ - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ + """ + Load B from shared into fragment; A is already a fragment. + Issue Tensor Core MMA and accumulate into C_local. + """ @@ - """ - The inner macro that loads data from shared buffers A_shared and - B_shared into local fragments, then issues Tensor Core mma ops, - accumulating into C_local. - """ + """ + Use local fragments A_local and B_local. + Issue Tensor Core MMA and accumulate into C_local. + """Also applies to: 156-161, 184-189
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (2)
117-117:num_stages=0will break T.Pipelined - use at least 1Passing
0toT.Pipelined(..., num_stages=...)can lead to invalid scheduling or no prefetch. Software pipelining requires at least 1 stage.Apply this diff to fix the issue:
- run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) + run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)- run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) + run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)- run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) + run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)Also applies to: 251-251, 385-385
521-521: Use "float32" instead of "float" for accum dtype
TensorCoreIntrinEmitter.dtype_abbrvexpects"float32", not"float". Using"float"will cause a KeyError during lowering.Apply this diff:
- run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float32", 128, 256, 32, 2)tilelang/intrinsics/mma_layout.py (1)
94-97: Address remaining usage of removed SR/RS aliases in examples.Based on the past review comment, example files still reference the old
shared_16x16_to_mma_32x8_layout_sr/shared_16x16_to_mma_32x8_layout_rsaliases that were removed. The current aliases use_sr_a/_sr_b/_rs_a/_rs_bsuffixes.Consider adding backward compatibility aliases or updating the example files:
# Add these aliases for backward compatibility shared_16x16_to_mma_32x8_layout_sr = shared_16x16_to_mma_32x8_layout_sr_a # or _sr_b shared_16x16_to_mma_32x8_layout_rs = shared_16x16_to_mma_32x8_layout_rs_a # or _rs_bOr verify that all example files have been updated to use the new per-matrix naming scheme.
🧹 Nitpick comments (13)
.clang-tidy (1)
4-4: Avoid enabling and disabling the same checks; it’s confusingThese checks appear in both keep and disable lists, and end up disabled by the latter. Consider removing them from the retained block for clarity:
- cppcoreguidelines-pro-type-static-cast-downcast
- cppcoreguidelines-pro-type-member-init
- cppcoreguidelines-narrowing-conversions
Suggested cleanup:
- cppcoreguidelines-pro-type-static-cast-downcast, - cppcoreguidelines-pro-type-member-init, - cppcoreguidelines-narrowing-conversions,Also applies to: 43-43, 5-5, 41-41, 9-9, 39-39
src/target/codegen_cuda.cc (2)
22-22: Avoid broad using-directive; prefer a narrow alias.To reduce namespace pollution inside tvm::codegen, consider an alias instead of using namespace.
Example:
-using namespace tvm::tl::codegen; +namespace tl_codegen = tvm::tl::codegen;Then qualify calls, e.g., tl_codegen::PrintMMAAssembly(...).
1335-1341: Verify tl::ptx_ldmatrix_xN argument order and unify both call sites.Here you emit tl::ptx_ldmatrix_xN(_trans)(smem + off, local + off), while the earlier tl::ptx_ldmatrix branch prints args starting at index 2 via print_extern_call_stmt (potentially a different order). Please confirm both wrappers take the same parameter order and adjust if necessary to avoid silent layout bugs. Consider centralizing this emission in one helper to keep the ordering consistent.
tilelang/tileop/gemm/gemm_base.py (1)
1-8: Avoid runtime import cycles; gate GemmWarpPolicy under TYPE_CHECKING.Drop the eager import and keep the annotation as a forward-ref to prevent potential import loops and silence type-checker complaints.
@@ -from tilelang.ir import GemmWarpPolicy +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from tilelang.ir import GemmWarpPolicy @@ - def policy(self) -> GemmWarpPolicy: + def policy(self) -> "GemmWarpPolicy": return self.gemm_node.policyAlso applies to: 117-119
tilelang/tileop/gemm/gemm_mma.py (2)
202-212: De-duplicate is_gemm_ helpers; inherit from GemmBase.*These override identical logic from GemmBase without changes. Remove to keep a single source of truth.
- def is_gemm_ss(self) -> bool: - return is_shared(self.A) and is_shared(self.B) - - def is_gemm_sr(self) -> bool: - return is_shared(self.A) and is_fragment(self.B) - - def is_gemm_rs(self) -> bool: - return is_fragment(self.A) and is_shared(self.B) - - def is_gemm_rr(self) -> bool: - return is_fragment(self.A) and is_fragment(self.B) + # Inherit is_gemm_* from GemmBase
84-89: Guard against non-divisible K blocking.If chunk (K tile) isn’t divisible by micro_size_k, the loop truncates. Add a precondition.
block_K = mma_emitter.chunk micro_size_k = mma_emitter.micro_size_k + assert block_K % micro_size_k == 0, "block_K must be divisible by micro_size_k"If this can be violated by some policy, consider raising ValueError with guidance.
tilelang/intrinsics/mma_macro_generator.py (3)
151-158: Annotate thread-binding getter for clarity.Return type is a Var; helps readers and tools.
- def get_thread_binding(self): + def get_thread_binding(self) -> Var: if self.thread_var is None: current_frame = T.KernelLaunchFrame.Current() assert current_frame is not None, "Must be called in a T.Kernel Frame" return current_frame.get_thread_binding() else: return self.thread_var
456-475: Docstring: this is a load-layout, not “store”.Small wording fix to avoid confusion with make_mma_store_layout.
- Create a layout function for storing MMA results into a fragment buffer. - This layout is used in conjunction with `inverse_mma_store_layout` to - map fragment indices to threads and local indices. + Create a layout function for loading data into an MMA fragment buffer. + This layout maps fragment indices to threads and local indices for MMA loads.
124-143: Tighten assert messages or use ValueError for consistency.Messages say “greater than” but the check is “>=”. Either update text or switch to explicit ValueError for ruff TRY003 friendliness.
- assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" + assert warp_row_tiles >= 16, f"warp_row_tiles must be >= 16, got {warp_row_tiles}" - assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" + assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" - assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be >= 8, got {warp_col_tiles}" - assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"src/target/ptx.h (2)
44-45: Inconsistent parameter naming conventions in function documentationThe documentation uses inconsistent naming - sometimes with underscores (
a_offset,a_ptr) and sometimes without (smem_ptr). Consider standardizing the parameter names.Apply this naming convention consistently across all function signatures and documentation. For example:
- Use
a_elem_offsetinstead of mixinga_offsetin docs anda_elem_offsetin signature- Keep parameter names consistent between the documentation comments and actual function signatures
Also applies to: 46-47, 60-61, 74-76, 88-91, 103-105
150-154: Fix comment formatting for mbarrier documentationMulti-line parameter documentation is improperly formatted with
\paramsplit across lines.Apply this diff to fix the documentation:
-/*! - * \brief Print ptx barrier arrival with expect tx operation using - * mbarrier.arrive.expect_tx \param barrier: The name of the barrier in shared - * memory. \param byte_count: Increases the tx count of the mbarrier object to - * track completion of addtional async transactions. - */ +/*! + * \brief Print ptx barrier arrival with expect tx operation using mbarrier.arrive.expect_tx + * \param barrier The name of the barrier in shared memory. + * \param byte_count Increases the tx count of the mbarrier object to + * track completion of additional async transactions. + */src/target/ptx.cc (2)
83-133: Consider using a map for more efficient dtype lookupsThe current implementation uses multiple if-else statements for string to DataType conversion. For better performance and maintainability, consider using an unordered_map.
Consider refactoring to use a static map:
inline DataType DTypeFromString(const std::string& str) { static const std::unordered_map<std::string, DataType> dtype_map = { {"int4", DataType::kInt4}, {".s4", DataType::kInt4}, {"uint4", DataType::kUInt4}, {".u4", DataType::kUInt4}, {"int8", DataType::kInt8}, {".s8", DataType::kInt8}, // ... rest of the mappings }; auto it = dtype_map.find(str); if (it != dtype_map.end()) { return it->second; } LOG(FATAL) << "Unrecognized PTX data type " << str; }
452-464: String replacement implementation is correct but could be more efficientThe current implementation correctly handles overlapping patterns by updating the search position. However, for better performance with many rules, consider pre-compiling patterns or using a more sophisticated string replacement algorithm.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
.clang-tidy(1 hunks)CMakeLists.txt(1 hunks)src/target/codegen_cuda.cc(4 hunks)src/target/ptx.cc(1 hunks)src/target/ptx.h(1 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py(7 hunks)tilelang/intrinsics/mma_layout.py(3 hunks)tilelang/intrinsics/mma_macro_generator.py(18 hunks)tilelang/intrinsics/utils.py(3 hunks)tilelang/tileop/gemm/gemm_base.py(1 hunks)tilelang/tileop/gemm/gemm_mma.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
tilelang/tileop/gemm/gemm_mma.py (5)
tilelang/tileop/gemm/gemm_base.py (18)
GemmBase(11-119)infer_layout(14-15)policy(118-119)M(33-34)N(37-38)in_dtype(53-55)accum_dtype(58-59)trans_A(45-46)trans_B(49-50)chunk(62-63)is_gemm_ss(20-21)A(66-67)B(70-71)C(74-75)is_gemm_sr(23-24)is_gemm_rs(26-27)is_gemm_rr(29-30)lower(17-18)tilelang/intrinsics/mma_macro_generator.py (10)
make_mma_store_layout(596-665)make_mma_load_layout(452-594)ldmatrix_a(199-261)ldmatrix_a(753-857)ldmatrix_b(263-333)ldmatrix_b(859-971)mma(335-395)mma(973-1022)mma(1027-1125)mma(1130-1229)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/transform/simplify.py (1)
_Simplify(30-49)tilelang/tileop/gemm/__init__.py (2)
infer_layout(52-57)lower(59-65)
src/target/ptx.h (1)
src/target/ptx.cc (20)
PrintMMAAssembly(569-632)PrintMMAAssembly(570-579)PrintLoadMatrixAssembly(657-693)PrintLoadMatrixAssembly(657-662)PrintCpAsyncAssembly(695-722)PrintCpAsyncAssembly(695-699)PrintPredicatedCpAsyncAssembly(724-780)PrintPredicatedCpAsyncAssembly(724-727)PrintCpAsyncBulkAsm(782-809)PrintCpAsyncBulkAsm(782-787)PrintCpAsyncBarrierAsm(811-826)PrintCpAsyncBarrierAsm(811-811)PrintInitBarrierThreadCountAsm(828-846)PrintInitBarrierThreadCountAsm(828-829)PrintArriveBarrierAsm(848-863)PrintArriveBarrierAsm(848-848)PrintArriveBarrierExpectTxAsm(865-883)PrintArriveBarrierExpectTxAsm(865-866)PrintWaitBarrierAsm(885-901)PrintWaitBarrierAsm(885-885)
tilelang/tileop/gemm/gemm_base.py (4)
tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/carver/roller/node.py (1)
Node(93-175)tilelang/tileop/gemm/gemm_mma.py (6)
infer_layout(15-58)lower(60-200)is_gemm_ss(202-203)is_gemm_sr(205-206)is_gemm_rs(208-209)is_gemm_rr(211-212)tilelang/tileop/gemm/__init__.py (2)
infer_layout(52-57)lower(59-65)
src/target/codegen_cuda.cc (1)
src/target/ptx.h (1)
codegen(33-164)
tilelang/intrinsics/mma_macro_generator.py (5)
tilelang/intrinsics/utils.py (2)
mma_store_index_map(81-82)get_ldmatrix_offset(21-63)tilelang/utils/language.py (1)
is_fragment(68-78)tilelang/intrinsics/mma_layout.py (5)
mma_load_a_32x4_to_shared_16x8_layout(130-133)mma_load_b_32x4_to_shared_16x8_layout(136-139)mma_load_a_32x16_to_shared_16x32_layout(142-145)mma_load_b_32x16_to_shared_16x32_layout(148-151)transform_func(212-215)tilelang/language/kernel.py (5)
threads(195-199)get_thread_binding(151-156)get_thread_binding(261-265)KernelLaunchFrame(72-206)Current(115-121)tilelang/layout/fragment.py (2)
replicate(146-160)repeat(123-144)
tilelang/intrinsics/utils.py (1)
tilelang/intrinsics/mma_layout.py (7)
ldmatrix_32x4_to_shared_16x8_layout_a(6-9)ldmatrix_32x4_to_shared_16x8_layout_b(12-15)ldmatrix_32x8_to_shared_16x16_layout(18-21)ldmatrix_trans_32x8_to_shared_16x16_layout(24-27)ldmatrix_32x16_to_shared_16x32_layout_a(30-33)ldmatrix_32x16_to_shared_16x32_layout_b(36-39)transform_func(212-215)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6)
tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/gemm.py (1)
gemm_v2(186-359)tilelang/jit/kernel.py (2)
get_profiler(360-376)out_idx(446-447)tilelang/utils/tensor.py (1)
TensorSupplyType(11-18)tilelang/language/__init__.py (1)
annotate_layout(102-140)tilelang/layout/fragment.py (1)
make_swizzled_layout(210-216)
🪛 Ruff (0.12.2)
tilelang/tileop/gemm/gemm_mma.py
57-58: Avoid specifying long messages outside the exception class
(TRY003)
199-200: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_macro_generator.py
122-122: Avoid specifying long messages outside the exception class
(TRY003)
224-224: Avoid specifying long messages outside the exception class
(TRY003)
290-290: Avoid specifying long messages outside the exception class
(TRY003)
503-503: Avoid specifying long messages outside the exception class
(TRY003)
520-520: Avoid specifying long messages outside the exception class
(TRY003)
578-578: Avoid specifying long messages outside the exception class
(TRY003)
592-592: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/utils.py
41-41: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-nvidia
- GitHub Check: bot-task
🔇 Additional comments (18)
.clang-tidy (1)
43-45: Document rationale for disabling performance-enum-size and confirm clang-tidy supportAdd a brief comment explaining the ABI risk above the disabled check, then ensure CI installs clang-tidy and includes “performance-enum-size”:
-performance-unnecessary-value-param, + # Rationale: changing enum underlying types can break ABI; disable until enum usage around T.gemm_v2 is audited. -performance-enum-size,Example CI validation:
command -v clang-tidy >/dev/null && clang-tidy --version \ && clang-tidy -list-checks -checks='*' | grep -q 'performance-enum-size'CMakeLists.txt (1)
135-135: PTX source inclusion under USE_CUDA looks correct.Adding src/target/ptx.cc to TILE_LANG_CUDA_SRCS is properly gated by USE_CUDA and aligns with the new helpers used by codegen_cuda.cc.
src/target/codegen_cuda.cc (2)
17-17: Include path switch to local ptx.h is appropriate.The relative include matches src/target/ptx.h and avoids stale paths.
1263-1265: Indentation before emitting asm improves readability.The added PrintIndent() before MMA/PTX emission keeps generated code aligned.
Also applies to: 1299-1305
tilelang/intrinsics/utils.py (1)
52-59: Int8 layout switch to 32x16→16x32 A/B functions: looks right; please validate all transpose modes.The A-non-transposed/B-transposed restriction matches ldmatrix constraints for int8. Please confirm coverage with tests across rs/sr/rr variants to avoid corner-case misindexing.
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (2)
47-47: Good adoption of T.gemm_v2 for experimental testingThe transition from
T.gemmtoT.gemm_v2aligns well with the PR's experimental nature. The commented-out original line provides a convenient fallback for debugging.
170-172: Consistent use of swizzled layouts for fragment optimizationGood implementation of swizzled memory layouts for the fragment variants (rs, sr, rr). This optimization improves memory access patterns for tensor cores.
Also applies to: 304-306, 440-443
src/target/ptx.cc (3)
594-631: Well-structured PTX assembly generationThe MMA assembly generation is well-organized with clear separation of concerns: validation, template generation, and pattern replacement. The use of the Replacer class provides good maintainability.
748-766: No action needed: C++17 support already enabled
CMakeLists.txt setsCMAKE_CXX_STANDARDto 17, so structured bindings are supported.
377-383: Correct layout validation logic – The condition properly enforces row-major A and column-major B for all cases except the supported m8n8k4 fp16 configuration, where custom layouts are allowed.tilelang/intrinsics/mma_layout.py (8)
6-9: LGTM! Clean implementation of ldmatrix layout for matrix A.The function correctly maps thread_id and local_id to row/col coordinates for a 32x4 to 16x8 layout transformation specific to matrix A.
12-15: LGTM! Proper matrix B layout implementation.The layout calculation correctly handles the different indexing pattern needed for matrix B in the 32x4 to 16x8 transformation.
18-21: LGTM! Consistent 32x8 to 16x16 layout mapping.The function maintains consistency with the existing pattern and correctly handles the wider local_id range (8 instead of 4).
24-27: LGTM! Proper transposed layout implementation.The transposed variant correctly swaps the indexing pattern while maintaining the same coordinate calculation logic.
50-73: LGTM! Well-structured matrix A/B layout separation.The refactor to separate matrix A and B layouts improves clarity and maintainability. The transpose variants and SR/RS aliases provide good backward compatibility.
76-97: LGTM! Consistent 16x16 to 32x8 layout implementation.The matrix A/B separation follows the same pattern established for the 32x4 layouts, with proper transpose variants and SR/RS aliases.
100-121: LGTM! Proper 32x16 layout implementation.The 16x32 to 32x16 layout functions correctly handle the larger dimension while maintaining consistency with the established patterns.
130-151: LGTM! MMA load helper functions are well-implemented.The new load helper functions provide clear separation between matrix A and B loading patterns, with correct coordinate calculations for each matrix type.
| const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_a) / | ||
| frag_attr_a.size / threads / (sparse ? 2 : 1), | ||
| num_operands_b = | ||
| (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads, | ||
| num_operands_c = | ||
| (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential integer overflow in operand calculations
The calculations for num_operands_* involve multiplication and division of potentially large values without overflow checks.
Add overflow checks or use safer integer types:
+ // Check for potential overflow
+ const uint64_t bits_a = static_cast<uint64_t>(m) * k * ptx::DTypeBits(dtype_a);
+ const uint64_t bits_b = static_cast<uint64_t>(k) * n * ptx::DTypeBits(dtype_b);
+ const uint64_t bits_c = static_cast<uint64_t>(m) * n * ptx::DTypeBits(dtype_c);
+
+ CHECK(bits_a / frag_attr_a.size / threads <= INT_MAX) << "Operand count overflow";
+
const int num_operands_a = (m * k) * ptx::DTypeBits(dtype_a) /
frag_attr_a.size / threads / (sparse ? 2 : 1),Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/target/ptx.cc around lines 502-507, the computation of num_operands_a/b/c
can overflow because several potentially large multiplications are done in int
before division; change the intermediate arithmetic to a wider unsigned integer
type (e.g., uint64_t or use __int128 for intermediate products), perform checks
for multiplication overflow (use builtin helpers like __builtin_mul_overflow or
check operands against max/value limits) before dividing, and cast back to the
target type only after safe division; if overflow is detected, handle it
deterministically (return error, clamp, or log) rather than producing a wrapped
value.
| * \brief Code generation with inlined PTX code. | ||
| */ | ||
| #ifndef TVM_TL_TARGET_SOURCE_PTX_H_ | ||
| #define TVM_TL_TARGET_SOURCE_PTX_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect header guard macro name
The header guard uses TVM_TL_TARGET_SOURCE_PTX_H_ but the file is located at src/target/ptx.h. The macro should match the actual path.
Apply this diff to fix the header guard:
-#ifndef TVM_TL_TARGET_SOURCE_PTX_H_
-#define TVM_TL_TARGET_SOURCE_PTX_H_
+#ifndef TVM_TL_TARGET_PTX_H_
+#define TVM_TL_TARGET_PTX_H_Also update line 167:
-#endif // TVM_TL_TARGET_SOURCE_PTX_H_
+#endif // TVM_TL_TARGET_PTX_H_📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #define TVM_TL_TARGET_SOURCE_PTX_H_ | |
| #ifndef TVM_TL_TARGET_PTX_H_ | |
| #define TVM_TL_TARGET_PTX_H_ | |
| // … rest of ptx.h contents … | |
| #endif // TVM_TL_TARGET_PTX_H_ |
🤖 Prompt for AI Agents
In src/target/ptx.h around line 25, the header guard macro is incorrect
(currently TVM_TL_TARGET_SOURCE_PTX_H_) and must match the file path; replace it
with a macro appropriate for src/target/ptx.h (e.g., SRC_TARGET_PTX_H_) and
ensure the closing #endif at line 167 is updated to reference the new macro name
so the guard pairs correctly.
| ldmatrix_32x4_to_shared_16x8_layout_a, | ||
| ldmatrix_32x4_to_shared_16x8_layout_b, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unreachable tf32/32-bit branch and its now-unused imports.
dtype is annotated as Literal["float16","int8"]; DataType(dtype).bits can’t be 32, so this block is dead code. The imports for ldmatrix_32x4_to_shared_16x8_layout_{a,b} become unused. Either wire this up with a real tf32 path (and broaden the type hints/tests) or drop it for now to avoid confusion.
Apply:
-from .mma_layout import (
- ldmatrix_32x4_to_shared_16x8_layout_a,
- ldmatrix_32x4_to_shared_16x8_layout_b,
+from .mma_layout import (
ldmatrix_32x8_to_shared_16x16_layout,
ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b,
mma_store_32x8_to_shared_16x16_layout,
)
@@
- if dtype_bits == 32:
- if matrix == "B" and transposed:
- transform_func = ldmatrix_32x4_to_shared_16x8_layout_b
- new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
- return new_row_idx * stride + new_col_idx
- elif matrix == "A" and not transposed:
- transform_func = ldmatrix_32x4_to_shared_16x8_layout_a
- new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
- return new_row_idx * stride + new_col_idx
- else:
- raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8")
+ # tf32 not yet supported (see PR TODO); 32-bit branch intentionally omitted.Also applies to: 31-42
🤖 Prompt for AI Agents
In tilelang/intrinsics/utils.py around lines 4-5 and 31-42 remove the dead
tf32/32-bit branch and the now-unused imports
ldmatrix_32x4_to_shared_16x8_layout_a and ldmatrix_32x4_to_shared_16x8_layout_b:
update the dtype handling to only consider the Literal["float16","int8"] cases
(or if you intend to support tf32/32-bit, broaden the dtype type hints/tests and
add a real code path), delete the unused import lines, and run tests/linters to
ensure no remaining references to those symbols remain.
- Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`. - This change simplifies the allocation process and aligns with the updated GEMM function signatures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (2)
117-117: num_stages must be >= 1 for T.Pipelined.Using 0 is invalid/fragile and has been flagged previously. Set to 1 (or 2) to enable prefetch.
- run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) + run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128) - run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) + run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128) - run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) + run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 1, 128)#!/bin/bash # Find any remaining uses of num_stages=0 in tests. rg -nP '\bnum_stages\s*=\s*0\b' -g 'testing/**' # Sanity-check run_gemm_* callsites passing a literal 0 as the stage arg. rg -nP 'run_gemm_(ss|sr|rs|rr)\([^)]*,\s*0\s*(,|\))' -C2 -g 'testing/**'Also applies to: 251-251, 385-385
521-521: Use "float32" (not "float") for accum dtype to match emitter mapping.The emitter expects "float32"; "float" can KeyError during lowering. This was raised before; adjust here as well.
- run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2) + run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float32", 128, 256, 32, 2)#!/bin/bash # Ensure no remaining accum dtype "float" with bfloat16 inputs. rg -nP 'run_gemm_\(ss|sr|rs|rr\)\([^)]*"bfloat16"[^)]*"bfloat16"[^)]*"float"'
🧹 Nitpick comments (3)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (3)
47-49: Remove the commented legacy GEMM call for clarity.Keep only the v2 call to avoid confusion.
- # T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
128-128: Misleading comment: these are float32 tests, not TF32.TF32 support is listed as TODO; rename the comment to reflect fp32.
- # tfloat32 test + # float32 tests (TF32 not enabled yet)
54-68: Reduce duplication across run_gemm_ wrappers.*Factor a small helper to compile, get_profiler, and assert against ref_program; pass the matmul builder as a parameter. Keeps variants focused on their distinct tiling only.
Also applies to: 189-203, 461-475
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (6)
tilelang/language/gemm.py (1)
gemm_v2(186-359)tilelang/language/copy.py (1)
copy(84-152)tilelang/jit/kernel.py (2)
get_profiler(360-376)out_idx(446-447)tilelang/utils/tensor.py (1)
TensorSupplyType(11-18)tilelang/language/__init__.py (1)
annotate_layout(102-140)tilelang/layout/fragment.py (1)
make_swizzled_layout(210-216)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
🔇 Additional comments (3)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py (3)
170-173: Swizzled A_shared layout usage looks correct.Annotating A_shared with a swizzled layout matches the rs variant intent.
304-306: Swizzled B_shared layout usage looks correct.Matches the sr variant and improves MMA load patterns.
441-443: Swizzling both A_shared and B_shared is appropriate for rr.This aligns with the rr data path and should help bank conflict avoidance.
…#793) * Refactor GEMM and GEMM-SP operations to enhance clarity and maintainability - Removed deprecated prime factorization functions from `gemm.cc` and `gemm_sp.cc`. - Introduced a new `GemmWarpPolicy` class to manage warp policy attributes and methods, improving encapsulation. - Updated reflection methods to include the new policy structure, ensuring proper registration and introspection capabilities. - Enhanced `GetArchInt` function in `utils.cc` for better readability and type safety. - Added new `gemm_v2` function in `gemm.py` for improved GEMM operation with additional parameters and checks. * Refactor GEMM and frontend legalize operations for improved clarity and functionality - Updated `gemm_py.h` to include the correct header for GEMM operations. - Renamed `FrontendLegalizer` class to `LetInliner` and updated related methods to reflect this change, enhancing code clarity. - Modified the pass function from `FrontendLegalize` to `LetInline` for better alignment with its purpose. - Updated test cases to utilize the new `gemm_v2` function and adjusted the testing framework for improved output and clarity. - Removed obsolete test file `test_tilelang_transform_frontend_legalize.py` to streamline the test suite. - Enhanced the `LowerAndLegalize` function to utilize the new `LetInline` pass, improving the overall transformation process. * Enhance CUDA code generation and testing for GEMM operations - Added indentation printing in `codegen_cuda.cc` for improved assembly code formatting. - Updated `test_tilelang_tilelibrary_gemm.py` to include additional GEMM test cases and shared memory allocation with specified scope. - Introduced new `matmul_sr` and `run_gemm_sr` functions for GEMM operations with shared and fragment memory layouts. - Refactored layout inference in `mma_macro_generator.py` to improve clarity and correctness in shared memory handling. - Enhanced `gemm/__init__.py` to support new GEMM operation combinations and layout inference logic. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework. * Refactor GEMM layout and testing for improved clarity and functionality - Updated `gemm_layouts.cc` to enhance the layout generation logic for transposed and non-transposed GEMM operations. - Renamed and modified functions in `test_tilelang_tilelibrary_gemm.py` to reflect changes in GEMM function signatures and improve test coverage. - Introduced new GEMM operation combinations in `gemm/__init__.py` to support additional layouts and configurations. - Enhanced layout inference in `mma_layout.py` and `mma_macro_generator.py` for better handling of shared memory layouts. These changes improve the clarity, functionality, and testing coverage of GEMM operations in the TileLang framework. * Refactor GEMM layout and Python integration for improved functionality - Updated `gemm_layouts.cc` to correct the order of layout replication and repetition for transposed and non-transposed GEMM operations. - Enhanced `gemm_py.cc` to handle block realization more robustly, ensuring correct assignment of global symbols and block attributes. - Refactored `inject_pipeline.cc` to streamline buffer read/write region handling, improving clarity and maintainability. - Cleaned up test cases in `test_tilelang_tilelibrary_gemm.py` by removing unnecessary print statements and adjusting function calls for better test execution flow. These changes enhance the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework. * Refactor GEMM layout and testing for improved clarity and functionality - Updated `gemm_layouts.cc` to enhance layout generation logic for transposed and non-transposed GEMM operations. - Improved block realization handling in `gemm_py.cc` for better assignment of global symbols. - Streamlined buffer read/write region handling in `inject_pipeline.cc` for clarity. - Enhanced test cases in `test_tilelang_tilelibrary_gemm.py` by adjusting function calls and adding new GEMM operation combinations. These changes improve the clarity, functionality, and robustness of GEMM operations and their testing in the TileLang framework. * tfloat32 support. * lint fix * lint fix * Refactor shared memory allocation in GEMM tests - Removed unnecessary scope specification in shared memory allocation for matrices A and B in `test_tilelang_tilelibrary_gemm.py`. - This change simplifies the allocation process and aligns with the updated GEMM function signatures.
This change will help us integrate additional backends and reduce compilation time.
Note: This pull request is experimental and currently targets MMA only.
TODOs for this PR
ldmatrixfor improved readabilitySummary by CodeRabbit
New Features
Improvements/Refactor
Tests
Chores