From c96868cc904fe716ff033cae5e289fe241d3671b Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 19 Dec 2024 15:48:08 +0000 Subject: [PATCH] [TUTORIAL] Remove grouped gemm simulation from 09-persistent-matmul MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As discussed in the [multi-buffering PR], the persistent matmul should be kept as an apples-to-apples performance comparison. In particular, the existing perf results makes tensor-descriptor look bad. With this updated tutorial I get results like (`K=4096, prec=fp8`): ``` ├─ 1278.215 4731.062 cublas [M=8192, N=8192, K=4096] │ └─ nan 4731.062 sm90_xmma_gemm_e4m3e4m3_e4m3f32_f32_tn_n_tilesize128x128x128_warpgroupsize1x1x1_bias_f16_execute_segment_k_off_kernel__5x_cublas ├─ 1208.855 454.774 matmul_kernel [M=8192, N=8192, K=4096] ├─ 1285.360 427.706 matmul_kernel_persistent [M=8192, N=8192, K=4096] ├─ 1330.667 413.143 matmul_kernel_descriptor_persistent [M=8192, N=8192, K=4096] └─ 1347.254 408.057 matmul_kernel_tma_persistent [M=8192, N=8192, K=4096] ``` So on H100 tensor descriptor is a 3.5% flops uplift over the plain persistent matmul vs. 4.8% for host-side TMA. For the same shapes with fp16 I see a 13% uplift from tensor descriptor vs. 13.4% from host-side TMA. [multi-buffering PR]: https://github.com/triton-lang/triton/pull/5290#discussion_r1870067182 --- python/tutorials/09-persistent-matmul.py | 59 ++++-------------------- 1 file changed, 10 insertions(+), 49 deletions(-) diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 4b4a08857d7a..eec0c6248c0f 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -50,8 +50,6 @@ def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" - if "tiles_per_update" in args: - ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}, tiles_per_update={args['tiles_per_update']:02}]" if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: @@ -376,8 +374,7 @@ def matmul_tma_persistent(a, b): @triton.jit(launch_metadata=_matmul_launch_metadata) -def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # - a_ptr, b_ptr, c_ptr, # +def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # M, N, K, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # @@ -417,7 +414,6 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # tile_id = start_pid - NUM_SMS ki = -1 - ni = -1 pid_m = 0 pid_n = 0 @@ -427,36 +423,10 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # num_pid_in_group = GROUP_SIZE_M * num_pid_n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Create an opaque value to prevent the descriptor creation from being - # hoisted out of the loop - zero = tl.inline_asm_elementwise("mov.b32 $0, 0;", "=r", [], dtype=tl.int32, is_pure=True, pack=1) for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) if ki == 0: - ni += 1 - - # Simulate a grouped gemm - if ni == tiles_per_update: - a_desc = tl._experimental_make_tensor_descriptor( - a_ptr + zero, - shape=[M, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], - ) - b_desc = tl._experimental_make_tensor_descriptor( - b_ptr + zero, - shape=[N, K], - strides=[K, 1], - block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], - ) - c_desc = tl._experimental_make_tensor_descriptor( - c_ptr + zero, - shape=[M, N], - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], - ) - ni = 0 tile_id += NUM_SMS group_id = tile_id // num_pid_in_group @@ -482,8 +452,7 @@ def matmul_kernel_descriptor_persistent(tiles_per_update: tl.constexpr, # accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) -def matmul_descriptor_persistent(a, b, tiles_per_update): - # Autotuner does not work with TMA. Use manual config. +def matmul_descriptor_persistent(a, b): configs = { torch.float8_e4m3fn: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, @@ -513,7 +482,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) matmul_kernel_descriptor_persistent[grid]( - tiles_per_update, # a, b, c, # M, N, K, # BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # @@ -570,7 +538,7 @@ def bench_fn(reps, warmup_reps, fn, *args): fn(*args) -def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000): +def bench(K, dtype, reps=1000, warmup_reps=10000): M = 8192 N = 8192 a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) @@ -586,10 +554,10 @@ def bench(K, dtype, tiles_per_update, reps=1000, warmup_reps=10000): bench_fn(reps, warmup_reps, matmul_persistent, a, b.T) if supports_tma(): bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b) - bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b, tiles_per_update) + bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b) -def validate(M, N, K, dtype, tiles_per_update): +def validate(M, N, K, dtype): a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) b = b.T.contiguous() @@ -599,7 +567,7 @@ def validate(M, N, K, dtype, tiles_per_update): naive_result = matmul(a, b.T) persistent_result = matmul_persistent(a, b.T) tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None - descriptor_persistent_result = matmul_descriptor_persistent(a, b, tiles_per_update) if supports_tma() else None + descriptor_persistent_result = matmul_descriptor_persistent(a, b) if supports_tma() else None if torch_result is not None: naive_vs_torch = "✅" if torch.allclose(naive_result.to(torch.float16), torch_result.to(torch.float16), @@ -624,7 +592,7 @@ def validate(M, N, K, dtype, tiles_per_update): if tma_persistent_result is not None: print(f"TMA persistent: {naive_vs_tma_persistent} ", end="") if descriptor_persistent_result is not None: - print(f"Device TMA persistent: {naive_vs_descriptor_persistent} ", end="") + print(f"Tensor descriptor persistent: {naive_vs_descriptor_persistent} ", end="") print() @@ -644,13 +612,6 @@ def show_profile(precision, profile_name): parser.add_argument("-K", type=int, required=False, default=512) parser.add_argument("--K_range", type=int, nargs=2) parser.add_argument("--K_step", type=int, default=512) - parser.add_argument( - "--tiles_per_update", - type=int, - default=1, - help= - "Number of output tiles calculated for each update of the tma descriptor in matmul_descriptor_persistent_kernel", - ) parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") args = parser.parse_args() @@ -666,11 +627,11 @@ def show_profile(precision, profile_name): torch.manual_seed(0) - validate(32, 32, 32, dtype, args.tiles_per_update) - validate(8192, 8192, 512, dtype, args.tiles_per_update) + validate(32, 32, 32, dtype) + validate(8192, 8192, 512, dtype) proton.start("matmul", hook="triton") for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): - bench(K, dtype, args.tiles_per_update) + bench(K, dtype) proton.finalize() show_profile(args.prec, "matmul")