diff --git a/python/tutorials/10-warp-specialized-matmul.py b/python/tutorials/10-warp-specialized-matmul.py index ed51de5809..b4d5fa9b88 100644 --- a/python/tutorials/10-warp-specialized-matmul.py +++ b/python/tutorials/10-warp-specialized-matmul.py @@ -105,14 +105,29 @@ def get_tma_descriptor_kernel_param(self, name): "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, + "NUM_CONSUMER_GROUPS": 2, }, num_stages=2, num_warps=4, num_consumer_groups=2, num_buffers_warp_spec=3, ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=3, + num_warps=4, + num_consumer_groups=0, # disable warp specialization + num_buffers_warp_spec=3, + ), ], key=["M", "N", "K"], + use_cuda_graph=True, ) @triton.jit def matmul_persistent_tma_ws_cooperative_kernel( @@ -126,6 +141,7 @@ def matmul_persistent_tma_ws_cooperative_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # + NUM_CONSUMER_GROUPS: tl.constexpr, ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -171,7 +187,7 @@ def matmul_persistent_tma_ws_cooperative_kernel( c = accumulator.to(tl.float16) - with tl.async_task([1, 2]): + with tl.async_task([1, NUM_CONSUMER_GROUPS]): tl._experimental_descriptor_store(c_ptr, c, [offs_am, offs_bn]) @@ -203,7 +219,7 @@ def grid(META): a.data_ptr(), M, K, - META["BLOCK_SIZE_M"] // 2, + META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], META["BLOCK_SIZE_K"], a.element_size(), ) @@ -222,7 +238,7 @@ def grid(META): c.data_ptr(), M, N, - META["BLOCK_SIZE_M"] // 2, + META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], META["BLOCK_SIZE_N"], c.element_size(), )