diff --git a/README.md b/README.md index 0c0769e7d..0a3cf381b 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ import tilelang.language as T # target currently can be "cuda" or "hip" or "cpu". # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): @T.prim_func def matmul_relu_kernel( diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py index fff65b44f..e645ae147 100644 --- a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -40,9 +40,9 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] - dtype = "float16" - accum_dtype = "float" - block_mask_dtype = "bool" + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.bool def kernel_func(block_M, block_N, num_stages, threads): @T.macro diff --git a/benchmark/mamba2/benchmark_mamba_chunk_scan.py b/benchmark/mamba2/benchmark_mamba_chunk_scan.py index a3ed72b1d..c9f5cec67 100644 --- a/benchmark/mamba2/benchmark_mamba_chunk_scan.py +++ b/benchmark/mamba2/benchmark_mamba_chunk_scan.py @@ -202,8 +202,8 @@ def chunk_scan_fwd( num_stages=2, threads=128, ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py index 6ca1402d7..643c1fd5e 100644 --- a/benchmark/matmul/benchmark_matmul.py +++ b/benchmark/matmul/benchmark_matmul.py @@ -62,9 +62,9 @@ def get_configs(args, kwargs): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, ).with_arch(arch) func = carve_template.equivalent_function() @@ -155,8 +155,8 @@ def matmul( # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( diff --git a/benchmark/matmul/benchmark_matmul_intrinsic.py b/benchmark/matmul/benchmark_matmul_intrinsic.py index 010ce87f7..4ef860c21 100644 --- a/benchmark/matmul/benchmark_matmul_intrinsic.py +++ b/benchmark/matmul/benchmark_matmul_intrinsic.py @@ -49,22 +49,22 @@ def tl_matmul( enable_rasteration=False, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config - # chunk = 32 if in_dtype == "float16" else 64 + # chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" block_M = block_row_warps * warp_row_tiles @@ -194,9 +194,9 @@ def get_configs(args, kwargs): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, ).with_arch(arch) func = carve_template.equivalent_function() @@ -251,9 +251,9 @@ def matmul( M, N, K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, with_roller=False, block_row_warps=None, block_col_warps=None, @@ -295,9 +295,9 @@ def kernel(): args = parser.parse_args() M, N, K = args.m, args.n, args.k - in_dtype = args.dtype - out_dtype = "float32" if in_dtype == "int8" else "float16" - accum_dtype = "float32" if in_dtype == "int8" else "float16" + in_dtype = T.dtype(args.dtype) + out_dtype = T.float32 if in_dtype == T.int8 else T.float16 + accum_dtype = T.float32 if in_dtype == T.int8 else T.float16 with_roller = args.with_roller with_roller = True # Compute total floating-point operations diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py index 22b5d13cf..7ecffc26a 100644 --- a/benchmark/matmul/benchmark_matmul_sp.py +++ b/benchmark/matmul/benchmark_matmul_sp.py @@ -262,7 +262,7 @@ def main( total_flops = 2 * M * N * K # matmul(...) returns (best_latency, best_config, ref_latency) - best_result = matmul_sp(M, N, K, "float16", args.accum_dtype) + best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype) best_latency = best_result.latency best_config = best_result.config A = torch.randn(M, K, dtype=torch.float16, device="cuda") diff --git a/benchmark/matmul_fp8/benchmark_matmul.py b/benchmark/matmul_fp8/benchmark_matmul.py index 930e8a6d1..e2e62812f 100644 --- a/benchmark/matmul_fp8/benchmark_matmul.py +++ b/benchmark/matmul_fp8/benchmark_matmul.py @@ -63,9 +63,9 @@ def get_configs(args, kwargs): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, ).with_arch(arch) func = carve_template.equivalent_function() @@ -159,8 +159,8 @@ def matmul( # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float8_e4m3fnuz" if torch.version.hip is not None else "float8_e4m3" - accum_dtype = "float" + dtype = T.float8_e4m3fnuz if torch.version.hip is not None else T.float8_e4m3fn + accum_dtype = T.float32 @T.prim_func def main( diff --git a/docs/deeplearning_operators/elementwise.md b/docs/deeplearning_operators/elementwise.md index 5e1243c26..f3543c02f 100644 --- a/docs/deeplearning_operators/elementwise.md +++ b/docs/deeplearning_operators/elementwise.md @@ -24,7 +24,7 @@ Please note that this tutorial does not delve deeply into the design principles ## Elementwise add in TileLang ```python -def elementwise_add(N, threads=256, dtype="bfloat16"): +def elementwise_add(N, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -43,7 +43,7 @@ Those familiar with CUDA programming might wonder where `threadIdx` fits into th The program can be compiled using the following code: ```python -program = elementwise_add(1024, threads=256, dtype="bfloat16") +program = elementwise_add(1024, threads=256, dtype=T.bfloat16) kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") ``` Launching the kernel is straightforward, just call it directly like a function: @@ -89,7 +89,7 @@ def elementwise_add( In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this: ```python -program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16") +program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16) kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") ``` @@ -102,7 +102,7 @@ TileLang automatically incorporates boundary-checking conditions; however, this When compiling the example below, let's set `N` to 2047: ```python -def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"): +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -176,7 +176,7 @@ While TileLang incorporates various optimizations for the aforementioned case, i In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design. ```python -def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"): +def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): @@ -212,7 +212,7 @@ Aha, this CUDA code aligns closely with conventional programming practices, maki But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations. ```python -def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"): +def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16): @T.prim_func def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py index c0335492b..788aec367 100644 --- a/examples/amd/example_amd_flash_attn_bwd.py +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -87,8 +87,8 @@ def fast_flashattn( head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 vec_size = qk_coalesced_width v_vec_size = v_coalesced_width @@ -109,7 +109,7 @@ def main( num_q_blocks = T.ceildiv(seq_len, block_M) - bx_loop_var = T.alloc_var("int32") + bx_loop_var = T.alloc_var(T.int32) bx_loop_var = b_split with T.While(bx_loop_var < num_q_blocks): @@ -236,8 +236,8 @@ def get_bwd_configs(): @tilelang.jit(out_idx=[2]) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 32 @@ -280,8 +280,8 @@ def flashattn_bwd( head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd_kernel( @@ -368,8 +368,8 @@ def flash_bwd_kernel( @tilelang.jit(out_idx=[1]) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 64 diff --git a/examples/amd/example_amd_flash_attn_fwd.py b/examples/amd/example_amd_flash_attn_fwd.py index bbb275578..ca9c361ff 100644 --- a/examples/amd/example_amd_flash_attn_fwd.py +++ b/examples/amd/example_amd_flash_attn_fwd.py @@ -100,8 +100,8 @@ def fast_flashattn( head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 vec_size = qk_coalesced_width v_vec_size = v_coalesced_width @@ -121,7 +121,7 @@ def main( num_q_blocks = T.ceildiv(seq_len, block_M) - bx = T.alloc_var("int32") + bx = T.alloc_var(T.int32) bx = b_split with T.While(bx < num_q_blocks): diff --git a/examples/analyze/README.md b/examples/analyze/README.md index 8171d8826..9ec0a6875 100644 --- a/examples/analyze/README.md +++ b/examples/analyze/README.md @@ -21,9 +21,9 @@ M = N = K = 1024 def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128): @T.prim_func - def main(A: T.Tensor((M, K), "float16"), - B: T.Tensor((N, K), "float16"), - C: T.Tensor((M, N), "float")): + def main(A: T.Tensor((M, K), T.float16), + B: T.Tensor((N, K), T.float16), + C: T.Tensor((M, N), T.float)): # ... (kernel definition) return main @@ -40,9 +40,9 @@ from tilelang.carver.arch import CUDA def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128): @T.prim_func - def main(data: T.Tensor((N, H, W, C), "float16"), - kernel: T.Tensor((K, K, C, F), "float16"), - out: T.Tensor((N, (H-K+1), (W-K+1), F), "float")): + def main(data: T.Tensor((N, H, W, C), T.float16), + kernel: T.Tensor((K, K, C, F), T.float16), + out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)): # ... (convolution kernel definition) return main diff --git a/examples/analyze/example_conv_analyze.py b/examples/analyze/example_conv_analyze.py index b90be1435..db21e02f6 100644 --- a/examples/analyze/example_conv_analyze.py +++ b/examples/analyze/example_conv_analyze.py @@ -25,12 +25,12 @@ def check_hopper(): return False -def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): +def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 is_hopper = check_hopper() @T.prim_func diff --git a/examples/analyze/example_gemm_analyze.py b/examples/analyze/example_gemm_analyze.py index e28440e1b..0367af126 100644 --- a/examples/analyze/example_gemm_analyze.py +++ b/examples/analyze/example_gemm_analyze.py @@ -15,8 +15,8 @@ def kernel( thread_num=None, enable_rasteration=None, ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def matmul( diff --git a/examples/attention_sink/benchmark_gqa_sink_fwd.py b/examples/attention_sink/benchmark_gqa_sink_fwd.py index 3538adc38..211ef1d18 100644 --- a/examples/attention_sink/benchmark_gqa_sink_fwd.py +++ b/examples/attention_sink/benchmark_gqa_sink_fwd.py @@ -1,6 +1,7 @@ import torch import argparse from tilelang.profiler import do_bench +from tilelang import language as T import triton import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor @@ -135,7 +136,8 @@ def main( dtype: str = "float16", tune: bool = False, ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: print("Using sliding window attention.") assert window_size <= seq_q diff --git a/examples/attention_sink/benchmark_mha_sink_fwd.py b/examples/attention_sink/benchmark_mha_sink_fwd.py index 76997d84b..50747e6b0 100644 --- a/examples/attention_sink/benchmark_mha_sink_fwd.py +++ b/examples/attention_sink/benchmark_mha_sink_fwd.py @@ -1,6 +1,7 @@ import torch import argparse from tilelang.profiler import do_bench +from tilelang import language as T import triton import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor @@ -131,7 +132,8 @@ def main( dtype: str = "float16", tune: bool = False, ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: print("Using sliding window attention.") assert window_size <= seq_q diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index 5af787a12..541baca04 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -37,7 +37,7 @@ def flashattn_fwd( block_N=64, num_stages=1, threads=128, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" @@ -49,7 +49,7 @@ def flashattn_fwd( head_kv = heads // groups q_shape = [batch, heads, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 @T.prim_func def flash_fwd( @@ -140,8 +140,8 @@ def flash_fwd( tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 32 @@ -179,8 +179,8 @@ def make_dq_layout(dQ): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 64 @@ -204,7 +204,7 @@ def flash_bwd_post( tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, } ) -def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype="float16"): # None for full attention +def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype=T.float16): # None for full attention if sm_scale is None: sm_scale = (1.0 / dim) ** 0.5 scale = sm_scale * 1.44269504 # log2(e) @@ -212,7 +212,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale head_kv = heads // groups q_shape = [batch, heads, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 block_M, block_N, num_stages, threads = get_bwd_configs() @@ -309,8 +309,8 @@ def flash_bwd( @tilelang.jit(out_idx=-1) -def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len] @T.prim_func @@ -346,7 +346,7 @@ def maybe_contiguous(x): q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)] BATCH, H, N_CTX, D_HEAD = q.shape - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) o, lse = kernel(q, k, v, sinks) ctx.save_for_backward(q, k, v, sinks, o, lse) @@ -359,7 +359,7 @@ def backward(ctx, do): q, k, v, sinks, o, lse = ctx.saved_tensors BATCH, H, N_CTX, D_HEAD = q.shape groups = ctx.groups - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) @@ -440,7 +440,8 @@ def main( window_size: Optional[int] = None, dtype: str = "float16", ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: print("Using sliding window attention.") assert window_size <= N_CTX @@ -472,8 +473,8 @@ def main( # Checks rtol, atol = { - "float16": (1e-2, 1e-2), - "bfloat16": (2e-2, 2e-2), + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), }[dtype] assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index feb5844f7..df157cd0f 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -41,7 +41,7 @@ def flashattn( block_N=128, num_stages=2, threads=256, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" @@ -53,7 +53,7 @@ def flashattn( head_kv = heads // groups q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, head_kv, seq_kv, dim] - accum_dtype = "float" + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" @@ -263,10 +263,11 @@ def main( dim: int = 128, groups: int = 8, window_size: Optional[int] = None, - dtype: str = "float16", + dtype: T.dtype = T.float16, tune: bool = False, ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: print("Using sliding window attention.") assert window_size <= seq_q diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index 155c488e6..be405e8bc 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -36,7 +36,7 @@ def flashattn_fwd( block_N=64, num_stages=1, threads=128, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" @@ -46,7 +46,7 @@ def flashattn_fwd( scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 @T.prim_func def flash_fwd( @@ -137,8 +137,8 @@ def flash_fwd( tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 32 @@ -176,8 +176,8 @@ def make_dq_layout(dQ): tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) -def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 64 @@ -208,7 +208,7 @@ def flashattn_bwd( dim, window_size=None, # None for full attention sm_scale=None, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): block_M, block_N, num_stages, threads = get_bwd_configs() @@ -217,7 +217,7 @@ def flashattn_bwd( scale = sm_scale * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - accum_dtype = "float" + accum_dtype = T.float32 if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" @@ -315,8 +315,8 @@ def flash_bwd( @tilelang.jit(out_idx=-1) -def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"): - accum_dtype = "float" +def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: T.dtype = T.float16): + accum_dtype = T.float32 shape = [batch, heads, seq_len] @T.prim_func @@ -346,7 +346,7 @@ class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, sinks, window_size): BATCH, H, N_CTX, D_HEAD = q.shape - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) o, lse = kernel(q, k, v, sinks) ctx.save_for_backward(q, k, v, sinks, o, lse) @@ -364,7 +364,7 @@ def maybe_contiguous(x): return x do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)] - dtype = "float16" if q.dtype == torch.float16 else "bfloat16" + dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16 kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) delta = kernel_prep(o, do) @@ -433,8 +433,9 @@ def ref_program( return output.transpose(1, 2).contiguous() -def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: str = "float16"): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] +def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: T.dtype = T.float16): + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: print("Using sliding window attention.") assert window_size <= N_CTX @@ -466,8 +467,8 @@ def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window # Checks rtol, atol = { - "float16": (1e-2, 1e-2), - "bfloat16": (2e-2, 2e-2), + T.float16: (1e-2, 1e-2), + T.bfloat16: (2e-2, 2e-2), }[dtype] assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index 78ac443b2..f6754bd94 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -35,7 +35,7 @@ def flashattn( block_N=64, num_stages=1, threads=128, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" @@ -45,7 +45,7 @@ def flashattn( scale = sm_scale * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - accum_dtype = "float" + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" @@ -246,10 +246,11 @@ def main( seq_kv: int = 256, dim: int = 128, window_size: Optional[int] = None, - dtype: str = "float16", + dtype: T.dtype = T.float16, tune: bool = False, ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: print("Using sliding window attention.") assert window_size <= seq_q @@ -308,7 +309,7 @@ def main( parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") parser.add_argument("--dim", type=int, default=128, help="dim") parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") - parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index decdc8f4f..ecaf2ce33 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -36,7 +36,7 @@ def flashattn( block_N=128, num_stages=2, threads=256, - dtype: str = "float16", + dtype: T.dtype = T.float16, ): if window_size is not None: assert window_size % block_N == 0, "window_size must be divisible by block_N" @@ -47,7 +47,7 @@ def flashattn( q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - accum_dtype = "float" + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" @@ -256,10 +256,11 @@ def main( seq_kv: int = 256, dim: int = 128, window_size: Optional[int] = None, - dtype: str = "float16", + dtype: T.dtype = T.float16, tune: bool = False, ): - torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + dtype = T.dtype(dtype) + torch_dtype = dtype.as_torch() if window_size is not None: print("Using sliding window attention.") assert window_size <= seq_q @@ -315,7 +316,7 @@ def main( parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") parser.add_argument("--dim", type=int, default=128, help="dim") parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") - parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") + parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16") parser.add_argument("--tune", action="store_true", help="tune") args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py index 35a044e50..7b8b7b95c 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py @@ -76,13 +76,13 @@ def bitnet_158_int8xint2_decode( reduce_thread=32, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" storage_nbit = 8 num_bits = 2 @@ -94,7 +94,7 @@ def bitnet_158_int8xint2_decode( MAX_TRANSACTION_SIZE_IN_BITS = 128 micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits micro_size_k_compressed = micro_size_k // num_elems_per_byte - storage_dtype = "int8" + storage_dtype = T.int8 block_K = reduce_thread * micro_size_k use_dp4a = True @@ -194,12 +194,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): # interleave weight numpy implementation -def interleave_weight(qweight, nbits=4, target_dtype="float16"): - assert target_dtype in ["float16", "int8"] +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] # reinterpret the data type of qweight to int32 qweight = qweight.view(np.int32) new_qweight = np.zeros_like(qweight) - bits_stride = 8 if target_dtype == "int8" else 16 + bits_stride = 8 if target_dtype == T.int8 else 16 mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f num_groups = 32 // bits_stride elems_per_group = bits_stride // nbits @@ -209,7 +209,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift - if nbits == 1 and target_dtype == "int8": + if nbits == 1 and target_dtype == T.int8: # special handling for 1b interleave n16_weight = new_qweight & np.int32(0xF0F00F0F) n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 @@ -217,12 +217,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 return n16_weight.view(np.int8) - elif nbits == 2 and target_dtype == "float16": + elif nbits == 2 and target_dtype == T.float16: n8_weight = new_qweight & np.int32(0xFF0000FF) n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 return n8_weight.view(np.int8) - elif nbits == 1 and target_dtype == "float16": + elif nbits == 1 and target_dtype == T.float16: n8_weight = new_qweight & 0xF000000F n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 @@ -259,4 +259,4 @@ def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, if __name__ == "__main__": - assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, "int8", "int32", "int32") + assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py index d68a01286..8c3373982 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -88,9 +88,9 @@ def bitnet_158_int8xint2_prefill( Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C. The returned prim_func expects: - - A: shape (M, K) with dtype `in_dtype` ("float16" or "int8"). + - A: shape (M, K) with dtype `in_dtype` (T.float16 or T.int8). - B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte). - - C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32"). + - C: output buffer shape (M, N) with dtype `out_dtype` (T.float16, T.float32, or T.int32). Details: - Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter. @@ -98,15 +98,15 @@ def bitnet_158_int8xint2_prefill( - block_row_warps, block_col_warps: number of warps per block in row/col. - warp_row_tiles, warp_col_tiles: tiles per warp. - chunk: K-sized chunk per block (block_K). - - micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32"). + - micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == T.int32). - Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior. - Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values. Parameters: M, N, K (int): Global matrix dimensions. - in_dtype (str): Input and decoded B element dtype; "float16" or "int8". - out_dtype (str): Output C dtype; one of "float16", "float32", "int32". - accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32"). + in_dtype (str): Input and decoded B element dtype; T.float16 or T.int8. + out_dtype (str): Output C dtype; one of T.float16, T.float32, T.int32. + accum_dtype (str): Accumulator dtype used by MMA (e.g., T.int32). fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used). block_row_warps (int): Warps in block row dimension. block_col_warps (int): Warps in block column dimension. @@ -118,18 +118,18 @@ def bitnet_158_int8xint2_prefill( T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution. """ assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if accum_dtype == "int32": + if accum_dtype == T.int32: micro_size_k = 32 num_elems_per_byte = 4 @@ -138,7 +138,7 @@ def bitnet_158_int8xint2_prefill( local_size_compressed = local_size // num_elems_per_byte shared_scope = "shared.dyn" - storage_dtype = "int8" + storage_dtype = T.int8 # Pipeline Stage stage = 2 @@ -317,12 +317,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): # interleave weight numpy implementation -def interleave_weight(qweight, nbits=4, target_dtype="float16"): - assert target_dtype in ["float16", "int8"] +def interleave_weight(qweight, nbits=4, target_dtype=T.float16): + assert target_dtype in [T.float16, T.int8] # reinterpret the data type of qweight to int32 qweight = qweight.view(np.int32) new_qweight = np.zeros_like(qweight) - bits_stride = 8 if target_dtype == "int8" else 16 + bits_stride = 8 if target_dtype == T.int8 else 16 mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f num_groups = 32 // bits_stride elems_per_group = bits_stride // nbits @@ -332,7 +332,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift - if nbits == 1 and target_dtype == "int8": + if nbits == 1 and target_dtype == T.int8: # special handling for 1b interleave n16_weight = new_qweight & np.int32(0xF0F00F0F) n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 @@ -340,12 +340,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 return n16_weight.view(np.int8) - elif nbits == 2 and target_dtype == "float16": + elif nbits == 2 and target_dtype == T.float16: n8_weight = new_qweight & np.int32(0xFF0000FF) n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 return n8_weight.view(np.int8) - elif nbits == 1 and target_dtype == "float16": + elif nbits == 1 and target_dtype == T.float16: n8_weight = new_qweight & 0xF000000F n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 @@ -382,4 +382,4 @@ def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype if __name__ == "__main__": - assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, "int8", "int32", "int32") + assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, T.int8, T.int32, T.int32) diff --git a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py index f2a0e2e7e..e3d35df4b 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py @@ -38,18 +38,18 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -57,7 +57,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 64 warp_col_tiles = 64 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -183,7 +183,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None print(src_code) - if in_dtype == "int8": + if in_dtype == T.int8: A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) else: @@ -209,12 +209,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") - assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) if __name__ == "__main__": # bitblas.testing.main() - # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") - # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") - assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + # assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) + assert_tl_matmul_correctness(16384, 16384, 16384, T.int8, T.int32, T.int32) diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index afb4cc888..934b0b25e 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -41,9 +41,9 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] - dtype = "float16" - accum_dtype = "float" - block_mask_dtype = "bool" + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.bool def kernel_func(block_M, block_N, num_stages, threads): @T.macro diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index 99418d5fd..77a29ebe2 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -14,8 +14,8 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // heads_kv @tilelang.jit( @@ -43,9 +43,9 @@ def flash_attn_split( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), glse: T.Tensor([batch, heads, num_split], accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype), ): @@ -139,7 +139,7 @@ def combine( lse_logsum_local = T.alloc_local([1], accum_dtype) lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - max_split = T.alloc_local([1], "int32") + max_split = T.alloc_local([1], T.int32) T.annotate_layout( { @@ -177,9 +177,9 @@ def main( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - block_table: T.Tensor(shape_block_table, "int32"), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), glse: T.Tensor([batch, heads, num_split], accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype), Output: T.Tensor(shape_o, dtype), diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index 8b5cde38d..257f41543 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -11,8 +11,8 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // heads_kv @tilelang.jit( @@ -35,9 +35,9 @@ def flash_attn_split( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), glse: T.Tensor([batch, heads, num_split], accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype), ): @@ -128,7 +128,7 @@ def combine( lse_logsum_local = T.alloc_local([1], accum_dtype) lse_max_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype) - max_split = T.alloc_local([1], "int32") + max_split = T.alloc_local([1], T.int32) T.annotate_layout( { @@ -166,9 +166,9 @@ def main( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, "int32"), - cache_seqlens: T.Tensor([batch], "int32"), - # actual_num_blocks: T.Tensor([batch], "int32"), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), glse: T.Tensor([batch, heads, num_split], accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype), Output: T.Tensor(shape_o, dtype), diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index 0d759211a..2957f8c97 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -13,8 +13,8 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // heads_kv @tilelang.jit( @@ -37,8 +37,8 @@ def flash_attn_split( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), glse: T.Tensor([batch, heads, num_split], accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype), ): @@ -156,8 +156,8 @@ def main( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, "bool"), - cache_seqlens: T.Tensor([batch], "int32"), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), glse: T.Tensor([batch, heads, num_split], accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype), Output: T.Tensor(shape_o, dtype), diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 0cbef5e0c..b8a34e45d 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -93,7 +93,7 @@ def supply_program(params: List[KernelParam]): ) @tilelang.jit(out_idx=[-1]) def blocksparse_matmul( - M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float" + M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 ): block_mask_shape = (M // block_M, N // block_N, K // block_K) diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index ec15b292e..6bde50c51 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -5,8 +5,8 @@ from tilelang.utils.tensor import torch_assert_close # support bfloat16, float, float16 -dtype = "bfloat16" -accum_dtype = "float" +dtype = T.bfloat16 +accum_dtype = T.float32 @tilelang.jit(out_idx=[2, 3]) @@ -18,8 +18,8 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): @T.prim_func def group_per_split_token_cast( X: T.Tensor((M, N), dtype), - batch_sizes: T.Tensor((BG,), "int32"), - X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), + batch_sizes: T.Tensor((BG,), T.int32), + X_fp8: T.Tensor((BG, M_max, N), T.float8_e4m3fn), X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype), ): with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): @@ -30,8 +30,8 @@ def group_per_split_token_cast( y_amax_local = T.alloc_fragment((blk_m,), accum_dtype) y_s_local = T.alloc_fragment((blk_m,), accum_dtype) y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) - y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") - row_offset = T.alloc_fragment((1,), "int32") + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) + row_offset = T.alloc_fragment((1,), T.int32) T.annotate_layout( { @@ -163,11 +163,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tenso def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): if batch_sizes is None: batch_sizes = [2048, 6144] - if dtype == "float": + if dtype == T.float: x = torch.randn(M, N, device="cuda", dtype=torch.float32) - elif dtype == "float16": + elif dtype == T.float16: x = torch.randn(M, N, device="cuda", dtype=torch.float16) - elif dtype == "bfloat16": + elif dtype == T.bfloat16: x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) else: raise ValueError(f"Unsupported dtype: {dtype}") diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py index 45281ab14..aa6d14884 100644 --- a/examples/cast/example_per_token_cast_to_fp8.py +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -7,14 +7,14 @@ @tilelang.jit(out_idx=[1, 2]) def per_token_cast_to_fp8(M, N, blk_m): - dtype = "float" + dtype = T.float group_size = 128 fp8_min = -448.0 fp8_max = 448.0 @T.prim_func def per_token_cast( - X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype) + X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), T.float8_e4m3fn), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype) ): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): row = bx @@ -23,7 +23,7 @@ def per_token_cast( y_amax_local = T.alloc_fragment((blk_m,), dtype) y_s_local = T.alloc_fragment((blk_m,), dtype) y_q_local = T.alloc_fragment((blk_m, group_size), dtype) - y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") + y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn) T.annotate_layout( { diff --git a/examples/compile_flags/usecase.py b/examples/compile_flags/usecase.py deleted file mode 100644 index 80e2b784b..000000000 --- a/examples/compile_flags/usecase.py +++ /dev/null @@ -1,54 +0,0 @@ -import tilelang -import tilelang.language as T - - -# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Initialize Kernel Context - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - T.copy(A[by * block_M, ko * block_K], A_shared) - T.copy(B[ko * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -M = 1024 -N = 1024 -K = 1024 -block_M = 128 -block_N = 128 -block_K = 32 - -func = matmul(M, N, K, block_M, block_N, block_K) - -jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr") -# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"]) -# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"]) - -import torch - -a = torch.randn(M, K, device="cuda", dtype=torch.float16) -b = torch.randn(K, N, device="cuda", dtype=torch.float16) - -c = jit_kernel(a, b) - -print(c) - -ref_c = a @ b - -torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) -print("Kernel output matches PyTorch reference.") diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index a84e5878a..ffd3972fb 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -25,12 +25,12 @@ def main(A, B): @tilelang.jit(out_idx=[2]) -def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): +def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 is_hopper = check_hopper() @T.prim_func diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 600b608a3..59588ac4f 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -75,13 +75,13 @@ def get_heuristic_config() -> dict: @tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[2]) def convolution( - N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float" + N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32 ): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 is_hopper = check_hopper() @T.prim_func diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index 8aba91406..18467a811 100644 --- a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -20,11 +20,11 @@ def tl_gemm( accum_dtype, ): assert in_dtype in [ - "float8_e4m3", + T.float8_e4m3fn, ], "Currently only float8_e4m3 is supported" assert out_dtype in [ - "bfloat16", - "float32", + T.bfloat16, + T.float32, ], "Currently only float16 and float32 are supported" group_size = 128 @@ -44,14 +44,14 @@ def main( A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor((M, N), out_dtype), - scales_a: T.Tensor(Scales_A_shape, "float32"), - scales_b: T.Tensor(Scales_B_shape, "float32"), + scales_a: T.Tensor(Scales_A_shape, T.float32), + scales_b: T.Tensor(Scales_B_shape, T.float32), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) - Scale_C_shared = T.alloc_shared((block_M), "float32") + Scale_C_shared = T.alloc_shared((block_M), T.float32) C_local = T.alloc_fragment(C_shared_shape, accum_dtype) C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) @@ -176,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp def main(): - assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32") + assert_tl_gemm_correctness(1024, 1024, 8192, 128, T.float8_e4m3fn, T.bfloat16, T.float32) if __name__ == "__main__": - for dtype in ["float8_e4m3"]: - for out_dtype in ["bfloat16", "float32"]: + for dtype in [T.float8_e4m3fn]: + for out_dtype in [T.bfloat16, T.float32]: for block_N in [16, 32, 64, 128]: - assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32") + assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, T.float32) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py index 499583798..a9035793b 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -36,8 +36,8 @@ def get_configs(): ) def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128): scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 733ae3c46..0d141b4b3 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -15,8 +15,8 @@ ) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): scale = float(softmax_scale * 1.44269504) # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py index dee05c1e9..23001bde8 100644 --- a/examples/deepseek_mla/example_mla_decode_paged.py +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -17,8 +17,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc if softmax_scale is None: softmax_scale = (dv + dpe) ** -0.5 scale = float(softmax_scale * 1.44269504) # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = h_q // h_kv VALID_BLOCK_H = min(block_H, kv_group_num) assert h_kv == 1, "h_kv must be 1" @@ -30,8 +30,8 @@ def flash_mla_kernel( Q_pe: T.Tensor([batch, h_q, dpe], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + CACHE_SEQLENS: T.Tensor([batch], T.int32), Output: T.Tensor([batch, h_q, dv], dtype), ): with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): @@ -103,8 +103,8 @@ def flash_mla_split_kv_kernel( Q_pe: T.Tensor([batch, h_q, dpe], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - CACHE_SEQLENS: T.Tensor([batch], "int32"), + BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + CACHE_SEQLENS: T.Tensor([batch], T.int32), glse: T.Tensor([batch, h_q, num_split], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), ): @@ -224,8 +224,8 @@ def main_split( Q_pe: T.Tensor([batch, h_q, dpe], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), glse: T.Tensor([batch, h_q, num_split], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), Output: T.Tensor([batch, h_q, dv], dtype), @@ -239,8 +239,8 @@ def main_no_split( Q_pe: T.Tensor([batch, h_q, dpe], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), - block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), - cache_seqlens: T.Tensor([batch], "int32"), + block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32), + cache_seqlens: T.Tensor([batch], T.int32), glse: T.Tensor([batch, h_q, num_split], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), Output: T.Tensor([batch, h_q, dv], dtype), diff --git a/examples/deepseek_mla/example_mla_decode_persistent.py b/examples/deepseek_mla/example_mla_decode_persistent.py index 305fd30ed..b6a1300a2 100644 --- a/examples/deepseek_mla/example_mla_decode_persistent.py +++ b/examples/deepseek_mla/example_mla_decode_persistent.py @@ -16,8 +16,8 @@ ) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py index 3fb90a556..8e317fa00 100644 --- a/examples/deepseek_mla/example_mla_decode_ws.py +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -27,8 +27,8 @@ ) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): sm_scale = float(softmax_scale * 1.44269504) # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" diff --git a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py index 4a1a84cf1..fa39fa498 100644 --- a/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py +++ b/examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py @@ -15,9 +15,9 @@ ) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) - dtype = "float16" - q_dtype = "float8_e4m3" - accum_dtype = "float" + dtype = T.float16 + q_dtype = T.float8_e4m3fn + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" diff --git a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py index ea3f72c50..dadb4b4cb 100644 --- a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py +++ b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py @@ -479,9 +479,9 @@ def tilelang_sparse_attention( q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(block_T, tilelang.math.next_power_of_2(dim)) @@ -876,7 +876,7 @@ def run_benchmark_suite(impl="all"): parser.add_argument("--dim", type=int, default=128, help="Head dimension") parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks") parser.add_argument("--block_size", type=int, default=32, help="Block size") - parser.add_argument("--dtype", type=str, default="float16", help="Data type (float16 or float32)") + parser.add_argument("--dtype", type=str, default=T.float16, help="Data type (float16 or float32)") parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor") parser.add_argument("--iterations", type=int, default=100, help="Number of iterations") parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") @@ -901,7 +901,7 @@ def run_benchmark_suite(impl="all"): if args.suite: run_benchmark_suite(impl=args.impl) else: - dtype = torch.float16 if args.dtype == "float16" else torch.float32 + dtype = torch.float16 if args.dtype == T.float16 else torch.float32 if args.impl in ["tilelang", "all"]: print("Benchmarking TileLang implementation:") diff --git a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py index 56e98a95b..41f1dd86b 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -49,9 +49,9 @@ def tilelang_kernel_fwd( o_slc_shape = [batch, seq_len, heads, dim] lse_slc_shape = [batch, seq_len, heads] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) @@ -170,8 +170,8 @@ def tilelang_kernel_bwd_dkv( block_size=64, groups=1, selected_blocks=16, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): if scale is None: sm_scale = (1.0 / dim) ** 0.5 @@ -217,7 +217,7 @@ def flash_bwd_dkv( DO_slc: T.Tensor(do_slc_shape, dtype), DK: T.Tensor(dk_shape, dtype), DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + BlockMask: T.Tensor(block_mask_shape, T.int32), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -340,8 +340,8 @@ def tilelang_kernel_bwd_dqkv( block_size=64, groups=1, selected_blocks=16, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): if scale is None: sm_scale = (1.0 / dim) ** 0.5 @@ -388,7 +388,7 @@ def flash_bwd_dqkv( DQ: T.Tensor(dq_shape, dtype), DK: T.Tensor(dk_shape, dtype), DV: T.Tensor(dv_shape, dtype), - BlockMask: T.Tensor(block_mask_shape, "int32"), + BlockMask: T.Tensor(block_mask_shape, T.int32), ): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): K_shared = T.alloc_shared([BS, BK], dtype) @@ -505,8 +505,8 @@ def tilelang_kernel_preprocess( heads, seq_len, dim, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, blk=32, ): from tilelang import language as T @@ -548,7 +548,7 @@ def tilelang_kernel_block_mask( seq_len, selected_blocks, block_size, - dtype="int32", + dtype=T.int32, ): from tilelang import language as T diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index 38fc51a9f..b7eea5804 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -35,9 +35,9 @@ def native_sparse_attention( q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1 - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index a8dd26b63..ad36b1040 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -26,9 +26,9 @@ def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, b q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] block_indices_shape = [batch, seq_len, head_kv, selected_blocks] - block_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py index af87db8b2..b52ebe42e 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -38,12 +38,12 @@ def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scal block_counts_shape = [c_seq_len, head_kv] offsets_shape = [batch + 1] token_indices_shape = [c_seq_len, 2] - block_indices_dtype = "int32" - block_counts_dtype = "int32" - offsets_dtype = "int32" - token_indices_dtype = "int32" - dtype = "float16" - accum_dtype = "float" + block_indices_dtype = T.int32 + block_counts_dtype = T.int32 + offsets_dtype = T.int32 + token_indices_dtype = T.int32 + dtype = T.float16 + accum_dtype = T.float32 block_S = block_size block_T = min(128, tilelang.math.next_power_of_2(dim)) diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index 305e2afc4..01ad0a734 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -97,9 +97,9 @@ def mqa_attn_return_logits( ): if block_Q is None: block_Q = 128 // heads - dtype = "float8_e4m3" - accum_dtype = "float" - index_dtype = "int32" + dtype = T.float8_e4m3fn + accum_dtype = T.float32 + index_dtype = T.int32 seq_len = T.dynamic("seq_len") seq_len_kv = T.dynamic("seq_len_kv") @@ -178,8 +178,8 @@ def clean_logits_( seq_len = T.dynamic("seq_len") seq_len_kv = T.dynamic("seq_len_kv") - dtype = "float" - indices_dtype = "int32" + dtype = T.float + indices_dtype = T.int32 @T.prim_func def clean_logits_kernel( diff --git a/examples/deepseek_v32/inference/kernel.py b/examples/deepseek_v32/inference/kernel.py index 262343536..25abf15d5 100644 --- a/examples/deepseek_v32/inference/kernel.py +++ b/examples/deepseek_v32/inference/kernel.py @@ -11,21 +11,21 @@ tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, } -FP8 = "float8_e4m3" -BF16 = "bfloat16" -FP32 = "float32" +FP8 = T.float8_e4m3fn +BF16 = T.bfloat16 +FP32 = T.float32 def fast_log2_ceil(x): - bits_x = T.reinterpret("uint32", x) + bits_x = T.reinterpret(T.uint32, x) exp_x = (bits_x >> 23) & 0xFF man_bits = bits_x & ((1 << 23) - 1) - return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + return T.Cast(T.int32, exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) def fast_pow2(x): bits_x = (x + 127) << 23 - return T.reinterpret("float32", bits_x) + return T.reinterpret(T.float32, bits_x) def fast_round_scale(amax, fp8_max_inv): @@ -107,8 +107,8 @@ def act_quant(x: torch.Tensor, @tilelang.jit(pass_configs=pass_configs) -def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): - assert out_dtype in [BF16, "float32"] +def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=T.float32): + assert out_dtype in [BF16, T.float32] M = T.dynamic("M") group_size = 128 diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index 1266e70ed..d8035c1ba 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -13,11 +13,11 @@ def preprocess( D, block_ND=32, num_stages=5, - dtype="bfloat16", - accum_dtype="float", + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 shape = [B, S, H, D] @T.prim_func @@ -52,11 +52,11 @@ def postprocess( kv_group=1, block_N=64, threads=128, - dtype="bfloat16", - accum_dtype="float", + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 dkv_shape = [B, S_kv, kv_group, D + D_tail] @T.prim_func @@ -95,15 +95,15 @@ def bwd( block_size=32, num_stages=0, threads=256, - indices_dtype="int32", - dtype="bfloat16", - accum_dtype="float", + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, ): assert is_causal == True, "non-casual is not supported now" assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" - assert dtype == "bfloat16" - assert accum_dtype == "float" - assert indices_dtype == "int32" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 if sm_scale is None: sm_scale = (D + D_tail) ** (-0.5) @@ -116,9 +116,9 @@ def bwd( indices_shape = [B, S, kv_group, topk] delta_shape = [B, S, H] lse_shape = [B, S, H] - assert indices_dtype == "int32" - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 H = H_kv padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index 3b963c751..f9c4d2f04 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -44,9 +44,9 @@ def sparse_mla_fwd( o_shape = [batch, seq_len, heads, dim] indices_shape = [batch, seq_len, kv_group, topk] lse_shape = [batch, seq_len, heads] - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 G = kv_group H = head_kv diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 972160c99..54e1a7209 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -53,9 +53,9 @@ def sparse_mla_fwd( o_shape = [batch, seq_len, heads, dim] indices_shape = [batch, seq_len, kv_group, topk] lse_shape = [batch, seq_len, heads] - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 G = kv_group H = head_kv diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py index cf87f526d..244f74c69 100644 --- a/examples/deepseek_v32/topk_selector.py +++ b/examples/deepseek_v32/topk_selector.py @@ -8,24 +8,24 @@ def convert_to_uint16(x): - hval = T.Cast("float16", x) - bits_uint = T.reinterpret("uint16", hval) + hval = T.Cast(T.float16, x) + bits_uint = T.reinterpret(T.uint16, hval) bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000)) return bits_uint >> 8 def convert_to_uint32(x): - bits_uint = T.reinterpret("uint32", x) + bits_uint = T.reinterpret(T.uint32, x) bits_uint = T.if_then_else( x < 0, - ~bits_uint & T.Cast("uint32", (0xFFFFFFFF)), - bits_uint | T.Cast("uint32", (0x80000000)), + ~bits_uint & T.Cast(T.uint32, (0xFFFFFFFF)), + bits_uint | T.Cast(T.uint32, (0x80000000)), ) return bits_uint @tilelang.jit(pass_configs=pass_configs) -def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): +def tl_topk_impl(topk, in_dtype=T.float32, out_dtype=T.int32): batch = T.dynamic("batch") seq_len = T.dynamic("seq_len") RADIX = 1 << 8 @@ -42,20 +42,20 @@ def tl_topk_kernel( with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): tx = T.get_thread_binding() - s_threshold_bin_id = T.alloc_shared([1], "int32") - s_histogram = T.alloc_shared([RADIX + 1], "int32") - s_num_input = T.alloc_shared([2], "int32") - s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32") - - l_threshold_bin_id = T.alloc_var("int32") - l_new_topk = T.alloc_var("int32") - l_num_input = T.alloc_var("int32") - l_bin_id32 = T.alloc_var("int32") - l_val = T.alloc_var("int32") - l_start_pos = T.alloc_var("int32") - l_start_idx = T.alloc_var("int32") - l_end_idx = T.alloc_var("int32") - l_out_pos = T.alloc_var("int32") + s_threshold_bin_id = T.alloc_shared([1], T.int32) + s_histogram = T.alloc_shared([RADIX + 1], T.int32) + s_num_input = T.alloc_shared([2], T.int32) + s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], T.int32) + + l_threshold_bin_id = T.alloc_var(T.int32) + l_new_topk = T.alloc_var(T.int32) + l_num_input = T.alloc_var(T.int32) + l_bin_id32 = T.alloc_var(T.int32) + l_val = T.alloc_var(T.int32) + l_start_pos = T.alloc_var(T.int32) + l_start_idx = T.alloc_var(T.int32) + l_end_idx = T.alloc_var(T.int32) + l_out_pos = T.alloc_var(T.int32) l_new_topk = topk l_start_idx = starts[bx] @@ -99,7 +99,7 @@ def tl_topk_kernel( input_idx = s * BLOCK_SIZE + tx if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: bin_id = convert_to_uint16(input[bx, input_idx]) - l_bin_id32 = T.Cast("int32", bin_id) + l_bin_id32 = T.Cast(T.int32, bin_id) if l_bin_id32 > l_threshold_bin_id: # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) @@ -128,7 +128,7 @@ def tl_topk_kernel( for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): if s * BLOCK_SIZE + tx < l_num_input: l_bin_id32 = T.Cast( - "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) ) T.atomic_add(s_histogram[l_bin_id32], 1) T.sync_threads() @@ -157,7 +157,7 @@ def tl_topk_kernel( T.sync_threads() if s * BLOCK_SIZE + tx < l_num_input: l_bin_id32 = T.Cast( - "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) ) if l_bin_id32 > l_threshold_bin_id: pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py index ba3e0b4a7..2d9c945b3 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -50,7 +50,7 @@ def matmul( in_dtype, out_dtype, accum_dtype, - source_format="uint", + source_format=T.uint32, num_bits=4, fast_dequant=True, block_M=256, @@ -90,7 +90,7 @@ def matmul( A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte @@ -121,7 +121,7 @@ def matmul( assert func_name is not None, "mxfp_intrin_info is not found" import_source = import_source - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin. @@ -131,13 +131,13 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel. Notes and preconditions: - - Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`. + - Asserts that `in_dtype == "fp4"` and `out_dtype == T.bfloat16`. - The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel. - The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly. - The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -193,7 +193,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared): return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16. @@ -204,7 +204,7 @@ def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the dequantized bfloat16 block into B_dequantize_shared. Constraints: - - Supports only in_dtype="fp4" and out_dtype="bfloat16". + - Supports only in_dtype="fp4" and out_dtype=T.bfloat16. - The helper assumes nbit == 4 and produces bfloat16 values. - The macro uses a fixed test-scale of 0 (no per-element scaling) as written. @@ -212,7 +212,7 @@ def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): """ @@ -228,32 +228,32 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale val (tir.PrimExpr): A uint8 value containing packed FP4 elements. pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. - dtype (str): Target dtype string; must be "bfloat16". + dtype (str): Target dtype string; must be T.bfloat16. Returns: tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. Notes: - - The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8". + - The function asserts `nbit == 4`, `dtype == T.bfloat16`, and that `val.dtype` is T.uint8. - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 bit fields and clamps the computed exponent to fit into 8 bits. """ assert nbit == 4 - assert dtype == "bfloat16" - assert val.dtype == "uint8" - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 - e_bf16 = e_f4 + tir.const(126, "uint16") + e_bf16 = e_f4 + tir.const(126, T.uint16) # Scale is the exponential part, within the representation of uint8 # To handle the overflow, we use the max function to limit the exponential part to 8 bits - e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") + e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16)) + m_f4 = f4 & tir.const(1, T.uint16) val_bf16 = tir.reinterpret( - "bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), ) return val_bf16 @@ -364,7 +364,7 @@ def ref_program_twiddling(A, qB): Returns: torch.Tensor: Result matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -384,7 +384,7 @@ def ref_program_simple(A, qB): Returns: torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -410,15 +410,15 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): """ total_flops = 2 * m * n * k if tune: - kernel = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant) + kernel = matmul(m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, fast_dequant=fast_dequant) else: kernel = matmul( m, n, k, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=4, fast_dequant=fast_dequant, block_M=256, diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py index 1091306c6..cc0375a1d 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -20,31 +20,31 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). + dtype (str): Destination dtype string (must be T.bfloat16). Returns: tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". + - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8. - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. """ assert nbit == 4 - assert dtype == "bfloat16" - assert val.dtype == "uint8" - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 - e_bf16 = e_f4 + tir.const(126, "uint16") + e_bf16 = e_f4 + tir.const(126, T.uint16) # Scale is the exponential part, within the representation of uint8 # To handle the overflow, we may use the min function to limit the exponential part to 8 bits # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") + m_f4 = f4 & tir.const(1, T.uint16) val_bf16 = tir.reinterpret( - "bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), ) return val_bf16 @@ -90,7 +90,7 @@ def matmul( in_dtype, out_dtype, accum_dtype, - source_format="uint", + source_format=T.uint32, num_bits=4, scale_size=32, fast_dequant=True, @@ -116,7 +116,7 @@ def matmul( Parameters: M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). + out_dtype (str): output tensor element type (e.g., T.bfloat16). accum_dtype (str): accumulation type used for the inner GEMM. source_format (str, optional): format string passed to intrinsic selector (default "uint"). num_bits (int, optional): number of bits per quantized element in B (default 4). @@ -141,7 +141,7 @@ def matmul( - An assertion enforces that K % (block_K * split) == 0. """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte A_shape = (M, K) @@ -170,7 +170,7 @@ def matmul( assert func_name is not None, "mxfp_intrin_info is not found" import_source = import_source - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. @@ -181,12 +181,12 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the scaled BF16 results into B_dequantize_shared. Notes: - - This factory only supports in_dtype="fp4" and out_dtype="bfloat16". + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -262,19 +262,19 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k): return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. Notes: - - Only supports in_dtype="fp4" and out_dtype="bfloat16". + - Only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] @T.macro def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): @@ -394,7 +394,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): Returns: torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) @@ -417,7 +417,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): Returns: torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias @@ -441,7 +441,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): No in-place modification is performed on inputs (a local floating copy of B is scaled). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) @@ -469,7 +469,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): No in-place modification is performed on inputs (a local floating copy of B is scaled). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias @@ -498,16 +498,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, if tune: kernel = matmul( - m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias ) else: kernel = matmul( m, n, k, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=4, scale_size=scale_size, block_M=256, diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py index 12395df0a..9e90418bc 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py @@ -20,31 +20,31 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). - dtype (str): Destination dtype string (must be "bfloat16"). + dtype (str): Destination dtype string (must be T.bfloat16). Returns: tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. Notes: - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". + - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8. - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. """ assert nbit == 4 - assert dtype == "bfloat16" - assert val.dtype == "uint8" - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 - e_bf16 = e_f4 + tir.const(126, "uint16") + e_bf16 = e_f4 + tir.const(126, T.uint16) # Scale is the exponential part, within the representation of uint8 # To handle the overflow, we may use the min function to limit the exponential part to 8 bits # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") + m_f4 = f4 & tir.const(1, T.uint16) val_bf16 = tir.reinterpret( - "bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), + T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16), ) return val_bf16 @@ -90,7 +90,7 @@ def matmul( in_dtype, out_dtype, accum_dtype, - source_format="uint", + source_format=T.uint32, num_bits=4, scale_size=32, fast_dequant=True, @@ -116,7 +116,7 @@ def matmul( Parameters: M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). in_dtype (str): element type of A (e.g., "fp4" in this file). - out_dtype (str): output tensor element type (e.g., "bfloat16"). + out_dtype (str): output tensor element type (e.g., T.bfloat16). accum_dtype (str): accumulation type used for the inner GEMM. source_format (str, optional): format string passed to intrinsic selector (default "uint"). num_bits (int, optional): number of bits per quantized element in B (default 4). @@ -141,7 +141,7 @@ def matmul( - An assertion enforces that K % (block_K * split) == 0. """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte A_shape = (M, K) @@ -170,7 +170,7 @@ def matmul( assert func_name is not None, "mxfp_intrin_info is not found" import_source = import_source - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. @@ -181,12 +181,12 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the scaled BF16 results into B_dequantize_shared. Notes: - - This factory only supports in_dtype="fp4" and out_dtype="bfloat16". + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -262,19 +262,19 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. Notes: - - Only supports in_dtype="fp4" and out_dtype="bfloat16". + - Only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] @T.macro def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): @@ -402,7 +402,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): Returns: torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): @@ -427,7 +427,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): Returns: torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert_bit_twiddling(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): @@ -453,7 +453,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): No in-place modification is performed on inputs (a local floating copy of B is scaled). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): @@ -483,7 +483,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): No in-place modification is performed on inputs (a local floating copy of B is scaled). """ - dtypeC = "bfloat16" + dtypeC = T.bfloat16 B = torch_convert(qB) for i in range(B.shape[0]): for j in range(B.shape[1]): @@ -514,16 +514,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, if tune: kernel = matmul( - m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias + m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias ) else: kernel = matmul( m, n, k, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=4, scale_size=scale_size, block_M=256, diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index c2b972a09..37826874b 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -26,7 +26,7 @@ def matmul( from tilelang.quantize import _tir_packed_to_unsigned_convert num_elems_per_byte = 8 // num_bits - storage_dtype = "int8" + storage_dtype = T.int8 storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) A_shape = (M, K) @@ -149,21 +149,21 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" num_bits = 4 num_elems_per_byte = 8 // num_bits - storage_dtype = "int8" + storage_dtype = T.int8 micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -182,7 +182,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles - block_K = 32 if in_dtype == "float16" else 64 + block_K = 32 if in_dtype == T.float16 else 64 chunk = block_K // reduce_k is_smooth_a = False @@ -365,7 +365,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct assert src_code is not None num_bits = 4 num_elems_per_byte = 8 // num_bits - storage_dtype = "int8" + storage_dtype = T.int8 A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) @@ -417,13 +417,13 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct @tilelang.testing.requires_package("bitblas") def test_run_dequantize_gemm(): - run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128) - run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128) + run_gemm(256, 256, 256, T.float16, T.float16, T.float16, 128, 128, 32, num_threads=128) + run_gemm(256, 256, 256, T.int8, T.int32, T.int32, 128, 128, 32, num_threads=128) @tilelang.testing.requires_package("bitblas") def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): - assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, "float16", "float16", "float16", 3) + assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, T.float16, T.float16, T.float16, 3) def main(): diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index 352637de5..79345771d 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -9,22 +9,22 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 - assert dtype == "float16" - assert val.dtype == "uint8" + assert dtype == T.float16 + assert val.dtype == T.uint8 # e_f4 == 0 -> e_f16 = 0 # e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14 # s1e2m1 - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") - e_f16 = e_f4 + tir.const(14, "uint16") - m_f4 = f4 & tir.const(1, "uint16") + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) + e_f16 = e_f4 + tir.const(14, T.uint16) + m_f4 = f4 & tir.const(1, T.uint16) m_f16 = m_f4 val_f16 = tir.reinterpret( - "float16", ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") | m_f16 << tir.const(9, "uint16")).astype("uint16") + T.float16, ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16) | m_f16 << tir.const(9, T.uint16)).astype(T.uint16) ) - # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) + # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, T.float16), val_f16) return val_f16 @@ -60,7 +60,7 @@ def _convert(val, pos): @tilelang.jit(out_idx=[1]) def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 B_shape = (N, K // num_elems_per_byte) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) @@ -98,7 +98,7 @@ def test_fp4_fp16_convert_close(): K, block_N, block_K, - "float16", + T.float16, ) B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) @@ -125,7 +125,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 A_shape = (M, K) B_shape = (N, K // num_elems_per_byte) A_shared_shape = (block_M, block_K) @@ -241,7 +241,7 @@ def kernel(block_M, block_N, block_K, num_stages, threads, split=1): def ref_program(A, qB): - dtypeC = "float16" + dtypeC = T.float16 B = torch_convert(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -252,7 +252,7 @@ def main(m=256, n=256, k=256, tune=False): total_flops = 2 * m * n * k if not tune: - kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( + kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)( block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 ) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) @@ -265,7 +265,7 @@ def main(m=256, n=256, k=256, tune=False): print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: - best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune) + best_result = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune) best_latency = best_result.latency best_config = best_result.config print(f"Best latency: {best_latency}") diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py index 3ff726738..61baa668e 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -9,15 +9,15 @@ def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 - assert dtype == "int8" - assert val.dtype == "uint8" + assert dtype == T.int8 + assert val.dtype == T.uint8 - mask = tir.const((1 << nbit) - 1, "uint8") + mask = tir.const((1 << nbit) - 1, T.uint8) - i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask + i4 = (val >> (pos.astype(T.uint8) * tir.const(nbit, T.uint8))) & mask - i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8")) - i8 = i8_shifted >> tir.const(4, "int8") + i8_shifted = tir.reinterpret(T.int8, i4 << tir.const(4, T.uint8)) + i8 = i8_shifted >> tir.const(4, T.int8) return i8 @@ -35,7 +35,7 @@ def get_configs(): @tilelang.jit(out_idx=[1]) def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 B_shape = (N, K // num_elems_per_byte) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) @@ -85,7 +85,7 @@ def _convert(val, pos): def ref_program(A, qB): - dtypeC = "int32" + dtypeC = T.int32 B = torch_convert(qB) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) @@ -96,7 +96,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune @tilelang.jit(out_idx=[2]) def kernel_func(block_M, block_N, block_K, num_stages, threads): num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 A_shape = (M, K) B_shape = (N, K // num_elems_per_byte) A_shared_shape = (block_M, block_K) @@ -166,7 +166,7 @@ def kernel(block_M, block_N, block_K, num_stages, threads): def main(m=128, n=256, k=256, tune=False): total_flops = 2 * m * n * k if not tune: - kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( + kernel = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)( block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 ) profiler = kernel.get_profiler() @@ -177,7 +177,7 @@ def main(m=128, n=256, k=256, tune=False): print(f"Tilelang: {latency} ms") else: - best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune) + best_result = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune) best_latency = best_result.latency best_config = best_result.config print(f"Bset latency: {best_latency}") diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py index 3f1214670..dea2e5ddd 100644 --- a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -17,7 +17,7 @@ def dequantize_gemv( out_dtype: str, accum_dtype: str, num_bits: int = 4, - storage_dtype: str = "int8", + storage_dtype: T.dtype = T.int8, source_format: str = "uint", n_partition: int = 4, reduce_thread: int = 32, @@ -51,7 +51,7 @@ def dequantize_gemv( C_shape = (M, N) dp4a_size = 4 - use_dp4a = in_dtype == "int8" and accum_dtype == "int32" + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 import_source: Optional[str] = None func_name: str = "" @@ -159,11 +159,11 @@ def main() -> None: M = 1 N = 1024 K = 1024 - in_dtype = "float16" - out_dtype = "float16" - accum_dtype = "float16" + in_dtype = T.float16 + out_dtype = T.float16 + accum_dtype = T.float16 num_bits = 4 - storage_dtype = "int8" + storage_dtype = T.int8 source_format = "uint" n_partition = 4 reduce_thread = 32 diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py index 098f814c2..9921c6bfe 100644 --- a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -49,7 +49,7 @@ def matmul( in_dtype, out_dtype, accum_dtype, - source_format="uint", + source_format=T.uint32, num_bits=4, scale_size=32, fast_dequant=True, @@ -83,8 +83,8 @@ def matmul( topk (int): number of experts selected per token. E (int): number of experts. padding_M (int): padded number of tokens after grouping and block alignment. - in_dtype (str): element type of A (e.g., "bfloat16"). - out_dtype (str): output tensor element type (e.g., "bfloat16"). + in_dtype (str): element type of A (e.g., T.bfloat16). + out_dtype (str): output tensor element type (e.g., T.bfloat16). accum_dtype (str): accumulation type used for the inner GEMM. source_format (str, optional): format string passed to intrinsic selector (default "uint"). num_bits (int, optional): number of bits per quantized element in B (default 4). @@ -111,7 +111,7 @@ def matmul( """ num_elems_per_byte = 8 // num_bits - storage_dtype = "uint8" + storage_dtype = T.uint8 QK = K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte A_shared_shape = (block_M, block_K) @@ -137,7 +137,7 @@ def matmul( import_source = import_source # the dequant part is the same as in dequant_gemm - def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16): """ Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: @@ -147,12 +147,12 @@ def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): - Writes the scaled BF16 results into B_dequantize_shared. Notes: - - This factory only supports in_dtype="fp4" and out_dtype="bfloat16". + - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. """ assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] # Some variables for dequantization in each thread MAX_TRANSACTION_SIZE_BITS = 128 @@ -227,9 +227,9 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared, return fast_dequant_bf16_fp4_twiddling - def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): + def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16): assert in_dtype in ["fp4"] - assert out_dtype in ["bfloat16"] + assert out_dtype in [T.bfloat16] @T.macro def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): @@ -259,8 +259,8 @@ def main( Bias: T.Tensor((E, N), out_dtype), # Add fusedmoe tensors topk_weights: T.Tensor((M * topk), out_dtype), - sorted_token_ids: T.Tensor((padding_M), "int32"), - expert_ids: T.Tensor((padding_M // block_M), "int32"), + sorted_token_ids: T.Tensor((padding_M), T.int32), + expert_ids: T.Tensor((padding_M // block_M), T.int32), C: T.Tensor((M, topk, N), out_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): @@ -271,8 +271,8 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype) topk_weights_shared = T.alloc_shared((block_M), out_dtype) - sorted_token_ids_shared = T.alloc_shared((block_M), "int32") - expert_id = T.alloc_local((1), "int32") # the expert id for the current block + sorted_token_ids_shared = T.alloc_shared((block_M), T.int32) + expert_id = T.alloc_local((1), T.int32) # the expert id for the current block # To use 1D TMA, the last dim of Scale_shared must have stride=1 # May use much more shared memory than necessary Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) @@ -346,7 +346,7 @@ def main( def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256): - dtypeC = "bfloat16" + dtypeC = T.bfloat16 M, K = A.shape E, N, QK = qB.shape topk = topk_weights.shape[0] // M @@ -451,9 +451,9 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi topk, E, padding_M, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=num_bits, scale_size=scale_size, fast_dequant=fast_dequant, @@ -467,9 +467,9 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi topk, E, padding_M, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, num_bits=num_bits, scale_size=scale_size, fast_dequant=fast_dequant, diff --git a/examples/dsa_sparse_finetune/indexer_bwd.py b/examples/dsa_sparse_finetune/indexer_bwd.py index 5d8132d9b..68508ad4e 100644 --- a/examples/dsa_sparse_finetune/indexer_bwd.py +++ b/examples/dsa_sparse_finetune/indexer_bwd.py @@ -9,9 +9,9 @@ from utils import get_abs_err, get_err_ratio -BF16 = "bfloat16" -FP32 = "float32" -INT32 = "int32" +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 pass_configs = { tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, diff --git a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py index 8e2f82ba6..d76eb0272 100644 --- a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -10,9 +10,9 @@ from utils import get_abs_err, get_err_ratio -BF16 = "bfloat16" -FP32 = "float32" -INT32 = "int32" +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 pass_configs = { tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py index 0b085516e..8b76dbca1 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_bwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -13,11 +13,11 @@ def preprocess( D, block_ND=32, num_stages=5, - dtype="bfloat16", - accum_dtype="float", + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 S = T.symbolic("S") @@ -53,11 +53,11 @@ def postprocess( kv_group=1, block_N=64, threads=128, - dtype="bfloat16", - accum_dtype="float", + dtype=T.bfloat16, + accum_dtype=T.float32, ): - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 S_kv = T.symbolic("S_kv") dkv_shape = [S_kv, kv_group, D + D_tail] @@ -94,15 +94,15 @@ def bwd( block_size=32, num_stages=0, threads=128, - indices_dtype="int32", - dtype="bfloat16", - accum_dtype="float", + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, ): assert is_causal == True, "non-casual is not supported now" assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" - assert dtype == "bfloat16" - assert accum_dtype == "float" - assert indices_dtype == "int32" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 if sm_scale is None: sm_scale = (D + D_tail) ** (-0.5) @@ -119,9 +119,9 @@ def bwd( lse_shape = [S, H] offsets_shape = [B_plus_one] token_indices_shape = [S, 2] - assert indices_dtype == "int32" - assert dtype == "bfloat16" - assert accum_dtype == "float" + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 H = H_kv padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) diff --git a/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/examples/dsa_sparse_finetune/sparse_mla_fwd.py index 6ec3caa7b..d87523695 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_fwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -47,9 +47,9 @@ def sparse_mla_fwd( lse_shape = [seq_len, heads] offsets_shape = [batch_plus_one] token_indices_shape = [seq_len, 2] - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 G = kv_group H = head_kv diff --git a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py index 6675215c7..a03bc74f5 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -8,9 +8,9 @@ from index import prepare_token_indices from utils import get_abs_err, get_err_ratio -BF16 = "bfloat16" -FP32 = "float32" -INT32 = "int32" +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 pass_configs = { tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, @@ -41,9 +41,9 @@ def tl_sparse_mla_topk_reducesum_impl( seq_len_kv = T.symbolic("seq_len_kv") head_kv = heads // kv_group - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 G = kv_group H = head_kv diff --git a/examples/dynamic_shape/example_dynamic.py b/examples/dynamic_shape/example_dynamic.py index 97ce7d9b3..598c9edf2 100644 --- a/examples/dynamic_shape/example_dynamic.py +++ b/examples/dynamic_shape/example_dynamic.py @@ -98,8 +98,8 @@ def ref_program(A, B): def main(M=16384, N=16384, K=16384): block_M, block_N, block_K = 128, 128, 32 trans_A, trans_B = False, False - in_dtype, out_dtype = "float16", "float16" - accum_dtype = "float32" + in_dtype, out_dtype = T.float16, T.float16 + accum_dtype = T.float32 num_stages = 3 threads = 128 matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index 72459459b..f075c64fd 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -43,11 +43,11 @@ def main(M=1024, N=1024, use_autotune=False): b = torch.randn(M, N, dtype=torch.float32, device="cuda") if use_autotune: - kernel = elementwise_add(M, N, in_dtype="float32", out_dtype="float32") + kernel = elementwise_add(M, N, in_dtype=T.float32, out_dtype=T.float32) else: # Default config config = {"block_M": 32, "block_N": 32, "threads": 128} - kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") + kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32) out = kernel(a, b) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index d1f5843e3..89c116669 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -17,8 +17,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( @@ -89,8 +89,8 @@ def flash_fwd( }, ) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_v] blk = 32 @@ -129,8 +129,8 @@ def make_dq_layout(dQ): }, ) def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_qk] blk = 64 @@ -161,8 +161,8 @@ def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, bl q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( @@ -256,8 +256,8 @@ def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M v_shape = [batch, seq_len, head_kv, dim_v] dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index c6cf336df..07586f99f 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -20,8 +20,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( @@ -94,8 +94,8 @@ def flash_fwd( }, ) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_v] blk = 32 @@ -134,8 +134,8 @@ def make_dq_layout(dQ): }, ) def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] @@ -178,8 +178,8 @@ def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, bl q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( @@ -276,8 +276,8 @@ def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal v_shape = [batch, seq_len, head_kv, dim_v] dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index 3501df1d7..cc88b64da 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -33,16 +33,16 @@ def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, d k_shape = [total_kv, head_kv, dim_qk] v_shape = [total_kv, head_kv, dim_v] o_shape = [total_q, heads, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( Q: T.Tensor(q_shape, dtype), # type: ignore K: T.Tensor(k_shape, dtype), # type: ignore V: T.Tensor(v_shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore Output: T.Tensor(o_shape, dtype), # type: ignore lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): @@ -143,8 +143,8 @@ def flash_fwd( }, ) def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [total_q, heads, dim_v] blk = 32 @@ -152,7 +152,7 @@ def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v): def flash_bwd_prep( O: T.Tensor(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore ): with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz): @@ -198,8 +198,8 @@ def make_dq_layout(dQ): }, ) def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 q_shape = [total_q, heads, dim_qk] k_shape = [total_kv, head_kv, dim_qk] v_shape = [total_kv, head_kv, dim_v] @@ -245,8 +245,8 @@ def flashattn_bwd_atomic_add( k_shape = [total_kv, head_kv, dim_qk] v_shape = [total_kv, head_kv, dim_v] do_shape = [total_q, heads, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( @@ -256,8 +256,8 @@ def flash_bwd( dO: T.Tensor(do_shape, dtype), # type: ignore lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dK: T.Tensor(k_shape, accum_dtype), # type: ignore dV: T.Tensor(v_shape, accum_dtype), # type: ignore @@ -386,8 +386,8 @@ def flashattn_bwd_split( do_shape = [total_q, heads, dim_v] dk_shape = [groups, total_kv, head_kv, dim_qk] # sum after kernel dv_shape = [groups, total_kv, head_kv, dim_v] # sum after kernel - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( @@ -397,8 +397,8 @@ def flash_bwd( dO: T.Tensor(do_shape, dtype), # type: ignore lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore - cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore - cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore + cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore + cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore dQ: T.Tensor(q_shape, accum_dtype), # type: ignore dK: T.Tensor(dk_shape, dtype), # type: ignore dV: T.Tensor(dv_shape, dtype), # type: ignore diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index adb7e06a8..f4e2de277 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -17,8 +17,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( @@ -89,8 +89,8 @@ def flash_fwd( }, ) def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim_v] blk = 32 @@ -129,8 +129,8 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc q_shape = [batch, seq_len, heads, dim_qk] k_shape = [batch, seq_len, head_kv, dim_qk] v_shape = [batch, seq_len, head_kv, dim_v] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index 408d6e507..5005435ea 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -70,8 +70,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.macro def MMA0( diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 3492be764..7b7a71b17 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -45,8 +45,8 @@ def flashattn( head_kv = heads // groups q_shape = [batch, seq_len, heads, dim] kv_shape = [batch, seq_len, head_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.macro def MMA0( diff --git a/examples/flash_attention/example_gqa_fwd_varlen.py b/examples/flash_attention/example_gqa_fwd_varlen.py index 87b11f71b..b02345d93 100644 --- a/examples/flash_attention/example_gqa_fwd_varlen.py +++ b/examples/flash_attention/example_gqa_fwd_varlen.py @@ -65,16 +65,16 @@ def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, bl q_shape = [UQ, heads, dim] kv_shape = [UKV, head_kv, dim] o_shape = [UQ, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( Q_unpad: T.Tensor(q_shape, dtype), K_unpad: T.Tensor(kv_shape, dtype), V_unpad: T.Tensor(kv_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), max_seqlen_q: T.int32, Output_unpad: T.Tensor(o_shape, dtype), ): diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index 81eb6d1e5..835a31596 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -15,8 +15,8 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( @@ -91,8 +91,8 @@ def flash_fwd( }, ) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 32 @@ -131,8 +131,8 @@ def make_dq_layout(dQ): }, ) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, heads, seq_len, dim] blk = 64 @@ -160,8 +160,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): sm_scale = (1.0 / dim) ** 0.5 scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, heads, seq_len, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( diff --git a/examples/flash_attention/example_mha_bwd_bshd.py b/examples/flash_attention/example_mha_bwd_bshd.py index 427a0f694..c0620bde0 100644 --- a/examples/flash_attention/example_mha_bwd_bshd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -15,8 +15,8 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( @@ -87,8 +87,8 @@ def flash_fwd( }, ) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 32 @@ -127,8 +127,8 @@ def make_dq_layout(dQ): }, ) def flashattn_bwd_postprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 64 @@ -156,8 +156,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): sm_scale = (1.0 / dim) ** 0.5 scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( diff --git a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index 813f379ca..34a8d69ce 100644 --- a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -15,8 +15,8 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_fwd( @@ -88,8 +88,8 @@ def flash_fwd( }, ) def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 shape = [batch, seq_len, heads, dim] blk = 32 @@ -125,8 +125,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): sm_scale = (1.0 / dim) ** 0.5 scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def flash_bwd( diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index 7fa5549d0..e70d17bf8 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -24,8 +24,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=6 scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index 440a2cd74..b8c4d81ec 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -24,8 +24,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N= scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index 888914c9b..248073f79 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -23,8 +23,8 @@ def get_configs(): def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.macro def MMA0( diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index b54d3e626..ab2aab44f 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -23,8 +23,8 @@ def get_configs(): def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.macro def MMA0( diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index f7bb36f71..6ba2e8ab4 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -80,16 +80,16 @@ def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64 v_shape = [UKV, heads, dim] o_shape = [UQ, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( Q_unpad: T.Tensor(q_shape, dtype), K_unpad: T.Tensor(k_shape, dtype), V_unpad: T.Tensor(v_shape, dtype), - cu_seqlens_q: T.Tensor([batch_size + 1], "int32"), - cu_seqlens_k: T.Tensor([batch_size + 1], "int32"), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), max_seqlen_q: T.int32, Output_unpad: T.Tensor(o_shape, dtype), ): diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 136a51292..ee42df208 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -53,8 +53,8 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, shape_k = [batch, seqlen_kv, groups, dim] shape_v = [batch, seqlen_kv, groups, dim] shape_o = [batch, heads, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // groups part_shape = [batch, heads, num_split, dim] diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py index 0fdd52919..ef3d8baed 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -209,8 +209,8 @@ def flashattn( shape_v = [total_seqlen_k, k_heads, dim] shape_o = [batch, heads, dim] shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // k_heads valid_block_H = min(block_H, kv_group_num) @@ -221,8 +221,8 @@ def flash_attn( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), Output: T.Tensor([batch, heads, dim], dtype), S: T.Tensor(shape_s, dtype), ): @@ -241,7 +241,7 @@ def flash_attn( logsum = T.alloc_fragment([block_H], accum_dtype) S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) - s_aux_shared = T.alloc_shared([block_H], "float32") + s_aux_shared = T.alloc_shared([block_H], T.float32) T.annotate_layout( { @@ -321,8 +321,8 @@ def flashattn_gqa_decode_no_split( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), Output: T.Tensor(shape_o, dtype), S: T.Tensor(shape_s, dtype), ): @@ -449,7 +449,7 @@ def test_equal_seqlen_decode_main(args): real_max_k_seqlen = args.k_seqlen head_size = args.head_size block_size = args.block_size - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 # For decode, query is just 1 token per batch q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) @@ -568,7 +568,7 @@ def test_varlen_decode_main(args): real_max_k_seqlen = args.k_seqlen head_size = args.head_size block_size = args.block_size - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") @@ -789,7 +789,7 @@ def speed_benchmark_decode_comparison(args): max_k_seqlen = args.k_seqlen head_size = args.head_size block_size = args.block_size - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 print("\n=== Decode Speed Benchmark Comparison ===") print("Configuration:") @@ -890,7 +890,7 @@ def speed_benchmark_decode_comparison(args): parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") parser.add_argument("--block_size", type=int, default=64, help="Block size for computation") - parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type") + parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") @@ -898,7 +898,7 @@ def speed_benchmark_decode_comparison(args): args = parser.parse_args() args.test_sink = True args.test_varlen = False - args.dtype = "float16" + args.dtype = T.float16 args.num_split = 1 if args.benchmark: diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py index 3537e5af0..0984e7075 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py @@ -45,8 +45,8 @@ def flashattn( shape_v = [total_seqlen_k, k_heads, dim] shape_o = [batch, heads, dim] shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // k_heads assert page_block_size >= block_N and page_block_size % block_N == 0, ( "page_block_size must be larger than block_N and a multiple of block_N" @@ -60,9 +60,9 @@ def flash_attn( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], "int32"), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], T.int32), Output: T.Tensor([batch, heads, dim], dtype), S: T.Tensor(shape_s, dtype), ): @@ -80,7 +80,7 @@ def flash_attn( scores_sum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype) S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) - s_aux_shared = T.alloc_shared([block_H], "float32") + s_aux_shared = T.alloc_shared([block_H], T.float32) bid = bx hid = by @@ -146,9 +146,9 @@ def flashattn_gqa_decode_no_split( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), - cu_seqlens_k: T.Tensor([batch + 1], "int32"), - s_aux: T.Tensor([heads], "float32"), - BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], "int32"), + cu_seqlens_k: T.Tensor([batch + 1], T.int32), + s_aux: T.Tensor([heads], T.float32), + BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], T.int32), Output: T.Tensor(shape_o, dtype), S: T.Tensor(shape_s, dtype), ): @@ -211,7 +211,7 @@ def test_equal_seqlen_decode_main(args): head_size = args.head_size block_size = args.block_size page_block_size = args.page_block_size - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 # For decode, query is just 1 token per batch q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) @@ -341,7 +341,7 @@ def test_varlen_decode_main(args): head_size = args.head_size block_size = args.block_size page_block_size = args.page_block_size - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") @@ -549,7 +549,7 @@ def speed_benchmark_decode_comparison(args): head_size = args.head_size block_size = args.block_size page_block_size = args.page_block_size - dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 print("\n=== Decode Speed Benchmark Comparison ===") print("Configuration:") @@ -659,7 +659,7 @@ def speed_benchmark_decode_comparison(args): parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length") parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension") parser.add_argument("--block_size", type=int, default=128, help="Block size for computation") - parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type") + parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type") parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths") parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism") parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark") @@ -668,7 +668,7 @@ def speed_benchmark_decode_comparison(args): args = parser.parse_args() args.test_sink = True args.test_varlen = True - args.dtype = "float16" + args.dtype = T.float16 args.num_split = 1 if args.benchmark: diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index d0381bc4a..5b243d695 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -14,8 +14,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ shape_q = [batch, seqlen_q, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim] part_shape = [batch, seqlen_q, heads, num_split, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.macro def MMA0( diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index b737f30aa..36c6ef3dc 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -33,7 +33,7 @@ def moe_forward_tilelang_shared( shared_W_up_shape = (dexpert, dhidden) shared_W_down_shape = (dhidden, dexpert) - accum_type = "float32" + accum_type = T.float32 @T.prim_func def kernel_shared( @@ -121,7 +121,7 @@ def moe_forward_tilelang_routed( # group_count = len(group_sizes_list) # M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list]) M = math.ceil(group_sum / block_token) + group_count - accum_dtype = "float32" + accum_dtype = T.float32 # Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm input_shape = (group_sum, dhidden) @@ -139,10 +139,10 @@ def kernel( routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore - group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore - group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore + group_sizes: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_padded_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore + group_idx_for_bx: T.Tensor((M,), T.int32), # type: ignore up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore output: T.Tensor(input_shape, dtype), # type: ignore ): @@ -155,8 +155,8 @@ def kernel( gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) - cur_group_idx = T.alloc_local([1], "int32") - cur_group_size = T.alloc_local([1], "int32") + cur_group_idx = T.alloc_local([1], T.int32) + cur_group_size = T.alloc_local([1], T.int32) T.use_swizzle(10, enable=True) @@ -208,8 +208,8 @@ def kernel( routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype) - cur_group_idx = T.alloc_local([1], "int32") - cur_group_size = T.alloc_local([1], "int32") + cur_group_idx = T.alloc_local([1], T.int32) + cur_group_size = T.alloc_local([1], T.int32) T.use_swizzle(10, enable=True) @@ -464,7 +464,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: """ input_tensor, weights, config = data - dtype_str = "float16" + dtype_str = T.float16 shared_kernel = moe_forward_tilelang_shared( config["d_hidden"], diff --git a/examples/gdn/example_chunk_delta_bwd.py b/examples/gdn/example_chunk_delta_bwd.py index ecda7e41b..39450bc5f 100644 --- a/examples/gdn/example_chunk_delta_bwd.py +++ b/examples/gdn/example_chunk_delta_bwd.py @@ -250,13 +250,13 @@ def kernel( dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) - dO_shared_t = T.alloc_shared((block_DV, block_S), dtype="float32") - dO_fragment = T.alloc_fragment((block_S, block_DV), dtype="float32") - dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype="float32") + dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32) + dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32) + dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32) K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) - Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype="float32") + Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype=T.float32) W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) G_last_local = T.alloc_local((1), dtype=gate_dtype) @@ -592,11 +592,11 @@ def main(): H=8, DK=DK, DV=128, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, scale=DK**-0.5, use_g=True, diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py index 43f1e972b..d316a6211 100644 --- a/examples/gdn/example_chunk_delta_h.py +++ b/examples/gdn/example_chunk_delta_h.py @@ -387,11 +387,11 @@ def main(): H=32, DK=128, DV=128, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, use_g=True, use_initial_state=False, diff --git a/examples/gdn/example_chunk_o.py b/examples/gdn/example_chunk_o.py index bd1e9aa23..815368159 100644 --- a/examples/gdn/example_chunk_o.py +++ b/examples/gdn/example_chunk_o.py @@ -230,10 +230,10 @@ def main(): DK=128, DV=128, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, use_g=True, block_DK=128, block_DV=128, diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index 66cb6942e..97e2f4f01 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -505,11 +505,11 @@ def main(): H=8, DK=DK, DV=DV, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, scale=DK**-0.5, # scale=1, diff --git a/examples/gdn/example_chunk_scaled_dot_kkt.py b/examples/gdn/example_chunk_scaled_dot_kkt.py index af2b08e57..e8ef17e3f 100644 --- a/examples/gdn/example_chunk_scaled_dot_kkt.py +++ b/examples/gdn/example_chunk_scaled_dot_kkt.py @@ -57,9 +57,9 @@ def tilelang_chunk_scaled_dot_kkt_fwd( H, DK, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, use_g=True, # kernel config block_S=64, @@ -183,9 +183,9 @@ def main(): H=32, DK=128, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, use_g=True, block_DK=64, threads=128, diff --git a/examples/gdn/example_cumsum.py b/examples/gdn/example_cumsum.py index 13547cd60..0760b4964 100644 --- a/examples/gdn/example_cumsum.py +++ b/examples/gdn/example_cumsum.py @@ -32,8 +32,8 @@ def tilelang_chunk_local_cumsum_scalar( is_varlen=False, head_first=False, reverse=False, - input_dtype="float16", - output_dtype="float32", + input_dtype=T.float16, + output_dtype=T.float32, # kernel config block_S=64, threads=256, @@ -154,8 +154,8 @@ def main(): chunk_size=64, reverse=True, head_first=False, - input_dtype="float32", - output_dtype="float32", + input_dtype=T.float32, + output_dtype=T.float32, threads=256, use_fragment=False, ) diff --git a/examples/gdn/example_wy_fast.py b/examples/gdn/example_wy_fast.py index 874e25c3b..9ac086ca7 100644 --- a/examples/gdn/example_wy_fast.py +++ b/examples/gdn/example_wy_fast.py @@ -205,10 +205,10 @@ def main(): DK=128, DV=128, chunk_size=64, - input_dtype="bfloat16", - output_dtype="bfloat16", - gate_dtype="float32", - accum_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + gate_dtype=T.float32, + accum_dtype=T.float32, block_DK=64, block_DV=32, threads=128, diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py index 5b0230e5c..de8afc2b7 100644 --- a/examples/gdn/example_wy_fast_bwd_split.py +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -518,11 +518,11 @@ def main(): H=8, DK=DK, DV=DV, - input_dtype="bfloat16", - output_dtype="bfloat16", - accum_dtype="float32", - gate_dtype="float32", - state_dtype="float32", + input_dtype=T.bfloat16, + output_dtype=T.bfloat16, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, chunk_size=64, block_DK=32, block_DV=32, diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index a51936ef8..e749fa087 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -1,16 +1,17 @@ -import tilelang.testing import torch +import tilelang.testing +from tilelang import language as T B = 1 S = 1024 # small but for test only. H = 32 DK = 128 DV = 128 -input_dtype = "bfloat16" -output_dtype = "bfloat16" -accum_dtype = "float32" -gate_dtype = "float32" -state_dtype = "float32" +input_dtype = T.bfloat16 +output_dtype = T.bfloat16 +accum_dtype = T.float32 +gate_dtype = T.float32 +state_dtype = T.float32 chunk_size = 64 use_g = True use_initial_state = True diff --git a/examples/gemm/README.md b/examples/gemm/README.md index d7833c97d..9ab7fb661 100644 --- a/examples/gemm/README.md +++ b/examples/gemm/README.md @@ -53,7 +53,7 @@ import tilelang from tilelang import Profiler import tilelang.language as T -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -176,7 +176,7 @@ import tilelang.language as T # that helps align data for MMA (Matrix Multiply-Accumulate) operations. from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -265,18 +265,18 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index 2c234d122..906a55d5d 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -3,7 +3,7 @@ @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm( A: T.Tensor((M, K), dtype), diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index badc33402..ca3222173 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -51,9 +51,9 @@ def get_configs(M, N, K, with_roller=False, topk=20): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, ).with_arch(arch) func = carve_template.equivalent_function() @@ -116,8 +116,8 @@ def kernel( thread_num=None, enable_rasteration=None, ): - dtype = "bfloat16" - accum_dtype = "float" + dtype = T.bfloat16 + accum_dtype = T.float32 @T.prim_func def main( @@ -178,7 +178,7 @@ def get_heuristic_config() -> dict: @tl.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm_autotune( A: T.Tensor((M, K), dtype), diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index 488e5bf6b..746e6ec01 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -35,18 +35,18 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -54,7 +54,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 64 warp_col_tiles = 64 - # chunk = 32 if in_dtype == "float16" else 64 + # chunk = 32 if in_dtype == T.float16 else 64 chunk = 32 shared_scope = "shared.dyn" @@ -163,7 +163,7 @@ def ref_program(A, B): def main(M=4096, N=4096, K=4096): - in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" + in_dtype, out_dtype, accum_dtype = T.float16, T.float16, T.float32 kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) src_code = kernel.get_kernel_source() # src_code is the generated cuda source diff --git a/examples/gemm/example_gemm_persistent.py b/examples/gemm/example_gemm_persistent.py index 6fc0e5aac..30f55de6a 100644 --- a/examples/gemm/example_gemm_persistent.py +++ b/examples/gemm/example_gemm_persistent.py @@ -5,7 +5,7 @@ @tilelang.jit(out_idx=[-1]) -def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float"): +def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -34,7 +34,7 @@ def main( @tilelang.jit(out_idx=[-1]) def matmul_persistent( - M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float", use_persistent_primitive=True + M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32, use_persistent_primitive=True ): sm_num = driver.get_num_sms() m_blocks = T.ceildiv(M, block_M) diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py index d1eb11df5..8663c878d 100644 --- a/examples/gemm/example_gemm_schedule.py +++ b/examples/gemm/example_gemm_schedule.py @@ -3,7 +3,7 @@ @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm_schedule( A: T.Tensor((M, K), dtype), diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd.py b/examples/gemm_fp8/example_tilelang_gemm_amd.py index 4c58144e4..93f8c4980 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd.py @@ -53,8 +53,8 @@ def get_configs(): ) @tilelang.jit(out_idx=[-1]) def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): - dtype = "float8_e4m3fnuz" - accum_dtype = "float" + dtype = T.float8_e4m3fnuz + accum_dtype = T.float32 @T.prim_func def gemm_fp8_rs( diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index 1ecd344bc..1b440a795 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -1,7 +1,6 @@ import torch import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type def calc_diff(x, y): @@ -12,7 +11,7 @@ def calc_diff(x, y): @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): @T.prim_func def gemm_fp8( A: T.Tensor((M, K), dtype), @@ -36,7 +35,7 @@ def gemm_fp8( def test_gemm_fp8(M, N, K, dtype): - torch_dtype = map_torch_type(dtype) + torch_dtype = T.dtype(dtype).as_torch() kernel = matmul(M, N, K, 128, 128, 64, dtype) @@ -56,8 +55,8 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 1024, "float8_e4m3") - test_gemm_fp8(1024, 1024, 1024, "float8_e5m2") + test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn) + test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2) if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index 3af4c3d6d..1c5d84d72 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -1,11 +1,10 @@ import torch import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32): # for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128. # if block_K < 128, promote after 128/block_K iters. # if block_K > 128, promote after every iter. @@ -55,7 +54,7 @@ def calc_diff(x, y): def test_gemm_fp8(M, N, K, dtype): - torch_dtype = map_torch_type(dtype) + torch_dtype = T.dtype(dtype).as_torch() kernel = matmul(M, N, K, 128, 128, 64, dtype) @@ -74,8 +73,8 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 8192, "float8_e4m3") - test_gemm_fp8(1024, 1024, 8192, "float8_e5m2") + test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn) + test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2) if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 6e2d41be8..7ecde7c1b 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -39,26 +39,26 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "float8_e4m3", - "float8_e5m2", - "int8", + T.float16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 is_float8 = in_dtype in [ - "float8_e4m3", - "float8_e5m2", - "float8_e4m3fn", - "float8_e5m2fnuz", + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, ] - if out_dtype == "int32" or is_float8: + if out_dtype == T.int32 or is_float8: micro_size_k = 32 # This is a debug config @@ -66,7 +66,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 32 warp_col_tiles = 32 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -220,8 +220,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def main(): - assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32") - assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") + assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) + assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py index 5cb42e328..aa7e8b360 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -73,8 +73,8 @@ def calc_diff(x, y): trans_A, trans_B = False, True num_stages = 2 threads = 256 -for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]: - for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]: +for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]: + for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]: torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) torch_acc_dtype = map_torch_type(tvm_acc_dtype) print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") diff --git a/examples/gemm_sm100/README.md b/examples/gemm_sm100/README.md index 73dd76c30..28bb611bf 100644 --- a/examples/gemm_sm100/README.md +++ b/examples/gemm_sm100/README.md @@ -40,19 +40,19 @@ import tilelang.language as T @T.prim_func def main( - A: T.Tensor((M, K), "bfloat16"), - B: T.Tensor((N, K), "bfloat16"), - C: T.Tensor((M, N), "bfloat16"), + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((N, K), T.bfloat16), + C: T.Tensor((M, N), T.bfloat16), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): # 1. Allocate memory buffers - A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory - B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory - C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) # A matrix shared memory + B_shared = T.alloc_shared((block_N, block_K), T.bfloat16) # B matrix shared memory + C_tmem = T.alloc_tmem([block_M, block_N], T.float) # TCGEN5MMA output to Tensor Memory mbar = T.alloc_barrier(1) # mbarrier synchronization primitive - C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage - C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory + C_local = T.alloc_fragment((block_M, block_N), T.float) # Register storage + C_shared = T.alloc_shared((block_M, block_N), T.bfloat16) # Output shared memory # 2. Main computation loop for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): diff --git a/examples/gemm_sm100/gemm_mma.py b/examples/gemm_sm100/gemm_mma.py index be43f4ec4..226e33c01 100644 --- a/examples/gemm_sm100/gemm_mma.py +++ b/examples/gemm_sm100/gemm_mma.py @@ -4,7 +4,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), diff --git a/examples/gemm_sm100/gemm_tcgen5mma.py b/examples/gemm_sm100/gemm_tcgen5mma.py index 88614f561..523a94fea 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma.py +++ b/examples/gemm_sm100/gemm_tcgen5mma.py @@ -54,7 +54,7 @@ def main( M, N, K = 4096, 4096, 8192 block_M, block_N, block_K = 128, 256, 128 trans_A, trans_B = False, True -in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float" +in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float num_stages = 2 threads = 256 diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py index fe3b15233..7f18523b7 100644 --- a/examples/gemm_sp/example_custom_compress.py +++ b/examples/gemm_sp/example_custom_compress.py @@ -17,7 +17,7 @@ DEFAULT_CONFIG = { # take best config from autotune script "4090": { - "float": { + T.float: { "block_M": 128, "block_N": 64, "block_K": 64, @@ -26,7 +26,7 @@ "policy": T.GemmWarpPolicy.Square, "enable_rasterization": True, }, - "float16": { + T.float16: { "block_M": 256, "block_N": 128, "block_K": 64, @@ -37,7 +37,7 @@ }, }, "h20": { - "float": { + T.float: { "block_M": 128, "block_N": 64, "block_K": 128, @@ -46,7 +46,7 @@ "policy": T.GemmWarpPolicy.Square, "enable_rasterization": True, }, - "float16": { + T.float16: { "block_M": 128, "block_N": 64, "block_K": 128, @@ -65,26 +65,26 @@ def matmul_sp_fp16_custom_compress( M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout ): - e_factor, e_dtype = (16, "int16") + e_factor, e_dtype = (16, T.int16) @T.prim_func def gemm_sp_fp16_custom_compress( - A_sparse: T.Tensor((M, K // 2), "float16"), + A_sparse: T.Tensor((M, K // 2), T.float16), E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), "float16"), + B: T.Tensor((K, N), T.float16), C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K // 2), "float16") + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - B_shared = T.alloc_shared((block_K, block_N), "float16") + B_shared = T.alloc_shared((block_K, block_N), T.float16) C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) if use_cutlass_layout: T.annotate_layout( { - E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), } ) T.clear(C_local) @@ -253,15 +253,15 @@ def kernel( if use_cutlass_layout: T.annotate_layout( { - E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K), } ) T.clear(A_sp_shared) T.clear(E_shared) # TODO: alloc_var seems buggy here - non_zero_cnt = T.alloc_local((1,), dtype="uint8") - non_zero_elt_log_idx = T.alloc_local((elem,), dtype="uint8") + non_zero_cnt = T.alloc_local((1,), dtype=T.uint8) + non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8) T.copy(A[bx * block_M, by * block_K], A_shared) for tm in T.Parallel(block_M): for g_i in range(0, block_K // group): @@ -300,7 +300,7 @@ def main(): parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") - parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") args = parser.parse_args() kernel = matmul_sp_fp16_custom_compress( @@ -314,7 +314,7 @@ def main(): assert not args.use_cutlass_layout, "torch sparse must be used with naive layout" a_sparse, e = torch_compress(a) else: - a_sparse, e = compress_kernel(args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(a) + a_sparse, e = compress_kernel(args.m, args.k, 32, 32, T.float16, use_cutlass_layout=args.use_cutlass_layout)(a) c = kernel(a_sparse, e, b) diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py index 828ca43a2..708bc7231 100644 --- a/examples/gemm_sp/example_gemm_sp.py +++ b/examples/gemm_sp/example_gemm_sp.py @@ -16,7 +16,7 @@ DEFAULT_CONFIG = { # take best config from autotune script "4090": { - "float": { + T.float: { "block_M": 128, "block_N": 64, "block_K": 64, @@ -25,7 +25,7 @@ "policy": T.GemmWarpPolicy.Square, "enable_rasterization": True, }, - "float16": { + T.float16: { "block_M": 256, "block_N": 128, "block_K": 64, @@ -36,7 +36,7 @@ }, }, "h20": { - "float": { + T.float: { "block_M": 128, "block_N": 64, "block_K": 128, @@ -45,7 +45,7 @@ "policy": T.GemmWarpPolicy.Square, "enable_rasterization": True, }, - "float16": { + T.float16: { "block_M": 128, "block_N": 64, "block_K": 128, @@ -66,15 +66,15 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, @T.prim_func def gemm_sp_fp16( - A_sparse: T.Tensor((M, K // 2), "float16"), + A_sparse: T.Tensor((M, K // 2), T.float16), E: T.Tensor((M, K // e_factor), e_dtype), - B: T.Tensor((K, N), "float16"), + B: T.Tensor((K, N), T.float16), C: T.Tensor((M, N), accum_dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K // 2), "float16") + A_shared = T.alloc_shared((block_M, block_K // 2), T.float16) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) - B_shared = T.alloc_shared((block_K, block_N), "float16") + B_shared = T.alloc_shared((block_K, block_N), T.float16) C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -83,8 +83,8 @@ def gemm_sp_fp16( T.use_swizzle(panel_size=10, enable=enable_rasterization) T.annotate_layout( { - E: make_cutlass_metadata_layout(E, mma_dtype="float16", block_k=block_K, arch=arch), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", block_k=block_K, arch=arch), + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch), } ) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): @@ -104,7 +104,7 @@ def main(): parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") args = parser.parse_args() kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index 320a699c5..62073c5bd 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -3,7 +3,7 @@ @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"): +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): splitK = K // split_k @T.prim_func diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py index dfd847101..83e83b5d2 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py @@ -3,7 +3,7 @@ @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"): +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32): splitK = K // split_k @T.prim_func diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index 2d83586e5..7ec1541ea 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -87,8 +87,8 @@ def compute_first_wave( C: T.Tensor, C_local: T.LocalBuffer, ): - start_iter = T.alloc_fragment((1,), "int32", "local") - end_iter = T.alloc_fragment((1,), "int32", "local") + start_iter = T.alloc_fragment((1,), T.int32, "local") + end_iter = T.alloc_fragment((1,), T.int32, "local") start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) @@ -179,9 +179,9 @@ def main(): BLOCK_SIZE_K, False, True, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, 2, 64, ) diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 00cbac067..9dd0e4dd9 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -17,8 +17,8 @@ def naive_gemv( K: int, BLOCK_N: int, BLOCK_K: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): @T.prim_func def main( @@ -49,8 +49,8 @@ def naive_splitk_gemv( K: int, BLOCK_N: int, BLOCK_K: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): @T.prim_func def main( @@ -85,8 +85,8 @@ def splitk_gemv( BLOCK_N: int, BLOCK_K: int, reduce_threads: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): TILE_K = T.ceildiv(BLOCK_K, reduce_threads) @@ -124,8 +124,8 @@ def splitk_gemv_vectorized( K: int, BLOCK_N: int, reduce_threads: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): MAX_TRANSACTION_SIZE_IN_BITS = 128 TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits @@ -165,8 +165,8 @@ def splitk_gemv_vectorized_tvm( K: int, BLOCK_N: int, reduce_threads: int, - dtype: str = "float16", - accum_dtype: str = "float", + dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float, ): MAX_TRANSACTION_SIZE_IN_BITS = 128 TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits @@ -233,7 +233,9 @@ def get_block_template_configs(): }, out_idx=[2], ) -def gemv_alloc_reducer(M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: str = "float16", accum_dtype: str = "float"): +def gemv_alloc_reducer( + M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float +): @T.prim_func def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m: @@ -274,8 +276,8 @@ def get_autotuned_kernel( BLOCK_N=None, reduce_threads=None, ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 MAX_TRANSACTION_SIZE_IN_BITS = 128 TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits BLOCK_K = reduce_threads * TILE_K diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py index b1af5360c..bb57c6073 100644 --- a/examples/grouped_gemm/example_grouped_gemm_bwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -6,29 +6,29 @@ @tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): +def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: a (torch.Tensor): Input tensor of shape (M, K). b (torch.Tensor): Input tensor of shape (G, K, N). """ - accum_dtype = "float32" + accum_dtype = T.float32 @T.prim_func def kernel( A: T.Tensor([batch_sum, K], dtype), # type: ignore B: T.Tensor([batch_count, K, N], dtype), # type: ignore C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore ): with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) - cur_batch_idx = T.alloc_local([1], "int32") - cur_batch_size = T.alloc_local([1], "int32") + cur_batch_idx = T.alloc_local([1], T.int32) + cur_batch_size = T.alloc_local([1], T.int32) m_start_padded = bx * block_M @@ -158,21 +158,21 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): @tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) -def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): +def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: a (torch.Tensor): Input tensor of shape (M, K). b (torch.Tensor): Input tensor of shape (G, K, N). """ - accum_dtype = "float32" + accum_dtype = T.float32 @T.prim_func def kernel( A: T.Tensor([batch_sum, M], dtype), # type: ignore B: T.Tensor([batch_sum, N], dtype), # type: ignore C: T.Tensor([batch_count, M, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore ): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz): A_shared = T.alloc_shared([block_K, block_M], dtype) diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd.py b/examples/grouped_gemm/example_grouped_gemm_fwd.py index 8f7710512..48d916051 100644 --- a/examples/grouped_gemm/example_grouped_gemm_fwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_fwd.py @@ -37,7 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): @tilelang.jit(out_idx=[2]) -def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): +def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: a (torch.Tensor): Input tensor of shape (M, K). @@ -45,7 +45,7 @@ def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2 """ batch_sum = sum(batch_sizes_list) batch_count = len(batch_sizes_list) - accum_dtype = "float32" + accum_dtype = T.float32 total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list) @T.prim_func @@ -53,16 +53,16 @@ def kernel( A: T.Tensor([batch_sum, K], dtype), # type: ignore B: T.Tensor([batch_count, K, N], dtype), # type: ignore C: T.Tensor([batch_sum, N], dtype), # type: ignore - batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore - batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore - batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore + batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore + batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore + batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore ): with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): A_shared = T.alloc_shared([block_M, block_K], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype) - cur_batch_idx = T.alloc_local([1], "int32") - cur_batch_size = T.alloc_local([1], "int32") + cur_batch_idx = T.alloc_local([1], T.int32) + cur_batch_size = T.alloc_local([1], T.int32) m_start_padded = bx * block_M diff --git a/examples/hadamard_transform/example_hadamard.py b/examples/hadamard_transform/example_hadamard.py index 64eb9bbdb..65f463b71 100644 --- a/examples/hadamard_transform/example_hadamard.py +++ b/examples/hadamard_transform/example_hadamard.py @@ -17,7 +17,7 @@ def is_pow_of_2(n): def hadamard(b, n, dtype): assert is_pow_of_2(n), "n must be a power of 2" assert 2 <= n <= 32768, "n must be in [2, 32768]" - elem_size = {"float32": 4, "float16": 2, "bfloat16": 2}[dtype] + elem_size = {T.float32: 4, T.float16: 2, T.bfloat16: 2}[dtype] logN = int(math.log2(n)) threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] @@ -138,7 +138,7 @@ def main(): B, D = args.batch, args.dim x = torch.randn((B, D), device="cuda") - kernel = hadamard(B, D, "float32") + kernel = hadamard(B, D, T.float32) y = kernel(x) y_ref = ref_program(x) torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) diff --git a/examples/lazy_jit/lazyjit.en.ipynb b/examples/lazy_jit/lazyjit.en.ipynb index 196ddfc46..99cb977f0 100644 --- a/examples/lazy_jit/lazyjit.en.ipynb +++ b/examples/lazy_jit/lazyjit.en.ipynb @@ -552,7 +552,7 @@ { "data": { "text/plain": [ - "# from tvm.script import tir as T\n", + "# import tilelang.language as T\n", "\n", "@T.prim_func\n", "def foo(x_handle: T.handle):\n", @@ -723,7 +723,7 @@ { "data": { "text/plain": [ - "# from tvm.script import tir as T\n", + "# import tilelang.language as T\n", "\n", "@T.prim_func\n", "def foo():\n", @@ -786,4 +786,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/lazy_jit/lazyjit.zh.ipynb b/examples/lazy_jit/lazyjit.zh.ipynb index d6db4c76e..601c5c5d2 100644 --- a/examples/lazy_jit/lazyjit.zh.ipynb +++ b/examples/lazy_jit/lazyjit.zh.ipynb @@ -552,7 +552,7 @@ { "data": { "text/plain": [ - "# from tvm.script import tir as T\n", + "# import tilelang.language as T\n", "\n", "@T.prim_func\n", "def foo(x_handle: T.handle):\n", @@ -723,7 +723,7 @@ { "data": { "text/plain": [ - "# from tvm.script import tir as T\n", + "# import tilelang.language as T\n", "\n", "@T.prim_func\n", "def foo():\n", @@ -786,4 +786,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index 7cbfc465a..397ec7bdf 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -21,12 +21,12 @@ def tl_fused_chunk_bwd_kernel( H, DK, DV, - dtype: str = "float16", + dtype: T.dtype = T.float16, scale: float = None, ) -> torch.Tensor: if scale is None: scale = DK**-0.5 - accum_dtype = "float" + accum_dtype = T.float32 chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index 3d28f92b0..849841e51 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -22,12 +22,12 @@ def tl_fused_chunk_fwd_kernel( H, DK, DV, - dtype: str = "float16", + dtype: T.dtype = T.float16, scale: float = None, ) -> torch.Tensor: if scale is None: scale = DK**-0.5 - accum_dtype = "float" + accum_dtype = T.float32 chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA diff --git a/examples/linear_attention/example_mamba_chunk_scan.py b/examples/linear_attention/example_mamba_chunk_scan.py index 53b6cf9fb..1958dfb5a 100644 --- a/examples/linear_attention/example_mamba_chunk_scan.py +++ b/examples/linear_attention/example_mamba_chunk_scan.py @@ -89,8 +89,8 @@ def chunk_scan_fwd( num_stages=2, threads=128, ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 diff --git a/examples/linear_attention/example_mamba_chunk_state.py b/examples/linear_attention/example_mamba_chunk_state.py index 6aefde7bb..fb766d5e9 100644 --- a/examples/linear_attention/example_mamba_chunk_state.py +++ b/examples/linear_attention/example_mamba_chunk_state.py @@ -55,8 +55,8 @@ def get_configs(): def chunk_state_fwd( batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, num_stages=2, threads=128 ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 diff --git a/examples/linear_attention/example_retention_fwd.py b/examples/linear_attention/example_retention_fwd.py index ccb11fe1b..f45e38388 100644 --- a/examples/linear_attention/example_retention_fwd.py +++ b/examples/linear_attention/example_retention_fwd.py @@ -13,12 +13,12 @@ def chunk_retention_fwd_kernel( H, DK, DV, - dtype: str = "float16", + dtype: T.dtype = T.float16, scale: float = None, ) -> torch.Tensor: if scale is None: scale = DK**-0.5 - accum_dtype = "float" + accum_dtype = T.float32 chunk_size = 64 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA @@ -37,7 +37,7 @@ def chunk_retention_fwd( with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H - log_decay = T.alloc_var("float32") + log_decay = T.alloc_var(T.float32) log_decay = T.log2(1 - T.exp2(-5.0 - 1.0 * i_h)) # Head-specific log decay q = T.alloc_shared([chunk_size, BK], dtype) diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index 6600bb5ed..f96e73ae5 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -31,9 +31,9 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz vertical_size_round, slash_size_round = tilelang.next_power_of_2(vertical_size), tilelang.next_power_of_2(slash_size) - dtype = "float16" - accum_dtype = "float" - int_dtype = "int32" + dtype = T.float16 + accum_dtype = T.float32 + int_dtype = T.int32 def kernel_func(block_M, block_N, num_stages, threads): @T.macro diff --git a/examples/norm/rms_norm.py b/examples/norm/rms_norm.py index a7a06b9c6..57bccc1a0 100644 --- a/examples/norm/rms_norm.py +++ b/examples/norm/rms_norm.py @@ -4,7 +4,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k): - dtype = "float" + dtype = T.float @T.prim_func def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): @@ -35,7 +35,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): @tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True}) def rms_norm(M, N, blk_m): - dtype = "float" + dtype = T.float @T.prim_func def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): diff --git a/examples/norm/test_rms_norm.py b/examples/norm/test_rms_norm.py index 124a212f6..53db03d98 100644 --- a/examples/norm/test_rms_norm.py +++ b/examples/norm/test_rms_norm.py @@ -5,7 +5,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k): - dtype = "float" + dtype = T.float @T.prim_func def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): @@ -35,7 +35,7 @@ def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): def rms_norm(M, N, blk_m): - dtype = "float" + dtype = T.float @T.prim_func def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): diff --git a/examples/online_softmax/online_softmax.py b/examples/online_softmax/online_softmax.py index 32f1c001f..811870e44 100644 --- a/examples/online_softmax/online_softmax.py +++ b/examples/online_softmax/online_softmax.py @@ -9,12 +9,12 @@ def softmax_kernel( M, N, - dtype: str = "float16", + dtype: T.dtype = T.float16, ) -> "Callable": BN = min(tl.next_power_of_2(N), 8192) NN = tl.cdiv(N, BN) - accum_dtype = "float" + accum_dtype = T.float32 scale = 1.44269504 # log2(e) diff --git a/examples/plot_layout/README.md b/examples/plot_layout/README.md index a65d771c2..8204e93d8 100644 --- a/examples/plot_layout/README.md +++ b/examples/plot_layout/README.md @@ -10,7 +10,7 @@ from typing import Literal, Callable from tilelang.intrinsics.utils import get_mma_micro_size from tilelang.tools import plot_layout -def make_mma_load_base_layout(dtype: str = "float16", +def make_mma_load_base_layout(dtype: str = T.float16, matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: """ @@ -69,7 +69,7 @@ def make_mma_load_base_layout(dtype: str = "float16", micro_size_s, _, micro_size_r = get_mma_micro_size(dtype) transform_func = transform_func - inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward_thread(i: int, j: int) -> int: """ @@ -94,7 +94,7 @@ def make_mma_load_base_layout(dtype: str = "float16", # Create a 16×16 matrix layout for ldmatrix operations -base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False) +base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) # Print the layout structure (optional for debugging) print(base_layout) diff --git a/examples/plot_layout/fragment_mfma_load_a.py b/examples/plot_layout/fragment_mfma_load_a.py index a7e8f8909..d45cc227b 100644 --- a/examples/plot_layout/fragment_mfma_load_a.py +++ b/examples/plot_layout/fragment_mfma_load_a.py @@ -12,7 +12,7 @@ def make_mfma_load_base_layout( - dtype: str = "float16", matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False + dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False ) -> T.Fragment: """ Create a layout function for storing MFMA results into a fragment buffer. @@ -79,7 +79,7 @@ def make_mfma_load_base_layout( else: raise ValueError(f"Unsupported matrix {matrix}") - inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward_thread(i: int, j: int) -> int: """ @@ -112,7 +112,7 @@ def forward_index(i: int, j: int) -> int: from tilelang.tools import plot_layout # ldmatrix layout 16x16 -base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False) +base_layout = make_mfma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) print(base_layout) plot_layout(base_layout, name="base_layout") diff --git a/examples/plot_layout/fragment_mma_load_a.py b/examples/plot_layout/fragment_mma_load_a.py index 17d1c6d51..df4a0b887 100644 --- a/examples/plot_layout/fragment_mma_load_a.py +++ b/examples/plot_layout/fragment_mma_load_a.py @@ -5,7 +5,7 @@ from tilelang.intrinsics.utils import get_mma_micro_size -def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: +def make_mma_load_base_layout(dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment: """ Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with `inverse_mma_store_layout` to @@ -74,7 +74,7 @@ def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"] else: raise ValueError(f"Unsupported matrix {matrix}") - inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward_thread(i: int, j: int) -> int: """ @@ -107,7 +107,7 @@ def forward_index(i: int, j: int) -> int: from tilelang.tools import plot_layout # ldmatrix layout 16x16 -base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False) +base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False) print(base_layout) plot_layout(base_layout, name="base_layout") diff --git a/examples/quickstart.py b/examples/quickstart.py index 4b765ce17..e99fc0dbc 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -6,7 +6,7 @@ # target currently can be "cuda" or "hip" or "cpu". # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def matmul_relu_kernel( A: T.Tensor((M, K), dtype), diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py index f5f7fe7ba..25741f97c 100644 --- a/examples/seer_attention/block_sparse_attn_tilelang.py +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -42,9 +42,9 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c kv_shape = [batch, heads, seq_kv, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] - dtype = "float16" - accum_dtype = "float" - block_mask_dtype = "int8" + dtype = T.float16 + accum_dtype = T.float32 + block_mask_dtype = T.int8 def kernel_func(block_M, block_N, num_stages, threads): @T.macro diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py index 6c37dc09c..14339ff02 100644 --- a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py +++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -2,6 +2,7 @@ import tilelang from tilelang.utils.sparse import compress_sm90 from tilelang.layout import make_cutlass_metadata_layout +from tilelang import language as T import tilelang.testing @@ -24,8 +25,6 @@ def matmul_sp( A_shared_shape = (block_M, block_K // 2) B_shared_shape = (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A_sparse: T.Tensor(A_sparse_shape, in_dtype), @@ -40,8 +39,8 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.annotate_layout( { - E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="9.0", block_k=block_K), - E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="9.0", block_k=block_K), + E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="9.0", block_k=block_K), + E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="9.0", block_k=block_K), } ) T.clear(C_local) @@ -111,7 +110,7 @@ def run_gemm_sp( def main(): - run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128) + run_gemm_sp(512, 1024, 768, T.float16, T.float16, T.float32, 128, 128, 128, 2, 128) if __name__ == "__main__": diff --git a/examples/topk/example_topk.py b/examples/topk/example_topk.py index c0cf09bc0..d4f0c8bfb 100644 --- a/examples/topk/example_topk.py +++ b/examples/topk/example_topk.py @@ -22,19 +22,19 @@ def tl_topk( blk_m, threads=128, ): - dtype = "float32" + dtype = T.float32 @T.prim_func def topk_kernel( logits: T.Tensor([M, N], dtype), topk_gates: T.Tensor([M, topk], dtype), - topk_indices: T.Tensor([M, topk], "int32"), + topk_indices: T.Tensor([M, topk], T.int32), ): with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx: logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype) max_val = T.alloc_fragment([blk_m], dtype=dtype) - expand_max_idx = T.alloc_fragment([blk_m, N], "int32") - max_idx = T.alloc_fragment([blk_m], "int32") + expand_max_idx = T.alloc_fragment([blk_m, N], T.int32) + max_idx = T.alloc_fragment([blk_m], T.int32) T.copy(logits[bx * blk_m, 0], logits_frag) diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index dbb39f789..8fa1eaf85 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -10,7 +10,7 @@ tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg", }, ) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm( A: T.Tensor((M, K), dtype), diff --git a/examples/warp_specialize/example_warp_specialize_flashmla.py b/examples/warp_specialize/example_warp_specialize_flashmla.py index 4f4417e75..6dcd51aa7 100644 --- a/examples/warp_specialize/example_warp_specialize_flashmla.py +++ b/examples/warp_specialize/example_warp_specialize_flashmla.py @@ -10,8 +10,8 @@ @tilelang.jit(out_idx=[6]) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 kv_group_num = heads // kv_head_num VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" diff --git a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py index 5d438b5de..4a2aa00d9 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -7,7 +7,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): num_stages = 2 mbarrier_list = [128, 128] * num_stages diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py index 03ddf8122..7b2278432 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py @@ -5,7 +5,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py index 63aed2bed..02d88019c 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -5,7 +5,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py index f24d76a22..5468aa6ea 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_gemm_0_1.py @@ -10,7 +10,7 @@ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, }, ) -def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): warp_group_num = 2 threads = 128 * warp_group_num diff --git a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py index f3f8a665b..31d156f32 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py @@ -5,7 +5,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @tilelang.jit(out_idx=[2]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor[(M, K), dtype], diff --git a/maint/gemm_v2/correctness_evaluation.py b/maint/gemm_v2/correctness_evaluation.py index e7a822544..44441cdeb 100644 --- a/maint/gemm_v2/correctness_evaluation.py +++ b/maint/gemm_v2/correctness_evaluation.py @@ -2,6 +2,8 @@ import pytest from tilelang import tvm as tvm import tilelang.testing +from tilelang import language as T +import torch def matmul( @@ -24,8 +26,6 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -74,13 +74,11 @@ def _compile_and_check( profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) def ref_program(A, B): - import torch - if trans_A: A = A.T if trans_B: B = B.T - if in_dtype == "float32": + if in_dtype == T.float32: A = (A.view(torch.int32) - 0x1000).view(torch.float32) B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) @@ -148,8 +146,6 @@ def matmul_rs( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) A_frag_shape = A_shared_shape - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -235,8 +231,6 @@ def matmul_sr( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_frag_shape = B_shared_shape - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -323,8 +317,6 @@ def matmul_rr( A_frag_shape = A_shared_shape B_frag_shape = B_shared_shape - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -399,9 +391,9 @@ def run_gemm_rr( [ pytest.param( k, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, id=f"K{k}-float16-float16-float16", ) for k in K_VALUES @@ -409,9 +401,9 @@ def run_gemm_rr( + [ pytest.param( k, - "int8", - "int32", - "int32", + T.int8, + T.int32, + T.int32, id="K32-int8-int32-int32", ) for k in K_VALUES_8Bit @@ -419,9 +411,9 @@ def run_gemm_rr( + [ pytest.param( k, - "float8_e5m2", - "float32", - "float32", + T.float8_e5m2, + T.float32, + T.float32, id="K32-float8_e5m2-float32-float32", ) for k in K_VALUES_8Bit @@ -429,9 +421,9 @@ def run_gemm_rr( + [ pytest.param( k, - "float8_e4m3", - "float32", - "float32", + T.float8_e4m3fn, + T.float32, + T.float32, id="K32-float8_e4m3-float32-float32", ) for k in K_VALUES_8Bit @@ -452,15 +444,15 @@ def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): def run_gemm_rs_false_false(m, n, k): - run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k) + run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k) def run_gemm_rs_true_false(m, n, k): - run_gemm_rs(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k) + run_gemm_rs(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k) def run_gemm_rs_true_true(m, n, k): - run_gemm_rs(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k) + run_gemm_rs(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k) def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): @@ -468,15 +460,15 @@ def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): def run_gemm_sr_false_false(m, n, k): - run_gemm_sr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k) + run_gemm_sr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k) def run_gemm_sr_true_false(m, n, k): - run_gemm_sr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k) + run_gemm_sr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k) def run_gemm_sr_true_true(m, n, k): - run_gemm_sr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k) + run_gemm_sr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k) def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): @@ -484,15 +476,15 @@ def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): def run_gemm_rr_false_false(m, n, k): - run_gemm_rr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k) + run_gemm_rr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k) def run_gemm_rr_true_false(m, n, k): - run_gemm_rr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k) + run_gemm_rr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k) def run_gemm_rr_true_true(m, n, k): - run_gemm_rr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k) + run_gemm_rr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k) TRANS_CASES = [ @@ -548,9 +540,9 @@ def test_gemm_false_false(m, n, k): k * 3, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, m, n, k, @@ -567,9 +559,9 @@ def test_gemm_true_false(m, n, k): k * 3, True, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, m, n, k, @@ -586,9 +578,9 @@ def test_gemm_true_true(m, n, k): k * 3, True, True, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, m, n, k, @@ -607,7 +599,7 @@ def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_rs_false_false(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_rs_false_false(m, n, k) @@ -615,7 +607,7 @@ def test_gemm_rs_false_false(m, n, k): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_rs_true_false(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_rs_true_false(m, n, k) @@ -623,7 +615,7 @@ def test_gemm_rs_true_false(m, n, k): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_rs_true_true(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_rs_true_true(m, n, k) @@ -639,7 +631,7 @@ def test_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_sr_false_false(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_sr_false_false(m, n, k) @@ -647,7 +639,7 @@ def test_gemm_sr_false_false(m, n, k): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_sr_true_false(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_sr_true_false(m, n, k) @@ -655,7 +647,7 @@ def test_gemm_sr_true_false(m, n, k): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_sr_true_true(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_sr_true_true(m, n, k) @@ -671,7 +663,7 @@ def test_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_rr_false_false(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_rr_false_false(m, n, k) @@ -679,7 +671,7 @@ def test_gemm_rr_false_false(m, n, k): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_rr_true_false(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_rr_true_false(m, n, k) @@ -687,7 +679,7 @@ def test_gemm_rr_true_false(m, n, k): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_rr_true_true(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_rr_true_true(m, n, k) @@ -699,7 +691,7 @@ def test_gemm_rr_true_true(m, n, k): # for n in [16, 32, 64, 128]: # for k in [16, 32, 64, 128]: # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128) + # run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) # print(f"Test {m} {n} {k} Pass") # # Test Pass @@ -707,7 +699,7 @@ def test_gemm_rr_true_true(m, n, k): # for n in [16, 32, 64, 128]: # for k in [16, 32, 64, 128]: # print(f"======================= Test {m} {n} {k} False False =============================") - # run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128) + # run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) # print(f"Test {m} {n} {k} Pass") # # Test Pass @@ -715,7 +707,7 @@ def test_gemm_rr_true_true(m, n, k): # for n in [16, 32, 64, 128]: # for k in [16, 32, 64, 128]: # print(f"======================= Test {m} {n} {k} True False =============================") - # run_gemm(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128) + # run_gemm(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) # print(f"Test {m}, {n} {k} Pass") # print(f"Test {n} Pass") @@ -724,7 +716,7 @@ def test_gemm_rr_true_true(m, n, k): # for n in [16, 32, 64, 128]: # for k in [16, 32, 64, 128]: # print(f"======================= Test {m} {n} {k} True True =============================") - # run_gemm(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128) + # run_gemm(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) # print(f"Test {m}, {n} {k} Pass") # print(f"Test {n} Pass") @@ -733,15 +725,15 @@ def test_gemm_rr_true_true(m, n, k): # for n in [16, 32, 64, 128]: # for k in [16, 32, 64, 128]: # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm_rs(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128) + # run_gemm_rs(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) # print(f"Test {m} {n} {k} Pass") # for n in [16, 32, 64, 128]: # for k in [16, 32, 64, 128]: - # run_gemm_rs(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256) + # run_gemm_rs(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256) # print(f"Test {64} {n} {k} Pass") # for n in [16, 32, 64, 128]: # for k in [16, 32, 64, 128]: - # run_gemm(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256) + # run_gemm(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256) # print(f"Test {64} {n} {k} Pass") diff --git a/maint/gemm_v2/correctness_evaluation_sm70.py b/maint/gemm_v2/correctness_evaluation_sm70.py index 3b4503d4e..606d10261 100644 --- a/maint/gemm_v2/correctness_evaluation_sm70.py +++ b/maint/gemm_v2/correctness_evaluation_sm70.py @@ -2,6 +2,7 @@ import pytest from tilelang import tvm as tvm import tilelang.testing +from tilelang import language as T def matmul( @@ -24,8 +25,6 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -81,7 +80,7 @@ def ref_program(A, B): A = A.T if trans_B: B = B.T - if in_dtype == "float32": + if in_dtype == T.float32: A = (A.view(torch.int32) - 0x1000).view(torch.float32) B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) @@ -147,8 +146,6 @@ def matmul_rs( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) A_frag_shape = A_shared_shape - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -217,18 +214,18 @@ def run_gemm_rs( FALSE_TRUE_CASES = [ pytest.param( k, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, id=f"K{k}-float16-float16-float16", ) for k in K_VALUES ] + [ pytest.param( k, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, id=f"K{k}-float16-float16-float32", ) for k in K_VALUES @@ -248,7 +245,7 @@ def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): def run_gemm_rs_false_false(m, n, k): - run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128) + run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) TRANS_CASES = [ @@ -306,9 +303,9 @@ def test_gemm_false_false(m, n, k): k * 3, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, m, n, k, @@ -329,7 +326,7 @@ def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): @pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}") @pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}") def test_gemm_rs_false_false(m, n, k): - _ensure_torch_dtypes("float16") + _ensure_torch_dtypes(T.float16) run_gemm_rs_false_false(m, n, k) @@ -341,7 +338,7 @@ def test_gemm_rs_false_false(m, n, k): # for n in [16, 32, 64, 128]: # for k in [16, 32, 64]: # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128) + # run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128) # print(f"Test {m} {n} {k} Pass") # # Test Pass @@ -349,5 +346,5 @@ def test_gemm_rs_false_false(m, n, k): # for n in [16, 32, 64, 128]: # for k in [16, 32, 64]: # print(f"======================= Test {m} {n} {k} False False =============================") - # run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128) + # run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128) # print(f"Test {m} {n} {k} Pass") diff --git a/maint/gemm_v2/correctness_evaluation_tcgen05.py b/maint/gemm_v2/correctness_evaluation_tcgen05.py index 4ce8691ec..8d9728182 100644 --- a/maint/gemm_v2/correctness_evaluation_tcgen05.py +++ b/maint/gemm_v2/correctness_evaluation_tcgen05.py @@ -80,7 +80,7 @@ def ref_program(A, B): A = A.T if trans_B: B = B.T - if in_dtype == "float32": + if in_dtype == T.float32: A = (A.view(torch.int32) - 0x1000).view(torch.float32) B = (B.view(torch.int32) - 0x1000).view(torch.float32) C = torch.matmul(A.to(torch.float), B.to(torch.float)) @@ -134,18 +134,18 @@ def run_gemm( FALSE_TRUE_CASES = [ pytest.param( k, - "float16", - "float32", - "float32", + T.float16, + T.float32, + T.float32, id=f"K{k}-float16-float-float", ) for k in K_VALUES ] + [ pytest.param( k, - "float8_e5m2", - "float32", - "float32", + T.float8_e5m2, + T.float32, + T.float32, id="K32-float8_e5m2-float32-float32", ) for k in K_VALUES_8Bit @@ -195,7 +195,7 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): # if m in [32, 64] and (n not in [64, 128, 256]): # continue # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128) + # run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 128) # print(f"Test {m} {n} {k} Pass") # # Test Pass @@ -205,7 +205,7 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): # if m in [32, 64] and (n not in [64, 128, 256]): # continue # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 256) + # run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 256) # print(f"Test {m} {n} {k} Pass") # # Test Pass @@ -215,4 +215,4 @@ def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype): # if m in [32, 64] and (n not in [64, 128, 256]): # continue # print(f"======================= Test {m} {n} {k} False True =============================") - # run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128) + # run_gemm(m, n, k * 3, False, True, T.float8_e5m2, T.float, T.float, m, n, k, 2, 128) diff --git a/maint/gemm_v2/latency.py b/maint/gemm_v2/latency.py index 4dcb7cf9a..b7b2a2af9 100644 --- a/maint/gemm_v2/latency.py +++ b/maint/gemm_v2/latency.py @@ -13,7 +13,7 @@ # target currently can be "cuda" or "hip" or "cpu". # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def matmul_relu_kernel( A: T.Tensor((M, K), dtype), diff --git a/maint/gemm_v2/latency_gemm.py b/maint/gemm_v2/latency_gemm.py index a66167d4b..5f0450e02 100644 --- a/maint/gemm_v2/latency_gemm.py +++ b/maint/gemm_v2/latency_gemm.py @@ -13,7 +13,7 @@ # target currently can be "cuda" or "hip" or "cpu". # if not specified, it will be inferred from the input tensors during compile time @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def matmul_relu_kernel( A: T.Tensor((M, K), dtype), diff --git a/maint/gemm_v2/latency_mha_fwd_bhsd.py b/maint/gemm_v2/latency_mha_fwd_bhsd.py index 3fd560012..7a83d7cec 100644 --- a/maint/gemm_v2/latency_mha_fwd_bhsd.py +++ b/maint/gemm_v2/latency_mha_fwd_bhsd.py @@ -38,8 +38,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=6 scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) q_shape = [batch, heads, seq_q, dim] kv_shape = [batch, heads, seq_kv, dim] - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 past_len = seq_kv - seq_q assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" diff --git a/maint/host_checks/common.py b/maint/host_checks/common.py index 649527d4a..3dbac5481 100644 --- a/maint/host_checks/common.py +++ b/maint/host_checks/common.py @@ -3,7 +3,7 @@ import torch -def make_matmul_prim(M, N, K, block_M=128, block_N=128, block_K=32, dtype="float16", accum_dtype="float"): +def make_matmul_prim(M, N, K, block_M=128, block_N=128, block_K=32, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), diff --git a/maint/precision/compare_ops.py b/maint/precision/compare_ops.py index 985c3bd96..c77a67cfd 100644 --- a/maint/precision/compare_ops.py +++ b/maint/precision/compare_ops.py @@ -186,8 +186,8 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool = @T.prim_func def tilelang_unary_kernel( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), ): with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by): for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): @@ -224,9 +224,9 @@ def make_tilelang_binary_kernel(M: int, N: int): @T.prim_func def tilelang_binary_kernel( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), - C: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), + C: T.Tensor((M, N), T.float32), ): with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by): for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): diff --git a/maint/scripts/performance.py b/maint/scripts/performance.py index 849bcf362..d53c227f5 100644 --- a/maint/scripts/performance.py +++ b/maint/scripts/performance.py @@ -30,8 +30,8 @@ def kernel( thread_num=None, enable_rasteration=None, ): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 408b16ccf..702ae0175 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1124,8 +1124,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { } // Handle conversion from float32 to float8 (E4M3/E5M2) - if (from_ty.is_float() && - (target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) { + if (from_ty.is_float() && (target_ty.is_float8())) { + bool target_type_is_e4m3 = target_ty.is_float8_e4m3() || + target_ty.is_float8_e4m3fn() || + target_ty.is_float8_e4m3fnuz(); // FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion // (float2 -> fp8x2) if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { @@ -1134,8 +1136,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret << ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast(&(" << src << ")), __NV_SATFINITE, " - << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; os << sret; return; } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) { @@ -1144,14 +1145,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = " << "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src << ")), __NV_SATFINITE, " - << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; PrintIndent(); stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = " << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src << "))+1), __NV_SATFINITE, " - << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; os << sret; return; } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) { @@ -1160,33 +1159,31 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = " << "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src << ")), __NV_SATFINITE, " - << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; PrintIndent(); stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = " << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src << "))+1), __NV_SATFINITE, " - << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; PrintIndent(); stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = " << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src << "))+2), __NV_SATFINITE, " - << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; PrintIndent(); stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = " << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src << "))+3), __NV_SATFINITE, " - << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") - << ");\n"; + << (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; os << sret; return; } } - if ((from_ty.is_float8_e4m3() || from_ty.is_float8_e5m2()) && - target_ty.is_float()) { + if (from_ty.is_float8() && target_ty.is_float()) { + bool from_type_is_e4m3 = from_ty.is_float8_e4m3() || + from_ty.is_float8_e4m3fn() || + from_ty.is_float8_e4m3fnuz(); // FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion // (fp8x2 -> float2) if (from_ty.lanes() == 2 && target_ty.lanes() == 2) { @@ -1196,8 +1193,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { << ")) = " "__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_" "t*>(&(" - << src << ")), " - << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << src << ")), " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; os << sret; return; @@ -1206,14 +1202,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { PrintIndent(); stream << "*(float2*)(&" << sret << ") = " << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[0], " - << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; PrintIndent(); stream << "*((float2*)(&" << sret << ")+1) = " << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[1], " - << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; os << sret; return; @@ -1222,26 +1216,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { PrintIndent(); stream << "*(float2*)(&" << sret << ") = " << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[0], " - << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; PrintIndent(); stream << "*((float2*)(&" << sret << ")+1) = " << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[1], " - << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; PrintIndent(); stream << "*((float2*)(&" << sret << ")+2) = " << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[2], " - << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << "))[2], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; PrintIndent(); stream << "*((float2*)(&" << sret << ")+3) = " << "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src - << "))[3], " - << (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") + << "))[3], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n"; os << sret; return; diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 337312851..0c6a76377 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -1179,10 +1179,10 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { // Check if this is a non-reducer store with Cast operation DataType src_type = cast->value.dtype(); DataType dst_type = cast->dtype; - bool src_ok = src_type.is_float() || src_type.is_bfloat() || - src_type.is_float8_e4m3() || src_type.is_float8_e5m2(); - bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() || - dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2(); + bool src_ok = + src_type.is_float() || src_type.is_bfloat() || src_type.is_float8(); + bool dst_ok = + dst_type.is_float() || dst_type.is_bfloat() || dst_type.is_float8(); if (src_ok && dst_ok && TargetIsCuda(Target::Current())) { has_cast_operations = true; } diff --git a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py index 65b2d5cff..b26354830 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py @@ -26,7 +26,7 @@ def tl_matmul( ): micro_size_x = micro_size_y = micro_size_k = 16 - if in_dtype in {"float8_e4m3fnuz", "int8"}: + if in_dtype in {T.float8_e4m3fnuz, T.int8}: micro_size_k = 32 block_row_warps = 2 @@ -160,7 +160,7 @@ def main( return main -def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32", a_transposed=False, b_transposed=True, k_pack=1): +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype=T.float32, a_transposed=False, b_transposed=True, k_pack=1): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack) print(matmul) kernel = tilelang.compile(matmul) @@ -169,10 +169,10 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa assert src_code is not None A_shape = (K, M) if a_transposed else (M, K) B_shape = (N, K) if b_transposed else (K, N) - if in_dtype == "int8": + if in_dtype == T.int8: A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) - elif in_dtype == "float8_e4m3fnuz": + elif in_dtype == T.float8_e4m3fnuz: A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) else: @@ -211,15 +211,15 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa @pytest.mark.parametrize( "M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack", [ - (128, 128, 128, "float16", "float16", "float32", False, True, 1), - (128, 256, 256, "float16", "float32", "float32", False, True, 1), - (128, 256, 256, "float16", "float32", "float32", False, True, 2), - (128, 128, 128, "int8", "int32", "int32", False, True, 1), - (128, 256, 256, "int8", "int32", "int32", False, True, 1), - (128, 256, 256, "int8", "int32", "int32", False, True, 2), - (128, 256, 256, "int8", "int32", "int32", False, False, 1), - (128, 256, 256, "int8", "int32", "int32", False, False, 2), - (128, 128, 128, "float8_e4m3fnuz", "float16", "float32", False, True, 1), + (128, 128, 128, T.float16, T.float16, T.float32, False, True, 1), + (128, 256, 256, T.float16, T.float32, T.float32, False, True, 1), + (128, 256, 256, T.float16, T.float32, T.float32, False, True, 2), + (128, 128, 128, T.int8, T.int32, T.int32, False, True, 1), + (128, 256, 256, T.int8, T.int32, T.int32, False, True, 1), + (128, 256, 256, T.int8, T.int32, T.int32, False, True, 2), + (128, 256, 256, T.int8, T.int32, T.int32, False, False, 1), + (128, 256, 256, T.int8, T.int32, T.int32, False, False, 2), + (128, 128, 128, T.float8_e4m3fnuz, T.float16, T.float32, False, True, 1), ], ) @tilelang.testing.requires_rocm @@ -235,10 +235,10 @@ def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transpose b_transposed=b_transposed, k_pack=k_pack, ) - assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32") - assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2) - assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False) - assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2) + assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32) + assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32, k_pack=2) + assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32, b_transposed=False) + assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32, b_transposed=False, k_pack=2) if __name__ == "__main__": diff --git a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py index eb2c6cbca..dc95eb701 100644 --- a/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py +++ b/testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py @@ -26,7 +26,7 @@ def tl_matmul( ): micro_size_x = micro_size_y = micro_size_k = 16 - if in_dtype in {"float8_e4m3fnuz", "int8"}: + if in_dtype in {T.float8_e4m3fnuz, T.int8}: micro_size_k = 32 block_row_warps = 2 @@ -196,7 +196,7 @@ def assert_tl_matmul_correctness( K, in_dtype, out_dtype, - accum_dtype="float32", + accum_dtype=T.float32, a_transposed=False, b_transposed=True, k_pack=1, @@ -211,10 +211,10 @@ def assert_tl_matmul_correctness( assert src_code is not None A_shape = (K, M) if a_transposed else (M, K) B_shape = (N, K) if b_transposed else (K, N) - if in_dtype == "int8": + if in_dtype == T.int8: A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) - elif in_dtype == "float8_e4m3fnuz": + elif in_dtype == T.float8_e4m3fnuz: A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype)) else: @@ -261,14 +261,14 @@ def assert_tl_matmul_correctness( @pytest.mark.parametrize( "M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load", [ - (256, 256, 512, "int8", "int32", "int32", False, True, 1, True, False), - (256, 256, 512, "int8", "int32", "int32", False, False, 1, True, False), - (256, 256, 512, "int8", "int32", "int32", False, True, 2, True, False), - (256, 256, 512, "int8", "int32", "int32", False, False, 2, True, False), - (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 1, True, False), - (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 1, True, False), - (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 2, True, False), - (256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 2, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, True, 1, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, False, 1, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, True, 2, True, False), + (256, 256, 512, T.int8, T.int32, T.int32, False, False, 2, True, False), + (256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, True, 1, True, False), + (256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, False, 1, True, False), + (256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, True, 2, True, False), + (256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, False, 2, True, False), ], ) @tilelang.testing.requires_rocm diff --git a/testing/python/amd/test_tilelang_test_amd.py b/testing/python/amd/test_tilelang_test_amd.py index c9c3bedbb..4035c299c 100644 --- a/testing/python/amd/test_tilelang_test_amd.py +++ b/testing/python/amd/test_tilelang_test_amd.py @@ -108,7 +108,7 @@ def ref_program(A, B): ) @tilelang.testing.requires_rocm def test_gemm_f16f32f32_nt(trans_A, trans_B, k_pack): - run_gemm(1024, 1024, 1024, trans_A, trans_B, "float16", "float32", "float32", 128, 128, 32, k_pack=k_pack) + run_gemm(1024, 1024, 1024, trans_A, trans_B, T.float16, T.float32, T.float32, 128, 128, 32, k_pack=k_pack) @pytest.mark.parametrize( @@ -123,7 +123,7 @@ def test_gemm_f16f32f32_nt(trans_A, trans_B, k_pack): ) @tilelang.testing.requires_rocm def test_gemm_bf16f32f32_nt(trans_A, trans_B, k_pack): - run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=k_pack) + run_gemm(1024, 1024, 1024, trans_A, trans_B, T.bfloat16, T.float32, T.float32, 128, 128, 32, k_pack=k_pack) @pytest.mark.parametrize( @@ -138,7 +138,7 @@ def test_gemm_bf16f32f32_nt(trans_A, trans_B, k_pack): ) @tilelang.testing.requires_rocm def test_gemm_bf16bf16f32(trans_A, trans_B, k_pack): - run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=k_pack) + run_gemm(1024, 1024, 1024, trans_A, trans_B, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32, k_pack=k_pack) def matmul_rs( @@ -241,24 +241,24 @@ def ref_program(A, B): # @tilelang.testing.requires_rocm # def test_gemm_rs_f16f32f32_nt(): -# run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32) -# run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32) -# run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32) -# run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, False, False, T.float16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, False, True, T.float16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, True, T.float16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, False, T.float16, T.float32, T.float32, 128, 128, 32) # @tilelang.testing.requires_rocm # def test_gemm_rs_bf16f32f32_nt(): -# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32) -# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32) -# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32) -# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, False, False, T.bfloat16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, False, True, T.bfloat16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, True, T.bfloat16, T.float32, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, False, T.bfloat16, T.float32, T.float32, 128, 128, 32) # @tilelang.testing.requires_rocm # def test_gemm_rs_bf16bf16f32_nt(): -# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) -# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) -# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32) -# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, False, False, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, True, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32) +# run_gemm_rs(1024, 1024, 1024, True, False, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32) if __name__ == "__main__": tilelang.testing.main() diff --git a/testing/python/analysis/test_tilelang_fragment_loop_checker.py b/testing/python/analysis/test_tilelang_fragment_loop_checker.py index 85aa51895..99458f1c8 100644 --- a/testing/python/analysis/test_tilelang_fragment_loop_checker.py +++ b/testing/python/analysis/test_tilelang_fragment_loop_checker.py @@ -5,7 +5,7 @@ @tilelang.jit -def simple_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): +def simple_invalid_loop(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): A = T.dynamic("A") @T.prim_func @@ -26,7 +26,7 @@ def main( @tilelang.jit -def nested_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): +def nested_invalid_loop(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): A = T.dynamic("A") @T.prim_func @@ -48,7 +48,7 @@ def main( @tilelang.jit -def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): +def invalid_loop_with_complex_dataflow(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): A = T.dynamic("A") @T.prim_func @@ -69,7 +69,7 @@ def main( @tilelang.jit -def valid_loop_not_use_loop_var(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): +def valid_loop_not_use_loop_var(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): A = T.dynamic("A") @T.prim_func @@ -91,7 +91,7 @@ def main( @tilelang.jit -def valid_loop_not_frag(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): +def valid_loop_not_frag(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): A = T.dynamic("A") @T.prim_func @@ -112,7 +112,7 @@ def main( @tilelang.jit -def valid_loop_serial(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128): +def valid_loop_serial(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128): A = T.dynamic("A") @T.prim_func diff --git a/testing/python/analysis/test_tilelang_nested_loop_checker.py b/testing/python/analysis/test_tilelang_nested_loop_checker.py index e282c8e34..664fda5b8 100644 --- a/testing/python/analysis/test_tilelang_nested_loop_checker.py +++ b/testing/python/analysis/test_tilelang_nested_loop_checker.py @@ -29,7 +29,7 @@ def _require_cuda_tensor(shape, dtype=torch.float32): @tilelang.jit(out_idx=[1]) -def nested_continuous_parallels(length=256, block=16, dtype="float32"): +def nested_continuous_parallels(length=256, block=16, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -44,7 +44,7 @@ def main( @tilelang.jit(out_idx=[1]) -def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"): +def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -60,7 +60,7 @@ def main( @tilelang.jit(out_idx=[1]) -def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"): +def nested_noncontinuous_parallels(length=256, block=16, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -149,9 +149,9 @@ def run_gemm_nested_pipelines( block_K = 32 trans_A = False trans_B = False - in_dtype = "float16" - out_dtype = "float16" - dtypeAccum = "float32" + in_dtype = T.float16 + out_dtype = T.float16 + dtypeAccum = T.float32 num_threads = 128 program = matmul_nested_pipelines( M, @@ -188,7 +188,7 @@ def ref_program(A, B): A = A.T if trans_B: B = B.T - if in_dtype == "float32": + if in_dtype == T.float32: # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas A = (A.view(torch.int32) - 0x1000).view(torch.float32) @@ -215,7 +215,7 @@ def test_nested_pipelines(): @tilelang.jit(out_idx=[1]) -def nested_continuous_serials(length=256, block=16, dtype="float32"): +def nested_continuous_serials(length=256, block=16, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -230,7 +230,7 @@ def main( @tilelang.jit(out_idx=[1]) -def nested_noncontinuous_serials(length=256, block=16, dtype="float32"): +def nested_noncontinuous_serials(length=256, block=16, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -272,7 +272,7 @@ def test_nested_serials(): @tilelang.jit(out_idx=[1]) -def nested_continuous_sp(length=256, block=16, dtype="float32"): +def nested_continuous_sp(length=256, block=16, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -287,7 +287,7 @@ def main( @tilelang.jit(out_idx=[1]) -def nested_continuous_ps(length=256, block=16, dtype="float32"): +def nested_continuous_ps(length=256, block=16, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -302,7 +302,7 @@ def main( @tilelang.jit(out_idx=[1]) -def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"): +def nested_continuous_psp(length=256, block1=8, block2=2, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -318,7 +318,7 @@ def main( @tilelang.jit(out_idx=[1]) -def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"): +def nested_continuous_sps(length=256, block1=8, block2=2, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -469,9 +469,9 @@ def run_gemm_mixed_pp( block_M = 128 block_N = 128 block_K = 32 - in_dtype = "float16" - out_dtype = "float16" - dtypeAccum = "float32" + in_dtype = T.float16 + out_dtype = T.float16 + dtypeAccum = T.float32 num_threads = 128 program = matmul_nested_pipa( @@ -502,7 +502,7 @@ def run_gemm_mixed_pp( def ref_program(A, B): import torch - if in_dtype == "float32": + if in_dtype == T.float32: # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas A = (A.view(torch.int32) - 0x1000).view(torch.float32) @@ -603,9 +603,9 @@ def run_gemm_tiled_op_with_parallel( block_M = 128 block_N = 128 block_K = 32 - in_dtype = "float16" - out_dtype = "float16" - dtypeAccum = "float32" + in_dtype = T.float16 + out_dtype = T.float16 + dtypeAccum = T.float32 num_threads = 128 program = matmul_nested_pipa( @@ -636,7 +636,7 @@ def run_gemm_tiled_op_with_parallel( def ref_program(A, B): import torch - if in_dtype == "float32": + if in_dtype == T.float32: # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas A = (A.view(torch.int32) - 0x1000).view(torch.float32) @@ -673,7 +673,7 @@ def ref_program(A, B): @tilelang.jit(out_idx=[1]) -def tir_op_with_parallel(length=256, block=16, dtype="float32"): +def tir_op_with_parallel(length=256, block=16, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -688,7 +688,7 @@ def main( @tilelang.jit(out_idx=[1]) -def customize_op_with_parallel(length=256, block=16, dtype="float32"): +def customize_op_with_parallel(length=256, block=16, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), diff --git a/testing/python/autotune/test_tilelang_autotune.py b/testing/python/autotune/test_tilelang_autotune.py index 3e6a05a24..53707ca34 100644 --- a/testing/python/autotune/test_tilelang_autotune.py +++ b/testing/python/autotune/test_tilelang_autotune.py @@ -57,9 +57,9 @@ def get_configs(M, N, K, with_roller=False): M=M, N=N, K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float16, ).with_arch(arch) func = carve_template.equivalent_function() @@ -187,8 +187,8 @@ def kernel( """ # Use half-precision for input data to reduce memory bandwidth, # accumulate in float for better numerical accuracy - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( diff --git a/testing/python/autotune/test_tilelang_autotune_with_inputs.py b/testing/python/autotune/test_tilelang_autotune_with_inputs.py index 8f9a6098d..4edea0b88 100644 --- a/testing/python/autotune/test_tilelang_autotune_with_inputs.py +++ b/testing/python/autotune/test_tilelang_autotune_with_inputs.py @@ -39,8 +39,8 @@ def get_configs(): ) @tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False): - dtype = "float16" - accum_dtype = "float" + dtype = T.float16 + accum_dtype = T.float32 @T.prim_func def main( diff --git a/testing/python/carver/test_tilelang_carver_generate_hints.py b/testing/python/carver/test_tilelang_carver_generate_hints.py index 313dc857f..ea674f7c7 100644 --- a/testing/python/carver/test_tilelang_carver_generate_hints.py +++ b/testing/python/carver/test_tilelang_carver_generate_hints.py @@ -3,19 +3,20 @@ from tilelang.carver.roller import PrimFuncNode, OutputNode, Edge from tilelang.carver.arch import auto_infer_current_arch from tvm import te +from tilelang.language import dtypes as T def run_general_matmul_emit_configs(M, N, K, topk: int = 20): arch = auto_infer_current_arch() def gemm(M, N, K): - A = te.placeholder((M, K), name="A", dtype="float16") - B = te.placeholder((N, K), name="B", dtype="float16") + A = te.placeholder((M, K), name="A", dtype=T.float16) + B = te.placeholder((N, K), name="B", dtype=T.float16) # Describe the matrix multiplication in TE k = te.reduce_axis((0, K), name="k") - C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C") + C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype(T.float16) * B[j, k].astype(T.float16), axis=[k]), name="C") return A, B, C @@ -55,13 +56,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20): arch = auto_infer_current_arch() def gemm(M, N, K): - A = te.placeholder((M, K), name="A", dtype="float16") - B = te.placeholder((N, K), name="B", dtype="float16") + A = te.placeholder((M, K), name="A", dtype=T.float16) + B = te.placeholder((N, K), name="B", dtype=T.float16) # Describe the matrix multiplication in TE k = te.reduce_axis((0, K), name="k") - C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C") + C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype(T.float16) * B[j, k].astype(T.float16), axis=[k]), name="C") return A, B, C diff --git a/testing/python/carver/test_tilelang_carver_recommend_hints.py b/testing/python/carver/test_tilelang_carver_recommend_hints.py index 4973c24d9..3a060f532 100644 --- a/testing/python/carver/test_tilelang_carver_recommend_hints.py +++ b/testing/python/carver/test_tilelang_carver_recommend_hints.py @@ -1,10 +1,11 @@ import tilelang.testing from tilelang import carver +from tilelang.language import dtypes as T from tilelang.carver.arch import auto_infer_current_arch from typing import List -def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: str = "float16", topk: int = 20): +def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20): arch = auto_infer_current_arch() carve_template = carver.GeneralReductionTemplate( structure=structure, @@ -20,12 +21,12 @@ def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[in def test_general_reduction_recommend_hints(): - run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], "float16") - run_general_reduction_recommend_hints("SS", [1024, 1024], "float16") - run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16") + run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], T.float16) + run_general_reduction_recommend_hints("SS", [1024, 1024], T.float16) + run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], T.float16) -def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float16", topk: int = 20): +def run_elementwise_recommend_hints(shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20): arch = auto_infer_current_arch() carve_template = carver.ElementwiseTemplate( shape=shape, @@ -40,18 +41,18 @@ def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float def test_elementwise_recommend_hints(): - run_elementwise_recommend_hints([1024, 1024], "float16") - run_elementwise_recommend_hints([1024], "float16") - run_elementwise_recommend_hints([1024, 1024, 1024], "float16") + run_elementwise_recommend_hints([1024, 1024], T.float16) + run_elementwise_recommend_hints([1024], T.float16) + run_elementwise_recommend_hints([1024, 1024, 1024], T.float16) def run_matmul_recommend_hints( M: int = 1024, N: int = 1024, K: int = 1024, - in_dtype: str = "float16", - out_dtype: str = "float16", - accum_dtype: str = "float16", + in_dtype: T.dtype = T.float16, + out_dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float16, ): arch = auto_infer_current_arch() carve_template = carver.MatmulTemplate( @@ -71,13 +72,13 @@ def run_matmul_recommend_hints( def test_matmul_recommend_hints(): - run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float16", "float16") - run_matmul_recommend_hints(1024, 1024, 1024, "int8", "int32", "int32") - run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16") + run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float16, T.float16) + run_matmul_recommend_hints(1024, 1024, 1024, T.int8, T.int32, T.int32) + run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float32, T.float16) def run_gemv_recommend_hints( - N: int = 1024, K: int = 1024, in_dtype: str = "float16", out_dtype: str = "float16", accum_dtype: str = "float16" + N: int = 1024, K: int = 1024, in_dtype: T.dtype = T.float16, out_dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float16 ): arch = auto_infer_current_arch() carve_template = carver.GEMVTemplate( @@ -96,9 +97,9 @@ def run_gemv_recommend_hints( def test_gemv_recommend_hints(): - run_gemv_recommend_hints(1024, 1024, "float16", "float16", "float16") - run_gemv_recommend_hints(1024, 1024, "int8", "int32", "int32") - run_gemv_recommend_hints(1024, 1024, "float16", "float32", "float16") + run_gemv_recommend_hints(1024, 1024, T.float16, T.float16, T.float16) + run_gemv_recommend_hints(1024, 1024, T.int8, T.int32, T.int32) + run_gemv_recommend_hints(1024, 1024, T.float16, T.float32, T.float16) def run_fmha_recommend_hints( @@ -107,9 +108,9 @@ def run_fmha_recommend_hints( seq_length: int = 512, seq_kv_length: int = 512, head_dim: int = 128, - in_dtype: str = "float16", - accum_dtype: str = "float16", - out_dtype: str = "float16", + in_dtype: T.dtype = T.float16, + accum_dtype: T.dtype = T.float16, + out_dtype: T.dtype = T.float16, ): arch = auto_infer_current_arch() carve_template = carver.FlashAttentionTemplate( @@ -133,8 +134,8 @@ def run_fmha_recommend_hints( def test_fmha_recommend_hints(): - run_fmha_recommend_hints(4, 32, 512, 512, 128, "float16", "float16", "float16") - run_fmha_recommend_hints(4, 32, 512, 512, 128, "int8", "int32", "int32") + run_fmha_recommend_hints(4, 32, 512, 512, 128, T.float16, T.float16, T.float16) + run_fmha_recommend_hints(4, 32, 512, 512, 128, T.int8, T.int32, T.int32) if __name__ == "__main__": diff --git a/testing/python/components/test_storage_rewrite_detect_inplace.py b/testing/python/components/test_storage_rewrite_detect_inplace.py index bd0a64d39..4c4f4e5f3 100644 --- a/testing/python/components/test_storage_rewrite_detect_inplace.py +++ b/testing/python/components/test_storage_rewrite_detect_inplace.py @@ -8,12 +8,12 @@ def _compile_kernel_without_inplace(): num_tokens = T.symbolic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]): + def buggy_kernel(x: T.Tensor[(num_tokens,), T.float]): with T.Kernel(num_tokens, threads=32) as pid: - read = T.alloc_var("int") + read = T.alloc_var(T.int) read = x[pid] - write = T.alloc_var("int") + write = T.alloc_var(T.int) write = read * 2 x[pid] = write @@ -29,12 +29,12 @@ def _compile_kernel_with_inplace(): num_tokens = T.symbolic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]): + def buggy_kernel(x: T.Tensor[(num_tokens,), T.float]): with T.Kernel(num_tokens, threads=32) as pid: - read = T.alloc_var("int") + read = T.alloc_var(T.int) read = x[pid] - write = T.alloc_var("int") + write = T.alloc_var(T.int) write = read * 2 x[pid] = write diff --git a/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py b/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py index 323f76458..d599e581a 100644 --- a/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py +++ b/testing/python/components/test_tilelang_pass_config_disable_warp_specialized.py @@ -1,5 +1,6 @@ -from tilelang import tvm as tvm import tilelang.testing +from tilelang import language as T +import torch def matmul( @@ -22,8 +23,6 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -93,8 +92,6 @@ def run_gemm( profiler = kernel.get_profiler() def ref_program(A, B): - import torch - if trans_A: A = A.T if trans_B: @@ -114,9 +111,9 @@ def test_gemm_f16f16f16_nn(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, @@ -129,9 +126,9 @@ def test_gemm_f16f16f16_nn(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, diff --git a/testing/python/cpu/test_tilelang_cpu_gemm.py b/testing/python/cpu/test_tilelang_cpu_gemm.py index 4a878f328..55646622e 100644 --- a/testing/python/cpu/test_tilelang_cpu_gemm.py +++ b/testing/python/cpu/test_tilelang_cpu_gemm.py @@ -5,7 +5,7 @@ import torch -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): num_stages = 0 @T.prim_func @@ -61,7 +61,7 @@ def test_matmul_codegen(): def test_matmul_compile(): - def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): # a simple kernel just for jit test @T.prim_func def matmul( @@ -103,7 +103,7 @@ def matmul( with tvm.target.Target("c"): complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes") - in_dtype = "float16" + in_dtype = T.float16 A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)) B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)) diff --git a/testing/python/debug/test_tilelang_debug_print.py b/testing/python/debug/test_tilelang_debug_print.py index e26296613..3483cffc0 100644 --- a/testing/python/debug/test_tilelang_debug_print.py +++ b/testing/python/debug/test_tilelang_debug_print.py @@ -5,7 +5,7 @@ import tilelang.language as T -def debug_print_buffer(M=16, N=16, dtype="float16"): +def debug_print_buffer(M=16, N=16, dtype=T.float16): @T.prim_func def program(Q: T.Tensor((M, N), dtype)): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): @@ -18,28 +18,28 @@ def program(Q: T.Tensor((M, N), dtype)): def test_debug_print_buffer(): - debug_print_buffer(dtype="bool") - debug_print_buffer(dtype="int8") - debug_print_buffer(dtype="int16") - debug_print_buffer(dtype="int32") - debug_print_buffer(dtype="int64") - debug_print_buffer(dtype="uint8") - debug_print_buffer(dtype="uint16") - debug_print_buffer(dtype="uint32") - debug_print_buffer(dtype="uint64") - debug_print_buffer(dtype="float16") - debug_print_buffer(dtype="float32") - debug_print_buffer(dtype="float64") - debug_print_buffer(dtype="bfloat16") - debug_print_buffer(dtype="float8_e4m3") - debug_print_buffer(dtype="float8_e4m3fn") - debug_print_buffer(dtype="float8_e4m3fnuz") - debug_print_buffer(dtype="float8_e5m2") - debug_print_buffer(dtype="float8_e5m2fnuz") + debug_print_buffer(dtype=T.bool) + debug_print_buffer(dtype=T.int8) + debug_print_buffer(dtype=T.int16) + debug_print_buffer(dtype=T.int32) + debug_print_buffer(dtype=T.int64) + debug_print_buffer(dtype=T.uint8) + debug_print_buffer(dtype=T.uint16) + debug_print_buffer(dtype=T.uint32) + debug_print_buffer(dtype=T.uint64) + debug_print_buffer(dtype=T.float16) + debug_print_buffer(dtype=T.float32) + debug_print_buffer(dtype=T.float64) + debug_print_buffer(dtype=T.bfloat16) + debug_print_buffer(dtype=T.float8_e4m3fn) + debug_print_buffer(dtype=T.float8_e4m3fn) + debug_print_buffer(dtype=T.float8_e4m3fnuz) + debug_print_buffer(dtype=T.float8_e5m2) + debug_print_buffer(dtype=T.float8_e5m2fnuz) def debug_print_buffer_conditional(M=16, N=16): - dtype = "float16" + dtype = T.float16 @T.prim_func def program(Q: T.Tensor((M, N), dtype)): @@ -59,7 +59,7 @@ def test_debug_print_buffer_conditional(): def debug_print_value_conditional(M=16, N=16): - dtype = "float16" + dtype = T.float16 @T.prim_func def program(Q: T.Tensor((M, N), dtype)): @@ -78,7 +78,7 @@ def test_debug_print_value_conditional(): def debug_print_register_files(M=16, N=16): - dtype = "float16" + dtype = T.float16 @T.prim_func def program(Q: T.Tensor((M, N), dtype)): @@ -97,7 +97,7 @@ def test_debug_print_register_files(): def debug_print_msg(M=16, N=16): - dtype = "float16" + dtype = T.float16 @T.prim_func def program(Q: T.Tensor((M, N), dtype)): diff --git a/testing/python/dynamic/test_tilelang_dynamic_symbolic.py b/testing/python/dynamic/test_tilelang_dynamic_symbolic.py index 8e50a2759..f93c330c8 100644 --- a/testing/python/dynamic/test_tilelang_dynamic_symbolic.py +++ b/testing/python/dynamic/test_tilelang_dynamic_symbolic.py @@ -33,18 +33,18 @@ def tl_matmul_macro( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -52,7 +52,7 @@ def tl_matmul_macro( block_col_warps = 1 warp_row_tiles = 16 warp_col_tiles = 16 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -453,36 +453,36 @@ def ref_program(A, B): def test_assert_tl_matmul_macro(): - assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16") - assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16") - assert_tl_matmul_macro_correctness(32, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_macro_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_macro_correctness(66, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_macro_correctness(32, 128, 128, T.float16, T.float16, T.float16) def test_assert_tl_matmul_block(): - assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) - assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) - assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_correctness(128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32) + assert_tl_matmul_block_correctness(67, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32) + assert_tl_matmul_block_correctness(36, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32) def test_assert_tl_matmul_block_all_dynamic(): - assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) - assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) - assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32) def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8 + 128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=8 ) assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 64, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8 + 64, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=8 ) assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4 + 64, 128, 60, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=4 ) # Tail split is enabled with dynamic alignment 0 assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( - 64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0 + 64, 128, 64, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=0 ) diff --git a/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py b/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py index 1bee1356f..ea6efadbc 100644 --- a/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py +++ b/testing/python/dynamic/test_tilelang_dynamic_symbolic_bench.py @@ -437,7 +437,7 @@ def ref_program(A, B): def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K): - assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", "float16", "float32") + assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, T.float16, T.float16, T.float32) def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): @@ -450,9 +450,9 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): block_K, False, False, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}, ) assert_tl_matmul_block_dynamic_m( @@ -464,9 +464,9 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): block_K, False, False, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, pass_configs={"tl.disable_dynamic_tail_split": False}, ) @@ -481,9 +481,9 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): block_K, False, False, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}, ) assert_tl_matmul_block_dynamic_mn( @@ -495,9 +495,9 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): block_K, False, False, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, pass_configs={"tl.disable_dynamic_tail_split": False}, ) @@ -512,9 +512,9 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): block_K, False, False, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 4}, ) assert_tl_matmul_block_dynamic_mnk( @@ -526,9 +526,9 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): block_K, False, False, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, pass_configs={"tl.disable_dynamic_tail_split": False}, ) diff --git a/testing/python/fastmath/test_mathops_fastmath.py b/testing/python/fastmath/test_mathops_fastmath.py index 72eddd960..e181eb4df 100644 --- a/testing/python/fastmath/test_mathops_fastmath.py +++ b/testing/python/fastmath/test_mathops_fastmath.py @@ -50,7 +50,7 @@ def check_non_fastmath_usage(source, mathop_name): check_fastmath_usage(source, mathop_name, expect_fastmath=False) -def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): +def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): """ Test single-argument mathops. T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) @@ -86,7 +86,7 @@ def main( print(f"✓ {mathop_name} compilation and execution test passed") -def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): +def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): """ Test two-argument mathops to ensure they generate non-fastmath CUDA code. """ @@ -134,7 +134,7 @@ def main( check_non_fastmath_usage(source_fastmath, mathop_name) # Test numerical correctness - torch_dtype = getattr(torch, dtype) + torch_dtype = dtype.as_torch() a = torch.randn(M, N, device="cuda", dtype=torch_dtype) b = torch.randn(M, N, device="cuda", dtype=torch_dtype) @@ -160,8 +160,8 @@ def run_abs_test(): @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): @@ -189,7 +189,7 @@ def main( print("✓ abs numerical test passed") -def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): +def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): """ Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). """ @@ -222,7 +222,7 @@ def main( check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) # Test numerical correctness - torch_dtype = getattr(torch, dtype) + torch_dtype = dtype.as_torch() a = torch.randn(M, N, device="cuda", dtype=torch_dtype) # Ensure positive values for functions that need them @@ -272,7 +272,7 @@ def main( @tilelang.testing.requires_cuda def test_mathops_generate_no_fastmath(name, func): """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" - run_single_arg_mathop_test(name, func, dtype="float32") + run_single_arg_mathop_test(name, func, dtype=T.float32) print(f"✓ {name} test passed") @@ -286,7 +286,7 @@ def test_mathops_generate_no_fastmath(name, func): @tilelang.testing.requires_cuda def test_two_arg_mathops_fastmath(name, func): """Test all two-argument mathops""" - run_two_arg_mathop_test(name, func, dtype="float32") + run_two_arg_mathop_test(name, func, dtype=T.float32) @tilelang.testing.requires_cuda @@ -311,7 +311,7 @@ def test_abs_maps_to_fabs(): @tilelang.testing.requires_cuda def test_fastmath_versions(name, func): """Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code""" - run_fastmath_mathop_test(name, func, dtype="float32") + run_fastmath_mathop_test(name, func, dtype=T.float32) print(f"✓ {name} test passed") diff --git a/testing/python/issue/test_tilelang_issue_1001.py b/testing/python/issue/test_tilelang_issue_1001.py index a4283daa5..f2315ef21 100644 --- a/testing/python/issue/test_tilelang_issue_1001.py +++ b/testing/python/issue/test_tilelang_issue_1001.py @@ -14,9 +14,9 @@ def _cumsum_view_infer_layout(hidden): num_tokens = T.dynamic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens, hidden), "float"]): + def buggy_kernel(x: T.Tensor[(num_tokens, hidden), T.float]): with T.Kernel(num_tokens, threads=128) as pid: - smem = T.alloc_shared((hidden,), dtype="float") + smem = T.alloc_shared((hidden,), dtype=T.float32) T.copy(x[pid, :], smem) T.cumsum(T.view(smem, (1, hidden)), dim=1) diff --git a/testing/python/issue/test_tilelang_issue_1008.py b/testing/python/issue/test_tilelang_issue_1008.py index 2d86d1645..a35a18449 100644 --- a/testing/python/issue/test_tilelang_issue_1008.py +++ b/testing/python/issue/test_tilelang_issue_1008.py @@ -33,7 +33,7 @@ def _fill_with_dynamic_region_kernel(): @T.prim_func def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821 with T.Kernel(num_tokens, threads=128) as _: - a, b = T.alloc_var("int"), T.alloc_var("int") + a, b = T.alloc_var(T.int), T.alloc_var(T.int) T.fill(x[a:b], 0) return buggy_kernel diff --git a/testing/python/issue/test_tilelang_issue_1115.py b/testing/python/issue/test_tilelang_issue_1115.py index ce21a3b05..658c126a0 100644 --- a/testing/python/issue/test_tilelang_issue_1115.py +++ b/testing/python/issue/test_tilelang_issue_1115.py @@ -9,7 +9,7 @@ def set_cache_kernel( S, D, pos_ty="int64", - dtype="float32", + dtype=T.float32, ): @T.prim_func def main( @@ -36,7 +36,7 @@ def main( pos_int64 = torch.arange(S, device="cuda", dtype=torch.int64) pos_int32 = torch.arange(S, device="cuda", dtype=torch.int32) kernel_int64 = set_cache_kernel(S, D, "int64") - kernel_int32 = set_cache_kernel(S, D, "int32") + kernel_int32 = set_cache_kernel(S, D, T.int32) kernel_int64(pos_int64, value, cache) torch.testing.assert_close(cache, value) kernel_int32(pos_int32, value, cache) diff --git a/testing/python/issue/test_tilelang_issue_1198.py b/testing/python/issue/test_tilelang_issue_1198.py index 08f36822b..e6330e435 100644 --- a/testing/python/issue/test_tilelang_issue_1198.py +++ b/testing/python/issue/test_tilelang_issue_1198.py @@ -9,7 +9,7 @@ def foo( [ 32, ], - "int32", + T.int32, ), ): pass diff --git a/testing/python/issue/test_tilelang_issue_1210.py b/testing/python/issue/test_tilelang_issue_1210.py index 971fb8193..2e141d782 100644 --- a/testing/python/issue/test_tilelang_issue_1210.py +++ b/testing/python/issue/test_tilelang_issue_1210.py @@ -4,10 +4,10 @@ def _make_kernel(M, N): - dtype = "bfloat16" + dtype = T.bfloat16 @T.prim_func - def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), "int32")): + def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), T.int32)): with T.Kernel(4, threads=1): A = T.alloc_shared([N], dtype) B = T.alloc_shared([N], dtype) diff --git a/testing/python/issue/test_tilelang_issue_1237.py b/testing/python/issue/test_tilelang_issue_1237.py index a9aadc5ee..bb936e468 100644 --- a/testing/python/issue/test_tilelang_issue_1237.py +++ b/testing/python/issue/test_tilelang_issue_1237.py @@ -7,12 +7,12 @@ def test_issue_1237_dynamic_copy_extent_builds(): # The goal is to ensure T.copy correctly handles dynamic extents # (e.g., src slice length vs. static dst buffer size) during prim_func building. - length = T.symbolic("len", dtype="int32") + length = T.symbolic("len", dtype=T.int32) @T.prim_func - def sample_kernel(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821 + def sample_kernel(global_tensor: T.Tensor[(length,), T.int32]): # noqa: F821 with T.Kernel(1, threads=32): - buffer_shared = T.alloc_shared((1024,), dtype="int32") + buffer_shared = T.alloc_shared((1024,), dtype=T.int32) T.copy(global_tensor[0:length], buffer_shared) # Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute. diff --git a/testing/python/issue/test_tilelang_issue_814.py b/testing/python/issue/test_tilelang_issue_814.py index a202bd960..f9f94bd74 100644 --- a/testing/python/issue/test_tilelang_issue_814.py +++ b/testing/python/issue/test_tilelang_issue_814.py @@ -5,7 +5,7 @@ @tilelang.jit -def _tmp_var_kernel(N, block_N, dtype="float"): +def _tmp_var_kernel(N, block_N, dtype=T.float32): @T.prim_func def kernel( A: T.Tensor((N,), dtype), diff --git a/testing/python/issue/test_tilelang_issue_830.py b/testing/python/issue/test_tilelang_issue_830.py index 74ceed3d9..1a2a909d2 100644 --- a/testing/python/issue/test_tilelang_issue_830.py +++ b/testing/python/issue/test_tilelang_issue_830.py @@ -34,7 +34,7 @@ def _empty_with_dead_code_kernel(): num_tokens = T.dynamic("num_tokens") @T.prim_func - def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]): + def buggy_kernel(x: T.Tensor[(num_tokens,), T.float32]): with T.Kernel(num_tokens, threads=32) as pid: y = x[pid] diff --git a/testing/python/issue/test_tilelang_issue_96.py b/testing/python/issue/test_tilelang_issue_96.py index 6ab7fe479..9bf5c69bd 100644 --- a/testing/python/issue/test_tilelang_issue_96.py +++ b/testing/python/issue/test_tilelang_issue_96.py @@ -4,7 +4,7 @@ import torch -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), diff --git a/testing/python/issue/test_tilelang_issue_merge_if.py b/testing/python/issue/test_tilelang_issue_merge_if.py index fa9432fc8..e3b1e3082 100644 --- a/testing/python/issue/test_tilelang_issue_merge_if.py +++ b/testing/python/issue/test_tilelang_issue_merge_if.py @@ -8,10 +8,10 @@ def merge_if_test(): @T.prim_func def main(): - A = T.alloc_fragment((1,), "float16") - B = T.alloc_fragment((1,), "float16") - C = T.alloc_fragment((1,), "float16") - D = T.alloc_fragment((1,), "float16") + A = T.alloc_fragment((1,), T.float16) + B = T.alloc_fragment((1,), T.float16) + C = T.alloc_fragment((1,), T.float16) + D = T.alloc_fragment((1,), T.float16) if A[0] == 0: A[0] = 0 if B[0] == 0: diff --git a/testing/python/jit/test_tilelang_jit_callback.py b/testing/python/jit/test_tilelang_jit_callback.py index 7d76a64d1..98b88820c 100644 --- a/testing/python/jit/test_tilelang_jit_callback.py +++ b/testing/python/jit/test_tilelang_jit_callback.py @@ -1,4 +1,4 @@ -from tilelang import tvm as tvm +from tilelang import language as T import tilelang.testing import tilelang from tilelang.engine.callback import register_cuda_postproc_callback @@ -25,8 +25,6 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -107,9 +105,9 @@ def test_gemm_f16f16f16_nn(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, @@ -137,8 +135,6 @@ def matmu_jit_kernel( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -207,8 +203,6 @@ def run_gemm_jit_kernel( B = B.T def ref_program(A, B): - import torch - C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -226,9 +220,9 @@ def test_gemm_jit_kernel(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, diff --git a/testing/python/jit/test_tilelang_jit_gemm.py b/testing/python/jit/test_tilelang_jit_gemm.py index 153f06cb1..97391f26f 100644 --- a/testing/python/jit/test_tilelang_jit_gemm.py +++ b/testing/python/jit/test_tilelang_jit_gemm.py @@ -1,4 +1,4 @@ -from tilelang import tvm as tvm +from tilelang import language as T import tilelang.testing import tilelang import torch @@ -27,8 +27,6 @@ def matmul_kernel_jit( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -95,8 +93,6 @@ def run_gemm_kernel_jit( B = B.T def ref_program(A, B): - import torch - C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = C.to(torch.__getattribute__(out_dtype)) return C @@ -114,9 +110,9 @@ def test_gemm_f16f16f16_nn_kernel_jit(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 128, 32, diff --git a/testing/python/jit/test_tilelang_jit_gemm_cython.py b/testing/python/jit/test_tilelang_jit_gemm_cython.py index 4ea4ba88d..c5399fc51 100644 --- a/testing/python/jit/test_tilelang_jit_gemm_cython.py +++ b/testing/python/jit/test_tilelang_jit_gemm_cython.py @@ -104,9 +104,9 @@ def test_gemm_f16f16f16_nn(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, @@ -226,9 +226,9 @@ def test_gemm_jit_kernel(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, @@ -278,7 +278,7 @@ def run_cython_kernel_do_bench( def test_cython_kernel_do_bench(): - run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_cython_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) def run_cython_kernel_multi_stream( @@ -322,7 +322,7 @@ def run_cython_kernel_multi_stream( def test_cython_kernel_multi_stream(): - run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_cython_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) def run_cython_dynamic_shape( @@ -371,11 +371,11 @@ def run_cython_dynamic_shape( def test_cython_dynamic_shape(): - run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) - run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) - run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) def run_cython_dynamic_shape_with_out_idx( @@ -424,7 +424,7 @@ def run_cython_dynamic_shape_with_out_idx( def test_cython_dynamic_shape_with_out_idx(): - run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) def matmul_int_variable( @@ -495,7 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B def test_matmul_int_variable(): - run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128) + run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, T.float16, T.float16, T.float32, 0, 128) def matmul_float_variable( @@ -566,7 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans def test_matmul_float_variable(): - run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128) + run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, T.float16, T.float16, T.float32, 0, 128) if __name__ == "__main__": diff --git a/testing/python/jit/test_tilelang_jit_nullptr.py b/testing/python/jit/test_tilelang_jit_nullptr.py index 8965e2ad3..a9edb5e93 100644 --- a/testing/python/jit/test_tilelang_jit_nullptr.py +++ b/testing/python/jit/test_tilelang_jit_nullptr.py @@ -7,7 +7,7 @@ @tl.jit -def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float", with_bias=False): +def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, with_bias=False): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -38,7 +38,7 @@ def main( return main -def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def run_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) diff --git a/testing/python/jit/test_tilelang_jit_nvrtc.py b/testing/python/jit/test_tilelang_jit_nvrtc.py index 2b1502772..b6823b8cc 100644 --- a/testing/python/jit/test_tilelang_jit_nvrtc.py +++ b/testing/python/jit/test_tilelang_jit_nvrtc.py @@ -104,9 +104,9 @@ def test_gemm_f16f16f16_nn(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, @@ -224,9 +224,9 @@ def test_gemm_jit_kernel(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, @@ -269,7 +269,7 @@ def run_nvrtc_kernel_do_bench( def test_nvrtc_kernel_do_bench(): - run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) def run_nvrtc_kernel_multi_stream( @@ -311,7 +311,7 @@ def run_nvrtc_kernel_multi_stream( def test_nvrtc_kernel_multi_stream(): - run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) def run_nvrtc_dynamic_shape( @@ -360,11 +360,11 @@ def run_nvrtc_dynamic_shape( def test_nvrtc_dynamic_shape(): - run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) - run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) - run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) def check_hopper(): @@ -375,7 +375,7 @@ def check_hopper(): return compute_capability == (9, 0) -def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): +def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @@ -463,7 +463,7 @@ def elementwise_add_with_l2_cache( M, N, block_size=256, - dtype="float32", + dtype=T.float32, ): @T.prim_func def kernel( diff --git a/testing/python/jit/test_tilelang_jit_parcompile.py b/testing/python/jit/test_tilelang_jit_parcompile.py index 0a6e9062c..56201e1cc 100644 --- a/testing/python/jit/test_tilelang_jit_parcompile.py +++ b/testing/python/jit/test_tilelang_jit_parcompile.py @@ -1,6 +1,7 @@ import tilelang.testing import tilelang import torch +from tilelang import language as T @tilelang.jit( @@ -16,9 +17,9 @@ def matmul_kernel_jit( block_K, trans_A=False, trans_B=True, - in_dtype="float16", - out_dtype="float32", - accum_dtype="float32", + in_dtype=T.float16, + out_dtype=T.float32, + accum_dtype=T.float32, num_stages=2, threads=128, ): diff --git a/testing/python/jit/test_tilelang_jit_tvm_ffi.py b/testing/python/jit/test_tilelang_jit_tvm_ffi.py index 5daaf3083..a0df27192 100644 --- a/testing/python/jit/test_tilelang_jit_tvm_ffi.py +++ b/testing/python/jit/test_tilelang_jit_tvm_ffi.py @@ -162,9 +162,9 @@ def test_gemm_jit_kernel(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, @@ -207,7 +207,7 @@ def run_tvm_ffi_kernel_do_bench( def test_tvm_ffi_kernel_do_bench(): - run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) def run_tvm_ffi_kernel_multi_stream( @@ -249,7 +249,7 @@ def run_tvm_ffi_kernel_multi_stream( def test_tvm_ffi_kernel_multi_stream(): - run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) def run_tvm_ffi_dynamic_shape( @@ -298,12 +298,12 @@ def run_tvm_ffi_dynamic_shape( def test_tvm_ffi_dynamic_shape(): - run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) - run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2) run_tvm_ffi_dynamic_shape( - T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2 + T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2 ) @@ -315,7 +315,7 @@ def check_hopper(): return compute_capability == (9, 0) -def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): +def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 @@ -403,7 +403,7 @@ def elementwise_add_with_l2_cache( M, N, block_size=256, - dtype="float32", + dtype=T.float32, ): @T.prim_func def kernel( diff --git a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py index e7d7021c5..97d050b73 100644 --- a/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py @@ -39,27 +39,27 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "bfloat16", - "float8_e4m3", - "float8_e5m2", - "int8", + T.float16, + T.bfloat16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 is_float8 = in_dtype in [ - "float8_e4m3", - "float8_e5m2", - "float8_e4m3fn", - "float8_e5m2fnuz", + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, ] - if out_dtype == "int32" or is_float8: + if out_dtype == T.int32 or is_float8: micro_size_k = 32 # This is a debug config @@ -67,7 +67,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 32 warp_col_tiles = 32 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -221,7 +221,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(8, 0) def test_assert_tl_matmul_bfloat16(): - assert_tl_matmul_correctness(256, 256, 256, "bfloat16", "float32", "float32") + assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32) if __name__ == "__main__": diff --git a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py b/testing/python/kernel/test_tilelang_kernel_element_wise_add.py index 52763c817..501b38fda 100644 --- a/testing/python/kernel/test_tilelang_kernel_element_wise_add.py +++ b/testing/python/kernel/test_tilelang_kernel_element_wise_add.py @@ -1,5 +1,5 @@ -from tilelang import tvm as tvm import tilelang.testing +from tilelang import language as T import torch @@ -12,8 +12,6 @@ def elementwise_add( out_dtype, threads, ): - import tilelang.language as T - @T.prim_func def main( A: T.Tensor((M, N), in_dtype), @@ -67,8 +65,8 @@ def test_elementwise_add_f32(): run_elementwise_add( 512, 1024, - "float32", - "float32", + T.float32, + T.float32, 128, 256, ) @@ -78,8 +76,8 @@ def test_elementwise_add_f16(): run_elementwise_add( 512, 1024, - "float16", - "float16", + T.float16, + T.float16, 128, 256, ) @@ -89,8 +87,8 @@ def test_elementwise_add_i32(): run_elementwise_add( 512, 1024, - "int32", - "int32", + T.int32, + T.int32, 128, 256, ) @@ -100,8 +98,8 @@ def test_elementwise_add_f32f16(): run_elementwise_add( 512, 1024, - "float32", - "float16", + T.float32, + T.float16, 128, 256, ) diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py index 63c821202..276083b26 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm.py @@ -54,8 +54,8 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_ @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(9) def test_assert_matmul(): - assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "float8_e4m3", "float32", "float32") - assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "float8_e5m2", "float32", "float32") + assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, T.float8_e4m3fn, T.float32, T.float32) + assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, T.float8_e5m2, T.float32, T.float32) if __name__ == "__main__": diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py index eec3a9caf..9ba369b6b 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemm_mma.py @@ -39,26 +39,26 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "float8_e4m3", - "float8_e5m2", - "int8", + T.float16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 is_float8 = in_dtype in [ - "float8_e4m3", - "float8_e5m2", - "float8_e4m3fn", - "float8_e5m2fnuz", + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, ] - if out_dtype == "int32" or is_float8: + if out_dtype == T.int32 or is_float8: micro_size_k = 32 # This is a debug config @@ -66,7 +66,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 32 warp_col_tiles = 32 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -221,8 +221,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(8, 9) def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32") - assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") + assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) + assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) if __name__ == "__main__": diff --git a/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py b/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py index 4a48b656f..7b757992a 100644 --- a/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py @@ -46,7 +46,7 @@ def gemv_simt( C_shape = (M, N) dp4a_size = 4 - use_dp4a = in_dtype == "int8" and accum_dtype == "int32" + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 @T.prim_func def main( @@ -164,8 +164,8 @@ def evaluate_gemv_simt( @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(8, 9) def test_gemv_simt(): - evaluate_gemv_simt(1, 1024, 1024, "float8_e4m3", "float32", "float32", with_bias=False) - evaluate_gemv_simt(1, 1024, 1024, "float8_e5m2", "float32", "float32", with_bias=False) + evaluate_gemv_simt(1, 1024, 1024, T.float8_e4m3fn, T.float32, T.float32, with_bias=False) + evaluate_gemv_simt(1, 1024, 1024, T.float8_e5m2, T.float32, T.float32, with_bias=False) if __name__ == "__main__": diff --git a/testing/python/kernel/test_tilelang_kernel_gemm.py b/testing/python/kernel/test_tilelang_kernel_gemm.py index 6c01297a1..6dc95e98a 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm.py @@ -1,5 +1,6 @@ from tilelang import tvm as tvm import tilelang.testing +import tilelang.language as T def matmul( @@ -22,8 +23,6 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -92,7 +91,7 @@ def ref_program(A, B): A = A.T if trans_B: B = B.T - if in_dtype == "float32": + if in_dtype == T.float32: # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas A = (A.view(torch.int32) - 0x1000).view(torch.float32) @@ -111,9 +110,9 @@ def test_gemm_f16f16f16_nn(): 768, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 128, 32, @@ -128,9 +127,9 @@ def test_gemm_f16f16f32_nn(): 768, False, False, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, 128, 128, 32, @@ -144,9 +143,9 @@ def test_gemm_bf16bf16f32_nn(): 768, False, False, - "bfloat16", - "bfloat16", - "float32", + T.bfloat16, + T.bfloat16, + T.float32, 128, 128, 32, @@ -160,9 +159,9 @@ def test_gemm_f32f32f32_nn(): 768, False, False, - "float32", - "float32", - "float32", + T.float32, + T.float32, + T.float32, 64, 128, 32, @@ -176,9 +175,9 @@ def test_gemm_f16f16f16_tn(): 768, True, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 128, 32, @@ -193,9 +192,9 @@ def test_gemm_f16f16f16_nt(): 768, False, True, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 128, 32, @@ -204,15 +203,15 @@ def test_gemm_f16f16f16_nt(): def test_gemm_i8i8i32_nt(): - run_gemm(512, 1024, 768, False, True, "int8", "int8", "int32", 128, 128, 64) + run_gemm(512, 1024, 768, False, True, T.int8, T.int8, T.int32, 128, 128, 64) def test_gemm_i8i8i32_tn(): - run_gemm(512, 1024, 768, True, False, "int8", "int8", "int32", 128, 128, 64) + run_gemm(512, 1024, 768, True, False, T.int8, T.int8, T.int32, 128, 128, 64) def test_gemm_f64f64f64_nt(): - run_gemm(512, 512, 512, False, True, "float64", "float64", "float64", 64, 32, 16) + run_gemm(512, 512, 512, False, True, T.float64, T.float64, T.float64, 64, 32, 16) def test_gemm_f32f32f32_nt(): @@ -222,9 +221,9 @@ def test_gemm_f32f32f32_nt(): 768, False, True, - "float32", - "float32", - "float32", + T.float32, + T.float32, + T.float32, 64, 128, 32, @@ -238,9 +237,9 @@ def test_gemm_f32f32f32_tn(): 768, True, False, - "float32", - "float32", - "float32", + T.float32, + T.float32, + T.float32, 64, 128, 32, @@ -254,9 +253,9 @@ def test_pad_aligned_f16f16f16_nn(): 768 - 24, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, @@ -271,9 +270,9 @@ def test_pad_f16f16f16_nn(): 768 - 5, False, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 256, 32, @@ -288,9 +287,9 @@ def test_pad_f16f16f32_nn(): 768 + 15, False, False, - "float16", - "float16", - "float32", + T.float16, + T.float16, + T.float32, 128, 64, 32, @@ -407,9 +406,9 @@ def test_gemm_f16f16f16_sr(): 768, False, True, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 128, 32, @@ -526,9 +525,9 @@ def test_gemm_f16f16f16_rs(): 768, True, False, - "float16", - "float16", - "float16", + T.float16, + T.float16, + T.float16, 128, 128, 32, diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py index 3633d3ece..dd1b75ebc 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_mma_intrinsic.py @@ -39,27 +39,27 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "bfloat16", - "float8_e4m3", - "float8_e5m2", - "int8", + T.float16, + T.bfloat16, + T.float8_e4m3fn, + T.float8_e5m2, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 is_float8 = in_dtype in [ - "float8_e4m3", - "float8_e5m2", - "float8_e4m3fn", - "float8_e5m2fnuz", + T.float8_e4m3fn, + T.float8_e5m2, + T.float8_e4m3fn, + T.float8_e5m2fnuz, ] - if out_dtype == "int32" or is_float8: + if out_dtype == T.int32 or is_float8: micro_size_k = 32 # This is a debug config @@ -67,7 +67,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 32 warp_col_tiles = 32 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -219,22 +219,22 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(8, 0) def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") - assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") - assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32") + assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) + assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(8, 0) def test_assert_tl_matmul_bfloat16(): - assert_tl_matmul_correctness(256, 256, 256, "bfloat16", "float32", "float32") + assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(8, 9) def test_assert_tl_matmul_fp8(): - assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32") - assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") + assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) + assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) if __name__ == "__main__": diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py index e4da44b26..584aa854a 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_simt.py @@ -35,13 +35,13 @@ def tl_matmul_simt( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" # This is a debug config @@ -72,7 +72,7 @@ def tl_matmul_simt( micro_size_k = 128 // DataType(in_dtype).bits dp4a_size = 4 - use_dp4a = in_dtype == "int8" and accum_dtype == "int32" + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 @T.prim_func def main( @@ -139,7 +139,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None - if in_dtype == "int8": + if in_dtype == T.int8: A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8) else: @@ -161,9 +161,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def test_assert_tl_matmul(): - assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") - assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") - assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32") + assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16) + assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32) + assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32) if __name__ == "__main__": diff --git a/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py b/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py index 2def480db..1f7660032 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm_with_stride.py @@ -4,7 +4,7 @@ import torch -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), diff --git a/testing/python/kernel/test_tilelang_kernel_gemv_simt.py b/testing/python/kernel/test_tilelang_kernel_gemv_simt.py index 5825f695c..b4a5c8249 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemv_simt.py +++ b/testing/python/kernel/test_tilelang_kernel_gemv_simt.py @@ -46,7 +46,7 @@ def gemv_simt( C_shape = (M, N) dp4a_size = 4 - use_dp4a = in_dtype == "int8" and accum_dtype == "int32" + use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32 @T.prim_func def main( @@ -164,15 +164,15 @@ def evaluate_gemv_simt( @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(8, 0) def test_gemv_simt(): - evaluate_gemv_simt(1, 1024, 1024, "float16", "float16", "float16", with_bias=False) - evaluate_gemv_simt(1, 1024, 1024, "int8", "int32", "int32", with_bias=False) + evaluate_gemv_simt(1, 1024, 1024, T.float16, T.float16, T.float16, with_bias=False) + evaluate_gemv_simt(1, 1024, 1024, T.int8, T.int32, T.int32, with_bias=False) @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version(8, 9) def test_gemv_simt_fp8(): - evaluate_gemv_simt(1, 1024, 1024, "float8_e4m3", "float32", "float32", with_bias=False) - evaluate_gemv_simt(1, 1024, 1024, "float8_e5m2", "float32", "float32", with_bias=False) + evaluate_gemv_simt(1, 1024, 1024, T.float8_e4m3fn, T.float32, T.float32, with_bias=False) + evaluate_gemv_simt(1, 1024, 1024, T.float8_e5m2, T.float32, T.float32, with_bias=False) if __name__ == "__main__": diff --git a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py index affeb3ddf..9d60e5229 100644 --- a/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py +++ b/testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py @@ -26,20 +26,20 @@ def tl_matmul( accum_dtype, ): assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" K = K // 2 micro_size_x = micro_size_y = micro_size_k = 16 - if accum_dtype == "int32": + if accum_dtype == T.int32: micro_size_k = 32 # This is a debug config @@ -47,7 +47,7 @@ def tl_matmul( block_col_warps = 2 warp_row_tiles = 64 warp_col_tiles = 64 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -197,8 +197,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def test_assert_tl_matmul_correctness(): - assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") - assert_tl_matmul_correctness(128, 128, 64, "int8", "int32", "int32") + assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) + assert_tl_matmul_correctness(128, 128, 64, T.int8, T.int32, T.int32) @simplify_prim_func @@ -212,18 +212,18 @@ def tl_matmul_weight_only_transform( ): K = K // 2 assert in_dtype in [ - "float16", - "int8", + T.float16, + T.int8, ], "Currently only float16 and int8 are supported" assert out_dtype in [ - "float16", - "float32", - "int32", + T.float16, + T.float32, + T.int32, ], "Currently only float16, float32 and int32 are supported" micro_size_x = micro_size_y = micro_size_k = 16 - if out_dtype == "int32": + if out_dtype == T.int32: micro_size_k = 32 transform_b = 3 @@ -233,7 +233,7 @@ def tl_matmul_weight_only_transform( block_col_warps = 2 warp_row_tiles = 64 warp_col_tiles = 64 - chunk = 32 if in_dtype == "float16" else 64 + chunk = 32 if in_dtype == T.float16 else 64 shared_scope = "shared.dyn" # Pipeline Stage @@ -375,8 +375,8 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ladder_permutate_config = bitblas.ops.LadderPermutateConfig( M=N, N=(K // 2), - datatype="int8", - storage_dtype="int8", + datatype=T.int8, + storage_dtype=T.int8, transform_kind=transform_b, transpose_matrix=True, ) @@ -400,9 +400,9 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt @tilelang.testing.requires_package("bitblas") @tilelang.testing.requires_llvm def test_assert_tl_matmul_weight_only_transform(): - assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32") + assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, T.int8, T.int32, T.int32) if __name__ == "__main__": # tilelang.testing.main() - assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32) diff --git a/testing/python/language/test_tilelang_language_alias.py b/testing/python/language/test_tilelang_language_alias.py index f55d9e85e..48fe1ac4d 100644 --- a/testing/python/language/test_tilelang_language_alias.py +++ b/testing/python/language/test_tilelang_language_alias.py @@ -4,7 +4,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -43,7 +43,7 @@ def main( return main -def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) kernel = tilelang.compile(program, out_idx=[2], target="cuda") kernel.run_once() diff --git a/testing/python/language/test_tilelang_language_all_of.py b/testing/python/language/test_tilelang_language_all_of.py index 48412127b..db694d337 100644 --- a/testing/python/language/test_tilelang_language_all_of.py +++ b/testing/python/language/test_tilelang_language_all_of.py @@ -31,8 +31,8 @@ def blocksparse_matmul_global( num_stages, thread_num, enable_rasteration, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @@ -75,8 +75,8 @@ def blocksparse_matmul_shared( num_stages, thread_num, enable_rasteration, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @@ -124,8 +124,8 @@ def blocksparse_matmul_local( num_stages, thread_num, enable_rasteration, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) diff --git a/testing/python/language/test_tilelang_language_alloc.py b/testing/python/language/test_tilelang_language_alloc.py index 6695e9348..883f65c3c 100644 --- a/testing/python/language/test_tilelang_language_alloc.py +++ b/testing/python/language/test_tilelang_language_alloc.py @@ -1,4 +1,5 @@ import tilelang.testing +from tilelang import language as T def alloc_var( @@ -6,8 +7,6 @@ def alloc_var( block_N, dtype, ): - import tilelang.language as T - @T.prim_func def main( A: T.Tensor((N,), dtype), @@ -38,7 +37,7 @@ def run_alloc_var( def test_alloc_var(): - run_alloc_var(1024, 128, "float16") + run_alloc_var(1024, 128, T.float16) def alloc_var_add( @@ -78,7 +77,7 @@ def run_alloc_var_add( def test_alloc_var_add(): - run_alloc_var_add(1024, 128, "float16") + run_alloc_var_add(1024, 128, T.float16) def alloc_var_with_initializer( @@ -117,7 +116,7 @@ def run_alloc_var_with_initializer( def test_alloc_var_with_initializer(): - run_alloc_var_with_initializer(256, 64, "int32", 5) + run_alloc_var_with_initializer(256, 64, T.int32, 5) def alloc_multi_vars_with_initializer( @@ -156,7 +155,7 @@ def run_alloc_multi_vars_with_initializer( def test_alloc_multi_vars_with_initializer(): - run_alloc_multi_vars_with_initializer(256, 64, "int32") + run_alloc_multi_vars_with_initializer(256, 64, T.int32) if __name__ == "__main__": diff --git a/testing/python/language/test_tilelang_language_annotate_safe_value.py b/testing/python/language/test_tilelang_language_annotate_safe_value.py index 442172b6f..3c8239a15 100644 --- a/testing/python/language/test_tilelang_language_annotate_safe_value.py +++ b/testing/python/language/test_tilelang_language_annotate_safe_value.py @@ -6,7 +6,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0): +def tilelang_copy(M, N, block_M, block_N, dtype=T.float16, pad_value=0): @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -26,7 +26,7 @@ def main( return main -def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0): +def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16, pad_value=0): program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value) kernel = tilelang.compile( program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} diff --git a/testing/python/language/test_tilelang_language_any_of.py b/testing/python/language/test_tilelang_language_any_of.py index 37605e5a0..74db94f7c 100644 --- a/testing/python/language/test_tilelang_language_any_of.py +++ b/testing/python/language/test_tilelang_language_any_of.py @@ -31,8 +31,8 @@ def blocksparse_matmul_global( num_stages, thread_num, enable_rasteration, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @@ -75,8 +75,8 @@ def blocksparse_matmul_shared( num_stages, thread_num, enable_rasteration, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) @@ -124,8 +124,8 @@ def blocksparse_matmul_local( num_stages, thread_num, enable_rasteration, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) diff --git a/testing/python/language/test_tilelang_language_assume.py b/testing/python/language/test_tilelang_language_assume.py index 32e6b1c31..06e92dfa9 100644 --- a/testing/python/language/test_tilelang_language_assume.py +++ b/testing/python/language/test_tilelang_language_assume.py @@ -9,7 +9,7 @@ def kernel_with_assume(): N = T.dynamic("N") @T.prim_func - def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32): + def main(A: T.Tensor((N,), T.float32), l: T.int32, r: T.int32): with T.Kernel(1, threads=32) as _: for i in T.serial(r - l + 1): T.assume(l + i >= 0 and l + i < N) @@ -31,8 +31,8 @@ def kernel_vectorize(M): @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), ): with T.Kernel(1, threads=32) as _: tid = T.get_thread_binding() @@ -60,8 +60,8 @@ def kernel_complex(): @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), ): with T.Kernel(1, threads=32) as _: tid = T.get_thread_binding() diff --git a/testing/python/language/test_tilelang_language_atomic_add.py b/testing/python/language/test_tilelang_language_atomic_add.py index eaf5ae1ed..fa4dff7b3 100644 --- a/testing/python/language/test_tilelang_language_atomic_add.py +++ b/testing/python/language/test_tilelang_language_atomic_add.py @@ -3,7 +3,7 @@ @tilelang.jit -def atomic_add_program(K, M, N, block_M, block_N, dtype="float"): +def atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): @@ -17,7 +17,7 @@ def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): return atomic_add -def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"): +def run_atomic_add(K, M, N, block_M, block_N, dtype=T.float32): kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) import torch @@ -36,7 +36,7 @@ def ref_program(A, B): @tilelang.jit -def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"): +def tile_atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): @@ -49,7 +49,7 @@ def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): return atomic_add -def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"): +def run_tile_atomic_add(K, M, N, block_M, block_N, dtype=T.float32): kernel = tile_atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) print(kernel.get_kernel_source()) import torch @@ -71,7 +71,7 @@ def ref_program(A, B): @tilelang.jit -def atomic_max_program(K, M, N, block_M, block_N, dtype="float"): +def atomic_max_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): @@ -85,7 +85,7 @@ def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): return atomic_max -def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"): +def run_atomic_max(K, M, N, block_M, block_N, dtype=T.float32): kernel = atomic_max_program(K, M, N, block_M, block_N, dtype=dtype) import torch @@ -104,7 +104,7 @@ def ref_program(A, B): @tilelang.jit -def atomic_min_program(K, M, N, block_M, block_N, dtype="float"): +def atomic_min_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): @@ -118,7 +118,7 @@ def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): return atomic_min -def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"): +def run_atomic_min(K, M, N, block_M, block_N, dtype=T.float32): kernel = atomic_min_program(K, M, N, block_M, block_N, dtype=dtype) import torch @@ -137,7 +137,7 @@ def ref_program(A, B): @tilelang.jit -def atomic_load_store_program(M, N, block_M, block_N, dtype="float"): +def atomic_load_store_program(M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): @@ -151,7 +151,7 @@ def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): return atomic_load_store -def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"): +def run_atomic_load_store(M, N, block_M, block_N, dtype=T.float32): kernel = atomic_load_store_program(M, N, block_M, block_N, dtype=dtype) import torch @@ -162,7 +162,7 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"): @tilelang.jit -def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"): +def atomic_memory_order_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): @@ -176,7 +176,7 @@ def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), return atomic_with_memory_order -def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"): +def run_atomic_memory_order(K, M, N, block_M, block_N, dtype=T.float32): kernel = atomic_memory_order_program(K, M, N, block_M, block_N, dtype=dtype) import torch @@ -197,7 +197,7 @@ def ref_program(A, B): @tilelang.jit def atomic_addx2_program(M, N, block_M, block_N): @T.prim_func - def atomic_addx2(A: T.Tensor((M, N), "float16"), B: T.Tensor((M, N), "float16")): + def atomic_addx2(A: T.Tensor((M, N), T.float16), B: T.Tensor((M, N), T.float16)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): for i, j in T.Parallel(block_M, block_N // 2): idx_i = bx * block_M + i @@ -248,7 +248,7 @@ def test_atomic_addx2(): @tilelang.jit -def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): +def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_different_orders( A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype), D: T.Tensor((M, N), dtype) @@ -266,7 +266,7 @@ def atomic_different_orders( return atomic_different_orders -def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"): +def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype=T.float32): kernel = atomic_different_memory_orders_program(M, N, block_M, block_N, dtype=dtype) import torch @@ -285,7 +285,7 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"): @tilelang.jit def atomic_addx4_program(M, N, block_M, block_N): @T.prim_func - def atomic_addx4(A: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32")): + def atomic_addx4(A: T.Tensor((M, N), T.float32), B: T.Tensor((M, N), T.float32)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): for i, j in T.Parallel(block_M, block_N // 4): idx_i = bx * block_M + i @@ -315,7 +315,7 @@ def run_atomic_addx4(M, N, block_M, block_N): @tilelang.jit -def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"): +def atomic_return_prev_program(M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), old_vals: T.Tensor((M, N), dtype)): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): @@ -328,7 +328,7 @@ def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtyp return atomic_with_return_prev -def run_atomic_return_prev(M, N, block_M, block_N, dtype="float32"): +def run_atomic_return_prev(M, N, block_M, block_N, dtype=T.float32): kernel = atomic_return_prev_program(M, N, block_M, block_N, dtype=dtype) import torch @@ -344,9 +344,9 @@ def run_atomic_return_prev(M, N, block_M, block_N, dtype="float32"): def test_atomic_different_memory_orders(): - run_atomic_different_memory_orders(32, 32, 8, 8, dtype="float") - run_atomic_different_memory_orders(32, 32, 8, 8, dtype="float16") - run_atomic_different_memory_orders(32, 32, 8, 8, dtype="bfloat16") + run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.float32) + run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.float16) + run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.bfloat16) def test_atomic_addx4(): diff --git a/testing/python/language/test_tilelang_language_ceildiv.py b/testing/python/language/test_tilelang_language_ceildiv.py index 66215abc5..f5af31b83 100644 --- a/testing/python/language/test_tilelang_language_ceildiv.py +++ b/testing/python/language/test_tilelang_language_ceildiv.py @@ -6,7 +6,7 @@ @tilelang.jit(out_idx=[-1]) def _ceildiv_kernel(a: int, b: int): @T.prim_func - def ceildiv_kernel(A: T.Tensor((1,), "int32")): + def ceildiv_kernel(A: T.Tensor((1,), T.int32)): with T.Kernel(1, threads=1) as _: A[0] = T.ceildiv(T.int32(a), T.int32(b)) @@ -30,7 +30,7 @@ def test_ceildiv(): @tilelang.jit def _ceildiv_kernel_dyn(b: int): @T.prim_func - def ceildiv_kernel(A: T.Tensor((1,), "int32"), a: T.int32): + def ceildiv_kernel(A: T.Tensor((1,), T.int32), a: T.int32): with T.Kernel(1, threads=1) as _: A[0] = T.ceildiv(T.int32(a), T.int32(b)) diff --git a/testing/python/language/test_tilelang_language_chain_equal.py b/testing/python/language/test_tilelang_language_chain_equal.py index 0a9623fa9..083eefdcb 100644 --- a/testing/python/language/test_tilelang_language_chain_equal.py +++ b/testing/python/language/test_tilelang_language_chain_equal.py @@ -10,7 +10,7 @@ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }, ) -def chain_equal(N, block_size, dtype="float32"): +def chain_equal(N, block_size, dtype=T.float32): @T.prim_func def main( A: T.Tensor((N,), dtype), @@ -25,7 +25,7 @@ def main( return main -def run_chain_equal(N=128, block_size=64, dtype="float32"): +def run_chain_equal(N=128, block_size=64, dtype=T.float32): kernel = chain_equal(N, block_size, dtype) A = torch.zeros((N,), dtype=torch.float32, device="cuda") B = torch.zeros((N,), dtype=torch.float32, device="cuda") diff --git a/testing/python/language/test_tilelang_language_clamp.py b/testing/python/language/test_tilelang_language_clamp.py index 06e558fda..372d74784 100644 --- a/testing/python/language/test_tilelang_language_clamp.py +++ b/testing/python/language/test_tilelang_language_clamp.py @@ -1,5 +1,5 @@ import tilelang.testing -from tilelang.utils.tensor import map_torch_type +from tilelang import language as T def clamp_within_bounds( @@ -91,7 +91,7 @@ def run_clamp_value_range( import torch # Convert string dtype to torch.dtype - torch_dtype = map_torch_type(dtype) + torch_dtype = dtype.as_torch() def ref_program(A): min_val = torch.min(A) * 0.5 @@ -107,10 +107,10 @@ def ref_program(A): def test_clamp(): # clamp tests for float16 and float32 - run_clamp(1024, 128, "float16", -0.05, 0.05) - run_clamp(1024, 128, "float32", -0.06, 0.05) - run_clamp_value_range(1024, 128, "float16") - run_clamp_value_range(1024, 128, "float32") + run_clamp(1024, 128, T.float16, -0.05, 0.05) + run_clamp(1024, 128, T.float32, -0.06, 0.05) + run_clamp_value_range(1024, 128, T.float16) + run_clamp_value_range(1024, 128, T.float32) if __name__ == "__main__": diff --git a/testing/python/language/test_tilelang_language_clear.py b/testing/python/language/test_tilelang_language_clear.py index 19ae0bbd5..af9d89631 100644 --- a/testing/python/language/test_tilelang_language_clear.py +++ b/testing/python/language/test_tilelang_language_clear.py @@ -4,7 +4,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -39,7 +39,7 @@ def main( return main -def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) kernel = tilelang.compile(program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True}) import torch diff --git a/testing/python/language/test_tilelang_language_composable_index.py b/testing/python/language/test_tilelang_language_composable_index.py index 8a586956b..7893c1f24 100644 --- a/testing/python/language/test_tilelang_language_composable_index.py +++ b/testing/python/language/test_tilelang_language_composable_index.py @@ -6,7 +6,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"): +def tilelang_composable_copy(M, N, block_M, block_N, dtype=T.float16): @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -25,7 +25,7 @@ def main( return main -def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): +def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_composable_copy(M, N, block_M, block_N, dtype) kernel = tilelang.compile( program, @@ -44,7 +44,7 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype def test_tilelang_copy(): run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128) run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576) - run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float") + run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576, dtype=T.float32) if __name__ == "__main__": diff --git a/testing/python/language/test_tilelang_language_copy.py b/testing/python/language/test_tilelang_language_copy.py index c8515d5b6..29bb0f951 100644 --- a/testing/python/language/test_tilelang_language_copy.py +++ b/testing/python/language/test_tilelang_language_copy.py @@ -8,7 +8,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def tilelang_copy(M, N, block_M, block_N, src_dtype="float16", dst_dtype="float16"): +def tilelang_copy(M, N, block_M, block_N, src_dtype=T.float16, dst_dtype=T.float16): @T.prim_func def main( A: T.Tensor((M, N), src_dtype), @@ -24,7 +24,7 @@ def main( return main -def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): +def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy(M, N, block_M, block_N, src_dtype=dtype, dst_dtype=dtype) kernel = tilelang.compile( program, @@ -42,10 +42,10 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16") def test_tilelang_copy(): run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128) run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576) - run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float") + run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype=T.float32) -def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): +def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype=T.float16): @T.prim_func def main( A: T.StridedTensor((M, N), (NN, 1), dtype), @@ -59,7 +59,7 @@ def main( return main -def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype="float16"): +def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype=T.float16): if isinstance(NN, int): assert NN > N, "NN must be greater than N" program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype) @@ -84,21 +84,21 @@ def test_tilelang_copy_with_stride(): run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.dynamic("NN"), block_M=128, block_N=128) -def tilelang_copy_bufferload(num_tokens, dtype="float16"): +def tilelang_copy_bufferload(num_tokens, dtype=T.float16): @T.prim_func def main( - indices: T.Tensor((num_tokens,), "int32"), + indices: T.Tensor((num_tokens,), T.int32), x: T.Tensor((num_tokens,), dtype), ): with T.Kernel(num_tokens, threads=32) as pid: - idx = T.alloc_local([1], "int32") + idx = T.alloc_local([1], T.int32) T.copy(indices[pid], idx[0]) x[idx[0]] = x[idx[0]] + 1 return main -def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"): +def run_tilelang_copy_bufferload(num_tokens=128, dtype=T.float16): program = tilelang_copy_bufferload(num_tokens, dtype) # test compilation only tilelang.compile( @@ -112,7 +112,7 @@ def test_tilelang_copy_bufferload(): run_tilelang_copy_bufferload(num_tokens=128) -def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"): +def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype=T.float16): @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -126,7 +126,7 @@ def main( return main -def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): +def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype) kernel = tilelang.compile( program, @@ -143,7 +143,7 @@ def test_tilelang_copy_buffer_load_with_parallel(): run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128) -def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dtype="float8_e8m0fnu", dst_dtype="float8_e8m0fnu"): +def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T.float8_e8m0fnu, dst_dtype=T.float8_e8m0fnu): program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype) kernel = tilelang.compile( program, @@ -159,10 +159,10 @@ def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dty @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(10, 0) def test_tilelang_copy_fp8_e8m0(): - run_tilelang_copy_fp8_e8m0(src_dtype="float8_e8m0fnu", dst_dtype="float8_e8m0fnu") + run_tilelang_copy_fp8_e8m0(src_dtype=T.float8_e8m0fnu, dst_dtype=T.float8_e8m0fnu) -def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype="float4_e2m1fn", dst_dtype="float4_e2m1fn"): +def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T.float4_e2m1fn, dst_dtype=T.float4_e2m1fn): program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype) kernel = tilelang.compile( program, @@ -179,9 +179,9 @@ def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype="f @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(10, 0) def test_tilelang_copy_fp4(): - run_tilelang_copy_fp4(src_dtype="float4_e2m1fn", dst_dtype="float4_e2m1fn") - run_tilelang_copy_fp4(src_dtype="float4_e2m1fn", dst_dtype="float16") - run_tilelang_copy_fp4(src_dtype="float4_e2m1fn", dst_dtype="bfloat16") + run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.float4_e2m1fn) + run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.float16) + run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.bfloat16) if __name__ == "__main__": diff --git a/testing/python/language/test_tilelang_language_cumsum.py b/testing/python/language/test_tilelang_language_cumsum.py index c563bcf2f..fecc0d2a8 100644 --- a/testing/python/language/test_tilelang_language_cumsum.py +++ b/testing/python/language/test_tilelang_language_cumsum.py @@ -2,11 +2,10 @@ import tilelang.testing import tilelang as tl import torch +import tilelang.language as T -def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"): - import tilelang.language as T - +def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32): @T.prim_func def cumsum( A: T.Tensor((M, N), dtype), @@ -23,7 +22,7 @@ def cumsum( return cumsum -def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"): +def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32): import tilelang.language as T @T.prim_func @@ -44,7 +43,7 @@ def cumsum( return cumsum -def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", scope="smem"): +def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32, scope="smem"): if scope == "smem": program = cumsum_smem_test(M, N, block_M, block_N, dim, reverse, dtype) elif scope == "fragment": @@ -74,7 +73,7 @@ def ref_program(A): torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) -def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"): +def cumsum_smem_test_1d(N, block_N, reverse=False, dtype=T.float32): import tilelang.language as T @T.prim_func @@ -92,7 +91,7 @@ def cumsum( return cumsum -def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"): +def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype=T.float32): import tilelang.language as T @T.prim_func @@ -112,7 +111,7 @@ def cumsum( return cumsum -def run_cumsum_1d(N, block_N, reverse=False, dtype="float32", scope="smem"): +def run_cumsum_1d(N, block_N, reverse=False, dtype=T.float32, scope="smem"): if scope == "smem": program = cumsum_smem_test_1d(N, block_N, reverse, dtype) elif scope == "fragment": @@ -150,8 +149,8 @@ def test_cumsum_smem(): run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True) # Test different dtypes - run_cumsum(256, 256, 128, 128, dtype="float32") - run_cumsum(256, 256, 128, 128, dtype="float32") + run_cumsum(256, 256, 128, 128, dtype=T.float32) + run_cumsum(256, 256, 128, 128, dtype=T.float32) def test_cumsum_fragment(): @@ -160,8 +159,8 @@ def test_cumsum_fragment(): run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True, scope="fragment") # Test different dtypes - run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment") - run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment") + run_cumsum(256, 256, 128, 128, dtype=T.float32, scope="fragment") + run_cumsum(256, 256, 128, 128, dtype=T.float32, scope="fragment") def test_cumsum_smem_1d(): @@ -174,7 +173,7 @@ def test_cumsum_fragment_1d(): run_cumsum_1d(1024, 128, reverse=True, scope="fragment") -def cumsum_region_test_1d(N, chunk_size, reverse=False, dtype="float32"): +def cumsum_region_test_1d(N, chunk_size, reverse=False, dtype=T.float32): """Test cumsum with buffer region (slice) as input.""" import tilelang.language as T @@ -198,7 +197,7 @@ def cumsum_region( return cumsum_region -def run_cumsum_region_1d(N, chunk_size, reverse=False, dtype="float32"): +def run_cumsum_region_1d(N, chunk_size, reverse=False, dtype=T.float32): """Run test for cumsum with region input.""" program = cumsum_region_test_1d(N, chunk_size, reverse, dtype) jit_kernel = tl.compile(program, out_idx=-1) @@ -224,7 +223,7 @@ def ref_program(A): torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) -def cumsum_region_test_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"): +def cumsum_region_test_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32): """Test cumsum with buffer region (slice) as input in 2D.""" import tilelang.language as T @@ -253,7 +252,7 @@ def cumsum_region( return cumsum_region -def run_cumsum_region_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"): +def run_cumsum_region_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32): """Run test for cumsum with 2D region input.""" program = cumsum_region_test_2d(M, N, block_M, block_N, dim, reverse, dtype) jit_kernel = tl.compile(program, out_idx=-1) diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index b0191b4d3..67115e8c2 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -303,8 +303,8 @@ def test_serial_step_neg(A: T.Tensor((10,), T.int32)): assert torch.all(res == ref), f"Expected {ref}, but got {res}" assert isinstance(T.serial(1, 10, 1), IRBuilderFrame) - assert isinstance(T.serial(1, 10, IntImm("int32", 1)), IRBuilderFrame) - assert not isinstance(T.serial(1, 10, Var("tmp", "int32")), IRBuilderFrame) + assert isinstance(T.serial(1, 10, IntImm(T.int32, 1)), IRBuilderFrame) + assert not isinstance(T.serial(1, 10, Var("tmp", T.int32)), IRBuilderFrame) assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame) @@ -433,7 +433,7 @@ def sample_kernel( idx_out: T.Tensor[(32,), T.int32], ): with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841 - fragment = T.alloc_fragment(32, "int32") + fragment = T.alloc_fragment(32, T.int32) T.copy(idx_out, fragment) for i in T.Parallel(32): @@ -458,10 +458,10 @@ def prim_buffer_slice_step(A: T.Buffer((10,), T.int32), B: T.Buffer((5,), T.int3 def test_boolop(): - a = Var("a", "int32") - b = Var("b", "int32") - c = Var("c", "int32") - d = Var("d", "int32") + a = Var("a", T.int32) + b = Var("b", T.int32) + c = Var("c", T.int32) + d = Var("d", T.int32) @T.macro def cond(): diff --git a/testing/python/language/test_tilelang_language_get_warp_info.py b/testing/python/language/test_tilelang_language_get_warp_info.py index edbc511d0..e14cece98 100644 --- a/testing/python/language/test_tilelang_language_get_warp_info.py +++ b/testing/python/language/test_tilelang_language_get_warp_info.py @@ -24,7 +24,7 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int: @tilelang.jit(out_idx=[-1]) def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): @T.prim_func - def laneid_kernel(A: T.Tensor((num_threads,), "int32")): + def laneid_kernel(A: T.Tensor((num_threads,), T.int32)): with T.Kernel(1, threads=num_threads) as _: tx = T.get_thread_binding() A[tx] = T.get_lane_idx(warp_size) @@ -35,7 +35,7 @@ def laneid_kernel(A: T.Tensor((num_threads,), "int32")): @tilelang.jit(out_idx=[-1]) def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None): @T.prim_func - def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")): + def warp_idx_sync_kernel(A: T.Tensor((num_threads,), T.int32)): with T.Kernel(1, threads=num_threads) as _: tx = T.get_thread_binding() A[tx] = T.get_warp_idx_sync(warp_size) @@ -46,7 +46,7 @@ def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")): @tilelang.jit(out_idx=[-1]) def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None): @T.prim_func - def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")): + def warp_idx_kernel(A: T.Tensor((num_threads,), T.int32)): with T.Kernel(1, threads=num_threads) as _: tx = T.get_thread_binding() A[tx] = T.get_warp_idx(warp_size) @@ -61,7 +61,7 @@ def _get_warp_group_idx_kernel( warps_per_group: Optional[int] = None, ): @T.prim_func - def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")): + def warp_group_idx_kernel(A: T.Tensor((num_threads,), T.int32)): with T.Kernel(1, threads=num_threads) as _: tx = T.get_thread_binding() A[tx] = T.get_warp_group_idx(warp_size, warps_per_group) @@ -72,7 +72,7 @@ def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")): @tilelang.jit(out_idx=[-1]) def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64): @T.prim_func - def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")): + def shuffle_elect_kernel(A: T.Tensor((num_threads,), T.int32)): with T.Kernel(1, threads=num_threads) as _: tx = T.get_thread_binding() elected = T.shuffle_elect(thread_extent) diff --git a/testing/python/language/test_tilelang_language_if_range.py b/testing/python/language/test_tilelang_language_if_range.py index 9c9845690..c81a241ba 100644 --- a/testing/python/language/test_tilelang_language_if_range.py +++ b/testing/python/language/test_tilelang_language_if_range.py @@ -7,7 +7,7 @@ @tilelang.jit( out_idx=[1], ) -def tilelang_if_range(M, N, block_M, block_N, dtype="float16"): +def tilelang_if_range(M, N, block_M, block_N, dtype=T.float16): @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -27,7 +27,7 @@ def main( return main -def run_tilelang_if_range(M=128, N=128, block_M=32, block_N=32, dtype="float16"): +def run_tilelang_if_range(M=128, N=128, block_M=32, block_N=32, dtype=T.float16): kernel = tilelang_if_range(M, N, block_M, block_N, dtype) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) diff --git a/testing/python/language/test_tilelang_language_infinity.py b/testing/python/language/test_tilelang_language_infinity.py index 5d2518661..746afc4e0 100644 --- a/testing/python/language/test_tilelang_language_infinity.py +++ b/testing/python/language/test_tilelang_language_infinity.py @@ -22,10 +22,10 @@ def _test_infinity(dtype: str): @tilelang.testing.requires_cuda def test_infinity(): - _test_infinity("float16") - _test_infinity("bfloat16") - _test_infinity("float32") - _test_infinity("float64") + _test_infinity(T.float16) + _test_infinity(T.bfloat16) + _test_infinity(T.float32) + _test_infinity(T.float64) if __name__ == "__main__": diff --git a/testing/python/language/test_tilelang_language_int64.py b/testing/python/language/test_tilelang_language_int64.py index 28fa2211f..d81e9dc6f 100644 --- a/testing/python/language/test_tilelang_language_int64.py +++ b/testing/python/language/test_tilelang_language_int64.py @@ -3,7 +3,7 @@ @tilelang.jit -def fill_symbolic(value: float, dtype="bfloat16"): +def fill_symbolic(value: float, dtype=T.bfloat16): n = T.symbolic("n", "int64") block_n = 512 @@ -33,7 +33,7 @@ def test_fill_symbolic(): @tilelang.jit -def fill_static(n: int, value: float, dtype="bfloat16"): +def fill_static(n: int, value: float, dtype=T.bfloat16): block_n = 512 @T.prim_func diff --git a/testing/python/language/test_tilelang_language_intrinsics_codegen.py b/testing/python/language/test_tilelang_language_intrinsics_codegen.py index 80318242c..b1d1e5401 100644 --- a/testing/python/language/test_tilelang_language_intrinsics_codegen.py +++ b/testing/python/language/test_tilelang_language_intrinsics_codegen.py @@ -9,8 +9,8 @@ def test_language_ldg_codegen(): @T.prim_func def main( - x: T.Tensor((N,), "float32"), - y: T.Tensor((N,), "float32"), + x: T.Tensor((N,), T.float32), + y: T.Tensor((N,), T.float32), ): with T.Kernel(N, threads=32) as pid: # Explicitly request read-only cache load for x[pid] diff --git a/testing/python/language/test_tilelang_language_lazy_jit.py b/testing/python/language/test_tilelang_language_lazy_jit.py index 31da09c54..505730965 100644 --- a/testing/python/language/test_tilelang_language_lazy_jit.py +++ b/testing/python/language/test_tilelang_language_lazy_jit.py @@ -60,8 +60,8 @@ def gemm( ) for in_dtype, out_dtype in prod: - in_dtype = in_dtype.torch() - out_dtype = out_dtype.torch() + in_dtype = in_dtype.as_torch() + out_dtype = out_dtype.as_torch() A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") C_ref = out_dtype(A @ B) @@ -97,8 +97,8 @@ def gemm_ptr( ] ) for in_dtype, out_dtype in prod: - in_dtype = in_dtype.torch() - out_dtype = out_dtype.torch() + in_dtype = in_dtype.as_torch() + out_dtype = out_dtype.as_torch() A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") C_ref = out_dtype(A @ B) @@ -326,8 +326,8 @@ def copy6( def test_jit2_deepseek_deepgemm(): @tilelang.lazy_jit def deep_gemm( - A: T.Tensor[[int, int], T.float8_e4m3], - B: T.Tensor[[int, int], T.float8_e4m3], + A: T.Tensor[[int, int], T.float8_e4m3fn], + B: T.Tensor[[int, int], T.float8_e4m3fn], scales_a: T.Tensor[[int, int], T.float32], scales_b: T.Tensor[[int, int], T.float32], out_dtype: T.dtype = T.bfloat16, diff --git a/testing/python/language/test_tilelang_language_let.py b/testing/python/language/test_tilelang_language_let.py index a2905952b..6f94ad664 100644 --- a/testing/python/language/test_tilelang_language_let.py +++ b/testing/python/language/test_tilelang_language_let.py @@ -6,7 +6,7 @@ def test_let_vectorize_load(): @T.prim_func def main(A_ptr: T.handle): - A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) + A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16) for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): diff --git a/testing/python/language/test_tilelang_language_mask_op.py b/testing/python/language/test_tilelang_language_mask_op.py index 37b520451..8f8997291 100644 --- a/testing/python/language/test_tilelang_language_mask_op.py +++ b/testing/python/language/test_tilelang_language_mask_op.py @@ -5,7 +5,7 @@ # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): +def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype=T.float16): @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -26,7 +26,7 @@ def main( return main -def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): +def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype) kernel = tilelang.compile( program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} @@ -42,7 +42,7 @@ def test_tilelang_copy_mask_parallel(): # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): +def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype=T.float16): @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -62,7 +62,7 @@ def main( return main -def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): +def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype) kernel = tilelang.compile( program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} @@ -78,7 +78,7 @@ def test_tilelang_copy_mask_copy(): # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): +def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype=T.float16): @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -99,7 +99,7 @@ def main( return main -def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): +def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype) kernel = tilelang.compile( program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} @@ -115,7 +115,7 @@ def test_tilelang_copy_mask_parallel_range(): # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit -def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): +def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype=T.float16): @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -135,7 +135,7 @@ def main( return main -def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): +def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16): program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype) kernel = tilelang.compile( program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} diff --git a/testing/python/language/test_tilelang_language_negative_index.py b/testing/python/language/test_tilelang_language_negative_index.py index c052ccb92..feeed2c6f 100644 --- a/testing/python/language/test_tilelang_language_negative_index.py +++ b/testing/python/language/test_tilelang_language_negative_index.py @@ -1,37 +1,37 @@ from tilelang import tvm import tilelang as tl import tilelang.testing -from tvm.script import tir as T +import tilelang.language as T @T.prim_func -def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): +def negative_index_before(A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)): T.func_attr({"tir.noalias": True}) B[0] = A[T.int32(-1)] @T.prim_func -def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): +def negative_index_expected(A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)): T.func_attr({"tir.noalias": True}) B[0] = A[T.int32(15)] @T.prim_func -def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): +def negative_index_loop_before(A: T.Buffer((16,), T.float32), B: T.Buffer((4,), T.float32)): T.func_attr({"tir.noalias": True}) for i in T.serial(4): B[i] = A[-i - 1] @T.prim_func -def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): +def negative_index_loop_expected(A: T.Buffer((16,), T.float32), B: T.Buffer((4,), T.float32)): T.func_attr({"tir.noalias": True}) for i in T.serial(4): B[i] = A[15 - i] @T.prim_func -def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): +def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)): T.func_attr({"tir.noalias": True}) for i in T.serial(16): B[i] = A[shift + i] diff --git a/testing/python/language/test_tilelang_language_parallel.py b/testing/python/language/test_tilelang_language_parallel.py index b0e85ff47..a392e70b6 100644 --- a/testing/python/language/test_tilelang_language_parallel.py +++ b/testing/python/language/test_tilelang_language_parallel.py @@ -8,7 +8,7 @@ @tilelang.jit(out_idx=[1]) -def parallel_elementwise_static(length=256, dtype="float32"): +def parallel_elementwise_static(length=256, dtype=T.float32): @T.prim_func def main( A: T.Tensor((length,), dtype), @@ -22,7 +22,7 @@ def main( @tilelang.jit(out_idx=[1]) -def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"): +def parallel_elementwise_dynamic(max_len=512, threads=256, dtype=T.float32): @T.prim_func def main( A: T.Tensor((max_len,), dtype), diff --git a/testing/python/language/test_tilelang_language_pipeline.py b/testing/python/language/test_tilelang_language_pipeline.py index 54e10550b..8136e246f 100644 --- a/testing/python/language/test_tilelang_language_pipeline.py +++ b/testing/python/language/test_tilelang_language_pipeline.py @@ -1,5 +1,6 @@ from tilelang import tvm as tvm import tilelang.testing +import tilelang.language as T def matmul( @@ -23,8 +24,6 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -63,9 +62,9 @@ def run_gemm( block_K = 32 trans_A = False trans_B = False - in_dtype = "float16" - out_dtype = "float16" - dtypeAccum = "float32" + in_dtype = T.float16 + out_dtype = T.float16 + dtypeAccum = T.float32 num_threads = 128 program = matmul( M, @@ -101,7 +100,7 @@ def ref_program(A, B): A = A.T if trans_B: B = B.T - if in_dtype == "float32": + if in_dtype == T.float32: # Convert float32 to tfloat32 because tfloat32 mma cannot truncate # float32 automatically, -0x1000 meas A = (A.view(torch.int32) - 0x1000).view(torch.float32) @@ -127,7 +126,7 @@ def test_pipeline_order_stage(): tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }, ) -def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype="float16", accum_dtype="float"): +def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype=T.float16, accum_dtype=T.float32): block_mask_shape = (M // block_M, N // block_N, K // block_K) import tilelang.language as T diff --git a/testing/python/language/test_tilelang_language_ptr.py b/testing/python/language/test_tilelang_language_ptr.py index 0e60ddd72..85458139a 100644 --- a/testing/python/language/test_tilelang_language_ptr.py +++ b/testing/python/language/test_tilelang_language_ptr.py @@ -6,7 +6,7 @@ from tilelang.utils import map_torch_type -def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( a_ptr: T.ptr, @@ -39,7 +39,7 @@ def main( return main -def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): program = matmul_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) jit_kernel = tl.compile(program, target="cuda", execution_backend="cython") diff --git a/testing/python/language/test_tilelang_language_reduce.py b/testing/python/language/test_tilelang_language_reduce.py index 7ec500391..1d9bf6130 100644 --- a/testing/python/language/test_tilelang_language_reduce.py +++ b/testing/python/language/test_tilelang_language_reduce.py @@ -1,13 +1,12 @@ from tilelang import tvm as tvm import tilelang.testing import tilelang as tl +import tilelang.language as T tilelang.testing.set_random_seed() def _make_shared_reduce(M, N, dtype, reduce_cb): - import tilelang.language as T - @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -30,7 +29,7 @@ def _run_program(program, ref_program, atol=1e-2, rtol=1e-2): profiler.assert_allclose(ref_program, atol=atol, rtol=rtol) -def reduce_max_test(M, N, dtype="float16"): +def reduce_max_test(M, N, dtype=T.float16): import tilelang.language as T @T.prim_func @@ -49,7 +48,7 @@ def main( return main -def reduce_sum_test(M, N, dtype="float32"): +def reduce_sum_test(M, N, dtype=T.float32): import tilelang.language as T @T.prim_func @@ -68,27 +67,27 @@ def main( return main -def reduce_sum_ss(M, N, dtype="float32"): +def reduce_sum_ss(M, N, dtype=T.float32): return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_sum(src, dst, dim=1)) -def reduce_max_ss(M, N, dtype="float32"): +def reduce_max_ss(M, N, dtype=T.float32): return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_max(src, dst, dim=1)) -def reduce_min_ss(M, N, dtype="float32"): +def reduce_min_ss(M, N, dtype=T.float32): return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_min(src, dst, dim=1)) -def reduce_abssum_ss(M, N, dtype="float32"): +def reduce_abssum_ss(M, N, dtype=T.float32): return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_abssum(src, dst, dim=1)) -def reduce_absmax_ss(M, N, dtype="float32"): +def reduce_absmax_ss(M, N, dtype=T.float32): return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_absmax(src, dst, dim=1)) -def run_reduce_sum(M, N, dtype="float32", mode="rr"): +def run_reduce_sum(M, N, dtype=T.float32, mode="rr"): if mode == "rr": program = reduce_sum_test(M, N, dtype) elif mode == "ss": @@ -98,12 +97,12 @@ def run_reduce_sum(M, N, dtype="float32", mode="rr"): _run_program(program, lambda A: A.sum(dim=1)) -def run_shared_reduce(program_builder, ref_program, M, N, dtype="float32"): +def run_shared_reduce(program_builder, ref_program, M, N, dtype=T.float32): program = program_builder(M, N, dtype) _run_program(program, ref_program) -def run_reduce_max(M, N, dtype="float16"): +def run_reduce_max(M, N, dtype=T.float16): program = reduce_max_test(M, N, dtype) _run_program(program, lambda A: A.max(dim=1).values, atol=1e-2, rtol=1e-2) @@ -119,28 +118,28 @@ def test_reduce_sum_shared(): def test_reduce_max(): - run_reduce_max(256, 256, "float16") - run_reduce_max(512, 128, "float16") - run_reduce_max(256, 256, "float32") + run_reduce_max(256, 256, T.float16) + run_reduce_max(512, 128, T.float16) + run_reduce_max(256, 256, T.float32) def test_reduce_max_shared(): - run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32") + run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, T.float32) def test_reduce_min_shared(): - run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, "float32") + run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, T.float32) def test_reduce_abssum_shared(): - run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, "float32") + run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, T.float32) def test_reduce_absmax_shared(): - run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, "float32") + run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, T.float32) -def reduce_sum_test_clear(M, N, dtype="float32"): +def reduce_sum_test_clear(M, N, dtype=T.float32): import tilelang.language as T @T.prim_func @@ -160,7 +159,7 @@ def main( return main -def run_reduce_sum_clear(M, N, dtype="float32"): +def run_reduce_sum_clear(M, N, dtype=T.float32): program = reduce_sum_test_clear(M, N, dtype) jit_kernel = tl.compile(program, out_idx=-1) @@ -176,12 +175,12 @@ def ref_program(A): def test_reduce_sum_clear(): - run_reduce_sum_clear(256, 256, "float32") - run_reduce_sum_clear(512, 128, "float32") - run_reduce_sum_clear(128, 512, "float32") + run_reduce_sum_clear(256, 256, T.float32) + run_reduce_sum_clear(512, 128, T.float32) + run_reduce_sum_clear(128, 512, T.float32) -def reduce_max_test_clear(M, N, dtype="float16"): +def reduce_max_test_clear(M, N, dtype=T.float16): import tilelang.language as T @T.prim_func @@ -201,7 +200,7 @@ def main( return main -def run_reduce_max_clear(M, N, dtype="float16"): +def run_reduce_max_clear(M, N, dtype=T.float16): program = reduce_max_test_clear(M, N, dtype) jit_kernel = tl.compile(program, out_idx=-1) @@ -217,7 +216,7 @@ def ref_program(A): def test_reduce_max_clear(): - run_reduce_max_clear(256, 256, "float16") + run_reduce_max_clear(256, 256, T.float16) if __name__ == "__main__": diff --git a/testing/python/language/test_tilelang_language_reshape.py b/testing/python/language/test_tilelang_language_reshape.py index 3c343309a..10c3d0ce8 100644 --- a/testing/python/language/test_tilelang_language_reshape.py +++ b/testing/python/language/test_tilelang_language_reshape.py @@ -1,13 +1,11 @@ -from tilelang import tvm as tvm import tilelang.testing import tilelang as tl +from tilelang import language as T import torch import pytest def reshape_test(N, M, dtype): - import tilelang.language as T - @T.prim_func def main( A: T.Tensor((N,), dtype), @@ -42,13 +40,11 @@ def ref_program(A): def test_reshape_smem(): # Test reshape - run_reshape(1024, 32, "float32") - run_reshape(2048, 64, "float16") + run_reshape(1024, 32, T.float32) + run_reshape(2048, 64, T.float16) def reshape_test_smem_1d_2_2d(N, M, dtype): - import tilelang.language as T - @T.prim_func def main( A: T.Tensor((N,), dtype), @@ -86,13 +82,11 @@ def ref_program(A): def test_reshape_smem_1d_2_2d(): - run_reshape_smem_1d_2_2d(1024, 32, "float32") - run_reshape_smem_1d_2_2d(2048, 64, "float16") + run_reshape_smem_1d_2_2d(1024, 32, T.float32) + run_reshape_smem_1d_2_2d(2048, 64, T.float16) def reshape_test_smem_2d_2_1d(N, M, dtype): - import tilelang.language as T - @T.prim_func def main( A: T.Tensor((N // M, M), dtype), @@ -130,13 +124,11 @@ def ref_program(A): def test_reshape_smem_2d_2_1d(): - run_reshape_smem_2d_2_1d(1024, 32, "float32") - run_reshape_smem_2d_2_1d(2048, 64, "float16") + run_reshape_smem_2d_2_1d(1024, 32, T.float32) + run_reshape_smem_2d_2_1d(2048, 64, T.float16) def reshape_fragment_test(N, M, dtype): - import tilelang.language as T - @T.prim_func def main( A: T.Tensor((N // M, M), dtype), @@ -175,12 +167,11 @@ def ref_program(A): def test_reshape_fragment(): - run_reshape_fragment(1024, 32, "float32") - run_reshape_fragment(2048, 64, "float16") + run_reshape_fragment(1024, 32, T.float32) + run_reshape_fragment(2048, 64, T.float16) def reshape_layout_transform_shared(N, M, dtype): - import tilelang.language as T from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout @T.prim_func @@ -222,13 +213,11 @@ def ref_program(A): def test_reshape_layout_transform_shared(): - run_reshape_layout_transform_shared(1024, 32, "float32") - run_reshape_layout_transform_shared(2048, 64, "float16") + run_reshape_layout_transform_shared(1024, 32, T.float32) + run_reshape_layout_transform_shared(2048, 64, T.float16) def reduce_after_reshape_test(N, M, dtype): - import tilelang.language as T - @T.prim_func def main( A: T.Tensor((N,), dtype), @@ -267,13 +256,11 @@ def ref_program(A): def test_reduce_after_reshape(): - run_reduce_after_reshape(1024, 32, "float32") - run_reduce_after_reshape(2048, 64, "float16") + run_reduce_after_reshape(1024, 32, T.float32) + run_reduce_after_reshape(2048, 64, T.float16) def reshape_shape_mismatch_test(N, M, dtype): - import tilelang.language as T - @T.prim_func def main( A: T.Tensor((N,), dtype), @@ -288,7 +275,7 @@ def main( def test_reshape_shape_mismatch(): with pytest.raises(AssertionError): - reshape_shape_mismatch_test(1024, 32, "float32") + reshape_shape_mismatch_test(1024, 32, T.float32) if __name__ == "__main__": diff --git a/testing/python/language/test_tilelang_language_ternary.py b/testing/python/language/test_tilelang_language_ternary.py index 632dcf7b4..20c7b5e77 100644 --- a/testing/python/language/test_tilelang_language_ternary.py +++ b/testing/python/language/test_tilelang_language_ternary.py @@ -7,7 +7,7 @@ @tilelang.jit( out_idx=[1], ) -def tilelang_ternary(M, N, block_M, block_N, dtype="float16"): +def tilelang_ternary(M, N, block_M, block_N, dtype=T.float16): @T.prim_func def main( A: T.Tensor((M, N), dtype), @@ -21,7 +21,7 @@ def main( return main -def run_tilelang_ternary(M=128, N=128, block_M=32, block_N=32, dtype="float16"): +def run_tilelang_ternary(M=128, N=128, block_M=32, block_N=32, dtype=T.float16): kernel = tilelang_ternary(M, N, block_M, block_N, dtype) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) b = kernel(a) diff --git a/testing/python/language/test_tilelang_language_tma_1d.py b/testing/python/language/test_tilelang_language_tma_1d.py index 90022b5ec..9cb79c10c 100644 --- a/testing/python/language/test_tilelang_language_tma_1d.py +++ b/testing/python/language/test_tilelang_language_tma_1d.py @@ -34,7 +34,7 @@ def run_elementwise_add(M, N): # Default config block_M, block_N = 128, 128 config = {"block_M": block_M, "block_N": block_N, "threads": 128} - kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") + kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32) out = kernel(a, b) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) diff --git a/testing/python/language/test_tilelang_language_unroll.py b/testing/python/language/test_tilelang_language_unroll.py index 416840a13..06367e975 100644 --- a/testing/python/language/test_tilelang_language_unroll.py +++ b/testing/python/language/test_tilelang_language_unroll.py @@ -6,7 +6,7 @@ def test_unroll_with_step(): @T.prim_func def main(A_ptr: T.handle): - A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) + A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16) for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): @@ -20,7 +20,7 @@ def main(A_ptr: T.handle): def test_unroll_with_unroll_factor(): @T.prim_func def main(A_ptr: T.handle): - A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) + A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16) for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): diff --git a/testing/python/language/test_tilelang_language_var_init.py b/testing/python/language/test_tilelang_language_var_init.py index d4f9062b8..36d9bf014 100644 --- a/testing/python/language/test_tilelang_language_var_init.py +++ b/testing/python/language/test_tilelang_language_var_init.py @@ -7,12 +7,12 @@ def test_var_assign() -> None: @tilelang.jit(out_idx=-1) def jit_kernel(): @T.prim_func - def test_var_assign(A: T.Tensor((2,), "int32")): + def test_var_assign(A: T.Tensor((2,), T.int32)): with T.Kernel(1) as _: - a = T.alloc_var("int32", init=1) - b = T.alloc_var("int32", init=a) # b gets value of a + a = T.alloc_var(T.int32, init=1) + b = T.alloc_var(T.int32, init=a) # b gets value of a a = 2 - d = T.alloc_var("int32", init=a) # c gets new value of a + d = T.alloc_var(T.int32, init=a) # c gets new value of a A[0] = b A[1] = d diff --git a/testing/python/language/test_tilelang_language_vectorize.py b/testing/python/language/test_tilelang_language_vectorize.py index 6867079c3..75360bb19 100644 --- a/testing/python/language/test_tilelang_language_vectorize.py +++ b/testing/python/language/test_tilelang_language_vectorize.py @@ -7,8 +7,8 @@ def vectorize_test(N, M, stride_A, stride_B): @T.prim_func def main( - A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821 - B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821 + A: T.StridedTensor[(N, M), (1, stride_A), T.float32], # noqa: F821 + B: T.StridedTensor[(N, M), (1, stride_B), T.float32], # noqa: F821 ): with T.Kernel(M // 128, threads=128) as (bx): tx = T.get_thread_binding(0) @@ -60,9 +60,9 @@ def test_vectorize(): def vectorize_test_invariant_index(N, M, K): @T.prim_func def main( - A: T.Tensor[(N, M), "float32"], # noqa: F821 - B: T.Tensor[(N, M), "float32"], # noqa: F821 - C: T.Tensor[(N, M // K), "float32"], # noqa: F821 + A: T.Tensor[(N, M), T.float32], # noqa: F821 + B: T.Tensor[(N, M), T.float32], # noqa: F821 + C: T.Tensor[(N, M // K), T.float32], # noqa: F821 ): with T.Kernel(N // 128, threads=128) as (bx): tx = T.get_thread_binding(0) diff --git a/testing/python/language/test_tilelang_language_vectorized_cast.py b/testing/python/language/test_tilelang_language_vectorized_cast.py index a9ab86985..1a0a0942a 100644 --- a/testing/python/language/test_tilelang_language_vectorized_cast.py +++ b/testing/python/language/test_tilelang_language_vectorized_cast.py @@ -4,11 +4,11 @@ import tilelang.language as T str2dtype = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "float8_e4m3": torch.float8_e4m3fn, - "float8_e5m2": torch.float8_e5m2, + T.float32: torch.float32, + T.float16: torch.float16, + T.bfloat16: torch.bfloat16, + T.float8_e4m3fn: torch.float8_e4m3fn, + T.float8_e5m2: torch.float8_e5m2, } @@ -81,22 +81,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, @pytest.mark.parametrize( "src_dtype, dst_dtype, check_str, lanes", [ - ("float32", "float16", "__float22half2_rn", 2), - ("float32", "float16", "__float22half2_rn", 4), - ("float16", "float32", "__half22float2", 2), - ("float16", "float32", "__half22float2", 4), - ("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2), - ("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4), - ("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2), - ("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4), - ("float32", "bfloat16", "__float22bfloat162_rn", 2), - ("float32", "bfloat16", "__float22bfloat162_rn", 4), - ("bfloat16", "float32", "__bfloat1622float2", 2), - ("bfloat16", "float32", "__bfloat1622float2", 4), - ("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2), - ("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4), - ("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2), - ("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4), + (T.float32, T.float16, "__float22half2_rn", 2), + (T.float32, T.float16, "__float22half2_rn", 4), + (T.float16, T.float32, "__half22float2", 2), + (T.float16, T.float32, "__half22float2", 4), + (T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 2), + (T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 4), + (T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 2), + (T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 4), + (T.float32, T.bfloat16, "__float22bfloat162_rn", 2), + (T.float32, T.bfloat16, "__float22bfloat162_rn", 4), + (T.bfloat16, T.float32, "__bfloat1622float2", 2), + (T.bfloat16, T.float32, "__bfloat1622float2", 4), + (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 2), + (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4), + (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2), + (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4), ], ) def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes): diff --git a/testing/python/language/test_tilelang_language_view.py b/testing/python/language/test_tilelang_language_view.py index ff050e312..dc4c3711b 100644 --- a/testing/python/language/test_tilelang_language_view.py +++ b/testing/python/language/test_tilelang_language_view.py @@ -1,3 +1,4 @@ +import tilelang.language as T from tilelang import tvm as tvm import tilelang.testing import tilelang as tl @@ -5,8 +6,6 @@ def view_test(N, M, dtype, new_dtype=None): - import tilelang.language as T - new_shape = [N // M, M] if new_dtype: from tvm import DataType @@ -37,9 +36,7 @@ def run_view(N, M, dtype, new_dtype=None): def ref_program(A): if new_dtype: - from tilelang.utils.tensor import map_torch_type - - torch_dtype = map_torch_type(new_dtype) + torch_dtype = T.dtype(new_dtype).as_torch() return A.view(N // M, M).view(dtype=torch_dtype) return A.view(N // M, M) @@ -48,17 +45,15 @@ def ref_program(A): def test_reshape_view(): # Test view with same dtype - run_view(1024, 32, "float32") - run_view(2048, 64, "float16") + run_view(1024, 32, T.float32) + run_view(2048, 64, T.float16) # Test view with dtype conversion - run_view(1024, 32, "float32", "float16") - run_view(2048, 64, "float16", "float32") + run_view(1024, 32, T.float32, T.float16) + run_view(2048, 64, T.float16, T.float32) def view_shape_mismatch_test(N, M, dtype, new_dtype=None): - import tilelang.language as T - new_shape = [N // M, M + 1] if new_dtype: from tvm import DataType @@ -84,7 +79,7 @@ def main( def test_view_shape_mismatch(): with pytest.raises(AssertionError): - view_shape_mismatch_test(1024, 32, "float32") + view_shape_mismatch_test(1024, 32, T.float32) if __name__ == "__main__": diff --git a/testing/python/language/test_tilelang_language_warp_reduce.py b/testing/python/language/test_tilelang_language_warp_reduce.py index 0a0fb70bb..a8868013d 100644 --- a/testing/python/language/test_tilelang_language_warp_reduce.py +++ b/testing/python/language/test_tilelang_language_warp_reduce.py @@ -33,7 +33,7 @@ def main(x: T.Tensor((32), dtype)): def test_warp_reduce_sum(): a = torch.randn((32,), dtype=torch.float32, device="cuda") - kernel = get_kernel("sum", "float32") + kernel = get_kernel("sum", T.float32) ref = torch.full_like(a, a.sum()) kernel(a) torch.testing.assert_close(a, ref) @@ -41,7 +41,7 @@ def test_warp_reduce_sum(): def test_warp_reduce_max(): a = torch.randn((32,), dtype=torch.float32, device="cuda") - kernel = get_kernel("max", "float32") + kernel = get_kernel("max", T.float32) print(kernel.get_kernel_source()) ref = torch.full_like(a, a.max()) kernel(a) @@ -50,7 +50,7 @@ def test_warp_reduce_max(): def test_warp_reduce_min(): a = torch.randn((32,), dtype=torch.float32, device="cuda") - kernel = get_kernel("min", "float32") + kernel = get_kernel("min", T.float32) ref = torch.full_like(a, a.min()) kernel(a) torch.testing.assert_close(a, ref) @@ -58,7 +58,7 @@ def test_warp_reduce_min(): def test_warp_reduce_bitand(): a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") - kernel = get_kernel("bitand", "int32") + kernel = get_kernel("bitand", T.int32) ref_val = a[0] for i in range(1, a.shape[0]): ref_val = ref_val & a[i] @@ -69,7 +69,7 @@ def test_warp_reduce_bitand(): def test_warp_reduce_bitor(): a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") - kernel = get_kernel("bitor", "int32") + kernel = get_kernel("bitor", T.int32) ref_val = a[0] for i in range(1, a.shape[0]): ref_val = ref_val | a[i] diff --git a/testing/python/layout/test_tilelang_layout_fused_replicate.py b/testing/python/layout/test_tilelang_layout_fused_replicate.py index 6d3c26820..8aa5f6c42 100644 --- a/testing/python/layout/test_tilelang_layout_fused_replicate.py +++ b/testing/python/layout/test_tilelang_layout_fused_replicate.py @@ -14,8 +14,8 @@ def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int): @T.prim_func def main( - a: T.Buffer((B, M, N), "bfloat16"), - a_out: T.Buffer((B, M, N), "float32"), + a: T.Buffer((B, M, N), T.bfloat16), + a_out: T.Buffer((B, M, N), T.float32), ): with T.Kernel( T.ceildiv(M, BLOCK_MN), @@ -23,7 +23,7 @@ def main( B, threads=128, ) as (pid_m, pid_n, pid_b): - a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32") + a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), T.float32) offs_m = pid_m * BLOCK_MN offs_n = pid_n * BLOCK_K diff --git a/testing/python/math/test_math_bitwise_reduce.py b/testing/python/math/test_math_bitwise_reduce.py index 8d7f5a1ac..044e0ea37 100644 --- a/testing/python/math/test_math_bitwise_reduce.py +++ b/testing/python/math/test_math_bitwise_reduce.py @@ -21,15 +21,15 @@ def bitwise_reduce( ): @T.prim_func def reduce_func( - A: T.Tensor((M, N), "int32"), - B: T.Tensor((M), "int32"), - Output: T.Tensor((M), "int32"), + A: T.Tensor((M, N), T.int32), + B: T.Tensor((M), T.int32), + Output: T.Tensor((M), T.int32), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_N), "int32") - A_fragment = T.alloc_fragment((block_M, block_N), "int32") - B_shared = T.alloc_shared((block_M,), "int32") - B_fragment = T.alloc_fragment((block_M), "int32") + A_shared = T.alloc_shared((block_M, block_N), T.int32) + A_fragment = T.alloc_fragment((block_M, block_N), T.int32) + B_shared = T.alloc_shared((block_M,), T.int32) + B_fragment = T.alloc_fragment((block_M), T.int32) T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(A_shared, A_fragment) T.copy(B[by * block_M], B_shared) diff --git a/testing/python/math/test_math_fast_math.py b/testing/python/math/test_math_fast_math.py index 7809983e8..3c50e95f4 100644 --- a/testing/python/math/test_math_fast_math.py +++ b/testing/python/math/test_math_fast_math.py @@ -49,7 +49,7 @@ def check_non_fastmath_usage(source, mathop_name): check_fastmath_usage(source, mathop_name, expect_fastmath=False) -def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): +def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): """ Test single-argument mathops. T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) @@ -85,7 +85,7 @@ def main( print(f"✓ {mathop_name} compilation and execution test passed") -def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): +def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): """ Test two-argument mathops to ensure they generate non-fastmath CUDA code. """ @@ -133,7 +133,7 @@ def main( check_non_fastmath_usage(source_fastmath, mathop_name) # Test numerical correctness - torch_dtype = getattr(torch, dtype) + torch_dtype = dtype.as_torch() a = torch.randn(M, N, device="cuda", dtype=torch_dtype) b = torch.randn(M, N, device="cuda", dtype=torch_dtype) @@ -159,8 +159,8 @@ def run_abs_test(): @T.prim_func def main( - A: T.Tensor((M, N), "float32"), - B: T.Tensor((M, N), "float32"), + A: T.Tensor((M, N), T.float32), + B: T.Tensor((M, N), T.float32), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): @@ -188,7 +188,7 @@ def main( print("✓ abs numerical test passed") -def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): +def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32): """ Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). """ @@ -221,7 +221,7 @@ def main( check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) # Test numerical correctness - torch_dtype = getattr(torch, dtype) + torch_dtype = dtype.as_torch() a = torch.randn(M, N, device="cuda", dtype=torch_dtype) # Ensure positive values for functions that need them @@ -273,7 +273,7 @@ def test_mathops_generate_no_fastmath(): ] for name, func in single_arg_mathops: - run_single_arg_mathop_test(name, func, dtype="float32") + run_single_arg_mathop_test(name, func, dtype=T.float32) print(f"✓ {name} test passed") @@ -287,7 +287,7 @@ def test_two_arg_mathops_fastmath(): ] for name, func in two_arg_mathops: - run_two_arg_mathop_test(name, func, dtype="float32") + run_two_arg_mathop_test(name, func, dtype=T.float32) @tilelang.testing.requires_cuda @@ -312,7 +312,7 @@ def test_fastmath_versions(): ] for name, func in fastmath_mathops: - run_fastmath_mathop_test(name, func, dtype="float32") + run_fastmath_mathop_test(name, func, dtype=T.float32) print(f"✓ {name} test passed") diff --git a/testing/python/math/test_math_ieee_math.py b/testing/python/math/test_math_ieee_math.py index 193092ec7..5d4988002 100644 --- a/testing/python/math/test_math_ieee_math.py +++ b/testing/python/math/test_math_ieee_math.py @@ -5,7 +5,7 @@ import pytest -def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype="float32"): +def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype=T.float32): """ Test IEEE-compliant math operations with specified rounding modes. """ @@ -75,7 +75,7 @@ def main_func( print(f"✓ {mathop_name} compilation test passed") # Test numerical execution - torch_dtype = getattr(torch, dtype) + torch_dtype = dtype.as_torch() a = torch.randn(M, N, device="cuda", dtype=torch_dtype) if num_inputs >= 2: @@ -186,8 +186,8 @@ def test_ieee_frsqrt_rn_only(): @T.prim_func def main( - A: T.Tensor((128, 128), "float32"), - B: T.Tensor((128, 128), "float32"), + A: T.Tensor((128, 128), T.float32), + B: T.Tensor((128, 128), T.float32), ): with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by): for i, j in T.Parallel(32, 32): diff --git a/testing/python/metal/test_metal_codegen.py b/testing/python/metal/test_metal_codegen.py index ea088aea9..5349bbec5 100644 --- a/testing/python/metal/test_metal_codegen.py +++ b/testing/python/metal/test_metal_codegen.py @@ -6,7 +6,7 @@ @tilelang.jit(execution_backend="torch") -def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float32, accum_dtype=T.float32): @T.prim_func def gemm( A: T.Tensor((M, K), dtype), @@ -39,13 +39,13 @@ def assert_gemm( block_M, block_N, block_K, - dtype="float32", - accum_dtype="float", + dtype=T.float32, + accum_dtype=T.float32, atol=1e-8, ): jit_kernel = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) - torch_dtype = getattr(torch, dtype) + torch_dtype = dtype.as_torch() a, b = None, None if "int" in dtype: a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps") @@ -69,12 +69,12 @@ def test_gemm_float32(): @tilelang.testing.requires_metal def test_gemm_float16(): - assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="float16", atol=1) + assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype=T.float16, atol=1) @tilelang.testing.requires_metal def test_gemm_int32(): - assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="int32", atol=1) + assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype=T.int32, atol=1) if __name__ == "__main__": diff --git a/testing/python/profiler/test_tilelang_profiler.py b/testing/python/profiler/test_tilelang_profiler.py index 8aa547084..09d894c59 100644 --- a/testing/python/profiler/test_tilelang_profiler.py +++ b/testing/python/profiler/test_tilelang_profiler.py @@ -3,7 +3,7 @@ @tilelang.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm( A: T.Tensor((M, K), dtype), diff --git a/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py b/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py index 7a42b23bd..083373eb7 100644 --- a/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py +++ b/testing/python/runtime/test_tilelang_runtime_dynamic_shared_memory.py @@ -9,14 +9,14 @@ @tilelang.jit def dynamic_smem_kernel(): # Symbolic length to drive dynamic shared memory allocation - length = T.symbolic("len", dtype="int32") # noqa: F821 + length = T.symbolic("len", dtype=T.int32) # noqa: F821 @T.prim_func - def main(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821 + def main(global_tensor: T.Tensor[(length,), T.int32]): # noqa: F821 # Launch a simple kernel that copies from global memory into shared memory # using a dynamically-sized allocation. No writes back to global_tensor. with T.Kernel(1, threads=32) as _: - buffer_shared = T.alloc_shared((length,), dtype="int32") # noqa: F821 + buffer_shared = T.alloc_shared((length,), dtype=T.int32) # noqa: F821 T.copy(buffer_shared, global_tensor) return main diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index de8a9f9dc..67123cb8c 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -1,3 +1,4 @@ +import tilelang.language as T from tilelang import tvm as tvm import tilelang.testing import pytest @@ -23,8 +24,6 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -112,20 +111,20 @@ def ref_program(A, B): @pytest.mark.parametrize( "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ - (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2, 128), - (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2, 128), - (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2, 128), - (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2, 128), - (128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128), - (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128), + (128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), ], ) def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): @@ -153,8 +152,6 @@ def matmul_rs( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) A_frag_shape = A_shared_shape - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -247,20 +244,20 @@ def ref_program(A, B): @pytest.mark.parametrize( "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ - (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128), - (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), ], ) def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): @@ -288,8 +285,6 @@ def matmul_sr( B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_frag_shape = B_shared_shape - import tilelang.language as T - @T.prim_func def main( A: T.Tensor(A_shape, in_dtype), @@ -381,20 +376,20 @@ def ref_program(A, B): @pytest.mark.parametrize( "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ - (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128), - (128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128), + (128, 128, 32, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 32, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 32, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 32, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), ], ) def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): @@ -519,22 +514,22 @@ def ref_program(A, B): @pytest.mark.parametrize( "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ - (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128), - (128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2, 128), - (128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2, 128), - (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), - (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float, 128, 256, 32, 2, 128), + (128, 8, 128, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 2, 128), + (128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 32, 2, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128), + (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128), ], ) def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py index 6c47bb5e4..b0f4a29c9 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -2,6 +2,7 @@ import torch import tilelang import tilelang.testing +import tilelang.language as T from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse from tilelang.layout import make_cutlass_metadata_layout @@ -44,14 +45,12 @@ def matmul_sp_sm90( trans_A, trans_B, ): - E_factor = 4 if in_dtype == "float32" else 8 + E_factor = 4 if in_dtype == T.float32 else 8 A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) B_shape = (K, N) if not trans_B else (N, K) A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) - import tilelang.language as T - @T.prim_func def main( A_sparse: T.Tensor(A_sparse_shape, in_dtype), @@ -104,15 +103,13 @@ def matmul_sp_sm80( trans_B, ): is_8_bit = "8" in in_dtype - metadata_dtype = "int32" if is_8_bit else "int16" + metadata_dtype = T.int32 if is_8_bit else T.int16 E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) B_shape = (K, N) if not trans_B else (N, K) A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) - import tilelang.language as T - @T.prim_func def main( A_sparse: T.Tensor(A_sparse_shape, in_dtype), @@ -312,19 +309,18 @@ def run_gemm_sp_sm80( @pytest.mark.parametrize( "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", [ - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True), - (512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True), - (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 0, 256, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 0, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 0, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False), + (512, 1024, 768, T.float8_e4m3fn, T.float16, T.float16, 64, 64, 64, 2, 128, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True), ], ) def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): @@ -337,21 +333,20 @@ def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_ @pytest.mark.parametrize( "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", [ - (512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False), - (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128, False, False), - (512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True), - (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True), - (512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True), - (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True), - (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True), - (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 32, 0, 32, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 64, 0, 32, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 1, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False), + (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 3, 128, False, False), + (512, 1024, 768, T.int8, T.int32, T.int32, 32, 32, 64, 0, 32, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 0, 32, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 128, 128, 128, 0, 128, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 1, 128, False, True), + (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True), ], ) def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py index cd4123d99..9d232902c 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -7,6 +7,7 @@ import tilelang.testing import torch +import tilelang.language as T def matmul( @@ -31,8 +32,6 @@ def matmul( A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tilelang.language as T - @T.prim_func def main( A_sparse: T.Tensor(A_sparse_shape, in_dtype), @@ -83,7 +82,7 @@ def run_gemm_ss( num_stages=3, num_threads=128, ): - metadata_dtype = "int32" if ("8" in in_dtype) else "int16" + metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 program = matmul( M, N, @@ -157,17 +156,17 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): @pytest.mark.parametrize( "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ - (512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2, 128), - (512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2, 128), - (512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2, 128), - (512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2, 128), - (128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128), - (128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), - (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128), + (128, 8, 64, False, True, T.float16, T.float16, T.float, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int32, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, False, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), ], ) def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): @@ -252,7 +251,7 @@ def run_gemm_rs( num_stages=3, num_threads=128, ): - metadata_dtype = "int32" if ("8" in in_dtype) else "int16" + metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 program = matmul_rs( M, N, @@ -308,16 +307,16 @@ def _matmul(A, B): @pytest.mark.parametrize( "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ - (512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128), - (128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128), - (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), ], ) def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): @@ -402,7 +401,7 @@ def run_gemm_sr( num_stages=3, num_threads=128, ): - metadata_dtype = "int32" if ("8" in in_dtype) else "int16" + metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 program = matmul_sr( M, N, @@ -458,16 +457,16 @@ def _matmul(A, B): @pytest.mark.parametrize( "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ - (512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128), - (128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128), - (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2, 128), - (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2, 128), - (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), ], ) def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): @@ -556,7 +555,7 @@ def run_gemm_rr( num_stages=3, num_threads=128, ): - metadata_dtype = "int32" if ("8" in in_dtype) else "int16" + metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16 program = matmul_rr( M, N, @@ -612,18 +611,18 @@ def _matmul(A, B): @pytest.mark.parametrize( "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", [ - (512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128), - (512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128), - (128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2, 128), - (128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2, 128), - (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), - (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), + (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128), + (512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 256, 32, 2, 128), + (128, 8, 128, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 2, 128), + (128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 64, 2, 128), + (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128), + (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128), ], ) def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): diff --git a/testing/python/transform/test_readonly_param_const_codegen.py b/testing/python/transform/test_readonly_param_const_codegen.py index d0a2bbbf1..0d255b46b 100644 --- a/testing/python/transform/test_readonly_param_const_codegen.py +++ b/testing/python/transform/test_readonly_param_const_codegen.py @@ -6,8 +6,8 @@ def _simple_add_kernel(): @T.prim_func def main( - x: T.Tensor((128,), "float32"), - y: T.Tensor((128,), "float32"), + x: T.Tensor((128,), T.float32), + y: T.Tensor((128,), T.float32), ): # One-dimensional kernel; writes y from x without modifying x with T.Kernel(128, threads=32) as pid: diff --git a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py index d3f45c5eb..cdff6fb1d 100644 --- a/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py +++ b/testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py @@ -16,13 +16,13 @@ def _check(original, transformed): def test_trival_pipeline(): @T.prim_func - def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")): + def before(A: T.Tensor((16, 1), T.float32), C: T.Tensor((16, 1), T.float32)): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}): with T.block(): T.reads(A[tx, i]) T.writes(C[tx, i]) - B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + B = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared") with T.block(): T.reads(A[tx, i]) T.writes(B[tx, 0]) diff --git a/testing/python/transform/test_tilelang_transform_cluster_planning.py b/testing/python/transform/test_tilelang_transform_cluster_planning.py index 2ec6321e8..296c6ce94 100644 --- a/testing/python/transform/test_tilelang_transform_cluster_planning.py +++ b/testing/python/transform/test_tilelang_transform_cluster_planning.py @@ -22,11 +22,11 @@ def _check(original, transformed): def test_cluster_planning(): @T.prim_func - def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")): + def before(A: T.Tensor((1024, 32), T.float16), B: T.Tensor((32, 1024), T.float16), C: T.Tensor((1024, 1024), T.float16)): with T.Kernel(8, 8, threads=128) as (bx, by): - A_shared = T.alloc_shared((128, 32), "float16") - B_shared = T.alloc_shared((32, 128), "float16") - C_local = T.alloc_fragment((128, 128), "float32") + A_shared = T.alloc_shared((128, 32), T.float16) + B_shared = T.alloc_shared((32, 128), T.float16) + C_local = T.alloc_fragment((128, 128), T.float32) T.clear(C_local) @@ -39,12 +39,12 @@ def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16" T.copy(C_local, C[by * 128, bx * 128]) @T.prim_func - def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")): + def after(A: T.Tensor((1024, 32), T.float16), B: T.Tensor((32, 1024), T.float16), C: T.Tensor((1024, 1024), T.float16)): T.func_attr({"clusterIdx.y": T.int32(2)}) with T.Kernel(8, 8, threads=128) as (bx, by): - A_shared = T.alloc_shared((128, 32), "float16") - B_shared = T.alloc_shared((32, 128), "float16") - C_local = T.alloc_fragment((128, 128), "float32") + A_shared = T.alloc_shared((128, 32), T.float16) + B_shared = T.alloc_shared((32, 128), T.float16) + C_local = T.alloc_fragment((128, 128), T.float32) T.clear(C_local) diff --git a/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py b/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py index 339b283e0..559b2ffb4 100644 --- a/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py +++ b/testing/python/transform/test_tilelang_transform_config_index_bitwidth.py @@ -19,8 +19,8 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) shape = [batch, heads, seq_len, dim] block_mask_shape = [batch, heads, downsample_len, downsample_len] - dtype = "bfloat16" - accum_dtype = "float" + dtype = T.bfloat16 + accum_dtype = T.float32 block_mask_dtype = "bool" def kernel_func(block_M, block_N, num_stages, threads): diff --git a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py index 854a26172..533a62fc6 100644 --- a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py +++ b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py @@ -25,8 +25,8 @@ def test_lower_fence_proxy(): @T.prim_func def before(): with T.Kernel(8): - A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn") - B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") + A_shared = T.decl_buffer((1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.decl_buffer((1, 4, 512), T.float16, scope="shared.dyn") C_local = T.decl_buffer((32,), scope="local") for i in T.unroll(16): C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) @@ -34,16 +34,16 @@ def before(): "handle", tir.op.Op.get("tl.tl_gemm"), "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), ) @T.prim_func def after(): with T.Kernel(8): - A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn") - B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") + A_shared = T.decl_buffer((1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.decl_buffer((1, 4, 512), T.float16, scope="shared.dyn") C_local = T.decl_buffer((32,), scope="local") for i in T.unroll(16): C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) @@ -52,9 +52,9 @@ def after(): "handle", tir.op.Op.get("tl.tl_gemm"), "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), ) _check(before, after) @@ -64,8 +64,8 @@ def test_async_to_generic_no_double_fence(): @T.prim_func def before(): with T.Kernel(8): - A_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn") - B_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn") + A_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn") + B_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn") T.ptx_cp_async("uint8", A_shared.data, 0, B_shared.data, 0, 16) T.fence_proxy_async() T.call_extern("handle", "generic_op") @@ -129,7 +129,7 @@ def test_tma_store_sync_injection(): @T.prim_func def before(): with T.Kernel(8): - A_global = T.decl_buffer((128,), "float16", scope="global") + A_global = T.decl_buffer((128,), T.float16, scope="global") T.evaluate(T.call_intrin("handle", tir.op.Op.get("tl.tma_store"), A_global.data)) mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) @@ -159,14 +159,14 @@ def test_wgmma_marked_async(): @T.prim_func def before(): with T.Kernel(1): - A_shared = T.decl_buffer((1,), "float16", scope="shared") - desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma") - desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma") - C_local = T.decl_buffer((32,), "float16", scope="local") + A_shared = T.decl_buffer((1,), T.float16, scope="shared") + desc_a = T.decl_buffer((1,), T.uint64, scope="local.descriptor.wgmma") + desc_b = T.decl_buffer((1,), T.uint64, scope="local.descriptor.wgmma") + C_local = T.decl_buffer((32,), T.float16, scope="local") A_shared[0] = T.float16(0) T.warpgroup_arrive() T.ptx_wgmma_ss( - "float16", + T.float16, "m64n64k16", T.bool(True), T.bool(True), diff --git a/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py b/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py index 0cc79b92f..1885c7c4b 100644 --- a/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py +++ b/testing/python/transform/test_tilelang_transform_inject_set_max_nreg.py @@ -9,7 +9,7 @@ def test_inject_set_max_nreg(): """Test the InjectSetMaxNReg pass""" @T.prim_func - def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16")): + def before(A: T.Tensor((512, 512), T.float16), B: T.Tensor((512, 512), T.float16)): bx = T.launch_thread("blockIdx.x", 8) by = T.launch_thread("blockIdx.y", 8) v = T.launch_thread("threadIdx.x", 128) @@ -22,8 +22,8 @@ def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16" T.annotate_producer_reg_dealloc(24) # Producer: decrease to 24 T.annotate_consumer_reg_alloc(240) # Consumer: increase to 240 - A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn") - B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn") + A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") C_local = T.alloc_buffer((32,), scope="local") T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128) @@ -37,7 +37,7 @@ def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16" T.tma_load( T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), T.get_mbarrier(k % 3), - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), k * 32, by * 64, ) @@ -49,9 +49,9 @@ def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16" T.call_extern( "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), ) T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) @@ -86,7 +86,7 @@ def test_inject_set_max_nreg_no_set_max_nreg(): """Test the InjectSetMaxNReg pass with no_set_max_nreg""" @T.prim_func - def before_no_set_max_nreg(A: T.Tensor((512, 512), "float16")): + def before_no_set_max_nreg(A: T.Tensor((512, 512), T.float16)): bx = T.launch_thread("blockIdx.x", 8) v = T.launch_thread("threadIdx.x", 128) diff --git a/testing/python/transform/test_tilelang_transform_layout_inference.py b/testing/python/transform/test_tilelang_transform_layout_inference.py index 270dd31ee..82fcd19ab 100644 --- a/testing/python/transform/test_tilelang_transform_layout_inference.py +++ b/testing/python/transform/test_tilelang_transform_layout_inference.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "block_M, block_N, block_K, threads, vec_load_b, dtype", [ - (64, 64, 32, 128, 8, "float16"), + (64, 64, 32, 128, 8, T.float16), ], ) def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): @@ -102,4 +102,4 @@ def main( if __name__ == "__main__": # tilelang.testing.main() - test_loop_tail_split(64, 64, 32, 128, 8, "float16") + test_loop_tail_split(64, 64, 32, 128, 8, T.float16) diff --git a/testing/python/transform/test_tilelang_transform_legalize_negative_index.py b/testing/python/transform/test_tilelang_transform_legalize_negative_index.py index c5dd065aa..26c151141 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_negative_index.py +++ b/testing/python/transform/test_tilelang_transform_legalize_negative_index.py @@ -19,15 +19,15 @@ def test_buffer_load_negative_index_legalized(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): + def before(A: T.Tensor((1024,), T.float32)): value = A[-1] - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value @T.prim_func - def after(A: T.Tensor((1024,), "float32")): + def after(A: T.Tensor((1024,), T.float32)): value = A[1023] # A[-1] becomes A[1023] - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value _check(before, after) @@ -39,15 +39,15 @@ def test_buffer_load_mixed_negative_positive_indices(): """ @T.prim_func - def before(A: T.Tensor((1024, 512), "float32")): + def before(A: T.Tensor((1024, 512), T.float32)): value = A[-1, 10] - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value @T.prim_func - def after(A: T.Tensor((1024, 512), "float32")): + def after(A: T.Tensor((1024, 512), T.float32)): value = A[1023, 10] # A[-1, 10] becomes A[1023, 10] - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value _check(before, after) @@ -59,15 +59,15 @@ def test_buffer_load_multiple_negative_indices(): """ @T.prim_func - def before(A: T.Tensor((1024, 512, 256), "float32")): + def before(A: T.Tensor((1024, 512, 256), T.float32)): value = A[-1, -2, -3] - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value @T.prim_func - def after(A: T.Tensor((1024, 512, 256), "float32")): + def after(A: T.Tensor((1024, 512, 256), T.float32)): value = A[1023, 510, 253] # -1+1024=1023, -2+512=510, -3+256=253 - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value _check(before, after) @@ -79,15 +79,15 @@ def test_buffer_load_negative_index_in_expression(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): - B = T.alloc_buffer((1024,), "float32") + def before(A: T.Tensor((1024,), T.float32)): + B = T.alloc_buffer((1024,), T.float32) for i in T.serial(1, 1024): value = A[-i] B[-i] = value @T.prim_func - def after(A: T.Tensor((1024,), "float32")): - B = T.alloc_buffer((1024,), "float32") + def after(A: T.Tensor((1024,), T.float32)): + B = T.alloc_buffer((1024,), T.float32) for i in T.serial(1, 1024): value = A[1024 - i] B[1024 - i] = value @@ -101,16 +101,16 @@ def test_buffer_load_non_negative_index_unchanged(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): + def before(A: T.Tensor((1024,), T.float32)): value = A[0] - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value @T.prim_func - def after(A: T.Tensor((1024,), "float32")): + def after(A: T.Tensor((1024,), T.float32)): # No changes expected for non-negative indices value = A[0] - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value _check(before, after) @@ -123,18 +123,18 @@ def test_buffer_load_unknown_sign_index_warning(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): - i = T.Var("i", "int32") + def before(A: T.Tensor((1024,), T.float32)): + i = T.Var("i", T.int32) value = A[i] - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value @T.prim_func - def after(A: T.Tensor((1024,), "float32")): - i = T.Var("i", "int32") + def after(A: T.Tensor((1024,), T.float32)): + i = T.Var("i", T.int32) # Unknown sign indices should remain unchanged value = A[i] - B = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((1,), T.float32) B[0] = value _check(before, after) @@ -146,18 +146,18 @@ def test_buffer_load_vector_index_negative_broadcast(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): + def before(A: T.Tensor((1024,), T.float32)): vec = T.Broadcast(-1, 4) value = A[vec] - B = T.alloc_buffer((4,), "float32") + B = T.alloc_buffer((4,), T.float32) B[T.Ramp(0, 1, 4)] = value @T.prim_func - def after(A: T.Tensor((1024,), "float32")): + def after(A: T.Tensor((1024,), T.float32)): # vec is unused and can be delimed by Simplify. vec = T.Broadcast(-1, 4) # noqa: F841 value = A[T.Broadcast(1023, 4)] - B = T.alloc_buffer((4,), "float32") + B = T.alloc_buffer((4,), T.float32) B[T.Ramp(0, 1, 4)] = value _check(before, after) @@ -169,18 +169,18 @@ def test_buffer_load_vector_index_negative_ramp(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): + def before(A: T.Tensor((1024,), T.float32)): vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] value = A[vec] - B = T.alloc_buffer((4,), "float32") + B = T.alloc_buffer((4,), T.float32) B[T.Ramp(0, 1, 4)] = value @T.prim_func - def after(A: T.Tensor((1024,), "float32")): + def after(A: T.Tensor((1024,), T.float32)): # vec is unused and can be delimed by Simplify. vec = T.Ramp(-4, 1, 4) # noqa: F841 value = A[T.Ramp(1020, 1, 4)] - B = T.alloc_buffer((4,), "float32") + B = T.alloc_buffer((4,), T.float32) B[T.Ramp(0, 1, 4)] = value _check(before, after) @@ -192,17 +192,17 @@ def test_buffer_load_nested_buffer_loads(): """ @T.prim_func - def before(A: T.Tensor((1024, 512), "float32")): + def before(A: T.Tensor((1024, 512), T.float32)): inner_val = A[-1, 10] - outer_val = A[inner_val.astype("int32"), -2] - B = T.alloc_buffer((1,), "float32") + outer_val = A[inner_val.astype(T.int32), -2] + B = T.alloc_buffer((1,), T.float32) B[0] = outer_val @T.prim_func - def after(A: T.Tensor((1024, 512), "float32")): + def after(A: T.Tensor((1024, 512), T.float32)): inner_val = A[1023, 10] - outer_val = A[inner_val.astype("int32"), 510] - B = T.alloc_buffer((1,), "float32") + outer_val = A[inner_val.astype(T.int32), 510] + B = T.alloc_buffer((1,), T.float32) B[0] = outer_val _check(before, after) @@ -214,11 +214,11 @@ def test_buffer_store_negative_index(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): + def before(A: T.Tensor((1024,), T.float32)): A[-1] = 42.0 @T.prim_func - def after(A: T.Tensor((1024,), "float32")): + def after(A: T.Tensor((1024,), T.float32)): A[1023] = 42.0 _check(before, after) @@ -230,11 +230,11 @@ def test_buffer_store_mixed_negative_positive_indices(): """ @T.prim_func - def before(A: T.Tensor((1024, 512), "float32")): + def before(A: T.Tensor((1024, 512), T.float32)): A[-1, 10] = 42.0 @T.prim_func - def after(A: T.Tensor((1024, 512), "float32")): + def after(A: T.Tensor((1024, 512), T.float32)): A[1023, 10] = 42.0 _check(before, after) @@ -246,11 +246,11 @@ def test_buffer_store_multiple_negative_indices(): """ @T.prim_func - def before(A: T.Tensor((1024, 512, 256), "float32")): + def before(A: T.Tensor((1024, 512, 256), T.float32)): A[-1, -2, -3] = 42.0 @T.prim_func - def after(A: T.Tensor((1024, 512, 256), "float32")): + def after(A: T.Tensor((1024, 512, 256), T.float32)): A[1023, 510, 253] = 42.0 # -1+1024=1023, -2+512=510, -3+256=253 _check(before, after) @@ -262,12 +262,12 @@ def test_buffer_store_negative_index_in_expression(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): + def before(A: T.Tensor((1024,), T.float32)): for i in T.serial(1, 1024): A[-i] = i * 2.0 @T.prim_func - def after(A: T.Tensor((1024,), "float32")): + def after(A: T.Tensor((1024,), T.float32)): for i in T.serial(1, 1024): A[1024 - i] = i * 2.0 @@ -280,13 +280,13 @@ def test_buffer_store_vector_index_negative_broadcast(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): + def before(A: T.Tensor((1024,), T.float32)): vec = T.Broadcast(-1, 4) values = T.Broadcast(42.0, 4) A[vec] = values @T.prim_func - def after(A: T.Tensor((1024,), "float32")): + def after(A: T.Tensor((1024,), T.float32)): # vec is unused and can be delimed by Simplify. vec = T.Broadcast(-1, 4) # noqa: F841 values = T.Broadcast(42.0, 4) @@ -301,13 +301,13 @@ def test_buffer_store_vector_index_negative_ramp(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32")): + def before(A: T.Tensor((1024,), T.float32)): vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] values = T.Ramp(0.0, 1.0, 4) # values: [0.0, 1.0, 2.0, 3.0] A[vec] = values @T.prim_func - def after(A: T.Tensor((1024,), "float32")): + def after(A: T.Tensor((1024,), T.float32)): # vec is unused and can be delimed by Simplify. vec = T.Ramp(-4, 1, 4) # noqa: F841 values = T.Ramp(0.0, 1.0, 4) @@ -322,14 +322,14 @@ def test_buffer_store_nested_in_condition(): """ @T.prim_func - def before(A: T.Tensor((1024,), "float32"), flag: T.int32): + def before(A: T.Tensor((1024,), T.float32), flag: T.int32): if flag > 0: A[-1] = 42.0 else: A[-2] = 24.0 @T.prim_func - def after(A: T.Tensor((1024,), "float32"), flag: T.int32): + def after(A: T.Tensor((1024,), T.float32), flag: T.int32): if flag > 0: A[1023] = 42.0 else: diff --git a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py index de2e61eec..4f75fa05d 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py +++ b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -5,7 +5,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): - dtype = "float32" + dtype = T.float32 @T.prim_func def main( @@ -41,39 +41,8 @@ def assert_vectorize_access(M: int = 64, N: int = 64): tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) -# def issue_1013_buggy_kernel(): -# # NOTE: This kernel is mainly to test some corner cases in boundary check - -# num_tokens = T.dynamic('num_tokens') -# num_threads = 128 - -# @T.prim_func -# def main(x: T.Tensor((num_tokens,), dtype="int64")): -# with T.Kernel(1, threads=num_threads) as _: -# count = T.alloc_var('int') -# thread_idx = T.get_thread_binding() -# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)): -# idx = thread_idx + i * num_threads -# count += x[idx] == 2 - -# # NOTE(chaofan): Ideally, the prover should be able to prove that the access is safe -# # and the padding value is not used. However, the current prover cannot handle this case. -# # So for now the expected kernel is a if-else statement to check the boundary. -# @T.prim_func -# def expected(x: T.Tensor((num_tokens,), dtype="int64")): -# with T.Kernel(1, threads=num_threads) as _: -# count = T.alloc_var('int') -# thread_idx = T.get_thread_binding() -# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)): -# idx = thread_idx + i * num_threads -# count += T.Cast("int32", -# value=T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2)) - -# return main, expected - - def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): - dtype = "float32" + dtype = T.float32 @T.prim_func def main( @@ -115,7 +84,7 @@ def assert_vectorize_access_with_atmoic_add(M: int = 64, N: int = 64): def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): - dtype = "float32" + dtype = T.float32 @T.prim_func def main( @@ -152,13 +121,6 @@ def test_vectorize_access(): assert_vectorize_access(64, 64) -# def test_issue_1013(): -# func, expected = issue_1013_buggy_kernel() -# mod = tvm.IRModule({func.attrs["global_symbol"]: func}) -# transformed = tl.transform.LegalizeSafeMemoryAccess()(mod) -# tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) - - def test_vectorize_access_with_atmoic_add(): assert_vectorize_access_with_atmoic_add(64, 64) diff --git a/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py index ec570d418..3cc7541cc 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py +++ b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py @@ -5,12 +5,12 @@ def vectorize_access_legalize(M: int = 64, N: int = 64): - dtype = "float32" + dtype = T.float32 vec_len = 8 @T.prim_func def main( - A: T.Tensor((M, N, vec_len), dtype="float32"), + A: T.Tensor((M, N, vec_len), dtype=T.float32), ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) @@ -21,7 +21,7 @@ def main( @T.prim_func def expected( - A: T.Tensor((M, N, vec_len), dtype="float32"), + A: T.Tensor((M, N, vec_len), dtype=T.float32), ): with T.Kernel(1, 1, threads=M) as (bx, by): A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) diff --git a/testing/python/transform/test_tilelang_transform_let_inline.py b/testing/python/transform/test_tilelang_transform_let_inline.py index 6603ecab3..e773e3fee 100644 --- a/testing/python/transform/test_tilelang_transform_let_inline.py +++ b/testing/python/transform/test_tilelang_transform_let_inline.py @@ -13,7 +13,7 @@ def _check(original, transformed): def test_let_binding(): @T.prim_func - def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")): + def before(A: T.Tensor((128, 128), T.float32), B: T.Tensor((128, 128), T.float32)): for i in range(128): for j in range(128): with T.block("compute"): @@ -22,7 +22,7 @@ def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32" B[i, j] = value @T.prim_func - def expected(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")): + def expected(A: T.Tensor((128, 128), T.float32), B: T.Tensor((128, 128), T.float32)): for i in range(128): for j in range(128): with T.block("compute"): @@ -33,14 +33,14 @@ def expected(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float3 def test_parallel_scope(): @T.prim_func - def before(A: T.Tensor((128,), "float32")): + def before(A: T.Tensor((128,), T.float32)): for i in T.Parallel(128): with T.block("parallel"): value = T.float32(1.0) A[i] = value @T.prim_func - def expected(A: T.Tensor((128,), "float32")): + def expected(A: T.Tensor((128,), T.float32)): for i in T.Parallel(128): with T.block("parallel"): A[i] = T.float32(1.0) diff --git a/testing/python/transform/test_tilelang_transform_lower_tile_op.py b/testing/python/transform/test_tilelang_transform_lower_tile_op.py index ac5841859..16c7cb802 100644 --- a/testing/python/transform/test_tilelang_transform_lower_tile_op.py +++ b/testing/python/transform/test_tilelang_transform_lower_tile_op.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "block_M, block_N, block_K, threads, vec_load_b, dtype", [ - (64, 64, 32, 128, 8, "float16"), + (64, 64, 32, 128, 8, T.float16), ], ) def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): diff --git a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py b/testing/python/transform/test_tilelang_transform_multi_version_buffer.py index 0d56ab1a8..e85fd8db8 100644 --- a/testing/python/transform/test_tilelang_transform_multi_version_buffer.py +++ b/testing/python/transform/test_tilelang_transform_multi_version_buffer.py @@ -24,7 +24,7 @@ def _check(original, transformed): M = 512 N = 512 K = 512 -dtype = "float16" +dtype = T.float16 block_M = 64 block_N = 64 block_K = 32 @@ -39,8 +39,8 @@ def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): with T.block(""): T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) T.writes() - A_shared = T.alloc_buffer((1, 8, 256), "float16", scope="shared.dyn") - B_shared = T.alloc_buffer((1, 4, 512), "float16", scope="shared.dyn") + A_shared = T.alloc_buffer((1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.alloc_buffer((1, 4, 512), T.float16, scope="shared.dyn") C_local = T.alloc_buffer((32,), scope="local") for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): for vec in T.vectorized(2): @@ -50,7 +50,7 @@ def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): T.tma_load( T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), 0, - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 2), k * 32, by * 64, ) @@ -58,16 +58,16 @@ def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): T.tma_load( T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), 0, - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 2), bx * 64, k * 32, ) T.call_extern( "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), ) @T.prim_func @@ -78,8 +78,8 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): with T.block(""): T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) T.writes() - A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn") - B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn") + A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") C_local = T.alloc_buffer((32,), scope="local") for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): for vec in T.vectorized(2): @@ -89,7 +89,7 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): T.tma_load( T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), 0, - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), k * 32, by * 64, ) @@ -97,16 +97,16 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): T.tma_load( T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), 0, - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2), bx * 64, k * 32, ) T.call_extern( "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), ) _check(before, after) @@ -114,10 +114,10 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): def test_multi_version_buffer_with_let(): @T.prim_func - def before(scales: T.Tensor((4,), "float32")): + def before(scales: T.Tensor((4,), T.float32)): with T.block("root"): - shared = T.alloc_buffer((8,), "float32", scope="shared.dyn") - accum = T.alloc_buffer((8,), "float32", scope="local") + shared = T.alloc_buffer((8,), T.float32, scope="shared.dyn") + accum = T.alloc_buffer((8,), T.float32, scope="local") for k in T.serial(4, annotations={"num_stages": T.int32(2)}): value = scales[k] for i in T.serial(8): @@ -126,10 +126,10 @@ def before(scales: T.Tensor((4,), "float32")): accum[i] = accum[i] + shared[i] @T.prim_func - def after(scales: T.Tensor((4,), "float32")): + def after(scales: T.Tensor((4,), T.float32)): with T.block("root"): - shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn") - accum = T.alloc_buffer((8,), "float32", scope="local") + shared = T.alloc_buffer((2, 8), T.float32, scope="shared.dyn") + accum = T.alloc_buffer((8,), T.float32, scope="local") for k in T.serial(4, annotations={"num_stages": T.int32(2)}): value = scales[k] for i in T.serial(8): diff --git a/testing/python/transform/test_tilelang_transform_pipeline_planning.py b/testing/python/transform/test_tilelang_transform_pipeline_planning.py index f38d6079e..83db7f75c 100644 --- a/testing/python/transform/test_tilelang_transform_pipeline_planning.py +++ b/testing/python/transform/test_tilelang_transform_pipeline_planning.py @@ -20,11 +20,11 @@ def _check(original, transformed): def test_simple_pipeline(): @T.prim_func - def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")): + def before(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)): with T.Kernel(8, 8, threads=128) as (bx, by): - A_shared = T.alloc_shared((128, 32), "float32") - B_shared = T.alloc_shared((32, 128), "float32") - C_local = T.alloc_fragment((128, 128), "float32") + A_shared = T.alloc_shared((128, 32), T.float32) + B_shared = T.alloc_shared((32, 128), T.float32) + C_local = T.alloc_fragment((128, 128), T.float32) T.clear(C_local) @@ -37,11 +37,11 @@ def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32" T.copy(C_local, C[by * 128, bx * 128]) @T.prim_func - def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")): + def after(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)): with T.Kernel(8, 8, threads=128) as (bx, by): - A_shared = T.alloc_shared((128, 32), "float32") - B_shared = T.alloc_shared((32, 128), "float32") - C_local = T.alloc_fragment((128, 128), "float32") + A_shared = T.alloc_shared((128, 32), T.float32) + B_shared = T.alloc_shared((32, 128), T.float32) + C_local = T.alloc_fragment((128, 128), T.float32) T.clear(C_local) diff --git a/testing/python/transform/test_tilelang_transform_simplify.py b/testing/python/transform/test_tilelang_transform_simplify.py index 657a2e8a4..3b7376820 100644 --- a/testing/python/transform/test_tilelang_transform_simplify.py +++ b/testing/python/transform/test_tilelang_transform_simplify.py @@ -22,9 +22,9 @@ def main( T.gemm(A, B, D) else: with T.block(): - A_shared = T.alloc_shared((64, 64), dtype="float32") - C_shared = T.alloc_shared((64, 64), dtype="float32") - D_shared = T.alloc_shared((64, 64), dtype="float32") + A_shared = T.alloc_shared((64, 64), dtype=T.float32) + C_shared = T.alloc_shared((64, 64), dtype=T.float32) + D_shared = T.alloc_shared((64, 64), dtype=T.float32) T.copy(A, A_shared) T.copy(C, C_shared) T.gemm(A_shared, C_shared, D_shared) @@ -40,7 +40,7 @@ def test_modify(with_B=False, with_bias=False): assert mod != mod2 -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( a: T.handle, diff --git a/testing/python/transform/test_tilelang_transform_warp_specialized.py b/testing/python/transform/test_tilelang_transform_warp_specialized.py index 2e101bf82..0171fab82 100644 --- a/testing/python/transform/test_tilelang_transform_warp_specialized.py +++ b/testing/python/transform/test_tilelang_transform_warp_specialized.py @@ -25,7 +25,7 @@ def _check(original, transformed): M = 512 N = 512 K = 512 -dtype = "float16" +dtype = T.float16 block_M = 64 block_N = 64 block_K = 32 @@ -40,15 +40,15 @@ def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): with T.block(""): T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) T.writes() - A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn") - B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn") + A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") C_local = T.alloc_buffer((32,), scope="local") for k in T.serial(16, annotations={"num_stages": T.int32(3)}): if v == 0: T.tma_load( T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), 0, - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), k * 32, by * 64, ) @@ -56,16 +56,16 @@ def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): T.tma_load( T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), 0, - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2), bx * 64, k * 32, ) T.call_extern( "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), ) @T.prim_func @@ -73,8 +73,8 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): bx = T.launch_thread("blockIdx.x", 8) by = T.launch_thread("blockIdx.y", 8) v = T.launch_thread("threadIdx.x", 256) - A_shared = T.decl_buffer((3, 1, 8, 256), "float16", scope="shared.dyn") - B_shared = T.decl_buffer((3, 1, 4, 512), "float16", scope="shared.dyn") + A_shared = T.decl_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn") + B_shared = T.decl_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn") C_local = T.decl_buffer((32,), scope="local") T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128) T.attr([128, 128], "kWarpSpecializationScope", 0) @@ -88,7 +88,7 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): T.tma_load( T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), T.get_mbarrier(k % 3), - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2), k * 32, by * 64, ) @@ -98,7 +98,7 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): T.tma_load( T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), T.get_mbarrier(k % 3), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2), bx * 64, k * 32, ) @@ -110,9 +110,9 @@ def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): T.call_extern( "handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", - T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), - T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), + T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1), + T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3), ) T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) diff --git a/testing/python/webgpu/test_webgpu_codegen.py b/testing/python/webgpu/test_webgpu_codegen.py index ed1752796..b8b199e79 100644 --- a/testing/python/webgpu/test_webgpu_codegen.py +++ b/testing/python/webgpu/test_webgpu_codegen.py @@ -4,7 +4,7 @@ import tilelang.language as T -def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): +def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def main( A: T.Tensor((M, K), dtype), @@ -38,8 +38,8 @@ def assert_gemm_codegen( block_M, block_N, block_K, - dtype="float16", - accum_dtype="float", + dtype=T.float16, + accum_dtype=T.float32, ): func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) # Because the current pass context have been polluted by previous testing. diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 1f2a4f404..87176b209 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -141,6 +141,7 @@ def _load_tile_lang_lib(): engine, # noqa: F401 tools, # noqa: F401 ) +from .language.v2 import dtypes # noqa: F401 from .autotuner import autotune # noqa: F401 from .transform import PassConfigKey # noqa: F401 diff --git a/tilelang/engine/param.py b/tilelang/engine/param.py index bb9872e4a..fe023f83f 100644 --- a/tilelang/engine/param.py +++ b/tilelang/engine/param.py @@ -6,7 +6,7 @@ import torch from tilelang import tvm as tvm from tvm.tir import Buffer, IntImm, Var, PrimExpr -from tilelang.utils.tensor import map_torch_type +import tilelang.language as T @dataclass @@ -138,7 +138,7 @@ def torch_dtype(self) -> torch.dtype: >>> param = KernelParam.from_buffer(buffer) >>> tensor = torch.empty(shape, dtype=param.torch_dtype()) """ - return map_torch_type(str(self.dtype)) + return T.dtype(self.dtype).as_torch() @dataclass diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index 1e97bd0f2..ad2192061 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -61,9 +61,9 @@ class MatrixCoreIntrinEmitter: def __init__( self, - a_dtype: str = "float16", - b_dtype: str = "float16", - accum_dtype: str = "float16", + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, @@ -105,9 +105,9 @@ def __init__( self.num_elems_per_byte = num_elems_per_byte self.thread_var = thread_var - def _initialize_k_dim(self, a_dtype="float16"): + def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): - if a_dtype in ["float8_e4m3fnuz", "int8"]: + if a_dtype in ["float8_e4m3fnuz", T.int8]: self.k_dim = 32 return a_dtype = DataType(a_dtype) @@ -132,7 +132,7 @@ def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): def _initialize_mfma_prefix(self, k_dim=16): in_dtype, out_dtype = self.a_dtype, self.accum_dtype M_DIM, N_DIM = self.M_DIM, self.N_DIM - out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype] + out_dtype_abbrv = {T.float16: "f16", T.float32: "f32", T.int8: "i8", T.int32: "i32"}[out_dtype] in_dtype_abbrv = { "bfloat16": "bf16", @@ -221,7 +221,7 @@ def get_ldmatrix_index_map(self, is_b=False): def get_store_index_map(self, inverse: bool = False) -> IndexMap: warp_size, local_size_c = self.WARP_SIZE, self.local_size_out - index_map = IndexMap.from_func(mfma_store_index_map, index_dtype="int32") + index_map = IndexMap.from_func(mfma_store_index_map, index_dtype=T.int32) if not inverse: return index_map inverse_index_map = index_map.inverse([warp_size, local_size_c]) @@ -521,7 +521,7 @@ def make_mfma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = " self.block_col_warps, ) - inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward_thread(i: int, j: int) -> int: """ @@ -670,9 +670,9 @@ def _legalize_to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> Buffe class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): def __init__( self, - a_dtype: str = "float16", - b_dtype: str = "float16", - accum_dtype: str = "float16", + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 28afdb291..4b41eef2a 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -60,9 +60,9 @@ class TensorCoreIntrinEmitter: def __init__( self, - a_dtype: str = "float16", - b_dtype: str = "float16", - accum_dtype: str = "float16", + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, @@ -108,7 +108,7 @@ def __init__( f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" ) - def _initialize_k_dim(self, a_dtype="float16"): + def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): a_dtype = DataType(a_dtype) self.k_dim = 256 // a_dtype.bits @@ -194,9 +194,9 @@ def get_store_index_map(self, inverse: bool = False) -> IndexMap: warp_size, local_size_c = self.WARP_SIZE, self.local_size_out if DataType(self.accum_dtype).bits == 64: - index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32") + index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype=T.int32) else: - index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32") + index_map = IndexMap.from_func(mma_store_index_map, index_dtype=T.int32) if not inverse: return index_map inverse_index_map = index_map.inverse([warp_size, local_size_c]) @@ -649,7 +649,7 @@ def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A self.block_col_warps, ) - inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward_thread(i: int, j: int) -> int: """ @@ -806,9 +806,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): def __init__( self, - a_dtype: str = "float16", - b_dtype: str = "float16", - accum_dtype: str = "float16", + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, @@ -839,7 +839,7 @@ def __init__( ) self._initialize_transform_kind(transform_kind_a, transform_kind_b) - def _initialize_k_dim(self, a_dtype="float16"): + def _initialize_k_dim(self, a_dtype=T.float16): self.k_dim = 256 // DataType(a_dtype).bits def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): @@ -1266,7 +1266,7 @@ def mma(self, A_local_buf, B_local_buf, C_local_buf): a_dtype_abbrv = "int4" b_dtype_abbrv = "int4" accum_dtype = self.accum_dtype - accum_dtype_abbrv = "int32" + accum_dtype_abbrv = T.int32 mma_prefix = "m16n8k32" @T.macro diff --git a/tilelang/intrinsics/mma_sm70_macro_generator.py b/tilelang/intrinsics/mma_sm70_macro_generator.py index 3186adb2a..6acc40a4c 100644 --- a/tilelang/intrinsics/mma_sm70_macro_generator.py +++ b/tilelang/intrinsics/mma_sm70_macro_generator.py @@ -46,9 +46,9 @@ class TensorCoreIntrinEmitter: def __init__( self, - a_dtype: str = "float16", - b_dtype: str = "float16", - accum_dtype: str = "float16", + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, @@ -89,7 +89,7 @@ def __init__( f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" ) - def _initialize_k_dim(self, a_dtype="float16"): + def _initialize_k_dim(self, a_dtype=T.float16): self.k_dim = 4 def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16): @@ -147,8 +147,8 @@ def get_thread_binding(self): def get_store_index_map(self, inverse: bool = False) -> IndexMap: warp_size, local_size_c = self.WARP_SIZE, self.local_size_out index_map = IndexMap.from_func( - mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16, - index_dtype="int32", + mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == T.float32 else mma_32x8_to_shared_16x16_layout_fp16, + index_dtype=T.int32, ) if not inverse: return index_map @@ -383,7 +383,7 @@ def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A self.block_col_warps, ) - inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward(i: int, j: int, rep: int) -> int: """ diff --git a/tilelang/intrinsics/mma_sp_macro_generator.py b/tilelang/intrinsics/mma_sp_macro_generator.py index ea7aa8992..3e375b46b 100644 --- a/tilelang/intrinsics/mma_sp_macro_generator.py +++ b/tilelang/intrinsics/mma_sp_macro_generator.py @@ -133,10 +133,10 @@ class SparseTensorCoreIntrinEmitter: def __init__( self, - a_dtype: str = "float16", - e_dtype: str = "uint8", - b_dtype: str = "float16", - accum_dtype: str = "float16", + a_dtype: str = T.float16, + e_dtype: str = T.uint8, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, a_transposed: bool = False, b_transposed: bool = False, e_transposed: bool = False, @@ -181,7 +181,7 @@ def __init__( f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" ) - def _initialize_k_dim(self, a_dtype="float16"): + def _initialize_k_dim(self, a_dtype=T.float16): if isinstance(a_dtype, str): a_dtype = DataType(a_dtype) # NOTE: k_dim here represents the logical shape of the MMA operation. @@ -250,7 +250,7 @@ def get_thread_binding(self): def get_store_index_map(self, inverse: bool = False) -> IndexMap: warp_size, local_size_c = self.WARP_SIZE, self.local_size_out - index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32") + index_map = IndexMap.from_func(mma_store_index_map, index_dtype=T.int32) if not inverse: return index_map inverse_index_map = index_map.inverse([warp_size, local_size_c]) @@ -708,7 +708,7 @@ def make_mma_load_layout(self, local_buf: Buffer, matrix: Literal["A", "B"] = "A self.block_col_warps, ) - inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward_thread(i: int, j: int) -> int: """ diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index 26208d6ce..923bb0e10 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -73,9 +73,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): def __init__( self, - a_dtype: str = "float16", - b_dtype: str = "float16", - accum_dtype: str = "float16", + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, @@ -245,7 +245,7 @@ def tcgen05mma(self, A_buf: Buffer, B_buf: Buffer, C_local_buf: Buffer, mbar, cl ) # Allocate an instruction descriptor wrapper and initialize it a_dtype_abbrv = self.a_dtype_abbrv - mask_zero = T.Cast("int32", 0) + mask_zero = T.Cast(T.int32, 0) mask0 = mask1 = mask2 = mask3 = mask_zero # TCGEN05 only has one warp group diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py index 483b6e731..864420c77 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -83,9 +83,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): def __init__( self, - a_dtype: str = "float16", - b_dtype: str = "float16", - accum_dtype: str = "float16", + a_dtype: str = T.float16, + b_dtype: str = T.float16, + accum_dtype: str = T.float16, a_transposed: bool = False, b_transposed: bool = False, block_row_warps: int = 2, @@ -515,7 +515,7 @@ def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragme self.block_col_warps, ) - inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") + inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32) def forward_thread(i: int, j: int) -> int: """ diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index b26f0b8fe..e9338fa6e 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -28,6 +28,7 @@ from tvm.script.parser.tir import block_attr from tvm.tir.buffer import Buffer from tvm.tir.expr import FloatImm, IntImm +from .v2 import dtypes as _dtypes from .v2.dtypes import dtype as tl_dtype from .v2.builder import OutTensor from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer @@ -158,7 +159,7 @@ def alloc_barrier(arrive_count: int): Returns: T.Buffer: A TVM buffer object allocated as a barrier """ - return T.alloc_buffer([arrive_count], "uint64", scope="shared.barrier") + return T.alloc_buffer([arrive_count], _dtypes.uint64, scope="shared.barrier") def alloc_tmem(shape, dtype): @@ -231,7 +232,7 @@ def alloc_reducer(shape, dtype, op="sum", replication=None): def alloc_descriptor( kind: DescKind = "wgmma", - dtype: str = "uint64", + dtype: str = _dtypes.uint64, ): """Allocate a descriptor buffer for WGMMA and TCGEN5.MMA. @@ -248,28 +249,28 @@ def alloc_descriptor( return T.alloc_buffer([1], dtype, scope=scope) -def alloc_wgmma_desc(dtype: str = "uint64"): +def alloc_wgmma_desc(dtype: str = _dtypes.uint64): return alloc_descriptor("wgmma", dtype=dtype) -def alloc_tcgen05_smem_desc(dtype: str = "uint64"): +def alloc_tcgen05_smem_desc(dtype: str = _dtypes.uint64): return alloc_descriptor("tcgen05_smem", dtype=dtype) -def alloc_tcgen05_instruction_desc(dtype: str = "uint32"): +def alloc_tcgen05_instruction_desc(dtype: str = _dtypes.uint32): return alloc_descriptor("tcgen05_instr", dtype=dtype) # Alias: short name consistent with imports -def alloc_tcgen05_instr_desc(dtype: str = "uint32"): +def alloc_tcgen05_instr_desc(dtype: str = _dtypes.uint32): return alloc_tcgen05_instruction_desc(dtype) @overload -def empty(shape: tuple[Unpack[_Shapes]], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... +def empty(shape: tuple[Unpack[_Shapes]], dtype: str = _dtypes.float32) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... -def empty(*shape: Unpack[_Shapes], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: +def empty(*shape: Unpack[_Shapes], dtype: str = _dtypes.float32) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: if len(shape) == 1 and isinstance(shape[0], (tuple, list)): return OutTensor(shape[0], dtype) elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str): diff --git a/tilelang/language/ast/ir.py b/tilelang/language/ast/ir.py index 035251434..a4caefc24 100644 --- a/tilelang/language/ast/ir.py +++ b/tilelang/language/ast/ir.py @@ -92,7 +92,7 @@ def buffer( shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], - dtype: str = "float32", + dtype: str = T.float32, data: Var = None, strides: List[PrimExpr] = None, elem_offset: PrimExpr = None, @@ -143,7 +143,7 @@ def buffer( """ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape if strides is not None: - strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] + strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides] else: strides = [] return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member @@ -244,7 +244,7 @@ def func_ret(ret_type: Type) -> Type: def match_buffer( param: Union[Var, BufferLoad, BufferRegion], shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None, - dtype: str = "float32", + dtype: str = T.float32, data: Var = None, strides: List[PrimExpr] = None, elem_offset: PrimExpr = None, @@ -266,11 +266,11 @@ def match_buffer( ------- Match buffer from function parameter .. code-block:: python - A = T.match_buffer(a, (128, 128), dtype="float32") + A = T.match_buffer(a, (128, 128), dtype=T.float32) Match buffer from Buffer subregion .. code-block:: python - A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") + A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype=T.float32) Parameters ---------- @@ -320,7 +320,7 @@ def match_buffer( raise ValueError("Shape must be specified when binding input param") shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape if strides is not None: - idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else "int32" + idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else T.int32 strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides] else: strides = [] @@ -440,7 +440,7 @@ def block_attr(attrs: Dict[str, Any]) -> None: def alloc_buffer( shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], - dtype: str = "float32", + dtype: str = T.float32, data: Var = None, strides: List[PrimExpr] = None, elem_offset: PrimExpr = None, @@ -491,7 +491,7 @@ def alloc_buffer( """ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape if strides is not None: - strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] + strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides] else: strides = [] return _ffi_api.AllocBuffer( # type: ignore[attr-defined] # pylint: disable=no-member @@ -537,7 +537,7 @@ class axis: # pylint: disable=invalid-name def spatial( dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, - dtype: str = "int32", + dtype: str = T.int32, ) -> Var: """The spatial block axis defining function. @@ -565,7 +565,7 @@ def spatial( def reduce( dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, - dtype: str = "int32", + dtype: str = T.int32, ) -> Var: """The reduced block axis defining function. @@ -593,7 +593,7 @@ def reduce( def scan( dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, - dtype: str = "int32", + dtype: str = T.int32, ) -> Var: """The scanning block axis defining function. @@ -621,7 +621,7 @@ def scan( def opaque( dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, - dtype: str = "int32", + dtype: str = T.int32, ) -> Var: """The opaque block axis defining function. @@ -646,7 +646,7 @@ def opaque( ) @staticmethod - def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: + def remap(kinds: str, bindings: List[PrimExpr], dtype: str = T.int32) -> Union[List[Var], Var]: """The block axis remapping function. Parameters @@ -1133,7 +1133,7 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name def decl_buffer( shape, - dtype="float32", + dtype=T.float32, data=None, strides=None, elem_offset=None, @@ -1184,7 +1184,7 @@ def decl_buffer( """ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape if strides is not None: - strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] + strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides] else: strides = [] return _ffi_api.DeclBuffer( # type: ignore[attr-defined] # pylint: disable=no-member @@ -1237,7 +1237,7 @@ def launch_thread( return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member -def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar: +def env_thread(thread_tag: str, dtype: str = T.int32) -> IterVar: """Bind a var to thread env Parameters @@ -1656,7 +1656,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer: args = [] for name, i in zip(params.keys(), identity + identity): if isinstance(i, int): - args.append(Var(name, "int32")) + args.append(Var(name, T.int32)) else: args.append(Var(name, i.dtype)) res = combiner(*args) diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index 20c5d1b4b..6f650470d 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -94,7 +94,7 @@ def legalize_arguments(arg: tir.Buffer | tir.Var): offset_a = A_offset[-1] offset_b = B_offset[-1] - mbar = to_buffer_region(mbar, access_type="rw") if mbar is not None else tir.const(0, "uint32") + mbar = to_buffer_region(mbar, access_type="rw") if mbar is not None else tir.const(0, T.uint32) C_coords = [r.min for r in C_region.region] # Convert BufferRegion to tl.region calls for arguments A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) diff --git a/tilelang/language/parser/entry.py b/tilelang/language/parser/entry.py index 5f2aaab7b..53316d8c2 100644 --- a/tilelang/language/parser/entry.py +++ b/tilelang/language/parser/entry.py @@ -157,7 +157,7 @@ class BufferProxy: def __call__( self, shape, - dtype="float32", + dtype=T.float32, data=None, strides=None, elem_offset=None, diff --git a/tilelang/language/tir/entry.py b/tilelang/language/tir/entry.py index 82ae7d70f..8d65786e4 100644 --- a/tilelang/language/tir/entry.py +++ b/tilelang/language/tir/entry.py @@ -89,12 +89,12 @@ def dynamic_capture(A, B): @T.prim_func - def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + def use1(A: T.Buffer((1024,), T.int32), B: T.Buffer((), T.int32)) -> None: for x_value in T.serial(10): static_capture(A, B) ### Produces B[()] = A[128] @T.prim_func - def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: + def use2(A: T.Buffer((1024,), T.int32), B: T.Buffer((), T.int32)) -> None: for x_value in T.serial(10): dynamic_capture(A, B) ### Produces B[()] = A[x_value] ``` diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index 6cf784184..d622911df 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1163,7 +1163,7 @@ def ptx_tcgen05_mma_ss( desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]). Aliases: you can also pass `ws` or `warp_specialized` (booleans) instead of `enable_ws`. Alternatively, use `variant="ws"` (or "default"). - - kind_dtype: instruction kind selector (e.g., "float16" for kind::f16, + - kind_dtype: instruction kind selector (e.g., T.float16 for kind::f16, "tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4). """ # Aliases precedence: if either `ws` or `warp_specialized` is provided, they override enable_ws @@ -1224,7 +1224,7 @@ def ptx_tcgen05_mma_ts( Expects 13 positional arguments: (kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset, desc_val, scale_out, mask0, mask1, mask2, mask3). - - kind_dtype: instruction kind selector (e.g., "float16" for kind::f16, + - kind_dtype: instruction kind selector (e.g., T.float16 for kind::f16, "tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4). """ return call_intrin( diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 645a1ad92..8e8930a1b 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -13,6 +13,7 @@ import tvm from tvm.tir import Buffer from tvm.script.ir_builder import tir, IRBuilder + from tvm.tir.expr import BufferLoad, EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union from collections.abc import Sequence diff --git a/tilelang/language/v2/dtypes.py b/tilelang/language/v2/dtypes.py index 6ed56b48a..c872985f9 100644 --- a/tilelang/language/v2/dtypes.py +++ b/tilelang/language/v2/dtypes.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: class dtype(Generic[_T]): - def torch(self) -> torch.dtype: ... + def as_torch(self) -> torch.dtype: ... else: dtype = tvm.DataType @@ -68,7 +68,32 @@ def torch(self) -> torch.dtype: ... torch.bfloat16: "bfloat16", } -# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} +_extended_torch_dtypes = [ + ("float8_e4m3fn",), + ("float8_e4m3fnuz",), + ("float8_e5m2",), + ("float8_e5m2fnuz",), + ("float8_e8m0fnu",), + ("float4_e2m1fnx2",), +] +for dtype_name_tuple in _extended_torch_dtypes: + dtype_name = dtype_name_tuple[0] + torch_dtype = getattr(torch, dtype_name, None) + if torch_dtype is not None: + _TORCH_DTYPE_TO_STR[torch_dtype] = dtype_name + + +_CANONICAL_TO_DISPLAY_STR = { + "double": "float64", + "float": "float32", + "int": "int32", + "long": "int64", + "short": "int16", + "uint": "uint32", + "ulong": "uint64", +} + +_STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} # _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()} @@ -76,7 +101,9 @@ def torch(self) -> torch.dtype: ... _STR_TO_TVM_DTYPE_CALL = { "bool": "Boolean", + "int4": "Int4", "int8": "Int8", + "int16": "Int16", "int32": "Int32", "int64": "Int64", "uint8": "UInt8", @@ -127,12 +154,20 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var return call(expr, is_size_var) +def __dtype_as_torch__(self: dtype) -> torch.dtype: + """Convert TileLang dtype to PyTorch dtype.""" + dtype_str = str(self) + if dtype_str in _STR_TO_TORCH_DTYPE: + return _STR_TO_TORCH_DTYPE[dtype_str] + raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}") + + __orig_dtype_new = dtype.__new__ def __dtype_new__(cls, value: AnyDType) -> dtype: if isinstance(value, str): - return __orig_dtype_new(cls, value) + return __orig_dtype_new(cls, _CANONICAL_TO_DISPLAY_STR.get(value, value)) elif value in _DTYPE_TO_STR: return __orig_dtype_new(cls, _DTYPE_TO_STR[value]) else: @@ -142,6 +177,7 @@ def __dtype_new__(cls, value: AnyDType) -> dtype: dtype.__call__ = __dtype_call__ dtype.__new__ = __dtype_new__ +dtype.as_torch = __dtype_as_torch__ def get_tvm_dtype(value: AnyDType) -> dtype: @@ -155,10 +191,12 @@ def get_tvm_dtype(value: AnyDType) -> dtype: class bool(dtype): ... class short(dtype): ... class int(dtype): ... + class uint(dtype): ... class long(dtype): ... class half(dtype): ... class float(dtype): ... class double(dtype): ... + class int4(dtype): ... class int8(dtype): ... class int16(dtype): ... class int32(dtype): ... @@ -320,10 +358,12 @@ class bfloat16(dtype): ... bool = dtype("bool") short = dtype("int16") int = dtype("int32") + uint = dtype("uint32") long = dtype("int64") half = dtype("float16") float = dtype("float32") double = dtype("float64") + int4 = dtype("int4") int8 = dtype("int8") int16 = dtype("int16") int32 = dtype("int32") @@ -484,10 +524,12 @@ class bfloat16(dtype): ... "bool", "short", "int", + "uint", "long", "half", "float", "double", + "int4", "int8", "int16", "int32", diff --git a/tilelang/layout/gemm_sp.py b/tilelang/layout/gemm_sp.py index e68c11674..7ae836bc8 100644 --- a/tilelang/layout/gemm_sp.py +++ b/tilelang/layout/gemm_sp.py @@ -31,10 +31,20 @@ def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, bl block_k = 128 # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2) - if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8_e4m3", "float8_e5m2"]: + if mma_dtype not in [ + T.float16, + T.bfloat16, + T.float32, + T.int8, + T.float8_e4m3, + T.float8_e4m3fn, + T.float8_e4m3fnuz, + T.float8_e5m2, + T.float8_e5m2fnuz, + ]: raise NotImplementedError(f"Unsupported dtype: {mma_dtype}") - if buffer.dtype not in ["uint8", "int8"]: + if buffer.dtype not in [T.uint8, T.int8]: raise ValueError(f"metadata should be 8 bit, got {buffer.dtype}") bits_map = { @@ -43,7 +53,10 @@ def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, bl "float32": 32, "int8": 8, "float8_e4m3": 8, + "float8_e4m3fn": 8, + "float8_e4m3fnuz": 8, "float8_e5m2": 8, + "float8_e5m2fnuz": 8, } # ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117 @@ -112,10 +125,10 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str): buffer: metadata buffer shape, for sm80 it should be a 16bit type """ - if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]: + if mma_dtype in [T.float16, T.bfloat16] and buffer.dtype not in [T.uint16, T.int16]: raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}") - if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]: + if mma_dtype in ["float8_e4m3", "float8_e5m2", T.int8, T.uint8] and buffer.dtype not in [T.uint32, T.int32]: raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}") m, k = buffer.shape @@ -134,7 +147,7 @@ def ColumnMajorInterleaved(i: int, j: int) -> int: return T.Layout(buffer.shape, ColumnMajorInterleaved) -def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = "float16", arch: str | None = None, **extra_args): +def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = T.float16, arch: str | None = None, **extra_args): if arch is None: arch = nvcc.get_target_compute_version() diff --git a/tilelang/quantize/lop3.py b/tilelang/quantize/lop3.py index e0788dab4..6f1f457d1 100644 --- a/tilelang/quantize/lop3.py +++ b/tilelang/quantize/lop3.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from typing import Literal +from tilelang import language as T decode_i4_to_f16 = """ template @@ -1088,10 +1089,10 @@ def get_lop3_intrin_group( - out_dtype: Literal["float16", "int8", "int4"], - source_format: Literal["int", "uint"] = "uint", + out_dtype: Literal[T.float16, T.int8, T.int4], + source_format: Literal[T.int, T.uint] = T.uint, source_bit: int = 4, - storage_dtype: Literal["int32", "int8"] = "int8", + storage_dtype: Literal[T.int32, T.int8] = T.int8, with_scaling: bool = False, with_zeros: bool = False, zeros_mode: Literal["original", "rescale", "quantized"] = "original", @@ -1104,10 +1105,10 @@ def get_lop3_intrin_group( Parameters ---------- - in_dtype : Literal["int8"] + in_dtype : Literal[T.int8] The data type of the input. It should be "int8". - out_dtype : Literal["float16", "int8", "int4"] + out_dtype : Literal[T.float16, T.int8, T.int4] The data type of the output. It can be either "float16" or "int8" or "int4". storage_nbit : int, optional @@ -1130,18 +1131,17 @@ def get_lop3_intrin_group( Dict[str, str] A dictionary mapping the names of the intrinsics to their corresponding implementations. """ - assert out_dtype in ["float16", "int8", "int4"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ." + out_dtype, source_format, storage_dtype = T.dtype(out_dtype), T.dtype(source_format), T.dtype(storage_dtype) + assert out_dtype in [T.float16, T.int8, T.int4], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ." - dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"} + dtype_mapping = {T.float16: "f16", T.int4: "i4", T.int8: "i8", T.int32: "i32"} target_dtype = dtype_mapping[out_dtype] - if source_format not in ["int", "uint"]: - raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.") - if with_zeros and source_format == "int": + if source_format not in [T.int, T.uint]: + raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}, {type(source_format)}.") + if with_zeros and source_format == T.int: raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}") - source_symbol = "i" if source_format == "int" else "u" - import_c_map = { "i4_to_f16": decode_i4_to_f16, "i2_to_f16": decode_i2_to_f16, @@ -1176,15 +1176,15 @@ def get_lop3_intrin_group( if is_ladder_stage3: key += "_offset" - if out_dtype == "float16": + if out_dtype == T.float16: d4f = "f16" - elif out_dtype == "int8": + elif out_dtype == T.int8: d4f = "i8s" - elif out_dtype == "int4": + elif out_dtype == T.int4: d4f = "i4s" else: raise ValueError(f"Unsupported target dtype: {target_dtype}") - source_symbol = "u" if source_format == "uint" else "s" + source_symbol = "u" if source_format == T.uint else "s" func_name = f"decode_i{source_bit}{source_symbol}_to_{d4f}" if with_scaling: func_name += "_scale" diff --git a/tilelang/quantize/mxfp.py b/tilelang/quantize/mxfp.py index e5c472cb1..dd7100a62 100644 --- a/tilelang/quantize/mxfp.py +++ b/tilelang/quantize/mxfp.py @@ -1,4 +1,5 @@ from typing import Literal +from tilelang import language as T # Implementation asm for fp4 to bf16, using twiddling # Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18 @@ -49,10 +50,10 @@ def get_mxfp_intrin_group( - out_dtype: Literal["float16", "bfloat16"] = "bfloat16", - source_format: Literal["int", "uint"] = "uint", + out_dtype: Literal[T.float16, T.bfloat16] = T.bfloat16, + source_format: Literal[T.int, T.uint] = T.uint, source_bit: int = 4, - storage_dtype: Literal["int32", "int8", "uint8"] = "uint8", + storage_dtype: Literal[T.int32, T.int8, T.uint8] = T.uint8, use_twiddling: bool = False, ) -> dict[str, str]: """ @@ -65,10 +66,10 @@ def get_mxfp_intrin_group( `_twiddling`). Parameters: - out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16". + out_dtype: Target floating-point type for decoded values; either T.float16 or T.bfloat16. source_format: Integer source representation; "int" or "uint". source_bit: Bit width of the packed source format (e.g., 4). - storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8"). + storage_dtype: Underlying storage integer dtype (one of T.int32, T.int8, T.uint8). use_twiddling: When True, select the twiddling variant of the decoding intrinsic. Returns: @@ -80,11 +81,12 @@ def get_mxfp_intrin_group( AssertionError: if out_dtype, source_format, or storage_dtype are not supported. KeyError: if the constructed key does not match any available C source implementation. """ - assert out_dtype in ["float16", "bfloat16"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'." - assert source_format in ["int", "uint"], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'." - assert storage_dtype in ["int32", "int8", "uint8"], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'." + out_dtype, source_format, storage_dtype = T.dtype(out_dtype), T.dtype(source_format), T.dtype(storage_dtype) + assert out_dtype in [T.float16, T.bfloat16], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'." + assert source_format in [T.int, T.uint], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'." + assert storage_dtype in [T.int32, T.int8, T.uint8], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'." - dtype_map = {"float16": "f16", "bfloat16": "bf16"} + dtype_map = {T.float16: "f16", T.bfloat16: "bf16"} key = f"fp{source_bit}_to_{dtype_map[out_dtype]}" if use_twiddling: key += "_twiddling" diff --git a/tilelang/quantize/quantization.py b/tilelang/quantize/quantization.py index db9d2349d..13552f674 100644 --- a/tilelang/quantize/quantization.py +++ b/tilelang/quantize/quantization.py @@ -22,6 +22,7 @@ # pylint: disable=invalid-name,missing-function-docstring,unused-variable """TIR computation utilities for quantization.""" +from tilelang import language as T from tilelang import tvm as tvm from tvm import tir @@ -36,7 +37,7 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa. Behavior: - - Validates `nbit == 4`, `dtype == "bfloat16"`, and `val.dtype == "uint8"` (AssertionError if violated). + - Validates `nbit == 4`, `dtype == T.bfloat16`, and `val.dtype == T.uint8` (AssertionError if violated). - Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`). - Interprets the 4-bit field as: sign = bit3, exponent = bits1-2, mantissa = bit0. - Converts the 2-bit exponent to bf16 exponent space by adding a bias of 126, adds `scale` to that exponent, @@ -49,27 +50,27 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale - val: uint8 expression containing packed fields. - pos: index of the field within `val` (0-based); used to compute the bit shift. - scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression). - - dtype: must be "bfloat16". + - dtype: must be T.bfloat16. Returns: - A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value. """ assert nbit == 4 - assert dtype == "bfloat16" - assert val.dtype == "uint8" - mask = tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") + assert dtype == T.bfloat16 + assert val.dtype == T.uint8 + mask = tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16) # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 - e_bf16 = e_f4 + tir.const(126, "uint16") + e_bf16 = e_f4 + tir.const(126, T.uint16) # Scale is the exponential part, within the representation of uint8 # To handle the overflow, we use the max function to limit the exponential part to 8 bits - e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) - m_f4 = f4 & tir.const(1, "uint16") - val_bf16 = tir.reinterpret("bfloat16", - ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) - | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) + e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16)) + m_f4 = f4 & tir.const(1, T.uint16) + val_bf16 = tir.reinterpret(T.bfloat16, + ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) + | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16)) return val_bf16 def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): @@ -88,29 +89,29 @@ def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_eve Returns: tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits). """ - mask = tir.const((1 << 16) - 1, "uint32") + mask = tir.const((1 << 16) - 1, T.uint32) res = [] for data in [v0, v1]: - u32_val = tir.reinterpret("uint32", data) + u32_val = tir.reinterpret(T.uint32, data) if round_to_even: - rounding_bias = ((u32_val >> tir.const(16, "uint32")) - & tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") + rounding_bias = ((u32_val >> tir.const(16, T.uint32)) + & tir.const(1, T.uint32)) + tir.const(0x7FFF, T.uint32) u32_val += rounding_bias - res.append((u32_val >> tir.const(16, "uint32")) & mask) - return res[0] | (res[1] << tir.const(16, "uint32")) + res.append((u32_val >> tir.const(16, T.uint32)) & mask) + return res[0] | (res[1] << tir.const(16, T.uint32)) def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr): - mask = tir.const((1 << 16) - 1, "uint32") + mask = tir.const((1 << 16) - 1, T.uint32) x0 = x & mask x1 = (x >> 16) & mask - return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1]) + return (tir.reinterpret(T.float32, x << tir.const(16, T.uint32)) for x in [x0, x1]) def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert val.dtype == "uint32" - mask = tvm.tir.const((1 << nbit) - 1, "uint32") - return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask) + assert val.dtype == T.uint32 + mask = tvm.tir.const((1 << nbit) - 1, T.uint32) + return tir.Cast(dtype, (val >> (pos * nbit).astype(T.uint32)) & mask) def _tir_packed_uint_to_uint_to_float(storage_nbit: int): @@ -119,7 +120,7 @@ def _tir_packed_uint_to_uint_to_float(storage_nbit: int): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" max_int_value = (1 << (nbit - 1)) - 1 - return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( + return ((val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & tir.const( (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) return f_convert @@ -130,74 +131,74 @@ def _tir_packed_int_to_int_to_float(storage_nbit: int): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" - mask = tir.const((1 << nbit) - 1, "int32") - unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + mask = tir.const((1 << nbit) - 1, T.int32) + unextended = (val >> (pos.astype(T.int32) * tir.const(nbit, T.int32))) & mask return tir.Cast( - dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + dtype, (unextended << tir.const(32 - nbit, T.int32)) >> tir.const(32 - nbit, T.int32)) return f_convert def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): - assert val.dtype == "float32" - val_u32 = tir.reinterpret("uint32", val) + assert val.dtype == T.float32 + val_u32 = tir.reinterpret(T.uint32, val) # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7) # e_f32 == 120 -> e_f4 = 1 # e_f32 < 120 -> e_f4 = 0 - m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32") - e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32") - s = (val_u32 >> tir.const(31, "uint32")) + m_h = (val_u32 >> tir.const(22, T.uint32)) & tir.const(1, T.uint32) + e_f32 = (val_u32 >> tir.const(23, T.uint32)) & tir.const(255, T.uint32) + s = (val_u32 >> tir.const(31, T.uint32)) e_f4 = tir.Select( - e_f32 > tir.const(120, "uint32"), - tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), - tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), - tir.const(0, "uint32"))) - return (s << tir.const(3, "uint32")) | e_f4 + e_f32 > tir.const(120, T.uint32), + tir.Min(e_f32 - tir.const(120, T.uint32) + m_h, tir.const(7, T.uint32)), + tir.Select(e_f32 == tir.const(120, T.uint32), tir.const(1, T.uint32), + tir.const(0, T.uint32))) + return (s << tir.const(3, T.uint32)) | e_f4 def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): - assert val.dtype == "float16" - val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val)) - m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32") - e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32") - s = (val_u32 >> tir.const(15, "uint32")) + assert val.dtype == T.float16 + val_u32 = tir.Cast(T.uint32, tir.reinterpret(T.uint16, val)) + m_h = (val_u32 >> tir.const(9, T.uint32)) & tir.const(1, T.uint32) + e_f16 = (val_u32 >> tir.const(10, T.uint32)) & tir.const(31, T.uint32) + s = (val_u32 >> tir.const(15, T.uint32)) e_f4 = tir.Select( - e_f16 > tir.const(8, "uint32"), - tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), - tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) - return (s << tir.const(3, "uint32")) | e_f4 + e_f16 > tir.const(8, T.uint32), + tir.Min(e_f16 - tir.const(8, T.uint32) + m_h, tir.const(7, T.uint32)), + tir.Select(e_f16 == tir.const(8, T.uint32), tir.const(1, T.uint32), tir.const(0, T.uint32))) + return (s << tir.const(3, T.uint32)) | e_f4 def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 - assert dtype == "float32" - assert val.dtype == "uint32" + assert dtype == T.float32 + assert val.dtype == T.uint32 # e_f4 == 0 -> e_f32 = 0 # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2 - mask = tvm.tir.const((1 << nbit) - 1, "uint32") - f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask - s = f4 >> tir.const(3, "uint32") - e_f4 = f4 & tir.const(7, "uint32") - e_f32 = e_f4 | tir.const(120, "uint32") - val_f32 = tir.reinterpret("float32", - (e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) - return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) + mask = tvm.tir.const((1 << nbit) - 1, T.uint32) + f4 = (val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & mask + s = f4 >> tir.const(3, T.uint32) + e_f4 = f4 & tir.const(7, T.uint32) + e_f32 = e_f4 | tir.const(120, T.uint32) + val_f32 = tir.reinterpret(T.float32, + (e_f32 | (s << tir.const(8, T.uint32))) << tir.const(23, T.uint32)) + return tir.Select(e_f4 == tir.const(0, T.uint32), tir.const(0, T.float32), val_f32) def _tir_packed_to_fp4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert nbit == 4 - assert dtype == "float16" - assert val.dtype == "uint32" + assert dtype == T.float16 + assert val.dtype == T.uint32 # e_f4 == 0 -> e_f16 = 0 # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 - mask = tvm.tir.const((1 << nbit) - 1, "uint16") - f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask - s = f4 >> tir.const(3, "uint16") - e_f4 = f4 & tir.const(7, "uint16") - e_f16 = e_f4 | tir.const(8, "uint16") - val_f16 = tir.reinterpret("float16", - ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16")) - return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16) + mask = tvm.tir.const((1 << nbit) - 1, T.uint16) + f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask + s = f4 >> tir.const(3, T.uint16) + e_f4 = f4 & tir.const(7, T.uint16) + e_f16 = e_f4 | tir.const(8, T.uint16) + val_f16 = tir.reinterpret(T.float16, + ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16)).astype(T.uint16)) + return tir.Select(e_f4 == tir.const(0, T.uint16), tir.const(0, T.float16), val_f16) def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8): storage_dtype = storage_type + str(storage_nbit) @@ -210,37 +211,37 @@ def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: st s = f4 >> tir.const(3, storage_dtype) e_f4 = f4 & tir.const(7, storage_dtype) e_f16 = e_f4 | tir.const(8, storage_dtype) - val_f16 = tir.reinterpret("float16", - ((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype("uint16")) - return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, "float16"), val_f16) + val_f16 = tir.reinterpret(T.float16, + ((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype(T.uint16)) + return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, T.float16), val_f16) return f_convert def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 - assert dtype == "float16" - s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") - e4 = val & tir.const(0x40, "uint16") - prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), - tir.const(0x4000, "uint16")) - e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | prefix - return tir.reinterpret("float16", s_f16 | e_f16) + assert dtype == T.float16 + s_f16 = (val >> tir.const(7, T.uint16)) << tir.const(15, T.uint16) + e4 = val & tir.const(0x40, T.uint16) + prefix = tir.Select(e4 == tir.const(0, T.uint16), tir.const(0x2000, T.uint16), + tir.const(0x4000, T.uint16)) + e_f16 = ((val & tir.const(63, T.uint16)) << tir.const(7, T.uint16)) | prefix + return tir.reinterpret(T.float16, s_f16 | e_f16) def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 - assert dtype == "float16" - s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") - e4 = val & tir.const(0x40, "uint16") - e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16")) - e_f16 = e_f16 ^ tir.const(0x2000, "uint16") - return tir.reinterpret("float16", s_f16 | e_f16) + assert dtype == T.float16 + s_f16 = (val >> tir.const(7, T.uint16)) << tir.const(15, T.uint16) + e4 = val & tir.const(0x40, T.uint16) + e_f16 = ((val & tir.const(63, T.uint16)) << tir.const(7, T.uint16)) | (e4 << tir.const(8, T.uint16)) | (e4 << tir.const(7, T.uint16)) + e_f16 = e_f16 ^ tir.const(0x2000, T.uint16) + return tir.reinterpret(T.float16, s_f16 | e_f16) def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 - assert dtype == "float16" - return tir.reinterpret("float8_e5m2", val).astype("float16") + assert dtype == T.float16 + return tir.reinterpret("float8_e5m2", val).astype(T.float16) def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): @@ -249,7 +250,7 @@ def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" max_int_value = (1 << (nbit - 1)) - return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( + return ((val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & tir.const( (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) return f_convert @@ -283,10 +284,10 @@ def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" - mask = tir.const((1 << nbit) - 1, "int32") - unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + mask = tir.const((1 << nbit) - 1, T.int32) + unextended = (val >> (pos.astype(T.int32) * tir.const(nbit, T.int32))) & mask return tir.Cast( - dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + dtype, (unextended << tir.const(32 - nbit, T.int32)) >> tir.const(32 - nbit, T.int32)) return f_convert diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index 5e4899b5e..7d31ae46d 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -2,6 +2,7 @@ from tilelang import tvm as tvm from tvm.target import Target from tvm import tir +from tilelang import language as T from tilelang.utils.language import is_shared, is_fragment from tilelang.tileop.base import GemmWarpPolicy from tvm.ir.base import Node @@ -121,7 +122,7 @@ def policy(self) -> GemmWarpPolicy: @property def mbarptr(self) -> PrimExpr: - return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, "uint32")) + return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, T.uint32)) @property def mbar(self) -> tir.Buffer: @@ -131,7 +132,7 @@ def mbar(self) -> tir.Buffer: def C_coords(self): coords = getattr(self.gemm_node, "cCoords", None) if coords is None or len(coords) == 0: - zero = tvm.tir.const(0, "int32") + zero = tvm.tir.const(0, T.int32) return [zero, zero] return [coords[i] for i in range(len(coords))] diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index f93a403eb..de3e72143 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -98,7 +98,7 @@ def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") accum_dtype = str(self.C.dtype) - if accum_dtype not in ["float32", "float16"]: + if accum_dtype not in [str(T.float32), str(T.float16)]: raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") A_shared = self.ARegion diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 8a2d250bb..b42ccd7ed 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -100,10 +100,10 @@ class PassConfigKey(str, Enum): such as `dst[i] = f(src[i])`, avoiding implicit aliasing: ``` - read = T.allocate([1], "int32", "local.var") - write = T.allocate([1], "int32", "local.var") - read_buf = T.Buffer((1,), "int32", data=read, scope="local.var") - write_buf = T.Buffer((1,), "int32", data=write, scope="local.var") + read = T.allocate([1], T.int32, "local.var") + write = T.allocate([1], T.int32, "local.var") + read_buf = T.Buffer((1,), T.int32, data=read, scope="local.var") + write_buf = T.Buffer((1,), T.int32, data=write, scope="local.var") write_buf[0] = read_buf[0] * 2 f(write_buf[0]) ``` @@ -113,8 +113,8 @@ class PassConfigKey(str, Enum): like: ``` - read = T.allocate([1], "int32", "local.var") - read_buf = T.Buffer((1,), "int32", data=read, scope="local.var") + read = T.allocate([1], T.int32, "local.var") + read_buf = T.Buffer((1,), T.int32, data=read, scope="local.var") read_buf[0] = read_buf[0] * 2 f(read_buf[0]) ```