Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Jan 6, 2026

Summary

This PR unifies @tilelang.jit and @tilelang.lazy_jit into a single @tilelang.jit decorator that automatically infers the execution mode.

Motivation

Previously, users had to choose between two decorators:

  • @tilelang.jit - for "lazy" style (returns kernel object)
  • @tilelang.lazy_jit - for "eager" style (executes immediately)

This was confusing because:

  1. The naming was counterintuitive (lazy_jit actually executes eagerly)
  2. Users had to understand the internal differences to choose correctly
  3. Two separate APIs for similar functionality added cognitive overhead

Changes

API Simplification

Now there's only one decorator @tilelang.jit that automatically detects the style:

Lazy mode - function explicitly returns a PrimFunc:

@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K):
    @T.prim_func
    def kernel(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype)):
        ...
    return kernel  # explicitly return PrimFunc

kernel = matmul(1024, 1024, 1024, 128, 128, 32)  # returns kernel object
result = kernel(a, b)  # execute separately

Eager mode - function uses DSL builder pattern:

@tilelang.jit
def gemm(A, B, C, block_M: int = 64):
    M, N, K = T.const("M N K")
    A: T.Tensor[[M, K], dtype]
    B: T.Tensor[[K, N], dtype]
    C: T.Tensor[[M, N], dtype]
    with T.Kernel(...):
        ...
    # no return, builder constructs TIR implicitly

gemm(A, B, C)  # compiles and executes immediately

Internal Naming Improvements

  • Renamed internal mode from v1/v2 to lazy/eager for clarity
  • legacy_jitis_lazy_style
  • _is_legacy_jit()_is_lazy_style()
  • Added comprehensive docstrings explaining the two modes

Files Changed

  • tilelang/__init__.py - removed lazy_jit export
  • tilelang/jit/__init__.py - unified decorator implementation
  • tilelang/language/v2/builder.py - updated internal naming and docs
  • testing/python/language/test_tilelang_language_lazy_jit.py - updated to use @tilelang.jit
  • testing/python/layout/test_tilelang_annotate_loop_layout.py - updated to use @tilelang.jit
  • examples/lazy_jit/lazyjit.*.ipynb - updated examples

Backward Compatibility

  • @tilelang.lazy_jit is removed (breaking change)
  • All existing @tilelang.jit code continues to work unchanged
  • Migration: simply replace @tilelang.lazy_jit with @tilelang.jit

Test Plan

  • Existing tests pass with updated decorator
  • Example notebooks work correctly
  • Both lazy and eager modes function as expected

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Breaking Changes

    • Replaced lazy-specific decorator with unified @tilelang.jit and removed lazy_jit from public exports
  • New Features

    • Mode-aware JIT execution with "auto", "lazy", and "eager" modes
  • Updates

    • Simplified jit decorator surface and unified runtime behavior; many examples and tests updated to use @tilelang.jit
    • Frontend/tests adjusted to return/allocate tensors inline
  • Other

    • Minor tuning/decorator cleanups in examples

✏️ Tip: You can customize this high-level summary in your review settings.

…nd tests

Updated all instances of the @tilelang.lazy_jit decorator to @tilelang.jit in the lazyjit example notebooks and related test files. This change aligns the code with the new JIT compilation approach, enhancing consistency across the codebase. Additionally, removed the lazy_jit import from the module initialization to streamline the API.
…l loops

Updated the T.copy function to accept a new keyword-only parameter, loop_layout, allowing users to specify layout hints for the outermost parallel loop. Enhanced layout annotation handling in CopyNode and AtomicAddNode classes to ensure compatibility with SIMT operations. Added tests to validate the functionality of loop layout annotations, improving robustness and clarity in layout management for parallel loops.
@github-actions
Copy link

github-actions bot commented Jan 6, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 6, 2026

📝 Walkthrough

Walkthrough

Replaces usages of @tilelang.lazy_jit with @tilelang.jit across examples and tests, removes lazy_jit from tilelang exports, and refactors JIT and v2 frontend components to a mode-aware execution model (mode: "auto" | "lazy" | "eager") with unified jit APIs and updated get_tir/compile flows.

Changes

Cohort / File(s) Change Summary
Example Notebooks
examples/lazy_jit/lazyjit.en.ipynb, examples/lazy_jit/lazyjit.zh.ipynb
Replaced all @tilelang.lazy_jit occurrences with @tilelang.jit; notebook logic unchanged.
Tests
testing/python/language/test_tilelang_language_lazy_jit.py, testing/python/language/test_tilelang_language_subtype.py, testing/python/layout/test_tilelang_annotate_loop_layout.py, testing/python/language/test_tilelang_language_frontend_v2.py
Replaced @tilelang.lazy_jit with @tilelang.jit and adapted several test functions to the new jit/prim_func semantics (inlined allocations/returns).
Public API Surface
tilelang/__init__.py
Removed lazy_jit from public exports; now re-exports jit, JITKernel, compile, par_compile.
JIT Core Implementation
tilelang/jit/__init__.py
Reworked JITImpl to use mode: Literal["auto","lazy","eager"]; removed boolean lazy_jit; added _infer_jit_mode, mode-aware get_tir, compile, and __call__ flows; simplified public jit interface and overloads.
Language V2 Builder & AST
tilelang/language/v2/builder.py, tilelang/language/v2/ast.py
Added lazy-vs-eager detection (is_lazy_style), TirTemplate.from_lazy_style, LazyJITFunc.mode and helpers (set_mode, _build_tir_template, _is_lazy_style); parse/get_tir now handle lazy vs eager template flows and caching.
Frontend v2 Tests / Call Sites
testing/python/language/test_tilelang_language_frontend_v2.py
Converted several prim_func-backed tests into in-function @tilelang.jit functions that allocate and return tensors; adjusted signatures and call sites accordingly.
Examples / Minor Edits
examples/attention_sink/example_mha_sink_fwd_bhsd.py, examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py, examples/gemm/example_gemm_intrinsics.py
Removed warmup/rep args from an @autotune(...) decorator; removed @simplify_prim_func import/use in two gemm examples.
Submodule Pointer
3rdparty/tvm
Updated git submodule pointer to a newer commit; no source edits.

Sequence Diagram(s)

sequenceDiagram
  participant User
  participant Decorator as "tilelang.jit (decorator)"
  participant JITImpl
  participant Builder
  participant Backend as "Compiler/Backend"

  User->>Decorator: apply `@tilelang.jit`(func)
  Decorator->>JITImpl: construct JITImpl(mode="auto", func=LazyJITFunc)
  User->>JITImpl: call(...) or get_tir(...)
  JITImpl->>JITImpl: _infer_jit_mode(args...) 
  alt inferred == "lazy"
    JITImpl->>Builder: request PrimFunc / TirTemplate (lazy-style)
    Builder-->>JITImpl: PrimFunc or TirTemplate
    JITImpl->>User: return kernel object / PrimFunc (deferred)
  else inferred == "eager"
    JITImpl->>Builder: trace/build PrimFunc (eager-style)
    Builder-->>JITImpl: PrimFunc
    JITImpl->>Backend: compile & execute
    Backend-->>JITImpl: result
    JITImpl->>User: return result
  end
  note right of JITImpl: mode cached after first inference
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Poem

🐇 I swapped my lazy hat for one more spry,
Modes triple-checked beneath the winter sky,
Decorators united, kernels learn to pick,
Templates whisper "lazy" or "eager" quick,
Hop, compile, return — a tiny joyful trick.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 24.68% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and concisely describes the main change: unifying two decorators (@jit and @lazy_jit) into a single @jit decorator with automatic mode inference.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tilelang/jit/__init__.py (2)

363-378: Potential AttributeError when func is a PrimFunc.

The get_tir method (line 295-296) allows self.func to be a PrimFunc directly, but line 367 unconditionally calls self.func.set_mode(self.mode). If self.func is a PrimFunc instead of a LazyJITFunc, this will raise an AttributeError since PrimFunc doesn't have a set_mode method.

🔎 Proposed fix
     def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
         # infer jit mode on first compile
         if self.mode == "auto":
             self.mode = self._infer_jit_mode(*args, **kwargs)
-        self.func.set_mode(self.mode)
+        if hasattr(self.func, 'set_mode'):
+            self.func.set_mode(self.mode)
         prim_func = self.get_tir(*args, **kwargs)

188-189: Unreachable code: duplicate return statement.

Line 189 is unreachable because the function already returns on line 188. This appears to be leftover code.

🔎 Proposed fix
         for future in tqdm(
             concurrent.futures.as_completed(futures),
             total=len(futures),
             desc="Parallel Compiling",
         ):
             idx = future_map[future]
             if ignore_error:
                 try:
                     results[idx] = future.result()
                 except Exception as e:
                     logger.warning(f"Error compiling function at index {idx}: {e}")
                     results[idx] = None
             else:
                 results[idx] = future.result()
         return results
