Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Jan 7, 2026

Summary

This PR unifies @tilelang.jit and @tilelang.lazy_jit into a single @tilelang.jit decorator that automatically infers the execution mode based on function behavior.

Two JIT Writing Styles

TileLang now supports two kernel writing styles with a unified decorator:

1. Lazy Style (explicit PrimFunc return)

@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K):
    @T.prim_func
    def kernel(A: T.Tensor((M, K), "float16"), B: T.Tensor((K, N), "float16"), C: T.Tensor((M, N), "float16")):
        with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128):
            ...
    return kernel  # explicitly return PrimFunc

# Returns kernel object, execute separately
kernel = matmul(1024, 1024, 1024, 128, 128, 32)
result = kernel(a, b)

2. Eager Style (DSL builder pattern)

@tilelang.jit
def gemm(A, B, C, block_M: int = 64):
    M, N, K = T.const("M N K")  # constexpr for static shapes
    A: T.Tensor[[M, K], "float16"]  # tensor shape via annotation
    B: T.Tensor[[K, N], "float16"]
    C: T.Tensor[[M, N], "float16"]
    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128):
        ...
    # no return - builder constructs TIR implicitly

# Compiles and executes immediately, returns result
gemm(A, B, C)

How Mode Inference Works

The decorator automatically distinguishes between lazy and eager styles by:

  1. Calling the original function with provided arguments
  2. Checking the return value:
    • If it returns a PrimFunclazy mode
    • If it raises JITNoBuilderErroreager mode

The key insight is that eager-style functions use features like T.const() and T.Kernel() which require an active Builder context. When called directly (without Builder), these functions raise JITNoBuilderError, signaling that the function is eager-style.

New Exception Classes

Added tilelang/jit/exceptions.py with:

  • JITNoBuilderError: Raised when T.const() or T.Kernel() is called outside JIT/prim_func context
  • EagerJITBuildError: General error during eager-style kernel construction

Key Changes

  • Removed lazy_jit export from tilelang/__init__.py
  • Unified JITImpl to handle both modes via mode attribute ("auto", "lazy", "eager")
  • Added _infer_jit_mode() method for automatic mode detection
  • Added Builder existence check in T.Kernel() function
  • Updated all examples and tests to use unified @tilelang.jit

Test plan

  • Existing lazy-style tests pass
  • Existing eager-style tests pass
  • Mode inference correctly identifies lazy vs eager functions
  • out_idx validation (only allowed in lazy mode)

Summary by CodeRabbit

  • New Features

    • Added eager JIT execution mode with automatic mode inference; new JITFunc class supporting lazy/eager execution models.
    • Introduced runtime guards enforcing proper Builder context for kernel operations.
    • Added JITNoBuilderError and EagerJITBuildError exception types for clearer error diagnostics.
  • Refactoring

    • Migrated decorator from @tilelang.lazy_jit to unified @tilelang.jit with configurable execution modes.
    • Reorganized module structure: moved core APIs from v2 to eager backend.
    • Removed simplify_prim_func decorator across examples and tests.
  • Improvements

    • Enhanced validation for constant shapes in layout operations.
    • Updated test suite with CUDA availability checks.

✏️ Tip: You can customize this high-level summary in your review settings.

…nd tests

Updated all instances of the @tilelang.lazy_jit decorator to @tilelang.jit in the lazyjit example notebooks and related test files. This change aligns the code with the new JIT compilation approach, enhancing consistency across the codebase. Additionally, removed the lazy_jit import from the module initialization to streamline the API.
…l loops

Updated the T.copy function to accept a new keyword-only parameter, loop_layout, allowing users to specify layout hints for the outermost parallel loop. Enhanced layout annotation handling in CopyNode and AtomicAddNode classes to ensure compatibility with SIMT operations. Added tests to validate the functionality of loop layout annotations, improving robustness and clarity in layout management for parallel loops.
…ernel and T.const functions. Ensure proper error handling when Builder is unavailable, improving robustness in JIT context management.
… to improve performance metrics output by adding average FLOPS calculations. Modify layout_inference.cc to refine buffer layout inference checks by ensuring fragment buffers are considered.
@github-actions
Copy link

github-actions bot commented Jan 7, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 7, 2026

📝 Walkthrough

Walkthrough

This PR consolidates TileLang's JIT system by replacing the lazy_jit decorator with a unified mode-based jit supporting both lazy and eager execution. It migrates module imports from tilelang.language.v2 to tilelang.language.eager, adds runtime Builder context guards, and updates all test and example decorators accordingly. New exception classes and JITFunc class enable mode inference and execution control.

Changes

Cohort / File(s) Summary
JIT Core Architecture
tilelang/jit/__init__.py, tilelang/jit/exceptions.py
Replaced LazyJITFunc with JITFunc; refactored JITImpl to support mode attribute ("auto", "lazy", "eager"); added _infer_jit_mode() and initialize_jit_mode() methods; modified __call__() to infer mode and return results immediately (eager) or kernel objects (lazy); updated compile()/par_compile() signatures; added JITNoBuilderError and EagerJITBuildError exception classes.
Language Builder Refactor
tilelang/language/eager/builder.py, tilelang/language/v2/__init__.py
Introduced JITFunc class replacing LazyJITFunc; refactored prim_func() to accept eager_jit parameter and return PrimFunc or JITFunc; reworked TirTemplate with is_lazy_style flag and from_lazy_style() constructor; updated const() to enforce Builder context with runtime guard; replaced lazy-only approach with dual lazy/eager paths.
Public API Consolidation
tilelang/__init__.py, tilelang/language/__init__.py, tilelang/language/eager/__init__.py
Removed lazy_jit from public exports; changed dtypes import source from v2 to eager; switched primary module re-export from .v2 to .eager; added public exports in eager module for prim_func, macro, PrimFunc, JITFunc, Ref, const.
Builder Context Guards
tilelang/language/kernel.py, tilelang/language/proxy.py
Added runtime checks in Kernel() and make_tensor() to require active Builder context; raises JITNoBuilderError when Builder.current() is None.
Import Path Updates (v2→eager)
tilelang/language/loop.py, tilelang/language/print_op.py, tilelang/language/allocate.py, tilelang/jit/adapter/tvm_ffi.py
Migrated imports from tilelang.language.v2 to tilelang.language.eager for SerialForWithStep, UnrollForWithStep, Builder, and dtype.
Test Decorator Migrations
testing/python/language/test_tilelang_language_lazy_jit.py, testing/python/language/test_tilelang_language_subtype.py, testing/python/layout/test_tilelang_annotate_loop_layout.py
Replaced all @tilelang.lazy_jit decorators with @tilelang.jit; updated par_compile() call signatures to pass two tensors per copy.
Example Decorator Removals
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py, examples/gemm/example_gemm_intrinsics.py
Removed simplify_prim_func import and @simplify_prim_func decorator.
Frontend Test Refactor
testing/python/language/test_tilelang_language_frontend_v2.py
Replaced macro/prim_func wrappers with @tilelang.jit; inlined tensor allocations into test bodies; converted test signatures to parameterless forms; updated return statements and call sites; refactored boolean logic to use TVM/TIR operators (Not, Or, tir_all).
Example Output Adjustments
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
Removed hard-coded block_H assignment and debug prints; added Average FLOPS output after performance measurements.
Test Coverage Updates
testing/python/language/test_tilelang_language_ptr.py, testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py, testing/python/issue/test_tilelang_issue_1549.py, testing/python/issue/test_tilelang_issue_1601.py, testing/python/analysis/test_tilelang_fragment_loop_checker.py
Added backend variants (cython, tvm_ffi); added CUDA requirement decorators; commented out test invocations; relocated simplify_prim_func imports and decorator usage.
Core Transform Validation
src/transform/layout_inference.cc, src/transform/lower_tile_op.cc
Narrowed warning condition in BufferUseDefCollector to fragment buffers only; added shape constancy check in makeBufferWithLayout.
Notebook Updates
examples/lazy_jit/lazyjit.en.ipynb, examples/lazy_jit/lazyjit.zh.ipynb
Replaced @tilelang.lazy_jit with @tilelang.jit decorators across code cells.
Expression Construction
testing/python/arith/test_arith_hard.py
Replaced Python and with TVM/TIR tir_all() for condition composition; removed @T.macro decorators from expression definitions.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • tile-ai/tilelang#1480: Overlapping changes to JIT/prim_func/lazy-vs-eager handling in builder and jit layer
  • tile-ai/tilelang#1337: Introduces lazy_jit experimental interface, which this PR removes from public exports
  • tile-ai/tilelang#1120: Touches language/JIT frontend surfaces and v2 builder types that this PR refactors

Suggested reviewers

  • tzj-fxz

Poem

🐰 From lazy dreams to eager schemes, a unified mode now gleams,
No more split paths through v2 dreams—just jit() with graceful themes!
Guards protect the Builder's keep, while modes run deep,
Eager results or lazy kernels sleep... hop, hop! 🎉

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 17.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and accurately summarizes the main objective: unifying two decorators (@jit and @lazy_jit) into a single @jit decorator with automatic mode inference.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@LeiWang1999 LeiWang1999 requested a review from Copilot January 7, 2026 08:59
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR unifies @tilelang.lazy_jit and @tilelang.jit into a single @tilelang.jit decorator that automatically infers execution mode based on function behavior. The unified decorator supports two styles: lazy mode (functions that explicitly return a PrimFunc) and eager mode (functions that use the DSL builder pattern with T.const() and T.Kernel()).

Key changes:

  • Added automatic mode inference by detecting whether functions return a PrimFunc or raise JITNoBuilderError
  • Introduced custom exception classes (JITNoBuilderError, EagerJITBuildError) for distinguishing execution modes
  • Removed lazy_jit from public exports
  • Updated all examples and tests to use the unified @tilelang.jit decorator

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tilelang/jit/__init__.py Unified JIT decorator implementation with automatic mode inference and removed lazy_jit function
tilelang/jit/exceptions.py Added new exception classes for JIT mode detection
tilelang/language/v2/builder.py Enhanced LazyJITFunc with mode inference logic and updated TirTemplate to support lazy-style functions
tilelang/language/kernel.py Added Builder existence check to T.Kernel() for eager mode validation
tilelang/__init__.py Removed lazy_jit from public API exports
testing/python/layout/test_tilelang_annotate_loop_layout.py Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit
testing/python/language/test_tilelang_language_subtype.py Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit
testing/python/language/test_tilelang_language_lazy_jit.py Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit
examples/lazy_jit/lazyjit.zh.ipynb Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit
examples/lazy_jit/lazyjit.en.ipynb Updated decorator usage from @tilelang.lazy_jit to @tilelang.jit
examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py Updated decorator usage and removed unused import
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py Code cleanup: removed unused variables and debug print statements
src/transform/layout_inference.cc Fixed buffer validation logic in layout inference
3rdparty/tvm Updated TVM submodule commit reference

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In
@examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py:
- Around line 391-393: The printed "GFLOPS" are incorrect because avg_flops =
total_flops / avg_time yields FLOPS; convert to GFLOPS by dividing by 1e9 and
update the formatted value and label accordingly for both the main measurement
(avg_flops, total_flops, avg_time) and the reference implementation's
corresponding variables (e.g., ref_total_flops/ref_avg_flops or whatever the
reference block uses around lines ~405-407), ensuring you compute gflops =
total_flops / avg_time / 1e9 and print "GFLOPS" with that converted value.
🧹 Nitpick comments (3)
tilelang/language/v2/builder.py (1)

1027-1034: Consider simplifying redundant cache check.

The cache check at line 1029 appears redundant since parse_args() (called at line 1028) already populates p1_cache at lines 1019-1023. The comment mentions "legacy gemm" but the logic path where p1_cache would be empty after parse_args is unclear.

♻️ Suggested simplification
     def get_tir(self, *args, **kwargs):
         (p1_key, _), tensor_args = self.parse_args(*args, **kwargs)
-        if p1_key not in self.p1_cache:
-            # in legacy gemm, we use lazy tir template to build the tir
-            tir_temp = self._build_tir_template(*args, **kwargs)
-            self.p1_cache[p1_key] = tir_temp
-            return tir_temp.get_tir(**tensor_args)
         return self.p1_cache[p1_key].get_tir(**tensor_args)

If there's a specific legacy case requiring this, please add a more detailed comment explaining when parse_args wouldn't populate the cache.

tilelang/jit/__init__.py (2)

295-302: Consider using TypeError for invalid type.

Per Python conventions, TypeError is more appropriate than ValueError when the issue is an unexpected type.

♻️ Suggested fix
         if isinstance(self.func, PrimFunc):
             tir = self.func
         elif isinstance(self.func, (LazyJITFunc, Callable)):
             tir = self.func(*args, **kwargs)
         else:
-            raise ValueError(f"Invalid function type: {type(self.func)}")
+            raise TypeError(f"Invalid function type: {type(self.func)}")
         assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}"
         return tir

438-443: Pass kernel arguments as keyword arguments instead of unpacking dict values as positional arguments.

At line 441, kernel(*kernel_args.values()) unpacks dictionary values positionally. While Python 3.7+ guarantees dict insertion order, this assumes the dict value order matches the kernel function's parameter order. To be explicit and avoid fragility, pass arguments as keywords: kernel(**kernel_args).

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 566d8f2 and edb8500.

📒 Files selected for processing (14)
  • 3rdparty/tvm
  • examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
  • examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
  • examples/lazy_jit/lazyjit.en.ipynb
  • examples/lazy_jit/lazyjit.zh.ipynb
  • src/transform/layout_inference.cc
  • testing/python/language/test_tilelang_language_lazy_jit.py
  • testing/python/language/test_tilelang_language_subtype.py
  • testing/python/layout/test_tilelang_annotate_loop_layout.py
  • tilelang/__init__.py
  • tilelang/jit/__init__.py
  • tilelang/jit/exceptions.py
  • tilelang/language/kernel.py
  • tilelang/language/v2/builder.py
💤 Files with no reviewable changes (1)
  • examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
🧰 Additional context used
🧠 Learnings (5)
📚 Learning: 2025-12-26T06:45:51.789Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1483
File: tilelang/jit/adapter/cutedsl/adapter.py:93-95
Timestamp: 2025-12-26T06:45:51.789Z
Learning: For the CuTeDSL backend in tilelang/jit/adapter/cutedsl/adapter.py, the host_kernel_source and device_kernel_source have the same value.

Applied to files:

  • examples/lazy_jit/lazyjit.en.ipynb
  • examples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • examples/lazy_jit/lazyjit.en.ipynb
  • examples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • examples/lazy_jit/lazyjit.en.ipynb
  • testing/python/language/test_tilelang_language_subtype.py
  • examples/lazy_jit/lazyjit.zh.ipynb
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_subtype.py
  • testing/python/language/test_tilelang_language_lazy_jit.py
  • testing/python/layout/test_tilelang_annotate_loop_layout.py
📚 Learning: 2025-12-15T07:48:45.785Z
Learnt from: cherichy
Repo: tile-ai/tilelang PR: 1421
File: src/target/codegen_cutedsl.cc:789-793
Timestamp: 2025-12-15T07:48:45.785Z
Learning: In tilelang/contrib/cutedsl, the `tl.make_rmem_tensor` function accepts both an Integer and a Tuple of Integer for its shape parameter. Therefore, both `tl.make_rmem_tensor(N, ...)` and `tl.make_rmem_tensor((N,), ...)` are valid syntaxes in CuteDSL-generated code.

Applied to files:

  • testing/python/language/test_tilelang_language_subtype.py
🧬 Code graph analysis (5)
testing/python/language/test_tilelang_language_subtype.py (1)
tilelang/jit/__init__.py (3)
  • jit (450-450)
  • jit (454-464)
  • jit (467-531)
testing/python/language/test_tilelang_language_lazy_jit.py (2)
tilelang/jit/__init__.py (3)
  • jit (450-450)
  • jit (454-464)
  • jit (467-531)
tilelang/language/v2/builder.py (1)
  • const (826-854)
src/transform/layout_inference.cc (2)
src/op/utils.h (1)
  • IsFragmentBuffer (33-35)
src/transform/common/loop_fusion_utils.h (1)
  • IsFragmentBuffer (70-75)
tilelang/language/v2/builder.py (1)
tilelang/jit/exceptions.py (2)
  • JITNoBuilderError (4-13)
  • EagerJITBuildError (16-24)
tilelang/language/kernel.py (2)
tilelang/jit/exceptions.py (1)
  • JITNoBuilderError (4-13)
tilelang/language/v2/builder.py (2)
  • Builder (166-676)
  • current (178-180)
🪛 Ruff (0.14.10)
tilelang/__init__.py

145-145: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/v2/builder.py

845-845: Avoid specifying long messages outside the exception class

(TRY003)


1012-1012: Avoid specifying long messages outside the exception class

(TRY003)


1053-1053: Avoid specifying long messages outside the exception class

(TRY003)


1078-1078: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/jit/__init__.py

300-300: Prefer TypeError exception for invalid type

(TRY004)


300-300: Avoid specifying long messages outside the exception class

(TRY003)


370-370: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/kernel.py

290-290: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: CodeQL analysis (python)
🔇 Additional comments (19)
3rdparty/tvm (1)

1-1: Verify submodule compatibility with JIT refactoring changes.

Submodule reference updated from e47e76a... to 001022b.... Since this PR introduces mode-inference changes to the JIT system and adds builder-aware error handling, please confirm that the new TVM commit is compatible with these tilelang changes and does not introduce breaking changes or regressions.

examples/lazy_jit/lazyjit.zh.ipynb (1)

56-56: LGTM! Decorator migration is consistent throughout the notebook.

All @tilelang.lazy_jit usages have been correctly updated to @tilelang.jit. The notebook content properly demonstrates eager-style DSL patterns (using T.const, T.Tensor annotations, T.Kernel), which aligns with the unified JIT decorator's auto-detection behavior.

Also applies to: 212-212, 251-251, 310-310, 362-362, 424-424, 473-473, 518-518, 580-580, 800-800, 860-860

src/transform/layout_inference.cc (1)

225-231: LGTM! Warning condition correctly narrowed to fragment buffers only.

The updated condition !use_list_.count(buffer) && IsFragmentBuffer(buffer) appropriately limits the warning to cases where fragment buffers are missing from use_list_. Non-fragment buffers (global/shared memory) not appearing in use_list_ is expected behavior and doesn't warrant a warning.

Note: Line 234 still accesses use_list_[buffer] even when the buffer isn't in use_list_. This is safe because std::unordered_map::operator[] creates an empty vector for missing keys, and iterating an empty vector is a no-op.

testing/python/layout/test_tilelang_annotate_loop_layout.py (1)

7-7: LGTM! Test file correctly migrated to unified @tilelang.jit decorator.

All three kernel functions (loop_layout_kernel, copy_with_layout_kernel, replicate_loop_layout_kernel) have been updated from @tilelang.lazy_jit to @tilelang.jit. The test assertions remain unchanged, ensuring the loop layout annotation functionality continues to work correctly with the unified decorator.

Also applies to: 54-54, 82-82

examples/lazy_jit/lazyjit.en.ipynb (1)

56-56: LGTM! English notebook correctly migrated to unified @tilelang.jit decorator.

All decorator usages have been consistently updated from @tilelang.lazy_jit to @tilelang.jit, matching the changes in the Chinese version of the notebook.

Also applies to: 212-212, 251-251, 310-310, 362-362, 424-424, 473-473, 518-518, 580-580, 800-800, 860-860

tilelang/language/kernel.py (2)

10-10: LGTM! Import for the new exception type.

The import of JITNoBuilderError is correctly placed and supports the new runtime guard for enforcing Builder context.


283-291: LGTM! Runtime guard correctly enforces Builder context for T.Kernel().

The guard ensures T.Kernel() can only be called within a proper JIT/prim_func context by checking for an active Builder. The lazy import of Builder appropriately avoids circular import issues.

The inline error message is clear and actionable. While static analysis (Ruff TRY003) suggests moving long messages into the exception class, the context-specific message here is acceptable since JITNoBuilderError is reused across multiple call sites (e.g., T.const()) that may need different messages.

tilelang/jit/exceptions.py (1)

1-24: LGTM! Well-structured custom exceptions for JIT error handling.

Both exception classes are cleanly defined with descriptive docstrings that clearly explain their purpose:

  • JITNoBuilderError: Raised when Builder-dependent operations (like T.Kernel(), T.const()) are called outside a JIT/prim_func context
  • EagerJITBuildError: Raised for failures during eager-style kernel construction

The separation allows callers to catch and handle these distinct error conditions appropriately.

tilelang/__init__.py (1)

145-145: LGTM! Public API correctly updated to remove lazy_jit export.

Removing lazy_jit from the public exports aligns with the PR objective to unify JIT functionality under the single @tilelang.jit decorator. Users should migrate from @tilelang.lazy_jit to @tilelang.jit.

Note: The static analysis hint about unused noqa: F401 appears to be a false positive — the directive correctly suppresses the "imported but unused" warning for these intentional re-exports.

This is a breaking change for any external code using tilelang.lazy_jit. Please ensure the migration guide or release notes document this API removal.

testing/python/language/test_tilelang_language_subtype.py (1)

9-10: LGTM!

The decorator migration from @tilelang.lazy_jit to @tilelang.jit is consistent across all kernel functions. These kernels use the eager-style DSL pattern (with T.dynamic, tensor annotations, and T.Kernel context), which the unified decorator will correctly auto-detect.

Also applies to: 18-19, 98-99, 108-109, 119-120, 130-131, 142-143

testing/python/language/test_tilelang_language_lazy_jit.py (1)

9-9: LGTM!

The decorator migrations from @tilelang.lazy_jit to @tilelang.jit are correct across all kernel functions. All use the eager-style DSL pattern with T.const, tensor annotations, and T.Kernel context managers.

Also applies to: 48-48, 105-105, 112-112, 119-119, 126-126, 134-134, 141-141, 178-178, 184-184, 189-189, 195-195, 202-202, 208-208

tilelang/language/v2/builder.py (4)

841-845: LGTM with a note on naming.

The error handling correctly enforces that T.const() is only usable in eager mode with an active Builder. The internal flag name builder.lazy_jit being True for eager mode is counterintuitive but is existing behavior maintained for backward compatibility.


857-923: LGTM!

The TirTemplate class correctly distinguishes between lazy-style (direct PrimFunc return, no substitution needed) and eager-style (constexpr variable substitution required). The from_lazy_style factory and early return in get_tir are clean implementations.


965-993: Verify side-effect safety of mode inference probing.

The _is_lazy_style() method probes the original function by calling self.orig_func(*args, **kwargs) to detect the execution style. This works correctly for the detection logic, but be aware that:

  1. Any side effects (e.g., prints, logging, external calls) in the decorated function will execute during mode inference.
  2. The function may be called twice: once for probing and once for actual TIR building.

This is likely acceptable since:

  • Lazy-style functions typically just construct and return a PrimFunc without side effects.
  • Eager-style functions will raise JITNoBuilderError early before reaching user code.

Please confirm this probing behavior is intentional and document it if users might have functions with side effects.


1091-1096: LGTM!

The get_origin(annot[k]) is None check correctly prevents calling typing generics (like Optional[int], Union[...], List[...]) which are callable but cannot be instantiated. This is a proper fix for handling modern type annotations.

tilelang/jit/__init__.py (4)

192-265: LGTM!

Excellent documentation for JITImpl. The docstring clearly explains both execution modes with practical examples, making the API behavior easy to understand for users.


304-316: LGTM!

The mode inference logic is clean and correct:

  • Short-circuits for explicitly set modes
  • Handles raw PrimFunc inputs by defaulting to lazy mode
  • Delegates to LazyJITFunc._is_lazy_style() for automatic detection

363-398: LGTM!

The compile method correctly:

  1. Infers mode on first compilation
  2. Validates that out_idx is only used in lazy mode (as per PR objectives)
  3. Propagates mode to the underlying function
  4. Handles debug output appropriately

467-531: LGTM!

The unified jit decorator is well-implemented:

  • Supports both @jit and @jit(options=...) usage patterns
  • Initializes mode as "auto" for automatic detection
  • Correctly wraps functions via prim_func(func, lazy_jit=True) to create LazyJITFunc
  • Preserves source code and signature for debugging and introspection

…ation and improve error handling for eager mode. Update matmul tests to utilize both Cython and TVM FFI JIT kernels, ensuring consistency in output validation.
…m_intrinsics.py and related test files. Enhance layout.cc to simplify integer set analysis for multi-threaded compilation, ensuring constant integer outputs. Update lower_tile_op.cc to enforce constant integer checks in layout output shapes. Improve error handling in builder.py for JIT context management.
…d proxy.py for improved readability. Add missing import in test_tilelang_kernel_int4_gemm_mma.py and ensure consistent spacing in builder.py.
…ons within test_arith_hard.py for improved clarity and consistency. Update main execution flow to call test_hard_prove instead of tilelang.testing.main.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py (1)

18-159: Add @simplify_prim_func decorator to tl_matmul for consistency.

The tl_matmul function (line 18) is missing the @simplify_prim_func decorator that appears on tl_matmul_weight_only_transform (line 203) and on tl_matmul in all other similar test files (test_tilelang_kernel_fp8_gemm_mma.py, test_tilelang_kernel_bf16_gemm_mma.py, test_tilelang_kernel_gemm_mma_intrinsic.py, etc.). Apply the decorator to maintain consistency and ensure uniform IR simplification across equivalent kernels.

🤖 Fix all issues with AI agents
In @tilelang/jit/__init__.py:
- Line 298: The isinstance check mixes LazyJITFunc and typing.Callable
erroneously; remove Callable from the tuple so the branch only matches
LazyJITFunc (i.e., change isinstance(self.func, (LazyJITFunc, Callable)) to
isinstance(self.func, LazyJITFunc)), and then remove the now-unused Callable
import if it becomes unused so imports stay clean.
🧹 Nitpick comments (3)
testing/python/language/test_tilelang_language_lazy_jit.py (1)

229-230: Commented test runner should be cleaned up.

The main test runner tilelang.testing.main() is commented out and replaced with a direct call to test_jit2_return(). This appears to be temporary debugging code that should either be reverted or properly justified.

🧹 Suggested cleanup
 if __name__ == "__main__":
-    # tilelang.testing.main()
-    test_jit2_return()
+    tilelang.testing.main()
testing/python/language/test_tilelang_language_frontend_v2.py (2)

294-320: LGTM! In-place tensor modification pattern is correct.

Both functions correctly use the eager-mode pattern with tensor annotations inside the function body (A: T.Tensor[(2,), T.float32]), which is the expected TileLang syntax for eager execution with external tensor arguments.

Note: Line 313 contains an extraneous blank line that could be removed for consistency.


324-339: LGTM! Consider consistent use of .item() for tensor element access.

The function correctly implements the eager-mode pattern. Line 338 uses .item() to extract the scalar value for comparison, which is more explicit than the direct comparison pattern used in test_var_assign (lines 220-221). While both approaches work with PyTorch, using .item() is more robust and explicit.

For consistency, consider using .item() in test_var_assign as well:

assert res[0].item() == 1
assert res[1].item() == 2
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between edb8500 and f7e00eb.

📒 Files selected for processing (11)
  • examples/gemm/example_gemm_intrinsics.py
  • src/layout/layout.cc
  • src/transform/lower_tile_op.cc
  • testing/python/arith/test_arith_hard.py
  • testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
  • testing/python/language/test_tilelang_language_frontend_v2.py
  • testing/python/language/test_tilelang_language_lazy_jit.py
  • testing/python/language/test_tilelang_language_ptr.py
  • tilelang/jit/__init__.py
  • tilelang/language/proxy.py
  • tilelang/language/v2/builder.py
💤 Files with no reviewable changes (1)
  • examples/gemm/example_gemm_intrinsics.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_ptr.py
  • testing/python/language/test_tilelang_language_lazy_jit.py
  • testing/python/language/test_tilelang_language_frontend_v2.py
  • testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • testing/python/language/test_tilelang_language_ptr.py
  • testing/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • testing/python/language/test_tilelang_language_lazy_jit.py
  • testing/python/language/test_tilelang_language_frontend_v2.py
  • testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
🧬 Code graph analysis (6)
tilelang/language/proxy.py (2)
tilelang/jit/exceptions.py (1)
  • JITNoBuilderError (4-13)
tilelang/language/v2/builder.py (2)
  • Builder (167-679)
  • current (179-181)
testing/python/arith/test_arith_hard.py (1)
tilelang/language/tir/op.py (1)
  • all (1913-1930)
tilelang/jit/__init__.py (3)
tilelang/jit/adapter/tvm_ffi.py (2)
  • func (206-260)
  • prim_func (319-321)
tilelang/language/v2/builder.py (10)
  • LazyJITFunc (935-1062)
  • PrimFunc (688-697)
  • _is_lazy_style (974-1002)
  • set_mode (1048-1050)
  • prim_func (184-193)
  • prim_func (1087-1128)
  • get_tir (922-931)
  • get_tir (1036-1043)
  • parse_args (1023-1034)
  • get (223-224)
tilelang/jit/adapter/wrapper.py (2)
  • prim_func (575-585)
  • prim_func (839-849)
testing/python/language/test_tilelang_language_lazy_jit.py (2)
tilelang/jit/__init__.py (4)
  • jit (451-451)
  • jit (455-465)
  • jit (468-532)
  • get_tir (291-303)
tilelang/language/copy_op.py (1)
  • copy (14-116)
testing/python/language/test_tilelang_language_frontend_v2.py (2)
tilelang/jit/__init__.py (3)
  • jit (451-451)
  • jit (455-465)
  • jit (468-532)
tilelang/language/allocate.py (3)
  • alloc_var (86-86)
  • alloc_var (90-90)
  • alloc_var (93-147)
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py (1)
tilelang/transform/simplify.py (1)
  • simplify_prim_func (53-58)
🪛 Ruff (0.14.10)
tilelang/language/proxy.py

282-282: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/__init__.py

301-301: Prefer TypeError exception for invalid type

(TRY004)


301-301: Avoid specifying long messages outside the exception class

(TRY003)


324-324: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/v2/builder.py

191-191: Avoid specifying long messages outside the exception class

(TRY003)


717-717: Avoid specifying long messages outside the exception class

(TRY003)


854-854: Avoid specifying long messages outside the exception class

(TRY003)


1021-1021: Avoid specifying long messages outside the exception class

(TRY003)


1062-1062: Avoid specifying long messages outside the exception class

(TRY003)


1087-1087: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (12)
src/layout/layout.cc (1)

126-144: LGTM! Robust multi-level fallback for extent calculation.

The simplification attempt with cascading fallbacks (IntImm check → ConstIntBound → input_size_ → Integer(1)) provides good defensive handling for multi-threaded compilation scenarios where int_set analysis may yield unsimplified symbolic expressions. The logic aligns well with the existing fallback path for unbounded cases (lines 107-124).

src/transform/lower_tile_op.cc (1)

69-70: LGTM! Essential validation for constant layout shapes.

The ICHECK correctly enforces that layout output shapes must be constant integers before dereferencing on line 71, preventing potential null pointer crashes. The error message clearly indicates the requirement and shows the problematic value.

testing/python/arith/test_arith_hard.py (2)

6-6: LGTM: Import addition is appropriate.

The import of tir_all is necessary for constructing proper TIR expressions. The alias avoids shadowing Python's built-in all function.


27-59: LGTM: Correct usage of tir_all for symbolic expressions.

Replacing Python's and operator with tir_all(...) is the correct approach for combining symbolic boolean expressions in TIR. The Python and operator would not create the proper TIR expression structure needed for the Analyzer to reason about these conditions. Using tir_all ensures the expressions are constructed as proper TIR intersection nodes that the symbolic reasoning engine can analyze.

testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py (1)

9-9: LGTM: Import correctly added.

The import of simplify_prim_func is properly placed and necessary for the decorator usage on line 203.

tilelang/language/proxy.py (1)

278-282: LGTM! Builder context check added for make_tensor.

The runtime guard correctly enforces that make_tensor requires an active Builder context. The error message clearly guides users to use the function within @tilelang.jit or @T.prim_func context.

testing/python/language/test_tilelang_language_ptr.py (1)

44-58: LGTM! Multi-backend validation improves test coverage.

The test now validates the matmul kernel across both Cython and TVM FFI execution backends, comparing results against a reference implementation and cross-validating between backends. This strengthens confidence in backend consistency.

testing/python/language/test_tilelang_language_lazy_jit.py (2)

9-9: LGTM! Decorator updated to unified jit API.

The change from @tilelang.lazy_jit to @tilelang.jit aligns with the PR objective of unifying the JIT decorator. The mode will be auto-inferred based on function behavior.


149-149: LGTM! par_compile updated for new API.

The par_compile call now correctly uses get_tir() to obtain TIR functions with dummy tensor arguments for each copy function. This aligns with the new mode-based JIT implementation.

testing/python/language/test_tilelang_language_frontend_v2.py (3)

206-222: LGTM! Eager-mode pattern correctly implemented.

The function correctly uses the unified @tilelang.jit decorator with eager-style execution: it allocates the tensor internally, performs operations, and returns the result.


258-292: LGTM! Both nested test functions follow the eager-mode pattern.

Both stepped_serial and stepped_serial_neg correctly use the unified decorator, allocate tensors internally, and return results for direct verification.


457-478: LGTM! Conditional logic correctly implemented for eager mode.

The probe function correctly uses the unified decorator with eager-mode semantics. The tensor annotation inside the function body and the constexpr conditional logic based on the tmp parameter are both properly implemented.

…nctions in test_tilelang_issue_1549.py and test_tilelang_issue_1601.py to ensure CUDA requirements are explicitly stated for relevant tests.
…setting max value from ist.max(). Update test_valid_loop in test_tilelang_fragment_loop_checker.py to disable certain tests and ensure proper execution flow. Enhance error handling in builder.py to correctly manage PrimFunc return types.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @testing/python/analysis/test_tilelang_fragment_loop_checker.py:
- Around line 145-147: Two test functions valid_loop_not_use_loop_var() and
valid_loop_serial() are defined but never invoked, leaving dead code and reduced
coverage; either re-enable them by adding calls (uncomment the calls currently
commented out near valid_loop_not_frag()) and fix any failures they expose, or
delete the unused function definitions and add a brief TODO/issue reference
comment explaining why they were removed; ensure that the final change either
calls valid_loop_not_use_loop_var() and valid_loop_serial() so all three
valid-loop scenarios run, or removes both functions and updates test coverage
expectations accordingly.
🧹 Nitpick comments (1)
tilelang/language/v2/builder.py (1)

1008-1026: Validate mode before building TIR template.

The method expects mode to be either "lazy" or "eager", but there's no validation that it's not "auto". Consider adding an assertion at the start:

assert self.mode in ("lazy", "eager"), f"Mode must be set before building TIR template, got: {self.mode}"

This would make the requirement explicit and provide a clearer error if the caller forgets to set the mode.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b3c4d77 and 62beb83.

📒 Files selected for processing (2)
  • testing/python/analysis/test_tilelang_fragment_loop_checker.py
  • tilelang/language/v2/builder.py
🧰 Additional context used
🧬 Code graph analysis (2)
testing/python/analysis/test_tilelang_fragment_loop_checker.py (1)
tilelang/env.py (1)
  • disable_cache (297-298)
tilelang/language/v2/builder.py (1)
tilelang/jit/exceptions.py (2)
  • JITNoBuilderError (4-13)
  • EagerJITBuildError (16-24)
🪛 Ruff (0.14.10)
tilelang/language/v2/builder.py

191-191: Avoid specifying long messages outside the exception class

(TRY003)


717-717: Avoid specifying long messages outside the exception class

(TRY003)


854-854: Avoid specifying long messages outside the exception class

(TRY003)


999-999: Consider moving this statement to an else block

(TRY300)


1025-1025: Avoid specifying long messages outside the exception class

(TRY003)


1066-1066: Avoid specifying long messages outside the exception class

(TRY003)


1091-1091: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (13)
testing/python/analysis/test_tilelang_fragment_loop_checker.py (1)

150-153: Verify if main block changes are intentional.

The main block now bypasses the standard test runner (tilelang.testing.main()) and directly calls test_valid_loop(). This pattern is unusual for committed test code and suggests debugging/manual testing.

Please clarify:

  1. Is this change intentional for production, or was it left in during debugging?
  2. If intentional, why is the standard test runner bypassed?
  3. Should tilelang.disable_cache() be in the main block, or should it be a test fixture?

If this was debugging code, consider reverting to the standard pattern:

♻️ Revert to standard test execution pattern
 if __name__ == "__main__":
-    # tilelang.testing.main()
-    tilelang.disable_cache()
-    test_valid_loop()
+    tilelang.testing.main()
tilelang/language/v2/builder.py (12)

21-21: LGTM!

The new imports support the lazy/eager mode detection logic and builder context validation.

Also applies to: 32-32


186-193: Good defensive cleanup pattern.

The try/finally ensures thread_local_storage.builder is always cleared, preventing builder leakage if an exception occurs during function body execution.


715-718: LGTM!

Enforcing builder existence for macros is correct, as they require an active JIT context to generate IR.


835-864: LGTM!

The updated docstring clearly explains eager mode usage, and the explicit exception handling is more appropriate than assertions for user-facing error cases.


867-907: Lazy-style templates bypass matcher safely.

The is_lazy_style field distinguishes between eager and lazy templates. For lazy-style templates created via from_lazy_style(), the matcher is None. This is safe because get_tir() (line 923) returns early for lazy-style templates, so _parse_phase2_key() (line 909) is never invoked.


922-931: LGTM!

The early return for lazy-style templates correctly bypasses shape substitution, as lazy functions return PrimFunc directly without constexpr variables.


974-1006: Mode inference through controlled exception handling.

The approach is clever: calling the function and inspecting the outcome (PrimFunc vs. JITNoBuilderError) distinguishes lazy from eager style. The caching at lines 996-998 for lazy-style avoids redundant PrimFunc construction.


1027-1038: LGTM!

The caching logic preserves performance by avoiding redundant TIR construction. The comment at line 1034 documents the contract with JITImpl to set mode before calling this method.


1040-1047: LGTM!

The fallback at lines 1042-1046 handles the case where parse_args() returns early without caching (line 1030), ensuring the TIR template is always built when needed.


1052-1054: LGTM!

The setter allows JITImpl to configure the execution mode after inference, correctly typed to exclude "auto".


1056-1066: LGTM!

Proxying function attributes is necessary for compatibility with introspection tools and the autotuner. The implementation correctly handles __wrapped__ as a special case and delegates other introspection attributes to the original function.


1091-1131: Critical fix for type annotation handling.

The get_origin() check at line 1108 prevents attempting to instantiate typing generics (like Optional[int], Union[...]) which are callable but should not be called. This correctly distinguishes between factory functions and typing constructs.

@LeiWang1999
Copy link
Member Author

@regression-perf

…hecker.py to call tilelang.testing.main. Modify JIT implementation to handle callable functions more generically. Enhance TirTemplate class in builder.py to return an empty tuple if matcher is None, improving robustness.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tilelang/language/v2/builder.py (1)

1093-1093: Minor: Consider explicit Optional type annotation.

The static analysis hint (RUF013) notes that PEP 484 prohibits implicit Optional. Line 1093 has:

def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> ...

For strict PEP 484 compliance, this should be:

def prim_func(func: Optional[Callable[_P, _T]] = None, *, lazy_jit: bool = False) -> ...

However, this is a very minor stylistic issue and doesn't affect runtime behavior. Python's type checkers are generally lenient about this pattern in decorators.

📝 Optional fix for PEP 484 compliance
+from typing import Optional
+
-def prim_func(func: Callable[_P, _T] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
+def prim_func(func: Optional[Callable[_P, _T]] = None, *, lazy_jit: bool = False) -> PrimFunc[_P, _T] | LazyJITFunc[_P, _T]:
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4093759 and d2d0ce6.

📒 Files selected for processing (3)
  • testing/python/analysis/test_tilelang_fragment_loop_checker.py
  • tilelang/jit/__init__.py
  • tilelang/language/v2/builder.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • testing/python/analysis/test_tilelang_fragment_loop_checker.py
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/language/v2/builder.py (2)
tilelang/jit/__init__.py (3)
  • jit (451-451)
  • jit (455-465)
  • jit (468-532)
tilelang/jit/exceptions.py (2)
  • JITNoBuilderError (4-13)
  • EagerJITBuildError (16-24)
tilelang/jit/__init__.py (1)
tilelang/language/v2/builder.py (10)
  • LazyJITFunc (937-1068)
  • PrimFunc (688-697)
  • _is_lazy_style (976-1008)
  • set_mode (1054-1056)
  • prim_func (184-193)
  • prim_func (1093-1134)
  • get_tir (924-933)
  • get_tir (1042-1049)
  • parse_args (1029-1040)
  • get (223-224)
🪛 Ruff (0.14.10)
tilelang/language/v2/builder.py

191-191: Avoid specifying long messages outside the exception class

(TRY003)


717-717: Avoid specifying long messages outside the exception class

(TRY003)


854-854: Avoid specifying long messages outside the exception class

(TRY003)


1001-1001: Consider moving this statement to an else block

(TRY300)


1027-1027: Avoid specifying long messages outside the exception class

(TRY003)


1068-1068: Avoid specifying long messages outside the exception class

(TRY003)


1093-1093: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/jit/__init__.py

324-324: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (14)
tilelang/jit/__init__.py (5)

277-279: LGTM! Mode attribute and func type narrowing are well-structured.

The addition of the mode attribute and narrowing func to LazyJITFunc[_KP, _T] properly support the new unified JIT workflow. The type narrowing is sound given that the decorator always creates a LazyJITFunc via prim_func(func, lazy_jit=True) at line 520.


373-399: Approve compile method changes.

The refactoring to use prim_func (line 374) instead of func and prim_func.script() (line 397) instead of func.script() is correct and aligns with the new mode-based workflow.


426-444: Approve call implementation with minor observation.

The early mode inference (lines 429-431) and conditional execution based on mode (lines 439-444) are well-structured. The eager mode execution path correctly invokes the kernel immediately with kernel(*kernel_args.values()).

One minor observation: At line 442, you're unpacking kernel_args.values() with *. Ensure that the ordering of kernel_args is deterministic and matches the expected parameter order for the kernel. Based on the code in builder.py where tensor_args is built from a dict comprehension, dictionary order is preserved (Python 3.7+), so this should be fine.


518-530: Approve unified jit decorator implementation.

The decorator properly creates a LazyJITFunc with lazy_jit=True (line 520) and initializes the mode to "auto" (line 519), which will be inferred on first call. The removal of the legacy lazy_jit parameter and the unified approach is clean.


295-325: The mode inference and initialization logic is sound and handles edge cases appropriately.

The three points are verified as non-blocking:

  1. Defaulting to "lazy" mode (line 315-316): Safe and correct. All decorated functions pass through prim_func(..., lazy_jit=True) during initialization, so this default only applies to edge cases where a non-LazyJITFunc is used directly—a rare scenario where "lazy" is the appropriate fallback.

  2. Error message for out_idx in eager mode (line 323-324): The message is clear and helpful. Since it appears only once in the file, extracting to a constant is optional rather than necessary.

  3. Mode state mutation (lines 321, 430): Intentional and safe. Mode is inferred once on the first call (when self.mode == "auto") and cached for subsequent calls. All operations that depend on mode occur after initialization, making the caching behavior correct and efficient.

tilelang/language/v2/builder.py (9)

186-193: Excellent: Builder cleanup in finally block.

Wrapping the cleanup in a finally block (line 192-193) ensures thread_local_storage.builder is always cleared even when exceptions occur. This prevents stale Builder references from leaking across prim_func invocations.


714-721: Approve runtime guard for Macro usage.

The check at lines 716-717 correctly enforces that macros can only be used within a JIT context by raising JITNoBuilderError when no active Builder exists. This aligns with the PR's goal of adding Builder existence checks.


835-863: Approve eager-mode enforcement in const().

The updated docstring (lines 836-849) clearly documents that T.const() is for eager mode only. The runtime check (lines 853-854) properly raises JITNoBuilderError when called outside of eager JIT context.

The commented-out assertions (lines 851-852) should ideally be removed rather than left as comments, but since they're immediately followed by the replacement code, this is acceptable for clarity during development.


904-907: LGTM: TirTemplate.from_lazy_style factory method.

The new factory method correctly creates a lazy-style template with is_lazy_style=True and no matcher, which aligns with lazy mode where PrimFunc is used directly without substitution.


976-1008: Approve mode inference logic with strong exception-based detection.

The _is_lazy_style method (lines 976-1008) implements clever mode detection:

  1. Lines 994-1000: If the function returns a PrimFunc, it's lazy style. The early caching at lines 998-999 is a nice optimization.

  2. Lines 1002-1008: Catching JITNoBuilderError and EagerJITBuildError to detect eager mode is elegant. The comment clearly explains that eager-only features (like T.const() or T.Kernel()) raise these exceptions when no Builder exists during the trial call.

  3. Line 1001: The static analysis hint (TRY300) suggests moving return False to an else block. However, the current structure is clear and the early return True makes the logic easier to follow. The hint can be safely ignored.


1010-1028: LGTM: _build_tir_template implements mode-based template construction.

The method correctly branches on mode:

  • Lazy mode (lines 1012-1014): Calls the function directly to get the PrimFunc.
  • Eager mode (lines 1015-1025): Traces through Builder to construct TIR, similar to the original prim_func flow.

The error message at line 1027 clearly indicates invalid mode values. Per the static analysis hint (TRY003), this could be a custom exception, but for an internal method, a ValueError with a clear message is acceptable.


1029-1056: Approve parse_args and set_mode methods.

  1. parse_args (lines 1029-1040): Correctly handles both lazy and eager styles, with proper caching and template building. The comment at line 1036 helpfully notes that mode should be set by JITImpl before calling this method.

  2. set_mode (lines 1054-1056): Simple setter for internal use. The docstring clarifies this is internal-only.


1058-1069: Excellent: Attribute proxying for autotuner compatibility.

The __getattr__ implementation (lines 1063-1068) with _PROXIED_ATTRS (line 1061) elegantly proxies function attributes like __closure__, __code__, etc., to the original function. The comment (lines 1058-1060) explains this is needed for autotuner and inspection, which is valuable context.

The special handling of __wrapped__ (lines 1065-1066) follows Python conventions for wrapped functions.


1107-1111: Approve callable annotation handling with get_origin check.

The updated logic (lines 1107-1111) correctly distinguishes between:

  • Callable factory functions (should be called)
  • Typing generics like Optional[int], Union[...], List[...] (should NOT be called)

The get_origin(annot[k]) is None check ensures typing constructs are not mistakenly invoked. This is a solid fix for a subtle bug.

@github-actions
Copy link

github-actions bot commented Jan 8, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/20806003264

Results

File Original Latency Current Latency Speedup
example_warp_specialize_gemm_copy_0_gemm_1 0.038272 0.039937 0.958309
example_warp_specialize_gemm_softpipe_stage2 0.037633 0.038912 0.967131
example_warp_specialize_gemm_copy_1_gemm_0 0.037793 0.038913 0.971218
example_warp_specialize_gemm_barrierpipe_stage2 0.038561 0.039072 0.986922
example_topk 0.01104 0.011136 0.991379
block_sparse_attn_tilelang 0.0102819 0.0102871 0.999499
example_gqa_bwd_wgmma_pipelined 0.0741775 0.0742058 0.999619
example_fusedmoe_tilelang 0.131665 0.131693 0.999788
example_mha_bwd_bshd_wgmma_pipelined 0.0256385 0.0256382 1.00001
example_mha_bwd_bshd 0.0408033 0.0408028 1.00001
example_elementwise_add 0.297576 0.297569 1.00002
example_vertical_slash_sparse_attn 0.237286 0.237265 1.00009
example_mha_bwd_bhsd 0.0400041 0.039999 1.00013
example_mha_fwd_varlen 0.0452824 0.0452745 1.00018
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40883 1.40854 1.00021
example_linear_attn_fwd 0.0365619 0.0365539 1.00022
example_gqa_bwd_tma_reduce_varlen 0.0636842 0.0636641 1.00032
example_tilelang_gemm_splitk 1.40786 1.40733 1.00038
example_linear_attn_bwd 0.152154 0.15209 1.00042
tilelang_example_sparse_tensorcore 0.0150624 0.015054 1.00055
example_gemv 0.288977 0.288763 1.00074
example_gqa_bwd 0.0498224 0.0497614 1.00123
example_dynamic 0.656108 0.65451 1.00244
example_mha_inference 0.073376 0.073026 1.00479
example_convolution_autotune 1.00284 0.993874 1.00902
example_dequant_gemv_fp16xint4 0.0287392 0.028473 1.00935
example_gqa_decode 0.049345 0.048769 1.01181
example_per_token_cast_to_fp8 0.00747453 0.0073717 1.01395
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.54498 3.48524 1.01714
example_tilelang_nsa_fwd 0.00712952 0.00700405 1.01791
example_tilelang_nsa_decode 0.00688601 0.00674453 1.02098
example_tilelang_block_sparse_attn 0.0104721 0.0101714 1.02956
example_dequant_gemm_w4a8 5.56274 5.39339 1.0314
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.015958 0.0154315 1.03412
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0160085 0.015472 1.03468
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0150846 0.0145692 1.03538
example_mha_sink_fwd_bhsd_sliding_window 0.0162621 0.0157061 1.0354
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0149447 0.0144271 1.03588
example_mha_sink_fwd_bhsd 0.0164691 0.0158569 1.03861
example_mha_sink_bwd_bhsd_sliding_window 0.0462213 0.0445018 1.03864
example_tilelang_sparse_gqa_decode_varlen_indice 0.01787 0.0172031 1.03877
example_dequant_gemm_bf16_fp4_hopper 0.602763 0.580265 1.03877
example_blocksparse_gemm 0.0235786 0.0226701 1.04007
example_tilelang_sparse_gqa_decode_varlen_mask 0.0243567 0.023384 1.04159
example_group_per_split_token_cast_to_fp8 0.0107028 0.0102405 1.04514
example_mha_sink_bwd_bhsd 0.0643951 0.0614604 1.04775
example_gqa_sink_bwd_bhsd_sliding_window 0.0267944 0.0255642 1.04812
example_gqa_sink_bwd_bhsd 0.0437833 0.0416606 1.05095
example_convolution 1.40008 1.33135 1.05163
example_dequant_gemm_bf16_mxfp4_hopper 0.536424 0.508618 1.05467
example_mla_decode 0.48746 0.461319 1.05667
example_dequant_gemm_fp4_hopper 1.12457 1.06175 1.05917

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

…ss the codebase, updating related function signatures and logic to support eager JIT compilation mode. This change enhances the clarity and functionality of JIT operations.
…escription and remove outdated comment regarding implicit TIR construction.
@LeiWang1999
Copy link
Member Author

changes applied, pls take a look @SiriusNEO

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @tilelang/jit/__init__.py:
- Around line 318-324: initialize_jit_mode currently calls
self.func.set_mode(self.mode) unconditionally which will raise AttributeError
when self.func is a PrimFunc; check the actual type before calling set_mode
(e.g., use isinstance(self.func, JITFunc) or mirror the guard used in
_infer_jit_mode) and only call self.func.set_mode(self.mode) when self.func is a
JITFunc; apply the same guarded check to the other unconditional call to
self.func.set_mode at the other location (around the later call referenced at
line 430) so PrimFunc paths do not attempt to call set_mode.
- Around line 438-443: The eager-mode return path assumes kernel_args is a dict
but parse_args can return None; update the eager branch in __init__.py to handle
kernel_args being None by calling kernel() with no arguments when kernel_args is
None (or using an empty list), e.g. compute args = list(kernel_args.values()) if
kernel_args else [] and then return kernel(*args); reference symbols:
kernel_args, parse_args, kernel, self.mode.
🧹 Nitpick comments (4)
tilelang/language/v2/builder.py (3)

866-871: Remove commented-out code.

The commented-out assertions on lines 867-868 are dead code that should be removed since the proper check is already implemented on lines 869-870.

♻️ Proposed fix
     builder = Builder.current()
-    # assert builder is not None, "T.const() can only be used inside @tilelang.jit (eager mode)"
-    # assert builder.eager_jit, "T.const() can only be used inside @tilelang.jit (eager mode)"
     if builder is None or not builder.eager_jit:
         raise JITNoBuilderError("T.const() can only be used inside @tilelang.jit (eager mode)")

1058-1065: Consider clarifying the cache check redundancy.

After parse_args populates p1_cache (lines 1051-1054), the check at line 1060 if p1_key not in self.p1_cache seems redundant. The comment mentions "legacy gemm" - if this is for backward compatibility, consider making it more explicit or removing if the legacy path is no longer needed.


1109-1109: Type annotation uses implicit Optional.

Per PEP 484, func: Callable[_P, _T] = None should be func: Callable[_P, _T] | None = None for type checker compatibility. This is a minor typing issue that doesn't affect runtime behavior.

♻️ Proposed fix
-def prim_func(func: Callable[_P, _T] = None, *, eager_jit: bool = False) -> PrimFunc[_P, _T] | JITFunc[_P, _T]:
+def prim_func(func: Callable[_P, _T] | None = None, *, eager_jit: bool = False) -> PrimFunc[_P, _T] | JITFunc[_P, _T]:
tilelang/jit/__init__.py (1)

426-431: Mode inference logic is duplicated.

Lines 428-430 duplicate the mode inference from initialize_jit_mode (lines 318-321). Consider extracting to a shared helper or calling initialize_jit_mode here to ensure consistent behavior, especially for the out_idx validation at line 322-323 which is skipped in __call__.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d2d0ce6 and d33550b.

📒 Files selected for processing (3)
  • tilelang/jit/__init__.py
  • tilelang/language/v2/__init__.py
  • tilelang/language/v2/builder.py
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/language/v2/__init__.py (1)
tilelang/language/v2/builder.py (7)
  • prim_func (189-198)
  • prim_func (1109-1150)
  • macro (201-228)
  • macro (746-782)
  • PrimFunc (704-713)
  • JITFunc (953-1084)
  • const (851-879)
tilelang/jit/__init__.py (7)
tilelang/language/v2/builder.py (5)
  • PrimFunc (704-713)
  • prim_func (189-198)
  • prim_func (1109-1150)
  • JITFunc (953-1084)
  • get (230-231)
tilelang/jit/adapter/wrapper.py (2)
  • prim_func (575-585)
  • prim_func (839-849)
tilelang/jit/adapter/cython/adapter.py (1)
  • prim_func (356-358)
tilelang/jit/adapter/tvm_ffi.py (2)
  • prim_func (319-321)
  • func (206-260)
tilelang/jit/adapter/nvrtc/adapter.py (1)
  • prim_func (267-269)
src/transform/simplify.cc (2)
  • func (222-279)
  • func (223-225)
tilelang/jit/kernel.py (1)
  • out_idx (609-610)
tilelang/language/v2/builder.py (2)
tilelang/jit/__init__.py (3)
  • jit (450-450)
  • jit (454-464)
  • jit (467-531)
tilelang/jit/exceptions.py (2)
  • JITNoBuilderError (4-13)
  • EagerJITBuildError (16-24)
🪛 Ruff (0.14.10)
tilelang/language/v2/__init__.py

1-1: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/jit/__init__.py

323-323: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/v2/builder.py

196-196: Avoid specifying long messages outside the exception class

(TRY003)


733-733: Avoid specifying long messages outside the exception class

(TRY003)


870-870: Avoid specifying long messages outside the exception class

(TRY003)


1017-1017: Consider moving this statement to an else block

(TRY300)


1043-1043: Avoid specifying long messages outside the exception class

(TRY003)


1084-1084: Avoid specifying long messages outside the exception class

(TRY003)


1109-1109: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
🔇 Additional comments (11)
tilelang/language/v2/builder.py (6)

21-21: LGTM: Import additions for mode-based JIT support.

The imports of Literal, get_origin from typing and exception classes from tilelang.jit.exceptions are appropriate for the new unified JIT behavior.

Also applies to: 32-32


188-198: Good: Proper cleanup with try/finally pattern.

The cleanup of thread_local_storage.builder in the finally block ensures the builder is always cleared, preventing stale state. This is a solid improvement over the previous implementation.


730-737: Good: Runtime guard for macro invocation context.

Raising JITNoBuilderError when no Builder is active ensures macros are only used within a proper JIT context, enabling the mode inference mechanism.


882-950: LGTM: TirTemplate properly handles lazy vs eager styles.

The is_lazy_style field and from_lazy_style factory method cleanly separate the two code paths. The early returns in _parse_phase2_key and get_tir for lazy style are correct.


1010-1024: Mode inference via exception catching is a reasonable pattern, but document the side-effect risk.

Calling self.orig_func(*args, **kwargs) to determine if the function returns a PrimFunc works, but note that:

  1. If the function has side effects before returning, those will execute during mode inference
  2. The result is cached in p1_cache on success (line 1015), so subsequent calls won't re-execute

This pattern is acceptable given the design constraints, but consider adding a comment noting that the original function should be side-effect-free up to the point of returning or raising.


1122-1127: Good: Properly handles typing generics.

The get_origin(annot[k]) is None check correctly prevents calling typing constructs like Optional[int] or Union[...] which are callable but cannot be instantiated as factories.

tilelang/language/v2/__init__.py (1)

1-2: LGTM: Export updated from LazyJITFunc to JITFunc.

The public API correctly exports the new JITFunc class. The noqa: F401 directive suppresses unused import warnings for these re-exports, which is appropriate. The Ruff warning about "unused noqa" suggests F401 may not be enabled in the project's Ruff configuration.

tilelang/jit/__init__.py (4)

27-27: LGTM: Import updated for unified JITFunc.

The import correctly brings in JITFunc to replace the previous LazyJITFunc.


192-264: Excellent: Comprehensive documentation for the unified JIT interface.

The docstring clearly explains the two execution modes with concrete examples, making it easy for users to understand the behavior difference between lazy and eager styles.


517-529: LGTM: Decorator correctly initializes unified JIT with auto mode.

The decorator always creates a JITFunc via prim_func(func, eager_jit=True) and sets mode="auto" for later inference. This unified approach correctly defers mode determination to runtime.


449-464: LGTM: Overload signatures for decorator patterns.

The overloads correctly handle both @jit (no arguments) and @jit(out_idx=...) (with keyword arguments) usage patterns.

…ate imports across multiple files to replace references to the deprecated v2 module with the new eager module, ensuring compatibility with the latest JIT compilation features. Remove obsolete v2 module files to streamline the codebase.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tilelang/jit/__init__.py (2)

156-190: par_compile has a syntax/runtime blocker (results = [... for _ in futures]) and dead return

If this is literal code, it won’t run. Also, return results after the with is unreachable.

Proposed fix
@@
-        results = [... for _ in futures]
+        results: list[JITKernel[_KP, _T] | None] = [None] * len(futures)
@@
-        return results
-    return results
+        return results

47-107: compile() type contract is inconsistent (func: PrimFunc = None but immediately asserted)

Either make func required, or raise a clear TypeError when None is passed (the current signature suggests None is acceptable).

Proposed fix
@@
-def compile(
-    func: PrimFunc[_KP, _T] = None,
+def compile(
+    func: PrimFunc[_KP, _T],
@@
-) -> JITKernel[_KP, _T]:
+) -> JITKernel[_KP, _T]:
@@
-    assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
+    if not isinstance(func, PrimFunc):
+        raise TypeError(f"target function must be a PrimFunc but got {type(func)}")
tilelang/language/eager/builder.py (1)

981-991: Root cause of eager kwargs bug: _parse_phase1_key mutates kwargs via pop()

_parse_phase1_key currently pop()s tensor kwargs from the passed dict. Since JITImpl.__call__ passes its own kwargs into parse_args, this can delete user-supplied tensor kwargs before later compilation/execution steps.

Fix by copying kwargs at the start of _parse_phase1_key (or at least inside parse_args) so caller-owned dicts aren’t mutated.

Proposed fix
@@
     def _parse_phase1_key(self, *args, **kwargs):
-        kwargs.update({k: v for k, v in zip(self.arg_names, args)})
+        kwargs = dict(kwargs)  # do not mutate caller kwargs
+        kwargs.update({k: v for k, v in zip(self.arg_names, args)})
         tensor_args = {}
         for k in self.tensor_args:
             if k in kwargs:
                 tensor_args[k] = kwargs.pop(k)
             elif k in self.tensor_args_defaults:
                 tensor_args[k] = self.tensor_args_defaults[k]
         p1_key = tuple(sorted(kwargs.items()))
         return p1_key, tensor_args, kwargs

Also applies to: 1045-1057

🤖 Fix all issues with AI agents
In @testing/python/language/test_tilelang_language_frontend_v2.py:
- Around line 208-223: The assertions in test_var_assign compare tensor elements
to Python ints, which can yield ambiguous truth values; change the checks to
extract Python scalars or compare tensors explicitly: call .item() on res[0] and
res[1] (e.g., res[0].item() == 1, res[1].item() == 2) or use torch.equal/tensor
comparison APIs to assert equality of tensors returned by test_var_assign.
- Around line 260-288: In both JIT tests, the tensor A is T.int32 but the code
writes float literals (1.0, 2.0) into it; change the assignments in
stepped_serial to use integer literals (1 and 2) so writes match A's dtype and
avoid backend-dependent casting, and verify stepped_serial_neg already writes
integer i (no change needed) — update A[i] = 1.0 -> A[i] = 1 and A[i] = 2.0 ->
A[i] = 2 inside the stepped_serial function.

In @tilelang/__init__.py:
- Line 145: Remove the ineffective "# noqa: F401" markers on the re-export
import lines for jit, JITKernel, compile, and par_compile in tilelang.__init__;
either delete those trailing noqa comments on the import statements (and let
lint report/unify behavior) or instead enable the F401/unused-import rule in
your Ruff/linters if you intentionally want to suppress unused-import warnings;
apply the same change to the other identical import line referenced in the
comment.
- Line 145: The public export of lazy_jit was removed but undocumented; either
add a release note directing users to replace lazy_jit with
@tilelang.jit(mode='lazy') or restore a deprecated alias: re-export lazy_jit in
tilelang.__init__ that points to the existing jit wrapper (e.g., def
lazy_jit(*args, **kwargs): warnings.warn("lazy_jit is deprecated; use
tilelang.jit(mode='lazy')", DeprecationWarning, stacklevel=2); return jit(*args,
mode='lazy', **kwargs)), and ensure the alias appears alongside jit, JITKernel,
compile, par_compile so imports continue to work while emitting the deprecation
warning.

In @tilelang/jit/__init__.py:
- Around line 318-325: The jit decorator claims to accept PrimFunc but always
funnels to prim_func(func, eager_jit=True) which expects a Python callable;
update the wrapper for jit to detect PrimFunc inputs (e.g., isinstance(func,
PrimFunc) or an equivalent check) and raise a clear, targeted error telling
users to use tilelang.jit.compile(...) (or the appropriate compile path) instead
of passing a PrimFunc, and apply the same check in the other jit entry path
referenced around the 467-529 region so PrimFunc is rejected consistently rather
than hitting inspect.signature in prim_func.
- Around line 318-325: The eager-path bug comes from JITFunc._parse_phase1_key
popping tensor kwargs out of the dict passed from JITImpl.__call__, which
mutates kwargs before compile/execute; fix by making defensive shallow copies in
JITImpl.__call__: create kw_for_parse = dict(kwargs) and pass that to
self.func.parse_args(...), then create kw_for_compile = dict(kwargs) (or reuse
the original kwargs untouched) when calling self.compile(*args,
**kw_for_compile) / self.execute(...). Also ensure eager arg ordering is stable
by building the eager positional list from args plus a deterministic walk of
parameter names (e.g., for name in self.func.param_names_in_order: if name in
kw_for_compile: append(kw_for_compile[name])) so keyword tensors are preserved
and ordered consistently.
- Around line 318-325: initialize_jit_mode currently unconditionally calls
self.func.set_mode(...) which will fail when self.func is a PrimFunc; guard the
call by checking the callable's capability (e.g., if hasattr(self.func,
"set_mode") or isinstance(self.func, the expected wrapper) ) and only call
set_mode when present, otherwise skip setting mode on PrimFunc (or call the
appropriate PrimFunc API if one exists). Also ensure the rest of
initialize_jit_mode behavior (mode inference via _infer_jit_mode, out_idx check,
and returning self.mode) remains unchanged.

In @tilelang/language/__init__.py:
- Around line 14-16: The wildcard re-export from .eager (from .eager import *)
is triggering the wrong noqa and risking API drift; either change the noqa
comment to suppress the wildcard-import error (use # noqa: F403 or # noqa:
F401,F403) or replace the wildcard with explicit names and a curated __all__
list (import the public symbols from .eager and .tir.ir and set __all__ =
['SymbolA', 'SymbolB', 'Layout', 'Fragment', ...]) so the public surface is
pinned and Ruff no longer flags unused/wildcard re-exports.

In @tilelang/language/eager/builder.py:
- Around line 731-737: The guard that raises JITNoBuilderError when
Builder.current() is None is correct, but the error text is misleading; update
the raised message (the string passed to JITNoBuilderError in the block that
checks Builder.current()) to mention the supported builder contexts (for
example: "T.macro can only be used inside a Builder context (e.g. @tilelang.jit
or @T.prim_func)") so users in a prim_func context aren’t misled.
🧹 Nitpick comments (5)
tilelang/language/proxy.py (1)

278-283: Runtime guard enforces Builder context requirement.

The guard ensures make_tensor is only called within a JIT or prim_func context by checking for an active Builder. The lazy import of Builder avoids circular dependencies.

Note: The static analysis tool flags TRY003 (long message outside exception class). While the descriptive error message is helpful, consider defining it within the JITNoBuilderError class if this pattern becomes widespread.

tilelang/language/kernel.py (1)

283-291: Runtime guard enforces Builder context requirement.

The guard ensures T.Kernel() is only called within a JIT or prim_func context. The implementation mirrors the pattern in proxy.py (lines 278-283) with a lazy Builder import to avoid circular dependencies.

Note: The static analysis tool flags TRY003 (long message outside exception class). Consider consolidating error messages into the exception class definition if this pattern recurs frequently.

tilelang/language/eager/__init__.py (1)

1-2: Public re-export module: define __all__ (or fix noqa) to stabilize the API surface

Since this module is now the public “surface” re-exported by tilelang.language, consider adding __all__ to prevent accidental symbol leaks as eager evolves. At minimum, align the noqa with what’s actually triggered (F403 for wildcard), since # noqa: F401 is currently flagged as unused.

testing/python/language/test_tilelang_language_frontend_v2.py (1)

42-57: test_expr: direct import of _all_dtypes is brittle

Importing a private underscore symbol (_all_dtypes) from tilelang.language.eager.dtypes couples the test to internal layout. Prefer a public export (e.g., tilelang.dtypes or a documented list) if one exists.

tilelang/language/eager/builder.py (1)

1109-1149: Typing: func: Callable = None should be Callable | None (or Optional[Callable])

Ruff’s RUF013 note is valid: the annotation currently implies func is always callable even though None is a supported value.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d33550b and 36f37f4.

📒 Files selected for processing (15)
  • testing/python/language/test_tilelang_language_frontend_v2.py
  • tilelang/__init__.py
  • tilelang/jit/__init__.py
  • tilelang/jit/adapter/tvm_ffi.py
  • tilelang/language/__init__.py
  • tilelang/language/allocate.py
  • tilelang/language/eager/__init__.py
  • tilelang/language/eager/ast.py
  • tilelang/language/eager/builder.py
  • tilelang/language/eager/dtypes.py
  • tilelang/language/eager/utils.py
  • tilelang/language/kernel.py
  • tilelang/language/loop.py
  • tilelang/language/print_op.py
  • tilelang/language/proxy.py
✅ Files skipped from review due to trivial changes (1)
  • tilelang/language/loop.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2026-01-06T05:20:45.325Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1606
File: testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py:30-30
Timestamp: 2026-01-06T05:20:45.325Z
Learning: In `testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py`, the test validates that the `hoist_broadcast_values` transformation pass correctly identifies and hoists broadcast operations by checking for patterns in the generated kernel source code. The specific literal values used (e.g., 430) are not important for the test's purpose, as it does not validate numerical precision or actual stored tensor values.

Applied to files:

  • testing/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/language/test_tilelang_language_frontend_v2.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • testing/python/language/test_tilelang_language_frontend_v2.py
🧬 Code graph analysis (7)
tilelang/jit/adapter/tvm_ffi.py (1)
tilelang/language/eager/dtypes.py (1)
  • dtype (14-15)
tilelang/language/eager/builder.py (1)
tilelang/jit/exceptions.py (2)
  • JITNoBuilderError (4-13)
  • EagerJITBuildError (16-24)
testing/python/language/test_tilelang_language_frontend_v2.py (6)
tilelang/language/tir/op.py (1)
  • all (1913-1930)
tilelang/jit/__init__.py (3)
  • jit (450-450)
  • jit (454-464)
  • jit (467-531)
tilelang/language/allocate.py (3)
  • alloc_var (86-86)
  • alloc_var (90-90)
  • alloc_var (93-147)
tilelang/language/eager/dtypes.py (3)
  • int32 (245-245)
  • dtype (14-15)
  • float32 (300-300)
tilelang/language/kernel.py (2)
  • Kernel (229-312)
  • threads (215-219)
tilelang/language/proxy.py (1)
  • Tensor (233-233)
tilelang/language/print_op.py (1)
tilelang/language/eager/builder.py (1)
  • Builder (167-695)
tilelang/language/allocate.py (2)
tilelang/language/eager/dtypes.py (1)
  • dtype (14-15)
tilelang/language/eager/builder.py (1)
  • OutTensor (114-120)
tilelang/language/eager/__init__.py (3)
tilelang/jit/adapter/tvm_ffi.py (1)
  • prim_func (319-321)
tilelang/language/eager/builder.py (8)
  • prim_func (189-198)
  • prim_func (1109-1150)
  • macro (201-228)
  • macro (746-782)
  • PrimFunc (704-713)
  • JITFunc (953-1084)
  • Ref (124-135)
  • const (851-879)
tilelang/language/proxy.py (2)
  • Ref (245-245)
  • Ref (253-253)
tilelang/language/proxy.py (2)
tilelang/jit/exceptions.py (1)
  • JITNoBuilderError (4-13)
tilelang/language/eager/builder.py (2)
  • Builder (167-695)
  • current (184-186)
🪛 Ruff (0.14.10)
tilelang/language/kernel.py

290-290: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/jit/__init__.py

323-323: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/__init__.py

145-145: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


163-163: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/eager/builder.py

196-196: Avoid specifying long messages outside the exception class

(TRY003)


733-733: Avoid specifying long messages outside the exception class

(TRY003)


870-870: Avoid specifying long messages outside the exception class

(TRY003)


1017-1017: Consider moving this statement to an else block

(TRY300)


1043-1043: Avoid specifying long messages outside the exception class

(TRY003)


1084-1084: Avoid specifying long messages outside the exception class

(TRY003)


1109-1109: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/language/__init__.py

14-14: from .eager import * used; unable to detect undefined names

(F403)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/eager/__init__.py

1-1: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


2-2: from .dtypes import * used; unable to detect undefined names

(F403)

tilelang/language/proxy.py

282-282: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (11)
tilelang/jit/adapter/tvm_ffi.py (1)

22-22: LGTM: Import migration to eager path.

The import path update from v2.dtypes to eager.dtypes aligns with the broader migration to the eager backend. The functional behavior remains unchanged.

tilelang/language/proxy.py (1)

13-13: LGTM: Import for new runtime guard.

The JITNoBuilderError import supports the new Builder context enforcement added below.

tilelang/language/kernel.py (1)

10-10: LGTM: Import for new runtime guard.

The JITNoBuilderError import supports the Builder context enforcement added in the Kernel function.

tilelang/language/print_op.py (1)

6-6: LGTM: Import migration to eager path.

The Builder import path update from v2.builder to eager.builder aligns with the broader migration to the eager backend. The functional behavior at line 128 (Builder.current().get_fileline_stack) remains unchanged.

tilelang/language/allocate.py (1)

31-33: LGTM: Import migrations to eager path.

All three import path updates (dtypes, tl_dtype, OutTensor) consistently migrate from the v2 to eager backend. The functional behavior throughout the file (lines 144, 159, 232, 249, 253, 257, 262, 267, 272, 274, 276) remains unchanged.

testing/python/language/test_tilelang_language_frontend_v2.py (4)

7-9: Boolean expression test: good shift to TIR ops, but keep operator choice consistent

Switching cond() to Or(Not(tir_all(...)), ...) is the right direction for TIR-safe boolean composition.

Also applies to: 451-455


295-321: swap_var / swap_idx: eager-style annotation placement looks correct

Annotating A: T.Tensor[...] inside the jitted function body is consistent with the eager builder pattern used elsewhere in this PR.


323-340: while_loop: return + .item() usage looks correct

This test is aligned with the eager execution behavior (jitted function returns the output tensor directly).


457-477: probe(A, tmp: bool): constexpr branching looks intentional

Using a Python bool to control compile-time branching is consistent with “constexpr if” behavior; call sites (probe(A, True/False)) validate both branches.

tilelang/language/eager/builder.py (2)

189-199: Builder context cleanup in prim_func: good hardening

The try/finally restoring thread_local_storage.builder and the “all outputs returned” validation make eager/lazy behavior much less footgun-prone.


851-880: T.const() eager-only guard: good, but keep exception type stable

Enforcing builder is None or not builder.eager_jitJITNoBuilderError matches the mode-inference strategy. This is a good contract for the “call once outside builder to detect eager” trick.

@LeiWang1999 LeiWang1999 requested a review from SiriusNEO January 9, 2026 07:08
…ove compatibility with eager execution. Refactor related function signatures and logic, ensuring seamless integration with the latest compilation features. This change aims to streamline JIT operations and enhance overall code clarity.
Copy link
Collaborator

@kurisu6912 kurisu6912 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Maybe we can rename examples/lazy_jit to examples/eager_jit (XD: lazy is eager, non-lazy is lazy

Copy link
Collaborator

@SiriusNEO SiriusNEO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just two small comments

@SiriusNEO SiriusNEO self-requested a review January 9, 2026 07:40
@LeiWang1999 LeiWang1999 merged commit 6e43953 into tile-ai:main Jan 9, 2026
6 checks passed
@Da1sypetals
Copy link

This feature seems to be broken. When using lazy mode, some syntax error can falsely report ValueError: out_idx is only supported in lazy mode. In eager mode, use T.empty() to declare output tensors instead.

@senlyu163
Copy link
Contributor

This feature seems to be broken. When using lazy mode, some syntax error can falsely report ValueError: out_idx is only supported in lazy mode. In eager mode, use T.empty() to declare output tensors instead.

met same problem ...

@SiriusNEO
Copy link
Collaborator

@Da1sypetals @senlyu163 Hi! The message seems to show that you use out_idx in eager mode (a.k.a no PrimFunc return). Could you provide scripts that triggers the false report? The current frontend works in my side.

@senlyu163
Copy link
Contributor

@SiriusNEO thx, it works fine. I make a mistake.

@Da1sypetals
Copy link

@Da1sypetals @senlyu163 Hi! The message seems to show that you use out_idx in eager mode (a.k.a no PrimFunc return). Could you provide scripts that triggers the false report? The current frontend works in my side.

If there are tilelang error in @T.prim_func, then this error will be wrongly raised. A simple example is create a buffer without dtype.

from icecream import ic
import torch
import einops as ein

import tilelang
import tilelang.language as T


@tilelang.jit(
    out_idx=[0],
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    },
)
def example():
    @T.prim_func
    def main(x: T.Tensor([8, 8], dtype=T.float32)):
        with T.Kernel(1, threads=128) as i_seq:
            R = T.alloc_shared([8, 8])

    return main


kernel = example()
kernel()

@LeiWang1999
Copy link
Member Author

@Da1sypetals it works for me on the upstream commit.

R = T.alloc_shared([8, 8])
    ^^^^^^^^^^^^^^^^^^^^^^
TypeError: alloc_shared() missing 1 required positional argument: 'dtype'

@Da1sypetals
Copy link

@Da1sypetals it works for me on the upstream commit.

R = T.alloc_shared([8, 8])
    ^^^^^^^^^^^^^^^^^^^^^^
TypeError: alloc_shared() missing 1 required positional argument: 'dtype'

Thanks! I'll test it when the next release come out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants