-
Notifications
You must be signed in to change notification settings - Fork 444
[Example] add w4a8 gemm kernel #815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The version of ml_dtypes should be pinned in the dependency specification. If the version of ml_dtypes is too low, it may result in errors such as fp4 not being defined.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a new INT4 (w4a8) dequantization + autotuned GEMM example and CUDA test, expands the autotuner search space and changes wrapper return behavior in an existing FP4 Hopper example, and pins the ml_dtypes requirement to >=0.5.3. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant CLI as __main__
participant Main as main()
participant Kernel as matmul_int8xint4()
participant Auto as autotune()
participant Ref as ref_program()
participant Torch as torch_convert()
User->>CLI: Provide m,n,k and --tune
CLI->>Main: main(m,n,k,tune)
alt tune == True
Main->>Kernel: matmul_int8xint4(..., tune=True)
Kernel->>Auto: search(get_configs())
Auto-->>Kernel: best config + prim_func
Main-->>User: report best config, TFLOPs
else tune == False
Main->>Kernel: matmul_int8xint4(..., tune=False)
Kernel-->>Main: prim_func (untuned)
Main->>Torch: torch_convert(qB)
Torch-->>Ref: B_int8
Main->>Ref: ref_program(A, B_int8)
Ref-->>Main: reference output
Main-->>User: latency, validation status
end
sequenceDiagram
autonumber
participant Tile as Tile Kernel
participant SM as Shared/Local Buffers
participant Unpack as _tir_u8_to_i4_to_i8
participant GEMM as T.gemm
Tile->>SM: Load A/B tiles into shared/local
Tile->>Unpack: Depack uint8 INT4 -> int8 per element
Unpack-->>SM: Provide dequantized B_local
Tile->>GEMM: Accumulate A_tile x B_tile^T -> Ct
Tile-->>Tile: Iterate over K blocks (num_stages, threads)
Tile-->>Caller: Store Ct -> C
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Cunxiao2002, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a new optimized General Matrix Multiply (GEMM) kernel specifically designed for 4-bit weights and 8-bit activations (W4A8), enhancing the library's capabilities for quantized operations. It includes a dedicated example to demonstrate its implementation and performance tuning. Concurrently, an existing FP4 GEMM example has been improved by expanding its autotuning search space, aiming for better performance discovery, and adapting to recent API changes in the tilelang framework. Additionally, a minor dependency version constraint has been updated.
Highlights
- New W4A8 GEMM Kernel: A new example has been added for a W4A8 (Weight 4-bit, Activation 8-bit) General Matrix Multiply (GEMM) kernel. This includes a
_tir_u8_to_i4_to_i8function for dequantizing packeduint8values intoint8within the kernel. - Expanded Autotuning Configurations: The autotuning configuration space for the existing FP4 GEMM example (
example_dequant_gemm_fp4_hopper.py) has been significantly broadened. This includes a wider range ofblock_M,block_N,block_K,num_stages, andthreadsparameters to explore for optimal performance. - Autotune API Refinement: Updates were made to the
@autotunedecorator usage in the FP4 GEMM example. The explicitkeysargument was removed, and the return value ofkernel_funcnow correctly accesses its.prim_funcattribute, aligning with potentialtilelangAPI changes. - Dependency Update: The
ml_dtypesdependency inrequirements.txthas been updated to specify a minimum version of0.5.3.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new example for a w4a8 GEMM kernel and updates an existing fp4 example. While the changes to the existing fp4_hopper.py example are improvements to its autotuning configuration, the new example_dequant_gemm_w4a8.py file contains several issues. My review includes feedback on a critical bug that will cause an AttributeError, incorrect command-line argument handling, a fragile and inefficient reference implementation, and other opportunities for improving code quality and maintainability.
| block_K=None, | ||
| num_stages=None, | ||
| threads=None): | ||
| return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel_func function returns a tvm.tir.PrimFunc object. PrimFunc objects do not have a .prim_func attribute, so accessing it here will raise an AttributeError at runtime. You should return the PrimFunc object directly.
| return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func | |
| return kernel_func(block_M, block_N, block_K, num_stages, threads) |
| def torch_convert(tensor): | ||
| def _convert(val, pos): | ||
| assert val.dtype == torch.uint8 | ||
| val = val.view(torch.int8) | ||
| mask = (1 << 4) - 1 | ||
| i4_shifted = ((val >> (pos * 4)) & mask) | ||
| i4 = ((i4_shifted << 4) >> 4) | ||
|
|
||
| return i4.view(torch.int8) | ||
|
|
||
| N = tensor.shape[0] | ||
| K = tensor.shape[1] | ||
| new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device) | ||
| for i in range(new_tensor.shape[0]): | ||
| for j in range(new_tensor.shape[1]): | ||
| new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) | ||
| return new_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch_convert function has a potential bug and is inefficient.
- Potential Bug: On line 78,
val.view(torch.int8)reinterprets the bits of atorch.uint8scalar astorch.int8. This is incorrect for values > 127. While current tests use a narrow range of inputs[0, 15], this is fragile and will produce incorrect results for the fullint4range. - Inefficiency: The implementation uses nested Python loops, which is very slow for tensor operations in PyTorch. This can be vectorized for a significant performance improvement.
Consider replacing this function with a vectorized version like the one below, which is both correct for all input values and much faster:
def torch_convert(tensor: torch.Tensor) -> torch.Tensor:
"""Vectorized conversion of a uint8 tensor (packing two int4s) to an int8 tensor."""
# tensor is (N, K_packed) uint8
# output is (N, K_packed * 2) int8
# Isolate low and high 4-bit nibbles
low_nibbles = tensor & 0x0F
high_nibbles = tensor >> 4
# Sign-extend nibbles from 4-bit to 8-bit.
# Values >= 8 (0b1000) are negative in int4.
low_nibbles = low_nibbles - (low_nibbles >= 8) * 16
high_nibbles = high_nibbles - (high_nibbles >= 8) * 16
# Interleave the nibbles into a new last dimension
unpacked = torch.stack([low_nibbles, high_nibbles], dim=-1)
# Reshape to final (N, K_packed * 2)
return unpacked.view(tensor.shape[0], -1).to(torch.int8)|
|
||
| M, N, K = args.m, args.n, args.k | ||
| # main(M, N, K, args.tune) | ||
| main(M, N, K, True) No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ | ||
| 使用左移右移方法将打包在uint8中的4位有符号整数(INT4)转换为8位有符号整数(INT8) | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) | ||
| C = C.to(torch.__getattribute__(dtypeC)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reference program uses floating-point arithmetic for what should be an integer matrix multiplication. This can introduce unnecessary precision errors and is less efficient. It's better to perform the multiplication using integer types to ensure correctness and get a more reliable reference result.
| C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) | |
| C = C.to(torch.__getattribute__(dtypeC)) | |
| C = torch.matmul(A.to(torch.int32), B.T.to(torch.int32)) | |
| C = C.to(torch.__getattribute__(dtypeC)) |
| T.copy(B_dequantize_local, B_dequantize_prev_local) | ||
| T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of B_dequantize_prev_local and the explicit copy on line 165 suggest manual software pipelining. This is likely redundant because the T.Pipelined construct is designed to manage this automatically. This manual approach can be removed to simplify the code and rely on the framework's pipelining. You can remove B_dequantize_prev_local and use B_dequantize_local directly in T.gemm.
| T.copy(B_dequantize_local, B_dequantize_prev_local) | |
| T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) | |
| T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True) |
| best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune) | ||
| best_latency = best_result.latency | ||
| best_config = best_result.config | ||
| print(f"Bset latency: {best_latency}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (6)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (6)
11-12: Non-English comments should be in English for better collaboration.The function docstring contains Chinese text. Consider using English for better accessibility to all contributors.
- """ - 使用左移右移方法将打包在uint8中的4位有符号整数(INT4)转换为8位有符号整数(INT8) - """ + """ + Convert packed 4-bit signed integers (INT4) in uint8 to 8-bit signed integers (INT8) + using shift operations for sign extension. + """
23-28: Comments should be in English for consistency.The implementation comments are in Chinese. Consider translating them to English for better maintainability.
- # 方法3: 使用左移和算术右移 - # 1. 先将4位数据左移4位,放到int8的高4位 - # 2. 然后算术右移4位,自动进行符号扩展 + # Method 3: Use left shift and arithmetic right shift + # 1. First shift the 4-bit data left by 4 bits to place it in the high 4 bits of int8 + # 2. Then arithmetic right shift by 4 bits for automatic sign extension i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8")) - i8 = i8_shifted >> tir.const(4, "int8") # 注意这里使用int8类型的右移(算术右移) + i8 = i8_shifted >> tir.const(4, "int8") # Note: using int8 type for arithmetic right shift
41-42: Remove or uncomment the test function call.There's a commented-out test function call. Either remove it or uncomment if it should be executed.
-# test_convert_int4_to_int8() -
102-103: Remove debug print statements before merging.Debug print statements should be removed from production code.
- print(tl_out) - print(ref_out) assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
120-122: Comments should be in English for consistency.The comments explaining K dimensions are in Chinese.
- # K是解包后的K - # 因此解包前是K // 2 + # K is the unpacked dimension + # Therefore, the packed dimension is K // 2
205-205: Fix typo in print statement.There's a typo in "Bset" which should be "Best".
- print(f"Bset latency: {best_latency}") + print(f"Best latency: {best_latency}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py(2 hunks)examples/dequantize_gemm/example_dequant_gemm_w4a8.py(1 hunks)requirements.txt(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (3)
examples/seer_attention/block_sparse_attn_tilelang.py (1)
kernel_func(47-150)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)
kernel_func(19-185)examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (1)
kernel_func(46-167)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (8)
tilelang/language/tir/op.py (1)
reinterpret(1816-1835)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/parallel.py (1)
Parallel(8-28)tilelang/autotuner/tuner.py (1)
autotune(692-785)tilelang/profiler/__init__.py (1)
assert_allclose(76-137)
🪛 GitHub Actions: CI Test on AMD
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
[error] 1-1: Code formatting changes detected by clang-format. Please review and stage the changes.
examples/dequantize_gemm/example_dequant_gemm_w4a8.py
[error] 1-1: Code formatting changes detected by clang-format. Please review and stage the changes.
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py
3-3: from tilelang.autotuner import * used; unable to detect undefined names
(F403)
24-24: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
25-25: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
27-27: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
27-27: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
174-174: autotune may be undefined, or defined from star imports
(F405)
🔇 Additional comments (8)
requirements.txt (1)
7-7: LGTM!The addition of a minimum version constraint for
ml_dtypesis a good practice for dependency management and ensures compatibility with features used in the dequantization examples.examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (2)
112-117: Search space expansion looks reasonable.The expanded autotuner search space provides more granular control over performance tuning with additional block size and thread configuration options.
253-253: Verified: autotune-decorated kernel returns .prim_func (intentional & consistent)Confirmed: examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py returns kernel_func(...).prim_func in the tuned branch (line 253) while the non‑tuned branch returns the callable (line 259). Same pattern in examples/dequantize_gemm/example_dequant_gemm_w4a8.py (lines 181/187). The autotuner implementation consumes jit_kernel.prim_func (tilelang/autotuner/tuner.py lines 429, 534). Keep as-is if the API change is intended; revert to returning the callable if not.
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (5)
61-61: Consider using the full Pipelined signature for clarity.Using
T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1)with explicit start parameter is clearer than relying on positional arguments.
9-30: Sign extension implementation looks correct.The INT4 to INT8 conversion using shift operations for sign extension is a standard and efficient approach. The implementation correctly handles the sign bit through arithmetic right shift.
1-219: Run clang-format on examples/dequantize_gemm/example_dequant_gemm_w4a8.pyCI reported formatting issues for this file; the verification script only echoed instructions and did not apply formatting — run clang-format -i examples/dequantize_gemm/example_dequant_gemm_w4a8.py, commit & push the change, then re-run CI.
37-37: No change required — 512 is valid. 512 ≤ 1024 and the same [128, 256, 512] config appears across multiple examples; no change needed.
181-181: Return-value pattern verified — tuning returns .prim_func; non-tuning returns kernel callable.Confirmed in examples/dequantize_gemm: tuning paths return .prim_func (e.g., example_dequant_gemm_fp4_hopper.py ~lines 251–253; example_dequant_gemm_w4a8.py ~lines 179–181) while non-tuning branches return the kernel callable — no change required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1)
227-229: Respect the--tuneflag.Don’t hardcode
True; this was raised previously.- # main(M, N, K, args.tune) - main(M, N, K, True) + main(M, N, K, args.tune)
🧹 Nitpick comments (9)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (1)
242-242: Avoid star import; explicitly importautotuneto fix Ruff F405/F403.Star import from
tilelang.autotunercan hide undefined names and triggers Ruff. Import only what you use.Apply at file top:
-from tilelang.autotuner import * +from tilelang.autotuner import autotuneexamples/dequantize_gemm/example_dequant_gemm_w4a8.py (8)
3-3: Avoid star import; import only what you need.Prevents F403/F405 and keeps the namespace clean.
-from tilelang.autotuner import * +from tilelang.autotuner import autotune
10-31: Document in English and clarifyposdomain.The implementation is correct; tighten docs for maintainers and avoid full‑width punctuation.
-def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - """ - 使用左移右移方法将打包在uint8中的4位有符号整数(INT4)转换为8位有符号整数(INT8) - """ +def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + """ + Convert a 4-bit signed integer (INT4) packed in a uint8 to an 8-bit signed integer (INT8) + via left-shift then arithmetic right-shift (sign extension). + Notes: + - nbit must be 4; dtype must be "int8"; val.dtype must be "uint8". + - pos selects the nibble within the byte: pos ∈ {0 (low), 1 (high)}. + """
65-65: Minor: simplify Pipelined loop header.Pure style; matches conventions used elsewhere in repo.
- for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1): + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
80-98: Vectorize the PyTorch reference conversion; current version is slow.Removes Python loops and avoids signed-view gotchas.
-def torch_convert(tensor): - - def _convert(val, pos): - assert val.dtype == torch.uint8 - val = val.view(torch.int8) - mask = (1 << 4) - 1 - i4_shifted = ((val >> (pos * 4)) & mask) - i4 = ((i4_shifted << 4) >> 4) - return i4.view(torch.int8) - - N = tensor.shape[0] - K = tensor.shape[1] - new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device) - for i in range(new_tensor.shape[0]): - for j in range(new_tensor.shape[1]): - new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) - return new_tensor +def torch_convert(tensor: torch.Tensor) -> torch.Tensor: + # tensor: (N, K_packed) uint8, each byte packs [low_nibble, high_nibble] + low = tensor & 0x0F + high = tensor >> 4 + # Sign-extend 4-bit two's complement to int8 + low = (low - (low >= 8) * 16).to(torch.int8) + high = (high - (high >= 8) * 16).to(torch.int8) + return torch.stack([low, high], dim=-1).view(tensor.shape[0], -1)
100-112: Pack both nibbles in tests; drop massive prints.Current
torch.randint(0, 16, (N, K//2))only populates the low nibble, leaving the high nibble zero. Pack two random nibbles into each byte and avoid printing huge tensors.- B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) - tl_out = tl_convert_kernel(B) - ref_out = torch_convert(B) - - print(tl_out) - print(ref_out) + low = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda") + high = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda") + B = (low | (high << 4)).to(torch.uint8) + tl_out = tl_convert_kernel(B) + ref_out = torch_convert(B)
114-120: Use integer reference GEMM to match kernel semantics.Float matmul + cast can hide issues. Do integer matmul directly.
- C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) - C = C.to(torch.__getattribute__(dtypeC)) + C = torch.matmul(A.to(torch.int32), B.T.to(torch.int32))
149-152: Remove redundant staging buffer; let the pipeline handle it.
B_dequantize_prev_local+ extra copy is unnecessary withT.Pipelined.- B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) @@ - T.copy(B_dequantize_local, B_dequantize_prev_local) - T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)Also applies to: 171-173
213-215: Fix typo in log output."Bset" → "Best".
- print(f"Bset latency: {best_latency}") + print(f"Best latency: {best_latency}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py(2 hunks)examples/dequantize_gemm/example_dequant_gemm_w4a8.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (2)
tilelang/autotuner/tuner.py (1)
autotune(692-785)examples/dequantize_gemm/example_dequant_gemm_w4a8.py (4)
get_configs(33-41)kernel(183-184)kernel(190-191)kernel_func(125-177)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (7)
tilelang/language/tir/op.py (1)
reinterpret(1816-1835)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (6)
get_configs(112-129)torch_convert(31-58)_convert(38-50)matmul(132-258)kernel(244-250)kernel(255-256)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/parallel.py (1)
Parallel(8-28)tilelang/language/__init__.py (1)
annotate_layout(103-141)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
242-242: autotune may be undefined, or defined from star imports
(F405)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py
3-3: from tilelang.autotuner import * used; unable to detect undefined names
(F403)
25-25: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
26-26: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
28-28: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
28-28: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
181-181: autotune may be undefined, or defined from star imports
(F405)
🔇 Additional comments (4)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (2)
112-118: Search space expansion looks good; please sanity-check constraints.LGTM to broaden configs. Please confirm each tuple honors kernel invariants at runtime (e.g., K % block_K == 0, smem/register pressure within limits for threads ∈ {128, 256}). Also consider capping configs for CI runtime.
250-250: Confirm.prim_funcis the intended return for the autotuner path.File: examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py:250 — repo search shows only
@T.prim_funcdecorator uses (no other.prim_funcattribute usages); tilelang's JITKernel docs do expose aprim_funcproperty, but older adapter/attribute names have changed in the past. Ensure the project pins a tilelang version that providesJITKernel.prim_func; otherwise return thePrimFuncdirectly or guard withgetattr(kernel_func(...), "prim_func", kernel_func(...)). (tilelang.com)examples/dequantize_gemm/example_dequant_gemm_w4a8.py (2)
33-42: Config space OK; verify 512-thread blocks on target GPUs.512 threads/block may exceed chosen arch occupancy once smem/regs are accounted for. If CI is slow, consider gating 512 behind a flag.
181-186: Confirm.prim_funcreturn is correct for your TileLang version.Same note as the FP4 example: ensure
kernel_func(...).prim_funcexists and is what@tilelang.jit/@autotuneexpect. If not, return thePrimFuncdirectly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (5)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (5)
88-94: Use integer matmul for the reference to avoid FP roundoff and be consistent with the kernel.This also speeds up and matches intended semantics.
Apply:
- dtypeC = "int32" - B = torch_convert(qB) - C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) - C = C.to(torch.__getattribute__(dtypeC)) - return C.transpose(0, 1) + B = torch_convert(qB) + C = torch.matmul(A.to(torch.int32), B.T.to(torch.int32)) + return C.transpose(0, 1).to(torch.int32)
121-145: Remove redundant B_dequantize_prev_local and extra copy; let Pipelined handle staging.Simplifies and reduces smem traffic; matches fp4 example guidance.
Apply:
- B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) @@ - T.copy(B_dequantize_local, B_dequantize_prev_local) - T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)If you intentionally double‑buffer, please add a brief comment and keep it.
185-185: Fix typo in log message.Apply:
- print(f"Bset latency: {best_latency}") + print(f"Best latency: {best_latency}")
153-158: .prim_func attribute access is invalid; PrimFunc has no such attribute.This raises AttributeError at runtime; return the PrimFunc directly.
Apply:
- def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): - return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): + return kernel_func(block_M, block_N, block_K, num_stages, threads)
68-85: torch_convert is incorrect for values >127 and very slow. Replace with vectorized, sign‑correct conversion.Current path reinterprets uint8 as int8 before shifting; this breaks for general inputs and loops on Python side.
Apply:
-def torch_convert(tensor): - - def _convert(val, pos): - assert val.dtype == torch.uint8 - val = val.view(torch.int8) - mask = (1 << 4) - 1 - i4_shifted = ((val >> (pos * 4)) & mask) - i4 = ((i4_shifted << 4) >> 4) - - return i4.view(torch.int8) - - N = tensor.shape[0] - K = tensor.shape[1] - new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device) - for i in range(new_tensor.shape[0]): - for j in range(new_tensor.shape[1]): - new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) - return new_tensor +def torch_convert(tensor: torch.Tensor) -> torch.Tensor: + # tensor: (N, K_packed) uint8 -> (N, K_packed*2) int8 + assert tensor.dtype == torch.uint8 + low = tensor & 0x0F + high = tensor >> 4 + # sign-extend 4-bit to 8-bit + low = (low - (low >= 8).to(tensor.dtype) * 16).to(torch.int8) + high = (high - (high >= 8).to(tensor.dtype) * 16).to(torch.int8) + unpacked = torch.stack([low, high], dim=-1) # (N, K_packed, 2) + return unpacked.view(tensor.shape[0], -1)
🧹 Nitpick comments (1)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1)
3-3: Avoid star import; import only what you use.Prevents F403/F405 and clarifies symbols.
Apply:
-from tilelang.autotuner import * +from tilelang.autotuner import autotune
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (6)
tilelang/language/tir/op.py (1)
reinterpret(1816-1835)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (5)
main(70-89)main(197-233)main(269-291)torch_convert(31-58)_convert(38-50)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/parallel.py (1)
Parallel(8-28)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py
3-3: from tilelang.autotuner import * used; unable to detect undefined names
(F403)
153-153: autotune may be undefined, or defined from star imports
(F405)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: bot-task
- GitHub Check: format-check
🔇 Additional comments (1)
examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1)
10-21: Sign-extension logic looks correct.Bit-pack extract + reinterpret + arithmetic shift yields proper int4→int8 sign extension.
Please confirm numerics with a quick TIR unit test across all 16 nibble values. I can add it if useful.
| for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1): | ||
| T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) | ||
| T.copy(B_shared, B_local) | ||
| for i, j in T.Parallel(block_N, block_K): | ||
| B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( | ||
| num_bits, | ||
| B_local[i, j // num_elems_per_byte], | ||
| j % num_elems_per_byte, | ||
| dtype=in_dtype, | ||
| ) | ||
| T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tail handling missing in _convert_test; potential OOB reads.
You launch ceildiv tiles but copy full tiles without guards; assert divisibility or add masked/tail copies.
Minimal safeguard:
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
+ assert N % block_N == 0, "For now _convert_test requires N divisible by block_N"
+ assert K % block_K == 0, "For now _convert_test requires K divisible by block_K"If you prefer tails, I can draft masked copies.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1): | |
| T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) | |
| T.copy(B_shared, B_local) | |
| for i, j in T.Parallel(block_N, block_K): | |
| B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( | |
| num_bits, | |
| B_local[i, j // num_elems_per_byte], | |
| j % num_elems_per_byte, | |
| dtype=in_dtype, | |
| ) | |
| T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) | |
| def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): | |
| assert N % block_N == 0, "For now _convert_test requires N divisible by block_N" | |
| assert K % block_K == 0, "For now _convert_test requires K divisible by block_K" | |
| for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1): | |
| T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) | |
| T.copy(B_shared, B_local) | |
| for i, j in T.Parallel(block_N, block_K): | |
| B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( | |
| num_bits, | |
| B_local[i, j // num_elems_per_byte], | |
| j % num_elems_per_byte, | |
| dtype=in_dtype, | |
| ) | |
| T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) |
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_gemm_w4a8.py around lines 53 to 64,
the loop uses T.ceildiv for k but always copies/iterates full block_K/block_N
tiles which can cause out-of-bounds reads on the last (tail) tile; either assert
that K (and N if needed) is divisible by block_K/block_N at the start, or
implement guarded/masked copies and loop bounds for the final tile: limit the
copied range to the actual remaining elements (compute tail_k = K - k*block_K
and use a masked copy or conditional bounds for B_shared/B_local and for the
inner dequantize loop), and ensure the final T.copy into C writes only the valid
tail width.
| assert K % (block_K) == 0 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard all tile boundaries in GEMM; M and N tails can read/write out of bounds.
Grid is ceildiv on M/N, but copies assume full tiles. Either add masks or assert divisibility.
Apply minimal asserts:
- assert K % (block_K) == 0
+ assert K % block_K == 0, "K must be divisible by block_K"
+ assert M % block_M == 0, "M must be divisible by block_M (no tail handling yet)"
+ assert N % block_N == 0, "N must be divisible by block_N (no tail handling yet)"I can provide a masked-tail version if you want to support arbitrary sizes.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert K % (block_K) == 0 | |
| assert K % block_K == 0, "K must be divisible by block_K" | |
| assert M % block_M == 0, "M must be divisible by block_M (no tail handling yet)" | |
| assert N % block_N == 0, "N must be divisible by block_N (no tail handling yet)" |
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_gemm_w4a8.py around lines 108-109,
the code only asserts K is divisible by block_K but leaves M and N tile tails
unguarded which can cause out-of-bounds reads/writes; add minimal guards by
asserting M % block_M == 0 and N % block_N == 0 (or replace with masked-tail
logic if you need to support arbitrary sizes), so copies operate only on full
tiles or implement masking for the M/N tails.
* [Bugfix] fix autotune bug * [Example] add w4a8 gemm kernel * fix lint: pinned the version of `ml_dtypes` The version of ml_dtypes should be pinned in the dependency specification. If the version of ml_dtypes is too low, it may result in errors such as fp4 not being defined. * Renames example for dequantization GEMM * format * add w4a8 example to ci * fix lint
Summary by CodeRabbit
New Features
Performance
Tests
Chores