Skip to content

Conversation

@Cunxiao2002
Copy link
Contributor

@Cunxiao2002 Cunxiao2002 commented Sep 15, 2025

Summary by CodeRabbit

  • New Features

    • Added an end-to-end INT4 dequantization + GEMM example with on-the-fly unpack/dequantize, validation against a reference, autotuning support, and a CLI for tuning and benchmarking.
  • Performance

    • Broadened the autotuner search space to improve tuning coverage and help discover better-performing configurations.
  • Tests

    • Added a CUDA-required validation test exercising the new INT4 example.
  • Chores

    • Pinned ml_dtypes dependency to >=0.5.3.

The version of ml_dtypes should be pinned in the dependency specification. If the version of ml_dtypes is too low, it may result in errors such as fp4 not being defined.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 15, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a new INT4 (w4a8) dequantization + autotuned GEMM example and CUDA test, expands the autotuner search space and changes wrapper return behavior in an existing FP4 Hopper example, and pins the ml_dtypes requirement to >=0.5.3.

Changes

Cohort / File(s) Summary of Edits
Autotuner adjustments (FP4 Hopper example)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
Broadened tuning grids (block_M, block_N, block_K, num_stages, threads), removed keys from the autotune decorator call, and changed the decorated wrapper to return kernel_func(...).prim_func instead of the callable.
New INT4 dequantization + GEMM example
examples/dequantize_gemm/example_dequant_gemm_w4a8.py
New example implementing INT4-packed-in-uint8 unpack helper (_tir_u8_to_i4_to_i8), TileLang/TIR conversion test and Torch reference converter, validation test, reference program, on-the-fly dequantize GEMM kernel generator (matmul_int8xint4), autotune configs, tuning and untuned execution paths, and CLI.
New CUDA test for the example
examples/dequantize_gemm/test_example_dequantize_gemm.py
Added test_example_dequant_gemm_w4a8() which imports the new example and invokes its main() (CUDA-required).
Dependency update
requirements.txt
Pinned ml_dtypes to ml_dtypes>=0.5.3.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant CLI as __main__
  participant Main as main()
  participant Kernel as matmul_int8xint4()
  participant Auto as autotune()
  participant Ref as ref_program()
  participant Torch as torch_convert()

  User->>CLI: Provide m,n,k and --tune
  CLI->>Main: main(m,n,k,tune)
  alt tune == True
    Main->>Kernel: matmul_int8xint4(..., tune=True)
    Kernel->>Auto: search(get_configs())
    Auto-->>Kernel: best config + prim_func
    Main-->>User: report best config, TFLOPs
  else tune == False
    Main->>Kernel: matmul_int8xint4(..., tune=False)
    Kernel-->>Main: prim_func (untuned)
    Main->>Torch: torch_convert(qB)
    Torch-->>Ref: B_int8
    Main->>Ref: ref_program(A, B_int8)
    Ref-->>Main: reference output
    Main-->>User: latency, validation status
  end
Loading
sequenceDiagram
  autonumber
  participant Tile as Tile Kernel
  participant SM as Shared/Local Buffers
  participant Unpack as _tir_u8_to_i4_to_i8
  participant GEMM as T.gemm

  Tile->>SM: Load A/B tiles into shared/local
  Tile->>Unpack: Depack uint8 INT4 -> int8 per element
  Unpack-->>SM: Provide dequantized B_local
  Tile->>GEMM: Accumulate A_tile x B_tile^T -> Ct
  Tile-->>Tile: Iterate over K blocks (num_stages, threads)
  Tile-->>Caller: Store Ct -> C
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • LeiWang1999

Poem

I nibble four bits from a cozy little byte,
Unpack them softly, ready for the fight.
Tiles hop in chorus, threads hum like spring rain,
Autotune finds rhythm, kernels sing again.
Hop, validate, benchmark — carrots for the brain. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[Example] add w4a8 gemm kernel" is concise and directly reflects the primary change in the diff—adding an example implementing a w4a8 (INT4-packed) GEMM kernel and associated test—so it accurately summarizes the main intent and is clear for teammates scanning history.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Cunxiao2002, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new optimized General Matrix Multiply (GEMM) kernel specifically designed for 4-bit weights and 8-bit activations (W4A8), enhancing the library's capabilities for quantized operations. It includes a dedicated example to demonstrate its implementation and performance tuning. Concurrently, an existing FP4 GEMM example has been improved by expanding its autotuning search space, aiming for better performance discovery, and adapting to recent API changes in the tilelang framework. Additionally, a minor dependency version constraint has been updated.

Highlights

  • New W4A8 GEMM Kernel: A new example has been added for a W4A8 (Weight 4-bit, Activation 8-bit) General Matrix Multiply (GEMM) kernel. This includes a _tir_u8_to_i4_to_i8 function for dequantizing packed uint8 values into int8 within the kernel.
  • Expanded Autotuning Configurations: The autotuning configuration space for the existing FP4 GEMM example (example_dequant_gemm_fp4_hopper.py) has been significantly broadened. This includes a wider range of block_M, block_N, block_K, num_stages, and threads parameters to explore for optimal performance.
  • Autotune API Refinement: Updates were made to the @autotune decorator usage in the FP4 GEMM example. The explicit keys argument was removed, and the return value of kernel_func now correctly accesses its .prim_func attribute, aligning with potential tilelang API changes.
  • Dependency Update: The ml_dtypes dependency in requirements.txt has been updated to specify a minimum version of 0.5.3.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new example for a w4a8 GEMM kernel and updates an existing fp4 example. While the changes to the existing fp4_hopper.py example are improvements to its autotuning configuration, the new example_dequant_gemm_w4a8.py file contains several issues. My review includes feedback on a critical bug that will cause an AttributeError, incorrect command-line argument handling, a fragile and inefficient reference implementation, and other opportunities for improving code quality and maintainability.

block_K=None,
num_stages=None,
threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The kernel_func function returns a tvm.tir.PrimFunc object. PrimFunc objects do not have a .prim_func attribute, so accessing it here will raise an AttributeError at runtime. You should return the PrimFunc object directly.

Suggested change
return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func
return kernel_func(block_M, block_N, block_K, num_stages, threads)

Comment on lines 75 to 91
def torch_convert(tensor):
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
i4_shifted = ((val >> (pos * 4)) & mask)
i4 = ((i4_shifted << 4) >> 4)

return i4.view(torch.int8)

N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The torch_convert function has a potential bug and is inefficient.

  1. Potential Bug: On line 78, val.view(torch.int8) reinterprets the bits of a torch.uint8 scalar as torch.int8. This is incorrect for values > 127. While current tests use a narrow range of inputs [0, 15], this is fragile and will produce incorrect results for the full int4 range.
  2. Inefficiency: The implementation uses nested Python loops, which is very slow for tensor operations in PyTorch. This can be vectorized for a significant performance improvement.

Consider replacing this function with a vectorized version like the one below, which is both correct for all input values and much faster:

def torch_convert(tensor: torch.Tensor) -> torch.Tensor:
    """Vectorized conversion of a uint8 tensor (packing two int4s) to an int8 tensor."""
    # tensor is (N, K_packed) uint8
    # output is (N, K_packed * 2) int8
    
    # Isolate low and high 4-bit nibbles
    low_nibbles = tensor & 0x0F
    high_nibbles = tensor >> 4
    
    # Sign-extend nibbles from 4-bit to 8-bit.
    # Values >= 8 (0b1000) are negative in int4.
    low_nibbles = low_nibbles - (low_nibbles >= 8) * 16
    high_nibbles = high_nibbles - (high_nibbles >= 8) * 16
    
    # Interleave the nibbles into a new last dimension
    unpacked = torch.stack([low_nibbles, high_nibbles], dim=-1)
    
    # Reshape to final (N, K_packed * 2)
    return unpacked.view(tensor.shape[0], -1).to(torch.int8)


