-
Notifications
You must be signed in to change notification settings - Fork 446
[Example] add w4a8 gemm kernel #815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1b4723a
fa7a1a5
2d02044
1e3fd5c
136a4bb
5294335
893ec2d
b2cb897
8282f86
80a7d89
b396e6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,200 @@ | ||||||||||
| import tilelang | ||||||||||
| import tilelang.language as T | ||||||||||
| from tilelang.autotuner import * | ||||||||||
| from tvm import tir | ||||||||||
| import itertools | ||||||||||
| import torch | ||||||||||
| import argparse | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): | ||||||||||
| assert nbit == 4 | ||||||||||
| assert dtype == "int8" | ||||||||||
| assert val.dtype == "uint8" | ||||||||||
|
|
||||||||||
| mask = tir.const((1 << nbit) - 1, "uint8") | ||||||||||
|
|
||||||||||
| i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask | ||||||||||
|
|
||||||||||
| i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8")) | ||||||||||
| i8 = i8_shifted >> tir.const(4, "int8") | ||||||||||
| return i8 | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def get_configs(): | ||||||||||
| iter_params = dict( | ||||||||||
| block_M=[64, 128], | ||||||||||
| block_N=[64, 128], | ||||||||||
| block_K=[128, 256], | ||||||||||
| num_stages=[1, 2], | ||||||||||
| threads=[128, 256, 512], | ||||||||||
| ) | ||||||||||
| return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @tilelang.jit(out_idx=[1]) | ||||||||||
| def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): | ||||||||||
| num_elems_per_byte = 8 // num_bits | ||||||||||
| storage_dtype = "uint8" | ||||||||||
| B_shape = (N, K // num_elems_per_byte) | ||||||||||
| B_shared_shape = (block_N, block_K // num_elems_per_byte) | ||||||||||
| B_dequantize_shared_shape = (block_N, block_K) | ||||||||||
|
|
||||||||||
| @T.prim_func | ||||||||||
| def main( | ||||||||||
| B: T.Tensor(B_shape, storage_dtype), | ||||||||||
| C: T.Tensor((N, K), in_dtype), | ||||||||||
| ): | ||||||||||
| with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): | ||||||||||
| B_shared = T.alloc_shared(B_shared_shape, storage_dtype) | ||||||||||
| B_local = T.alloc_fragment(B_shared_shape, storage_dtype) | ||||||||||
| B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) | ||||||||||
|
|
||||||||||
| for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1): | ||||||||||
| T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) | ||||||||||
| T.copy(B_shared, B_local) | ||||||||||
| for i, j in T.Parallel(block_N, block_K): | ||||||||||
| B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( | ||||||||||
| num_bits, | ||||||||||
| B_local[i, j // num_elems_per_byte], | ||||||||||
| j % num_elems_per_byte, | ||||||||||
| dtype=in_dtype, | ||||||||||
| ) | ||||||||||
| T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) | ||||||||||
|
|
||||||||||
| return main | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def torch_convert(tensor): | ||||||||||
|
|
||||||||||
| def _convert(val, pos): | ||||||||||
| assert val.dtype == torch.uint8 | ||||||||||
| val = val.view(torch.int8) | ||||||||||
| mask = (1 << 4) - 1 | ||||||||||
| i4_shifted = ((val >> (pos * 4)) & mask) | ||||||||||
| i4 = ((i4_shifted << 4) >> 4) | ||||||||||
|
|
||||||||||
| return i4.view(torch.int8) | ||||||||||
|
|
||||||||||
| N = tensor.shape[0] | ||||||||||
| K = tensor.shape[1] | ||||||||||
| new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device) | ||||||||||
| for i in range(new_tensor.shape[0]): | ||||||||||
| for j in range(new_tensor.shape[1]): | ||||||||||
| new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) | ||||||||||
| return new_tensor | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def ref_program(A, qB): | ||||||||||
| dtypeC = "int32" | ||||||||||
| B = torch_convert(qB) | ||||||||||
| C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) | ||||||||||
| C = C.to(torch.__getattribute__(dtypeC)) | ||||||||||
|
Comment on lines
+91
to
+92
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reference program uses floating-point arithmetic for what should be an integer matrix multiplication. This can introduce unnecessary precision errors and is less efficient. It's better to perform the multiplication using integer types to ensure correctness and get a more reliable reference result.
Suggested change
|
||||||||||
| return C.transpose(0, 1) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): | ||||||||||
|
|
||||||||||
| @tilelang.jit(out_idx=[2]) | ||||||||||
| def kernel_func(block_M, block_N, block_K, num_stages, threads): | ||||||||||
| num_elems_per_byte = 8 // num_bits | ||||||||||
| storage_dtype = "uint8" | ||||||||||
| A_shape = (M, K) | ||||||||||
| B_shape = (N, K // num_elems_per_byte) | ||||||||||
| A_shared_shape = (block_M, block_K) | ||||||||||
| B_shared_shape = (block_N, block_K // num_elems_per_byte) | ||||||||||
| B_dequantize_local_shape = (block_N, block_K) | ||||||||||
|
|
||||||||||
| assert K % (block_K) == 0 | ||||||||||
|
|
||||||||||
|
Comment on lines
+108
to
+109
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard all tile boundaries in GEMM; M and N tails can read/write out of bounds. Grid is ceildiv on M/N, but copies assume full tiles. Either add masks or assert divisibility. Apply minimal asserts: - assert K % (block_K) == 0
+ assert K % block_K == 0, "K must be divisible by block_K"
+ assert M % block_M == 0, "M must be divisible by block_M (no tail handling yet)"
+ assert N % block_N == 0, "N must be divisible by block_N (no tail handling yet)"I can provide a masked-tail version if you want to support arbitrary sizes. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||
| @T.prim_func | ||||||||||
| def main( | ||||||||||
| A: T.Tensor(A_shape, in_dtype), | ||||||||||
| B: T.Tensor(B_shape, storage_dtype), | ||||||||||
| Ct: T.Tensor((N, M), out_dtype), | ||||||||||
| ): | ||||||||||
| with T.Kernel( | ||||||||||
| T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||||||||||
| A_shared = T.alloc_shared(A_shared_shape, in_dtype) | ||||||||||
| B_shared = T.alloc_shared(B_shared_shape, storage_dtype) | ||||||||||
| B_local = T.alloc_fragment(B_shared_shape, storage_dtype) | ||||||||||
| B_dequantize_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) | ||||||||||
| B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) | ||||||||||
| Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) | ||||||||||
| Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) | ||||||||||
|
|
||||||||||
| T.annotate_layout({ | ||||||||||
| B_shared: tilelang.layout.make_swizzled_layout(B_shared), | ||||||||||
| Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), | ||||||||||
| }) | ||||||||||
|
|
||||||||||
| T.clear(Ct_local) | ||||||||||
| for k in T.Pipelined(K // block_K, num_stages=num_stages): | ||||||||||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||||||||||
| T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) | ||||||||||
| T.copy(B_shared, B_local) | ||||||||||
| for i, j in T.Parallel(block_N, block_K): | ||||||||||
| B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( | ||||||||||
| num_bits, | ||||||||||
| B_local[i, j // num_elems_per_byte], | ||||||||||
| j % num_elems_per_byte, | ||||||||||
| dtype=in_dtype, | ||||||||||
| ) | ||||||||||
| T.copy(B_dequantize_local, B_dequantize_prev_local) | ||||||||||
| T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) | ||||||||||
|
Comment on lines
+143
to
+144
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of
Suggested change
|
||||||||||
| T.copy(Ct_local, Ct_shared) | ||||||||||
| T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, | ||||||||||
| by * block_M:(by + 1) * block_M]) | ||||||||||
|
|
||||||||||
| return main | ||||||||||
|
|
||||||||||
| if tune: | ||||||||||
|
|
||||||||||
| @autotune(configs=get_configs(), warmup=10, rep=10) | ||||||||||
| @tilelang.jit(out_idx=[2]) | ||||||||||
| def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): | ||||||||||
| return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||
|
|
||||||||||
| return kernel() | ||||||||||
|
|
||||||||||
| else: | ||||||||||
|
|
||||||||||
| def kernel(block_M, block_N, block_K, num_stages, threads): | ||||||||||
| return kernel_func(block_M, block_N, block_K, num_stages, threads) | ||||||||||
|
|
||||||||||
| return kernel | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def main(m=128, n=256, k=256, tune=False): | ||||||||||
| total_flops = 2 * m * n * k | ||||||||||
| if (not tune): | ||||||||||
| kernel = matmul_int8xint4( | ||||||||||
| m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( | ||||||||||
| block_M=32, block_N=32, block_K=128, num_stages=1, threads=128) | ||||||||||
| profiler = kernel.get_profiler() | ||||||||||
| profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) | ||||||||||
| print("All checks pass.") | ||||||||||
|
|
||||||||||
| latency = profiler.do_bench(warmup=50) | ||||||||||
| print(f"Tilelang: {latency} ms") | ||||||||||
|
|
||||||||||
| else: | ||||||||||
| best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune) | ||||||||||
| best_latency = best_result.latency | ||||||||||
| best_config = best_result.config | ||||||||||
| print(f"Bset latency: {best_latency}") | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
| print(f"Best config: {best_config}") | ||||||||||
| print(f"Best tflops: {total_flops / best_latency * 1e-9}") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| if __name__ == "__main__": | ||||||||||
| parser = argparse.ArgumentParser() | ||||||||||
| parser.add_argument("--m", type=int, default=512, help="Matrix dimension M") | ||||||||||
| parser.add_argument("--n", type=int, default=512, help="Matrix dimension N") | ||||||||||
| parser.add_argument("--k", type=int, default=512, help="Matrix dimension K") | ||||||||||
| parser.add_argument("--tune", action="store_true", help="Enable tuning") | ||||||||||
| args = parser.parse_args() | ||||||||||
|
|
||||||||||
| M, N, K = args.m, args.n, args.k | ||||||||||
| main(M, N, K, args.tune) | ||||||||||
| # main(M, N, K, True) | ||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,6 @@ numpy>=1.23.5 | |
| tqdm>=4.62.3 | ||
| typing_extensions>=4.10.0 | ||
| cloudpickle | ||
| ml_dtypes | ||
| ml_dtypes>=0.5.3 | ||
| psutil | ||
| torch | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tail handling missing in _convert_test; potential OOB reads.
You launch ceildiv tiles but copy full tiles without guards; assert divisibility or add masked/tail copies.
Minimal safeguard:
If you prefer tails, I can draft masked copies.
📝 Committable suggestion
🤖 Prompt for AI Agents