-
Notifications
You must be signed in to change notification settings - Fork 446
[Refactor] Unify @jit and @lazy_jit into a single @jit decorator #1617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…nd tests Updated all instances of the @tilelang.lazy_jit decorator to @tilelang.jit in the lazyjit example notebooks and related test files. This change aligns the code with the new JIT compilation approach, enhancing consistency across the codebase. Additionally, removed the lazy_jit import from the module initialization to streamline the API.
…l loops Updated the T.copy function to accept a new keyword-only parameter, loop_layout, allowing users to specify layout hints for the outermost parallel loop. Enhanced layout annotation handling in CopyNode and AtomicAddNode classes to ensure compatibility with SIMT operations. Added tests to validate the functionality of loop layout annotations, improving robustness and clarity in layout management for parallel loops.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughReplaces usages of Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Decorator as "tilelang.jit (decorator)"
participant JITImpl
participant Builder
participant Backend as "Compiler/Backend"
User->>Decorator: apply `@tilelang.jit`(func)
Decorator->>JITImpl: construct JITImpl(mode="auto", func=LazyJITFunc)
User->>JITImpl: call(...) or get_tir(...)
JITImpl->>JITImpl: _infer_jit_mode(args...)
alt inferred == "lazy"
JITImpl->>Builder: request PrimFunc / TirTemplate (lazy-style)
Builder-->>JITImpl: PrimFunc or TirTemplate
JITImpl->>User: return kernel object / PrimFunc (deferred)
else inferred == "eager"
JITImpl->>Builder: trace/build PrimFunc (eager-style)
Builder-->>JITImpl: PrimFunc
JITImpl->>Backend: compile & execute
Backend-->>JITImpl: result
JITImpl->>User: return result
end
note right of JITImpl: mode cached after first inference
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…tor/unify-jit-decorator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/jit/__init__.py (2)
363-378: PotentialAttributeErrorwhenfuncis aPrimFunc.The
get_tirmethod (line 295-296) allowsself.functo be aPrimFuncdirectly, but line 367 unconditionally callsself.func.set_mode(self.mode). Ifself.funcis aPrimFuncinstead of aLazyJITFunc, this will raise anAttributeErrorsincePrimFuncdoesn't have aset_modemethod.🔎 Proposed fix
def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: # infer jit mode on first compile if self.mode == "auto": self.mode = self._infer_jit_mode(*args, **kwargs) - self.func.set_mode(self.mode) + if hasattr(self.func, 'set_mode'): + self.func.set_mode(self.mode) prim_func = self.get_tir(*args, **kwargs)
188-189: Unreachable code: duplicatereturnstatement.Line 189 is unreachable because the function already returns on line 188. This appears to be leftover code.
🔎 Proposed fix
for future in tqdm( concurrent.futures.as_completed(futures), total=len(futures), desc="Parallel Compiling", ): idx = future_map[future] if ignore_error: try: results[idx] = future.result() except Exception as e: logger.warning(f"Error compiling function at index {idx}: {e}") results[idx] = None else: results[idx] = future.result() return results - return results
🧹 Nitpick comments (8)
examples/lazy_jit/lazyjit.en.ipynb (1)
24-24: Update notebook documentation to reflect the unified decorator.The notebook title and several text cells still reference "Lazy JIT" terminology, which may confuse users now that
lazy_jithas been removed from the public API. Consider updating:
- Line 24: Title "Tilelang Lazy JIT"
- Line 40: Reference to "Tilelang Lazy JIT"
- Line 554: "LazyJIT has very small overhead"
- Line 621: "Both
lazyjitand the originaljit"These should be updated to reflect the unified
@tilelang.jitdecorator and explain that the mode (lazy vs eager) is automatically inferred.Also applies to: 40-40, 554-554, 621-621
tilelang/__init__.py (1)
145-145: Remove the unusednoqadirective.The static analysis tool correctly identified that the
# noqa: F401comment is unnecessary. Since all imported items are now used or intentionally exported, the directive can be removed.🔎 Proposed fix
- from .jit import jit, JITKernel, compile, par_compile # noqa: F401 + from .jit import jit, JITKernel, compile, par_compileexamples/lazy_jit/lazyjit.zh.ipynb (1)
24-24: Update Chinese notebook documentation to reflect the unified decorator.Similar to the English version, the Chinese notebook's title and text still reference "Lazy JIT" terminology. Consider updating the documentation to reflect that
@tilelang.jitis now the unified decorator that automatically infers execution mode.Also applies to: 40-40, 554-554, 621-621
tilelang/jit/__init__.py (1)
299-301: Consider usingTypeErrorfor invalid type.Per Python conventions (PEP),
TypeErroris more appropriate when an argument has an incorrect type.🔎 Proposed fix
- else: - raise ValueError(f"Invalid function type: {type(self.func)}") + else: + raise TypeError(f"Invalid function type: {type(self.func)}")tilelang/language/v2/builder.py (4)
846-848: Consider improving assertion error messages.Both assertions produce the same error message but check different conditions. The first checks if we're inside a JIT context, the second checks if the builder is in JIT mode. Clearer messages would help debugging.
🔎 Proposed fix
builder = Builder.current() - assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)" - assert builder.lazy_jit, "T.const() can only be used inside @tilelang.jit (eager mode)" + assert builder is not None, "T.const() must be called inside a @tilelang.jit decorated function" + assert builder.lazy_jit, "T.const() requires eager mode (DSL builder pattern with tensor annotations)"
985-990: Broad exception catch in style detection.Catching all
Exceptiontypes is intentional here (to treat any failure as "not lazy style"), but it may mask genuine bugs in user code. Consider logging at a higher level or limiting to specific exceptions if possible.🔎 Proposed fix
try: result = self.orig_func(*args, **kwargs) return isinstance(result, PrimFunc) - except Exception: - logger.debug("Function doesn't return PrimFunc directly, treating as eager style") + except Exception as e: + logger.debug(f"Function doesn't return PrimFunc directly, treating as eager style: {e}") return False
1024-1028: Outdated comment referencing "legacy gemm".The comment on line 1025 references "legacy gemm" which doesn't align with the new lazy/eager terminology. Consider updating for consistency.
🔎 Proposed fix
(p1_key, _), tensor_args = self.parse_args(*args, **kwargs) if p1_key not in self.p1_cache: - # in legacy gemm, we use lazy tir template to build the tir + # build and cache the TIR template on first call with this config tir_temp = self._build_tir_template(*args, **kwargs) self.p1_cache[p1_key] = tir_temp return tir_temp.get_tir(**tensor_args)
1061-1061: ImplicitOptionaltype hint.PEP 484 recommends explicit
Optional[X]rather thanX = Nonefor default None parameters. This improves type checker compatibility.🔎 Proposed fix
+from typing import Optional + -def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]: +def prim_func(func: Optional[Callable[_P, _T]] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/lazy_jit/lazyjit.en.ipynbexamples/lazy_jit/lazyjit.zh.ipynbtesting/python/language/test_tilelang_language_lazy_jit.pytesting/python/layout/test_tilelang_annotate_loop_layout.pytilelang/__init__.pytilelang/jit/__init__.pytilelang/language/v2/builder.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_lazy_jit.pytesting/python/layout/test_tilelang_annotate_loop_layout.pyexamples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2025-12-26T06:45:51.789Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1483
File: tilelang/jit/adapter/cutedsl/adapter.py:93-95
Timestamp: 2025-12-26T06:45:51.789Z
Learning: For the CuTeDSL backend in tilelang/jit/adapter/cutedsl/adapter.py, the host_kernel_source and device_kernel_source have the same value.
Applied to files:
examples/lazy_jit/lazyjit.en.ipynbexamples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
examples/lazy_jit/lazyjit.en.ipynbexamples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
examples/lazy_jit/lazyjit.en.ipynbtesting/python/layout/test_tilelang_annotate_loop_layout.pyexamples/lazy_jit/lazyjit.zh.ipynb
🧬 Code graph analysis (4)
tilelang/__init__.py (2)
tilelang/jit/__init__.py (5)
jit(439-439)jit(443-453)jit(456-496)compile(47-107)compile(363-393)tilelang/jit/kernel.py (1)
JITKernel(38-767)
testing/python/language/test_tilelang_language_lazy_jit.py (3)
tilelang/jit/__init__.py (3)
jit(439-439)jit(443-453)jit(456-496)tilelang/language/v2/builder.py (1)
const(831-856)tilelang/language/symbolics.py (1)
dynamic(12-29)
testing/python/layout/test_tilelang_annotate_loop_layout.py (1)
tilelang/jit/__init__.py (3)
jit(439-439)jit(443-453)jit(456-496)
tilelang/language/v2/builder.py (1)
tilelang/jit/__init__.py (1)
get_tir(291-302)
🪛 Ruff (0.14.10)
tilelang/__init__.py
145-145: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/jit/__init__.py
300-300: Prefer TypeError exception for invalid type
(TRY004)
300-300: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/v2/builder.py
988-988: Do not catch blind exception: Exception
(BLE001)
1009-1009: Avoid specifying long messages outside the exception class
(TRY003)
1061-1061: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (10)
examples/lazy_jit/lazyjit.en.ipynb (1)
56-56: LGTM! Decorator changes are consistent.All decorator changes from
@tilelang.lazy_jitto@tilelang.jitare correctly applied.Also applies to: 212-212, 251-251, 310-310, 362-362, 424-424, 473-473, 518-518, 580-580, 800-800, 860-860
tilelang/__init__.py (1)
145-145: LGTM! Breaking change correctly implements the API unification.The removal of
lazy_jitfrom the public exports is correct and aligns with the PR's stated goal of unifying the decorators. Users should migrate to@tilelang.jitas documented in the PR objectives.testing/python/language/test_tilelang_language_lazy_jit.py (1)
9-9: LGTM! Test updates are comprehensive and correct.All test function decorators have been correctly updated from
@tilelang.lazy_jitto@tilelang.jit. The test coverage spans various annotation patterns (T.const, T.dynamic, T.Tensor, T.StridedTensor), ensuring the unified decorator works correctly across different usage scenarios.Also applies to: 48-48, 105-105, 112-112, 119-119, 126-126, 134-134, 141-141, 178-178, 184-184, 189-189, 195-195, 202-202, 208-208
examples/lazy_jit/lazyjit.zh.ipynb (1)
56-56: LGTM! Decorator changes are consistent.All decorator changes from
@tilelang.lazy_jitto@tilelang.jitare correctly applied throughout the Chinese notebook.Also applies to: 212-212, 251-251, 310-310, 362-362, 424-424, 473-473, 518-518, 580-580, 800-800, 860-860
testing/python/layout/test_tilelang_annotate_loop_layout.py (1)
7-7: LGTM! Layout test updates are correct.All three kernel function decorators have been correctly updated from
@tilelang.lazy_jitto@tilelang.jit. The tests cover loop layout annotation scenarios and will verify that the unified decorator works correctly with the layout system.Also applies to: 54-54, 82-82
tilelang/jit/__init__.py (4)
192-279: LGTM!The comprehensive docstring with examples for both lazy and eager modes significantly improves developer experience. The attribute documentation is clear and the type annotations are well-defined.
304-316: LGTM!The mode inference logic is clean with appropriate early returns for explicit modes and sensible fallback behavior.
485-494: LGTM!The decorator correctly wraps the function in a
LazyJITFuncand initializesJITImplwithmode="auto"for automatic inference. The comment on line 486 could be slightly clarified to distinguish between the internallazy_jitwrapper flag and the user-facing execution mode.
420-432: Type annotation forfuncis explicit and unambiguous—no type guard needed.The field
funcat line 279-280 is explicitly typed asfunc: LazyJITFunc[_KP, _T], not a union or generic allowingPrimFunc. Thejitdecorator always wraps the input function withprim_func(..., lazy_jit=True)before passing it toJITImpl, ensuringself.funcis alwaysLazyJITFuncat runtime. Theparse_argscall at line 421 is safe. The defensiveisinstancechecks in theget_tirmethod are unrelated to the__call__code path.Likely an incorrect or invalid review comment.
tilelang/language/v2/builder.py (1)
859-924: LGTM!The
TirTemplaterefactoring cleanly separates lazy and eager paths. Thefrom_lazy_stylefactory method and the short-circuit inget_tirare well-designed for the dual execution model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 428-433: parse_args can return None for kernel_args when there are
no tensor arguments, but the code unconditionally calls
kernel(*kernel_args.values()), which will raise when kernel_args is None; update
the call site in the wrapper (the block that gets kernel from _kernel_cache,
calls self.compile, then invokes kernel) to handle the None case by invoking
kernel() when kernel_args is falsy (e.g., use kernel(*kernel_args.values()) if
kernel_args else kernel()), or change parse_args to always return an empty dict;
ensure this fix touches the call that currently does
kernel(*kernel_args.values()) and keep the compile/_kernel_cache logic
unchanged.
- Around line 277-279: The runtime isinstance checks in get_tir and
_infer_jit_mode don't match the declared JITImpl.func type (LazyJITFunc[_KP,
_T]) and are redundant because the decorator always wraps the input with
prim_func(..., lazy_jit=True); update the code to either remove the
PrimFunc/Callable isinstance branches in get_tir and _infer_jit_mode and treat
func as LazyJITFunc (using LazyJITFunc methods/properties directly), or change
the annotation of JITImpl.func to PrimFunc[_KP,_T] | LazyJITFunc[_KP,_T] |
Callable[_KP,_T] and keep the checks—pick one consistent approach and apply it
to both get_tir and _infer_jit_mode, referencing JITImpl.func, LazyJITFunc,
PrimFunc, prim_func and the decorator wrapping logic to ensure types and runtime
checks align.
🧹 Nitpick comments (1)
tilelang/jit/__init__.py (1)
295-302: Consider usingTypeErrorfor type validation.When the function type is invalid,
TypeErroris more semantically appropriate thanValueErrorfor type-related errors.Proposed fix
else: - raise ValueError(f"Invalid function type: {type(self.func)}") + raise TypeError(f"Invalid function type: {type(self.func)}")Based on static analysis hints.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
3rdparty/tvmtilelang/jit/__init__.py
✅ Files skipped from review due to trivial changes (1)
- 3rdparty/tvm
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/jit/__init__.py (1)
tilelang/language/v2/builder.py (9)
LazyJITFunc(928-1036)PrimFunc(690-699)_is_lazy_style(967-990)prim_func(182-189)prim_func(1061-1099)get_tir(915-924)get_tir(1022-1029)parse_args(1011-1020)get(219-220)
🪛 Ruff (0.14.10)
tilelang/jit/__init__.py
300-300: Prefer TypeError exception for invalid type
(TRY004)
300-300: Avoid specifying long messages outside the exception class
(TRY003)
370-373: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (6)
tilelang/jit/__init__.py (6)
195-265: Excellent documentation of the unified JIT interface.The docstring clearly explains the lazy vs eager execution modes with concrete examples. This will significantly help users understand when each mode is triggered and how to use them effectively.
304-316: Mode inference logic is sound.The method correctly infers the execution mode by:
- Returning existing mode if already set
- Delegating to
_is_lazy_style()to determine if the function returns a PrimFunc (lazy) or uses builder pattern (eager)This aligns well with the PR objectives for automatic mode detection.
435-440: Clean mode-based execution logic.The conditional execution based on mode is well-structured:
- Eager mode: Executes kernel immediately and returns result
- Lazy mode: Returns kernel object for manual invocation
This aligns perfectly with the PR objectives for unified behavior.
476-501: Well-documented decorator interface.The docstring clearly explains:
- Automatic mode inference
- Parameter purposes and constraints
- That
out_idxonly works in lazy modeThis provides good guidance for users migrating from the old two-decorator system.
364-376: Mode inference and validation logic is well-implemented.The compile method correctly:
- Infers mode on first compilation (line 365-366)
- Validates that
out_idxis only used in lazy mode (lines 369-373)- Propagates the mode to the wrapped function (line 375)
This ensures consistent behavior across the JIT lifecycle.
514-523: Parameter naming is correct; no issue identified.Line 516 correctly passes
lazy_jit=Truetoprim_func. Theprim_funcfunction signature intilelang/language/v2/builder.pydefines this parameter aslazy_jit: bool = False, so the call matches the actual function definition. Theis_lazy_stylenaming exists separately as internal fields and methods within theLazyJITFuncclass, not as theprim_funcparameter name.
| mode: Literal["auto", "lazy", "eager"] | ||
| # place func at the last element for better __repr__ | ||
| func: Callable[_P, _T] | PrimFunc[_KP, _T] | LazyJITFunc[_KP, _T] | ||
| func: LazyJITFunc[_KP, _T] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's locate the file and check its size
find . -name "__init__.py" -path "*tilelang/jit*" -type fRepository: tile-ai/tilelang
Length of output: 288
🏁 Script executed:
# Read the relevant sections of the file
head -300 tilelang/jit/__init__.py | tail -50Repository: tile-ai/tilelang
Length of output: 2177
🏁 Script executed:
# Get context around lines 277-279 and beyond
sed -n '270,330p' tilelang/jit/__init__.pyRepository: tile-ai/tilelang
Length of output: 2818
🏁 Script executed:
# Search for JITImpl class definition and instantiation patterns
rg -n -B 5 -A 10 'class JITImpl' tilelang/jit/__init__.pyRepository: tile-ai/tilelang
Length of output: 657
🏁 Script executed:
# Search for where func is used and what types it might receive
rg -n 'func.*LazyJITFunc|func.*PrimFunc|func.*Callable' tilelang/jit/__init__.pyRepository: tile-ai/tilelang
Length of output: 1193
🏁 Script executed:
# Look at the decorator implementation around line 514 where JITImpl is likely created
sed -n '514,570p' tilelang/jit/__init__.pyRepository: tile-ai/tilelang
Length of output: 521
🏁 Script executed:
# Search for JITImpl instantiation to see how func is passed
rg -n 'JITImpl(' tilelang/jit/__init__.pyRepository: tile-ai/tilelang
Length of output: 126
🏁 Script executed:
# Search for all JITImpl instantiations with a simpler pattern
rg -n 'JITImpl\[' tilelang/jit/__init__.pyRepository: tile-ai/tilelang
Length of output: 267
🏁 Script executed:
# Look for places where JITImpl is directly instantiated outside the decorator
rg -n 'return JITImpl' tilelang/jit/__init__.pyRepository: tile-ai/tilelang
Length of output: 87
🏁 Script executed:
# Check if there are any other files that might instantiate JITImpl
rg -r 'JITImpl(' tilelang/ --type=py | head -20Repository: tile-ai/tilelang
Length of output: 1115
🏁 Script executed:
# Look at get_tir method more carefully - what does it expect?
sed -n '292,310p' tilelang/jit/__init__.pyRepository: tile-ai/tilelang
Length of output: 885
🏁 Script executed:
# Check line 389 to understand the PrimFunc check
sed -n '385,395p' tilelang/jit/__init__.pyRepository: tile-ai/tilelang
Length of output: 536
Remove or update runtime type checks in get_tir and _infer_jit_mode to match type annotation.
The func field is declared as LazyJITFunc[_KP, _T] (line 279), and the decorator always wraps the input function in LazyJITFunc via prim_func(..., lazy_jit=True) before instantiation (lines 514-522). However, get_tir() (lines 295-297) and _infer_jit_mode() (line 314) include runtime type checks for PrimFunc and generic Callable that don't match this guaranteed type.
Since JITImpl.func is always LazyJITFunc at instantiation, remove the unnecessary isinstance checks for other types, or update the type annotation to PrimFunc[_KP, _T] | LazyJITFunc[_KP, _T] | Callable[_KP, _T] if these cases are intentionally supported.
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 277 - 279, The runtime isinstance
checks in get_tir and _infer_jit_mode don't match the declared JITImpl.func type
(LazyJITFunc[_KP, _T]) and are redundant because the decorator always wraps the
input with prim_func(..., lazy_jit=True); update the code to either remove the
PrimFunc/Callable isinstance branches in get_tir and _infer_jit_mode and treat
func as LazyJITFunc (using LazyJITFunc methods/properties directly), or change
the annotation of JITImpl.func to PrimFunc[_KP,_T] | LazyJITFunc[_KP,_T] |
Callable[_KP,_T] and keep the checks—pick one consistent approach and apply it
to both get_tir and _infer_jit_mode, referencing JITImpl.func, LazyJITFunc,
PrimFunc, prim_func and the decorator wrapping logic to ensure types and runtime
checks align.
| kwargs.update(kwargs.pop("__tune_params", {})) | ||
| key, kernel_args = self.func.parse_args(*args, **kwargs) | ||
| kernel = self._kernel_cache.get(key, None) | ||
| if kernel is None: | ||
| kernel = self.compile(*args, **kwargs) | ||
| self._kernel_cache[key] = kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, find the LazyJITFunc class and parse_args method
fd tilelang -type f -name "*.py" | head -20Repository: tile-ai/tilelang
Length of output: 230
🏁 Script executed:
# Search for LazyJITFunc class definition
rg "class LazyJITFunc" -A 50Repository: tile-ai/tilelang
Length of output: 3575
🏁 Script executed:
# Search for parse_args method definition
rg "def parse_args" -B 5 -A 20Repository: tile-ai/tilelang
Length of output: 2085
🏁 Script executed:
# Check if parse_cache_key is used anywhere
rg "parse_cache_key" -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 1602
🏁 Script executed:
# Look at the specific file to understand context
cat tilelang/jit/__init__.py | head -n 450 | tail -n 50Repository: tile-ai/tilelang
Length of output: 2083
Fix the None case handling for kernel_args when calling kernel.
The parse_args method can return None as kernel_args when no tensor arguments are present (line 429), but line 437 unconditionally calls kernel(*kernel_args.values()) without checking for None. This will crash at runtime in eager mode when the decorated function has no tensor arguments.
Either handle the None case:
kernel(*kernel_args.values()) if kernel_args else kernel()
Or verify that parse_args always returns a dict (even if empty).
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 428 - 433, parse_args can return None
for kernel_args when there are no tensor arguments, but the code unconditionally
calls kernel(*kernel_args.values()), which will raise when kernel_args is None;
update the call site in the wrapper (the block that gets kernel from
_kernel_cache, calls self.compile, then invokes kernel) to handle the None case
by invoking kernel() when kernel_args is falsy (e.g., use
kernel(*kernel_args.values()) if kernel_args else kernel()), or change
parse_args to always return an empty dict; ensure this fix touches the call that
currently does kernel(*kernel_args.values()) and keep the compile/_kernel_cache
logic unchanged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 425-437: parse_args can return None for kernel_args, causing
kernel(*kernel_args.values()) in the eager branch to raise an AttributeError;
update the eager execution path in __call__ (or the surrounding function) to
handle kernel_args being None by calling kernel() when kernel_args is None or
calling kernel(*kernel_args.values()) when not None, and ensure kernel_arg
handling mirrors the lazy path so mode == "eager" correctly invokes the compiled
kernel with zero or more args.
🧹 Nitpick comments (1)
tilelang/jit/__init__.py (1)
304-316: Simplify mode inference logic by removing redundant type check.Since
self.funcis alwaysLazyJITFunc[_KP, _T](guaranteed by the decorator at line 513), theisinstancecheck on line 314 and the early return on line 315 are unnecessary. The logic can be simplified to directly callself.func._is_lazy_style().🔎 Proposed simplification
def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]: """ Infer the JIT execution mode based on function behavior. Returns "lazy" if the function explicitly returns a PrimFunc, or "eager" if it uses the DSL builder pattern. """ if self.mode in ("lazy", "eager"): return self.mode # auto: infer by checking if function returns PrimFunc directly - if not isinstance(self.func, LazyJITFunc): - return "lazy" return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager"
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/jit/__init__.py
🧰 Additional context used
🪛 Ruff (0.14.10)
tilelang/jit/__init__.py
300-300: Prefer TypeError exception for invalid type
(TRY004)
300-300: Avoid specifying long messages outside the exception class
(TRY003)
370-370: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
tilelang/jit/__init__.py (3)
195-265: Excellent documentation for the unified JIT decorator.The comprehensive docstring clearly explains both execution modes with concrete examples and describes the automatic inference behavior. This will help users understand the migration from separate decorators.
461-522: Well-structured unified decorator implementation.The decorator correctly wraps all functions with
prim_func(..., lazy_jit=True)to create aLazyJITFunc, then initializesJITImplwithmode="auto"for runtime inference. This approach cleanly unifies the previous separate decorators while maintaining backward compatibility.
363-398: No action needed. Theset_modemethod is properly implemented inLazyJITFuncattilelang/language/v2/builder.py:1034. Line 372's call toself.func.set_mode(self.mode)is valid.Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 425-443: parse_args can return kernel_args = None (see
LazyJITFunc.parse_args), so calling kernel(*kernel_args.values()) will crash;
update the eager path in __call__ (after key/kernel lookup in __init__.py) to
detect if kernel_args is None and invoke kernel() with no args, otherwise call
kernel(*kernel_args.values()). Keep existing logic for mode inference
(self._infer_jit_mode, self.func.set_mode) and kernel caching (_kernel_cache,
compile) unchanged.
- Around line 304-316: The isinstance check for LazyJITFunc in _infer_jit_mode
is redundant because self.func is always a LazyJITFunc; remove the entire branch
`if not isinstance(self.func, LazyJITFunc): return "lazy"` and simplify the
method to first honor explicit self.mode ("lazy" or "eager") and otherwise
return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager", keeping
the same return type Literal["lazy","eager"] and preserving the call to
_is_lazy_style on the LazyJITFunc instance.
- Around line 467-478: The decorator signature incorrectly allows PrimFunc but
the implementation calls prim_func(func, lazy_jit=True) which expects a Python
function and then runs mutate(), so either remove PrimFunc from the type union
in the jit signature or add an explicit runtime branch in jit that detects a
PrimFunc and handles it safely (e.g., pass-through return or wrap appropriately)
instead of calling prim_func/mutate on it; reference symbols: jit, prim_func,
PrimFunc, mutate.
In @tilelang/language/v2/builder.py:
- Line 1057: The signature of prim_func declares func with a default None but
its type omits None; update the annotation for the parameter func in prim_func
to explicitly allow None (e.g., use Optional[Callable[_P, _T]] or Callable[_P,
_T] | None) so the default value matches the type; keep the return annotation
as-is and modify only the func parameter type in the prim_func definition.
- Around line 961-984: The _is_lazy_style function currently catches a bare
Exception when calling self.orig_func, which can hide programming errors; change
the except to only catch expected runtime invocation errors (e.g., except
(TypeError, AttributeError, ValueError) as e) and log the exception details via
logger.debug including e, and do not catch BaseException (so
KeyboardInterrupt/SystemExit propagate); keep the existing return False behavior
for those specific caught exceptions to treat the function as eager style.
- Around line 826-842: The assertions in const() are contradictory: they assert
builder.lazy_jit but the error text says "eager mode". Fix by making the check
and message consistent — either test the actual JIT mode flag set in
_build_tir_template() (e.g., use the mode field on the LazyJITFunc or rename
builder.lazy_jit to a clearer name like builder.is_jit) or simply update the
assertion to assert builder.lazy_jit and change the error text to "T.const() can
only be used inside @tilelang.jit (JIT/eager mode)"; reference const(),
Builder.current(), builder.lazy_jit and _build_tir_template() to locate and
update the check and message accordingly.
🧹 Nitpick comments (1)
tilelang/language/v2/builder.py (1)
909-918: Verifymatcheris unused in lazy-style mode.When
is_lazy_styleisTrue,get_tir()returns the PrimFunc directly without substitution (lines 910-911), bypassing thematcherlogic. Confirm thatmatcheris alwaysNonefor lazy-style templates to avoid confusion.🔎 Optional assertion to enforce invariant
def get_tir(self, **kwargs): if self.is_lazy_style: + assert self.matcher is None, "Lazy-style templates should not have a matcher" return self.prim_func
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/jit/__init__.pytilelang/language/v2/builder.py
🧰 Additional context used
🪛 Ruff (0.14.10)
tilelang/language/v2/builder.py
982-982: Do not catch blind exception: Exception
(BLE001)
1003-1003: Avoid specifying long messages outside the exception class
(TRY003)
1057-1057: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/jit/__init__.py
300-300: Prefer TypeError exception for invalid type
(TRY004)
300-300: Avoid specifying long messages outside the exception class
(TRY003)
370-370: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (3)
tilelang/jit/__init__.py (2)
363-398: Verify mode inference happens beforeget_tir()uses it.The
compile()method infers the mode (lines 365-366) and sets it onself.func(line 372) before callingget_tir()(line 373). However,get_tir()callsself.func(*args, **kwargs)which invokesLazyJITFunc.__call__→get_tir()→_build_tir_template(), and that method requiresself.modeto be "lazy" or "eager" (builder.py line 1003).Confirm that
LazyJITFunc.modeis set by the time_build_tir_template()is called, or the code will raise ValueError for invalid mode.The flow looks correct: mode is inferred and set on line 372 before get_tir is called on line 373. However, verify with a test case that has
mode="auto"initially.
195-265: Excellent documentation for dual execution modes.The docstring clearly explains the difference between lazy and eager modes with concrete examples. This will help users understand when each mode is triggered and how to use them effectively.
tilelang/language/v2/builder.py (1)
986-1003: Remove this concern — mode is properly set before_build_tir_template()is called.Both
JITImpl.__call__()andJITImpl.compile()ensureself.modeis inferred and set on theLazyJITFuncinstance before calling methods that depend on it:
__call__()(line 428–430): Infers mode if "auto", then callsself.func.set_mode()beforeself.func.parse_args()(line 432)compile()(line 366–372): Infers mode if "auto", then callsself.func.set_mode()beforeself.get_tir()(line 373)The
_infer_jit_mode()method always returns "lazy" or "eager" (never "auto"), soself.modeis guaranteed to be a valid state before_build_tir_template()executes. The existing comment at line 1013 confirms the developer's awareness of this dependency.
tilelang/jit/__init__.py
Outdated
| def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]: | ||
| """ | ||
| Infer the JIT execution mode based on function behavior. | ||
|
|
||
| Returns "lazy" if the function explicitly returns a PrimFunc, | ||
| or "eager" if it uses the DSL builder pattern. | ||
| """ | ||
| if self.mode in ("lazy", "eager"): | ||
| return self.mode | ||
| # auto: infer by checking if function returns PrimFunc directly | ||
| if not isinstance(self.func, LazyJITFunc): | ||
| return "lazy" | ||
| return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove redundant type check in _infer_jit_mode().
Line 314 checks if not isinstance(self.func, LazyJITFunc) and returns "lazy", but self.func is annotated as LazyJITFunc[_KP, _T] (line 279) and the decorator ensures it's always wrapped in LazyJITFunc (line 519). This check is redundant.
🔎 Proposed fix
def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]:
"""
Infer the JIT execution mode based on function behavior.
Returns "lazy" if the function explicitly returns a PrimFunc,
or "eager" if it uses the DSL builder pattern.
"""
if self.mode in ("lazy", "eager"):
return self.mode
# auto: infer by checking if function returns PrimFunc directly
- if not isinstance(self.func, LazyJITFunc):
- return "lazy"
return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager"Based on past review comments.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]: | |
| """ | |
| Infer the JIT execution mode based on function behavior. | |
| Returns "lazy" if the function explicitly returns a PrimFunc, | |
| or "eager" if it uses the DSL builder pattern. | |
| """ | |
| if self.mode in ("lazy", "eager"): | |
| return self.mode | |
| # auto: infer by checking if function returns PrimFunc directly | |
| if not isinstance(self.func, LazyJITFunc): | |
| return "lazy" | |
| return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager" | |
| def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]: | |
| """ | |
| Infer the JIT execution mode based on function behavior. | |
| Returns "lazy" if the function explicitly returns a PrimFunc, | |
| or "eager" if it uses the DSL builder pattern. | |
| """ | |
| if self.mode in ("lazy", "eager"): | |
| return self.mode | |
| # auto: infer by checking if function returns PrimFunc directly | |
| return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager" |
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 304 - 316, The isinstance check for
LazyJITFunc in _infer_jit_mode is redundant because self.func is always a
LazyJITFunc; remove the entire branch `if not isinstance(self.func,
LazyJITFunc): return "lazy"` and simplify the method to first honor explicit
self.mode ("lazy" or "eager") and otherwise return "lazy" if
self.func._is_lazy_style(*args, **kwargs) else "eager", keeping the same return
type Literal["lazy","eager"] and preserving the call to _is_lazy_style on the
LazyJITFunc instance.
| """ | ||
| Declare constexpr variables for dynamic tensor dimensions (eager mode only). | ||
|
|
||
| In eager mode, use T.const() to declare shape dimensions that will be | ||
| inferred from actual tensor arguments at runtime. | ||
|
|
||
| Example:: | ||
|
|
||
| @tilelang.jit | ||
| def kernel(A, B): | ||
| M, N = T.const("M, N") | ||
| A: T.Tensor[[M, N], T.float32] | ||
| ... | ||
| """ | ||
| builder = Builder.current() | ||
| assert builder is not None, "const can only be used inside `tilelang.lazy_jit` function" | ||
| assert builder.lazy_jit, "const can only be used inside `tilelang.lazy_jit` function" | ||
| assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)" | ||
| assert builder.lazy_jit, "T.const() can only be used inside @tilelang.jit (eager mode)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix confusing assertion logic in const() function.
Line 842 asserts that builder.lazy_jit is True, but the error message claims T.const() can only be used in "eager mode". This is contradictory and confusing.
Looking at line 994 in _build_tir_template(), builder.lazy_jit = True is set for eager-style execution (when mode is "eager"), which suggests the field name lazy_jit is misleading—it actually indicates "uses JIT" rather than "is lazy style".
Either:
- Rename
builder.lazy_jitto something clearer likebuilder.is_jitorbuilder.from_jit - Or invert the assertion to check the mode field on LazyJITFunc instead
🔎 Suggested fix: Check mode directly
- assert builder.lazy_jit, "T.const() can only be used inside @tilelang.jit (eager mode)"
+ # T.const() is only valid in eager mode where we infer shapes from tensors
+ # In lazy mode, shape parameters are explicit function argumentsThen verify the mode when it matters, or accept that any JIT context allows const().
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In @tilelang/language/v2/builder.py around lines 826 - 842, The assertions in
const() are contradictory: they assert builder.lazy_jit but the error text says
"eager mode". Fix by making the check and message consistent — either test the
actual JIT mode flag set in _build_tir_template() (e.g., use the mode field on
the LazyJITFunc or rename builder.lazy_jit to a clearer name like
builder.is_jit) or simply update the assertion to assert builder.lazy_jit and
change the error text to "T.const() can only be used inside @tilelang.jit
(JIT/eager mode)"; reference const(), Builder.current(), builder.lazy_jit and
_build_tir_template() to locate and update the check and message accordingly.
|
|
||
|
|
||
| def prim_func(func: Callable[_P, _T] = None, *, lazy_jit=False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]: | ||
| def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use explicit Optional for optional parameter.
The func parameter has a default of None but the type annotation doesn't include None, violating PEP 484. Use explicit Optional or union with None.
🔎 Proposed fix
-def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
+def prim_func(func: Callable[_P, _T] | None = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:Based on static analysis hints.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]: | |
| def prim_func(func: Callable[_P, _T] | None = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]: |
🧰 Tools
🪛 Ruff (0.14.10)
1057-1057: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🤖 Prompt for AI Agents
In @tilelang/language/v2/builder.py at line 1057, The signature of prim_func
declares func with a default None but its type omits None; update the annotation
for the parameter func in prim_func to explicitly allow None (e.g., use
Optional[Callable[_P, _T]] or Callable[_P, _T] | None) so the default value
matches the type; keep the return annotation as-is and modify only the func
parameter type in the prim_func definition.
…le detection - Removed unnecessary parameters from the autotune decorator in example_mha_sink_fwd_bhsd.py. - Updated JITImpl to infer execution mode using static AST analysis instead of runtime checks. - Added a new function to detect lazy style in AST, enhancing the clarity of function behavior. - Refactored LazyJITFunc to utilize the new lazy style detection mechanism.
- Simplified the autotune decorator in example_mha_sink_fwd_bhsd.py by removing unnecessary parameters. - Improved JITImpl to utilize static AST analysis for mode inference, eliminating the need for runtime checks. - Introduced a new function for lazy style detection in AST, enhancing clarity and functionality. - Updated LazyJITFunc to leverage the new lazy style detection mechanism.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 430-448: parse_args() can return kernel_args == None for no-tensor
calls; change the eager-path in the JIT wrapper (the block after
self._kernel_cache lookup) to handle that: if kernel_args is None call kernel()
with no positional args, otherwise call kernel(*kernel_args.values()). Update
the code around self.func.parse_args(...) and the eager branch that currently
does kernel(*kernel_args.values()) so it checks kernel_args is not None before
accessing .values(), leaving cache key/compile behavior unchanged.
🧹 Nitpick comments (1)
tilelang/language/v2/ast.py (1)
587-657: Add parentheses to clarify operator precedence.The boolean expressions at lines 606-610 chain
andandoroperators without parentheses. While the precedence is correct (andbinds tighter thanor), explicit parentheses improve readability and address the static analysis warning.🔎 Proposed fix
for decorator in node.decorator_list: if ( - isinstance(decorator, ast.Attribute) - and decorator.attr == "prim_func" - or isinstance(decorator, ast.Name) - and decorator.id == "prim_func" + (isinstance(decorator, ast.Attribute) + and decorator.attr == "prim_func") + or (isinstance(decorator, ast.Name) + and decorator.id == "prim_func") ): has_inner_prim_func = TrueBased on static analysis hints.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/attention_sink/example_mha_sink_fwd_bhsd.pyexamples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.pytesting/python/language/test_tilelang_language_subtype.pytilelang/jit/__init__.pytilelang/language/v2/ast.pytilelang/language/v2/builder.py
💤 Files with no reviewable changes (1)
- examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_subtype.py
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
testing/python/language/test_tilelang_language_subtype.py
📚 Learning: 2025-12-15T07:48:45.785Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: src/target/codegen_cutedsl.cc:789-793
Timestamp: 2025-12-15T07:48:45.785Z
Learning: In tilelang/contrib/cutedsl, the `tl.make_rmem_tensor` function accepts both an Integer and a Tuple of Integer for its shape parameter. Therefore, both `tl.make_rmem_tensor(N, ...)` and `tl.make_rmem_tensor((N,), ...)` are valid syntaxes in CuteDSL-generated code.
Applied to files:
testing/python/language/test_tilelang_language_subtype.py
🧬 Code graph analysis (3)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
tilelang/autotuner/tuner.py (1)
autotune(691-786)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
get_configs(14-16)examples/flash_attention/example_mha_fwd_bhsd.py (1)
get_configs(11-13)examples/flash_attention/example_mha_fwd_varlen.py (1)
get_configs(15-17)examples/flash_attention/example_mha_fwd_bshd.py (1)
get_configs(11-13)
testing/python/language/test_tilelang_language_subtype.py (1)
tilelang/jit/__init__.py (3)
jit(455-455)jit(459-469)jit(472-536)
tilelang/jit/__init__.py (2)
tilelang/language/v2/builder.py (10)
LazyJITFunc(922-1042)set_mode(1028-1030)PrimFunc(684-693)_is_lazy_style(961-982)prim_func(182-189)prim_func(1067-1108)get_tir(909-918)get_tir(1016-1023)parse_args(1003-1014)get(219-220)tilelang/jit/adapter/wrapper.py (2)
prim_func(575-585)prim_func(839-849)
🪛 Ruff (0.14.10)
tilelang/language/v2/builder.py
1001-1001: Avoid specifying long messages outside the exception class
(TRY003)
1042-1042: Avoid specifying long messages outside the exception class
(TRY003)
1067-1067: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/jit/__init__.py
305-305: Prefer TypeError exception for invalid type
(TRY004)
305-305: Avoid specifying long messages outside the exception class
(TRY003)
375-375: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/v2/ast.py
607-608: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
609-610: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (12)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)
18-18: No changes needed—the warmup reduction is intentional and appropriate.The file pair demonstrates differentiated tuning strategies:
example_mha_sink_fwd_bhsd.py(non-pipelined): useswarmup=25(default)example_mha_sink_fwd_bhsd_wgmma_pipelined.py(pipelined): retainswarmup=500This reflects that simpler, non-pipelined kernels don't require extended warmup periods for stabilization, while the more complex pipelined implementation maintains the longer warmup for reliable autotuning. The reduction is targeted, not a blanket standardization.
tilelang/language/v2/ast.py (2)
584-584: LGTM!The
is_lazy_stylefield is well-integrated with the new detection logic and provides a clear semantic flag for distinguishing lazy vs eager execution modes.
690-691: LGTM!The integration of lazy-style detection into the
mutate()function is clean and correct. Detection happens before AST transformation, ensuring the original function structure is analyzed.Also applies to: 717-717
testing/python/language/test_tilelang_language_subtype.py (1)
9-9: LGTM!The decorator changes from
@tilelang.lazy_jitto@tilelang.jitare straightforward and align with the PR's unification objective. The test logic remains unchanged, ensuring behavioral consistency.Also applies to: 18-18, 98-98, 108-108, 119-119, 130-130, 142-142
tilelang/language/v2/builder.py (5)
853-895: LGTM!The
TirTemplateenhancements cleanly separate lazy-style (direct PrimFunc return) from eager-style (constexpr substitution) execution paths. Theis_lazy_styleflag andfrom_lazy_style()constructor provide clear semantics, and theget_tir()short-circuit for lazy style is efficient.Also applies to: 909-912
944-944: Excellent improvement: static AST-based mode detection.The new
_is_lazy_style()implementation correctly leverages the static AST analysis performed duringmutate(), eliminating the need for runtime PrimFunc checks and exception handling. This is more reliable and efficient than the previous approach.Also applies to: 961-982
984-1023: LGTM!The
_build_tir_template()method cleanly centralizes template construction logic based on execution mode. The integration withparse_args()andget_tir()is correct, and the caching mechanism ensures efficient reuse.Also applies to: 1028-1030
1032-1042: LGTM!The attribute proxying mechanism correctly exposes the original function's attributes for compatibility with autotuner and
inspectmodule. The implementation properly handles__wrapped__and delegates other attributes toorig_func.
826-842: Confusing assertion logic inconst()function.Line 842 asserts that
builder.lazy_jitisTrue, but the error message claimsT.const()can only be used in "eager mode". This is contradictory and confusing. Looking at line 992,builder.lazy_jit = Trueis set in what the code treats as eager-style execution, suggesting the field name is misleading.Consider either:
- Renaming
builder.lazy_jitto something clearer likebuilder.is_jitorbuilder.from_jit- Or updating the assertion message to reflect what the check actually does
Likely an incorrect or invalid review comment.
tilelang/jit/__init__.py (3)
193-265: LGTM!The extensive docstring update provides clear, concrete examples of both lazy and eager execution modes. The simplified type for
func(now alwaysLazyJITFunc[_KP, _T]) correctly reflects that the decorator always wraps functions inLazyJITFunc.Also applies to: 277-279
368-403: LGTM!The
compile()method correctly infers execution mode, validates thatout_idxis only used in lazy mode (an important constraint), and properly sets the mode onfuncbefore retrieving TIR. The error message for eager mode without_idxis helpful.
472-483: Do not allowPrimFuncin thejitdecorator signature.The type annotation
func: Callable[_P, _T] | PrimFunc | NonesuggestsPrimFuncinput is supported, but the decorator callsprim_func(func, lazy_jit=True)without checking input type. Theprim_func()function expects a Python function to analyze viamutate()(AST transformation), not aPrimFuncobject. Passing aPrimFuncwould fail becausemutate()cannot perform AST analysis on TVM IR objects.Either remove
PrimFuncfrom the type annotation or add an explicit check to handlePrimFuncseparately.🔎 Proposed fix to remove PrimFunc from type
def jit( - func: Callable[_P, _T] | PrimFunc | None = None, + func: Callable[_P, _T] | None = None, *, # Indicates subsequent arguments are keyword-only out_idx: list[int] | int | None = None,Also applies to: 522-536
Likely an incorrect or invalid review comment.
| kwargs.update(kwargs.pop("__tune_params", {})) | ||
|
|
||
| # infer mode early, before parse_args needs it | ||
| if self.mode == "auto": | ||
| self.mode = self._infer_jit_mode() | ||
| self.func.set_mode(self.mode) | ||
|
|
||
| key, kernel_args = self.func.parse_args(*args, **kwargs) | ||
| kernel = self._kernel_cache.get(key, None) | ||
| if kernel is None: | ||
| kernel = self.compile(*args, **kwargs) | ||
| self._kernel_cache[key] = kernel | ||
|
|
||
| # eager mode: execute kernel immediately and return result | ||
| # lazy mode: return kernel object for manual invocation | ||
| if self.mode == "eager": | ||
| return kernel(*kernel_args.values()) | ||
| else: | ||
| key = self.parse_cache_key(*args, **kwargs) | ||
| tune_params = kwargs.pop("__tune_params", {}) | ||
| kernel = self._kernel_cache.get(key, None) | ||
| if kernel is None: | ||
| kernel = self.compile(*args, **kwargs, **tune_params) | ||
| self._kernel_cache[key] = kernel | ||
| return kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Handle None return from parse_args() to prevent crash.
LazyJITFunc.parse_args() can return None for kernel_args when there are no tensor arguments (see builder.py line 1007: if not tensor_args: return (p1_key, None), kwargs). However, line 446 unconditionally calls kernel(*kernel_args.values()), which will raise AttributeError: 'NoneType' object has no attribute 'values' in eager mode when the function has no tensor arguments.
🔎 Proposed fix to handle None case
key, kernel_args = self.func.parse_args(*args, **kwargs)
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs)
self._kernel_cache[key] = kernel
# eager mode: execute kernel immediately and return result
# lazy mode: return kernel object for manual invocation
if self.mode == "eager":
- return kernel(*kernel_args.values())
+ return kernel(*kernel_args.values()) if kernel_args else kernel()
else:
return kernel🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 430 - 448, parse_args() can return
kernel_args == None for no-tensor calls; change the eager-path in the JIT
wrapper (the block after self._kernel_cache lookup) to handle that: if
kernel_args is None call kernel() with no positional args, otherwise call
kernel(*kernel_args.values()). Update the code around self.func.parse_args(...)
and the eager branch that currently does kernel(*kernel_args.values()) so it
checks kernel_args is not None before accessing .values(), leaving cache
key/compile behavior unchanged.
- Removed the `@simplify_prim_func` decorator from `tl_matmul` in `example_gemm_intrinsics.py`. - Introduced a new `JITMode` literal in `__init__.py` to clarify execution modes. - Updated the `jit` function to include a `mode` parameter, defaulting to "auto", enhancing flexibility in execution mode selection.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 437-446: parse_args can return kernel_args as None when there are
no tensor args, so calling kernel(*kernel_args.values()) will crash; update the
eager-mode branch in the function that calls self.func.parse_args to check
whether kernel_args is None (or falsy) and, if so, invoke kernel() with no
positional args, otherwise call kernel(*kernel_args.values()); reference the
symbols self.func.parse_args, kernel_args, kernel, and self.mode to locate and
modify the code path.
🧹 Nitpick comments (1)
tilelang/jit/__init__.py (1)
378-401: Local variable shadows importedprim_funcfunction.Line 378 assigns
prim_func = self.get_tir(...)which shadows the importedprim_funcfunction fromtilelang.language.v2. While this works because the import isn't used after this point in the method, consider renaming totir_funcorpffor clarity.🔎 Suggested rename
- prim_func = self.get_tir(*args, **kwargs) + tir_func = self.get_tir(*args, **kwargs) kernel_result = compile( - prim_func, + tir_func, out_idx=self.out_idx, ... ) ... with open(path.join(self.debug_root_path, program_file), "w") as f: - print(prim_func.script(), file=f) + print(tir_func.script(), file=f)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/gemm/example_gemm_intrinsics.pytilelang/jit/__init__.py
💤 Files with no reviewable changes (1)
- examples/gemm/example_gemm_intrinsics.py
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/jit/__init__.py (3)
tilelang/language/v2/builder.py (10)
LazyJITFunc(922-1042)set_mode(1028-1030)PrimFunc(684-693)_is_lazy_style(961-982)prim_func(182-189)prim_func(1067-1108)get_tir(909-918)get_tir(1016-1023)parse_args(1003-1014)get(219-220)tilelang/jit/kernel.py (1)
out_idx(609-610)tilelang/jit/adapter/wrapper.py (2)
prim_func(575-585)prim_func(839-849)
🪛 Ruff (0.14.10)
tilelang/jit/__init__.py
305-305: Prefer TypeError exception for invalid type
(TRY004)
305-305: Avoid specifying long messages outside the exception class
(TRY003)
375-375: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
tilelang/jit/__init__.py (2)
529-542: LGTM on decorator implementation.The unified decorator correctly wraps functions via
prim_func(), captures source and signature, and passes the newmodeparameter toJITImpl. The pattern of returningdecorator(func) if func is not None else decoratorproperly supports both@jitand@jit(...)usage.
195-238: Clear and comprehensive documentation for the unified JIT interface.The docstring provides excellent examples differentiating lazy and eager modes, making migration straightforward for users. The mode inference behavior is well-explained.
| key, kernel_args = self.func.parse_args(*args, **kwargs) | ||
| kernel = self._kernel_cache.get(key, None) | ||
| if kernel is None: | ||
| kernel = self.compile(*args, **kwargs) | ||
| self._kernel_cache[key] = kernel | ||
|
|
||
| # eager mode: execute kernel immediately and return result | ||
| # lazy mode: return kernel object for manual invocation | ||
| if self.mode == "eager": | ||
| return kernel(*kernel_args.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: kernel_args can be None, causing crash in eager mode.
LazyJITFunc.parse_args() returns (p1_key, None), kwargs when there are no tensor arguments (see builder.py line 1007). Line 446 unconditionally calls kernel(*kernel_args.values()), which will raise AttributeError: 'NoneType' object has no attribute 'values' when the decorated function has no tensor arguments.
🔎 Proposed fix
# eager mode: execute kernel immediately and return result
# lazy mode: return kernel object for manual invocation
if self.mode == "eager":
- return kernel(*kernel_args.values())
+ return kernel(*kernel_args.values()) if kernel_args else kernel()
else:
return kernel📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| key, kernel_args = self.func.parse_args(*args, **kwargs) | |
| kernel = self._kernel_cache.get(key, None) | |
| if kernel is None: | |
| kernel = self.compile(*args, **kwargs) | |
| self._kernel_cache[key] = kernel | |
| # eager mode: execute kernel immediately and return result | |
| # lazy mode: return kernel object for manual invocation | |
| if self.mode == "eager": | |
| return kernel(*kernel_args.values()) | |
| key, kernel_args = self.func.parse_args(*args, **kwargs) | |
| kernel = self._kernel_cache.get(key, None) | |
| if kernel is None: | |
| kernel = self.compile(*args, **kwargs) | |
| self._kernel_cache[key] = kernel | |
| # eager mode: execute kernel immediately and return result | |
| # lazy mode: return kernel object for manual invocation | |
| if self.mode == "eager": | |
| return kernel(*kernel_args.values()) if kernel_args else kernel() |
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 437 - 446, parse_args can return
kernel_args as None when there are no tensor args, so calling
kernel(*kernel_args.values()) will crash; update the eager-mode branch in the
function that calls self.func.parse_args to check whether kernel_args is None
(or falsy) and, if so, invoke kernel() with no positional args, otherwise call
kernel(*kernel_args.values()); reference the symbols self.func.parse_args,
kernel_args, kernel, and self.mode to locate and modify the code path.
…tection - Updated test functions to remove unnecessary `@T.prim_func` decorators and streamline variable assignments. - Enhanced lazy style detection in AST analysis to improve clarity and functionality. - Added logging to warn users when JIT mode cannot be auto-detected, encouraging explicit mode specification. - Refactored `LazyJITFunc` to return itself after setting the JIT mode for better chaining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 476-488: The jit() signature claims func can be a PrimFunc but the
implementation always calls prim_func(func, lazy_jit=True) which expects a
Python function and will break for an already-constructed PrimFunc; either
remove PrimFunc from the type union in jit()’s func parameter, or add a branch
in jit() that detects isinstance(func, PrimFunc) (or checks type) and skips
calling prim_func/AST mutation, passing the PrimFunc through to JITImpl
creation; update references to prim_func, mutate, and the type annotation
accordingly so PrimFunc values are handled safely.
- Around line 437-448: LazyJITFunc can receive kernel_args==None from
LazyJITFunc.parse_args, so calling kernel(*kernel_args.values()) in the eager
branch can raise AttributeError; update the eager execution in the function that
uses self.mode (tilelang/jit/__init__.py) to handle a None kernel_args by
invoking kernel() with no arguments when kernel_args is None (instead of
unconditionally using kernel_args.values()), e.g., check if kernel_args is None
and call kernel() else call kernel(*kernel_args.values()) so functions without
tensor args execute correctly.
🧹 Nitpick comments (3)
tilelang/language/v2/ast.py (1)
624-630: Add parentheses to clarify operator precedence in condition.The condition mixes
andandoroperators without explicit parentheses. While Python's precedence rules make this work correctly, the intent is clearer with explicit grouping.🔎 Proposed fix
if ( - isinstance(decorator, ast.Attribute) - and decorator.attr == "prim_func" - or isinstance(decorator, ast.Name) - and decorator.id == "prim_func" + (isinstance(decorator, ast.Attribute) and decorator.attr == "prim_func") + or (isinstance(decorator, ast.Name) and decorator.id == "prim_func") ):tilelang/language/v2/builder.py (2)
840-842: Confusing field namebuilder.lazy_jitfor eager mode context.The field
builder.lazy_jit = Trueis set when using eager mode (line 992), but the name suggests "lazy JIT". The assertion message correctly says "eager mode", but the field name is misleading. Consider renaming tobuilder.is_jitorbuilder.from_jit_decoratorfor clarity.
1028-1031: Type mismatch inset_modeparameter.The
modeattribute is typed asLiteral["auto", "lazy", "eager"](line 944), butset_modeacceptsLiteral["lazy", "eager"]. This is intentional since "auto" shouldn't be set after inference, but consider adding a comment or updating the docstring to clarify this.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
testing/python/language/test_tilelang_language_frontend_v2.pytilelang/jit/__init__.pytilelang/language/v2/ast.pytilelang/language/v2/builder.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.
Applied to files:
testing/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
testing/python/language/test_tilelang_language_frontend_v2.py
🧬 Code graph analysis (3)
tilelang/language/v2/ast.py (3)
tilelang/language/ast/ir.py (1)
func_name(206-214)tilelang/jit/__init__.py (1)
decorator(529-541)tilelang/language/v2/builder.py (1)
source(707-708)
tilelang/language/v2/builder.py (3)
tilelang/jit/adapter/cython/adapter.py (1)
prim_func(356-358)tilelang/language/v2/ast.py (1)
IRGenerator(583-587)tilelang/jit/__init__.py (1)
get_tir(291-307)
testing/python/language/test_tilelang_language_frontend_v2.py (2)
tilelang/jit/__init__.py (3)
jit(458-458)jit(462-473)jit(476-543)tilelang/language/allocate.py (5)
empty(267-267)empty(270-278)alloc_var(86-86)alloc_var(90-90)alloc_var(93-147)
🪛 Ruff (0.14.10)
tilelang/language/v2/ast.py
625-626: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
627-628: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
tilelang/language/v2/builder.py
1001-1001: Avoid specifying long messages outside the exception class
(TRY003)
1043-1043: Avoid specifying long messages outside the exception class
(TRY003)
1068-1068: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/jit/__init__.py
305-305: Prefer TypeError exception for invalid type
(TRY004)
305-305: Avoid specifying long messages outside the exception class
(TRY003)
375-375: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (18)
tilelang/language/v2/ast.py (5)
3-3: LGTM: Logger setup for lazy-style detection warnings.The logging module import and logger initialization are appropriate for providing diagnostic warnings when the JIT mode cannot be auto-detected.
Also applies to: 20-21
587-588: LGTM: Addedis_lazy_styleflag to IRGenerator.The new field with default
Falsemaintains backward compatibility while enabling downstream components to tailor behavior based on detected style.
590-598: LGTM: Helper to detect return statements with values.The function correctly identifies functions with meaningful return values while filtering out explicit
return Nonestatements.
663-670: LGTM: Helpful warning for ambiguous style detection.When the function has a return value but no recognizable lazy/eager patterns, the warning appropriately guides users to specify the mode explicitly.
703-704: LGTM: Lazy style detection integrated into mutate flow.The detection is performed before AST transformation and correctly propagated to the resulting
IRGenerator.Also applies to: 730-730
tilelang/jit/__init__.py (2)
195-265: LGTM: Clear documentation of lazy vs eager execution modes.The updated docstring provides excellent examples for both modes, and the attribute annotations correctly reflect the new mode-aware interface.
368-403: LGTM: Mode-aware compilation with proper validation.The method correctly infers mode, validates that
out_idxis only used in lazy mode, and propagates mode to the wrapped function before compilation.tilelang/language/v2/builder.py (6)
20-20: LGTM: Extended typing imports for mode support.The addition of
Literalandget_originsupports the new mode type hints and callable annotation handling.
853-918: LGTM: TirTemplate extended with lazy-style support.The addition of
is_lazy_styleflag,from_lazy_style()factory method, and short-circuit inget_tir()cleanly separates lazy and eager template handling.
921-944: LGTM: Updated LazyJITFunc documentation and mode support.The docstring clearly explains the two execution styles, and the mode attribute enables explicit mode control when auto-detection is insufficient.
961-982: Simplified_is_lazy_style()using static AST analysis.The method now delegates to
ir_gen.is_lazy_style, which was computed duringmutate(). This is more reliable than the previous runtime detection approach.
984-1001: LGTM: Mode-aware TIR template building.The
_build_tir_templatemethod correctly routes toTirTemplate.from_lazy_style()for lazy mode and the Builder-based path for eager mode.
1082-1086: LGTM: Safe callable annotation evaluation.The check using
get_origin()correctly skips typing generics likeOptional,Union,Listwhich are callable but cannot be instantiated.testing/python/language/test_tilelang_language_frontend_v2.py (5)
206-221: LGTM: Test migrated to unified@tilelang.jitdecorator.The test correctly uses eager mode with
T.empty()for output allocation and returns the result tensor.
258-292: LGTM: Stepped serial tests migrated to unified decorator.Both positive and negative step tests correctly use the eager mode pattern with explicit returns.
294-319: LGTM: Swap tests migrated with in-place tensor modification.The tests correctly use eager mode with tensor type annotations inside the function body, demonstrating in-place modification semantics.
324-338: LGTM: While loop test migrated to unified decorator.The test correctly uses eager mode with
T.empty()output and validates while loop accumulation.
457-477: LGTM: Constexpr if test demonstrates mode parameterization.The test correctly shows eager mode with a boolean parameter controlling compile-time branching, verifying both true and false paths.
| key, kernel_args = self.func.parse_args(*args, **kwargs) | ||
| kernel = self._kernel_cache.get(key, None) | ||
| if kernel is None: | ||
| kernel = self.compile(*args, **kwargs) | ||
| self._kernel_cache[key] = kernel | ||
|
|
||
| # eager mode: execute kernel immediately and return result | ||
| # lazy mode: return kernel object for manual invocation | ||
| if self.mode == "eager": | ||
| return kernel(*kernel_args.values()) | ||
| else: | ||
| key = self.parse_cache_key(*args, **kwargs) | ||
| tune_params = kwargs.pop("__tune_params", {}) | ||
| kernel = self._kernel_cache.get(key, None) | ||
| if kernel is None: | ||
| kernel = self.compile(*args, **kwargs, **tune_params) | ||
| self._kernel_cache[key] = kernel | ||
| return kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: kernel_args can be None, causing crash in eager mode.
LazyJITFunc.parse_args() returns (p1_key, None) when there are no tensor arguments (see builder.py line 1006-1007). Line 446 unconditionally calls kernel(*kernel_args.values()), which will raise AttributeError: 'NoneType' object has no attribute 'values' when the decorated function has no tensor arguments.
🔎 Proposed fix
# eager mode: execute kernel immediately and return result
# lazy mode: return kernel object for manual invocation
if self.mode == "eager":
- return kernel(*kernel_args.values())
+ return kernel(*kernel_args.values()) if kernel_args else kernel()
else:
return kernel🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 437 - 448, LazyJITFunc can receive
kernel_args==None from LazyJITFunc.parse_args, so calling
kernel(*kernel_args.values()) in the eager branch can raise AttributeError;
update the eager execution in the function that uses self.mode
(tilelang/jit/__init__.py) to handle a None kernel_args by invoking kernel()
with no arguments when kernel_args is None (instead of unconditionally using
kernel_args.values()), e.g., check if kernel_args is None and call kernel() else
call kernel(*kernel_args.values()) so functions without tensor args execute
correctly.
| def jit( | ||
| func: Callable[_P, _T] | PrimFunc | None = None, | ||
| *, # Indicates subsequent arguments are keyword-only | ||
| out_idx: list[int] | int | None = None, | ||
| target: str | Target | None = None, | ||
| target_host: str | Target | None = None, | ||
| execution_backend: ExecutionBackend | None = None, | ||
| verbose: bool | None = None, | ||
| pass_configs: dict[str, Any] | None = None, | ||
| debug_root_path: str | None = None, | ||
| compile_flags: list[str] | str | None = None, | ||
| ): | ||
| mode: JITMode = "auto", | ||
| ) -> Callable[[Callable[_P, _T]], JITImpl[_KP, _KP, _T, _T]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type annotation allows PrimFunc but implementation doesn't handle it.
The signature includes func: Callable[_P, _T] | PrimFunc | None, but line 530 calls prim_func(func, lazy_jit=True) without checking if func is already a PrimFunc. The prim_func() function performs AST transformation via mutate(), which expects a Python function, not a TVM IR object. Either remove PrimFunc from the type annotation or add a branch to handle it.
🔎 Proposed fix: Remove PrimFunc from signature
def jit(
- func: Callable[_P, _T] | PrimFunc | None = None,
+ func: Callable[_P, _T] | None = None,
*, # Indicates subsequent arguments are keyword-only📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def jit( | |
| func: Callable[_P, _T] | PrimFunc | None = None, | |
| *, # Indicates subsequent arguments are keyword-only | |
| out_idx: list[int] | int | None = None, | |
| target: str | Target | None = None, | |
| target_host: str | Target | None = None, | |
| execution_backend: ExecutionBackend | None = None, | |
| verbose: bool | None = None, | |
| pass_configs: dict[str, Any] | None = None, | |
| debug_root_path: str | None = None, | |
| compile_flags: list[str] | str | None = None, | |
| ): | |
| mode: JITMode = "auto", | |
| ) -> Callable[[Callable[_P, _T]], JITImpl[_KP, _KP, _T, _T]]: | |
| def jit( | |
| func: Callable[_P, _T] | None = None, | |
| *, # Indicates subsequent arguments are keyword-only | |
| out_idx: list[int] | int | None = None, | |
| target: str | Target | None = None, | |
| target_host: str | Target | None = None, | |
| execution_backend: ExecutionBackend | None = None, | |
| verbose: bool | None = None, | |
| pass_configs: dict[str, Any] | None = None, | |
| debug_root_path: str | None = None, | |
| compile_flags: list[str] | str | None = None, | |
| mode: JITMode = "auto", | |
| ) -> Callable[[Callable[_P, _T]], JITImpl[_KP, _KP, _T, _T]]: |
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 476 - 488, The jit() signature claims
func can be a PrimFunc but the implementation always calls prim_func(func,
lazy_jit=True) which expects a Python function and will break for an
already-constructed PrimFunc; either remove PrimFunc from the type union in
jit()’s func parameter, or add a branch in jit() that detects isinstance(func,
PrimFunc) (or checks type) and skips calling prim_func/AST mutation, passing the
PrimFunc through to JITImpl creation; update references to prim_func, mutate,
and the type annotation accordingly so PrimFunc values are handled safely.
Summary
This PR unifies
@tilelang.jitand@tilelang.lazy_jitinto a single@tilelang.jitdecorator that automatically infers the execution mode.Motivation
Previously, users had to choose between two decorators:
@tilelang.jit- for "lazy" style (returns kernel object)@tilelang.lazy_jit- for "eager" style (executes immediately)This was confusing because:
lazy_jitactually executes eagerly)Changes
API Simplification
Now there's only one decorator
@tilelang.jitthat automatically detects the style:Lazy mode - function explicitly returns a PrimFunc:
Eager mode - function uses DSL builder pattern:
Internal Naming Improvements
v1/v2tolazy/eagerfor claritylegacy_jit→is_lazy_style_is_legacy_jit()→_is_lazy_style()Files Changed
tilelang/__init__.py- removedlazy_jitexporttilelang/jit/__init__.py- unified decorator implementationtilelang/language/v2/builder.py- updated internal naming and docstesting/python/language/test_tilelang_language_lazy_jit.py- updated to use@tilelang.jittesting/python/layout/test_tilelang_annotate_loop_layout.py- updated to use@tilelang.jitexamples/lazy_jit/lazyjit.*.ipynb- updated examplesBackward Compatibility
@tilelang.lazy_jitis removed (breaking change)@tilelang.jitcode continues to work unchanged@tilelang.lazy_jitwith@tilelang.jitTest Plan
🤖 Generated with Claude Code
Summary by CodeRabbit
Breaking Changes
New Features
Updates
Other
✏️ Tip: You can customize this high-level summary in your review settings.