M, N, K = args.m, args.n, args.k
# main(M, N, K, args.tune)
main(M, N, K, True) No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The main function is hardcoded to run with tune=True, which ignores the --tune command-line argument. The intended behavior, as suggested by the commented-out line above, is to use the parsed argument from args.tune.

Suggested change
main(M, N, K, True)
main(M, N, K, args.tune)

Comment on lines 10 to 12
"""
使用左移右移方法将打包在uint8中的4位有符号整数(INT4)转换为8位有符号整数(INT8)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments in this file are in Chinese. To ensure consistency and maintainability for all contributors, please translate them to English. This also applies to other Chinese comments in this file (e.g., lines 23-27, 120-121).

Comment on lines +111 to +112
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The reference program uses floating-point arithmetic for what should be an integer matrix multiplication. This can introduce unnecessary precision errors and is less efficient. It's better to perform the multiplication using integer types to ensure correctness and get a more reliable reference result.

Suggested change
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
C = torch.matmul(A.to(torch.int32), B.T.to(torch.int32))
C = C.to(torch.__getattribute__(dtypeC))

Comment on lines +165 to +166
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of B_dequantize_prev_local and the explicit copy on line 165 suggest manual software pipelining. This is likely redundant because the T.Pipelined construct is designed to manage this automatically. This manual approach can be removed to simplify the code and rely on the framework's pipelining. You can remove B_dequantize_prev_local and use B_dequantize_local directly in T.gemm.

Suggested change
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)

best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Bset latency: {best_latency}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the print statement. "Bset" should be "Best".

Suggested change
print(f"Bset latency: {best_latency}")
print(f"Best latency: {best_latency}")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (6)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (6)

11-12: Non-English comments should be in English for better collaboration.

The function docstring contains Chinese text. Consider using English for better accessibility to all contributors.

-    """
-    使用左移右移方法将打包在uint8中的4位有符号整数(INT4)转换为8位有符号整数(INT8)
-    """
+    """
+    Convert packed 4-bit signed integers (INT4) in uint8 to 8-bit signed integers (INT8)
+    using shift operations for sign extension.
+    """

23-28: Comments should be in English for consistency.

The implementation comments are in Chinese. Consider translating them to English for better maintainability.

-    # 方法3: 使用左移和算术右移
-    # 1. 先将4位数据左移4位,放到int8的高4位
-    # 2. 然后算术右移4位,自动进行符号扩展
+    # Method 3: Use left shift and arithmetic right shift
+    # 1. First shift the 4-bit data left by 4 bits to place it in the high 4 bits of int8
+    # 2. Then arithmetic right shift by 4 bits for automatic sign extension
    i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8"))
-    i8 = i8_shifted >> tir.const(4, "int8")  # 注意这里使用int8类型的右移(算术右移)
+    i8 = i8_shifted >> tir.const(4, "int8")  # Note: using int8 type for arithmetic right shift

41-42: Remove or uncomment the test function call.

There's a commented-out test function call. Either remove it or uncomment if it should be executed.

-# test_convert_int4_to_int8()
-

102-103: Remove debug print statements before merging.

Debug print statements should be removed from production code.

-    print(tl_out)
-    print(ref_out)
     assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)

120-122: Comments should be in English for consistency.

The comments explaining K dimensions are in Chinese.

-        # K是解包后的K
-        # 因此解包前是K // 2
+        # K is the unpacked dimension
+        # Therefore, the packed dimension is K // 2

205-205: Fix typo in print statement.

There's a typo in "Bset" which should be "Best".

-        print(f"Bset latency: {best_latency}")
+        print(f"Best latency: {best_latency}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f0d6669 and 1e3fd5c.

📒 Files selected for processing (3)
  • examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (2 hunks)
  • examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1 hunks)
  • requirements.txt (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (3)
examples/seer_attention/block_sparse_attn_tilelang.py (1)
  • kernel_func (47-150)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)
  • kernel_func (19-185)
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (1)
  • kernel_func (46-167)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (8)
tilelang/language/tir/op.py (1)
  • reinterpret (1816-1835)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/parallel.py (1)
  • Parallel (8-28)
tilelang/autotuner/tuner.py (1)
  • autotune (692-785)
tilelang/profiler/__init__.py (1)
  • assert_allclose (76-137)
🪛 GitHub Actions: CI Test on AMD
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py

[error] 1-1: Code formatting changes detected by clang-format. Please review and stage the changes.

examples/dequantize_gemm/example_dequant_gemm_w4a8.py

[error] 1-1: Code formatting changes detected by clang-format. Please review and stage the changes.

🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py

3-3: from tilelang.autotuner import * used; unable to detect undefined names

(F403)


24-24: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


25-25: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


27-27: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


27-27: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)


174-174: autotune may be undefined, or defined from star imports

(F405)

🔇 Additional comments (8)
requirements.txt (1)

7-7: LGTM!

The addition of a minimum version constraint for ml_dtypes is a good practice for dependency management and ensures compatibility with features used in the dequantization examples.

examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (2)

112-117: Search space expansion looks reasonable.

The expanded autotuner search space provides more granular control over performance tuning with additional block size and thread configuration options.


253-253: Verified: autotune-decorated kernel returns .prim_func (intentional & consistent)

Confirmed: examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py returns kernel_func(...).prim_func in the tuned branch (line 253) while the non‑tuned branch returns the callable (line 259). Same pattern in examples/dequantize_gemm/example_dequant_gemm_w4a8.py (lines 181/187). The autotuner implementation consumes jit_kernel.prim_func (tilelang/autotuner/tuner.py lines 429, 534). Keep as-is if the API change is intended; revert to returning the callable if not.

examples/dequantize_gemm/example_dequant_gemm_w4a8.py (5)

61-61: Consider using the full Pipelined signature for clarity.

Using T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1) with explicit start parameter is clearer than relying on positional arguments.


9-30: Sign extension implementation looks correct.

The INT4 to INT8 conversion using shift operations for sign extension is a standard and efficient approach. The implementation correctly handles the sign bit through arithmetic right shift.


1-219: Run clang-format on examples/dequantize_gemm/example_dequant_gemm_w4a8.py

CI reported formatting issues for this file; the verification script only echoed instructions and did not apply formatting — run clang-format -i examples/dequantize_gemm/example_dequant_gemm_w4a8.py, commit & push the change, then re-run CI.


37-37: No change required — 512 is valid. 512 ≤ 1024 and the same [128, 256, 512] config appears across multiple examples; no change needed.


181-181: Return-value pattern verified — tuning returns .prim_func; non-tuning returns kernel callable.

Confirmed in examples/dequantize_gemm: tuning paths return .prim_func (e.g., example_dequant_gemm_fp4_hopper.py ~lines 251–253; example_dequant_gemm_w4a8.py ~lines 179–181) while non-tuning branches return the kernel callable — no change required.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1)

227-229: Respect the --tune flag.

Don’t hardcode True; this was raised previously.

-    # main(M, N, K, args.tune)
-    main(M, N, K, True)
+    main(M, N, K, args.tune)
🧹 Nitpick comments (9)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (1)

242-242: Avoid star import; explicitly import autotune to fix Ruff F405/F403.

Star import from tilelang.autotuner can hide undefined names and triggers Ruff. Import only what you use.

Apply at file top:

-from tilelang.autotuner import *
+from tilelang.autotuner import autotune
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (8)

3-3: Avoid star import; import only what you need.

Prevents F403/F405 and keeps the namespace clean.

-from tilelang.autotuner import *
+from tilelang.autotuner import autotune

10-31: Document in English and clarify pos domain.

The implementation is correct; tighten docs for maintainers and avoid full‑width punctuation.

-def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
-    """
-    使用左移右移方法将打包在uint8中的4位有符号整数(INT4)转换为8位有符号整数(INT8)
-    """
+def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
+    """
+    Convert a 4-bit signed integer (INT4) packed in a uint8 to an 8-bit signed integer (INT8)
+    via left-shift then arithmetic right-shift (sign extension).
+    Notes:
+      - nbit must be 4; dtype must be "int8"; val.dtype must be "uint8".
+      - pos selects the nibble within the byte: pos ∈ {0 (low), 1 (high)}.
+    """

65-65: Minor: simplify Pipelined loop header.

Pure style; matches conventions used elsewhere in repo.

-            for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
+            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):

80-98: Vectorize the PyTorch reference conversion; current version is slow.

Removes Python loops and avoids signed-view gotchas.

-def torch_convert(tensor):
-
-    def _convert(val, pos):
-        assert val.dtype == torch.uint8
-        val = val.view(torch.int8)
-        mask = (1 << 4) - 1
-        i4_shifted = ((val >> (pos * 4)) & mask)
-        i4 = ((i4_shifted << 4) >> 4)
-        return i4.view(torch.int8)
-
-    N = tensor.shape[0]
-    K = tensor.shape[1]
-    new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device)
-    for i in range(new_tensor.shape[0]):
-        for j in range(new_tensor.shape[1]):
-            new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
-    return new_tensor
+def torch_convert(tensor: torch.Tensor) -> torch.Tensor:
+    # tensor: (N, K_packed) uint8, each byte packs [low_nibble, high_nibble]
+    low = tensor & 0x0F
+    high = tensor >> 4
+    # Sign-extend 4-bit two's complement to int8
+    low = (low - (low >= 8) * 16).to(torch.int8)
+    high = (high - (high >= 8) * 16).to(torch.int8)
+    return torch.stack([low, high], dim=-1).view(tensor.shape[0], -1)

100-112: Pack both nibbles in tests; drop massive prints.

Current torch.randint(0, 16, (N, K//2)) only populates the low nibble, leaving the high nibble zero. Pack two random nibbles into each byte and avoid printing huge tensors.

-    B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
-    tl_out = tl_convert_kernel(B)
-    ref_out = torch_convert(B)
-
-    print(tl_out)
-    print(ref_out)
+    low = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda")
+    high = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda")
+    B = (low | (high << 4)).to(torch.uint8)
+    tl_out = tl_convert_kernel(B)
+    ref_out = torch_convert(B)

114-120: Use integer reference GEMM to match kernel semantics.

Float matmul + cast can hide issues. Do integer matmul directly.

-    C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
-    C = C.to(torch.__getattribute__(dtypeC))
+    C = torch.matmul(A.to(torch.int32), B.T.to(torch.int32))

149-152: Remove redundant staging buffer; let the pipeline handle it.

B_dequantize_prev_local + extra copy is unnecessary with T.Pipelined.

-                B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
@@
-                    T.copy(B_dequantize_local, B_dequantize_prev_local)
-                    T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
+                    T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)

Also applies to: 171-173


213-215: Fix typo in log output.

"Bset" → "Best".

-        print(f"Bset latency: {best_latency}")
+        print(f"Best latency: {best_latency}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1e3fd5c and 893ec2d.

📒 Files selected for processing (2)
  • examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (2 hunks)
  • examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (2)
tilelang/autotuner/tuner.py (1)
  • autotune (692-785)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (4)
  • get_configs (33-41)
  • kernel (183-184)
  • kernel (190-191)
  • kernel_func (125-177)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (7)
tilelang/language/tir/op.py (1)
  • reinterpret (1816-1835)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (6)
  • get_configs (112-129)
  • torch_convert (31-58)
  • _convert (38-50)
  • matmul (132-258)
  • kernel (244-250)
  • kernel (255-256)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/parallel.py (1)
  • Parallel (8-28)
tilelang/language/__init__.py (1)
  • annotate_layout (103-141)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py

242-242: autotune may be undefined, or defined from star imports

(F405)

examples/dequantize_gemm/example_dequant_gemm_w4a8.py

3-3: from tilelang.autotuner import * used; unable to detect undefined names

(F403)


25-25: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


26-26: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


28-28: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


28-28: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)


181-181: autotune may be undefined, or defined from star imports

(F405)

🔇 Additional comments (4)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (2)

112-118: Search space expansion looks good; please sanity-check constraints.

LGTM to broaden configs. Please confirm each tuple honors kernel invariants at runtime (e.g., K % block_K == 0, smem/register pressure within limits for threads ∈ {128, 256}). Also consider capping configs for CI runtime.


250-250: Confirm .prim_func is the intended return for the autotuner path.

File: examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py:250 — repo search shows only @T.prim_func decorator uses (no other .prim_func attribute usages); tilelang's JITKernel docs do expose a prim_func property, but older adapter/attribute names have changed in the past. Ensure the project pins a tilelang version that provides JITKernel.prim_func; otherwise return the PrimFunc directly or guard with getattr(kernel_func(...), "prim_func", kernel_func(...)). (tilelang.com)

examples/dequantize_gemm/example_dequant_gemm_w4a8.py (2)

33-42: Config space OK; verify 512-thread blocks on target GPUs.

512 threads/block may exceed chosen arch occupancy once smem/regs are accounted for. If CI is slow, consider gating 512 behind a flag.


181-186: Confirm .prim_func return is correct for your TileLang version.

Same note as the FP4 example: ensure kernel_func(...).prim_func exists and is what @tilelang.jit/@autotune expect. If not, return the PrimFunc directly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (5)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (5)

88-94: Use integer matmul for the reference to avoid FP roundoff and be consistent with the kernel.

This also speeds up and matches intended semantics.

Apply:

-    dtypeC = "int32"
-    B = torch_convert(qB)
-    C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
-    C = C.to(torch.__getattribute__(dtypeC))
-    return C.transpose(0, 1)
+    B = torch_convert(qB)
+    C = torch.matmul(A.to(torch.int32), B.T.to(torch.int32))
+    return C.transpose(0, 1).to(torch.int32)

121-145: Remove redundant B_dequantize_prev_local and extra copy; let Pipelined handle staging.

Simplifies and reduces smem traffic; matches fp4 example guidance.

Apply:

-                B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
@@
-                    T.copy(B_dequantize_local, B_dequantize_prev_local)
-                    T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
+                    T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)

If you intentionally double‑buffer, please add a brief comment and keep it.


185-185: Fix typo in log message.

Apply:

-        print(f"Bset latency: {best_latency}")
+        print(f"Best latency: {best_latency}")

153-158: .prim_func attribute access is invalid; PrimFunc has no such attribute.

This raises AttributeError at runtime; return the PrimFunc directly.

Apply:

-        def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
-            return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func
+        def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
+            return kernel_func(block_M, block_N, block_K, num_stages, threads)

68-85: torch_convert is incorrect for values >127 and very slow. Replace with vectorized, sign‑correct conversion.

Current path reinterprets uint8 as int8 before shifting; this breaks for general inputs and loops on Python side.

Apply:

-def torch_convert(tensor):
-
-    def _convert(val, pos):
-        assert val.dtype == torch.uint8
-        val = val.view(torch.int8)
-        mask = (1 << 4) - 1
-        i4_shifted = ((val >> (pos * 4)) & mask)
-        i4 = ((i4_shifted << 4) >> 4)
-
-        return i4.view(torch.int8)
-
-    N = tensor.shape[0]
-    K = tensor.shape[1]
-    new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device)
-    for i in range(new_tensor.shape[0]):
-        for j in range(new_tensor.shape[1]):
-            new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
-    return new_tensor
+def torch_convert(tensor: torch.Tensor) -> torch.Tensor:
+    # tensor: (N, K_packed) uint8 -> (N, K_packed*2) int8
+    assert tensor.dtype == torch.uint8
+    low = tensor & 0x0F
+    high = tensor >> 4
+    # sign-extend 4-bit to 8-bit
+    low = (low - (low >= 8).to(tensor.dtype) * 16).to(torch.int8)
+    high = (high - (high >= 8).to(tensor.dtype) * 16).to(torch.int8)
+    unpacked = torch.stack([low, high], dim=-1)  # (N, K_packed, 2)
+    return unpacked.view(tensor.shape[0], -1)
🧹 Nitpick comments (1)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1)

3-3: Avoid star import; import only what you use.

Prevents F403/F405 and clarifies symbols.

Apply:

-from tilelang.autotuner import *
+from tilelang.autotuner import autotune
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b2cb897 and 8282f86.

📒 Files selected for processing (1)
  • examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (6)
tilelang/language/tir/op.py (1)
  • reinterpret (1816-1835)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (5)
  • main (70-89)
  • main (197-233)
  • main (269-291)
  • torch_convert (31-58)
  • _convert (38-50)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/parallel.py (1)
  • Parallel (8-28)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py

3-3: from tilelang.autotuner import * used; unable to detect undefined names

(F403)


153-153: autotune may be undefined, or defined from star imports

(F405)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: bot-task
  • GitHub Check: format-check
🔇 Additional comments (1)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1)

10-21: Sign-extension logic looks correct.

Bit-pack extract + reinterpret + arithmetic shift yields proper int4→int8 sign extension.

Please confirm numerics with a quick TIR unit test across all 16 nibble values. I can add it if useful.

Comment on lines +53 to +64
for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Tail handling missing in _convert_test; potential OOB reads.

You launch ceildiv tiles but copy full tiles without guards; assert divisibility or add masked/tail copies.

Minimal safeguard:

 def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
+    assert N % block_N == 0, "For now _convert_test requires N divisible by block_N"
+    assert K % block_K == 0, "For now _convert_test requires K divisible by block_K"

If you prefer tails, I can draft masked copies.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
assert N % block_N == 0, "For now _convert_test requires N divisible by block_N"
assert K % block_K == 0, "For now _convert_test requires K divisible by block_K"
for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_gemm_w4a8.py around lines 53 to 64,
the loop uses T.ceildiv for k but always copies/iterates full block_K/block_N
tiles which can cause out-of-bounds reads on the last (tail) tile; either assert
that K (and N if needed) is divisible by block_K/block_N at the start, or
implement guarded/masked copies and loop bounds for the final tile: limit the
copied range to the actual remaining elements (compute tail_k = K - k*block_K
and use a masked copy or conditional bounds for B_shared/B_local and for the
inner dequantize loop), and ensure the final T.copy into C writes only the valid
tail width.

Comment on lines +108 to +109
assert K % (block_K) == 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard all tile boundaries in GEMM; M and N tails can read/write out of bounds.

Grid is ceildiv on M/N, but copies assume full tiles. Either add masks or assert divisibility.

Apply minimal asserts:

-        assert K % (block_K) == 0
+        assert K % block_K == 0, "K must be divisible by block_K"
+        assert M % block_M == 0, "M must be divisible by block_M (no tail handling yet)"
+        assert N % block_N == 0, "N must be divisible by block_N (no tail handling yet)"

I can provide a masked-tail version if you want to support arbitrary sizes.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert K % (block_K) == 0
assert K % block_K == 0, "K must be divisible by block_K"
assert M % block_M == 0, "M must be divisible by block_M (no tail handling yet)"
assert N % block_N == 0, "N must be divisible by block_N (no tail handling yet)"
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_gemm_w4a8.py around lines 108-109,
the code only asserts K is divisible by block_K but leaves M and N tile tails
unguarded which can cause out-of-bounds reads/writes; add minimal guards by
asserting M % block_M == 0 and N % block_N == 0 (or replace with masked-tail
logic if you need to support arbitrary sizes), so copies operate only on full
tiles or implement masking for the M/N tails.

RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* [Bugfix] fix autotune bug

* [Example] add w4a8 gemm kernel

* fix lint: pinned the version of `ml_dtypes`
The version of ml_dtypes should be pinned in the dependency specification. If the version of ml_dtypes is too low, it may result in errors such as fp4 not being defined.

* Renames example for dequantization GEMM

* format

* add w4a8 example to ci

* fix lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants