[EagerJIT] Add Support for Parameter Only Kernel Compilation#1664
[EagerJIT] Add Support for Parameter Only Kernel Compilation#1664LeiWang1999 merged 3 commits intotile-ai:mainfrom
Conversation
…arity and handling of annotations
…jit-bug-call-as-annot
…update kernel compilation parameters
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughChanges add support for compile-time constants in JIT-compiled functions, expand AST argument type annotation parsing to handle Call and Subscript patterns with explicit kwargs parameter representation, and refactor the constraint matcher from 3-tuple to 4-tuple structure to include buffer element names for phase-2 key resolution. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_eager_jit.py (1)
228-244: Test compiles but doesn't verify execution correctness.This test validates that parameter-only compilation works, but unlike other tests in this file (e.g.,
test_jit2_gemm), it doesn't execute the kernel or verify results. Consider adding execution and assertion to catch potential issues in the transpose logic.♻️ Suggested enhancement to verify execution
transpose.compile(M=1024, N=1024, block_M=64, block_N=64) + + # Verify correctness + X = torch.randn(1024, 1024, dtype=torch.float32, device="cuda") + Y = torch.empty(1024, 1024, dtype=torch.float32, device="cuda") + transpose(X, Y) + torch.testing.assert_close(Y, X.T)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
testing/python/language/test_tilelang_language_eager_jit.pytesting/python/layout/test_tilelang_annotate_loop_layout.pytilelang/language/eager/ast.pytilelang/language/eager/builder.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-06T05:20:51.649Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:51.649Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
testing/python/layout/test_tilelang_annotate_loop_layout.pytesting/python/language/test_tilelang_language_eager_jit.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/layout/test_tilelang_annotate_loop_layout.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_eager_jit.py
🧬 Code graph analysis (4)
testing/python/layout/test_tilelang_annotate_loop_layout.py (2)
tilelang/jit/__init__.py (2)
compile(47-107)compile(372-398)tilelang/jit/adapter/tvm_ffi.py (1)
get_kernel_source(311-316)
tilelang/language/eager/builder.py (1)
tilelang/jit/exceptions.py (2)
JITNoBuilderError(4-13)EagerJITBuildError(16-24)
tilelang/language/eager/ast.py (4)
tilelang/language/eager/builder.py (1)
arg(670-674)tilelang/language/dtypes.py (1)
dtype(14-19)tilelang/jit/adapter/tvm_ffi.py (1)
func(206-260)tilelang/language/proxy.py (3)
TensorProxy(137-155)StridedTensorProxy(158-167)ptr(256-275)
testing/python/language/test_tilelang_language_eager_jit.py (2)
tilelang/language/eager/builder.py (1)
const(851-879)tilelang/language/allocate.py (1)
alloc_shared(39-54)
🪛 Ruff (0.14.10)
tilelang/language/eager/builder.py
942-946: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (10)
testing/python/layout/test_tilelang_annotate_loop_layout.py (4)
28-32: LGTM! Consistent update to dimension-based compilation.The change from tensor-based to keyword-argument-based compilation (
M=M, N=N) aligns with the new parameter-only kernel compilation feature. The assertion correctly validates the expected vectorized memory access pattern.
44-47: LGTM!Same consistent pattern as above for the identity layout test.
70-76: LGTM!Updated to dimension-based compilation with the reformatted assertion checking the expected bitwise masking pattern (
& 63).
102-108: LGTM!Consistent with other test updates in this file.
tilelang/language/eager/ast.py (2)
475-475: LGTM! Enables kwargs passthrough for parameter-only compilation.Adding
__kwargsto the function arguments allows keyword arguments to flow through the AST transformation, supporting the new dimension-based compilation API.
514-527: LGTM! Expanded type hint parsing for tensor annotations.The new case handling correctly identifies
T.Tensor[...],T.StridedTensor[...], andT.ptr(...)patterns by checking bothast.Callandast.Subscriptnode types. Storingptrfor all tensor-like types allows the eager JIT builder to recognize these as buffer arguments requiring constexpr substitution.tilelang/language/eager/builder.py (4)
893-893: LGTM! Type annotation updated for 4-tuple matcher.The matcher now stores
(buffer_name, type, index, constexpr_name)to support both tensor-based inference and direct keyword argument lookup.
900-905: LGTM! Matcher construction captures constexpr variable names.Storing
s.name(the constexpr variable name) alongside the buffer metadata enables the phase-2 parser to resolve constants via direct keyword arguments likeM=1024.
929-946: LGTM! Phase-2 parsing supports both named constants and tensor inference.The updated logic correctly prioritizes direct keyword lookup (
kwargs.get(name)) before falling back to tensor-based shape/stride extraction. The error message is clear and actionable, guiding users to provide either the constant value or the corresponding tensor.
1027-1033: LGTM! TypeError handling improves lazy/eager detection robustness.Adding
TypeErrorto the caught exceptions handles cases where calling the original function fails due to argument mismatches, correctly identifying such functions as eager-style.
…#1664) * [Fix] Refactor type hint extraction logic in DSLMutator for better clarity and handling of annotations * [Refactor] Remove redundant tensor creation in loop layout tests and update kernel compilation parameters
* finish KDA algorithm in tilelang * fix pre-commit.ci * fix pre-commit.ci * fix pre-commit local * [Style] Fix some code styles * [Refactor] Remove redundant swizzle for they can be automatically done * [Refactor] remove chunk_bwd_intra.py and rename chunk_bwd_intra_op.py and do some fix form coderabbitai * update ruff * update pre-commit * [Enhancement] Improve unroll loop functionality for dynamic extent and corresponding test case (#1654) * Add unroll loop functionality and corresponding test case - Introduced a new `UnrollLoop` function in the transform module to unroll loops based on various configuration options. - Added a test case in `test_tilelang_language_unroll.py` to validate the behavior of `T.unroll` with only the extent parameter, ensuring correct kernel generation with unroll pragmas. * Refactor unroll kernel implementation and update test case - Changed the kernel function in `test_tilelang_language_unroll.py` to use a new `unroll_kernel` function that compiles and returns the output tensor, improving clarity and structure. - Updated the `OptimizeForTarget` function in `phase.py` to ensure the `UnrollLoop` transformation is applied correctly, maintaining consistency in optimization phases. * lint fix * lint fix * [Bugfix] Fix missing annotations for default CallNode Visitor (#1659) tvm fix * [Clean] Remove unnecessary debug print (#1661) remove unnecessary debug print * [Bugfix] Fix variable scoping issue in InjectSoftwarePipeline for transitive LetStmt dependencies (#1657) * [Enhancement] Update global load/store functions for CUDA compatibility (#1652) Refactor the `ld_global_256` and `st_global_256` functions to support both CUDA versions above 12.9 and earlier versions. This change ensures that 256-bit loads and stores are handled correctly across different CUDA versions, improving performance and compatibility. The implementation now uses two 128-bit loads/stores for older versions, enhancing the robustness of the codebase. * Update comments in global load/store functions for CUDA compatibility Clarified comments in `ld_global_256` and `st_global_256` functions to indicate that the fallback for CUDA versions below 12.9 may have performance regressions. This change enhances code readability and provides better context for developers working with different CUDA versions. * Update submodule and enhance LetStmt handling in inject_pipeline.cc - Updated the TVM submodule to the latest commit. - Improved the handling of LetStmt in the inject_pipeline.cc file to account for transitive dependencies on loop variables, ensuring correct variable substitution in rewritten blocks. - Adjusted test_tilelang_issue_1263.py to remove unnecessary jit decorator and updated the kernel compilation process with specific pass configurations. * lint fix * revert tvm * remove unused test * test fix * [Refactor] Improve CallNode handling to include annotations in various operations (#1663) * [Enhancement] Update CallNode handling to include annotations in various operations - Modified CallNode invocations in multiple files to ensure that annotations are passed correctly, enhancing the consistency and functionality of the codebase. - Removed the "use_tma" annotation from AtomicAddNode and adjusted related calls to maintain expected behavior. - Updated CUDA intrinsic dispatch functions to include annotations, improving compatibility and correctness in CUDA operations. * lint fix * [EagerJIT] Add Support for Parameter Only Kernel Compilation (#1664) * [Fix] Refactor type hint extraction logic in DSLMutator for better clarity and handling of annotations * [Refactor] Remove redundant tensor creation in loop layout tests and update kernel compilation parameters * [AutoDD] Add Tilelang AutoDD to Reduce Buggy Program (#1639) * [Feat] Add tilelang autodd for delta debugging * fix typos * fix lint error * fix typos * fix lint error * fix bugs * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix codeview comments * [Refactor] Move AutoDD detection to env module and update import logic * Refactor: Relocate the _is_running_autodd function to the env module for better organization and encapsulation. * Update initialization logic to skip logger and heavy imports based on a new light import mode, enhancing flexibility in module usage. * Ensure consistent handling of environment variables across the package, improving overall code clarity and maintainability. * [Documentation] Add AutoDD section to debug_tools_for_tilelang.md * Introduced a comprehensive guide on AutoDD (Automatic Delta Debugging) for isolating bugs in TileLang programs. * Explained Delta Debugging methodology, usage, parameters, and provided examples for clarity. * Highlighted the benefits of using AutoDD for large codebases and hard-to-locate errors, emphasizing time-saving aspects. * Included tips for effective usage and a reference to a complete example in the documentation. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> * rebase origin * [Feature] Support `cp.reduce.async.bulk.tensor` (#1667) * support cp.reduce.async.bulk.tensor and add test * Refactor flash attention example by removing unnecessary layout annotations * support swizzle layout for tma reduce * auto swizzle for non-1d tma atomic add * upd example and test * lint * typo * add constraint for test * Refactor CUDA data type mapping by moving the to_CUtensorMapDataType function to utils.cc and utils.h, while removing redundant definitions from atomic_add.cc and copy.cc. * lint * rename basename according to CI * Update submodule TVM and remove deprecated KDA example files - Updated the TVM submodule to commit 354eef9a. - Removed several outdated KDA example files and utility scripts that are no longer in use, including chunk_bwd_dqkwg.py, chunk_bwd_dv.py, chunk_bwd_gla_dA.py, chunk_bwd_intra.py, chunk_delta_bwd.py, chunk_delta_h_fwd.py, chunk_inter_solve_fused.py, chunk_intra_token_parallel.py, chunk_o.py, README.md, test_utils_kda.py, wy_fast_bwd.py, wy_fast.py, and various FLA_KDA implementations. * lint fix --------- Co-authored-by: wufang <wufang@MBP-MK6VR66Y2M-2329.local> Co-authored-by: tzj-fxz <tzjfxz@gmail.com> Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Co-authored-by: Kuris <227995639+kurisu6912@users.noreply.github.com> Co-authored-by: Kexing Zhou <KEKE_046@pku.edu.cn> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: LeiWang1999 <leiwang1999@outlook.com> Co-authored-by: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com> Co-authored-by: Tong WU <109033598+Rachmanino@users.noreply.github.com>
This pr add support for the grammar to compile kernel with only constexpr parameters.