-    return results
🧹 Nitpick comments (8)
examples/lazy_jit/lazyjit.en.ipynb (1)

24-24: Update notebook documentation to reflect the unified decorator.

The notebook title and several text cells still reference "Lazy JIT" terminology, which may confuse users now that lazy_jit has been removed from the public API. Consider updating:

  • Line 24: Title "Tilelang Lazy JIT"
  • Line 40: Reference to "Tilelang Lazy JIT"
  • Line 554: "LazyJIT has very small overhead"
  • Line 621: "Both lazyjit and the original jit"

These should be updated to reflect the unified @tilelang.jit decorator and explain that the mode (lazy vs eager) is automatically inferred.

Also applies to: 40-40, 554-554, 621-621

tilelang/__init__.py (1)

145-145: Remove the unused noqa directive.

The static analysis tool correctly identified that the # noqa: F401 comment is unnecessary. Since all imported items are now used or intentionally exported, the directive can be removed.

🔎 Proposed fix
-    from .jit import jit, JITKernel, compile, par_compile  # noqa: F401
+    from .jit import jit, JITKernel, compile, par_compile
examples/lazy_jit/lazyjit.zh.ipynb (1)

24-24: Update Chinese notebook documentation to reflect the unified decorator.

Similar to the English version, the Chinese notebook's title and text still reference "Lazy JIT" terminology. Consider updating the documentation to reflect that @tilelang.jit is now the unified decorator that automatically infers execution mode.

Also applies to: 40-40, 554-554, 621-621

tilelang/jit/__init__.py (1)

299-301: Consider using TypeError for invalid type.

Per Python conventions (PEP), TypeError is more appropriate when an argument has an incorrect type.

🔎 Proposed fix
-        else:
-            raise ValueError(f"Invalid function type: {type(self.func)}")
+        else:
+            raise TypeError(f"Invalid function type: {type(self.func)}")
tilelang/language/v2/builder.py (4)

846-848: Consider improving assertion error messages.

Both assertions produce the same error message but check different conditions. The first checks if we're inside a JIT context, the second checks if the builder is in JIT mode. Clearer messages would help debugging.

🔎 Proposed fix
     builder = Builder.current()
-    assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)"
-    assert builder.lazy_jit, "T.const() can only be used inside @tilelang.jit (eager mode)"
+    assert builder is not None, "T.const() must be called inside a @tilelang.jit decorated function"
+    assert builder.lazy_jit, "T.const() requires eager mode (DSL builder pattern with tensor annotations)"

985-990: Broad exception catch in style detection.

Catching all Exception types is intentional here (to treat any failure as "not lazy style"), but it may mask genuine bugs in user code. Consider logging at a higher level or limiting to specific exceptions if possible.

🔎 Proposed fix
         try:
             result = self.orig_func(*args, **kwargs)
             return isinstance(result, PrimFunc)
-        except Exception:
-            logger.debug("Function doesn't return PrimFunc directly, treating as eager style")
+        except Exception as e:
+            logger.debug(f"Function doesn't return PrimFunc directly, treating as eager style: {e}")
             return False

1024-1028: Outdated comment referencing "legacy gemm".

The comment on line 1025 references "legacy gemm" which doesn't align with the new lazy/eager terminology. Consider updating for consistency.

🔎 Proposed fix
         (p1_key, _), tensor_args = self.parse_args(*args, **kwargs)
         if p1_key not in self.p1_cache:
-            # in legacy gemm, we use lazy tir template to build the tir
+            # build and cache the TIR template on first call with this config
             tir_temp = self._build_tir_template(*args, **kwargs)
             self.p1_cache[p1_key] = tir_temp
             return tir_temp.get_tir(**tensor_args)

1061-1061: Implicit Optional type hint.

PEP 484 recommends explicit Optional[X] rather than X = None for default None parameters. This improves type checker compatibility.

🔎 Proposed fix
+from typing import Optional
+
-def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
+def prim_func(func: Optional[Callable[_P, _T]] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9d446f3 and bfecb9b.

📒 Files selected for processing (7)
  • examples/lazy_jit/lazyjit.en.ipynb
  • examples/lazy_jit/lazyjit.zh.ipynb
  • testing/python/language/test_tilelang_language_lazy_jit.py
  • testing/python/layout/test_tilelang_annotate_loop_layout.py
  • tilelang/__init__.py
  • tilelang/jit/__init__.py
  • tilelang/language/v2/builder.py
🧰 Additional context used
🧠 Learnings (4)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_lazy_jit.py
  • testing/python/layout/test_tilelang_annotate_loop_layout.py
  • examples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2025-12-26T06:45:51.789Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1483
File: tilelang/jit/adapter/cutedsl/adapter.py:93-95
Timestamp: 2025-12-26T06:45:51.789Z
Learning: For the CuTeDSL backend in tilelang/jit/adapter/cutedsl/adapter.py, the host_kernel_source and device_kernel_source have the same value.

Applied to files:

  • examples/lazy_jit/lazyjit.en.ipynb
  • examples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/lazy_jit/lazyjit.en.ipynb
  • examples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • examples/lazy_jit/lazyjit.en.ipynb
  • testing/python/layout/test_tilelang_annotate_loop_layout.py
  • examples/lazy_jit/lazyjit.zh.ipynb
🧬 Code graph analysis (4)
tilelang/__init__.py (2)
tilelang/jit/__init__.py (5)
  • jit (439-439)
  • jit (443-453)
  • jit (456-496)
  • compile (47-107)
  • compile (363-393)
tilelang/jit/kernel.py (1)
  • JITKernel (38-767)
testing/python/language/test_tilelang_language_lazy_jit.py (3)
tilelang/jit/__init__.py (3)
  • jit (439-439)
  • jit (443-453)
  • jit (456-496)
tilelang/language/v2/builder.py (1)
  • const (831-856)
tilelang/language/symbolics.py (1)
  • dynamic (12-29)
testing/python/layout/test_tilelang_annotate_loop_layout.py (1)
tilelang/jit/__init__.py (3)
  • jit (439-439)
  • jit (443-453)
  • jit (456-496)
tilelang/language/v2/builder.py (1)
tilelang/jit/__init__.py (1)
  • get_tir (291-302)
🪛 Ruff (0.14.10)
tilelang/__init__.py

145-145: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/jit/__init__.py

300-300: Prefer TypeError exception for invalid type

(TRY004)


300-300: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/v2/builder.py

988-988: Do not catch blind exception: Exception

(BLE001)


1009-1009: Avoid specifying long messages outside the exception class

(TRY003)


1061-1061: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (10)
examples/lazy_jit/lazyjit.en.ipynb (1)

56-56: LGTM! Decorator changes are consistent.

All decorator changes from @tilelang.lazy_jit to @tilelang.jit are correctly applied.

Also applies to: 212-212, 251-251, 310-310, 362-362, 424-424, 473-473, 518-518, 580-580, 800-800, 860-860

tilelang/__init__.py (1)

145-145: LGTM! Breaking change correctly implements the API unification.

The removal of lazy_jit from the public exports is correct and aligns with the PR's stated goal of unifying the decorators. Users should migrate to @tilelang.jit as documented in the PR objectives.

testing/python/language/test_tilelang_language_lazy_jit.py (1)

9-9: LGTM! Test updates are comprehensive and correct.

All test function decorators have been correctly updated from @tilelang.lazy_jit to @tilelang.jit. The test coverage spans various annotation patterns (T.const, T.dynamic, T.Tensor, T.StridedTensor), ensuring the unified decorator works correctly across different usage scenarios.

Also applies to: 48-48, 105-105, 112-112, 119-119, 126-126, 134-134, 141-141, 178-178, 184-184, 189-189, 195-195, 202-202, 208-208

examples/lazy_jit/lazyjit.zh.ipynb (1)

56-56: LGTM! Decorator changes are consistent.

All decorator changes from @tilelang.lazy_jit to @tilelang.jit are correctly applied throughout the Chinese notebook.

Also applies to: 212-212, 251-251, 310-310, 362-362, 424-424, 473-473, 518-518, 580-580, 800-800, 860-860

testing/python/layout/test_tilelang_annotate_loop_layout.py (1)

7-7: LGTM! Layout test updates are correct.

All three kernel function decorators have been correctly updated from @tilelang.lazy_jit to @tilelang.jit. The tests cover loop layout annotation scenarios and will verify that the unified decorator works correctly with the layout system.

Also applies to: 54-54, 82-82

tilelang/jit/__init__.py (4)

192-279: LGTM!

The comprehensive docstring with examples for both lazy and eager modes significantly improves developer experience. The attribute documentation is clear and the type annotations are well-defined.


304-316: LGTM!

The mode inference logic is clean with appropriate early returns for explicit modes and sensible fallback behavior.


485-494: LGTM!

The decorator correctly wraps the function in a LazyJITFunc and initializes JITImpl with mode="auto" for automatic inference. The comment on line 486 could be slightly clarified to distinguish between the internal lazy_jit wrapper flag and the user-facing execution mode.


420-432: Type annotation for func is explicit and unambiguous—no type guard needed.

The field func at line 279-280 is explicitly typed as func: LazyJITFunc[_KP, _T], not a union or generic allowing PrimFunc. The jit decorator always wraps the input function with prim_func(..., lazy_jit=True) before passing it to JITImpl, ensuring self.func is always LazyJITFunc at runtime. The parse_args call at line 421 is safe. The defensive isinstance checks in the get_tir method are unrelated to the __call__ code path.

Likely an incorrect or invalid review comment.

tilelang/language/v2/builder.py (1)

859-924: LGTM!

The TirTemplate refactoring cleanly separates lazy and eager paths. The from_lazy_style factory method and the short-circuit in get_tir are well-designed for the dual execution model.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 428-433: parse_args can return None for kernel_args when there are
no tensor arguments, but the code unconditionally calls
kernel(*kernel_args.values()), which will raise when kernel_args is None; update
the call site in the wrapper (the block that gets kernel from _kernel_cache,
calls self.compile, then invokes kernel) to handle the None case by invoking
kernel() when kernel_args is falsy (e.g., use kernel(*kernel_args.values()) if
kernel_args else kernel()), or change parse_args to always return an empty dict;
ensure this fix touches the call that currently does
kernel(*kernel_args.values()) and keep the compile/_kernel_cache logic
unchanged.
- Around line 277-279: The runtime isinstance checks in get_tir and
_infer_jit_mode don't match the declared JITImpl.func type (LazyJITFunc[_KP,
_T]) and are redundant because the decorator always wraps the input with
prim_func(..., lazy_jit=True); update the code to either remove the
PrimFunc/Callable isinstance branches in get_tir and _infer_jit_mode and treat
func as LazyJITFunc (using LazyJITFunc methods/properties directly), or change
the annotation of JITImpl.func to PrimFunc[_KP,_T] | LazyJITFunc[_KP,_T] |
Callable[_KP,_T] and keep the checks—pick one consistent approach and apply it
to both get_tir and _infer_jit_mode, referencing JITImpl.func, LazyJITFunc,
PrimFunc, prim_func and the decorator wrapping logic to ensure types and runtime
checks align.
🧹 Nitpick comments (1)
tilelang/jit/__init__.py (1)

