-
Notifications
You must be signed in to change notification settings - Fork 447
[Refactor] Unify @jit and @lazy_jit into a single @jit decorator #1632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…nd tests Updated all instances of the @tilelang.lazy_jit decorator to @tilelang.jit in the lazyjit example notebooks and related test files. This change aligns the code with the new JIT compilation approach, enhancing consistency across the codebase. Additionally, removed the lazy_jit import from the module initialization to streamline the API.
…l loops Updated the T.copy function to accept a new keyword-only parameter, loop_layout, allowing users to specify layout hints for the outermost parallel loop. Enhanced layout annotation handling in CopyNode and AtomicAddNode classes to ensure compatibility with SIMT operations. Added tests to validate the functionality of loop layout annotations, improving robustness and clarity in layout management for parallel loops.
…tor/unify-jit-decorator
…tor/unify-jit-decorator
…ernel and T.const functions. Ensure proper error handling when Builder is unavailable, improving robustness in JIT context management.
… to improve performance metrics output by adding average FLOPS calculations. Modify layout_inference.cc to refine buffer layout inference checks by ensuring fragment buffers are considered.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis PR consolidates TileLang's JIT system by replacing the Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR unifies @tilelang.lazy_jit and @tilelang.jit into a single @tilelang.jit decorator that automatically infers execution mode based on function behavior. The unified decorator supports two styles: lazy mode (functions that explicitly return a PrimFunc) and eager mode (functions that use the DSL builder pattern with T.const() and T.Kernel()).
Key changes:
- Added automatic mode inference by detecting whether functions return a
PrimFuncor raiseJITNoBuilderError - Introduced custom exception classes (
JITNoBuilderError,EagerJITBuildError) for distinguishing execution modes - Removed
lazy_jitfrom public exports - Updated all examples and tests to use the unified
@tilelang.jitdecorator
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
tilelang/jit/__init__.py |
Unified JIT decorator implementation with automatic mode inference and removed lazy_jit function |
tilelang/jit/exceptions.py |
Added new exception classes for JIT mode detection |
tilelang/language/v2/builder.py |
Enhanced LazyJITFunc with mode inference logic and updated TirTemplate to support lazy-style functions |
tilelang/language/kernel.py |
Added Builder existence check to T.Kernel() for eager mode validation |
tilelang/__init__.py |
Removed lazy_jit from public API exports |
testing/python/layout/test_tilelang_annotate_loop_layout.py |
Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit |
testing/python/language/test_tilelang_language_subtype.py |
Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit |
testing/python/language/test_tilelang_language_lazy_jit.py |
Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit |
examples/lazy_jit/lazyjit.zh.ipynb |
Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit |
examples/lazy_jit/lazyjit.en.ipynb |
Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit |
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py |
Updated decorator usage and removed unused import |
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py |
Code cleanup: removed unused variables and debug print statements |
src/transform/layout_inference.cc |
Fixed buffer validation logic in layout inference |
3rdparty/tvm |
Updated TVM submodule commit reference |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In
@examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py:
- Around line 391-393: The printed "GFLOPS" are incorrect because avg_flops =
total_flops / avg_time yields FLOPS; convert to GFLOPS by dividing by 1e9 and
update the formatted value and label accordingly for both the main measurement
(avg_flops, total_flops, avg_time) and the reference implementation's
corresponding variables (e.g., ref_total_flops/ref_avg_flops or whatever the
reference block uses around lines ~405-407), ensuring you compute gflops =
total_flops / avg_time / 1e9 and print "GFLOPS" with that converted value.
🧹 Nitpick comments (3)
tilelang/language/v2/builder.py (1)
1027-1034: Consider simplifying redundant cache check.The cache check at line 1029 appears redundant since
parse_args()(called at line 1028) already populatesp1_cacheat lines 1019-1023. The comment mentions "legacy gemm" but the logic path wherep1_cachewould be empty afterparse_argsis unclear.♻️ Suggested simplification
def get_tir(self, *args, **kwargs): (p1_key, _), tensor_args = self.parse_args(*args, **kwargs) - if p1_key not in self.p1_cache: - # in legacy gemm, we use lazy tir template to build the tir - tir_temp = self._build_tir_template(*args, **kwargs) - self.p1_cache[p1_key] = tir_temp - return tir_temp.get_tir(**tensor_args) return self.p1_cache[p1_key].get_tir(**tensor_args)If there's a specific legacy case requiring this, please add a more detailed comment explaining when
parse_argswouldn't populate the cache.tilelang/jit/__init__.py (2)
295-302: Consider usingTypeErrorfor invalid type.Per Python conventions,
TypeErroris more appropriate thanValueErrorwhen the issue is an unexpected type.♻️ Suggested fix
if isinstance(self.func, PrimFunc): tir = self.func elif isinstance(self.func, (LazyJITFunc, Callable)): tir = self.func(*args, **kwargs) else: - raise ValueError(f"Invalid function type: {type(self.func)}") + raise TypeError(f"Invalid function type: {type(self.func)}") assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}" return tir
438-443: Pass kernel arguments as keyword arguments instead of unpacking dict values as positional arguments.At line 441,
kernel(*kernel_args.values())unpacks dictionary values positionally. While Python 3.7+ guarantees dict insertion order, this assumes the dict value order matches the kernel function's parameter order. To be explicit and avoid fragility, pass arguments as keywords:kernel(**kernel_args).
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
3rdparty/tvmexamples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.pyexamples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.pyexamples/lazy_jit/lazyjit.en.ipynbexamples/lazy_jit/lazyjit.zh.ipynbsrc/transform/layout_inference.cctesting/python/language/test_tilelang_language_lazy_jit.pytesting/python/language/test_tilelang_language_subtype.pytesting/python/layout/test_tilelang_annotate_loop_layout.pytilelang/__init__.pytilelang/jit/__init__.pytilelang/jit/exceptions.pytilelang/language/kernel.pytilelang/language/v2/builder.py
💤 Files with no reviewable changes (1)
- examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
🧰 Additional context used
🧠 Learnings (5)
📚 Learning: 2025-12-26T06:45:51.789Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1483
File: tilelang/jit/adapter/cutedsl/adapter.py:93-95
Timestamp: 2025-12-26T06:45:51.789Z
Learning: For the CuTeDSL backend in tilelang/jit/adapter/cutedsl/adapter.py, the host_kernel_source and device_kernel_source have the same value.
Applied to files:
examples/lazy_jit/lazyjit.en.ipynbexamples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/lazy_jit/lazyjit.en.ipynbexamples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
examples/lazy_jit/lazyjit.en.ipynbtesting/python/language/test_tilelang_language_subtype.pyexamples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_subtype.pytesting/python/language/test_tilelang_language_lazy_jit.pytesting/python/layout/test_tilelang_annotate_loop_layout.py
📚 Learning: 2025-12-15T07:48:45.785Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: src/target/codegen_cutedsl.cc:789-793
Timestamp: 2025-12-15T07:48:45.785Z
Learning: In tilelang/contrib/cutedsl, the `tl.make_rmem_tensor` function accepts both an Integer and a Tuple of Integer for its shape parameter. Therefore, both `tl.make_rmem_tensor(N, ...)` and `tl.make_rmem_tensor((N,), ...)` are valid syntaxes in CuteDSL-generated code.
Applied to files:
testing/python/language/test_tilelang_language_subtype.py
🧬 Code graph analysis (5)
testing/python/language/test_tilelang_language_subtype.py (1)
tilelang/jit/__init__.py (3)
jit(450-450)jit(454-464)jit(467-531)
testing/python/language/test_tilelang_language_lazy_jit.py (2)
tilelang/jit/__init__.py (3)
jit(450-450)jit(454-464)jit(467-531)tilelang/language/v2/builder.py (1)
const(826-854)
src/transform/layout_inference.cc (2)
src/op/utils.h (1)
IsFragmentBuffer(33-35)src/transform/common/loop_fusion_utils.h (1)
IsFragmentBuffer(70-75)
tilelang/language/v2/builder.py (1)
tilelang/jit/exceptions.py (2)
JITNoBuilderError(4-13)EagerJITBuildError(16-24)
tilelang/language/kernel.py (2)
tilelang/jit/exceptions.py (1)
JITNoBuilderError(4-13)tilelang/language/v2/builder.py (2)
Builder(166-676)current(178-180)
🪛 Ruff (0.14.10)
tilelang/__init__.py
145-145: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/language/v2/builder.py
845-845: Avoid specifying long messages outside the exception class
(TRY003)
1012-1012: Avoid specifying long messages outside the exception class
(TRY003)
1053-1053: Avoid specifying long messages outside the exception class
(TRY003)
1078-1078: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/jit/__init__.py
300-300: Prefer TypeError exception for invalid type
(TRY004)
300-300: Avoid specifying long messages outside the exception class
(TRY003)
370-370: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/kernel.py
290-290: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: CodeQL analysis (python)
🔇 Additional comments (19)
3rdparty/tvm (1)
1-1: Verify submodule compatibility with JIT refactoring changes.Submodule reference updated from
e47e76a...to001022b.... Since this PR introduces mode-inference changes to the JIT system and adds builder-aware error handling, please confirm that the new TVM commit is compatible with these tilelang changes and does not introduce breaking changes or regressions.examples/lazy_jit/lazyjit.zh.ipynb (1)
56-56: LGTM! Decorator migration is consistent throughout the notebook.All
@tilelang.lazy_jitusages have been correctly updated to@tilelang.jit. The notebook content properly demonstrates eager-style DSL patterns (usingT.const,T.Tensorannotations,T.Kernel), which aligns with the unified JIT decorator's auto-detection behavior.Also applies to: 212-212, 251-251, 310-310, 362-362, 424-424, 473-473, 518-518, 580-580, 800-800, 860-860
src/transform/layout_inference.cc (1)
225-231: LGTM! Warning condition correctly narrowed to fragment buffers only.The updated condition
!use_list_.count(buffer) && IsFragmentBuffer(buffer)appropriately limits the warning to cases where fragment buffers are missing fromuse_list_. Non-fragment buffers (global/shared memory) not appearing inuse_list_is expected behavior and doesn't warrant a warning.Note: Line 234 still accesses
use_list_[buffer]even when the buffer isn't inuse_list_. This is safe becausestd::unordered_map::operator[]creates an empty vector for missing keys, and iterating an empty vector is a no-op.testing/python/layout/test_tilelang_annotate_loop_layout.py (1)
7-7: LGTM! Test file correctly migrated to unified@tilelang.jitdecorator.All three kernel functions (
loop_layout_kernel,copy_with_layout_kernel,replicate_loop_layout_kernel) have been updated from@tilelang.lazy_jitto@tilelang.jit. The test assertions remain unchanged, ensuring the loop layout annotation functionality continues to work correctly with the unified decorator.Also applies to: 54-54, 82-82
examples/lazy_jit/lazyjit.en.ipynb (1)
56-56: LGTM! English notebook correctly migrated to unified@tilelang.jitdecorator.All decorator usages have been consistently updated from
@tilelang.lazy_jitto@tilelang.jit, matching the changes in the Chinese version of the notebook.Also applies to: 212-212, 251-251, 310-310, 362-362, 424-424, 473-473, 518-518, 580-580, 800-800, 860-860
tilelang/language/kernel.py (2)
10-10: LGTM! Import for the new exception type.The import of
JITNoBuilderErroris correctly placed and supports the new runtime guard for enforcing Builder context.
283-291: LGTM! Runtime guard correctly enforces Builder context for T.Kernel().The guard ensures
T.Kernel()can only be called within a proper JIT/prim_func context by checking for an active Builder. The lazy import ofBuilderappropriately avoids circular import issues.The inline error message is clear and actionable. While static analysis (Ruff TRY003) suggests moving long messages into the exception class, the context-specific message here is acceptable since
JITNoBuilderErroris reused across multiple call sites (e.g.,T.const()) that may need different messages.tilelang/jit/exceptions.py (1)
1-24: LGTM! Well-structured custom exceptions for JIT error handling.Both exception classes are cleanly defined with descriptive docstrings that clearly explain their purpose:
JITNoBuilderError: Raised when Builder-dependent operations (likeT.Kernel(),T.const()) are called outside a JIT/prim_func contextEagerJITBuildError: Raised for failures during eager-style kernel constructionThe separation allows callers to catch and handle these distinct error conditions appropriately.
tilelang/__init__.py (1)
145-145: LGTM! Public API correctly updated to removelazy_jitexport.Removing
lazy_jitfrom the public exports aligns with the PR objective to unify JIT functionality under the single@tilelang.jitdecorator. Users should migrate from@tilelang.lazy_jitto@tilelang.jit.Note: The static analysis hint about unused
noqa: F401appears to be a false positive — the directive correctly suppresses the "imported but unused" warning for these intentional re-exports.This is a breaking change for any external code using
tilelang.lazy_jit. Please ensure the migration guide or release notes document this API removal.testing/python/language/test_tilelang_language_subtype.py (1)
9-10: LGTM!The decorator migration from
@tilelang.lazy_jitto@tilelang.jitis consistent across all kernel functions. These kernels use the eager-style DSL pattern (withT.dynamic, tensor annotations, andT.Kernelcontext), which the unified decorator will correctly auto-detect.Also applies to: 18-19, 98-99, 108-109, 119-120, 130-131, 142-143
testing/python/language/test_tilelang_language_lazy_jit.py (1)
9-9: LGTM!The decorator migrations from
@tilelang.lazy_jitto@tilelang.jitare correct across all kernel functions. All use the eager-style DSL pattern withT.const, tensor annotations, andT.Kernelcontext managers.Also applies to: 48-48, 105-105, 112-112, 119-119, 126-126, 134-134, 141-141, 178-178, 184-184, 189-189, 195-195, 202-202, 208-208
tilelang/language/v2/builder.py (4)
841-845: LGTM with a note on naming.The error handling correctly enforces that
T.const()is only usable in eager mode with an active Builder. The internal flag namebuilder.lazy_jitbeingTruefor eager mode is counterintuitive but is existing behavior maintained for backward compatibility.
857-923: LGTM!The
TirTemplateclass correctly distinguishes between lazy-style (direct PrimFunc return, no substitution needed) and eager-style (constexpr variable substitution required). Thefrom_lazy_stylefactory and early return inget_tirare clean implementations.
965-993: Verify side-effect safety of mode inference probing.The
_is_lazy_style()method probes the original function by callingself.orig_func(*args, **kwargs)to detect the execution style. This works correctly for the detection logic, but be aware that:
- Any side effects (e.g., prints, logging, external calls) in the decorated function will execute during mode inference.
- The function may be called twice: once for probing and once for actual TIR building.
This is likely acceptable since:
- Lazy-style functions typically just construct and return a PrimFunc without side effects.
- Eager-style functions will raise
JITNoBuilderErrorearly before reaching user code.Please confirm this probing behavior is intentional and document it if users might have functions with side effects.
1091-1096: LGTM!The
get_origin(annot[k]) is Nonecheck correctly prevents calling typing generics (likeOptional[int],Union[...],List[...]) which are callable but cannot be instantiated. This is a proper fix for handling modern type annotations.tilelang/jit/__init__.py (4)
192-265: LGTM!Excellent documentation for
JITImpl. The docstring clearly explains both execution modes with practical examples, making the API behavior easy to understand for users.
304-316: LGTM!The mode inference logic is clean and correct:
- Short-circuits for explicitly set modes
- Handles raw
PrimFuncinputs by defaulting to lazy mode- Delegates to
LazyJITFunc._is_lazy_style()for automatic detection
363-398: LGTM!The
compilemethod correctly:
- Infers mode on first compilation
- Validates that
out_idxis only used in lazy mode (as per PR objectives)- Propagates mode to the underlying function
- Handles debug output appropriately
467-531: LGTM!The unified
jitdecorator is well-implemented:
- Supports both
@jitand@jit(options=...)usage patterns- Initializes mode as "auto" for automatic detection
- Correctly wraps functions via
prim_func(func, lazy_jit=True)to createLazyJITFunc- Preserves source code and signature for debugging and introspection
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
Show resolved
Hide resolved
…ation and improve error handling for eager mode. Update matmul tests to utilize both Cython and TVM FFI JIT kernels, ensuring consistency in output validation.
…m_intrinsics.py and related test files. Enhance layout.cc to simplify integer set analysis for multi-threaded compilation, ensuring constant integer outputs. Update lower_tile_op.cc to enforce constant integer checks in layout output shapes. Improve error handling in builder.py for JIT context management.
…d proxy.py for improved readability. Add missing import in test_tilelang_kernel_int4_gemm_mma.py and ensure consistent spacing in builder.py.
…ons within test_arith_hard.py for improved clarity and consistency. Update main execution flow to call test_hard_prove instead of tilelang.testing.main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py (1)
18-159: Add@simplify_prim_funcdecorator totl_matmulfor consistency.The
tl_matmulfunction (line 18) is missing the@simplify_prim_funcdecorator that appears ontl_matmul_weight_only_transform(line 203) and ontl_matmulin all other similar test files (test_tilelang_kernel_fp8_gemm_mma.py,test_tilelang_kernel_bf16_gemm_mma.py,test_tilelang_kernel_gemm_mma_intrinsic.py, etc.). Apply the decorator to maintain consistency and ensure uniform IR simplification across equivalent kernels.
🤖 Fix all issues with AI agents
In @tilelang/jit/__init__.py:
- Line 298: The isinstance check mixes LazyJITFunc and typing.Callable
erroneously; remove Callable from the tuple so the branch only matches
LazyJITFunc (i.e., change isinstance(self.func, (LazyJITFunc, Callable)) to
isinstance(self.func, LazyJITFunc)), and then remove the now-unused Callable
import if it becomes unused so imports stay clean.
🧹 Nitpick comments (3)
testing/python/language/test_tilelang_language_lazy_jit.py (1)
229-230: Commented test runner should be cleaned up.The main test runner
tilelang.testing.main()is commented out and replaced with a direct call totest_jit2_return(). This appears to be temporary debugging code that should either be reverted or properly justified.🧹 Suggested cleanup
if __name__ == "__main__": - # tilelang.testing.main() - test_jit2_return() + tilelang.testing.main()testing/python/language/test_tilelang_language_frontend_v2.py (2)
294-320: LGTM! In-place tensor modification pattern is correct.Both functions correctly use the eager-mode pattern with tensor annotations inside the function body (
A: T.Tensor[(2,), T.float32]), which is the expected TileLang syntax for eager execution with external tensor arguments.Note: Line 313 contains an extraneous blank line that could be removed for consistency.
324-339: LGTM! Consider consistent use of.item()for tensor element access.The function correctly implements the eager-mode pattern. Line 338 uses
.item()to extract the scalar value for comparison, which is more explicit than the direct comparison pattern used intest_var_assign(lines 220-221). While both approaches work with PyTorch, using.item()is more robust and explicit.For consistency, consider using
.item()intest_var_assignas well:assert res[0].item() == 1 assert res[1].item() == 2
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
examples/gemm/example_gemm_intrinsics.pysrc/layout/layout.ccsrc/transform/lower_tile_op.cctesting/python/arith/test_arith_hard.pytesting/python/kernel/test_tilelang_kernel_int4_gemm_mma.pytesting/python/language/test_tilelang_language_frontend_v2.pytesting/python/language/test_tilelang_language_lazy_jit.pytesting/python/language/test_tilelang_language_ptr.pytilelang/jit/__init__.pytilelang/language/proxy.pytilelang/language/v2/builder.py
💤 Files with no reviewable changes (1)
- examples/gemm/example_gemm_intrinsics.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_ptr.pytesting/python/language/test_tilelang_language_lazy_jit.pytesting/python/language/test_tilelang_language_frontend_v2.pytesting/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/language/test_tilelang_language_ptr.pytesting/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
testing/python/language/test_tilelang_language_lazy_jit.pytesting/python/language/test_tilelang_language_frontend_v2.pytesting/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
🧬 Code graph analysis (6)
tilelang/language/proxy.py (2)
tilelang/jit/exceptions.py (1)
JITNoBuilderError(4-13)tilelang/language/v2/builder.py (2)
Builder(167-679)current(179-181)
testing/python/arith/test_arith_hard.py (1)
tilelang/language/tir/op.py (1)
all(1913-1930)
tilelang/jit/__init__.py (3)
tilelang/jit/adapter/tvm_ffi.py (2)
func(206-260)prim_func(319-321)tilelang/language/v2/builder.py (10)
LazyJITFunc(935-1062)PrimFunc(688-697)_is_lazy_style(974-1002)set_mode(1048-1050)prim_func(184-193)prim_func(1087-1128)get_tir(922-931)get_tir(1036-1043)parse_args(1023-1034)get(223-224)tilelang/jit/adapter/wrapper.py (2)
prim_func(575-585)prim_func(839-849)
testing/python/language/test_tilelang_language_lazy_jit.py (2)
tilelang/jit/__init__.py (4)
jit(451-451)jit(455-465)jit(468-532)get_tir(291-303)tilelang/language/copy_op.py (1)
copy(14-116)
testing/python/language/test_tilelang_language_frontend_v2.py (2)
tilelang/jit/__init__.py (3)
jit(451-451)jit(455-465)jit(468-532)tilelang/language/allocate.py (3)
alloc_var(86-86)alloc_var(90-90)alloc_var(93-147)
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py (1)
tilelang/transform/simplify.py (1)
simplify_prim_func(53-58)
🪛 Ruff (0.14.10)
tilelang/language/proxy.py
282-282: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/__init__.py
301-301: Prefer TypeError exception for invalid type
(TRY004)
301-301: Avoid specifying long messages outside the exception class
(TRY003)
324-324: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/v2/builder.py
191-191: Avoid specifying long messages outside the exception class
(TRY003)
717-717: Avoid specifying long messages outside the exception class
(TRY003)
854-854: Avoid specifying long messages outside the exception class
(TRY003)
1021-1021: Avoid specifying long messages outside the exception class
(TRY003)
1062-1062: Avoid specifying long messages outside the exception class
(TRY003)
1087-1087: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (12)
src/layout/layout.cc (1)
126-144: LGTM! Robust multi-level fallback for extent calculation.The simplification attempt with cascading fallbacks (IntImm check → ConstIntBound → input_size_ → Integer(1)) provides good defensive handling for multi-threaded compilation scenarios where int_set analysis may yield unsimplified symbolic expressions. The logic aligns well with the existing fallback path for unbounded cases (lines 107-124).
src/transform/lower_tile_op.cc (1)
69-70: LGTM! Essential validation for constant layout shapes.The ICHECK correctly enforces that layout output shapes must be constant integers before dereferencing on line 71, preventing potential null pointer crashes. The error message clearly indicates the requirement and shows the problematic value.
testing/python/arith/test_arith_hard.py (2)
6-6: LGTM: Import addition is appropriate.The import of
tir_allis necessary for constructing proper TIR expressions. The alias avoids shadowing Python's built-inallfunction.
27-59: LGTM: Correct usage oftir_allfor symbolic expressions.Replacing Python's
andoperator withtir_all(...)is the correct approach for combining symbolic boolean expressions in TIR. The Pythonandoperator would not create the proper TIR expression structure needed for the Analyzer to reason about these conditions. Usingtir_allensures the expressions are constructed as proper TIR intersection nodes that the symbolic reasoning engine can analyze.testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py (1)
9-9: LGTM: Import correctly added.The import of
simplify_prim_funcis properly placed and necessary for the decorator usage on line 203.tilelang/language/proxy.py (1)
278-282: LGTM! Builder context check added for make_tensor.The runtime guard correctly enforces that
make_tensorrequires an active Builder context. The error message clearly guides users to use the function within@tilelang.jitor@T.prim_funccontext.testing/python/language/test_tilelang_language_ptr.py (1)
44-58: LGTM! Multi-backend validation improves test coverage.The test now validates the matmul kernel across both Cython and TVM FFI execution backends, comparing results against a reference implementation and cross-validating between backends. This strengthens confidence in backend consistency.
testing/python/language/test_tilelang_language_lazy_jit.py (2)
9-9: LGTM! Decorator updated to unified jit API.The change from
@tilelang.lazy_jitto@tilelang.jitaligns with the PR objective of unifying the JIT decorator. The mode will be auto-inferred based on function behavior.
149-149: LGTM! par_compile updated for new API.The
par_compilecall now correctly usesget_tir()to obtain TIR functions with dummy tensor arguments for each copy function. This aligns with the new mode-based JIT implementation.testing/python/language/test_tilelang_language_frontend_v2.py (3)
206-222: LGTM! Eager-mode pattern correctly implemented.The function correctly uses the unified
@tilelang.jitdecorator with eager-style execution: it allocates the tensor internally, performs operations, and returns the result.
258-292: LGTM! Both nested test functions follow the eager-mode pattern.Both
stepped_serialandstepped_serial_negcorrectly use the unified decorator, allocate tensors internally, and return results for direct verification.
457-478: LGTM! Conditional logic correctly implemented for eager mode.The
probefunction correctly uses the unified decorator with eager-mode semantics. The tensor annotation inside the function body and the constexpr conditional logic based on thetmpparameter are both properly implemented.
…nctions in test_tilelang_issue_1549.py and test_tilelang_issue_1601.py to ensure CUDA requirements are explicitly stated for relevant tests.
…setting max value from ist.max(). Update test_valid_loop in test_tilelang_fragment_loop_checker.py to disable certain tests and ensure proper execution flow. Enhance error handling in builder.py to correctly manage PrimFunc return types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @testing/python/analysis/test_tilelang_fragment_loop_checker.py:
- Around line 145-147: Two test functions valid_loop_not_use_loop_var() and
valid_loop_serial() are defined but never invoked, leaving dead code and reduced
coverage; either re-enable them by adding calls (uncomment the calls currently
commented out near valid_loop_not_frag()) and fix any failures they expose, or
delete the unused function definitions and add a brief TODO/issue reference
comment explaining why they were removed; ensure that the final change either
calls valid_loop_not_use_loop_var() and valid_loop_serial() so all three
valid-loop scenarios run, or removes both functions and updates test coverage
expectations accordingly.
🧹 Nitpick comments (1)
tilelang/language/v2/builder.py (1)
1008-1026: Validate mode before building TIR template.The method expects
modeto be either "lazy" or "eager", but there's no validation that it's not "auto". Consider adding an assertion at the start:assert self.mode in ("lazy", "eager"), f"Mode must be set before building TIR template, got: {self.mode}"This would make the requirement explicit and provide a clearer error if the caller forgets to set the mode.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
testing/python/analysis/test_tilelang_fragment_loop_checker.pytilelang/language/v2/builder.py
🧰 Additional context used
🧬 Code graph analysis (2)
testing/python/analysis/test_tilelang_fragment_loop_checker.py (1)
tilelang/env.py (1)
disable_cache(297-298)
tilelang/language/v2/builder.py (1)
tilelang/jit/exceptions.py (2)
JITNoBuilderError(4-13)EagerJITBuildError(16-24)
🪛 Ruff (0.14.10)
tilelang/language/v2/builder.py
191-191: Avoid specifying long messages outside the exception class
(TRY003)
717-717: Avoid specifying long messages outside the exception class
(TRY003)
854-854: Avoid specifying long messages outside the exception class
(TRY003)
999-999: Consider moving this statement to an else block
(TRY300)
1025-1025: Avoid specifying long messages outside the exception class
(TRY003)
1066-1066: Avoid specifying long messages outside the exception class
(TRY003)
1091-1091: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (13)
testing/python/analysis/test_tilelang_fragment_loop_checker.py (1)
150-153: Verify if main block changes are intentional.The main block now bypasses the standard test runner (
tilelang.testing.main()) and directly callstest_valid_loop(). This pattern is unusual for committed test code and suggests debugging/manual testing.Please clarify:
- Is this change intentional for production, or was it left in during debugging?
- If intentional, why is the standard test runner bypassed?
- Should
tilelang.disable_cache()be in the main block, or should it be a test fixture?If this was debugging code, consider reverting to the standard pattern:
♻️ Revert to standard test execution pattern
if __name__ == "__main__": - # tilelang.testing.main() - tilelang.disable_cache() - test_valid_loop() + tilelang.testing.main()tilelang/language/v2/builder.py (12)
21-21: LGTM!The new imports support the lazy/eager mode detection logic and builder context validation.
Also applies to: 32-32
186-193: Good defensive cleanup pattern.The try/finally ensures
thread_local_storage.builderis always cleared, preventing builder leakage if an exception occurs during function body execution.
715-718: LGTM!Enforcing builder existence for macros is correct, as they require an active JIT context to generate IR.
835-864: LGTM!The updated docstring clearly explains eager mode usage, and the explicit exception handling is more appropriate than assertions for user-facing error cases.
867-907: Lazy-style templates bypass matcher safely.The
is_lazy_stylefield distinguishes between eager and lazy templates. For lazy-style templates created viafrom_lazy_style(), thematcherisNone. This is safe becauseget_tir()(line 923) returns early for lazy-style templates, so_parse_phase2_key()(line 909) is never invoked.
922-931: LGTM!The early return for lazy-style templates correctly bypasses shape substitution, as lazy functions return PrimFunc directly without constexpr variables.
974-1006: Mode inference through controlled exception handling.The approach is clever: calling the function and inspecting the outcome (PrimFunc vs. JITNoBuilderError) distinguishes lazy from eager style. The caching at lines 996-998 for lazy-style avoids redundant PrimFunc construction.
1027-1038: LGTM!The caching logic preserves performance by avoiding redundant TIR construction. The comment at line 1034 documents the contract with JITImpl to set mode before calling this method.
1040-1047: LGTM!The fallback at lines 1042-1046 handles the case where
parse_args()returns early without caching (line 1030), ensuring the TIR template is always built when needed.
1052-1054: LGTM!The setter allows JITImpl to configure the execution mode after inference, correctly typed to exclude "auto".
1056-1066: LGTM!Proxying function attributes is necessary for compatibility with introspection tools and the autotuner. The implementation correctly handles
__wrapped__as a special case and delegates other introspection attributes to the original function.
1091-1131: Critical fix for type annotation handling.The
get_origin()check at line 1108 prevents attempting to instantiate typing generics (likeOptional[int],Union[...]) which are callable but should not be called. This correctly distinguishes between factory functions and typing constructs.
|
@regression-perf |
…hecker.py to call tilelang.testing.main. Modify JIT implementation to handle callable functions more generically. Enhance TirTemplate class in builder.py to return an empty tuple if matcher is None, improving robustness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tilelang/language/v2/builder.py (1)
1093-1093: Minor: Consider explicit Optional type annotation.The static analysis hint (RUF013) notes that PEP 484 prohibits implicit
Optional. Line 1093 has:def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> ...For strict PEP 484 compliance, this should be:
def prim_func(func: Optional[Callable[_P, _T]] = None, *, lazy_jit: bool = False) -> ...However, this is a very minor stylistic issue and doesn't affect runtime behavior. Python's type checkers are generally lenient about this pattern in decorators.
📝 Optional fix for PEP 484 compliance
+from typing import Optional + -def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]: +def prim_func(func: Optional[Callable[_P, _T]] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
testing/python/analysis/test_tilelang_fragment_loop_checker.pytilelang/jit/__init__.pytilelang/language/v2/builder.py
🚧 Files skipped from review as they are similar to previous changes (1)
- testing/python/analysis/test_tilelang_fragment_loop_checker.py
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/language/v2/builder.py (2)
tilelang/jit/__init__.py (3)
jit(451-451)jit(455-465)jit(468-532)tilelang/jit/exceptions.py (2)
JITNoBuilderError(4-13)EagerJITBuildError(16-24)
tilelang/jit/__init__.py (1)
tilelang/language/v2/builder.py (10)
LazyJITFunc(937-1068)PrimFunc(688-697)_is_lazy_style(976-1008)set_mode(1054-1056)prim_func(184-193)prim_func(1093-1134)get_tir(924-933)get_tir(1042-1049)parse_args(1029-1040)get(223-224)
🪛 Ruff (0.14.10)
tilelang/language/v2/builder.py
191-191: Avoid specifying long messages outside the exception class
(TRY003)
717-717: Avoid specifying long messages outside the exception class
(TRY003)
854-854: Avoid specifying long messages outside the exception class
(TRY003)
1001-1001: Consider moving this statement to an else block
(TRY300)
1027-1027: Avoid specifying long messages outside the exception class
(TRY003)
1068-1068: Avoid specifying long messages outside the exception class
(TRY003)
1093-1093: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/jit/__init__.py
324-324: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (14)
tilelang/jit/__init__.py (5)
277-279: LGTM! Mode attribute and func type narrowing are well-structured.The addition of the
modeattribute and narrowingfunctoLazyJITFunc[_KP, _T]properly support the new unified JIT workflow. The type narrowing is sound given that the decorator always creates a LazyJITFunc viaprim_func(func, lazy_jit=True)at line 520.
373-399: Approve compile method changes.The refactoring to use
prim_func(line 374) instead offuncandprim_func.script()(line 397) instead offunc.script()is correct and aligns with the new mode-based workflow.
426-444: Approve call implementation with minor observation.The early mode inference (lines 429-431) and conditional execution based on mode (lines 439-444) are well-structured. The eager mode execution path correctly invokes the kernel immediately with
kernel(*kernel_args.values()).One minor observation: At line 442, you're unpacking
kernel_args.values()with*. Ensure that the ordering ofkernel_argsis deterministic and matches the expected parameter order for the kernel. Based on the code inbuilder.pywheretensor_argsis built from a dict comprehension, dictionary order is preserved (Python 3.7+), so this should be fine.
518-530: Approve unified jit decorator implementation.The decorator properly creates a
LazyJITFuncwithlazy_jit=True(line 520) and initializes the mode to "auto" (line 519), which will be inferred on first call. The removal of the legacylazy_jitparameter and the unified approach is clean.
295-325: The mode inference and initialization logic is sound and handles edge cases appropriately.The three points are verified as non-blocking:
Defaulting to "lazy" mode (line 315-316): Safe and correct. All decorated functions pass through
prim_func(..., lazy_jit=True)during initialization, so this default only applies to edge cases where a non-LazyJITFunc is used directly—a rare scenario where "lazy" is the appropriate fallback.Error message for out_idx in eager mode (line 323-324): The message is clear and helpful. Since it appears only once in the file, extracting to a constant is optional rather than necessary.
Mode state mutation (lines 321, 430): Intentional and safe. Mode is inferred once on the first call (when
self.mode == "auto") and cached for subsequent calls. All operations that depend on mode occur after initialization, making the caching behavior correct and efficient.tilelang/language/v2/builder.py (9)
186-193: Excellent: Builder cleanup in finally block.Wrapping the cleanup in a
finallyblock (line 192-193) ensuresthread_local_storage.builderis always cleared even when exceptions occur. This prevents stale Builder references from leaking across prim_func invocations.
714-721: Approve runtime guard for Macro usage.The check at lines 716-717 correctly enforces that macros can only be used within a JIT context by raising
JITNoBuilderErrorwhen no active Builder exists. This aligns with the PR's goal of adding Builder existence checks.
835-863: Approve eager-mode enforcement in const().The updated docstring (lines 836-849) clearly documents that
T.const()is for eager mode only. The runtime check (lines 853-854) properly raisesJITNoBuilderErrorwhen called outside of eager JIT context.The commented-out assertions (lines 851-852) should ideally be removed rather than left as comments, but since they're immediately followed by the replacement code, this is acceptable for clarity during development.
904-907: LGTM: TirTemplate.from_lazy_style factory method.The new factory method correctly creates a lazy-style template with
is_lazy_style=Trueand no matcher, which aligns with lazy mode where PrimFunc is used directly without substitution.
976-1008: Approve mode inference logic with strong exception-based detection.The
_is_lazy_stylemethod (lines 976-1008) implements clever mode detection:
Lines 994-1000: If the function returns a PrimFunc, it's lazy style. The early caching at lines 998-999 is a nice optimization.
Lines 1002-1008: Catching
JITNoBuilderErrorandEagerJITBuildErrorto detect eager mode is elegant. The comment clearly explains that eager-only features (likeT.const()orT.Kernel()) raise these exceptions when no Builder exists during the trial call.Line 1001: The static analysis hint (TRY300) suggests moving
return Falseto anelseblock. However, the current structure is clear and the earlyreturn Truemakes the logic easier to follow. The hint can be safely ignored.
1010-1028: LGTM: _build_tir_template implements mode-based template construction.The method correctly branches on mode:
- Lazy mode (lines 1012-1014): Calls the function directly to get the PrimFunc.
- Eager mode (lines 1015-1025): Traces through Builder to construct TIR, similar to the original prim_func flow.
The error message at line 1027 clearly indicates invalid mode values. Per the static analysis hint (TRY003), this could be a custom exception, but for an internal method, a ValueError with a clear message is acceptable.
1029-1056: Approve parse_args and set_mode methods.
parse_args (lines 1029-1040): Correctly handles both lazy and eager styles, with proper caching and template building. The comment at line 1036 helpfully notes that mode should be set by JITImpl before calling this method.
set_mode (lines 1054-1056): Simple setter for internal use. The docstring clarifies this is internal-only.
1058-1069: Excellent: Attribute proxying for autotuner compatibility.The
__getattr__implementation (lines 1063-1068) with_PROXIED_ATTRS(line 1061) elegantly proxies function attributes like__closure__,__code__, etc., to the original function. The comment (lines 1058-1060) explains this is needed for autotuner and inspection, which is valuable context.The special handling of
__wrapped__(lines 1065-1066) follows Python conventions for wrapped functions.
1107-1111: Approve callable annotation handling with get_origin check.The updated logic (lines 1107-1111) correctly distinguishes between:
- Callable factory functions (should be called)
- Typing generics like
Optional[int],Union[...],List[...](should NOT be called)The
get_origin(annot[k]) is Nonecheck ensures typing constructs are not mistakenly invoked. This is a solid fix for a subtle bug.
Performance Regression Test ReportTriggered by: @LeiWang1999 Results
Artifacts
|
…ss the codebase, updating related function signatures and logic to support eager JIT compilation mode. This change enhances the clarity and functionality of JIT operations.
…escription and remove outdated comment regarding implicit TIR construction.
|
changes applied, pls take a look @SiriusNEO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @tilelang/jit/__init__.py:
- Around line 318-324: initialize_jit_mode currently calls
self.func.set_mode(self.mode) unconditionally which will raise AttributeError
when self.func is a PrimFunc; check the actual type before calling set_mode
(e.g., use isinstance(self.func, JITFunc) or mirror the guard used in
_infer_jit_mode) and only call self.func.set_mode(self.mode) when self.func is a
JITFunc; apply the same guarded check to the other unconditional call to
self.func.set_mode at the other location (around the later call referenced at
line 430) so PrimFunc paths do not attempt to call set_mode.
- Around line 438-443: The eager-mode return path assumes kernel_args is a dict
but parse_args can return None; update the eager branch in __init__.py to handle
kernel_args being None by calling kernel() with no arguments when kernel_args is
None (or using an empty list), e.g. compute args = list(kernel_args.values()) if
kernel_args else [] and then return kernel(*args); reference symbols:
kernel_args, parse_args, kernel, self.mode.
🧹 Nitpick comments (4)
tilelang/language/v2/builder.py (3)
866-871: Remove commented-out code.The commented-out assertions on lines 867-868 are dead code that should be removed since the proper check is already implemented on lines 869-870.
♻️ Proposed fix
builder = Builder.current() - # assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)" - # assert builder.eager_jit, "T.const() can only be used inside @tilelang.jit (eager mode)" if builder is None or not builder.eager_jit: raise JITNoBuilderError("T.const() can only be used inside @tilelang.jit (eager mode)")
1058-1065: Consider clarifying the cache check redundancy.After
parse_argspopulatesp1_cache(lines 1051-1054), the check at line 1060if p1_key not in self.p1_cacheseems redundant. The comment mentions "legacy gemm" - if this is for backward compatibility, consider making it more explicit or removing if the legacy path is no longer needed.
1109-1109: Type annotation uses implicit Optional.Per PEP 484,
func: Callable[_P, _T] = Noneshould befunc: Callable[_P, _T] | None = Nonefor type checker compatibility. This is a minor typing issue that doesn't affect runtime behavior.♻️ Proposed fix
-def prim_func(func: Callable[_P, _T] = None, *, eager_jit: bool = False) -> PrimFunc[_P, _T] | JITFunc[_P, _T]: +def prim_func(func: Callable[_P, _T] | None = None, *, eager_jit: bool = False) -> PrimFunc[_P, _T] | JITFunc[_P, _T]:tilelang/jit/__init__.py (1)
426-431: Mode inference logic is duplicated.Lines 428-430 duplicate the mode inference from
initialize_jit_mode(lines 318-321). Consider extracting to a shared helper or callinginitialize_jit_modehere to ensure consistent behavior, especially for theout_idxvalidation at line 322-323 which is skipped in__call__.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tilelang/jit/__init__.pytilelang/language/v2/__init__.pytilelang/language/v2/builder.py
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/language/v2/__init__.py (1)
tilelang/language/v2/builder.py (7)
prim_func(189-198)prim_func(1109-1150)macro(201-228)macro(746-782)PrimFunc(704-713)JITFunc(953-1084)const(851-879)
tilelang/jit/__init__.py (7)
tilelang/language/v2/builder.py (5)
PrimFunc(704-713)prim_func(189-198)prim_func(1109-1150)JITFunc(953-1084)get(230-231)tilelang/jit/adapter/wrapper.py (2)
prim_func(575-585)prim_func(839-849)tilelang/jit/adapter/cython/adapter.py (1)
prim_func(356-358)tilelang/jit/adapter/tvm_ffi.py (2)
prim_func(319-321)func(206-260)tilelang/jit/adapter/nvrtc/adapter.py (1)
prim_func(267-269)src/transform/simplify.cc (2)
func(222-279)func(223-225)tilelang/jit/kernel.py (1)
out_idx(609-610)
tilelang/language/v2/builder.py (2)
tilelang/jit/__init__.py (3)
jit(450-450)jit(454-464)jit(467-531)tilelang/jit/exceptions.py (2)
JITNoBuilderError(4-13)EagerJITBuildError(16-24)
🪛 Ruff (0.14.10)
tilelang/language/v2/__init__.py
1-1: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/jit/__init__.py
323-323: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/v2/builder.py
196-196: Avoid specifying long messages outside the exception class
(TRY003)
733-733: Avoid specifying long messages outside the exception class
(TRY003)
870-870: Avoid specifying long messages outside the exception class
(TRY003)
1017-1017: Consider moving this statement to an else block
(TRY300)
1043-1043: Avoid specifying long messages outside the exception class
(TRY003)
1084-1084: Avoid specifying long messages outside the exception class
(TRY003)
1109-1109: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (11)
tilelang/language/v2/builder.py (6)
21-21: LGTM: Import additions for mode-based JIT support.The imports of
Literal,get_originfrom typing and exception classes fromtilelang.jit.exceptionsare appropriate for the new unified JIT behavior.Also applies to: 32-32
188-198: Good: Proper cleanup with try/finally pattern.The cleanup of
thread_local_storage.builderin thefinallyblock ensures the builder is always cleared, preventing stale state. This is a solid improvement over the previous implementation.
730-737: Good: Runtime guard for macro invocation context.Raising
JITNoBuilderErrorwhen no Builder is active ensures macros are only used within a proper JIT context, enabling the mode inference mechanism.
882-950: LGTM: TirTemplate properly handles lazy vs eager styles.The
is_lazy_stylefield andfrom_lazy_stylefactory method cleanly separate the two code paths. The early returns in_parse_phase2_keyandget_tirfor lazy style are correct.
1010-1024: Mode inference via exception catching is a reasonable pattern, but document the side-effect risk.Calling
self.orig_func(*args, **kwargs)to determine if the function returns a PrimFunc works, but note that:
- If the function has side effects before returning, those will execute during mode inference
- The result is cached in
p1_cacheon success (line 1015), so subsequent calls won't re-executeThis pattern is acceptable given the design constraints, but consider adding a comment noting that the original function should be side-effect-free up to the point of returning or raising.
1122-1127: Good: Properly handles typing generics.The
get_origin(annot[k]) is Nonecheck correctly prevents calling typing constructs likeOptional[int]orUnion[...]which are callable but cannot be instantiated as factories.tilelang/language/v2/__init__.py (1)
1-2: LGTM: Export updated from LazyJITFunc to JITFunc.The public API correctly exports the new
JITFuncclass. Thenoqa: F401directive suppresses unused import warnings for these re-exports, which is appropriate. The Ruff warning about "unused noqa" suggests F401 may not be enabled in the project's Ruff configuration.tilelang/jit/__init__.py (4)
27-27: LGTM: Import updated for unified JITFunc.The import correctly brings in
JITFuncto replace the previousLazyJITFunc.
192-264: Excellent: Comprehensive documentation for the unified JIT interface.The docstring clearly explains the two execution modes with concrete examples, making it easy for users to understand the behavior difference between lazy and eager styles.
517-529: LGTM: Decorator correctly initializes unified JIT with auto mode.The decorator always creates a
JITFuncviaprim_func(func, eager_jit=True)and setsmode="auto"for later inference. This unified approach correctly defers mode determination to runtime.
449-464: LGTM: Overload signatures for decorator patterns.The overloads correctly handle both
@jit(no arguments) and@jit(out_idx=...)(with keyword arguments) usage patterns.
…ate imports across multiple files to replace references to the deprecated v2 module with the new eager module, ensuring compatibility with the latest JIT compilation features. Remove obsolete v2 module files to streamline the codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tilelang/jit/__init__.py (2)
156-190:par_compilehas a syntax/runtime blocker (results = [... for _ in futures]) and deadreturnIf this is literal code, it won’t run. Also,
return resultsafter thewithis unreachable.Proposed fix
@@ - results = [... for _ in futures] + results: list[JITKernel[_KP, _T] | None] = [None] * len(futures) @@ - return results - return results + return results
47-107:compile()type contract is inconsistent (func: PrimFunc = Nonebut immediately asserted)Either make
funcrequired, or raise a clearTypeErrorwhenNoneis passed (the current signature suggestsNoneis acceptable).Proposed fix
@@ -def compile( - func: PrimFunc[_KP, _T] = None, +def compile( + func: PrimFunc[_KP, _T], @@ -) -> JITKernel[_KP, _T]: +) -> JITKernel[_KP, _T]: @@ - assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}" + if not isinstance(func, PrimFunc): + raise TypeError(f"target function must be a PrimFunc but got {type(func)}")tilelang/language/eager/builder.py (1)
981-991: Root cause of eager kwargs bug:_parse_phase1_keymutates kwargs viapop()
_parse_phase1_keycurrentlypop()s tensor kwargs from the passed dict. SinceJITImpl.__call__passes its ownkwargsintoparse_args, this can delete user-supplied tensor kwargs before later compilation/execution steps.Fix by copying kwargs at the start of
_parse_phase1_key(or at least insideparse_args) so caller-owned dicts aren’t mutated.Proposed fix
@@ def _parse_phase1_key(self, *args, **kwargs): - kwargs.update({k: v for k, v in zip(self.arg_names, args)}) + kwargs = dict(kwargs) # do not mutate caller kwargs + kwargs.update({k: v for k, v in zip(self.arg_names, args)}) tensor_args = {} for k in self.tensor_args: if k in kwargs: tensor_args[k] = kwargs.pop(k) elif k in self.tensor_args_defaults: tensor_args[k] = self.tensor_args_defaults[k] p1_key = tuple(sorted(kwargs.items())) return p1_key, tensor_args, kwargsAlso applies to: 1045-1057
🤖 Fix all issues with AI agents
In @testing/python/language/test_tilelang_language_frontend_v2.py:
- Around line 208-223: The assertions in test_var_assign compare tensor elements
to Python ints, which can yield ambiguous truth values; change the checks to
extract Python scalars or compare tensors explicitly: call .item() on res[0] and
res[1] (e.g., res[0].item() == 1, res[1].item() == 2) or use torch.equal/tensor
comparison APIs to assert equality of tensors returned by test_var_assign.
- Around line 260-288: In both JIT tests, the tensor A is T.int32 but the code
writes float literals (1.0, 2.0) into it; change the assignments in
stepped_serial to use integer literals (1 and 2) so writes match A's dtype and
avoid backend-dependent casting, and verify stepped_serial_neg already writes
integer i (no change needed) — update A[i] = 1.0 -> A[i] = 1 and A[i] = 2.0 ->
A[i] = 2 inside the stepped_serial function.
In @tilelang/__init__.py:
- Line 145: Remove the ineffective "# noqa: F401" markers on the re-export
import lines for jit, JITKernel, compile, and par_compile in tilelang.__init__;
either delete those trailing noqa comments on the import statements (and let
lint report/unify behavior) or instead enable the F401/unused-import rule in
your Ruff/linters if you intentionally want to suppress unused-import warnings;
apply the same change to the other identical import line referenced in the
comment.
- Line 145: The public export of lazy_jit was removed but undocumented; either
add a release note directing users to replace lazy_jit with
@tilelang.jit(mode='lazy') or restore a deprecated alias: re-export lazy_jit in
tilelang.__init__ that points to the existing jit wrapper (e.g., def
lazy_jit(*args, **kwargs): warnings.warn("lazy_jit is deprecated; use
tilelang.jit(mode='lazy')", DeprecationWarning, stacklevel=2); return jit(*args,
mode='lazy', **kwargs)), and ensure the alias appears alongside jit, JITKernel,
compile, par_compile so imports continue to work while emitting the deprecation
warning.
In @tilelang/jit/__init__.py:
- Around line 318-325: The jit decorator claims to accept PrimFunc but always
funnels to prim_func(func, eager_jit=True) which expects a Python callable;
update the wrapper for jit to detect PrimFunc inputs (e.g., isinstance(func,
PrimFunc) or an equivalent check) and raise a clear, targeted error telling
users to use tilelang.jit.compile(...) (or the appropriate compile path) instead
of passing a PrimFunc, and apply the same check in the other jit entry path
referenced around the 467-529 region so PrimFunc is rejected consistently rather
than hitting inspect.signature in prim_func.
- Around line 318-325: The eager-path bug comes from JITFunc._parse_phase1_key
popping tensor kwargs out of the dict passed from JITImpl.__call__, which
mutates kwargs before compile/execute; fix by making defensive shallow copies in
JITImpl.__call__: create kw_for_parse = dict(kwargs) and pass that to
self.func.parse_args(...), then create kw_for_compile = dict(kwargs) (or reuse
the original kwargs untouched) when calling self.compile(*args,
**kw_for_compile) / self.execute(...). Also ensure eager arg ordering is stable
by building the eager positional list from args plus a deterministic walk of
parameter names (e.g., for name in self.func.param_names_in_order: if name in
kw_for_compile: append(kw_for_compile[name])) so keyword tensors are preserved
and ordered consistently.
- Around line 318-325: initialize_jit_mode currently unconditionally calls
self.func.set_mode(...) which will fail when self.func is a PrimFunc; guard the
call by checking the callable's capability (e.g., if hasattr(self.func,
"set_mode") or isinstance(self.func, the expected wrapper) ) and only call
set_mode when present, otherwise skip setting mode on PrimFunc (or call the
appropriate PrimFunc API if one exists). Also ensure the rest of
initialize_jit_mode behavior (mode inference via _infer_jit_mode, out_idx check,
and returning self.mode) remains unchanged.
In @tilelang/language/__init__.py:
- Around line 14-16: The wildcard re-export from .eager (from .eager import *)
is triggering the wrong noqa and risking API drift; either change the noqa
comment to suppress the wildcard-import error (use # noqa: F403 or # noqa:
F401,F403) or replace the wildcard with explicit names and a curated __all__
list (import the public symbols from .eager and .tir.ir and set __all__ =
['SymbolA', 'SymbolB', 'Layout', 'Fragment', ...]) so the public surface is
pinned and Ruff no longer flags unused/wildcard re-exports.
In @tilelang/language/eager/builder.py:
- Around line 731-737: The guard that raises JITNoBuilderError when
Builder.current() is None is correct, but the error text is misleading; update
the raised message (the string passed to JITNoBuilderError in the block that
checks Builder.current()) to mention the supported builder contexts (for
example: "T.macro can only be used inside a Builder context (e.g. @tilelang.jit
or @T.prim_func)") so users in a prim_func context aren’t misled.
🧹 Nitpick comments (5)
tilelang/language/proxy.py (1)
278-283: Runtime guard enforces Builder context requirement.The guard ensures
make_tensoris only called within a JIT or prim_func context by checking for an active Builder. The lazy import of Builder avoids circular dependencies.Note: The static analysis tool flags TRY003 (long message outside exception class). While the descriptive error message is helpful, consider defining it within the
JITNoBuilderErrorclass if this pattern becomes widespread.tilelang/language/kernel.py (1)
283-291: Runtime guard enforces Builder context requirement.The guard ensures
T.Kernel()is only called within a JIT or prim_func context. The implementation mirrors the pattern inproxy.py(lines 278-283) with a lazy Builder import to avoid circular dependencies.Note: The static analysis tool flags TRY003 (long message outside exception class). Consider consolidating error messages into the exception class definition if this pattern recurs frequently.
tilelang/language/eager/__init__.py (1)
1-2: Public re-export module: define__all__(or fix noqa) to stabilize the API surfaceSince this module is now the public “surface” re-exported by
tilelang.language, consider adding__all__to prevent accidental symbol leaks as eager evolves. At minimum, align the noqa with what’s actually triggered (F403for wildcard), since# noqa: F401is currently flagged as unused.testing/python/language/test_tilelang_language_frontend_v2.py (1)
42-57:test_expr: direct import of_all_dtypesis brittleImporting a private underscore symbol (
_all_dtypes) fromtilelang.language.eager.dtypescouples the test to internal layout. Prefer a public export (e.g.,tilelang.dtypesor a documented list) if one exists.tilelang/language/eager/builder.py (1)
1109-1149: Typing:func: Callable = Noneshould beCallable | None(orOptional[Callable])Ruff’s
RUF013note is valid: the annotation currently impliesfuncis always callable even thoughNoneis a supported value.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
testing/python/language/test_tilelang_language_frontend_v2.pytilelang/__init__.pytilelang/jit/__init__.pytilelang/jit/adapter/tvm_ffi.pytilelang/language/__init__.pytilelang/language/allocate.pytilelang/language/eager/__init__.pytilelang/language/eager/ast.pytilelang/language/eager/builder.pytilelang/language/eager/dtypes.pytilelang/language/eager/utils.pytilelang/language/kernel.pytilelang/language/loop.pytilelang/language/print_op.pytilelang/language/proxy.py
✅ Files skipped from review due to trivial changes (1)
- tilelang/language/loop.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
testing/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/language/test_tilelang_language_frontend_v2.py
🧬 Code graph analysis (7)
tilelang/jit/adapter/tvm_ffi.py (1)
tilelang/language/eager/dtypes.py (1)
dtype(14-15)
tilelang/language/eager/builder.py (1)
tilelang/jit/exceptions.py (2)
JITNoBuilderError(4-13)EagerJITBuildError(16-24)
testing/python/language/test_tilelang_language_frontend_v2.py (6)
tilelang/language/tir/op.py (1)
all(1913-1930)tilelang/jit/__init__.py (3)
jit(450-450)jit(454-464)jit(467-531)tilelang/language/allocate.py (3)
alloc_var(86-86)alloc_var(90-90)alloc_var(93-147)tilelang/language/eager/dtypes.py (3)
int32(245-245)dtype(14-15)float32(300-300)tilelang/language/kernel.py (2)
Kernel(229-312)threads(215-219)tilelang/language/proxy.py (1)
Tensor(233-233)
tilelang/language/print_op.py (1)
tilelang/language/eager/builder.py (1)
Builder(167-695)
tilelang/language/allocate.py (2)
tilelang/language/eager/dtypes.py (1)
dtype(14-15)tilelang/language/eager/builder.py (1)
OutTensor(114-120)
tilelang/language/eager/__init__.py (3)
tilelang/jit/adapter/tvm_ffi.py (1)
prim_func(319-321)tilelang/language/eager/builder.py (8)
prim_func(189-198)prim_func(1109-1150)macro(201-228)macro(746-782)PrimFunc(704-713)JITFunc(953-1084)Ref(124-135)const(851-879)tilelang/language/proxy.py (2)
Ref(245-245)Ref(253-253)
tilelang/language/proxy.py (2)
tilelang/jit/exceptions.py (1)
JITNoBuilderError(4-13)tilelang/language/eager/builder.py (2)
Builder(167-695)current(184-186)
🪛 Ruff (0.14.10)
tilelang/language/kernel.py
290-290: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/jit/__init__.py
323-323: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/__init__.py
145-145: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
163-163: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/language/eager/builder.py
196-196: Avoid specifying long messages outside the exception class
(TRY003)
733-733: Avoid specifying long messages outside the exception class
(TRY003)
870-870: Avoid specifying long messages outside the exception class
(TRY003)
1017-1017: Consider moving this statement to an else block
(TRY300)
1043-1043: Avoid specifying long messages outside the exception class
(TRY003)
1084-1084: Avoid specifying long messages outside the exception class
(TRY003)
1109-1109: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/language/__init__.py
14-14: from .eager import * used; unable to detect undefined names
(F403)
14-14: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/language/eager/__init__.py
1-1: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
2-2: from .dtypes import * used; unable to detect undefined names
(F403)
tilelang/language/proxy.py
282-282: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (11)
tilelang/jit/adapter/tvm_ffi.py (1)
22-22: LGTM: Import migration to eager path.The import path update from
v2.dtypestoeager.dtypesaligns with the broader migration to the eager backend. The functional behavior remains unchanged.tilelang/language/proxy.py (1)
13-13: LGTM: Import for new runtime guard.The
JITNoBuilderErrorimport supports the new Builder context enforcement added below.tilelang/language/kernel.py (1)
10-10: LGTM: Import for new runtime guard.The
JITNoBuilderErrorimport supports the Builder context enforcement added in theKernelfunction.tilelang/language/print_op.py (1)
6-6: LGTM: Import migration to eager path.The Builder import path update from
v2.buildertoeager.builderaligns with the broader migration to the eager backend. The functional behavior at line 128 (Builder.current().get_fileline_stack) remains unchanged.tilelang/language/allocate.py (1)
31-33: LGTM: Import migrations to eager path.All three import path updates (
dtypes,tl_dtype,OutTensor) consistently migrate from the v2 to eager backend. The functional behavior throughout the file (lines 144, 159, 232, 249, 253, 257, 262, 267, 272, 274, 276) remains unchanged.testing/python/language/test_tilelang_language_frontend_v2.py (4)
7-9: Boolean expression test: good shift to TIR ops, but keep operator choice consistentSwitching
cond()toOr(Not(tir_all(...)), ...)is the right direction for TIR-safe boolean composition.Also applies to: 451-455
295-321:swap_var/swap_idx: eager-style annotation placement looks correctAnnotating
A: T.Tensor[...]inside the jitted function body is consistent with the eager builder pattern used elsewhere in this PR.
323-340:while_loop: return +.item()usage looks correctThis test is aligned with the eager execution behavior (jitted function returns the output tensor directly).
457-477:probe(A, tmp: bool): constexpr branching looks intentionalUsing a Python
boolto control compile-time branching is consistent with “constexpr if” behavior; call sites (probe(A, True/False)) validate both branches.tilelang/language/eager/builder.py (2)
189-199: Builder context cleanup inprim_func: good hardeningThe
try/finallyrestoringthread_local_storage.builderand the “all outputs returned” validation make eager/lazy behavior much less footgun-prone.
851-880:T.const()eager-only guard: good, but keep exception type stableEnforcing
builder is None or not builder.eager_jit→JITNoBuilderErrormatches the mode-inference strategy. This is a good contract for the “call once outside builder to detect eager” trick.
…ove compatibility with eager execution. Refactor related function signatures and logic, ensuring seamless integration with the latest compilation features. This change aims to streamline JIT operations and enhance overall code clarity.
kurisu6912
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Maybe we can rename examples/lazy_jit to examples/eager_jit (XD: lazy is eager, non-lazy is lazy
SiriusNEO
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just two small comments
|
This feature seems to be broken. When using lazy mode, some syntax error can falsely report |
met same problem ... |
|
@Da1sypetals @senlyu163 Hi! The message seems to show that you use |
|
@SiriusNEO thx, it works fine. I make a mistake. |
If there are tilelang error in @T.prim_func, then this error will be wrongly raised. A simple example is create a buffer without dtype. from icecream import ic
import torch
import einops as ein
import tilelang
import tilelang.language as T
@tilelang.jit(
out_idx=[0],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def example():
@T.prim_func
def main(x: T.Tensor([8, 8], dtype=T.float32)):
with T.Kernel(1, threads=128) as i_seq:
R = T.alloc_shared([8, 8])
return main
kernel = example()
kernel() |
|
@Da1sypetals it works for me on the upstream commit. R = T.alloc_shared([8, 8])
^^^^^^^^^^^^^^^^^^^^^^
TypeError: alloc_shared() missing 1 required positional argument: 'dtype' |
Thanks! I'll test it when the next release come out. |
Summary
This PR unifies
@tilelang.jitand@tilelang.lazy_jitinto a single@tilelang.jitdecorator that automatically infers the execution mode based on function behavior.Two JIT Writing Styles
TileLang now supports two kernel writing styles with a unified decorator:
1. Lazy Style (explicit PrimFunc return)
2. Eager Style (DSL builder pattern)
How Mode Inference Works
The decorator automatically distinguishes between lazy and eager styles by:
PrimFunc→ lazy modeJITNoBuilderError→ eager modeThe key insight is that eager-style functions use features like
T.const()andT.Kernel()which require an active Builder context. When called directly (without Builder), these functions raiseJITNoBuilderError, signaling that the function is eager-style.New Exception Classes
Added
tilelang/jit/exceptions.pywith:JITNoBuilderError: Raised whenT.const()orT.Kernel()is called outside JIT/prim_func contextEagerJITBuildError: General error during eager-style kernel constructionKey Changes
lazy_jitexport fromtilelang/__init__.pyJITImplto handle both modes viamodeattribute ("auto","lazy","eager")_infer_jit_mode()method for automatic mode detectionT.Kernel()function@tilelang.jitTest plan
out_idxvalidation (only allowed in lazy mode)Summary by CodeRabbit
New Features
JITFuncclass supporting lazy/eager execution models.JITNoBuilderErrorandEagerJITBuildErrorexception types for clearer error diagnostics.Refactoring
@tilelang.lazy_jitto unified@tilelang.jitwith configurable execution modes.v2toeagerbackend.simplify_prim_funcdecorator across examples and tests.Improvements
✏️ Tip: You can customize this high-level summary in your review settings.