diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index f36f02908..c5588d516 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -110,11 +110,11 @@ def test_fp4_fp16_convert_close(): def get_configs(): - block_M = [128] - block_N = [128, 256] - block_K = [128] - num_stages = [2] - threads = [256] + block_M = [64, 128] + block_N = [64, 128] + block_K = [128, 256] + num_stages = [1, 2] + threads = [128, 256] splits = [1] _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) @@ -239,11 +239,7 @@ def main( if tune: - @autotune( - configs=get_configs(), - keys=["block_M", "block_N", "block_K", "num_stages", "threads", "split"], - warmup=10, - rep=10) + @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[2]) def kernel(block_M=None, block_N=None, @@ -251,7 +247,7 @@ def kernel(block_M=None, num_stages=None, threads=None, split=None): - return kernel_func(block_M, block_N, block_K, num_stages, threads, split) + return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func return kernel() else: diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py new file mode 100644 index 000000000..52ee8216f --- /dev/null +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -0,0 +1,200 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import * +from tvm import tir +import itertools +import torch +import argparse + + +def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "int8" + assert val.dtype == "uint8" + + mask = tir.const((1 << nbit) - 1, "uint8") + + i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask + + i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8")) + i8 = i8_shifted >> tir.const(4, "int8") + return i8 + + +def get_configs(): + iter_params = dict( + block_M=[64, 128], + block_N=[64, 128], + block_K=[128, 256], + num_stages=[1, 2], + threads=[128, 256, 512], + ) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@tilelang.jit(out_idx=[1]) +def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): + num_elems_per_byte = 8 // num_bits + storage_dtype = "uint8" + B_shape = (N, K // num_elems_per_byte) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + @T.prim_func + def main( + B: T.Tensor(B_shape, storage_dtype), + C: T.Tensor((N, K), in_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype) + + for k in T.Pipelined(0, T.ceildiv(K, block_K), num_stages=1): + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, C[bx * block_N, k * block_K]) + + return main + + +def torch_convert(tensor): + + def _convert(val, pos): + assert val.dtype == torch.uint8 + val = val.view(torch.int8) + mask = (1 << 4) - 1 + i4_shifted = ((val >> (pos * 4)) & mask) + i4 = ((i4_shifted << 4) >> 4) + + return i4.view(torch.int8) + + N = tensor.shape[0] + K = tensor.shape[1] + new_tensor = torch.empty(N, K * 2, dtype=torch.int8, device=tensor.device) + for i in range(new_tensor.shape[0]): + for j in range(new_tensor.shape[1]): + new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2) + return new_tensor + + +def ref_program(A, qB): + dtypeC = "int32" + B = torch_convert(qB) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C.transpose(0, 1) + + +def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): + + @tilelang.jit(out_idx=[2]) + def kernel_func(block_M, block_N, block_K, num_stages, threads): + num_elems_per_byte = 8 // num_bits + storage_dtype = "uint8" + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_local_shape = (block_N, block_K) + + assert K % (block_K) == 0 + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, storage_dtype), + Ct: T.Tensor((N, M), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment(B_shared_shape, storage_dtype) + B_dequantize_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) + B_dequantize_prev_local = T.alloc_fragment(B_dequantize_local_shape, in_dtype) + Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) + Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) + + T.annotate_layout({ + B_shared: tilelang.layout.make_swizzled_layout(B_shared), + Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), + }) + + T.clear(Ct_local) + for k in T.Pipelined(K // block_K, num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + T.copy(B_shared, B_local) + for i, j in T.Parallel(block_N, block_K): + B_dequantize_local[i, j] = _tir_u8_to_i4_to_i8( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=in_dtype, + ) + T.copy(B_dequantize_local, B_dequantize_prev_local) + T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) + T.copy(Ct_local, Ct_shared) + T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, + by * block_M:(by + 1) * block_M]) + + return main + + if tune: + + @autotune(configs=get_configs(), warmup=10, rep=10) + @tilelang.jit(out_idx=[2]) + def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None): + return kernel_func(block_M, block_N, block_K, num_stages, threads).prim_func + + return kernel() + + else: + + def kernel(block_M, block_N, block_K, num_stages, threads): + return kernel_func(block_M, block_N, block_K, num_stages, threads) + + return kernel + + +def main(m=128, n=256, k=256, tune=False): + total_flops = 2 * m * n * k + if (not tune): + kernel = matmul_int8xint4( + m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) + print("All checks pass.") + + latency = profiler.do_bench(warmup=50) + print(f"Tilelang: {latency} ms") + + else: + best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune) + best_latency = best_result.latency + best_config = best_result.config + print(f"Bset latency: {best_latency}") + print(f"Best config: {best_config}") + print(f"Best tflops: {total_flops / best_latency * 1e-9}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=512, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=512, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=512, help="Matrix dimension K") + parser.add_argument("--tune", action="store_true", help="Enable tuning") + args = parser.parse_args() + + M, N, K = args.m, args.n, args.k + main(M, N, K, args.tune) + # main(M, N, K, True) diff --git a/examples/dequantize_gemm/test_example_dequantize_gemm.py b/examples/dequantize_gemm/test_example_dequantize_gemm.py index 6276f57ef..9ced0a8ed 100644 --- a/examples/dequantize_gemm/test_example_dequantize_gemm.py +++ b/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -4,6 +4,7 @@ import example_dequant_gemm_fp4_hopper import example_dequant_gemm_bf16_mxfp4_hopper import example_dequant_gemm_bf16_mxfp4_hopper_tma +import example_dequant_gemm_w4a8 @tilelang.testing.requires_cuda @@ -29,5 +30,10 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma(): example_dequant_gemm_bf16_mxfp4_hopper_tma.main() +@tilelang.testing.requires_cuda +def test_example_dequant_gemm_w4a8(): + example_dequant_gemm_w4a8.main() + + if __name__ == "__main__": tilelang.testing.main() diff --git a/requirements.txt b/requirements.txt index 35945f839..f69a5259a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,6 @@ numpy>=1.23.5 tqdm>=4.62.3 typing_extensions>=4.10.0 cloudpickle -ml_dtypes +ml_dtypes>=0.5.3 psutil torch