295-302: Consider using TypeError for type validation.

When the function type is invalid, TypeError is more semantically appropriate than ValueError for type-related errors.

Proposed fix
         else:
-            raise ValueError(f"Invalid function type: {type(self.func)}")
+            raise TypeError(f"Invalid function type: {type(self.func)}")

Based on static analysis hints.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bfecb9b and d173672.

📒 Files selected for processing (2)
  • 3rdparty/tvm
  • tilelang/jit/__init__.py
✅ Files skipped from review due to trivial changes (1)
  • 3rdparty/tvm
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/jit/__init__.py (1)
tilelang/language/v2/builder.py (9)
  • LazyJITFunc (928-1036)
  • PrimFunc (690-699)
  • _is_lazy_style (967-990)
  • prim_func (182-189)
  • prim_func (1061-1099)
  • get_tir (915-924)
  • get_tir (1022-1029)
  • parse_args (1011-1020)
  • get (219-220)
🪛 Ruff (0.14.10)
tilelang/jit/__init__.py

300-300: Prefer TypeError exception for invalid type

(TRY004)


300-300: Avoid specifying long messages outside the exception class

(TRY003)


370-373: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (6)
tilelang/jit/__init__.py (6)

195-265: Excellent documentation of the unified JIT interface.

The docstring clearly explains the lazy vs eager execution modes with concrete examples. This will significantly help users understand when each mode is triggered and how to use them effectively.


304-316: Mode inference logic is sound.

The method correctly infers the execution mode by:

  1. Returning existing mode if already set
  2. Delegating to _is_lazy_style() to determine if the function returns a PrimFunc (lazy) or uses builder pattern (eager)

This aligns well with the PR objectives for automatic mode detection.


435-440: Clean mode-based execution logic.

The conditional execution based on mode is well-structured:

  • Eager mode: Executes kernel immediately and returns result
  • Lazy mode: Returns kernel object for manual invocation

This aligns perfectly with the PR objectives for unified behavior.


476-501: Well-documented decorator interface.

The docstring clearly explains:

  • Automatic mode inference
  • Parameter purposes and constraints
  • That out_idx only works in lazy mode

This provides good guidance for users migrating from the old two-decorator system.


364-376: Mode inference and validation logic is well-implemented.

The compile method correctly:

  • Infers mode on first compilation (line 365-366)
  • Validates that out_idx is only used in lazy mode (lines 369-373)
  • Propagates the mode to the wrapped function (line 375)

This ensures consistent behavior across the JIT lifecycle.


514-523: Parameter naming is correct; no issue identified.

Line 516 correctly passes lazy_jit=True to prim_func. The prim_func function signature in tilelang/language/v2/builder.py defines this parameter as lazy_jit: bool = False, so the call matches the actual function definition. The is_lazy_style naming exists separately as internal fields and methods within the LazyJITFunc class, not as the prim_func parameter name.

Comment on lines +277 to +279
mode: Literal["auto", "lazy", "eager"]
# place func at the last element for better __repr__
func: Callable[_P, _T] | PrimFunc[_KP, _T] | LazyJITFunc[_KP, _T]
func: LazyJITFunc[_KP, _T]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's locate the file and check its size
find . -name "__init__.py" -path "*tilelang/jit*" -type f

Repository: tile-ai/tilelang

Length of output: 288


🏁 Script executed:

# Read the relevant sections of the file
head -300 tilelang/jit/__init__.py | tail -50

Repository: tile-ai/tilelang

Length of output: 2177


🏁 Script executed:

# Get context around lines 277-279 and beyond
sed -n '270,330p' tilelang/jit/__init__.py

Repository: tile-ai/tilelang

Length of output: 2818


🏁 Script executed:

# Search for JITImpl class definition and instantiation patterns
rg -n -B 5 -A 10 'class JITImpl' tilelang/jit/__init__.py

Repository: tile-ai/tilelang

Length of output: 657


🏁 Script executed:

# Search for where func is used and what types it might receive
rg -n 'func.*LazyJITFunc|func.*PrimFunc|func.*Callable' tilelang/jit/__init__.py

Repository: tile-ai/tilelang

Length of output: 1193


🏁 Script executed:

# Look at the decorator implementation around line 514 where JITImpl is likely created
sed -n '514,570p' tilelang/jit/__init__.py

Repository: tile-ai/tilelang

Length of output: 521


🏁 Script executed:

# Search for JITImpl instantiation to see how func is passed
rg -n 'JITImpl(' tilelang/jit/__init__.py

Repository: tile-ai/tilelang

Length of output: 126


🏁 Script executed:

# Search for all JITImpl instantiations with a simpler pattern
rg -n 'JITImpl\[' tilelang/jit/__init__.py

Repository: tile-ai/tilelang

Length of output: 267


🏁 Script executed:

# Look for places where JITImpl is directly instantiated outside the decorator
rg -n 'return JITImpl' tilelang/jit/__init__.py

Repository: tile-ai/tilelang

Length of output: 87


🏁 Script executed:

# Check if there are any other files that might instantiate JITImpl
rg -r 'JITImpl(' tilelang/ --type=py | head -20

Repository: tile-ai/tilelang

Length of output: 1115


🏁 Script executed:

# Look at get_tir method more carefully - what does it expect?
sed -n '292,310p' tilelang/jit/__init__.py

Repository: tile-ai/tilelang

Length of output: 885


🏁 Script executed:

# Check line 389 to understand the PrimFunc check
sed -n '385,395p' tilelang/jit/__init__.py

Repository: tile-ai/tilelang

Length of output: 536


Remove or update runtime type checks in get_tir and _infer_jit_mode to match type annotation.

The func field is declared as LazyJITFunc[_KP, _T] (line 279), and the decorator always wraps the input function in LazyJITFunc via prim_func(..., lazy_jit=True) before instantiation (lines 514-522). However, get_tir() (lines 295-297) and _infer_jit_mode() (line 314) include runtime type checks for PrimFunc and generic Callable that don't match this guaranteed type.

Since JITImpl.func is always LazyJITFunc at instantiation, remove the unnecessary isinstance checks for other types, or update the type annotation to PrimFunc[_KP, _T] | LazyJITFunc[_KP, _T] | Callable[_KP, _T] if these cases are intentionally supported.

🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 277 - 279, The runtime isinstance
checks in get_tir and _infer_jit_mode don't match the declared JITImpl.func type
(LazyJITFunc[_KP, _T]) and are redundant because the decorator always wraps the
input with prim_func(..., lazy_jit=True); update the code to either remove the
PrimFunc/Callable isinstance branches in get_tir and _infer_jit_mode and treat
func as LazyJITFunc (using LazyJITFunc methods/properties directly), or change
the annotation of JITImpl.func to PrimFunc[_KP,_T] | LazyJITFunc[_KP,_T] |
Callable[_KP,_T] and keep the checks—pick one consistent approach and apply it
to both get_tir and _infer_jit_mode, referencing JITImpl.func, LazyJITFunc,
PrimFunc, prim_func and the decorator wrapping logic to ensure types and runtime
checks align.

Comment on lines 428 to 433
kwargs.update(kwargs.pop("__tune_params", {}))
key, kernel_args = self.func.parse_args(*args, **kwargs)
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs)
self._kernel_cache[key] = kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, find the LazyJITFunc class and parse_args method
fd tilelang -type f -name "*.py" | head -20

Repository: tile-ai/tilelang

Length of output: 230


🏁 Script executed:

# Search for LazyJITFunc class definition
rg "class LazyJITFunc" -A 50

Repository: tile-ai/tilelang

Length of output: 3575


🏁 Script executed:

# Search for parse_args method definition
rg "def parse_args" -B 5 -A 20

Repository: tile-ai/tilelang

Length of output: 2085


🏁 Script executed:

# Check if parse_cache_key is used anywhere
rg "parse_cache_key" -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 1602


🏁 Script executed:

# Look at the specific file to understand context
cat tilelang/jit/__init__.py | head -n 450 | tail -n 50

Repository: tile-ai/tilelang

Length of output: 2083


Fix the None case handling for kernel_args when calling kernel.

The parse_args method can return None as kernel_args when no tensor arguments are present (line 429), but line 437 unconditionally calls kernel(*kernel_args.values()) without checking for None. This will crash at runtime in eager mode when the decorated function has no tensor arguments.

Either handle the None case:

  • kernel(*kernel_args.values()) if kernel_args else kernel()

Or verify that parse_args always returns a dict (even if empty).

🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 428 - 433, parse_args can return None
for kernel_args when there are no tensor arguments, but the code unconditionally
calls kernel(*kernel_args.values()), which will raise when kernel_args is None;
update the call site in the wrapper (the block that gets kernel from
_kernel_cache, calls self.compile, then invokes kernel) to handle the None case
by invoking kernel() when kernel_args is falsy (e.g., use
kernel(*kernel_args.values()) if kernel_args else kernel()), or change
parse_args to always return an empty dict; ensure this fix touches the call that
currently does kernel(*kernel_args.values()) and keep the compile/_kernel_cache
logic unchanged.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 425-437: parse_args can return None for kernel_args, causing
kernel(*kernel_args.values()) in the eager branch to raise an AttributeError;
update the eager execution path in __call__ (or the surrounding function) to
handle kernel_args being None by calling kernel() when kernel_args is None or
calling kernel(*kernel_args.values()) when not None, and ensure kernel_arg
handling mirrors the lazy path so mode == "eager" correctly invokes the compiled
kernel with zero or more args.
🧹 Nitpick comments (1)
tilelang/jit/__init__.py (1)

304-316: Simplify mode inference logic by removing redundant type check.

Since self.func is always LazyJITFunc[_KP, _T] (guaranteed by the decorator at line 513), the isinstance check on line 314 and the early return on line 315 are unnecessary. The logic can be simplified to directly call self.func._is_lazy_style().

🔎 Proposed simplification
 def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]:
     """
     Infer the JIT execution mode based on function behavior.

     Returns "lazy" if the function explicitly returns a PrimFunc,
     or "eager" if it uses the DSL builder pattern.
     """
     if self.mode in ("lazy", "eager"):
         return self.mode
     # auto: infer by checking if function returns PrimFunc directly
-    if not isinstance(self.func, LazyJITFunc):
-        return "lazy"
     return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager"
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d173672 and 5740b4e.

📒 Files selected for processing (1)
  • tilelang/jit/__init__.py
🧰 Additional context used
🪛 Ruff (0.14.10)
tilelang/jit/__init__.py

300-300: Prefer TypeError exception for invalid type

(TRY004)


300-300: Avoid specifying long messages outside the exception class

(TRY003)


370-370: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
tilelang/jit/__init__.py (3)

195-265: Excellent documentation for the unified JIT decorator.

The comprehensive docstring clearly explains both execution modes with concrete examples and describes the automatic inference behavior. This will help users understand the migration from separate decorators.


461-522: Well-structured unified decorator implementation.

The decorator correctly wraps all functions with prim_func(..., lazy_jit=True) to create a LazyJITFunc, then initializes JITImpl with mode="auto" for runtime inference. This approach cleanly unifies the previous separate decorators while maintaining backward compatibility.


363-398: No action needed. The set_mode method is properly implemented in LazyJITFunc at tilelang/language/v2/builder.py:1034. Line 372's call to self.func.set_mode(self.mode) is valid.

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 425-443: parse_args can return kernel_args = None (see
LazyJITFunc.parse_args), so calling kernel(*kernel_args.values()) will crash;
update the eager path in __call__ (after key/kernel lookup in __init__.py) to
detect if kernel_args is None and invoke kernel() with no args, otherwise call
kernel(*kernel_args.values()). Keep existing logic for mode inference
(self._infer_jit_mode, self.func.set_mode) and kernel caching (_kernel_cache,
compile) unchanged.
- Around line 304-316: The isinstance check for LazyJITFunc in _infer_jit_mode
is redundant because self.func is always a LazyJITFunc; remove the entire branch
`if not isinstance(self.func, LazyJITFunc): return "lazy"` and simplify the
method to first honor explicit self.mode ("lazy" or "eager") and otherwise
return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager", keeping
the same return type Literal["lazy","eager"] and preserving the call to
_is_lazy_style on the LazyJITFunc instance.
- Around line 467-478: The decorator signature incorrectly allows PrimFunc but
the implementation calls prim_func(func, lazy_jit=True) which expects a Python
function and then runs mutate(), so either remove PrimFunc from the type union
in the jit signature or add an explicit runtime branch in jit that detects a
PrimFunc and handles it safely (e.g., pass-through return or wrap appropriately)
instead of calling prim_func/mutate on it; reference symbols: jit, prim_func,
PrimFunc, mutate.

In @tilelang/language/v2/builder.py:
- Line 1057: The signature of prim_func declares func with a default None but
its type omits None; update the annotation for the parameter func in prim_func
to explicitly allow None (e.g., use Optional[Callable[_P, _T]] or Callable[_P,
_T] | None) so the default value matches the type; keep the return annotation
as-is and modify only the func parameter type in the prim_func definition.
- Around line 961-984: The _is_lazy_style function currently catches a bare
Exception when calling self.orig_func, which can hide programming errors; change
the except to only catch expected runtime invocation errors (e.g., except
(TypeError, AttributeError, ValueError) as e) and log the exception details via
logger.debug including e, and do not catch BaseException (so
KeyboardInterrupt/SystemExit propagate); keep the existing return False behavior
for those specific caught exceptions to treat the function as eager style.
- Around line 826-842: The assertions in const() are contradictory: they assert
builder.lazy_jit but the error text says "eager mode". Fix by making the check
and message consistent — either test the actual JIT mode flag set in
_build_tir_template() (e.g., use the mode field on the LazyJITFunc or rename
builder.lazy_jit to a clearer name like builder.is_jit) or simply update the
assertion to assert builder.lazy_jit and change the error text to "T.const() can
only be used inside @tilelang.jit (JIT/eager mode)"; reference const(),
Builder.current(), builder.lazy_jit and _build_tir_template() to locate and
update the check and message accordingly.
🧹 Nitpick comments (1)
tilelang/language/v2/builder.py (1)

909-918: Verify matcher is unused in lazy-style mode.

When is_lazy_style is True, get_tir() returns the PrimFunc directly without substitution (lines 910-911), bypassing the matcher logic. Confirm that matcher is always None for lazy-style templates to avoid confusion.

🔎 Optional assertion to enforce invariant
     def get_tir(self, **kwargs):
         if self.is_lazy_style:
+            assert self.matcher is None, "Lazy-style templates should not have a matcher"
             return self.prim_func
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5740b4e and b62ec8e.

📒 Files selected for processing (2)
  • tilelang/jit/__init__.py
  • tilelang/language/v2/builder.py
🧰 Additional context used
🪛 Ruff (0.14.10)
tilelang/language/v2/builder.py

982-982: Do not catch blind exception: Exception

(BLE001)


1003-1003: Avoid specifying long messages outside the exception class

(TRY003)


1057-1057: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/jit/__init__.py

300-300: Prefer TypeError exception for invalid type

(TRY004)


300-300: Avoid specifying long messages outside the exception class

(TRY003)


370-370: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (3)
tilelang/jit/__init__.py (2)

363-398: Verify mode inference happens before get_tir() uses it.

The compile() method infers the mode (lines 365-366) and sets it on self.func (line 372) before calling get_tir() (line 373). However, get_tir() calls self.func(*args, **kwargs) which invokes LazyJITFunc.__call__get_tir()_build_tir_template(), and that method requires self.mode to be "lazy" or "eager" (builder.py line 1003).

Confirm that LazyJITFunc.mode is set by the time _build_tir_template() is called, or the code will raise ValueError for invalid mode.

The flow looks correct: mode is inferred and set on line 372 before get_tir is called on line 373. However, verify with a test case that has mode="auto" initially.


195-265: Excellent documentation for dual execution modes.

The docstring clearly explains the difference between lazy and eager modes with concrete examples. This will help users understand when each mode is triggered and how to use them effectively.

tilelang/language/v2/builder.py (1)

986-1003: Remove this concern — mode is properly set before _build_tir_template() is called.

Both JITImpl.__call__() and JITImpl.compile() ensure self.mode is inferred and set on the LazyJITFunc instance before calling methods that depend on it:

  • __call__() (line 428–430): Infers mode if "auto", then calls self.func.set_mode() before self.func.parse_args() (line 432)
  • compile() (line 366–372): Infers mode if "auto", then calls self.func.set_mode() before self.get_tir() (line 373)

The _infer_jit_mode() method always returns "lazy" or "eager" (never "auto"), so self.mode is guaranteed to be a valid state before _build_tir_template() executes. The existing comment at line 1013 confirms the developer's awareness of this dependency.

Comment on lines 304 to 316
def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]:
"""
Infer the JIT execution mode based on function behavior.

Returns "lazy" if the function explicitly returns a PrimFunc,
or "eager" if it uses the DSL builder pattern.
"""
if self.mode in ("lazy", "eager"):
return self.mode
# auto: infer by checking if function returns PrimFunc directly
if not isinstance(self.func, LazyJITFunc):
return "lazy"
return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove redundant type check in _infer_jit_mode().

Line 314 checks if not isinstance(self.func, LazyJITFunc) and returns "lazy", but self.func is annotated as LazyJITFunc[_KP, _T] (line 279) and the decorator ensures it's always wrapped in LazyJITFunc (line 519). This check is redundant.

🔎 Proposed fix
     def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]:
         """
         Infer the JIT execution mode based on function behavior.
 
         Returns "lazy" if the function explicitly returns a PrimFunc,
         or "eager" if it uses the DSL builder pattern.
         """
         if self.mode in ("lazy", "eager"):
             return self.mode
         # auto: infer by checking if function returns PrimFunc directly
-        if not isinstance(self.func, LazyJITFunc):
-            return "lazy"
         return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager"

Based on past review comments.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]:
"""
Infer the JIT execution mode based on function behavior.
Returns "lazy" if the function explicitly returns a PrimFunc,
or "eager" if it uses the DSL builder pattern.
"""
if self.mode in ("lazy", "eager"):
return self.mode
# auto: infer by checking if function returns PrimFunc directly
if not isinstance(self.func, LazyJITFunc):
return "lazy"
return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager"
def _infer_jit_mode(self, *args: _P.args, **kwargs: _P.kwargs) -> Literal["lazy", "eager"]:
"""
Infer the JIT execution mode based on function behavior.
Returns "lazy" if the function explicitly returns a PrimFunc,
or "eager" if it uses the DSL builder pattern.
"""
if self.mode in ("lazy", "eager"):
return self.mode
# auto: infer by checking if function returns PrimFunc directly
return "lazy" if self.func._is_lazy_style(*args, **kwargs) else "eager"
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 304 - 316, The isinstance check for
LazyJITFunc in _infer_jit_mode is redundant because self.func is always a
LazyJITFunc; remove the entire branch `if not isinstance(self.func,
LazyJITFunc): return "lazy"` and simplify the method to first honor explicit
self.mode ("lazy" or "eager") and otherwise return "lazy" if
self.func._is_lazy_style(*args, **kwargs) else "eager", keeping the same return
type Literal["lazy","eager"] and preserving the call to _is_lazy_style on the
LazyJITFunc instance.

Comment on lines +826 to +842
"""
Declare constexpr variables for dynamic tensor dimensions (eager mode only).

In eager mode, use T.const() to declare shape dimensions that will be
inferred from actual tensor arguments at runtime.

Example::

@tilelang.jit
def kernel(A, B):
M, N = T.const("M, N")
A: T.Tensor[[M, N], T.float32]
...
"""
builder = Builder.current()
assert builder is not None, "const can only be used inside `tilelang.lazy_jit` function"
assert builder.lazy_jit, "const can only be used inside `tilelang.lazy_jit` function"
assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)"
assert builder.lazy_jit, "T.const() can only be used inside @tilelang.jit (eager mode)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix confusing assertion logic in const() function.

Line 842 asserts that builder.lazy_jit is True, but the error message claims T.const() can only be used in "eager mode". This is contradictory and confusing.

Looking at line 994 in _build_tir_template(), builder.lazy_jit = True is set for eager-style execution (when mode is "eager"), which suggests the field name lazy_jit is misleading—it actually indicates "uses JIT" rather than "is lazy style".

Either:

  • Rename builder.lazy_jit to something clearer like builder.is_jit or builder.from_jit
  • Or invert the assertion to check the mode field on LazyJITFunc instead
🔎 Suggested fix: Check mode directly
-    assert builder.lazy_jit, "T.const() can only be used inside @tilelang.jit (eager mode)"
+    # T.const() is only valid in eager mode where we infer shapes from tensors
+    # In lazy mode, shape parameters are explicit function arguments

Then verify the mode when it matters, or accept that any JIT context allows const().

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @tilelang/language/v2/builder.py around lines 826 - 842, The assertions in
const() are contradictory: they assert builder.lazy_jit but the error text says
"eager mode". Fix by making the check and message consistent — either test the
actual JIT mode flag set in _build_tir_template() (e.g., use the mode field on
the LazyJITFunc or rename builder.lazy_jit to a clearer name like
builder.is_jit) or simply update the assertion to assert builder.lazy_jit and
change the error text to "T.const() can only be used inside @tilelang.jit
(JIT/eager mode)"; reference const(), Builder.current(), builder.lazy_jit and
_build_tir_template() to locate and update the check and message accordingly.



def prim_func(func: Callable[_P, _T] = None, *, lazy_jit=False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use explicit Optional for optional parameter.

The func parameter has a default of None but the type annotation doesn't include None, violating PEP 484. Use explicit Optional or union with None.

🔎 Proposed fix
-def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
+def prim_func(func: Callable[_P, _T] | None = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:

Based on static analysis hints.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
def prim_func(func: Callable[_P, _T] | None = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
🧰 Tools
🪛 Ruff (0.14.10)

1057-1057: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

🤖 Prompt for AI Agents
In @tilelang/language/v2/builder.py at line 1057, The signature of prim_func
declares func with a default None but its type omits None; update the annotation
for the parameter func in prim_func to explicitly allow None (e.g., use
Optional[Callable[_P, _T]] or Callable[_P, _T] | None) so the default value
matches the type; keep the return annotation as-is and modify only the func
parameter type in the prim_func definition.

…le detection

- Removed unnecessary parameters from the autotune decorator in example_mha_sink_fwd_bhsd.py.
- Updated JITImpl to infer execution mode using static AST analysis instead of runtime checks.
- Added a new function to detect lazy style in AST, enhancing the clarity of function behavior.
- Refactored LazyJITFunc to utilize the new lazy style detection mechanism.
- Simplified the autotune decorator in example_mha_sink_fwd_bhsd.py by removing unnecessary parameters.
- Improved JITImpl to utilize static AST analysis for mode inference, eliminating the need for runtime checks.
- Introduced a new function for lazy style detection in AST, enhancing clarity and functionality.
- Updated LazyJITFunc to leverage the new lazy style detection mechanism.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 430-448: parse_args() can return kernel_args == None for no-tensor
calls; change the eager-path in the JIT wrapper (the block after
self._kernel_cache lookup) to handle that: if kernel_args is None call kernel()
with no positional args, otherwise call kernel(*kernel_args.values()). Update
the code around self.func.parse_args(...) and the eager branch that currently
does kernel(*kernel_args.values()) so it checks kernel_args is not None before
accessing .values(), leaving cache key/compile behavior unchanged.
🧹 Nitpick comments (1)
tilelang/language/v2/ast.py (1)

587-657: Add parentheses to clarify operator precedence.

The boolean expressions at lines 606-610 chain and and or operators without parentheses. While the precedence is correct (and binds tighter than or), explicit parentheses improve readability and address the static analysis warning.

🔎 Proposed fix
             for decorator in node.decorator_list:
                 if (
-                    isinstance(decorator, ast.Attribute)
-                    and decorator.attr == "prim_func"
-                    or isinstance(decorator, ast.Name)
-                    and decorator.id == "prim_func"
+                    (isinstance(decorator, ast.Attribute)
+                    and decorator.attr == "prim_func")
+                    or (isinstance(decorator, ast.Name)
+                    and decorator.id == "prim_func")
                 ):
                     has_inner_prim_func = True

Based on static analysis hints.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b62ec8e and 1f93975.

📒 Files selected for processing (6)
  • examples/attention_sink/example_mha_sink_fwd_bhsd.py
  • examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
  • testing/python/language/test_tilelang_language_subtype.py
  • tilelang/jit/__init__.py
  • tilelang/language/v2/ast.py
  • tilelang/language/v2/builder.py
💤 Files with no reviewable changes (1)
  • examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_subtype.py
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • testing/python/language/test_tilelang_language_subtype.py
📚 Learning: 2025-12-15T07:48:45.785Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: src/target/codegen_cutedsl.cc:789-793
Timestamp: 2025-12-15T07:48:45.785Z
Learning: In tilelang/contrib/cutedsl, the `tl.make_rmem_tensor` function accepts both an Integer and a Tuple of Integer for its shape parameter. Therefore, both `tl.make_rmem_tensor(N, ...)` and `tl.make_rmem_tensor((N,), ...)` are valid syntaxes in CuteDSL-generated code.

Applied to files:

  • testing/python/language/test_tilelang_language_subtype.py
🧬 Code graph analysis (3)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
tilelang/autotuner/tuner.py (1)
  • autotune (691-786)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
  • get_configs (14-16)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
  • get_configs (11-13)
examples/flash_attention/example_mha_fwd_varlen.py (1)
  • get_configs (15-17)
examples/flash_attention/example_mha_fwd_bshd.py (1)
  • get_configs (11-13)
testing/python/language/test_tilelang_language_subtype.py (1)
tilelang/jit/__init__.py (3)
  • jit (455-455)
  • jit (459-469)
  • jit (472-536)
tilelang/jit/__init__.py (2)
tilelang/language/v2/builder.py (10)
  • LazyJITFunc (922-1042)
  • set_mode (1028-1030)
  • PrimFunc (684-693)
  • _is_lazy_style (961-982)
  • prim_func (182-189)
  • prim_func (1067-1108)
  • get_tir (909-918)
  • get_tir (1016-1023)
  • parse_args (1003-1014)
  • get (219-220)
tilelang/jit/adapter/wrapper.py (2)
  • prim_func (575-585)
  • prim_func (839-849)
🪛 Ruff (0.14.10)
tilelang/language/v2/builder.py

1001-1001: Avoid specifying long messages outside the exception class

(TRY003)


1042-1042: Avoid specifying long messages outside the exception class

(TRY003)


1067-1067: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/jit/__init__.py

305-305: Prefer TypeError exception for invalid type

(TRY004)


305-305: Avoid specifying long messages outside the exception class

(TRY003)


375-375: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/v2/ast.py

607-608: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


609-610: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (12)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)

18-18: No changes needed—the warmup reduction is intentional and appropriate.

The file pair demonstrates differentiated tuning strategies:

  • example_mha_sink_fwd_bhsd.py (non-pipelined): uses warmup=25 (default)
  • example_mha_sink_fwd_bhsd_wgmma_pipelined.py (pipelined): retains warmup=500

This reflects that simpler, non-pipelined kernels don't require extended warmup periods for stabilization, while the more complex pipelined implementation maintains the longer warmup for reliable autotuning. The reduction is targeted, not a blanket standardization.

tilelang/language/v2/ast.py (2)

584-584: LGTM!

The is_lazy_style field is well-integrated with the new detection logic and provides a clear semantic flag for distinguishing lazy vs eager execution modes.


690-691: LGTM!

The integration of lazy-style detection into the mutate() function is clean and correct. Detection happens before AST transformation, ensuring the original function structure is analyzed.

Also applies to: 717-717

testing/python/language/test_tilelang_language_subtype.py (1)

9-9: LGTM!

The decorator changes from @tilelang.lazy_jit to @tilelang.jit are straightforward and align with the PR's unification objective. The test logic remains unchanged, ensuring behavioral consistency.

Also applies to: 18-18, 98-98, 108-108, 119-119, 130-130, 142-142

tilelang/language/v2/builder.py (5)

853-895: LGTM!

The TirTemplate enhancements cleanly separate lazy-style (direct PrimFunc return) from eager-style (constexpr substitution) execution paths. The is_lazy_style flag and from_lazy_style() constructor provide clear semantics, and the get_tir() short-circuit for lazy style is efficient.

Also applies to: 909-912


944-944: Excellent improvement: static AST-based mode detection.

The new _is_lazy_style() implementation correctly leverages the static AST analysis performed during mutate(), eliminating the need for runtime PrimFunc checks and exception handling. This is more reliable and efficient than the previous approach.

Also applies to: 961-982


984-1023: LGTM!

The _build_tir_template() method cleanly centralizes template construction logic based on execution mode. The integration with parse_args() and get_tir() is correct, and the caching mechanism ensures efficient reuse.

Also applies to: 1028-1030


1032-1042: LGTM!

The attribute proxying mechanism correctly exposes the original function's attributes for compatibility with autotuner and inspect module. The implementation properly handles __wrapped__ and delegates other attributes to orig_func.


826-842: Confusing assertion logic in const() function.

Line 842 asserts that builder.lazy_jit is True, but the error message claims T.const() can only be used in "eager mode". This is contradictory and confusing. Looking at line 992, builder.lazy_jit = True is set in what the code treats as eager-style execution, suggesting the field name is misleading.

Consider either:

  • Renaming builder.lazy_jit to something clearer like builder.is_jit or builder.from_jit
  • Or updating the assertion message to reflect what the check actually does

Likely an incorrect or invalid review comment.

tilelang/jit/__init__.py (3)

193-265: LGTM!

The extensive docstring update provides clear, concrete examples of both lazy and eager execution modes. The simplified type for func (now always LazyJITFunc[_KP, _T]) correctly reflects that the decorator always wraps functions in LazyJITFunc.

Also applies to: 277-279


368-403: LGTM!

The compile() method correctly infers execution mode, validates that out_idx is only used in lazy mode (an important constraint), and properly sets the mode on func before retrieving TIR. The error message for eager mode with out_idx is helpful.


472-483: Do not allow PrimFunc in the jit decorator signature.

The type annotation func: Callable[_P, _T] | PrimFunc | None suggests PrimFunc input is supported, but the decorator calls prim_func(func, lazy_jit=True) without checking input type. The prim_func() function expects a Python function to analyze via mutate() (AST transformation), not a PrimFunc object. Passing a PrimFunc would fail because mutate() cannot perform AST analysis on TVM IR objects.

Either remove PrimFunc from the type annotation or add an explicit check to handle PrimFunc separately.

🔎 Proposed fix to remove PrimFunc from type
 def jit(
-    func: Callable[_P, _T] | PrimFunc | None = None,
+    func: Callable[_P, _T] | None = None,
     *,  # Indicates subsequent arguments are keyword-only
     out_idx: list[int] | int | None = None,

Also applies to: 522-536

Likely an incorrect or invalid review comment.

Comment on lines 430 to 448
kwargs.update(kwargs.pop("__tune_params", {}))

# infer mode early, before parse_args needs it
if self.mode == "auto":
self.mode = self._infer_jit_mode()
self.func.set_mode(self.mode)

key, kernel_args = self.func.parse_args(*args, **kwargs)
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs)
self._kernel_cache[key] = kernel

# eager mode: execute kernel immediately and return result
# lazy mode: return kernel object for manual invocation
if self.mode == "eager":
return kernel(*kernel_args.values())
else:
key = self.parse_cache_key(*args, **kwargs)
tune_params = kwargs.pop("__tune_params", {})
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs, **tune_params)
self._kernel_cache[key] = kernel
return kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Handle None return from parse_args() to prevent crash.

LazyJITFunc.parse_args() can return None for kernel_args when there are no tensor arguments (see builder.py line 1007: if not tensor_args: return (p1_key, None), kwargs). However, line 446 unconditionally calls kernel(*kernel_args.values()), which will raise AttributeError: 'NoneType' object has no attribute 'values' in eager mode when the function has no tensor arguments.

🔎 Proposed fix to handle None case
         key, kernel_args = self.func.parse_args(*args, **kwargs)
         kernel = self._kernel_cache.get(key, None)
         if kernel is None:
             kernel = self.compile(*args, **kwargs)
             self._kernel_cache[key] = kernel
 
         # eager mode: execute kernel immediately and return result
         # lazy mode: return kernel object for manual invocation
         if self.mode == "eager":
-            return kernel(*kernel_args.values())
+            return kernel(*kernel_args.values()) if kernel_args else kernel()
         else:
             return kernel
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 430 - 448, parse_args() can return
kernel_args == None for no-tensor calls; change the eager-path in the JIT
wrapper (the block after self._kernel_cache lookup) to handle that: if
kernel_args is None call kernel() with no positional args, otherwise call
kernel(*kernel_args.values()). Update the code around self.func.parse_args(...)
and the eager branch that currently does kernel(*kernel_args.values()) so it
checks kernel_args is not None before accessing .values(), leaving cache
key/compile behavior unchanged.

- Removed the `@simplify_prim_func` decorator from `tl_matmul` in `example_gemm_intrinsics.py`.
- Introduced a new `JITMode` literal in `__init__.py` to clarify execution modes.
- Updated the `jit` function to include a `mode` parameter, defaulting to "auto", enhancing flexibility in execution mode selection.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 437-446: parse_args can return kernel_args as None when there are
no tensor args, so calling kernel(*kernel_args.values()) will crash; update the
eager-mode branch in the function that calls self.func.parse_args to check
whether kernel_args is None (or falsy) and, if so, invoke kernel() with no
positional args, otherwise call kernel(*kernel_args.values()); reference the
symbols self.func.parse_args, kernel_args, kernel, and self.mode to locate and
modify the code path.
🧹 Nitpick comments (1)
tilelang/jit/__init__.py (1)

378-401: Local variable shadows imported prim_func function.

Line 378 assigns prim_func = self.get_tir(...) which shadows the imported prim_func function from tilelang.language.v2. While this works because the import isn't used after this point in the method, consider renaming to tir_func or pf for clarity.

🔎 Suggested rename
-        prim_func = self.get_tir(*args, **kwargs)
+        tir_func = self.get_tir(*args, **kwargs)
         kernel_result = compile(
-            prim_func,
+            tir_func,
             out_idx=self.out_idx,
             ...
         )
         ...
         with open(path.join(self.debug_root_path, program_file), "w") as f:
-                print(prim_func.script(), file=f)
+                print(tir_func.script(), file=f)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1f93975 and df50a0b.

📒 Files selected for processing (2)
  • examples/gemm/example_gemm_intrinsics.py
  • tilelang/jit/__init__.py
💤 Files with no reviewable changes (1)
  • examples/gemm/example_gemm_intrinsics.py
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/jit/__init__.py (3)
tilelang/language/v2/builder.py (10)
  • LazyJITFunc (922-1042)
  • set_mode (1028-1030)
  • PrimFunc (684-693)
  • _is_lazy_style (961-982)
  • prim_func (182-189)
  • prim_func (1067-1108)
  • get_tir (909-918)
  • get_tir (1016-1023)
  • parse_args (1003-1014)
  • get (219-220)
tilelang/jit/kernel.py (1)
  • out_idx (609-610)
tilelang/jit/adapter/wrapper.py (2)
  • prim_func (575-585)
  • prim_func (839-849)
🪛 Ruff (0.14.10)
tilelang/jit/__init__.py

305-305: Prefer TypeError exception for invalid type

(TRY004)


305-305: Avoid specifying long messages outside the exception class

(TRY003)


375-375: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (2)
tilelang/jit/__init__.py (2)

529-542: LGTM on decorator implementation.

The unified decorator correctly wraps functions via prim_func(), captures source and signature, and passes the new mode parameter to JITImpl. The pattern of returning decorator(func) if func is not None else decorator properly supports both @jit and @jit(...) usage.


195-238: Clear and comprehensive documentation for the unified JIT interface.

The docstring provides excellent examples differentiating lazy and eager modes, making migration straightforward for users. The mode inference behavior is well-explained.

Comment on lines +437 to +446
key, kernel_args = self.func.parse_args(*args, **kwargs)
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs)
self._kernel_cache[key] = kernel

# eager mode: execute kernel immediately and return result
# lazy mode: return kernel object for manual invocation
if self.mode == "eager":
return kernel(*kernel_args.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: kernel_args can be None, causing crash in eager mode.

LazyJITFunc.parse_args() returns (p1_key, None), kwargs when there are no tensor arguments (see builder.py line 1007). Line 446 unconditionally calls kernel(*kernel_args.values()), which will raise AttributeError: 'NoneType' object has no attribute 'values' when the decorated function has no tensor arguments.

🔎 Proposed fix
         # eager mode: execute kernel immediately and return result
         # lazy mode: return kernel object for manual invocation
         if self.mode == "eager":
-            return kernel(*kernel_args.values())
+            return kernel(*kernel_args.values()) if kernel_args else kernel()
         else:
             return kernel
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
key, kernel_args = self.func.parse_args(*args, **kwargs)
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs)
self._kernel_cache[key] = kernel
# eager mode: execute kernel immediately and return result
# lazy mode: return kernel object for manual invocation
if self.mode == "eager":
return kernel(*kernel_args.values())
key, kernel_args = self.func.parse_args(*args, **kwargs)
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs)
self._kernel_cache[key] = kernel
# eager mode: execute kernel immediately and return result
# lazy mode: return kernel object for manual invocation
if self.mode == "eager":
return kernel(*kernel_args.values()) if kernel_args else kernel()
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 437 - 446, parse_args can return
kernel_args as None when there are no tensor args, so calling
kernel(*kernel_args.values()) will crash; update the eager-mode branch in the
function that calls self.func.parse_args to check whether kernel_args is None
(or falsy) and, if so, invoke kernel() with no positional args, otherwise call
kernel(*kernel_args.values()); reference the symbols self.func.parse_args,
kernel_args, kernel, and self.mode to locate and modify the code path.

…tection

- Updated test functions to remove unnecessary `@T.prim_func` decorators and streamline variable assignments.
- Enhanced lazy style detection in AST analysis to improve clarity and functionality.
- Added logging to warn users when JIT mode cannot be auto-detected, encouraging explicit mode specification.
- Refactored `LazyJITFunc` to return itself after setting the JIT mode for better chaining.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI Agents
In @tilelang/jit/__init__.py:
- Around line 476-488: The jit() signature claims func can be a PrimFunc but the
implementation always calls prim_func(func, lazy_jit=True) which expects a
Python function and will break for an already-constructed PrimFunc; either
remove PrimFunc from the type union in jit()’s func parameter, or add a branch
in jit() that detects isinstance(func, PrimFunc) (or checks type) and skips
calling prim_func/AST mutation, passing the PrimFunc through to JITImpl
creation; update references to prim_func, mutate, and the type annotation
accordingly so PrimFunc values are handled safely.
- Around line 437-448: LazyJITFunc can receive kernel_args==None from
LazyJITFunc.parse_args, so calling kernel(*kernel_args.values()) in the eager
branch can raise AttributeError; update the eager execution in the function that
uses self.mode (tilelang/jit/__init__.py) to handle a None kernel_args by
invoking kernel() with no arguments when kernel_args is None (instead of
unconditionally using kernel_args.values()), e.g., check if kernel_args is None
and call kernel() else call kernel(*kernel_args.values()) so functions without
tensor args execute correctly.
🧹 Nitpick comments (3)
tilelang/language/v2/ast.py (1)

624-630: Add parentheses to clarify operator precedence in condition.

The condition mixes and and or operators without explicit parentheses. While Python's precedence rules make this work correctly, the intent is clearer with explicit grouping.

🔎 Proposed fix
                 if (
-                    isinstance(decorator, ast.Attribute)
-                    and decorator.attr == "prim_func"
-                    or isinstance(decorator, ast.Name)
-                    and decorator.id == "prim_func"
+                    (isinstance(decorator, ast.Attribute) and decorator.attr == "prim_func")
+                    or (isinstance(decorator, ast.Name) and decorator.id == "prim_func")
                 ):
tilelang/language/v2/builder.py (2)

840-842: Confusing field name builder.lazy_jit for eager mode context.

The field builder.lazy_jit = True is set when using eager mode (line 992), but the name suggests "lazy JIT". The assertion message correctly says "eager mode", but the field name is misleading. Consider renaming to builder.is_jit or builder.from_jit_decorator for clarity.


1028-1031: Type mismatch in set_mode parameter.

The mode attribute is typed as Literal["auto", "lazy", "eager"] (line 944), but set_mode accepts Literal["lazy", "eager"]. This is intentional since "auto" shouldn't be set after inference, but consider adding a comment or updating the docstring to clarify this.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between df50a0b and 7432b46.

📒 Files selected for processing (4)
  • testing/python/language/test_tilelang_language_frontend_v2.py
  • tilelang/jit/__init__.py
  • tilelang/language/v2/ast.py
  • tilelang/language/v2/builder.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • testing/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • testing/python/language/test_tilelang_language_frontend_v2.py
🧬 Code graph analysis (3)
tilelang/language/v2/ast.py (3)
tilelang/language/ast/ir.py (1)
  • func_name (206-214)
tilelang/jit/__init__.py (1)
  • decorator (529-541)
tilelang/language/v2/builder.py (1)
  • source (707-708)
tilelang/language/v2/builder.py (3)
tilelang/jit/adapter/cython/adapter.py (1)
  • prim_func (356-358)
tilelang/language/v2/ast.py (1)
  • IRGenerator (583-587)
tilelang/jit/__init__.py (1)
  • get_tir (291-307)
testing/python/language/test_tilelang_language_frontend_v2.py (2)
tilelang/jit/__init__.py (3)
  • jit (458-458)
  • jit (462-473)
  • jit (476-543)
tilelang/language/allocate.py (5)
  • empty (267-267)
  • empty (270-278)
  • alloc_var (86-86)
  • alloc_var (90-90)
  • alloc_var (93-147)
🪛 Ruff (0.14.10)
tilelang/language/v2/ast.py

625-626: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


627-628: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)

tilelang/language/v2/builder.py

1001-1001: Avoid specifying long messages outside the exception class

(TRY003)


1043-1043: Avoid specifying long messages outside the exception class

(TRY003)


1068-1068: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/jit/__init__.py

305-305: Prefer TypeError exception for invalid type

(TRY004)


305-305: Avoid specifying long messages outside the exception class

(TRY003)


375-375: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (18)
tilelang/language/v2/ast.py (5)

3-3: LGTM: Logger setup for lazy-style detection warnings.

The logging module import and logger initialization are appropriate for providing diagnostic warnings when the JIT mode cannot be auto-detected.

Also applies to: 20-21


587-588: LGTM: Added is_lazy_style flag to IRGenerator.

The new field with default False maintains backward compatibility while enabling downstream components to tailor behavior based on detected style.


590-598: LGTM: Helper to detect return statements with values.

The function correctly identifies functions with meaningful return values while filtering out explicit return None statements.


663-670: LGTM: Helpful warning for ambiguous style detection.

When the function has a return value but no recognizable lazy/eager patterns, the warning appropriately guides users to specify the mode explicitly.


703-704: LGTM: Lazy style detection integrated into mutate flow.

The detection is performed before AST transformation and correctly propagated to the resulting IRGenerator.

Also applies to: 730-730

tilelang/jit/__init__.py (2)

195-265: LGTM: Clear documentation of lazy vs eager execution modes.

The updated docstring provides excellent examples for both modes, and the attribute annotations correctly reflect the new mode-aware interface.


368-403: LGTM: Mode-aware compilation with proper validation.

The method correctly infers mode, validates that out_idx is only used in lazy mode, and propagates mode to the wrapped function before compilation.

tilelang/language/v2/builder.py (6)

20-20: LGTM: Extended typing imports for mode support.

The addition of Literal and get_origin supports the new mode type hints and callable annotation handling.


853-918: LGTM: TirTemplate extended with lazy-style support.

The addition of is_lazy_style flag, from_lazy_style() factory method, and short-circuit in get_tir() cleanly separates lazy and eager template handling.


921-944: LGTM: Updated LazyJITFunc documentation and mode support.

The docstring clearly explains the two execution styles, and the mode attribute enables explicit mode control when auto-detection is insufficient.


961-982: Simplified _is_lazy_style() using static AST analysis.

The method now delegates to ir_gen.is_lazy_style, which was computed during mutate(). This is more reliable than the previous runtime detection approach.


984-1001: LGTM: Mode-aware TIR template building.

The _build_tir_template method correctly routes to TirTemplate.from_lazy_style() for lazy mode and the Builder-based path for eager mode.


1082-1086: LGTM: Safe callable annotation evaluation.

The check using get_origin() correctly skips typing generics like Optional, Union, List which are callable but cannot be instantiated.

testing/python/language/test_tilelang_language_frontend_v2.py (5)

206-221: LGTM: Test migrated to unified @tilelang.jit decorator.

The test correctly uses eager mode with T.empty() for output allocation and returns the result tensor.


258-292: LGTM: Stepped serial tests migrated to unified decorator.

Both positive and negative step tests correctly use the eager mode pattern with explicit returns.


294-319: LGTM: Swap tests migrated with in-place tensor modification.

The tests correctly use eager mode with tensor type annotations inside the function body, demonstrating in-place modification semantics.


324-338: LGTM: While loop test migrated to unified decorator.

The test correctly uses eager mode with T.empty() output and validates while loop accumulation.


457-477: LGTM: Constexpr if test demonstrates mode parameterization.

The test correctly shows eager mode with a boolean parameter controlling compile-time branching, verifying both true and false paths.

Comment on lines +437 to 448
key, kernel_args = self.func.parse_args(*args, **kwargs)
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs)
self._kernel_cache[key] = kernel

# eager mode: execute kernel immediately and return result
# lazy mode: return kernel object for manual invocation
if self.mode == "eager":
return kernel(*kernel_args.values())
else:
key = self.parse_cache_key(*args, **kwargs)
tune_params = kwargs.pop("__tune_params", {})
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs, **tune_params)
self._kernel_cache[key] = kernel
return kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: kernel_args can be None, causing crash in eager mode.

LazyJITFunc.parse_args() returns (p1_key, None) when there are no tensor arguments (see builder.py line 1006-1007). Line 446 unconditionally calls kernel(*kernel_args.values()), which will raise AttributeError: 'NoneType' object has no attribute 'values' when the decorated function has no tensor arguments.

🔎 Proposed fix
         # eager mode: execute kernel immediately and return result
         # lazy mode: return kernel object for manual invocation
         if self.mode == "eager":
-            return kernel(*kernel_args.values())
+            return kernel(*kernel_args.values()) if kernel_args else kernel()
         else:
             return kernel
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 437 - 448, LazyJITFunc can receive
kernel_args==None from LazyJITFunc.parse_args, so calling
kernel(*kernel_args.values()) in the eager branch can raise AttributeError;
update the eager execution in the function that uses self.mode
(tilelang/jit/__init__.py) to handle a None kernel_args by invoking kernel()
with no arguments when kernel_args is None (instead of unconditionally using
kernel_args.values()), e.g., check if kernel_args is None and call kernel() else
call kernel(*kernel_args.values()) so functions without tensor args execute
correctly.

Comment on lines +476 to +488
def jit(
func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
out_idx: list[int] | int | None = None,
target: str | Target | None = None,
target_host: str | Target | None = None,
execution_backend: ExecutionBackend | None = None,
verbose: bool | None = None,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None,
):
mode: JITMode = "auto",
) -> Callable[[Callable[_P, _T]], JITImpl[_KP, _KP, _T, _T]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Type annotation allows PrimFunc but implementation doesn't handle it.

The signature includes func: Callable[_P, _T] | PrimFunc | None, but line 530 calls prim_func(func, lazy_jit=True) without checking if func is already a PrimFunc. The prim_func() function performs AST transformation via mutate(), which expects a Python function, not a TVM IR object. Either remove PrimFunc from the type annotation or add a branch to handle it.

🔎 Proposed fix: Remove PrimFunc from signature
 def jit(
-    func: Callable[_P, _T] | PrimFunc | None = None,
+    func: Callable[_P, _T] | None = None,
     *,  # Indicates subsequent arguments are keyword-only
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def jit(
func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
out_idx: list[int] | int | None = None,
target: str | Target | None = None,
target_host: str | Target | None = None,
execution_backend: ExecutionBackend | None = None,
verbose: bool | None = None,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None,
):
mode: JITMode = "auto",
) -> Callable[[Callable[_P, _T]], JITImpl[_KP, _KP, _T, _T]]:
def jit(
func: Callable[_P, _T] | None = None,
*, # Indicates subsequent arguments are keyword-only
out_idx: list[int] | int | None = None,
target: str | Target | None = None,
target_host: str | Target | None = None,
execution_backend: ExecutionBackend | None = None,
verbose: bool | None = None,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None,
mode: JITMode = "auto",
) -> Callable[[Callable[_P, _T]], JITImpl[_KP, _KP, _T, _T]]:
🤖 Prompt for AI Agents
In @tilelang/jit/__init__.py around lines 476 - 488, The jit() signature claims
func can be a PrimFunc but the implementation always calls prim_func(func,
lazy_jit=True) which expects a Python function and will break for an
already-constructed PrimFunc; either remove PrimFunc from the type union in
jit()’s func parameter, or add a branch in jit() that detects isinstance(func,
PrimFunc) (or checks type) and skips calling prim_func/AST mutation, passing the
PrimFunc through to JITImpl creation; update references to prim_func, mutate,
and the type annotation accordingly so PrimFunc values are handled safely.

@LeiWang1999 LeiWang1999 closed this Jan 7, 2026
@LeiWang1999 LeiWang1999 deleted the refactor/unify-jit-decorator branch January 12, 2026 08:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant