Skip to content
18 changes: 7 additions & 11 deletions examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def test_fp4_fp16_convert_close():


def get_configs():
block_M = [128]
block_N = [128, 256]
block_K = [128]
num_stages = [2]
threads = [256]
block_M = [64, 128]
block_N = [64, 128]
block_K = [128, 256]
num_stages = [1, 2]
threads = [128, 256]
splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))

Expand Down Expand Up @@ -239,19 +239,15 @@ def main(

if tune:

@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"],
warmup=10,
rep=10)
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2])
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
threads=None,
split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split)
return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func

return kernel()
else:
Expand Down
200 changes: 200 additions & 0 deletions examples/dequantize_gemm/example_dequant_gemm_w4a8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from tvm import tir
import itertools
import torch
import argparse


def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "int8"
assert val.dtype == "uint8"

mask = tir.const((1 << nbit) - 1, "uint8")

i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask

i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8"))
i8 = i8_shifted >> tir.const(4, "int8")
return i8


def get_configs():
iter_params = dict(
block_M=[64, 128],
block_N=[64, 128],
block_K=[128, 256],
num_stages=[1, 2],
threads=[128, 256, 512],
)
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]


@tilelang.jit(out_idx=[1])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)

@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)

for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])

Comment on lines +53 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Tail handling missing in _convert_test; potential OOB reads.

You launch ceildiv tiles but copy full tiles without guards; assert divisibility or add masked/tail copies.

Minimal safeguard:

 def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
+    assert N % block_N == 0, "For now _convert_test requires N divisible by block_N"
+    assert K % block_K == 0, "For now _convert_test requires K divisible by block_K"

If you prefer tails, I can draft masked copies.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
assert N % block_N == 0, "For now _convert_test requires N divisible by block_N"
assert K % block_K == 0, "For now _convert_test requires K divisible by block_K"
for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_gemm_w4a8.py around lines 53 to 64,
the loop uses T.ceildiv for k but always copies/iterates full block_K/block_N
tiles which can cause out-of-bounds reads on the last (tail) tile; either assert
that K (and N if needed) is divisible by block_K/block_N at the start, or
implement guarded/masked copies and loop bounds for the final tile: limit the
copied range to the actual remaining elements (compute tail_k = K - k*block_K
and use a masked copy or conditional bounds for B_shared/B_local and for the
inner dequantize loop), and ensure the final T.copy into C writes only the valid
tail width.

return main


def torch_convert(tensor):

def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
i4_shifted = ((val >> (pos * 4)) & mask)
i4 = ((i4_shifted << 4) >> 4)

return i4.view(torch.int8)

N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor


def ref_program(A, qB):
dtypeC = "int32"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
Comment on lines +91 to +92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The reference program uses floating-point arithmetic for what should be an integer matrix multiplication. This can introduce unnecessary precision errors and is less efficient. It's better to perform the multiplication using integer types to ensure correctness and get a more reliable reference result.

Suggested change
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
C = torch.matmul(A.to(torch.int32), B.T.to(torch.int32))
C = C.to(torch.__getattribute__(dtypeC))

return C.transpose(0, 1)


def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):

@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_local_shape = (block_N, block_K)

assert K % (block_K) == 0

Comment on lines +108 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard all tile boundaries in GEMM; M and N tails can read/write out of bounds.

Grid is ceildiv on M/N, but copies assume full tiles. Either add masks or assert divisibility.

Apply minimal asserts:

-        assert K % (block_K) == 0
+        assert K % block_K == 0, "K must be divisible by block_K"
+        assert M % block_M == 0, "M must be divisible by block_M (no tail handling yet)"
+        assert N % block_N == 0, "N must be divisible by block_N (no tail handling yet)"

I can provide a masked-tail version if you want to support arbitrary sizes.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert K % (block_K) == 0
assert K % block_K == 0, "K must be divisible by block_K"
assert M % block_M == 0, "M must be divisible by block_M (no tail handling yet)"
assert N % block_N == 0, "N must be divisible by block_N (no tail handling yet)"
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_gemm_w4a8.py around lines 108-109,
the code only asserts K is divisible by block_K but leaves M and N tile tails
unguarded which can cause out-of-bounds reads/writes; add minimal guards by
asserting M % block_M == 0 and N % block_N == 0 (or replace with masked-tail
logic if you need to support arbitrary sizes), so copies operate only on full
tiles or implement masking for the M/N tails.

@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)

T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})

T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
Comment on lines +143 to +144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of B_dequantize_prev_local and the explicit copy on line 165 suggest manual software pipelining. This is likely redundant because the T.Pipelined construct is designed to manage this automatically. This manual approach can be removed to simplify the code and rely on the framework's pipelining. You can remove B_dequantize_prev_local and use B_dequantize_local directly in T.gemm.

Suggested change
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)

T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])

return main

if tune:

@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2])
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The kernel_func function returns a tvm.tir.PrimFunc object. PrimFunc objects do not have a .prim_func attribute, so accessing it here will raise an AttributeError at runtime. You should return the PrimFunc object directly.

Suggested change
return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func
return kernel_func(block_M, block_N, block_K, num_stages, threads)


return kernel()

else:

def kernel(block_M, block_N, block_K, num_stages, threads):
return kernel_func(block_M, block_N, block_K, num_stages, threads)

return kernel


def main(m=128, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul_int8xint4(
m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2)
print("All checks pass.")

latency = profiler.do_bench(warmup=50)
print(f"Tilelang: {latency} ms")

else:
best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Bset latency: {best_latency}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the print statement. "Bset" should be "Best".

Suggested change
print(f"Bset latency: {best_latency}")
print(f"Best latency: {best_latency}")

print(f"Best config: {best_config}")
print(f"Best tflops: {total_flops / best_latency * 1e-9}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=512, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=512, help="Matrix dimension K")
parser.add_argument("--tune", action="store_true", help="Enable tuning")
args = parser.parse_args()

M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
# main(M, N, K, True)
6 changes: 6 additions & 0 deletions examples/dequantize_gemm/test_example_dequantize_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper
import example_dequant_gemm_bf16_mxfp4_hopper_tma
import example_dequant_gemm_w4a8


@tilelang.testing.requires_cuda
Expand All @@ -29,5 +30,10 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
example_dequant_gemm_bf16_mxfp4_hopper_tma.main()


@tilelang.testing.requires_cuda
def test_example_dequant_gemm_w4a8():
example_dequant_gemm_w4a8.main()


if __name__ == "__main__":
tilelang.testing.main()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ numpy>=1.23.5
tqdm>=4.62.3
typing_extensions>=4.10.0
cloudpickle
ml_dtypes
ml_dtypes>=0.5.3
psutil
torch
Loading