Skip to content

[Lint] Phaseout Yapf format and embrace ruff format#1417

Merged
LeiWang1999 merged 3 commits intotile-ai:mainfrom
LeiWang1999:ruff_1212
Dec 12, 2025
Merged

[Lint] Phaseout Yapf format and embrace ruff format#1417
LeiWang1999 merged 3 commits intotile-ai:mainfrom
LeiWang1999:ruff_1212

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 12, 2025

This pull request primarily makes code style and formatting improvements across several benchmarking scripts, as well as updates to the pre-commit configuration to streamline linting tools. The changes focus on improving code readability and consistency, and on consolidating formatting tools in the development workflow.

Pre-commit configuration updates:

  • Switched from using both yapf and ruff for code formatting to using only ruff (ruff-format), and removed the yapf hooks from .pre-commit-config.yaml for a simpler and unified formatting approach.

Code formatting and consistency improvements:

  • Reformatted tensor and function arguments to be more concise and consistent in the following files:

    • benchmark/blocksparse_attention/benchmark_library_dense_fmha.py [1] [2]
    • benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py [1] [2] [3] [4] [5] [6]
    • benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py [1] [2]
    • benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py [1] [2] [3] [4] [5] [6] [7]
    • benchmark/mamba2/benchmark_mamba_chunk_scan.py [1] [2] [3] [4]
  • Standardized string literals to use double quotes and improved the formatting of einsum and tensor operations for better clarity and alignment with project style guides. [1] [2] [3] [4]

These changes do not affect the logic or functionality of the code, but they enhance maintainability and ensure a consistent developer experience.

Summary by CodeRabbit

Release Notes

  • Chores

    • Updated code formatting toolchain: replaced Yapf with ruff-format in pre-commit configuration for consistent Python code style.
    • Applied comprehensive code style standardization across the codebase, including quote normalization and spacing consistency.
  • Documentation

    • Added configuration for Sphinx autodoc documentation generation with improved type hints display.
  • Style

    • Normalized Python code formatting across all benchmark and example files for improved readability and consistency.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 12, 2025

Important

Review skipped

Review was skipped as selected files did not have any reviewable changes.

💤 Files selected but had no reviewable changes (2)
  • testing/python/language/test_tilelang_language_intrinsics_codegen.py
  • testing/python/language/test_tilelang_language_vectorize.py

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

This PR applies widespread code formatting and linting updates across the repository. It replaces Google Yapf with ruff-format in pre-commit configuration and applies consistent formatting (quotes, spacing, line consolidation) throughout benchmark and example files. Minor functional changes include forcing roller-based search in matmul intrinsic benchmark and reorganizing in-kernel data access patterns in a deepseek kernel.

Changes

Cohort / File(s) Summary
Pre-commit and build configuration
.pre-commit-config.yaml
Removed Google Yapf hook; added ruff-format hook alongside existing ruff-check.
Documentation configuration
docs/conf.py
Updated string literals from single to double quotes; added variables (language, exclude_patterns, pygments_style, todo_include_todos); reformatted multi-line dicts/lists to inline style.
Benchmark: Blocksparse Attention
benchmark/blocksparse_attention/*
Cosmetic formatting: consolidated dense_mask construction, normalized quotation marks, adjusted spacing in slice expressions and function arguments; function signature compacted in triton benchmark.
Benchmark: Mamba2
benchmark/mamba2/benchmark_mamba_chunk_scan.py
Mostly formatting and whitespace adjustments; preserved core computations; minor dtype assertion update.
Benchmark: Matmul
benchmark/matmul/benchmark_matmul.py, benchmark_matmul_intrinsic.py, benchmark_matmul_sp.py
Formatting adjustments and blank-line management; matmul_intrinsic now forcefully enables roller-based search via with_roller = True.
Benchmark: Matmul FP8
benchmark/matmul_fp8/benchmark_matmul.py
Minor formatting and decorator adjustment; one-liner return simplified.
Examples: AMD Flash Attention
examples/amd/example_amd_flash_attn_*.py
String literal normalization; spacing in slice expressions; parameter formatting; main function signatures adjusted.
Examples: Analysis
examples/analyze/example_conv_analyze.py, example_gemm_analyze.py
Minor formatting and indentation adjustments; conv kernel adds internal dtype/accum_dtype overrides and is_hopper flag.
Examples: Attention Sink
examples/attention_sink/*
Formatting and style adjustments: multi-line to single-line argument calls, quote normalization, spacing updates; no functional changes.
Examples: BitNet 1.58b
examples/bitnet-1.58b/*
Widespread formatting updates: string literal quotes, blank-line adjustments, multi-line constructs reflowed; configuration_bitnet.py simplifies error messages into single-line f-strings.
Examples: Blocksparse Attention
examples/blocksparse_attention/*
Formatting and cosmetic refactors: dense_mask construction single-lined, quote normalization, heuristic.py function signature collapsed.
Examples: Blocksparse GEMM
examples/blocksparse_gemm/example_blocksparse_gemm.py
Cosmetic refactors: multi-line parser arguments consolidated, config construction reflowed, decorator formatting adjusted.
Examples: Cast
examples/cast/example_*.py, test_example_cast.py
Formatting adjustments to function signatures, layout annotations, and slice expressions; test call reflowed to single-line.
Examples: Compile Flags
examples/compile_flags/usecase.py
Parameter indentation adjusted; multi-line compile call collapsed to single-line.
Examples: conftest
examples/conftest.py
Multi-line conditional sum check collapsed to single-line; error message f-string consolidated.
Examples: Convolution
examples/convolution/example_convolution*.py
Function signatures reflowed; internal dtype/accum_dtype overrides added in convolution; layout annotations and in-bound checks reformatted.
Examples: Deepseek DeepGEMM
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
Whitespace/formatting adjustments and minor slice formatting; no logic changes.
Examples: Deepseek MLA
examples/deepseek_mla/amd/benchmark_mla_decode_amd_*.py, benchmark_mla_decode_amd_triton.py
Formatting and call-site adjustments; torch and triton benchmarks add extra h_q parameter to function signatures affecting arity.
Examples: Deepseek MLA (main)
examples/deepseek_mla/benchmark_mla.py, example_mla_decode*.py, torch_refs.py
Formatting and style adjustments: decorator reflow, spacing normalization, quote consistency; no semantic changes except torch_refs minor double-quote updates.
Examples: Deepseek NSA
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py, example_*.py, reference.py
Decorator reformatting, function signature multi-line expansions with explicit type hints, string literal normalization; token_indices parameter additions in some forward paths.
Examples: Deepseek V32
examples/deepseek_v32/fp8_lighting_indexer.py, sparse_mla_*.py, topk_selector.py, utils.py, test_tilelang_example_deepseek_v32.py
Formatting adjustments across kernel definitions and test scaffolding; sparse_mla_fwd_pipelined.py reorganizes in-kernel data access patterns and shared-memory KV indexing; minor numeric literal changes (1.0 vs 1.).
Examples: DSA Sparse Finetune
examples/dsa_sparse_finetune/dsa.py, index.py, indexer_*.py, sparse_mla_*.py, utils.py
Formatting and minor functional tweaks: multi-line signatures reflowed, string quoting normalized, minor slice spacing adjusted; dsa.py adds explicit dim_v parameter in backward calls.
Examples: Dynamic Shape
examples/dynamic_shape/example_dynamic.py
Function signature and call site formatting adjusted from multi-line to single-line.
Examples: Elementwise
examples/elementwise/example_elementwise_add.py
Removed @T.prim_func decorator; adjusted loop syntax; reformatted autotuner construction to multi-line parenthesized style.
Examples: Flash Attention
examples/flash_attention/bert_padding.py
Cosmetic refactoring: multi-line statements reflowed to single-line; no semantic or control-flow changes.
Examples: Dequantize GEMM
examples/dequantize_gemm/dequantize_utils.py, example_dequant_*.py
Formatting and style updates: quote normalization, line-wrapping adjustments, docstring additions; example_dequant_gemm_fine_grained.py expands B_shape to 4D with updated indexing paths.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~30 minutes

Areas requiring extra attention:

  • examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py — Added h_q parameter changes function arity across multiple signatures; verify parameter threading is correct.
  • examples/deepseek_v32/sparse_mla_fwd_pipelined.py — Reorganized in-kernel shared-memory KV data access patterns; verify memory layout and synchronization semantics are preserved.
  • examples/dequantize_gemm/example_dequant_gemm_fine_grained.py — Expanded B_shape from 3D to 4D with cascading indexing changes; verify all access paths correctly handle new dimension.
  • benchmark/matmul/benchmark_matmul_intrinsic.py — Forced with_roller = True; confirm this override is intentional and desired.
  • examples/elementwise/example_elementwise_add.py — Removed @T.prim_func decorator; verify impact on kernel registration and compilation.

Possibly related PRs

Suggested reviewers

  • chengyupku

🐰 Hops through the code with glee,
Ruff formats it perfectly—no Yapf we see!
Quotes now consistent, spacing so clean,
The prettiest codebase you've ever seen!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.14% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: removing Yapf formatting and adopting ruff-format, which aligns with the extensive code reformatting across the entire PR.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (22)
examples/bitnet-1.58b/vllm_workspace/utils.py (1)

6-19: Minor: output_ids mismatch assert message prints strings, not IDs (misleading).
On Line 18, when output_ids_0 != output_ids_1, the f-string prints output_str_* rather than the ID lists, which makes debugging tokenization/logprob issues harder.

-        assert output_ids_0 == output_ids_1, f"Test{prompt_idx}:\n{name_0}:\t{output_str_0!r}\n{name_1}:\t{output_str_1!r}"
+        assert output_ids_0 == output_ids_1, (
+            f"Test{prompt_idx}:\n"
+            f"{name_0}:\tids={output_ids_0}\tstr={output_str_0!r}\n"
+            f"{name_1}:\tids={output_ids_1}\tstr={output_str_1!r}"
+        )
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)

315-324: Fix Ruff ARG001: unused args in ref_program_fa (or keep but underscore).
If this signature is used by a shared harness, keep it but underscore unused params to satisfy lint.

-def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
+def ref_program_fa(query, key, value, _block_indices, cache_seqlens, _max_cache_seqlen, _num_blocks, _block_size):

35-106: Use cur_kv_head instead of hid for block_mask indexing on line 76.

The block_mask tensor has shape [batch, heads_kv, num_blocks], but line 76 indexes it with hid, which ranges from 0 to (heads // valid_block_H) - 1. When valid_block_H = min(block_H, kv_group_num) and block_H < kv_group_num, the kernel's y-dimension exceeds heads_kv, causing out-of-bounds access. Use cur_kv_head (defined on line 64) instead, which correctly maps to the [0, heads_kv) range:

if block_mask[bid, cur_kv_head, start + k]:
examples/amd/example_amd_flash_attn_bwd.py (2)

237-260: Ruff E741: rename ambiguous parameter O in preprocess prim_func.
O is easily confused with 0 and is flagged by ruff; renaming improves readability (esp. in a math-heavy kernel).

@@
-    def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
+    def flash_bwd_prep(out: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
@@
-                T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
+                T.copy(out[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)

448-578: Ruff RUF059: don’t bind unused *_mean_diff variables.
These are unpacked but never used; prefix with _ (or drop from return) to keep CI clean.

@@
-    dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
+    dq_close, dq_max_diff, _dq_mean_diff = debug_tensor_comparison(dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
@@
-    dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
+    dk_close, dk_max_diff, _dk_mean_diff = debug_tensor_comparison(dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
@@
-    dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
+    dv_close, dv_max_diff, _dv_mean_diff = debug_tensor_comparison(dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
examples/bitnet-1.58b/benchmark_generate.py (1)

33-44: Benchmark timing: add CUDA synchronize + guard division by zero.
Without a sync, generation_time can under-report GPU work (async launches), skewing tokens/sec.

     start_time = time.time()
     output_ids = model.generate(input_ids, generation_config=generation_config)
-    end_time = time.time()
+    if input_ids.is_cuda:
+        torch.cuda.synchronize()
+    end_time = time.time()
...
-    tokens_per_second = num_tokens / generation_time
+    tokens_per_second = num_tokens / max(generation_time, 1e-9)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)

311-335: Likely precision bug: dsinks should be accum_dtype, not dtype (fp16/bf16).

Right now dsink_fragment[i] = ... * delta_fragment[i] (accumulation-type math) is stored into dtype, then later reduced; this is prone to large numeric error.

 def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"):
     accum_dtype = "float"
@@
     def flash_bwd_dsink(
         Sinks: T.Tensor([heads], dtype),  # type: ignore
         Delta: T.Tensor(shape, accum_dtype),  # type: ignore
         lse: T.Tensor(shape, accum_dtype),  # type: ignore
-        dsinks: T.Tensor(shape, dtype),  # type: ignore
+        dsinks: T.Tensor(shape, accum_dtype),  # type: ignore
     ):
@@
-            dsink_fragment = T.alloc_fragment([block], dtype)
+            dsink_fragment = T.alloc_fragment([block], accum_dtype)

148-166: Ruff E741: rename ambiguous O / l identifiers (if ruff is enforced on examples).

-    def flash_bwd_prep(
-        O: T.Tensor(shape, dtype),  # type: ignore
+    def flash_bwd_prep(
+        out: T.Tensor(shape, dtype),  # type: ignore
         dO: T.Tensor(shape, dtype),  # type: ignore
         Delta: T.Tensor([batch, heads, seq_len], accum_dtype),  # type: ignore
     ):
@@
-                T.copy(O[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
+                T.copy(out[bz, bx, by * blk : (by + 1) * blk, k * blk : (k + 1) * blk], o)
 def make_dq_layout(dQ):
     # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
-    return T.Layout(dQ.shape, lambda b, h, l, d: [b, h, l // 8, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
+    return T.Layout(dQ.shape, lambda b, h, seq, d: [b, h, seq // 8, d // 8, (d % 2), 4 * (seq % 8) + (d % 8) // 2])

Confirm whether E741 is expected to pass on examples/** in CI/pre-commit.

Also applies to: 171-173

examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py (1)

1-16: Remove duplicate import torch (and consider narrowing the file-wide # ruff: noqa).
Line 6 re-imports torch (already imported at Line 2). Also, # ruff: noqa disables all ruff rules for the file, which somewhat cuts against the “ruff as single source of truth” goal—if you only need to suppress specific rules, prefer a targeted noqa list.

Suggested diff:

 # ruff: noqa
 import torch
 from typing import Optional, Union
 from packaging.version import parse

-import torch
 import triton
 import triton.language as tl

 import fla
examples/cast/example_group_per_split_token_cast_to_fp8.py (1)

19-61: Enforce batch_sizes ordering constraint to prevent out-of-bounds reads.

The kernel iterates T.ceildiv(M_max, blk_m) row tiles for every batch group and unconditionally reads a full M_max-height window from the flat input X (lines 46–49). For any batch group bg, the read range is [row_offset[0] + row·blk_m, row_offset[0] + (row+1)·blk_m], where row_offset[0] = sum(batch_sizes[0:bg]). When batch_sizes[bg] < M_max, the final tiles read past the group's actual data. If the largest batch is not last, reads exceed X.shape[0]—e.g., with batch_sizes=[6144, 2048], M_max=6144, and M=8192, the second group reads indices 6144 to 12288, overrunning the input by 4096 elements.

Add the precondition to main():

 def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
     if batch_sizes is None:
         batch_sizes = [2048, 6144]
+    assert batch_sizes[-1] == max(batch_sizes), "batch_sizes must be sorted with largest group last"

Separately, line 53 sets y_s_local[i]=0 for out-of-bounds rows, and line 55 divides by it (y_local[i, j] / y_s_local[i]), producing NaN/Inf before masking the result at line 58. Consider a non-zero default scale to avoid unnecessary invalid arithmetic.

benchmark/matmul/benchmark_matmul_intrinsic.py (1)

301-302: Remove unused CLI argument --with_roller or explain why roller is forced.

The --with_roller argument is parsed at line 301 but immediately overwritten at line 302. Users cannot disable roller-based search via CLI. Either remove the unused argument entirely, or add a comment explaining why roller is forced for this intrinsic benchmark:

-    parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces")

Or:

     with_roller = args.with_roller
-    with_roller = True
+    with_roller = True  # Force roller-based search for intrinsic benchmark
examples/deepseek_nsa/example_tilelang_nsa_bwd.py (4)

158-209: Blocker: H is undefined in tilelang_kernel_bwd_dkv (kernel launch + indexing).
with T.Kernel(..., B * H, ...) and i_bh // H will raise at build-time/runtime unless H is defined. It looks like H should be KV heads (heads // groups).

 def tilelang_kernel_bwd_dkv(
@@
 ):
@@
-    heads_kv = heads // groups
+    heads_kv = heads // groups
+    H = heads_kv
@@
-    print("NV", NV, "NS", NS, "B", B, "H", H)
+    # print("NV", NV, "NS", NS, "B", B, "H", H)
@@
-        with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
+        with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
             ...
-            i_b, i_h = i_bh // H, i_bh % H
+            i_b, i_h = i_bh // H, i_bh % H

(If you actually meant H == heads here, then adjust the Q/K/V shapes + indexing accordingly; current i_h * G : (i_h + 1) * G strongly suggests KV-head indexing.)

Also applies to: 210-315


328-414: Blocker: H is also undefined in tilelang_kernel_bwd_dqkv.
Same issue pattern as bwd_dkv (B * H, i_bh // H), and same likely fix (H = heads // groups).

 def tilelang_kernel_bwd_dqkv(
@@
 ):
@@
-    heads_kv = heads // groups
+    heads_kv = heads // groups
+    H = heads_kv
@@
-        with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
+        with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
             ...
-            i_b, i_h = i_bh // H, i_bh % H
+            i_b, i_h = i_bh // H, i_bh % H

Also applies to: 415-495


618-632: Fix block_counts handling before calling .to(...) (crashes for None/int).
parallel_nsa_bwd accepts block_counts: Union[torch.LongTensor, int], and forward passes ctx.block_counts through—so this can be None, int, or a tensor depending on call-site. Normalize it before passing into the kernel:

-    block_mask = tilelang_kernel_block_mask(B, H, T, S, BS)(block_indices.to(torch.int32), block_counts.to(torch.int32)).to(torch.bool)
+    if isinstance(block_counts, int):
+        block_counts_i32 = torch.full((B, T, H), block_counts, device=block_indices.device, dtype=torch.int32)
+    elif block_counts is None:
+        # default: all blocks available (or use S)
+        block_counts_i32 = torch.full((B, T, H), S, device=block_indices.device, dtype=torch.int32)
+    else:
+        block_counts_i32 = block_counts.to(torch.int32)
+
+    block_mask = (
+        tilelang_kernel_block_mask(B, H, T, S, BS)(
+            block_indices.to(torch.int32),
+            block_counts_i32,
+        )
+        .to(torch.bool)
+    )

(If tilelang_kernel_block_mask is intended to support “no BlockCounts”, consider changing its signature to omit BlockCounts entirely when unused.)


9-15: Guard the fla import or fail with a clearer message.

The import fla at line 9 is unconditional; if fla is not installed, the module will fail to import. While fla is listed in examples/deepseek_nsa/requirements.txt, users who run this file without installing from requirements will encounter an unclear ModuleNotFoundError. Consider either moving the import inside if __name__ == "__main__": or wrapping it in a try-except that provides an actionable error message pointing to the requirements file.

The version split (< 0.2.1) appears intentional and consistent across other deepseek_nsa examples; no change needed there.

examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)

220-235: Silence ruff ARG001 for unused args in main_no_split without changing the call interface.
Keeping params for signature compatibility is reasonable; just mark them unused.

 def main_no_split(
         Q: T.Tensor([batch, heads, dim], dtype),
         Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
         KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
         K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
-        glse: T.Tensor([batch, heads, num_split], dtype),
-        Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
+        _glse: T.Tensor([batch, heads, num_split], dtype),
+        _output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
         Output: T.Tensor([batch, heads, dim], dtype),
 ):
     flash_attn(Q, Q_pe, KV, K_pe, Output)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (1)

10-49: scale argument is unused but docstring claims it’s applied—make this consistent.
Either apply scale in the helper (and adjust the downstream multiply), or mark it intentionally unused + fix the docstring.

Minimal lint/doc fix (no behavior change):

-def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
+def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, _scale: tir.PrimExpr, dtype: str):
     """
-    Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
+    Convert a 4-bit field packed in a uint8 into a bfloat16 value.
examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)

194-208: Tensor layout change detected - beyond formatting scope.

The B_shape has been restructured from a 2D layout to a 4D micro-tiled layout:

# New 4D layout
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte)

This is a functional change that affects:

  1. How B tensors are indexed (lines 277-281)
  2. Shared memory access patterns
  3. Output write patterns (lines 340-342)

While this may be intentional refactoring, it exceeds the "formatting-only" scope described in the PR objectives. Please ensure:

  1. This change is intentional and tested
  2. The PR description is updated to note functional changes
  3. Related documentation reflects the new layout
examples/analyze/example_conv_analyze.py (1)

28-35: dtype / accum_dtype params are misleading because they’re overwritten unconditionally. Either honor the passed values or drop them from the signature to avoid confusion.

-def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
+def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads):
@@
-    dtype = "float16"
-    accum_dtype = "float"
+    dtype = "float16"
+    accum_dtype = "float"
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)

310-340: is_casual typo in public-ish interfaces: rename to is_causal (keep backward compat). Keyword callers will trip on this.

 def sparse_mla_fwd_interface(
-    q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False
+    q,
+    kv,
+    indices,
+    q_start_index_s,
+    kv_stride,
+    sm_scale=None,
+    is_causal=True,
+    return_kernel=False,
+    print_kernel=False,
 ):
@@
-    kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0)
+    kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_causal, CP0)
@@
-def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True):
+def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_causal=True):

Also applies to: 343-380

examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py (2)

90-98: Likely wrong BlockCounts indexing vs bos (varlen correctness).

BlockIndices and Q/K/V loads use bos + i_t, but NS = BlockCounts[i_t, i_h] ignores bos. That makes counts come from the wrong row for any segment with bos != 0.

-            NS = BlockCounts[i_t, i_h]
+            NS = BlockCounts[bos + i_t, i_h]
             T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared)

148-190: parallel_nsa_fwd is broken: undefined D, optional inputs used unconditionally, and V-dim ambiguity.

  • Line 181-183: D is undefined.
  • Line 162 and 187-189: offsets / token_indices can be None but you call len(offsets) and .to(...).
  • The TileLang kernel assumes V-dim == K-dim (single dim), but the wrapper currently treats V = v.shape[-1] as potentially different.

Concrete fix (minimal, keeps current kernel assumption V == K):

 def parallel_nsa_fwd(
@@
 ):
-    B, C_SEQ_LEN, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
+    B, C_SEQ_LEN, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
+    D = K
+    if V != D:
+        raise ValueError(f"TileLang NSA varlen kernel currently assumes V==K, got V={V}, K={D}")
+
+    # Allow non-varlen call sites by defaulting to a single segment.
+    if offsets is None:
+        offsets = torch.tensor([0, C_SEQ_LEN], device=q.device, dtype=torch.long)
+    if token_indices is None:
+        token_indices = prepare_token_indices(offsets)
 
-    batch = len(offsets) - 1
+    batch = len(offsets) - 1
@@
-    kernel(
-        q.view(C_SEQ_LEN, HQ, D),
-        k.view(C_SEQ_LEN, H, D),
-        v.view(C_SEQ_LEN, H, D),
+    kernel(
+        q.view(C_SEQ_LEN, HQ, D),
+        k.view(C_SEQ_LEN, H, D),
+        v.view(C_SEQ_LEN, H, D),
         o_slc.view(C_SEQ_LEN, HQ, V),
         block_indices.to(torch.int32).view(C_SEQ_LEN, H, S),
-        block_counts.to(torch.int32).view(C_SEQ_LEN, H),
-        offsets.to(torch.int32),
-        token_indices.to(torch.int32),
+        (torch.full((C_SEQ_LEN, H), S, device=q.device, dtype=torch.int32) if block_counts is None else block_counts.to(torch.int32).view(C_SEQ_LEN, H)),
+        offsets.to(torch.int32),
+        token_indices.to(torch.int32),
     )
     return o_slc
🟡 Minor comments (13)
examples/deepseek_mla/example_mla_decode.py-219-221 (1)

219-221: Remove or annotate unused function parameters in main_no_split.

The parameters glse and Output_partial (lines 219–220) are not used in the function body. While this may be intentional for API consistency with main_split (which uses both), the Ruff linter correctly flags them as unused.

To silence the warnings, prefix with underscore (_glse, _Output_partial) or add a # noqa: ARG001 comment to clarify the intentional design.

  def main_no_split(
      Q: T.Tensor([batch, heads, dim], dtype),
      Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
      KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
      K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
-     glse: T.Tensor([batch, heads, num_split], dtype),
-     Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
+     _glse: T.Tensor([batch, heads, num_split], dtype),
+     _Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
      Output: T.Tensor([batch, heads, dim], dtype),
  ):
      flash_attn(Q, Q_pe, KV, K_pe, Output)

Alternatively, add a comment above the function:

# noqa: ARG001
def main_no_split(...):
examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py-272-304 (1)

272-304: Minor: Unused reference implementation parameters.

The ref_program_torch function declares max_cache_seqlen (line 272) and unpacks heads (line 273) but neither is used in the function body. This is acceptable for a reference implementation, but consider prefixing unused variables with _ (e.g., _, heads, dim = query.shape) if they should not be referenced in future revisions.

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py-280-312 (1)

280-312: Fix Ruff ARG001/RUF059: unused args/vars in ref_program_torch.
max_cache_seqlen is unused, and heads from batch, heads, dim = query.shape is unused. Either remove them (if not needed by callers) or prefix with _ to keep the signature stable for benchmarks.

-def ref_program_torch(query, key, value, block_mask, cache_seqlens, max_cache_seqlen, num_blocks, block_size):
-    batch, heads, dim = query.shape
+def ref_program_torch(query, key, value, block_mask, cache_seqlens, _max_cache_seqlen, num_blocks, block_size):
+    batch, _heads, dim = query.shape
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py-339-344 (1)

339-344: Fix Ruff RUF046: remove redundant int(...).
math.ceil(...) already returns an int.

-    max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
+    max_selected_blocks = math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py-425-426 (1)

425-426: Fix ruff RUF003: replace fullwidth parentheses in comment.
Line 425 uses (padding_M,), which ruff flags. Replace with ASCII ( and ) to keep pre-commit green.

-    expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda")  # (padding_M,)
+    expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda")  # (padding_M,)
examples/deepseek_mla/example_mla_decode_ws.py-519-529 (1)

519-529: Fix/silence Ruff ARG001 on intentionally-unused args in main_no_split.
If main_no_split must keep the same signature as main_split, explicitly mark the unused args to keep lint clean.

Proposed diff (pick one approach):

 def main_no_split(
     Q: T.Tensor([batch, heads, dim], dtype),
     Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
     KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
     K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
-    glse: T.Tensor([batch, heads, num_split], dtype),
-    Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
+    glse: T.Tensor([batch, heads, num_split], dtype),  # noqa: ARG001
+    Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),  # noqa: ARG001
     Output: T.Tensor([batch, heads, dim], dtype),
 ):
     flash_attn(Q, Q_pe, KV, K_pe, Output)

or:

-    glse: T.Tensor([batch, heads, num_split], dtype),
-    Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
+    _glse: T.Tensor([batch, heads, num_split], dtype),
+    _Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
examples/dynamic_shape/example_dynamic.py-56-61 (1)

56-61: Run ruff format on this file: lines 56, 58, 60, and 105 exceed the configured line-length of 140 characters and require reformatting.

The file currently fails ruff format --check. These long function signatures and calls should be wrapped across multiple lines to comply with ruff formatting standards.

examples/bitnet-1.58b/configuration_bitnet.py-182-183 (1)

182-183: Fix typo in error message + avoid long inline exception strings (TRY003).
Line 183 has a user-facing typo (“with with”) and triggers Ruff TRY003; moving the message to a constant also keeps the line readable.

+_ROPE_SCALING_DICT_ERROR = (
+    "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {rope_scaling}"
+)
 ...
         if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
-            raise ValueError(f"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {self.rope_scaling}")
+            raise ValueError(_ROPE_SCALING_DICT_ERROR.format(rope_scaling=self.rope_scaling))
examples/blocksparse_attention/block_sparse_attn_triton.py-152-152 (1)

152-152: Address ruff ARG001 for unused ctx in _forward.
If ctx is intentionally unused, either rename to _ctx or explicitly delete it to satisfy linting.

-def _forward(ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
+def _forward(_ctx, q, k, v, block_sparse_mask, sm_scale, BLOCK_M=64, BLOCK_N=64, num_warps=None, num_stages=1, out=None):
     assert q.shape[-1] == k.shape[-1] == v.shape[-1]

Also applies to: 190-190, 195-195

examples/deepseek_v32/sparse_mla_fwd.py-28-33 (1)

28-33: Fix misleading assert message for tail_dim.
Line 29’s f-string says dim={tail_dim} which is confusing when debugging padding issues.

-    assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
+    assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, tail_dim={tail_dim}"
examples/convolution/example_convolution_autotune.py-77-79 (1)

77-79: Unused parameter enable_rasteration detected.

The parameter enable_rasteration is declared but never used in the function body. Additionally, note the inconsistent spelling—the parameter name appears to be enable_rasteration (line 78) but the config key uses enable_rasteration (line 52), suggesting this may be a legacy typo for "rasterization."

Consider either:

  1. Using the parameter in the kernel logic
  2. Removing it if it's no longer needed
  3. Prefixing with underscore (_enable_rasteration) to indicate it's intentionally unused

Based on static analysis hints.

examples/convolution/example_convolution_autotune.py-133-145 (1)

133-145: Unused parameter with_roller detected.

The with_roller parameter (line 144) is defined but never referenced in the function body. This suggests either:

  1. Incomplete implementation
  2. Dead parameter that should be removed

Consider removing the parameter or implementing the intended roller-based search space logic.

Based on static analysis hints.

examples/convolution/example_convolution.py-80-90 (1)

80-90: CLI flag change is user-facing (long-only). If you want “non-breaking”, keep short aliases too (-n + --n, etc.).

-    parser.add_argument("--n", type=int, default=128, help="n")
+    parser.add_argument("-n", "--n", type=int, default=128, help="n")
@@
-    parser.add_argument("--p", type=int, default=1, help="p")
+    parser.add_argument("-p", "--p", type=int, default=1, help="p")

Committable suggestion skipped: line range outside the PR's diff.

🧹 Nitpick comments (48)
examples/bitnet-1.58b/load_from_quantized.py (1)

52-58: Pure formatting change; semantics preserved.
The parenthesized multi-line chain keeps the same call order and arguments (from_quantized(saved_model_path) -> cuda() -> half()).

Optional: consider a device guard for CPU-only environments.
This script still hard-requires CUDA via .cuda(); if examples are expected to run without GPUs, switch to .to(device) and pick device = "cuda" if torch.cuda.is_available() else "cpu".

examples/bitnet-1.58b/vllm_workspace/utils.py (1)

6-12: Optional: add context to the length mismatch assertions.
The bare assert len(outputs_0_lst) == len(outputs_1_lst) gives little signal in failures; consider including name_* and both lengths.

-    assert len(outputs_0_lst) == len(outputs_1_lst)
+    assert len(outputs_0_lst) == len(outputs_1_lst), (
+        f"{name_0} has {len(outputs_0_lst)} outputs, {name_1} has {len(outputs_1_lst)} outputs"
+    )

Also applies to: 24-30

examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py (2)

90-96: Line 95 return is very dense; consider a small readability split (no behavior change).

-    return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
+    x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n)
+    x_scale = (x_amax / 448.0).view(m, -1)
+    return x_fp8, x_scale

98-107: Padding/reshape formatting looks equivalent; consider torch.empty instead of zeros if you always overwrite full padded region.

Right now you do overwrite x_padded[:m, :n], but the rest of the padded buffer stays zero and affects x_amax/scaling (so this may be intentionally zero-padded). If padding must be zero, keep zeros; otherwise empty is cheaper.

examples/deepseek_nsa/reference.py (1)

126-132: Consider re-wrapping the gather(...) + mask expression for debuggability (optional). The one-liner is correct but hard to step through when shape issues arise.

examples/deepseek_v32/utils.py (1)

26-27: Avoid hard-to-read mega-isinstance lines (ruff may reflow, but this is borderline).
Line 26 is correct but pretty dense; consider extracting the “safe scalar types” tuple to a module-level constant for readability (and reuse in _is_equal).

examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)

69-71: Consider factoring the sliding-window condition to keep the T.if_then_else readable (optional).

The single-line condition is correct but hard to scan. A tiny helper boolean (or split lines) would improve readability without changing codegen intent.

examples/deepseek_v32/topk_selector.py (2)

130-132: De-duplicate the repeated “byte extract” expression to avoid future drift
You compute the same ((convert_to_uint32(...) >> (24 - round * 8)) & 0xFF) twice; consider assigning it once (per use-site) before T.Cast(...) to improve readability and reduce divergence risk.

-                        l_bin_id32 = T.Cast(
-                            "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
-                        )
+                        byte_u32 = (convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF
+                        l_bin_id32 = T.Cast("int32", byte_u32)
-                        l_bin_id32 = T.Cast(
-                            "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
-                        )
+                        byte_u32 = (convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF
+                        l_bin_id32 = T.Cast("int32", byte_u32)

Also applies to: 159-161


211-211: Guard the ratio print against len(set_ref) == 0 (debug safety)
Probably fine given current topk, but this will throw if set_ref is empty (e.g., if this helper is reused with topk=0).

-        print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref))
+        denom = len(set_ref)
+        print("selected/all:", len(intersection), "/", denom, "=", (len(intersection) / denom) if denom else 0.0)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)

238-238: Optional cleanup: Address unused parameters flagged by static analysis.

Ruff identifies several pre-existing issues that could be cleaned up in a follow-up:

  • Line 238: max_cache_seqlen is unused in sparse_gqa_decode_varlen_indice
  • Line 292: max_cache_seqlen and num_blocks are unused in ref_program_torch
  • Line 293: heads is unpacked but never used
  • Line 326: Multiple unused arguments in ref_program_fa (block_indices, max_cache_seqlen, num_blocks, block_size)
  • Line 352: Redundant int() cast—math.ceil() already returns int in Python 3

Consider prefixing unused parameters with _ or removing them if they're not part of a required interface.

Also applies to: 292-293, 326-326, 352-352

examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (5)

28-37: Keep get_configs() return readable (and within any configured ruff line-length).
Line 37’s one-liner comprehension is hard to scan and may trip line-length lint depending on your ruff config. Consider expanding it for readability.

-    return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
+    return [
+        {k: v for k, v in zip(iter_params, values)}
+        for values in itertools.product(*iter_params.values())
+    ]

372-377: ref_moe scaling line is dense—consider factoring for readability.
Line 372 is correct-looking but hard to audit; a couple temporaries (e.g., scale_idx = ...) would make it easier to maintain and reduce mistake risk if modified later.


417-417: Long torch.cat(...) line may fight lint; consider wrapping.
Even if ruff-format allows it, ruff lint rules (if enabled) may enforce line length separately. Wrapping also improves readability.


499-500: do_bench(lambda: kernel(...)) one-liner: consider naming the callable for clarity.
Optional, but avoids an extremely long line and makes profiling code easier to tweak.


513-522: CLI call site is a bit dense; otherwise fine.
If you want to keep within line-length limits, wrap the final main(...) call across multiple lines—no functional issues.

examples/deepseek_mla/example_mla_decode_ws.py (2)

480-485: T.annotate_layout reformat is fine; consider a tiny readability tweak.
You could inline the one-entry dict to reduce vertical space, but current version is totally acceptable.


598-606: CLI flag normalization to --long form looks good; consider reflowing the final tuple unpack.
The one-line multi-assign at Line 605 is a bit dense; optional to break into multiple lines for easier diffs.

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)

244-249: Python version compatibility: tuple[...] return annotation (Line 244).
This requires Python ≥ 3.9; if the repo still supports 3.8, use typing.Tuple[...] instead.

-from typing import Optional
+from typing import Optional, Tuple
@@
-def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+def gen_inputs(B, H, Sq, Skv, D, dtype=torch.float16) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
examples/amd/example_amd_flash_attn_bwd.py (3)

46-62: Config generation loop looks fine; consider yielding instead of building a giant list (optional).
This can generate a large valid_configs list (potentially heavy for import-time/CLI startup). If this is only used for autotune enumeration, a generator can reduce memory.


213-234: BWD config generation loop looks fine; same optional generator note applies.


581-590: CLI arg formatting looks good; consider matching main() defaults for seq_len/dim (optional).
main() defaults to seq_len=4096, dim=128, but argparse defaults are 1024/64. If intentional for faster runs, maybe mention it in help= to avoid confusion.

examples/dsa_sparse_finetune/sparse_mla_bwd.py (2)

101-102: Consider using idiomatic boolean check.

is_causal == True can be simplified to is_causal or is_causal is True for identity comparison.

-    assert is_causal == True, "non-casual is not supported now"
+    assert is_causal, "non-casual is not supported now"

259-259: Pre-existing typo: is_casual should be is_causal.

The parameter name is_casual appears to be a typo for is_causal. This is pre-existing and out of scope for this formatting PR, but consider fixing in a follow-up.

examples/dynamic_shape/example_dynamic.py (1)

56-56: Prefer a wrapped signature/call (or a config object) for matmul_dynamic(...) to keep examples readable.

-def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads):
+def matmul_dynamic(
+    M,
+    N,
+    K,
+    block_M,
+    block_N,
+    block_K,
+    trans_A,
+    trans_B,
+    in_dtype,
+    out_dtype,
+    accum_dtype,
+    num_stages,
+    threads,
+):
@@
-    kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
+    kernel = matmul_dynamic_mnk(
+        block_M,
+        block_N,
+        block_K,
+        trans_A,
+        trans_B,
+        in_dtype,
+        out_dtype,
+        accum_dtype,
+        num_stages,
+        threads,
+    )

Also applies to: 60-60

examples/bitnet-1.58b/vllm_workspace/inference_with_native_format.py (1)

35-43: Formatting-only change looks safe; consider keywording the first arg for clarity.
ckpt_path is passed positionally into VllmRunner—if this is a “model_name/model path” parameter, model_name=ckpt_path (or model=...) is a bit clearer and more grep-friendly.

examples/bitnet-1.58b/vllm_workspace/conftest.py (2)

366-368: Good: bias getattr-guard increases compatibility across model heads.
Minor improvement: cache emb = self.model.get_output_embeddings() once to avoid repeated virtual calls.

-                if getattr(self.model.get_output_embeddings(), "bias", None) is not None:
-                    logits += self.model.get_output_embeddings().bias.unsqueeze(0)
+                emb = self.model.get_output_embeddings()
+                if getattr(emb, "bias", None) is not None:
+                    logits += emb.bias.unsqueeze(0)

392-394: Line 393 and line 519 should be refactored for readability and to avoid formatter churn if line-length is reduced. Currently, these lines are 113 characters and fall within the configured limit of 140 characters in pyproject.toml, but they are close enough that future adjustments to line-length could trigger reformatting. Consider breaking the list comprehensions into multiple lines for better maintainability.

Also applies to: 516, 517

examples/bitnet-1.58b/benchmark_generate.py (3)

38-40: Prefer tokenizer.batch_decode() over per-row decode.
Cleaner and usually faster for batched outputs.

-    generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids]
+    generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

51-71: profile(): drop numpy (mean of a scalar) or actually collect samples.
times is a single float; np.mean(times) is redundant and adds an import.

 def profile(model, input_data):
-    import numpy as np
-
     model = model.cuda()
     model.eval()
@@
-        times = get_runtime(num_repeats)
-    return np.mean(times)
+        runtime_ms = get_runtime(num_repeats)
+    return runtime_ms

74-96: Model init: .half() is likely redundant when torch_dtype=torch.float16 is set.
This is harmless but can be removed for clarity.

     model = (
         BitnetForCausalLM.from_pretrained(
             model_path,
             use_flash_attention_2=True,
             torch_dtype=torch.float16,
         )
         .cuda()
-        .half()
     )
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)

49-52: sm_scale expression change is formatting-only; consider de-duping log2(e) constant.
No semantic change; optional readability improvement is to name LOG2E = 1.44269504 and reuse it (also used again on Line 192).

examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py (1)

36-36: Optional: Fix typo in error message.

The error message contains "TRue" instead of "True". While this line wasn't changed in this PR, it's a minor typo worth correcting for consistency.

Apply this diff to fix the typo:

-    assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
+    assert trans_B is True, "Dequantize only implement for trans_B=True currently"
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)

146-164: E741 is explicitly ignored in the repository's Ruff configuration, so renaming O and l is not a linting requirement. This is an optional style improvement rather than an enforced standard.

benchmark/mamba2/benchmark_mamba_chunk_scan.py (2)

129-170: Helion kernel local casting changes look safe; consider hoisting repeated .to(torch.float32) where possible.
No correctness issues spotted, but repeated casts inside tiled loops can be a hotspot depending on Helion lowering.


178-180: iter_params one-liner is OK; a bit dense for future edits.
If this becomes a common edit point, multi-line dict literal may stay more readable.

examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py (2)

62-66: Prefer splitting the combined tl.load tuple assignment for readability/debugging.
The one-liner at Line 63 is correct, but it’s harder to instrument/inspect when debugging kernel index math.


254-328: parallel_nsa(...) formatting looks fine; consider clarifying cu_seqlens vs offsets naming at the autograd boundary.
ParallelNSAFunction.forward(..., offsets) is called with cu_seqlens (Line 321). If cu_seqlens is intentionally the same “offsets” tensor, consider renaming the forward parameter for clarity (no functional change, but reduces confusion).

examples/bitnet-1.58b/modeling_bitnet.py (1)

964-1127: Prefer explicit boolean masking over bool * bool for padding_mask (Torch compatibility).

On Lines 1114-1115, eq(...) * eq(...) works but can be brittle (dtype / deprecation warnings) compared to explicit boolean ops.

-                padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
+                padding_mask = causal_mask[..., :mask_length].eq(0.0) & attention_mask[:, None, None, :].eq(0.0)
                 causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
examples/cast/example_per_token_cast_to_fp8.py (1)

16-45: Consider explicit shape/parameter guards (blk_m + tail tiles).
If T.copy(...) isn’t automatically bounds-masked for T.ceildiv(...) tiles, this kernel will read/write past tensor extents when M % blk_m != 0 or N % 128 != 0. Also, forward_thread_fn uses (blk_m // 4) (Line 30), which is unsafe for blk_m < 4 and likely unintended for blk_m % 4 != 0.

Suggested minimal guardrails (example-level):

 def main(M=8192, N=8192, blk_m=8):
+    assert blk_m >= 4 and blk_m % 4 == 0
+    assert N % 128 == 0
+    assert M % blk_m == 0
     kernel = per_token_cast_to_fp8(M, N, blk_m)
benchmark/matmul/benchmark_matmul_intrinsic.py (1)

293-294: Argparse formatting approved, but note the --with_roller argument type issue.

The argument uses type=bool which doesn't work as expected with argparse — passing --with_roller False would still be truthy because the string "False" is truthy. However, since the value is overwritten anyway (line 302), this is moot unless the override is removed.

If the override is removed, consider using action="store_true" instead:

-    parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces")
+    parser.add_argument("--with_roller", action="store_true", help="Whether to use roller to deduce search spaces")
examples/dequantize_gemm/dequantize_utils.py (2)

42-46: Consider re-wrapping res3 / nested torch.where for maintainability.
Single-line bit-math + nested conditional is harder to audit when someone needs to adjust masks/shifts.


123-148: assert_similar diagnostics: OK, but printing diff unconditionally can be noisy.
If this utility is used in benchmarks/tests, consider gating print(f"{diff=}") behind a flag or switching to logging.

examples/deepseek_v32/sparse_mla_fwd.py (1)

210-213: Replace 1 - 1 placeholders with explicit 0 (or remove).
As written, torch.arange(1 - 1, ...) and mask[:, :, : 1 - 1, 0] = True are hard to read and the slice is a no-op (:0). If these are meant to be constants, make them explicit; otherwise, consider deleting.

-    compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
-        1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda"
-    ).view(1, -1)
+    compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
+        0, sk, 1, dtype=torch.int32, device="cuda"
+    ).view(1, -1)

-    mask[:, :, : 1 - 1, 0] = True
+    # (optional) remove if not needed; this was a no-op slice (":0")
+    # mask[:, :, :0, 0] = True

Also applies to: 217-217

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py (1)

10-49: Unused scale parameter in _tir_u8_to_f4_to_bf16.

The static analysis correctly identifies that the scale parameter is not used in the function body (the commented-out line 43 would have used it). The docstring describes its purpose, but the implementation doesn't actually apply scaling.

This appears to be pre-existing technical debt rather than something introduced by this formatting PR, so it's acceptable to defer fixing it.

Consider either:

  1. Removing the unused scale parameter if scaling is not needed
  2. Uncommenting line 43 to apply the scale if it should be used
 def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
     ...
     # Scale is the exponential part, within the representation of uint8
     # To handle the overflow, we may use the min function to limit the exponential part to 8 bits
-    # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
+    e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
examples/bitnet-1.58b/eval_gpu_memory.py (1)

8-10: --hf_path is defined but not used (hardcoded model id in main()). This is confusing for users running the script.

-def main():
-    model = BitnetForCausalLM.from_pretrained(
-        "1bitLLM/bitnet_b1_58-3B",
+def main(args):
+    model = BitnetForCausalLM.from_pretrained(
+        args.hf_path,
         device_map="auto",
         low_cpu_mem_usage=True,
         use_flash_attention_2=True,
         torch_dtype=torch.float16,
     ).half()
@@
-if __name__ == "__main__":
-    main()
+if __name__ == "__main__":
+    args = parser.parse_args()
+    main(args)

Also applies to: 37-45, 51-52

examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)

9-49: Reflowed compile flags/asserts look fine; minor typo in error text (“non-casual”).

examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (1)

110-129: Confusing shadowed ParallelNSAFunction (and mismatched parallel_nsa_fwd call) — consider removing/renaming.

This block appears effectively dead due to the later ParallelNSAFunction redefinition; keeping both increases the chance of accidental import/use.

examples/deepseek_nsa/example_triton_nsa_fwd.py (1)

116-135: Shadowed ParallelNSAFunction block is misleading (and the call doesn’t match the later parallel_nsa_fwd signature).

Consider deleting or renaming the earlier class to avoid future “wrong one imported” mistakes.

Comment on lines +54 to +58
causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
out = torch.einsum(
"bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cat -n benchmark/mamba2/benchmark_mamba_chunk_scan.py | sed -n '54,58p'

Repository: tile-ai/tilelang

Length of output: 420


🏁 Script executed:

python3 << 'EOF'
import torch

# Test 1: dtype=bool (Python bool)
try:
    result1 = torch.ones(2, 2, dtype=bool)
    print(f"dtype=bool works: {result1.dtype}")
except Exception as e:
    print(f"dtype=bool error: {type(e).__name__}: {e}")

# Test 2: dtype=torch.bool
try:
    result2 = torch.ones(2, 2, dtype=torch.bool)
    print(f"dtype=torch.bool works: {result2.dtype}")
except Exception as e:
    print(f"dtype=torch.bool error: {type(e).__name__}: {e}")

# Test 3: Check if torch.bool is the correct type
print(f"torch.bool type: {type(torch.bool)}")
print(f"Python bool type: {type(bool)}")
EOF

Repository: tile-ai/tilelang

Length of output: 177


🌐 Web query:

PyTorch torch.ones dtype parameter bool vs torch.bool

💡 Result:

Use torch.bool. The dtype argument expects a torch.dtype (examples listed in the docs include torch.bool); torch.bool is the boolean dtype. The torch.ones doc and the dtype list show dtype should be a torch.dtype (e.g. torch.bool). [1][2]

References:

  • torch.ones signature/dtype parameter. [1]
  • tensor dtypes list (shows torch.bool). [2]

Fix invalid Torch dtype in causal_mask construction (runtime TypeError).

Line 54 uses dtype=bool (Python type), but PyTorch's dtype parameter expects a torch.dtype object. Use dtype=torch.bool instead.

- causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
+ causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=torch.bool), diagonal=0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
out = torch.einsum(
"bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
)
causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=torch.bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum(
"bchls,bhcs,bcshp->bclhp", scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)
)
🤖 Prompt for AI Agents
In benchmark/mamba2/benchmark_mamba_chunk_scan.py around lines 54 to 58, the
causal_mask is constructed using the Python type bool for the dtype which causes
a runtime TypeError; change the dtype argument to the PyTorch dtype torch.bool
and ensure torch is imported/available in this scope so causal_mask =
torch.tril(torch.ones(chunk_size, chunk_size, device=x.device,
dtype=torch.bool), diagonal=0) produces a proper boolean mask before subsequent
masked_fill and einsum operations.

Comment on lines +212 to 220
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the file context around the lines in question
cat -n benchmark/mamba2/benchmark_mamba_chunk_scan.py | head -230 | tail -50

Repository: tile-ai/tilelang

Length of output: 2167


🏁 Script executed:

# Let's search for T.Tensor usage patterns in the file
rg "T\.Tensor" benchmark/mamba2/benchmark_mamba_chunk_scan.py -A 1 | head -50

Repository: tile-ai/tilelang

Length of output: 725


🏁 Script executed:

# Check what T is imported as
head -50 benchmark/mamba2/benchmark_mamba_chunk_scan.py | grep -E "^import|^from"

Repository: tile-ai/tilelang

Length of output: 272


🏁 Script executed:

# Look for similar tensor shape annotations in the codebase to understand patterns
rg "T\.Tensor\(\(" --no-filename -B 1 -A 1 | head -100

Repository: tile-ai/tilelang

Length of output: 2156


🏁 Script executed:

# Check if there's any T.Tensor with single dimension without trailing comma elsewhere
rg "T\.Tensor\(\([a-zA-Z_][a-zA-Z0-9_]*\)," --no-filename | head -20

Repository: tile-ai/tilelang

Length of output: 698


🏁 Script executed:

# Search specifically for patterns like (name) without trailing comma in T.Tensor calls
rg "T\.Tensor\(\([a-zA-Z_][a-zA-Z0-9_]*\)[,\)]" benchmark/mamba2/benchmark_mamba_chunk_scan.py

Repository: tile-ai/tilelang

Length of output: 113


🏁 Script executed:

# Get more context around the instances of (M), (N) without trailing comma
rg "T\.Tensor\(\((M|N|padding_M|nheads)\)," -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 4565


🏁 Script executed:

# Check if there are instances of (N,) WITH trailing comma to understand if both patterns coexist
rg "T\.Tensor\(\((M|N|padding_M|nheads),\)," -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 14761


🏁 Script executed:

# Search for the T.Tensor API definition or documentation
fd ".*\.py$" | xargs grep -l "def Tensor" | head -5

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if there's any documentation about T.Tensor that specifies shape requirements
rg "class Tensor|def Tensor" --type-list python -B 2 -A 10 | head -100

Repository: tile-ai/tilelang

Length of output: 449


🏁 Script executed:

# Verify the actual instances - count how many have trailing comma vs not
echo "=== Instances WITH trailing comma (correct) ==="
rg "T\.Tensor\(\([a-zA-Z_][a-zA-Z0-9_]*,\)" --no-filename | wc -l
echo "=== Instances WITHOUT trailing comma (incorrect) ==="
rg "T\.Tensor\(\([a-zA-Z_][a-zA-Z0-9_]*\)[,\)]" --no-filename | wc -l

Repository: tile-ai/tilelang

Length of output: 166


🏁 Script executed:

# Check if there are any related issues or comments about this pattern
rg "nheads.*Tensor|D.*Tensor.*nheads" -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 1205


🏁 Script executed:

# Verify the D tensor is actually used with shape (nheads,) in actual code
rg "D\[" benchmark/mamba2/benchmark_mamba_chunk_scan.py -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 466


Fix D tensor shape annotation: (nheads) must be (nheads,) for 1D tuple syntax.
In Python, (nheads) evaluates to the variable itself, not a 1-element tuple; tensor shape annotations require the trailing comma to form a proper tuple. The correct codebase pattern uses (N,) for all 1D tensors (95+ instances), while this file incorrectly omits the comma. This same error also appears in examples/linear_attention/example_mamba_chunk_scan.py.

-        D: T.Tensor((nheads), dtype),  # type: ignore
+        D: T.Tensor((nheads,), dtype),  # type: ignore
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
):
cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore
x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore
C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore
prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore
D: T.Tensor((nheads,), dtype), # type: ignore
Output: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore
):
🤖 Prompt for AI Agents
In benchmark/mamba2/benchmark_mamba_chunk_scan.py around lines 212 to 220, the
shape annotation for tensor D uses (nheads) which is not a 1-tuple; change it to
(nheads,) to make a proper 1D shape tuple. Update the same pattern in
examples/linear_attention/example_mamba_chunk_scan.py if present, run project
lint/type checks, and ensure all 1D tensor shapes follow the (N,) convention
consistently.

Comment on lines 13 to 29
def ref_program(Q, K, V, is_causal, groups=1):
assert Q.size(
2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(
2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}"
assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}"
dim = Q.size(-1)
K_ref = K.repeat_interleave(groups, dim=2)
V_ref = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K_ref)
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K_ref)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
scores = scores.masked_fill(mask == 0, float("-inf"))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V_ref)
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_ref)
lse = torch.logsumexp(scores, dim=-1).float()
return output, lse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid dtype/device mismatch in ref_program scaling; use scores.new_tensor(dim) (or math.sqrt).
torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) creates a CPU scalar by default; dividing a CUDA scores by a CPU tensor can error or trigger implicit transfers depending on PyTorch version. Prefer a device-aware scalar.

@@
-    scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
+    scores = scores / torch.sqrt(scores.new_tensor(dim))
🤖 Prompt for AI Agents
examples/amd/example_amd_flash_attn_bwd.py lines 13-29: the scaling uses
torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) which creates a CPU tensor and
can cause dtype/device mismatches when dividing a CUDA `scores`; replace that
expression with a device-aware scalar such as scores.new_tensor(dim).sqrt() (or
use math.sqrt(dim) to get a Python float) so the divisor matches `scores`'s
device and dtype before performing the division.

Comment on lines 75 to 84
sinks = T.alloc_fragment([heads], dtype)

T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fix potential OOB: sinks fragment shape doesn’t match its indexing.

-            sinks = T.alloc_fragment([heads], dtype)
+            sinks = T.alloc_fragment([block_M], dtype)
@@
             T.fill(scores_max, -T.infinity(accum_dtype))
             for i in T.Parallel(block_M):
                 sinks[i] = Sinks[by]

Also applies to: 125-127

Comment on lines 72 to 81
sinks = T.alloc_fragment([heads], dtype)

T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for i in T.Parallel(block_M):
sinks[i] = Sinks[by]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fix potential OOB: sinks fragment shape doesn’t match its indexing.

sinks is allocated as [heads] but indexed by i in T.Parallel(block_M) and later read as sinks[i], so this can go out-of-bounds when block_M > heads.

-            sinks = T.alloc_fragment([heads], dtype)
+            sinks = T.alloc_fragment([block_M], dtype)

             T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
             T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
@@
             T.fill(scores_max, -T.infinity(accum_dtype))
             for i in T.Parallel(block_M):
                 sinks[i] = Sinks[by]

Also applies to: 122-124

🤖 Prompt for AI Agents
In examples/attention_sink/example_mha_sink_bwd_bhsd.py around lines 72-81 (also
apply same fix at 122-124): sinks is allocated with shape [heads] but indexed
over range(block_M), causing potential OOB when block_M > heads; change the
allocation or the loop to make sizes consistent — either allocate sinks with
length block_M (e.g., T.alloc_fragment([block_M], dtype)) if you need per-i
storage, or restrict the loop to range(heads) (and adjust later reads
accordingly); update both occurrences so the allocation shape matches the
indexing used.

Comment on lines 13 to 16
def generate_text_batch(model, tokenizer, prompts, max_length=100):
# Encode the input prompts as a batch
input_ids = tokenizer(
prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Pass attention_mask when using padded batches (padding=True).
Right now you pad prompts but only pass input_ids into generate(), which can yield incorrect generation if the model/tokenizer doesn’t infer an attention mask the way you expect.

-    input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
+    enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
+    input_ids = enc.input_ids
+    attention_mask = enc.attention_mask
...
-    output_ids = model.generate(input_ids, generation_config=generation_config)
+    output_ids = model.generate(input_ids, attention_mask=attention_mask, generation_config=generation_config)

Also applies to: 33-35

🤖 Prompt for AI Agents
In examples/bitnet-1.58b/benchmark_generate.py around lines 13 to 16 (and
similarly lines 33 to 35), you encode prompts with padding but only pass
input_ids to model.generate; this can produce incorrect output because the model
needs the attention_mask for padded batches. Update the code to capture the
tokenizer output (input_ids and attention_mask), move tensors to model.device,
and pass attention_mask alongside input_ids into generate() so the model ignores
padded tokens during attention.

Comment on lines 184 to +189
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor,
float) or rope_scaling_factor <= 1.0:
raise ValueError(
f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}")
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

rope_scaling.factor validation likely rejects common JSON configs (int vs float).
Line 188-189 currently requires factor to be a float; configs loaded from JSON often yield int (e.g., 2), which would now raise. Prefer accepting real numbers (while excluding bool).

+import numbers
+
+_ROPE_SCALING_TYPE_ERROR = (
+    "`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+)
+_ROPE_SCALING_FACTOR_ERROR = (
+    "`rope_scaling`'s factor field must be a number > 1, got {rope_scaling_factor}"
+)
 ...
         rope_scaling_type = self.rope_scaling.get("type", None)
         rope_scaling_factor = self.rope_scaling.get("factor", None)
         if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
-            raise ValueError(f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}")
+            raise ValueError(_ROPE_SCALING_TYPE_ERROR.format(rope_scaling_type=rope_scaling_type))
         if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
-            raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
+            raise ValueError(_ROPE_SCALING_FACTOR_ERROR.format(rope_scaling_factor=rope_scaling_factor))
+        if isinstance(rope_scaling_factor, bool) or not isinstance(rope_scaling_factor, numbers.Real) or rope_scaling_factor <= 1:
+            raise ValueError(_ROPE_SCALING_FACTOR_ERROR.format(rope_scaling_factor=rope_scaling_factor))

(You can keep the single check; shown expanded for clarity.)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.8)

187-187: Avoid specifying long messages outside the exception class

(TRY003)


189-189: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In examples/bitnet-1.58b/configuration_bitnet.py around lines 184 to 189, the
current validation requires rope_scaling.factor to be a float which will reject
JSON integers (e.g., 2); update the check to accept any real number (int or
float) but exclude bool, and ensure its numeric value is > 1.0—either by using
isinstance(value, numbers.Real) and not isinstance(value, bool) or by allowing
int/float and then converting to float for the comparison and error message.

Comment on lines +105 to 116
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
) as (
bx,
by,
bx,
by,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add bounds handling (N-tail and K-tail) — current indexing can go OOB.
With bx = T.ceildiv(N, n_partition) and col = bx * n_partition + ni, col can be >= N on the last block(s), but it’s used for both B[...] reads and C[...] writes. More critically, block_K = reduce_thread * micro_size_k can exceed K (e.g., the included __main__ call uses K=256 while int8 settings yield block_K=512), so A[...] / B[...] loads can index past K in the very first ko tile.

Please either (a) add predication/guarded loads + guarded store, or (b) assert supported shapes and adjust the example to match those constraints.

A minimal shape-safe direction (sketch; adapt to tilelang idioms for predication):

@@
         with T.Kernel(
             T.ceildiv(N, n_partition),
             M,
             threads=(reduce_thread, n_partition),
         ) as (
             bx,
             by,
         ):
+            col = bx * n_partition + ni
@@
-            for ko in T.serial(T.ceildiv(K, block_K)):
+            for ko in T.serial(T.ceildiv(K, block_K)):
                 for v in T.vectorized(micro_size_k):
-                    A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
+                    k = ko * block_K + kr * micro_size_k + v
+                    A_local[v] = T.if_then_else(k < K, A[by, k], T.Cast(in_dtype, 0))
@@
                 for v in T.vectorized(micro_size_k_compressed):
-                    B_quant_local[v] = B[bx * n_partition + ni, ko * (...) + kr * (...) + v]
+                    kc = ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v
+                    B_quant_local[v] = T.if_then_else(
+                        (col < N) & (kc < B_shape[1]),
+                        B[col, kc],
+                        T.Cast(storage_dtype, 0),
+                    )
@@
-            if kr == 0:
-                C[by, bx * n_partition + ni] = reduced_accum_res[0]
+            if (kr == 0) & (col < N):
+                C[by, col] = reduced_accum_res[0]

(If T.if_then_else isn’t available/appropriate in this context, the alternative is to early-exit the compute for invalid col and/or add explicit shape asserts so the kernel is only generated for safe shapes.)

Also applies to: 134-137, 129-137, 172-174, 261-262

🤖 Prompt for AI Agents
In
examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py
around lines 105-116, the indexing can go out-of-bounds for both N (col) and K
(block_K), so add explicit bounds handling: guard each B read and C write with a
predicate col < N, and guard each A/B load with k_index < K (or zero-fill when
out-of-bounds) so no memory access happens past shapes; alternatively,
enforce/validate shapes at kernel entry (assert block_K <= K and that bx
computation cannot produce col >= N) and update the example shapes accordingly;
apply the same guarded-load/store or shape-assert changes to the other affected
ranges (lines 129-137, 134-137, 172-174, 261-262).

Comment on lines 442 to 446
if head_first:
o_slc = rearrange(o_slc, 'b t h d -> b h t d')
o_swa = rearrange(o_swa, 'b t h d -> b h t d')
o_slc = rearrange(o_slc, "b t h d -> b h t d")
o_swa = rearrange(o_swa, "b t h d -> b h t d")

return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: naive_nsa rearranges o_swa even when it is None (head_first + no window).

When window_size == 0, o_swa is None; rearrange(o_swa, ...) will throw.

-    if head_first:
-        o_slc = rearrange(o_slc, "b t h d -> b h t d")
-        o_swa = rearrange(o_swa, "b t h d -> b h t d")
+    if head_first:
+        o_slc = rearrange(o_slc, "b t h d -> b h t d")
+        if o_swa is not None:
+            o_swa = rearrange(o_swa, "b t h d -> b h t d")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if head_first:
o_slc = rearrange(o_slc, 'b t h d -> b h t d')
o_swa = rearrange(o_swa, 'b t h d -> b h t d')
o_slc = rearrange(o_slc, "b t h d -> b h t d")
o_swa = rearrange(o_swa, "b t h d -> b h t d")
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
if head_first:
o_slc = rearrange(o_slc, "b t h d -> b h t d")
if o_swa is not None:
o_swa = rearrange(o_swa, "b t h d -> b h t d")
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
🤖 Prompt for AI Agents
In examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py around lines 442 to 446,
the code unconditionally calls rearrange(o_swa, ...) when head_first is True,
but o_swa can be None (window_size == 0) which causes an exception; change the
logic so you only call rearrange on o_swa when o_swa is not None (e.g., perform
o_slc = rearrange(...) unconditionally under head_first, then if o_swa is not
None do o_swa = rearrange(...)); similarly ensure the final return sums
converted tensors only when o_swa is present (return o_slc.to(dtype) +
o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)).

Comment on lines +91 to 99
def cal_cu_seqlen_ks_for_q(
cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int
) -> torch.IntTensor:
cu_seqlen_ks_for_each_q = torch.gather(
input=torch.cat([
cu_seqlens_ks,
torch.full((1,),
torch.iinfo(torch.int32).max,
dtype=torch.int32,
device=cu_seqlens_qs.device)
]),
input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]),
dim=0,
index=cal_seq_idx_for_q(
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long())
index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(),
)
return cu_seqlen_ks_for_each_q.int()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential dtype mismatch in torch.cat([...]) used for gather inputs.
cal_cu_seqlen_ks_for_q (Line 95) concatenates cu_seqlens_ks with an int32 tensor, and cal_cu_seqlen_ke_for_q (Line 113) concatenates cu_seqlens_ke with an int32 tensor. torch.cat requires identical dtypes; if cu_seqlens_ks/ke are truly LongTensor as annotated, this would raise at runtime. Either ensure callers pass int32 here (and fix type hints), or cast before concat.

Proposed minimal fix (keeps behavior, enforces dtype consistency):

-        input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]),
+        input=torch.cat(
+            [
+                cu_seqlens_ks.to(torch.int32),
+                torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device),
+            ]
+        ),
...
-        input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]),
+        input=torch.cat(
+            [
+                cu_seqlens_ke.to(torch.int32),
+                torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device),
+            ]
+        ),

Also applies to: 103-116

@LeiWang1999
Copy link
Member Author

@codex review

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. Nice work!

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@LeiWang1999 LeiWang1999 merged commit 2905143 into tile-ai:main Dec 12, 2025
0 of 2 checks passed
chengyupku added a commit to tile-ai/tilescale that referenced this pull request Feb 6, 2026
* [Example] Add GQA decoding kernel with varlen page table (#1265)

* [Example] Add page table for gqa decode

* [Example] Page table for varlen decoding

* [Lint]

* [Refactor] Remove redundant code

* [Lint]

* [Lint]

* [Lint]

* [Refactor] add support for numpy dtype conversion (#1255)

* add typing stub for tir.ir

* remove idents

* minor update

* [Refactor] add numpy conversion for dtype

* fix lint error

* remove unused np.float_ in dtype conversion

* fix type in np.int_

* fix typo

* minor fix

* remove debug files

* [EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability (#1148)

* Keep the max of all blocks seen in scores_max for stability

* ruff formatting

* [Docs] Improve Installation Guide (#1270)

* [Docs] Improve installation guide

* address comments

* [Enhancement] Keep max score attention across blocks in FlashAttention for better numerical stablity (#1269)

* Implement max score retention across blocks in FlashAttention for improved stability

* fix manual pipeline parameters

* Update examples/flash_attention/example_gqa_fwd_varlen.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix typo

* more

* fix a previous typo

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* [Bugfix] Fix multiple cg defination when using T.sync_grid (#1272)

* [Minor] Remove from __future__ import annotations for python 3.8 (#1273)

* [BugFix] Adding extra parameters into autotune hashkey (#1274)

* [BugFix] Adding extra parameters into autotune hashkey

* lint

* None check

* check serializable

* Fix various issues under `int64_t` static and dynamic shape. (#1218)

* Fix various issues under int64_t static and dynamic shape.

* Resolve reviewed issues.

* Add unit test.

* fix

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>

* Bug fix for Gated Delta Net benchmark script (#1267)

* fix argument order for fla chunk_gated_delta_rule_fwd_h

* explicit import assert_similar from utils

* rename utils module to avoid name clash

* set store_final_state and save_new_value to True

* fix

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>

* [Bugfix] Minor fix for some cases (#1278)

* [Language] Add shape check in `T.view/reshape` (#1277)

* [Language] Add shape check in T.view/reshape

* address comments

* [FFI] Use tvm ffi as the default execution backend (#1259)

* [Refactor] Update FFI type handling and simplify argument management

* Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity.
* Updated function registration in `runtime.cc` to utilize canonical names for better consistency.
* Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled.
* Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection.
* Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity.

* [Update] Sync TVM submodule and enhance kernel source handling

* Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes.
* Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging.
* Commented out the main execution call in test files to prevent unintended execution during testing.
* Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues.
* Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends.

* [Refactor] Clean up imports and improve code formatting

* Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code.
* Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency.
* Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality.

* Update execution backend options and improve resolution logic

- Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target.
- Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions.
- Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target.
- Updated documentation to reflect changes in execution backend options and their defaults.

* lint fix

* fix

* Enhance argument handling in CUDA and HIP runtime modules

- Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime.
- Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers.
- Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks.

* lint fix

* lint fix

* lint fix

* lint fix

* minor fix

* fix

* recover check

* Refactor argument binding and validation in `arg_binder.cc`

- Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers.
- Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards.
- Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling.
- Minor adjustments in test files to streamline kernel execution and improve readability.

* lint fix

* stride fix

* minor fix

* fix

* lint fix

* lint fix

* Add CUDA stream access policy window helpers and integrate with L2 persistent cache management

- Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage.
- Updated runtime files to include new FFI packed functions for managing stream attributes.
- Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown.
- Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source.

* check with symbolic

* support null ptr

* Update CMakeLists and lower.py for code generation and subproject status

- Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support.
- Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility.
- Marked the TVM subproject as dirty to indicate local modifications.

* lint fix

* Update comments for clarity in quickstart.py

* [Bugfix] Supply missing `T.print` for bool type (#1279)

* fix for bool dtype

* lint fix

* fix

* ci fix

* [Fix] Fix memory leak bug (#1281)

* add typing stub for tir.ir

* remove idents

* minor update

* [Refactor] add numpy conversion for dtype

* fix lint error

* remove unused np.float_ in dtype conversion

* fix type in np.int_

* fix typo

* minor fix

* remove debug files

* fix memory leak bug

* fix lint error

* add comments

* fix lint error

* remove duplicated, because tilelang doesn't dependent deprecated

* [Enhancement] Enhance CUDA compilation by integrating pass context configuration (#1283)

- Updated the `tilelang_callback_cuda_compile` function to accept a `pass_config` parameter, allowing for more flexible compilation options.
- Introduced handling for fast math and PTXAS options based on the provided pass configuration.
- Modified the CUDA build process in `rt_mod_cuda.cc` to utilize the current pass context, improving the integration of compilation settings.
- Refactored NVCC command construction to use a dedicated function for better clarity and maintainability.

* Fix the bug in issue #1266 (#1284)

Co-authored-by: cheeryBloosm <liu_yu_hao@126.com>

* [Language][UX] Nested loop checker in pre-lowering stage (#1288)

* [Language][UX] Nested loop checker in pre-lowering stage

* rename

* comment

* address comments

* [Compatibility] Support CUDA 11.3 (#1290)

* [Feat] Add support for using `T.Tensor(n * 2 + 1)` in function annotation (#1285)

* [Feature] Add support for A: T.Tensor(n + 1) and A: T.Tensor(2*n)

* issue fix

* fix

* fix

* decreate nproc for debugging

---------

Co-authored-by: Lei Wang <leiwang1999@outlook.com>

* [Feat] add support for passing reference in T.Var annotation (#1291)

* [Enhancement] Shared Memory Size Can be Dynamic (#1294)

* bugfix

* lint fix

* test

* lint fix

* increate procs

* recover

* [Fix] Remove unused let_bindings_ in CodeGenC to fix #1300 (#1305)

* [Feat] add missing support of uint32x2

* [Feat] Add `T.Ref` annotation and tests

* fix lint error

* minor update for error message on twice decl

* Remove unused let_bindings_ in CodeGenC to fix #1300

* [Bugfix] Fallback to the old AtomicAdd implementation for legacy architectures (#1306)

* [Fix] Fix frame scope error in T.macro (#1308)

* [Fix] Fix #1307 by adding macro inside function

* fix lint error

* add comments and fix lint error

* Remove debug print from enter_frame method

Removed debug print statement from enter_frame method.

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [WIP] support more dtypes for tcgen05 (#1229)

support ld with pack for fp32 dtype

add dump

add tempalte expand

remove unused dtype and change to rebased apis

* Improve memory access safety and `T.assume` handling (#1292)

* Improve memory access safety and T.assume handling

* Improve memory access safety and T.assume handling

* bugfix

* lint fix

* bugfix

* bugfix

* refactor legalize safe memory access pass

---------

Co-authored-by: Lei Wang <leiwang1999@outlook.com>

* [Bugfix] Fix autotune cache (#1315)

* [Refactor] Backup Analyzer to get the appropriate arith informations (#1311)

* [Refactor] Update Vectorization Functions to Accept Analyzer Parameter

- Modified `VectorizeLoop` and related functions to accept an `arith::Analyzer` parameter, enhancing their capability to perform analysis during vectorization.
- Updated multiple instances in `copy.cc`, `fill.cc`, `parallel.cc`, and layout inference files to utilize the new analyzer parameter for improved performance and correctness.
- Ensured consistency across vectorization logic by integrating the analyzer into existing workflows, facilitating better optimization opportunities.

* [Fix] Corrected PostOrderVisit call in loop_vectorize.cc

- Updated the PostOrderVisit function to analyze the body of the loop node instead of the node itself, ensuring proper handling of nested loops during vectorization analysis.

* fix

* lint fix

* fix

* Revert "[WIP] support more dtypes for tcgen05 (#1229)" (#1323)

This reverts commit 0d101c110f74ebf2ef8c11a5ece9dfb314b48baa.

Co-authored-by: Zhiwen Mo <zm125@ic.ac.uk>

* [CI]: Bump actions/checkout from 5 to 6 (#1319)

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [CI]: Bump pypa/cibuildwheel from 3.2 to 3.3 (#1318)

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Installation] Fix building using customized TVM path (#1326)

* [Release] Allow developer with write permission to trigger wheel release (#1322)

* [Feat] Support warp reduce (#1316)

* [Feat] Support warp reduce

* lint

* add test

* lint

* [Enhancement] Support more dtype in `T.print` (#1329)

* [Enhancement] Support more dtype in `T.print`

* upd

* upd

* [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape (#1321)

* [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape

* remove debug lines

* remove rubbish

* Fix decorator syntax for atomic_different_memory_orders_program

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Fix] fix wrong uint narrowing bug in tvm in #1310 (#1320)

* [Refactor] Disable strided buffer load inside tvm (#1301) (#1332)

* [Refactor] Moving `NormalizeToBufferRegion` and `MakeAccessPtrFromRegion` to utils (#1333)

* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix

* [Fix] Fix bug copying from or to local buffer (#1304) (#1324)

* [Fix] fix copy from or to local buffer (#1304)

* fix lint error

* minor fix testing script

* [Language][UX] Semantic check for parallel fragment access (#1338)

* Add unit tests for T.assume (#1341)

* Add test for T.assume

* Add unit test for T.assume

* Add unit test for T.assume

* Add unit tests for T.assume

* Remove debug print for kernel source

Remove print statement for kernel source in tests.

* Update test_tilelang_language_assume.py

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Feat] Extend LegalizeNegativeIndex to support buffer store stmts (#1339)

This commit enhances the LegalizeNegativeIndex transformation pass to handle
both buffer load and store operations with negative indices and adds some
test cases.

* [Refactor] Phaseout vmap for Tile Operators (#1334)

* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix

* Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations.

* fix

* Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions.

* fix

* fix

* test fix

* lint fix

* Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management.

* fix

* lint fix

* fix

* fix

* test fix

* lint fix

* lint fix

* minor fix

* fix

---------

Co-authored-by: Zhiwen Mo <zm125@ic.ac.uk>

* [Enhancement] add more dtype and fix mma.ws for fp16 for tcgen05 (#1327)

* feat: add fp8 variants; add placeholder for fp6/fp4 in meta

support ld with pack for fp32 dtype

add dump

add tempalte expand

remove unused dtype and change to rebased apis

* fix: when atom-m!=128, enable_ws

* fix: typo in tcgen05 meta; dispatch in gemm sm100

* [Refactor] Enhance CopyNode's IterVar Creation and Range Handling (#1346)

* [Refactor] Enhance CopyNode's IterVar Creation and Range Handling

This commit refines the `MakeIterVars` method in `CopyNode` to select base ranges based on memory scope levels, ensuring that the chosen ranges are not smaller than the original source ranges. Additionally, it updates the Python `copy` function to clarify range handling, including broadcasting logic and extent alignment. These changes improve the robustness and clarity of the copy operation's implementation.

* test fix

* [Fix] Fix missing `not` rewrite in frontend (#1348)

* [Enhancement] Add support for k_pack in gemm_mfma (#1344)

* add support for k_pack

* support benchmark on ROCm

* fix format

* Add sparse fine-tuning kernel for deepseek sparse attention to example (#1296)

* [EXAMPLE] add example for dsa sparse finetuning

* [Refactor]

* [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder (#1352)

* [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder

This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase.

* [Enhancement] Update matmul kernel and optimize argument binding

This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code.

* lint fix

* [Enhancement] Add tensor checks documentation and improve argument binding assertions

This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code.

* [Enhancement] Update .gitignore and refine matmul kernel for improved performance

This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users.

* lint fix

* lint fix

* [Refactor] Simplify tensor_null_test function and remove ptr_null_test

This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations.

* lint fix

* fix

* [Refactor] Simplify index sign state handling in LegalizeNegativeIndex (#1354)

This commit refines the logic for determining the sign state of indices in the LegalizeNegativeIndex transformation. It prioritizes vector patterns, specifically Ramp and Broadcast nodes, to avoid compile-time lane queries. The handling of scalar indices is also streamlined, ensuring clearer diagnostics when non-negativity cannot be proven. These changes enhance the robustness and clarity of index handling in the transformation pass.

* [Enhancement] Improve error handling and assertion messages across runtime and argument binding (#1356)

This commit enhances the error handling mechanisms in the runtime by introducing CPU-safe runtime helpers and refining assertion messages in the CodeGenCHost and ArgBinder. It includes structured packed error messages for various conditions, improving clarity in diagnostics. Additionally, the CMake configuration is updated to always include necessary runtime helpers, ensuring consistent error reporting. The changes aim to provide clearer feedback during runtime errors and improve the overall robustness of the argument binding process.

* [Bugfix] Disable floordiv optimization due to integer overflow risk (#1355)

* disable overflow-prone floordiv optimization in lower_intrin.cc

* disable overflow-prone floordiv optimization in lower_intrin.cc

* [Bugfix] Fix the jit_kernel issue (#1357)

* [Bugfix] Fix the jit_kernel issue

* Update README.md

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Refactor] Update Fragment Indexing in ParallelOpNode's InferLayout Method (#1359)

This commit refines the Fragment creation process in the InferLayout method of ParallelOpNode. It removes the unnecessary forward_index array and utilizes default fragment indexing for consistency with other operations. Additionally, it binds the thread range to enhance comparability across different operations.

* [Analysis] Enhance NestedLoopChecker with tile op cases (#1358)

* [Analysis] Enhance NestedLoopChecker with tile op cases

* fix tileop issue

* [Language] support `T.gemm_sp_v2` on sm80 and sm89 (#1056)

* [misc] add a cpp side wrapper for gemm_sp_py

* [misc] typing

* [IR] bind GemmSPWarpPolicy

* [chore] add wrapper code

* [IR] fix GemmSPWarpPolicy

* [codegen] apply ptxas instructions

* [intrinsic] add typical (unused) mma layout

* [template] add uint16 debug func

* [intrinsic] add b matrix layout

* [gemm_sp] enable fp16/bf16 on sm8x

* [layout] refactor fp16/bf16 layout

* [gemm_sp] enable int8

* [chore] update test case dtype

* [gemm_sp] enable fp32

* [layout] refactor layouts

* [intrinsic] enable ldmatrix for mat A

* [layout] enable ldsm for matrix b

* [layout] add ldmatrix for fp32 and fp8

* [chore] refine

* [chore] refactor

* [chore] add fp8 efactor

* [chore] refactor

* [chore] add remove negative zero util

* [example] add a custom compress kernel

* [chore] minor update

* [test] refactor gemm_sp test

* [refactor] make metadata layout func

* [example] add option for using cutlass layout

* [doc] add a gemm_sp doc

* [doc] minor polish

* [chore] remove unused

* [bugfix] fix non replicate b case

* [test] refactor

* [chore] add a check

* [bugfix] fix util bug

* [wip] init a new test case for v2

* [chore] minor refactor

* [chore] minor update

* [bugfix] enable 16bit rs

* [language] enable rs

* [language] enable gemm_sp_sr

* [language] enable gemm_sp_rr

* [test] enable more tests

* [tvm] update ffi binding

* [chore] remove print

* [chore] fix benchmark script

* [lint] precommit lint

* [chore] apply feedback

* [test] use arch 8.0

* [chore] rollback ::ordered_metadata for backward compatibility

* [bugfix] fix captialized

* [example] keep gemm_sp on hopper

* [test] fix no fp8 normal kernel

* [test] reduce matmul size to satisfy accum error

* [test] use cal_diff for assertion

* [bugfix] expand float8 type

* [lib] add make_int4 for short type

* [language] add transpose E

* [bugfix] fix wrong var

* [format] format

* [chore] refactor binding

* [chore] fix wrong passing var

* [Bugfix] Update TIR registration for GemmSPPy to use tile operation (#1361)

* [Enhancement] Implement dynamic unroll factor in CUDA code generation (#1360)

* [Enhancement] Implement dynamic unroll factor in CUDA code generation

This commit introduces support for specifying a dynamic unroll factor in the CUDA code generation. The `unroll_factor` map is added to store unroll factors for loop variables, allowing for more flexible and optimized loop unrolling. Additionally, the `unroll` function is integrated into the loop language, enabling users to define unroll factors directly in their code. This enhancement improves performance by allowing tailored unrolling strategies based on specific loop characteristics.

* lint fix

* [Bugfix] Correct initialization of non-zero counters in custom compress kernel and update TIR registration for gemm_sp_py to use the correct tile operation

* [CI] [pre-commit.ci] autoupdate (#1362)

updates:
- [github.com/pre-commit/mirrors-clang-format: v21.1.2 → v21.1.6](https://github.com/pre-commit/mirrors-clang-format/compare/v21.1.2...v21.1.6)
- [github.com/astral-sh/ruff-pre-commit: v0.14.3 → v0.14.7](https://github.com/astral-sh/ruff-pre-commit/compare/v0.14.3...v0.14.7)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Bugfix] Remove debug print in PyStmtFunctionVisitor  (#1363)

* [Debug] Always include line info in NVCC command for improved profiling and mapping (#1364)

* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py (#1365)

* [Enhancement] Add DISABLE_CACHE environment variables (#1368)

* [Refactor]: Remove useless include in atomicadd_vectorize.h (#1371)

* [Refactor] Generalize fp8 process (#1372)

* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py

* [Enhancement] Extend support for float8 data types in GEMM operations

- Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`.
- Refactored condition checks in `checkWgmma` methods to simplify float8 type handling.
- Adjusted test cases to ensure compatibility with the new float8 types in tile language examples.

* lint fix

* [Layout] Enhance Free Layout Inference (#1375)

* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py

* [Enhancement] Extend support for float8 data types in GEMM operations

- Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`.
- Refactored condition checks in `checkWgmma` methods to simplify float8 type handling.
- Adjusted test cases to ensure compatibility with the new float8 types in tile language examples.

* lint fix

* [Enhancement] Add injective layout detection and exception handling

- Introduced `DetectInjective` method in `FragmentNode` to check for injective layouts.
- Added `LoopLayoutInjectiveException` to handle errors related to non-injective layouts.
- Updated `InferLayout` methods in `ParallelOpNode` to utilize injective checks and log relevant information.
- Refactored layout inference queue management to use `std::deque` for improved performance and added prioritization logic for buffer layouts.

* remove debug print

* remove debug print

* remove debug print

* minor layout fix

* fix for T.view

* [Enhancement] Improve injective layout detection in FragmentNode

- Updated the `DetectInjective` method to handle symbolic dimensions more effectively by introducing a mechanism to collect symbolic shapes and adjust the detection level accordingly.
- Added logging for cases where the layout detection falls back to NoCheck due to symbolic dimensions.
- Minor update to the test file to include the tilelang testing module.

* [Refactor] Simplify layout inference for bulk copy operations

- Removed unnecessary conditions for bulk load/store operations in the layout inference logic.
- Streamlined the handling of layout application for bulk copy instances to enhance clarity and maintainability.

* remove debug print

* [Enhancement] Introduce layout-related exceptions and improve error handling

- Added `LayoutConflictException` and `LoopLayoutInjectiveException` classes for better exception management in layout operations.
- Updated `InferLayout` method in `ParallelOpNode` to throw `LoopLayoutInjectiveException` with detailed error information when injective layout checks fail.
- Removed redundant exception class definitions from `parallel.h` to streamline code organization.

* [Enhancement] Introduce buffer var lca analysis for pass plan buffer allocations (#1376)

* Update submodule TVM to latest commit and add PlanAndUpdateBufferAllocationLocation function to transform module

- Updated the TVM submodule to commit 3a32b763.
- Added a new function `PlanAndUpdateBufferAllocationLocation` in the transform module to facilitate buffer allocation planning within PrimFuncs.

* Refactor buffer allocation code for improved readability and consistency

- Updated formatting and spacing in `plan_update_buffer_allocation_location.cc` for better code clarity.
- Standardized the use of pointer and reference syntax across various class methods.
- Enhanced comments for better understanding of buffer allocation logic.
- Removed unnecessary lines and improved overall code structure.

* Refactor buffer allocation checks for improved clarity

- Replaced size checks with empty checks for `ffi::Array<Buffer>` in `plan_update_buffer_allocation_location.cc` to enhance code readability.
- Updated conditions in multiple methods to use `empty()` instead of comparing size to zero, streamlining the logic.

* [Tool] Provide layout visualization tool (#1353)

* Provide layout visualization tool

Adds a layout visualization tool to TileLang, which helps users understand and debug the layout transformations applied during compilation.

This tool visualizes the memory layout of tensors at different stages of the compilation process, allowing developers to identify potential inefficiencies and optimize their code for better performance.

The visualization can be enabled via a pass config option.

* format

* add layout visual example

* Adds vis extra with matplotlib dependency

* rafactor pass config name

* fix lint

* Enables configurable layout visualization formats

Allows users to specify the output formats (png, pdf, svg) for layout visualization through a pass config option.

This change provides more flexibility in how layout visualizations are generated, allowing users to choose the formats that best suit their needs.

It also fixes a bug where layout visualization was not correctly disabled when the config option was set to "false".

* Adds visual layout inference tool docs

* fix lint

* fix lint

* Rafactor configurable layout visualization formats

* fix lint

* fix typo

* add some comments

* fix lints

* add some warnings for user

* Moves layout visualization

* Refactors layout visualization pass configuration

Updates the layout visualization pass configuration to use boolean flag for enabling and a string for specifying formats.

* Enables multiple layout visualization formats

* Updates layout visualization docs

* Moves layout visualization to analysis

* [Release] Relax constraint of tvm-ffi to compatible version (#1373)

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>

* [Language] Tilelang LazyJIT Experimental Version (#1337)

* initial step

* modify builder

* scratch version of new frontend

* write some tests

* add many tests

* add typing stub for tir.ir

* remove idents

* minor update

* minor update

* First version of jitv2 (renamed to LazyJIT)

* fix pre-commit error

* minor fix

* fix lint error

* fix lint error

* Fix conditional check for PrimFunc instance

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Builder] Enhance variable name binding and scope management (#1378)

- Improved handling of TVM Var/Buffer names to prevent out-of-scope errors when reusing Python names across different for-frames.
- Added assertions to ensure variables are defined within the correct control flow frame, enhancing error checking and code reliability.

* [Bugfix] make cuda driver api compat with cuda12/13, along with tests (#1379)

* [Fix] typo in cuda attr (#1380)

* [Bugfix] make cuda driver api compat with cuda12/13, along with tests

* fix typo in cudaDevAttr

* [Language V2] Minor fix for complex annotations (#1381)

* [Release] Bump Version into 0.1.7 (#1377)

* Update VERSION to 0.1.7

* Update Python version in distribution scripts to support CPython 3.9 and log output

* [Typing] Enhance compatibility for advanced typing features in Python (#1382)

- Updated `allocate.py` and `annot.py` to improve compatibility with Python 3.9 and later by conditionally importing advanced typing features such as `TypeVarTuple`, `Unpack`, and `ParamSpec`.
- Added fallback imports from `typing_extensions` for environments using earlier Python versions.
- Improved handling of generic alias detection to ensure consistent behavior across different Python versions.

* [Bugfix][Build] Update CMake configuration to remove project root injection for sys.path (#1385)

* [Build] Update CMake configuration for tilelang_cython_wrapper installation

- Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib.
- Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules.
- Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects.

* [Build] Standardize output directories for tilelang libraries

- Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds.
- This change enhances organization and ensures that all build artifacts are located in a unified directory structure.

* [BugFix] Fix split kernel layout bug of GQA decode (#1386)

* [BugFix] Fix split kernel layout bug of GQA decode

* [BugFix] Avoid local with Parallel; use robust fragment instead

* [Enhancement] Add debug output methods for Layout and Fragment classes (#1392)

* [Doc] Update logging docs (#1395)

* [Enhancement] Refactor inflight computing to support dynamic pipeline extents (#1399)

* [Build] Update CMake configuration for tilelang_cython_wrapper installation

- Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib.
- Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules.
- Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects.

* [Build] Standardize output directories for tilelang libraries

- Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds.
- This change enhances organization and ensures that all build artifacts are located in a unified directory structure.

* [Refactor] Update TVM subproject and enhance pipeline loop handling

- Updated the TVM subproject to commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0.
- Added new fields to `PipelineAnnotation` and `RewrittenBlockInfo` structures to track original statement indices and improve async state management.
- Refactored `EmitImpl` and `PopulateWaitCounts` methods to enhance clarity and functionality, including better handling of commit groups and wait counts.
- Simplified access index calculations and strengthened analyzer constraints for loop bounds.

* [Cleanup] Remove license block and unused includes from inject_pipeline.cc

- Eliminated the Apache license block from the top of the file to streamline the code.
- Removed unused include directives for memory and stringstream to enhance code clarity and reduce unnecessary dependencies.

* [Refactor] Enhance transformation pipeline and test execution

- Added an additional Simplify transformation in the InjectSoftwarePipeline to improve optimization.
- Updated the test file to call `test_trival_pipeline()` directly, commenting out the previous main execution for better test isolation.

* [AMD] Fix 3 bugs when build docker on amd mi3x gpu (#1401)

* [Typo] Fix tilelang link in README.md (#1402)

* [Dependency] Update apache-tvm-ffi version to >=0.1.2 (#1400)

* [Dependency] Update apache-tvm-ffi version to >=0.1.2 in project files

* [Dependency] Update subproject commit for TVM to latest version afc07935

* [Enhancement] Add support for optional step parameter in loop constructs

- Updated loop creation functions to accept an optional step parameter, enhancing flexibility in loop definitions.
- Modified ForFrame implementations to utilize the new step parameter across various loop types including serial, parallel, and pipelined loops.
- Adjusted related vectorization transformations to accommodate the step parameter, ensuring consistent behavior in loop vectorization processes.

* lint fix

* [AMD] Enable FA2 fwd on AMD MI300X (#1406)

* enable FA2 on AMD MI300X

* make lint happy

* [TypoFix] fix typo for SM120 (#1408)

* [Doc] Minor documentation update (#1410)

* [Dependency] Add torch-c-dlpack-ext to project requirements (#1403)

* [Dependency] Add torch-c-dlpack-ext to project requirements

* Added torch-c-dlpack-ext to both pyproject.toml and requirements.txt to provide prebuilt torch extensions, which may prevent JIT compilation on first import of TVM FFI.

* [Build] Update manylinux images in project configuration

* Changed the manylinux image for x86_64 from "manylinux2014" to "manylinux_2_28" in both pyproject.toml and the Dockerfile to align with updated standards for compatibility and performance.

* [Build] Update CUDA repository configuration in pyproject.toml

* Changed the package manager command from `yum-config-manager` to `dnf config-manager` for adding the CUDA repository, ensuring compatibility with newer systems.

* fix

* [Build] Update CUDA repository to RHEL 8

* Changed the CUDA repository configuration in both pyproject.toml and the manylinux Dockerfile from RHEL 7 to RHEL 8, ensuring compatibility with newer systems.

* test: run out of space

* use cu130 to reduce size

* upd

* upd comment

* upd

---------

Co-authored-by: Your Name <wenji.yyc@alibaba-inc.com>

* [Dependency] Update TVM subproject to latest commit 2b1ead1a (#1412)

* [Enhancement] Introduce `T.__ldg` (#1414)

* [Enhancement] Add __ldg intrinsic for CUDA read-only cache loads

* Introduced the __ldg intrinsic to enable explicit read-only cached loads from global memory in CUDA.
* Updated the corresponding documentation and added support in both CUDA and HIP code generation.
* Enhanced the Python interface for __ldg to accept BufferLoad and Buffer types, improving usability.

* [Enhancement] Update formatting and linting rules in pyproject.toml; minor test adjustment

* Added new formatting rules in pyproject.toml to enforce consistent code style, including hanging indents and argument splitting.
* Updated test_tilelang_language_intrinsics_codegen.py to improve readability by adding a blank line before the main execution block.
* Refactored error messages in builtin.py for better clarity and consistency, ensuring proper formatting in function definitions and raising ValueErrors.

* lint fix

* [Enhancement] Improve vectorization invariant check (#1398)

* Improve loop vectorize

* Improve loop vectorize

* Improve loop vectorize

* Improve loop vectorize

* Improve loop vectorize

* Add some vectorize tests and comments

* [Lint] Phaseout Yapf format and embrace ruff format (#1417)

* [Atomic] Use ptr for atomicAdd dst instead of reference (#1425)

* [Enhancement] Update AtomicAdd function signature to accept pointer to destination

* Modified AtomicAdd in CUDA to take a pointer instead of a reference for the destination argument.
* Updated related code in atomicadd_vectorize.cc to ensure compatibility with the new signature.
* Adjusted Python interface in atomic.py to pass the destination by pointer, aligning with device function requirements.

* [Enhancement] Refactor AtomicAddRet function signature to accept pointer

* Updated AtomicAddRet in both CUDA and HIP to take a pointer instead of a reference for the address argument, improving consistency with the AtomicAdd function.
* Adjusted the implementation to ensure proper reinterpretation of the address type for atomic operations.

* lint fix

* [Enhancement] Refactor AtomicAddNode::MakeSIMTLoop to use destination pointer

* Updated the MakeSIMTLoop function to build a pointer to the destination element using tvm_access_ptr instead of loading the destination value directly.
* Simplified the handling of source and destination predicates, improving clarity and maintainability of the code.
* Ensured compatibility with the new pointer-based approach for atomic operations.

* lint fix

* test fix

* lint fix

* [CUDA] Add read-only parameter annotation for CUDA codegen (#1416)

* [Enhancement] Add read-only parameter annotation for CUDA codegen

* Introduced the `AnnotateReadOnlyParams` transformation to annotate read-only handle parameters in PrimFuncs, enabling the generation of `const` qualifiers in CUDA codegen.
* Updated `PrintFunctionSignature` and `AddFunction` methods to utilize the new attribute `tl.readonly_param_indices`, enhancing performance by allowing read-only cache loads.
* Modified the optimization pipeline to include the new annotation step, improving the overall efficiency of the code generation process.

* lint fix

* [Dependency] Update apache-tvm-ffi version to >=0.1.3

* Updated the version of apache-tvm-ffi in pyproject.toml, requirements.txt, and requirements-dev.txt to ensure compatibility with the latest features and fixes.
* Made adjustments in CUDA and HIP template files to use `const` qualifiers for global pointer parameters, enhancing code safety and clarity.

* lint fix

* [Enhancement] Refactor ReadWriteMarker for improved parameter handling

* Updated the ReadWriteMarker class to accept a set of parameter or data variables, enhancing its ability to track written variables.
* Introduced a new method, ResolveDataVarFromPtrArg, to resolve underlying buffer data from pointer-like arguments, improving accuracy in identifying written variables.
* Modified the MarkReadOnlyParams function to gather handle parameters and their corresponding buffer data variables, streamlining the process of determining read-only parameters.
* Enhanced the logic for identifying written variables to account for aliased data variables, ensuring comprehensive tracking of modifications.

* lint fix

* Update tma_load function to use const qualifier for global memory pointer

* Changed the parameter type of gmem_ptr in the tma_load function from void* to void const* to enhance type safety and clarity in memory operations.
* This modification ensures that the function correctly handles read-only global memory pointers, aligning with best practices in CUDA programming.

* Remove commented-out code and reorder transformations in OptimizeForTarget function for clarity

* Refactor buffer marking logic in annotate_read_only_params.cc to improve accuracy in identifying written variables. Update OptimizeForTarget function to reorder transformations for better clarity.

* [Refactor] Phase out the primitives folder since its design has been merged into tileop (#1429)

* Phase out primitives

* revert changes

* Refactor GemmWarpPolicy method signature for clarity

Updated the `from_warp_partition` method in the `GemmWarpPolicy` class to return the type `GemmWarpPolicy` instead of a string, enhancing type safety and clarity in the codebase. Removed an unnecessary blank line for improved readability.

* fix

* [CI]: Bump actions/upload-artifact from 5 to 6 (#1431)

Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 5 to 6.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [CI]: Bump actions/download-artifact from 6 to 7 (#1432)

Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 6 to 7.
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](https://github.com/actions/download-artifact/compare/v6...v7)

---
updated-dependencies:
- dependency-name: actions/download-artifact
  dependency-version: '7'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Bugfix] Convey  `compile_flags` to ffi compilation path with pass_configs (#1434)

* [Enhancement] Add device compile flags support in pass configuration

* Introduced `kDeviceCompileFlags` option in the pass configuration to allow additional device compiler flags for CUDA compilation.
* Updated the `tilelang_callback_cuda_compile` function to merge extra flags from the pass configuration, enhancing flexibility in compiler options.
* Modified the `JITKernel` class to handle device compile flags appropriately, ensuring they are included during compilation.
* Documented the new pass configuration key for clarity on usage and expected input formats.

* lint fix

* [Refactor] Simplify compile_flags handling in JIT functions

* Removed redundant string check for compile_flags in the compile, jit, and lazy_jit functions, ensuring compile_flags is consistently treated as a list.
* Updated the JITKernel class to handle compile_flags as a list when a string is provided, enhancing code clarity and maintainability.

* lint fix

* fix

* [Enhancement] Improve buffer usage tracking in MakePackedAPI (#1435)

* Added detailed logging for data and shape variable parameters during buffer usage detection in the MakePackedAPI function.
* Refactored the UsedBufferDetector to differentiate between used parameters by data and shape variables, enhancing clarity in buffer management.
* Updated logic to ensure minimal carrier buffers are selected for shape symbols, improving the efficiency of parameter handling.

* [Enhancement] Improve InjectAssumes logic and make assumes work after SplitHostDevice (#1405)

* [Refactor] Refactor InjectAssumes logic and make assumes work after SplitHostDevice

* address comments

* fix

* fix submodule

* fix

* fix 3rdparty

* [Enhancement] Include PrimFunc name in memory cache logs for better debugging (#1437)

* Added the `get_prim_func_name` utility to extract human-readable function names from TVM PrimFuncs.
* Updated memory cache logging in `AutoTuner` and `KernelCache` classes to include the kernel name, improving clarity during cache hits.
* Enhanced debug logging to provide more informative messages when checking disk cache for kernels.

* [CI] Update lint dependencies and fix lint on trunk (#1433)

* [CI] Update pre-commit hooks

* [Lint] Pass correct `exclude-header-filter` to `clang-tidy`

* [Lint] Download latest `run-clang-tidy` script

* [CI] Show compile commands

* [CI] Add output grouping to GHA

* [Lint] Re-order pre-commit hooks

* [Enhancement] Refactor vectorization checks in loop_vectorize (#1440)

* Introduced a new function, IsExprInvariantInVectorBoundary, to encapsulate the logic for checking if an expression is invariant within vector boundaries, improving code clarity and reusability.
* Updated the existing vectorization logic to utilize this new function, streamlining the process of determining vectorization feasibility based on boundary conditions.
* Enhanced comments for better understanding of the vectorization criteria and mathematical rationale behind the checks.

* Enhance vectorized conversion support (#1438)

* [Feature] Support region as input of T.cumsum (#1426)

* [Feature] Support region as input of T.cumsum

- Extend T.cumsum to accept BufferRegion and BufferLoad inputs in addition to Buffer
- This enables operations on buffer slices/regions like:
  T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0)
- Update cumsum_fragment to handle region inputs properly
- Add comprehensive tests for 1D and 2D region inputs including normal and reverse modes

Fixes #879

* Fix formatting and add docstring for cumsum_fragment

- Add comprehensive docstring for cumsum_fragment function
- Format code according to ruff style guidelines

* Fix CodeRabbit review issues

- Fix negative dimension bounds check (dim < -len(shape) instead of dim <= -len(shape))
- Add src/dst shape compatibility validation for out-of-place cumsum
- Update copy() type annotation to accept BufferRegion as dst parameter
- Fix test in-place mutation issues by using out-of-place cumsum operations
- Add non-divisible size test cases for tail region coverage

* Fix out-of-bounds access in region tests

- Add bounds clamping using T.min() for chunk_end calculations
- Prevents accessing beyond tensor bounds for non-divisible sizes
- Matches reference implementation behavior
- Fixes both 1D and 2D region test cases

* Fix region test: use simple slice expressions instead of T.min()

- Remove T.min() which cannot be used directly in slice indices
- Use chunk_start + chunk_size form instead
- Rely on system's automatic bounds checking for non-divisible sizes
- Update comments to reflect this approach

* Fix cumsum region: use region extents in lowering and update tests for shared memory

* Simplify fragment scope check using is_fragment()

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>

* [Fix] Fix analyzer bind conflicting (#1446)

* [Refactor] Reduce direct dependency on PyTorch due to its limited type support (#1444)

* [Enhancement] Update KernelParam to use tvm.DataType directly and add torch_dtype conversion method

- Changed dtype in KernelParam from torch.dtype to tvm.DataType to support a wider range of data types and prevent information loss during conversions.
- Added a new method, torch_dtype, to convert tvm.DataType back to torch.dtype for tensor creation.
- Updated various adapters to utilize the new torch_dtype method for parameter type conversion during initialization.

* [Enhancement] Refactor CUDA type handling and add support for FP4 and FP8 types

- Renamed functions for clarity: GetFP8Type, GetFP6Type, and GetFP4Type are now GetTileLangFP8Type, GetTileLangFP6Type, and GetTileLangFP4Type respectively.
- Enhanced FP4 type handling to support additional lane sizes (2, 4, 8, 16, 32, 64).
- Updated CUDA code generation to include new FP8 and FP4 types, ensuring proper type handling in PrintType and related functions.
- Introduced new structures for FP8 types in cuda_fp8.h to facilitate better memory management and type packing.
- Added methods in KernelParam and tensor utilities to recognize and handle float4 types, improving compatibility with PyTorch.
- Enhanced logging for debugging purposes in various CUDA functions to track type handling and memory operations more effectively.

* lint fix

* Remove unnecessary logging statements from CUDA code generation and delete obsolete matrix multiplication test file.

* [Enhancement] Add support for FP4 and FP8 types in CUDA code generation

- Enhanced PrintVecElemLoad and PrintVecElemStore functions to handle new FP4 types.
- Updated arg_binder to allow float4 to match int8 at runtime, improving compatibility with PyTorch.
- Modified loop_vectorize to account for buffer dtype lanes in vectorization calculations.
- Refactored tensor type mapping to support new float4 and float8 types, ensuring correct type handling in tensor operations.
- Added tests for FP4 and FP8 copy operations to validate functionality and integration with existing workflows.

---------

Co-authored-by: Zhiwen Mo <zm125@ic.ac.uk>

* [Refactor] Use `pytest.mark.parameterize` to speedup parallel testing (#1447)

* Refactor GEMM tests to use parameterized pytest fixtures

- Converted multiple test cases for GEMM operations in `test_tilelang_tilelibrary_gemm_sp.py` to use `pytest.mark.parametrize` for better maintainability and readability.
- Similar refactoring applied to `test_tilelang_tilelibrary_gemm_sp_v2.py`, consolidating test cases for `run_gemm_ss`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` into parameterized tests.
- This change reduces code duplication and enhances the clarity of test configurations.

* Update testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* [Docs] Improve installation instructions for developers (#1450)

* [Feat] Integrate Z3 in TVM Arith Analyzer (#1367)

* [Bugfix] Improve autotune from elementwise_add function in examples (#1445)

* Remove JIT decorator from elementwise_add function in examples

* fix kernel compilation without autotune

* Refactor main function to accept parameters and update tests for autotune option

* Refactor autotune test function for morden style

* [Language] Introduce `T.annotate_restrict_buffers` (#1428)

* [Enhancement] Introduce non-restrict parameter support in code generation

- Added a new PrimFunc-level attribute `tl.non_restrict_params` to specify handle Vars that should not be marked with the restrict qualifier during code generation.
- Updated `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to handle non-restrict parameters, ensuring proper treatment of overlapping buffer aliases.
- Implemented a new annotation function `annotate_restrict_buffers` to facilitate the marking of buffer parameters as non-restrict.
- Enhanced the `SplitHostDevice` transformation to propagate non-restrict parameters from host to device functions.
- Added a new transform function `HoistNonRestrictParams` to manage non-restrict parameters effectively.

* [Enhancement] Improve HoistNonRestrictParams transformation

- Updated the HoistNonRestrictParams function to recursively collect all `tl.non_restrict_params` annotations from nested blocks, enhancing flexibility in annotation placement.
- Introduced a new NonRestrictCollector class to manage the collection and deduplication of non-restrict parameters.
- Modified the SplitHostDevice transformation to remove the non-restrict attribute from the host-side PrimFunc after propagation to device kernels.
- Adjusted the LowerAndLegalize function to directly apply the HoistNonRestrictParams transformation without exception handling, streamlining the process.

* [Refactor] Simplify non-restrict parameter handling in code generation

- Removed unnecessary normalization logic and associated data structures from `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP`.
- Streamlined the handling of non-restrict parameters by directly inserting them into the `non_restrict` set, improving code clarity and maintainability.
- Updated conditional checks to eliminate redundant checks against normalized names, enhancing performance and readability.

* [Dependency] Update TVM subproject to latest commit 68aa8461

- Updated the TVM subproject to the latest commit, ensuring compatibility with recent changes and improvements.
- Refactored non-restrict parameter handling in `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to enhance code clarity and maintainability.
- Adjusted the `SplitHostDevice` transformation to streamline the propagation of non-restrict parameters.

* fix

* [Analyzer] Require loop extent > 0 when entering loop (#1451)

* Updat ROCm CI to Nightly-ROCm-7.1 (#1449)

* [Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.

* [Issue Template] Enable blank issues in GitHub issue template(#1453)

* [CI] Moved the clang-tidy step to after pip install (#1456)

* [Bug] Fix tvm build script when patchelf is not found #1459)

* [Analyzer] Fix floordiv & floormod bug in z3 prover (#1458)

* fix floordiv & floormod in z3 prover

* fix lint error

* [Cache] Rename sparse compress cache directory (#1460)

* Enhance cache directory structure by including version information in sparse.py to ensure separate caches for different versions.

* Fix formatting in sparse.py by adding a newline for improved readability and consistency.

* [Language]Adds a random number generation capability through curand_kernel (#1461)

* add curand.{curand_init, curand}

* run format.sh

* add default value for curand_init & add test for curand

* Update testing/python/language/test_rand.py

Remove unused thread binding

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* remove unused library

* enable tilelang cache for testing

* run format.sh

* Revert "run format.sh"

This reverts commit 5afaff782f31cdf653e2c45b469da8dead228b8a.

* Revert "enable tilelang cache for testing"

This reverts commit c277a43e77938bd88d47a108dd1bd65734d4a1ae.

* Revert "remove unused library"

This reverts commit 568ad20611f039380113937fd131151a2bffd801.

* run format.sh

* ensure FreshName for __philox_state

* ensure FreshName for __philox_state
…
chengyupku added a commit to tile-ai/tilescale that referenced this pull request Feb 6, 2026
* Enhance threadblock swizzle templates with default offset parameter and streamline parser.py for better readability

* [Cache] Rename sparse compress cache directory

* Temporarily exclude sink tests from non-distributed example tests in CI to address timeout issues

* [DeepEP] Move deepep benchmark to example and allow compatible with new version DeepEP

* [Feat] Enhance `T.st` to support intra-node store to peer's symm memory

* use strided loop to simplify get_dispatch a bit

* [Feat] Support warp reduce operators

* draft notify dispatch

* rename and refactor `T.barrier/sync_blocks`

* fix prev typo

* [Feat] Add `get_device_tensor` function and related test

* support elect_one_sync() and add test

* draft dispatch

* suupport ld, st, warp_sync, continue and add test

* support warp vote and add test

* support device-side wait_ne

* refactor T.wait_* and refine dispatch test logic

* intra-node dispatch test passed

* draft combine

* support massage-only debug print

* intra-node combine test passed

* unify dispatch, migrate topk_idx to u64, support cached dispatch

* Refactor to pre-alloc buffers and expose interface, add benchmark

* remove redundant test

* update doc

* use int4 vectorization for dispatch

* use comm_stream for comm kernels

* optimze dispatch perf via skipping tensor validation

* add dispatch benchmark result

* make rank as an argument of the kernel

* use cuda postproc for vectorization in combine

* support int4 ld/st ptx in cuda template

* [Feat] Support auto vectorization for ld/st to optimize combine to surpass deepep

* lint

* upd doc

* make ci happy

* fix review issues

* fix import error

* Add DeepEP submodule and installation script for CI

* fix ci bug

* [Sync] Merge mainstream TileLang TVM-FFI features into TileScale (#47)

* [Example] Add GQA decoding kernel with varlen page table (#1265)

* [Example] Add page table for gqa decode

* [Example] Page table for varlen decoding

* [Lint]

* [Refactor] Remove redundant code

* [Lint]

* [Lint]

* [Lint]

* [Refactor] add support for numpy dtype conversion (#1255)

* add typing stub for tir.ir

* remove idents

* minor update

* [Refactor] add numpy conversion for dtype

* fix lint error

* remove unused np.float_ in dtype conversion

* fix type in np.int_

* fix typo

* minor fix

* remove debug files

* [EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability (#1148)

* Keep the max of all blocks seen in scores_max for stability

* ruff formatting

* [Docs] Improve Installation Guide (#1270)

* [Docs] Improve installation guide

* address comments

* [Enhancement] Keep max score attention across blocks in FlashAttention for better numerical stablity (#1269)

* Implement max score retention across blocks in FlashAttention for improved stability

* fix manual pipeline parameters

* Update examples/flash_attention/example_gqa_fwd_varlen.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix typo

* more

* fix a previous typo

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* [Bugfix] Fix multiple cg defination when using T.sync_grid (#1272)

* [Minor] Remove from __future__ import annotations for python 3.8 (#1273)

* [BugFix] Adding extra parameters into autotune hashkey (#1274)

* [BugFix] Adding extra parameters into autotune hashkey

* lint

* None check

* check serializable

* Fix various issues under `int64_t` static and dynamic shape. (#1218)

* Fix various issues under int64_t static and dynamic shape.

* Resolve reviewed issues.

* Add unit test.

* fix

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>

* Bug fix for Gated Delta Net benchmark script (#1267)

* fix argument order for fla chunk_gated_delta_rule_fwd_h

* explicit import assert_similar from utils

* rename utils module to avoid name clash

* set store_final_state and save_new_value to True

* fix

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>

* [Bugfix] Minor fix for some cases (#1278)

* [Language] Add shape check in `T.view/reshape` (#1277)

* [Language] Add shape check in T.view/reshape

* address comments

* [FFI] Use tvm ffi as the default execution backend (#1259)

* [Refactor] Update FFI type handling and simplify argument management

* Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity.
* Updated function registration in `runtime.cc` to utilize canonical names for better consistency.
* Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled.
* Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection.
* Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity.

* [Update] Sync TVM submodule and enhance kernel source handling

* Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes.
* Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging.
* Commented out the main execution call in test files to prevent unintended execution during testing.
* Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues.
* Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends.

* [Refactor] Clean up imports and improve code formatting

* Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code.
* Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency.
* Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality.

* Update execution backend options and improve resolution logic

- Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target.
- Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions.
- Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target.
- Updated documentation to reflect changes in execution backend options and their defaults.

* lint fix

* fix

* Enhance argument handling in CUDA and HIP runtime modules

- Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime.
- Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers.
- Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks.

* lint fix

* lint fix

* lint fix

* lint fix

* minor fix

* fix

* recover check

* Refactor argument binding and validation in `arg_binder.cc`

- Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers.
- Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards.
- Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling.
- Minor adjustments in test files to streamline kernel execution and improve readability.

* lint fix

* stride fix

* minor fix

* fix

* lint fix

* lint fix

* Add CUDA stream access policy window helpers and integrate with L2 persistent cache management

- Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage.
- Updated runtime files to include new FFI packed functions for managing stream attributes.
- Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown.
- Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source.

* check with symbolic

* support null ptr

* Update CMakeLists and lower.py for code generation and subproject status

- Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support.
- Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility.
- Marked the TVM subproject as dirty to indicate local modifications.

* lint fix

* Update comments for clarity in quickstart.py

* [Bugfix] Supply missing `T.print` for bool type (#1279)

* fix for bool dtype

* lint fix

* fix

* ci fix

* [Fix] Fix memory leak bug (#1281)

* add typing stub for tir.ir

* remove idents

* minor update

* [Refactor] add numpy conversion for dtype

* fix lint error

* remove unused np.float_ in dtype conversion

* fix type in np.int_

* fix typo

* minor fix

* remove debug files

* fix memory leak bug

* fix lint error

* add comments

* fix lint error

* remove duplicated, because tilelang doesn't dependent deprecated

* [Enhancement] Enhance CUDA compilation by integrating pass context configuration (#1283)

- Updated the `tilelang_callback_cuda_compile` function to accept a `pass_config` parameter, allowing for more flexible compilation options.
- Introduced handling for fast math and PTXAS options based on the provided pass configuration.
- Modified the CUDA build process in `rt_mod_cuda.cc` to utilize the current pass context, improving the integration of compilation settings.
- Refactored NVCC command construction to use a dedicated function for better clarity and maintainability.

* Fix the bug in issue #1266 (#1284)

Co-authored-by: cheeryBloosm <liu_yu_hao@126.com>

* [Language][UX] Nested loop checker in pre-lowering stage (#1288)

* [Language][UX] Nested loop checker in pre-lowering stage

* rename

* comment

* address comments

* [Compatibility] Support CUDA 11.3 (#1290)

* [Feat] Add support for using `T.Tensor(n * 2 + 1)` in function annotation (#1285)

* [Feature] Add support for A: T.Tensor(n + 1) and A: T.Tensor(2*n)

* issue fix

* fix

* fix

* decreate nproc for debugging

---------

Co-authored-by: Lei Wang <leiwang1999@outlook.com>

* [Feat] add support for passing reference in T.Var annotation (#1291)

* [Enhancement] Shared Memory Size Can be Dynamic (#1294)

* bugfix

* lint fix

* test

* lint fix

* increate procs

* recover

* [Fix] Remove unused let_bindings_ in CodeGenC to fix #1300 (#1305)

* [Feat] add missing support of uint32x2

* [Feat] Add `T.Ref` annotation and tests

* fix lint error

* minor update for error message on twice decl

* Remove unused let_bindings_ in CodeGenC to fix #1300

* [Bugfix] Fallback to the old AtomicAdd implementation for legacy architectures (#1306)

* [Fix] Fix frame scope error in T.macro (#1308)

* [Fix] Fix #1307 by adding macro inside function

* fix lint error

* add comments and fix lint error

* Remove debug print from enter_frame method

Removed debug print statement from enter_frame method.

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [WIP] support more dtypes for tcgen05 (#1229)

support ld with pack for fp32 dtype

add dump

add tempalte expand

remove unused dtype and change to rebased apis

* Improve memory access safety and `T.assume` handling (#1292)

* Improve memory access safety and T.assume handling

* Improve memory access safety and T.assume handling

* bugfix

* lint fix

* bugfix

* bugfix

* refactor legalize safe memory access pass

---------

Co-authored-by: Lei Wang <leiwang1999@outlook.com>

* [Bugfix] Fix autotune cache (#1315)

* [Refactor] Backup Analyzer to get the appropriate arith informations (#1311)

* [Refactor] Update Vectorization Functions to Accept Analyzer Parameter

- Modified `VectorizeLoop` and related functions to accept an `arith::Analyzer` parameter, enhancing their capability to perform analysis during vectorization.
- Updated multiple instances in `copy.cc`, `fill.cc`, `parallel.cc`, and layout inference files to utilize the new analyzer parameter for improved performance and correctness.
- Ensured consistency across vectorization logic by integrating the analyzer into existing workflows, facilitating better optimization opportunities.

* [Fix] Corrected PostOrderVisit call in loop_vectorize.cc

- Updated the PostOrderVisit function to analyze the body of the loop node instead of the node itself, ensuring proper handling of nested loops during vectorization analysis.

* fix

* lint fix

* fix

* Revert "[WIP] support more dtypes for tcgen05 (#1229)" (#1323)

This reverts commit 0d101c110f74ebf2ef8c11a5ece9dfb314b48baa.

Co-authored-by: Zhiwen Mo <zm125@ic.ac.uk>

* [CI]: Bump actions/checkout from 5 to 6 (#1319)

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [CI]: Bump pypa/cibuildwheel from 3.2 to 3.3 (#1318)

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Installation] Fix building using customized TVM path (#1326)

* [Release] Allow developer with write permission to trigger wheel release (#1322)

* [Feat] Support warp reduce (#1316)

* [Feat] Support warp reduce

* lint

* add test

* lint

* [Enhancement] Support more dtype in `T.print` (#1329)

* [Enhancement] Support more dtype in `T.print`

* upd

* upd

* [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape (#1321)

* [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape

* remove debug lines

* remove rubbish

* Fix decorator syntax for atomic_different_memory_orders_program

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Fix] fix wrong uint narrowing bug in tvm in #1310 (#1320)

* [Refactor] Disable strided buffer load inside tvm (#1301) (#1332)

* [Refactor] Moving `NormalizeToBufferRegion` and `MakeAccessPtrFromRegion` to utils (#1333)

* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix

* [Fix] Fix bug copying from or to local buffer (#1304) (#1324)

* [Fix] fix copy from or to local buffer (#1304)

* fix lint error

* minor fix testing script

* [Language][UX] Semantic check for parallel fragment access (#1338)

* Add unit tests for T.assume (#1341)

* Add test for T.assume

* Add unit test for T.assume

* Add unit test for T.assume

* Add unit tests for T.assume

* Remove debug print for kernel source

Remove print statement for kernel source in tests.

* Update test_tilelang_language_assume.py

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Feat] Extend LegalizeNegativeIndex to support buffer store stmts (#1339)

This commit enhances the LegalizeNegativeIndex transformation pass to handle
both buffer load and store operations with negative indices and adds some
test cases.

* [Refactor] Phaseout vmap for Tile Operators (#1334)

* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix

* Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations.

* fix

* Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions.

* fix

* fix

* test fix

* lint fix

* Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management.

* fix

* lint fix

* fix

* fix

* test fix

* lint fix

* lint fix

* minor fix

* fix

---------

Co-authored-by: Zhiwen Mo <zm125@ic.ac.uk>

* [Enhancement] add more dtype and fix mma.ws for fp16 for tcgen05 (#1327)

* feat: add fp8 variants; add placeholder for fp6/fp4 in meta

support ld with pack for fp32 dtype

add dump

add tempalte expand

remove unused dtype and change to rebased apis

* fix: when atom-m!=128, enable_ws

* fix: typo in tcgen05 meta; dispatch in gemm sm100

* [Refactor] Enhance CopyNode's IterVar Creation and Range Handling (#1346)

* [Refactor] Enhance CopyNode's IterVar Creation and Range Handling

This commit refines the `MakeIterVars` method in `CopyNode` to select base ranges based on memory scope levels, ensuring that the chosen ranges are not smaller than the original source ranges. Additionally, it updates the Python `copy` function to clarify range handling, including broadcasting logic and extent alignment. These changes improve the robustness and clarity of the copy operation's implementation.

* test fix

* [Fix] Fix missing `not` rewrite in frontend (#1348)

* [Enhancement] Add support for k_pack in gemm_mfma (#1344)

* add support for k_pack

* support benchmark on ROCm

* fix format

* Add sparse fine-tuning kernel for deepseek sparse attention to example (#1296)

* [EXAMPLE] add example for dsa sparse finetuning

* [Refactor]

* [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder (#1352)

* [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder

This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase.

* [Enhancement] Update matmul kernel and optimize argument binding

This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code.

* lint fix

* [Enhancement] Add tensor checks documentation and improve argument binding assertions

This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code.

* [Enhancement] Update .gitignore and refine matmul kernel for improved performance

This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users.

* lint fix

* lint fix

* [Refactor] Simplify tensor_null_test function and remove ptr_null_test

This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations.

* lint fix

* fix

* [Refactor] Simplify index sign state handling in LegalizeNegativeIndex (#1354)

This commit refines the logic for determining the sign state of indices in the LegalizeNegativeIndex transformation. It prioritizes vector patterns, specifically Ramp and Broadcast nodes, to avoid compile-time lane queries. The handling of scalar indices is also streamlined, ensuring clearer diagnostics when non-negativity cannot be proven. These changes enhance the robustness and clarity of index handling in the transformation pass.

* [Enhancement] Improve error handling and assertion messages across runtime and argument binding (#1356)

This commit enhances the error handling mechanisms in the runtime by introducing CPU-safe runtime helpers and refining assertion messages in the CodeGenCHost and ArgBinder. It includes structured packed error messages for various conditions, improving clarity in diagnostics. Additionally, the CMake configuration is updated to always include necessary runtime helpers, ensuring consistent error reporting. The changes aim to provide clearer feedback during runtime errors and improve the overall robustness of the argument binding process.

* [Bugfix] Disable floordiv optimization due to integer overflow risk (#1355)

* disable overflow-prone floordiv optimization in lower_intrin.cc

* disable overflow-prone floordiv optimization in lower_intrin.cc

* [Bugfix] Fix the jit_kernel issue (#1357)

* [Bugfix] Fix the jit_kernel issue

* Update README.md

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Refactor] Update Fragment Indexing in ParallelOpNode's InferLayout Method (#1359)

This commit refines the Fragment creation process in the InferLayout method of ParallelOpNode. It removes the unnecessary forward_index array and utilizes default fragment indexing for consistency with other operations. Additionally, it binds the thread range to enhance comparability across different operations.

* [Analysis] Enhance NestedLoopChecker with tile op cases (#1358)

* [Analysis] Enhance NestedLoopChecker with tile op cases

* fix tileop issue

* [Language] support `T.gemm_sp_v2` on sm80 and sm89 (#1056)

* [misc] add a cpp side wrapper for gemm_sp_py

* [misc] typing

* [IR] bind GemmSPWarpPolicy

* [chore] add wrapper code

* [IR] fix GemmSPWarpPolicy

* [codegen] apply ptxas instructions

* [intrinsic] add typical (unused) mma layout

* [template] add uint16 debug func

* [intrinsic] add b matrix layout

* [gemm_sp] enable fp16/bf16 on sm8x

* [layout] refactor fp16/bf16 layout

* [gemm_sp] enable int8

* [chore] update test case dtype

* [gemm_sp] enable fp32

* [layout] refactor layouts

* [intrinsic] enable ldmatrix for mat A

* [layout] enable ldsm for matrix b

* [layout] add ldmatrix for fp32 and fp8

* [chore] refine

* [chore] refactor

* [chore] add fp8 efactor

* [chore] refactor

* [chore] add remove negative zero util

* [example] add a custom compress kernel

* [chore] minor update

* [test] refactor gemm_sp test

* [refactor] make metadata layout func

* [example] add option for using cutlass layout

* [doc] add a gemm_sp doc

* [doc] minor polish

* [chore] remove unused

* [bugfix] fix non replicate b case

* [test] refactor

* [chore] add a check

* [bugfix] fix util bug

* [wip] init a new test case for v2

* [chore] minor refactor

* [chore] minor update

* [bugfix] enable 16bit rs

* [language] enable rs

* [language] enable gemm_sp_sr

* [language] enable gemm_sp_rr

* [test] enable more tests

* [tvm] update ffi binding

* [chore] remove print

* [chore] fix benchmark script

* [lint] precommit lint

* [chore] apply feedback

* [test] use arch 8.0

* [chore] rollback ::ordered_metadata for backward compatibility

* [bugfix] fix captialized

* [example] keep gemm_sp on hopper

* [test] fix no fp8 normal kernel

* [test] reduce matmul size to satisfy accum error

* [test] use cal_diff for assertion

* [bugfix] expand float8 type

* [lib] add make_int4 for short type

* [language] add transpose E

* [bugfix] fix wrong var

* [format] format

* [chore] refactor binding

* [chore] fix wrong passing var

* [Bugfix] Update TIR registration for GemmSPPy to use tile operation (#1361)

* [Enhancement] Implement dynamic unroll factor in CUDA code generation (#1360)

* [Enhancement] Implement dynamic unroll factor in CUDA code generation

This commit introduces support for specifying a dynamic unroll factor in the CUDA code generation. The `unroll_factor` map is added to store unroll factors for loop variables, allowing for more flexible and optimized loop unrolling. Additionally, the `unroll` function is integrated into the loop language, enabling users to define unroll factors directly in their code. This enhancement improves performance by allowing tailored unrolling strategies based on specific loop characteristics.

* lint fix

* [Bugfix] Correct initialization of non-zero counters in custom compress kernel and update TIR registration for gemm_sp_py to use the correct tile operation

* [CI] [pre-commit.ci] autoupdate (#1362)

updates:
- [github.com/pre-commit/mirrors-clang-format: v21.1.2 → v21.1.6](https://github.com/pre-commit/mirrors-clang-format/compare/v21.1.2...v21.1.6)
- [github.com/astral-sh/ruff-pre-commit: v0.14.3 → v0.14.7](https://github.com/astral-sh/ruff-pre-commit/compare/v0.14.3...v0.14.7)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Bugfix] Remove debug print in PyStmtFunctionVisitor  (#1363)

* [Debug] Always include line info in NVCC command for improved profiling and mapping (#1364)

* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py (#1365)

* [Enhancement] Add DISABLE_CACHE environment variables (#1368)

* [Refactor]: Remove useless include in atomicadd_vectorize.h (#1371)

* [Refactor] Generalize fp8 process (#1372)

* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py

* [Enhancement] Extend support for float8 data types in GEMM operations

- Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`.
- Refactored condition checks in `checkWgmma` methods to simplify float8 type handling.
- Adjusted test cases to ensure compatibility with the new float8 types in tile language examples.

* lint fix

* [Layout] Enhance Free Layout Inference (#1375)

* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py

* [Enhancement] Extend support for float8 data types in GEMM operations

- Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`.
- Refactored condition checks in `checkWgmma` methods to simplify float8 type handling.
- Adjusted test cases to ensure compatibility with the new float8 types in tile language examples.

* lint fix

* [Enhancement] Add injective layout detection and exception handling

- Introduced `DetectInjective` method in `FragmentNode` to check for injective layouts.
- Added `LoopLayoutInjectiveException` to handle errors related to non-injective layouts.
- Updated `InferLayout` methods in `ParallelOpNode` to utilize injective checks and log relevant information.
- Refactored layout inference queue management to use `std::deque` for improved performance and added prioritization logic for buffer layouts.

* remove debug print

* remove debug print

* remove debug print

* minor layout fix

* fix for T.view

* [Enhancement] Improve injective layout detection in FragmentNode

- Updated the `DetectInjective` method to handle symbolic dimensions more effectively by introducing a mechanism to collect symbolic shapes and adjust the detection level accordingly.
- Added logging for cases where the layout detection falls back to NoCheck due to symbolic dimensions.
- Minor update to the test file to include the tilelang testing module.

* [Refactor] Simplify layout inference for bulk copy operations

- Removed unnecessary conditions for bulk load/store operations in the layout inference logic.
- Streamlined the handling of layout application for bulk copy instances to enhance clarity and maintainability.

* remove debug print

* [Enhancement] Introduce layout-related exceptions and improve error handling

- Added `LayoutConflictException` and `LoopLayoutInjectiveException` classes for better exception management in layout operations.
- Updated `InferLayout` method in `ParallelOpNode` to throw `LoopLayoutInjectiveException` with detailed error information when injective layout checks fail.
- Removed redundant exception class definitions from `parallel.h` to streamline code organization.

* [Enhancement] Introduce buffer var lca analysis for pass plan buffer allocations (#1376)

* Update submodule TVM to latest commit and add PlanAndUpdateBufferAllocationLocation function to transform module

- Updated the TVM submodule to commit 3a32b763.
- Added a new function `PlanAndUpdateBufferAllocationLocation` in the transform module to facilitate buffer allocation planning within PrimFuncs.

* Refactor buffer allocation code for improved readability and consistency

- Updated formatting and spacing in `plan_update_buffer_allocation_location.cc` for better code clarity.
- Standardized the use of pointer and reference syntax across various class methods.
- Enhanced comments for better understanding of buffer allocation logic.
- Removed unnecessary lines and improved overall code structure.

* Refactor buffer allocation checks for improved clarity

- Replaced size checks with empty checks for `ffi::Array<Buffer>` in `plan_update_buffer_allocation_location.cc` to enhance code readability.
- Updated conditions in multiple methods to use `empty()` instead of comparing size to zero, streamlining the logic.

* [Tool] Provide layout visualization tool (#1353)

* Provide layout visualization tool

Adds a layout visualization tool to TileLang, which helps users understand and debug the layout transformations applied during compilation.

This tool visualizes the memory layout of tensors at different stages of the compilation process, allowing developers to identify potential inefficiencies and optimize their code for better performance.

The visualization can be enabled via a pass config option.

* format

* add layout visual example

* Adds vis extra with matplotlib dependency

* rafactor pass config name

* fix lint

* Enables configurable layout visualization formats

Allows users to specify the output formats (png, pdf, svg) for layout visualization through a pass config option.

This change provides more flexibility in how layout visualizations are generated, allowing users to choose the formats that best suit their needs.

It also fixes a bug where layout visualization was not correctly disabled when the config option was set to "false".

* Adds visual layout inference tool docs

* fix lint

* fix lint

* Rafactor configurable layout visualization formats

* fix lint

* fix typo

* add some comments

* fix lints

* add some warnings for user

* Moves layout visualization

* Refactors layout visualization pass configuration

Updates the layout visualization pass configuration to use boolean flag for enabling and a string for specifying formats.

* Enables multiple layout visualization formats

* Updates layout visualization docs

* Moves layout visualization to analysis

* [Release] Relax constraint of tvm-ffi to compatible version (#1373)

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>

* [Language] Tilelang LazyJIT Experimental Version (#1337)

* initial step

* modify builder

* scratch version of new frontend

* write some tests

* add many tests

* add typing stub for tir.ir

* remove idents

* minor update

* minor update

* First version of jitv2 (renamed to LazyJIT)

* fix pre-commit error

* minor fix

* fix lint error

* fix lint error

* Fix conditional check for PrimFunc instance

---------

Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>

* [Builder] Enhance variable name binding and scope management (#1378)

- Improved handling of TVM Var/Buffer names to prevent out-of-scope errors when reusing Python names across different for-frames.
- Added assertions to ensure variables are defined within the correct control flow frame, enhancing error checking and code reliability.

* [Bugfix] make cuda driver api compat with cuda12/13, along with tests (#1379)

* [Fix] typo in cuda attr (#1380)

* [Bugfix] make cuda driver api compat with cuda12/13, along with tests

* fix typo in cudaDevAttr

* [Language V2] Minor fix for complex annotations (#1381)

* [Release] Bump Version into 0.1.7 (#1377)

* Update VERSION to 0.1.7

* Update Python version in distribution scripts to support CPython 3.9 and log output

* [Typing] Enhance compatibility for advanced typing features in Python (#1382)

- Updated `allocate.py` and `annot.py` to improve compatibility with Python 3.9 and later by conditionally importing advanced typing features such as `TypeVarTuple`, `Unpack`, and `ParamSpec`.
- Added fallback imports from `typing_extensions` for environments using earlier Python versions.
- Improved handling of generic alias detection to ensure consistent behavior across different Python versions.

* [Bugfix][Build] Update CMake configuration to remove project root injection for sys.path (#1385)

* [Build] Update CMake configuration for tilelang_cython_wrapper installation

- Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib.
- Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules.
- Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects.

* [Build] Standardize output directories for tilelang libraries

- Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds.
- This change enhances organization and ensures that all build artifacts are located in a unified directory structure.

* [BugFix] Fix split kernel layout bug of GQA decode (#1386)

* [BugFix] Fix split kernel layout bug of GQA decode

* [BugFix] Avoid local with Parallel; use robust fragment instead

* [Enhancement] Add debug output methods for Layout and Fragment classes (#1392)

* [Doc] Update logging docs (#1395)

* [Enhancement] Refactor inflight computing to support dynamic pipeline extents (#1399)

* [Build] Update CMake configuration for tilelang_cython_wrapper installation

- Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib.
- Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules.
- Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects.

* [Build] Standardize output directories for tilelang libraries

- Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds.
- This change enhances organization and ensures that all build artifacts are located in a unified directory structure.

* [Refactor] Update TVM subproject and enhance pipeline loop handling

- Updated the TVM subproject to commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0.
- Added new fields to `PipelineAnnotation` and `RewrittenBlockInfo` structures to track original statement indices and improve async state management.
- Refactored `EmitImpl` and `PopulateWaitCounts` methods to enhance clarity and functionality, including better handling of commit groups and wait counts.
- Simplified access index calculations and strengthened analyzer constraints for loop bounds.

* [Cleanup] Remove license block and unused includes from inject_pipeline.cc

- Eliminated the Apache license block from the top of the file to streamline the code.
- Removed unused include directives for memory and stringstream to enhance code clarity and reduce unnecessary dependencies.

* [Refactor] Enhance transformation pipeline and test execution

- Added an additional Simplify transformation in the InjectSoftwarePipeline to improve optimization.
- Updated the test file to call `test_trival_pipeline()` directly, commenting out the previous main execution for better test isolation.

* [AMD] Fix 3 bugs when build docker on amd mi3x gpu (#1401)

* [Typo] Fix tilelang link in README.md (#1402)

* [Dependency] Update apache-tvm-ffi version to >=0.1.2 (#1400)

* [Dependency] Update apache-tvm-ffi version to >=0.1.2 in project files

* [Dependency] Update subproject commit for TVM to latest version afc07935

* [Enhancement] Add support for optional step parameter in loop constructs

- Updated loop creation functions to accept an optional step parameter, enhancing flexibility in loop definitions.
- Modified ForFrame implementations to utilize the new step parameter across various loop types including serial, parallel, and pipelined loops.
- Adjusted related vectorization transformations to accommodate the step parameter, ensuring consistent behavior in loop vectorization processes.

* lint fix

* [AMD] Enable FA2 fwd on AMD MI300X (#1406)

* enable FA2 on AMD MI300X

* make lint happy

* [TypoFix] fix typo for SM120 (#1408)

* [Doc] Minor documentation update (#1410)

* [Dependency] Add torch-c-dlpack-ext to project requirements (#1403)

* [Dependency] Add torch-c-dlpack-ext to project requirements

* Added torch-c-dlpack-ext to both pyproject.toml and requirements.txt to provide prebuilt torch extensions, which may prevent JIT compilation on first import of TVM FFI.

* [Build] Update manylinux images in project configuration

* Changed the manylinux image for x86_64 from "manylinux2014" to "manylinux_2_28" in both pyproject.toml and the Dockerfile to align with updated standards for compatibility and performance.

* [Build] Update CUDA repository configuration in pyproject.toml

* Changed the package manager command from `yum-config-manager` to `dnf config-manager` for adding the CUDA repository, ensuring compatibility with newer systems.

* fix

* [Build] Update CUDA repository to RHEL 8

* Changed the CUDA repository configuration in both pyproject.toml and the manylinux Dockerfile from RHEL 7 to RHEL 8, ensuring compatibility with newer systems.

* test: run out of space

* use cu130 to reduce size

* upd

* upd comment

* upd

---------

Co-authored-by: Your Name <wenji.yyc@alibaba-inc.com>

* [Dependency] Update TVM subproject to latest commit 2b1ead1a (#1412)

* [Enhancement] Introduce `T.__ldg` (#1414)

* [Enhancement] Add __ldg intrinsic for CUDA read-only cache loads

* Introduced the __ldg intrinsic to enable explicit read-only cached loads from global memory in CUDA.
* Updated the corresponding documentation and added support in both CUDA and HIP code generation.
* Enhanced the Python interface for __ldg to accept BufferLoad and Buffer types, improving usability.

* [Enhancement] Update formatting and linting rules in pyproject.toml; minor test adjustment

* Added new formatting rules in pyproject.toml to enforce consistent code style, including hanging indents and argument splitting.
* Updated test_tilelang_language_intrinsics_codegen.py to improve readability by adding a blank line before the main execution block.
* Refactored error messages in builtin.py for better clarity and consistency, ensuring proper formatting in function definitions and raising ValueErrors.

* lint fix

* [Enhancement] Improve vectorization invariant check (#1398)

* Improve loop vectorize

* Improve loop vectorize

* Improve loop vectorize

* Improve loop vectorize

* Improve loop vectorize

* Add some vectorize tests and comments

* [Lint] Phaseout Yapf format and embrace ruff format (#1417)

* [Atomic] Use ptr for atomicAdd dst instead of reference (#1425)

* [Enhancement] Update AtomicAdd function signature to accept pointer to destination

* Modified AtomicAdd in CUDA to take a pointer instead of a reference for the destination argument.
* Updated related code in atomicadd_vectorize.cc to ensure compatibility with the new signature.
* Adjusted Python interface in atomic.py to pass the destination by pointer, aligning with device function requirements.

* [Enhancement] Refactor AtomicAddRet function signature to accept pointer

* Updated AtomicAddRet in both CUDA and HIP to take a pointer instead of a reference for the address argument, improving consistency with the AtomicAdd function.
* Adjusted the implementation to ensure proper reinterpretation of the address type for atomic operations.

* lint fix

* [Enhancement] Refactor AtomicAddNode::MakeSIMTLoop to use destination pointer

* Updated the MakeSIMTLoop function to build a pointer to the destination element using tvm_access_ptr instead of loading the destination value directly.
* Simplified the handling of source and destination predicates, improving clarity and maintainability of the code.
* Ensured compatibility with the new pointer-based approach for atomic operations.

* lint fix

* test fix

* lint fix

* [CUDA] Add read-only parameter annotation for CUDA codegen (#1416)

* [Enhancement] Add read-only parameter annotation for CUDA codegen

* Introduced the `AnnotateReadOnlyParams` transformation to annotate read-only handle parameters in PrimFuncs, enabling the generation of `const` qualifiers in CUDA codegen.
* Updated `PrintFunctionSignature` and `AddFunction` methods to utilize the new attribute `tl.readonly_param_indices`, enhancing performance by allowing read-only cache loads.
* Modified the optimization pipeline to include the new annotation step, improving the overall efficiency of the code generation process.

* lint fix

* [Dependency] Update apache-tvm-ffi version to >=0.1.3

* Updated the version of apache-tvm-ffi in pyproject.toml, requirements.txt, and requirements-dev.txt to ensure compatibility with the latest features and fixes.
* Made adjustments in CUDA and HIP template files to use `const` qualifiers for global pointer parameters, enhancing code safety and clarity.

* lint fix

* [Enhancement] Refactor ReadWriteMarker for improved parameter handling

* Updated the ReadWriteMarker class to accept a set of parameter or data variables, enhancing its ability to track written variables.
* Introduced a new method, ResolveDataVarFromPtrArg, to resolve underlying buffer data from pointer-like arguments, improving accuracy in identifying written variables.
* Modified the MarkReadOnlyParams function to gather handle parameters and their corresponding buffer data variables, streamlining the process of determining read-only parameters.
* Enhanced the logic for identifying written variables to account for aliased data variables, ensuring comprehensive tracking of modifications.

* lint fix

* Update tma_load function to use const qualifier for global memory pointer

* Changed the parameter type of gmem_ptr in the tma_load function from void* to void const* to enhance type safety and clarity in memory operations.
* This modification ensures that the function correctly handles read-only global memory pointers, aligning with best practices in CUDA programming.

* Remove commented-out code and reorder transformations in OptimizeForTarget function for clarity

* Refactor buffer marking logic in annotate_read_only_params.cc to improve accuracy in identifying written variables. Update OptimizeForTarget function to reorder transformations for better clarity.

* [Refactor] Phase out the primitives folder since its design has been merged into tileop (#1429)

* Phase out primitives

* revert changes

* Refactor GemmWarpPolicy method signature for clarity

Updated the `from_warp_partition` method in the `GemmWarpPolicy` class to return the type `GemmWarpPolicy` instead of a string, enhancing type safety and clarity in the codebase. Removed an unnecessary blank line for improved readability.

* fix

* [CI]: Bump actions/upload-artifact from 5 to 6 (#1431)

Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 5 to 6.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [CI]: Bump actions/download-artifact from 6 to 7 (#1432)

Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 6 to 7.
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](https://github.com/actions/download-artifact/compare/v6...v7)

---
updated-dependencies:
- dependency-name: actions/download-artifact
  dependency-version: '7'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* [Bugfix] Convey  `compile_flags` to ffi compilation path with pass_configs (#1434)

* [Enhancement] Add device compile flags support in pass configuration

* Introduced `kDeviceCompileFlags` option in the pass configuration to allow additional device compiler flags for CUDA compilation.
* Updated the `tilelang_callback_cuda_compile` function to merge extra flags from the pass configuration, enhancing flexibility in compiler options.
* Modified the `JITKernel` class to handle device compile flags appropriately, ensuring they are included during compilation.
* Documented the new pass configuration key for clarity on usage and expected input formats.

* lint fix

* [Refactor] Simplify compile_flags handling in JIT functions

* Removed redundant string check for compile_flags in the compile, jit, and lazy_jit functions, ensuring compile_flags is consistently treated as a list.
* Updated the JITKernel class to handle compile_flags as a list when a string is provided, enhancing code clarity and maintainability.

* lint fix

* fix

* [Enhancement] Improve buffer usage tracking in MakePackedAPI (#1435)

* Added detailed logging for data and shape variable parameters during buffer usage detection in the MakePackedAPI function.
* Refactored the UsedBufferDetector to differentiate between used parameters by data and shape variables, enhancing clarity in buffer management.
* Updated logic to ensure minimal carrier buffers are selected for shape symbols, improving the efficiency of parameter handling.

* [Enhancement] Improve InjectAssumes logic and make assumes work after SplitHostDevice (#1405)

* [Refactor] Refactor InjectAssumes logic and make assumes work after SplitHostDevice

* address comments

* fix

* fix submodule

* fix

* fix 3rdparty

* [Enhancement] Include PrimFunc name in memory cache logs for better debugging (#1437)

* Added the `get_prim_func_name` utility to extract human-readable function names from TVM PrimFuncs.
* Updated memory cache logging in `AutoTuner` and `KernelCache` classes to include the kernel name, improving clarity during cache hits.
* Enhanced debug logging to provide more informative messages when checking disk cache for kernels.

* [CI] Update lint dependencies and fix lint on trunk (#1433)

* [CI] Update pre-commit hooks

* [Lint] Pass correct `exclude-header-filter` to `clang-tidy`

* [Lint] Download latest `run-clang-tidy` script

* [CI] Show compile commands

* [CI] Add output grouping to GHA

* [Lint] Re-order pre-commit hooks

* [Enhancement] Refactor vectorization checks in loop_vectorize (#1440)

* Introduced a new function, IsExprInvariantInVectorBoundary, to encapsulate the logic for checking if an expression is invariant within vector boundaries, improving code clarity and reusability.
* Updated the existing vectorization logic to utilize this new function, streamlining the process of determining vectorization feasibility based on boundary conditions.
* Enhanced comments for better understanding of the vectorization criteria and mathematical rationale behind the checks.

* Enhance vectorized conversion support (#1438)

* [Feature] Support region as input of T.cumsum (#1426)

* [Feature] Support region as input of T.cumsum

- Extend T.cumsum to accept BufferRegion and BufferLoad inputs in addition to Buffer
- This enables operations on buffer slices/regions like:
  T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0)
- Update cumsum_fragment to handle region inputs properly
- Add comprehensive tests for 1D and 2D region inputs including normal and reverse modes

Fixes #879

* Fix formatting and add docstring for cumsum_fragment

- Add comprehensive docstring for cumsum_fragment function
- Format code according to ruff style guidelines

* Fix CodeRabbit review issues

- Fix negative dimension bounds check (dim < -len(shape) instead of dim <= -len(shape))
- Add src/dst shape compatibility validation for out-of-place cumsum
- Update copy() type annotation to accept BufferRegion as dst parameter
- Fix test in-place mutation issues by using out-of-place cumsum operations
- Add non-divisible size test cases for tail region coverage

* Fix out-of-bounds access in region tests

- Add bounds clamping using T.min() for chunk_end calculations
- Prevents accessing beyond tensor bounds for non-divisible sizes
- Matches reference implementation behavior
- Fixes both 1D and 2D region test cases

* Fix region test: use simple slice expressions instead of T.min()

- Remove T.min() which cannot be used directly in slice indices
- Use chunk_start + chunk_size form instead
- Rely on system's automatic bounds checking for non-divisible sizes
- Update comments to reflect this approach

* Fix cumsum region: use region extents in lowering and update tests for shared memory

* Simplify fragment scope check using is_fragment()

---------

Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>

* [Fix] Fix analyzer bind conflicting (#1446)

* [Refactor] Reduce direct dependency on PyTorch due to its limited type support (#1444)

* [Enhancement] Update KernelParam to use tvm.DataType directly and add torch_dtype conversion method

- Changed dtype in KernelParam from torch.dtype to tvm.DataType to support a wider range of data types and prevent information loss during conversions.
- Added a new method, torch_dtype, to convert tvm.DataType back to torch.dtype for tensor creation.
- Updated various adapters to utilize the new torch_dtype method for parameter type conversion during initialization.

* [Enhancement] Refactor CUDA type handling and add support for FP4 and FP8 types

- Renamed functions for clarity: GetFP8Type, GetFP6Type, and GetFP4Type are now GetTileLangFP8Type, GetTileLangFP6Type, and GetTileLangFP4Type respectively.
- Enhanced FP4 type handling to support additional lane sizes (2, 4, 8, 16, 32, 64).
- Updated CUDA code generation to include new FP8 and FP4 types, ensuring proper type handling in PrintType and related functions.
- Introduced new structures for FP8 types in cuda_fp8.h to facilitate better memory management and type packing.
- Added methods in KernelParam and tensor utilities to recognize and handle float4 types, improving compatibility with PyTorch.
- Enhanced logging for debugging purposes in various CUDA functions to track type handling and memory operations more effectively.

* lint fix

* Remove unnecessary logging statements from CUDA code generation and delete obsolete matrix multiplication test file.

* [Enhancement] Add support for FP4 and FP8 types in CUDA code generation

- Enhanced PrintVecElemLoad and PrintVecElemStore functions to handle new FP4 types.
- Updated arg_binder to allow float4 to match int8 at runtime, improving compatibility with PyTorch.
- Modified loop_vectorize to account for buffer dtype lanes in vectorization calculations.
- Refactored tensor type mapping to support new float4 and float8 types, ensuring correct type handling in tensor operations.
- Added tests for FP4 and FP8 copy operations to validate functionality and integration with existing workflows.

---------

Co-authored-by: Zhiwen Mo <zm125@ic.ac.uk>

* [Refactor] Use `pytest.mark.parameterize` to speedup parallel testing (#1447)

* Refactor GEMM tests to use parameterized pytest fixtures

- Converted multiple test cases for GEMM operations in `test_tilelang_tilelibrary_gemm_sp.py` to use `pytest.mark.parametrize` for better maintainability and readability.
- Similar refactoring applied to `test_tilelang_tilelibrary_gemm_sp_v2.py`, consolidating test cases for `run_gemm_ss`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` into parameterized tests.
- This change reduces code duplication and enhances the clarity of test configurations.

* Update testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* [Docs] Improve installation instructions for developers (#1450)

* [Feat] Integrate Z3 in TVM Arith Analyzer (#1367)

* [Bugfix] Improve autotune from elementwise_add function in examples (#1445)

* Remove JIT decorator from elementwise_add function in examples

* fix kernel compilation without autotune

* Refactor main function to accept parameters and update tests for autotune option

* Refactor autotune test function for morden style

* [Language] Introduce `T.annotate_restrict_buffers` (#1428)

* [Enhancement] Introduce non-restrict parameter support in code generation

- Added a new PrimFunc-level attribute `tl.non_restrict_params` to specify handle Vars that should not be marked with the restrict qualifier during code generation.
- Updated `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to handle non-restrict parameters, ensuring proper treatment of overlapping buffer aliases.
- Implemented a new annotation function `annotate_restrict_buffers` to facilitate the marking of buffer parameters as non-restrict.
- Enhanced the `SplitHostDevice` transformation to propagate non-restrict parameters from host to device functions.
- Added a new transform function `HoistNonRestrictParams` to manage non-restrict parameters effectively.

* [Enhancement] Improve HoistNonRestrictParams transformation

- Updated the HoistNonRestrictParams function to recursively collect all `tl.non_restrict_params` annotations from nested blocks, enhancing flexibility in annotation placement.
- Introduced a new NonRestrictCollector class to manage the collection and deduplication of non-restrict parameters.
- Modified the SplitHostDevice transformation to remove the non-restrict attribute from the host-side PrimFunc after propagation to device kernels.
- Adjusted the LowerAndLegalize function to directly apply the HoistNonRestrictParams transformation without exception handling, streamlining the process.

* [Refactor] Simplify non-restrict parameter handling in code generation

- Removed unnecessary normalization logic and associated data structures from `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP`.
- Streamlined the handling of non-restrict parameters by directly inserting them into the `non_restrict` set, improving code clarity and maintainability.
- Updated conditional checks to eliminate redundant checks against normalized names, enhancing performance and readability.

* [Dependency] Update TVM subproject to latest commit 68aa8461

- Updated the TVM subproject to the latest commit, ensuring compatibility with recent changes and improvements.
- Refactored non-restrict parameter handling in `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to enhance code clarity and maintainability.
- Adjusted the `SplitHostDevice` transformation to streamline the propagation of non-restrict parameters.

* fix

* [Analyzer] Require loop extent > 0 when entering loop (#1451)

* Updat ROCm CI to Nightly-ROCm-7.1 (#1449)

* [Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update…
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant