diff --git a/benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py b/benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py new file mode 100644 index 00000000000..4926bb21457 --- /dev/null +++ b/benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py @@ -0,0 +1,366 @@ +# python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8 +import argparse + +import torch +import triton +from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm +from fbgemm_grouped_gemm import ( + grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise, +) +from transformers import AutoConfig + +from sglang.srt.layers.moe.ep_moe.kernels import ( + grouped_gemm_triton as sglang_grouped_gemm, +) + + +def get_model_config(model_name: str, tp_size: int): + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + if config.architectures[0] == "DbrxForCausalLM": + num_groups = config.ffn_config.moe_num_experts + intermediate_size = config.ffn_config.ffn_hidden_size + elif config.architectures[0] == "JambaForCausalLM": + num_groups = config.num_experts + intermediate_size = config.intermediate_size + elif config.architectures[0] == "Qwen2MoeForCausalLM": + num_groups = config.num_experts + intermediate_size = config.moe_intermediate_size + elif config.architectures[0] == "Qwen3MoeForCausalLM": + num_groups = config.num_experts + intermediate_size = config.moe_intermediate_size + elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: + num_groups = ( + config.n_routed_experts + 1 + if config.architectures[0] in ["DeepseekV3ForCausalLM"] + else config.n_routed_experts + ) + intermediate_size = config.moe_intermediate_size + elif config.architectures[0] == "Llama4ForConditionalGeneration": + num_groups = config.text_config.num_local_experts + intermediate_size = config.text_config.intermediate_size + elif config.architectures[0] in [ + "Grok1ForCausalLM", + "Grok1ImgGen", + "Grok1AForCausalLM", + ]: + num_groups = config.num_local_experts + intermediate_size = config.moe_intermediate_size + else: + num_groups = config.num_local_experts + intermediate_size = config.intermediate_size + + shape_configs = { + "num_groups": num_groups, + "hidden_size": config.hidden_size, + "intermediate_size": intermediate_size, + "dtype": config.torch_dtype, + } + print(f"{shape_configs=}") + return shape_configs + + +def create_test_data(batch_size, num_groups, hidden_size, intermediate_size): + torch.manual_seed(42) + + tokens_per_group = batch_size // num_groups + m_sizes = torch.full( + (num_groups,), tokens_per_group, dtype=torch.int64, device="cuda" + ) + + x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda") + + base_weights = torch.randn( + num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + + w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size) + w_sglang = base_weights + + c_fbgemm = torch.empty( + batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda" + ) + c_sglang = torch.empty( + batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda" + ) + + seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda") + for i in range(1, num_groups + 1): + seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group + + weight_indices = torch.arange(num_groups, dtype=torch.int64, device="cuda") + + return ( + x, + w_fbgemm, + w_sglang, + c_fbgemm, + c_sglang, + m_sizes, + seg_indptr, + weight_indices, + ) + + +def create_fp8_test_data(batch_size, num_groups, hidden_size, intermediate_size): + torch.manual_seed(42) + + tokens_per_group = batch_size // num_groups + m_sizes = torch.full( + (num_groups,), tokens_per_group, dtype=torch.int64, device="cuda" + ) + + x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device="cuda") + w_fp16 = torch.randn( + num_groups * intermediate_size, hidden_size, dtype=torch.float16, device="cuda" + ) + + x_fp8 = x_fp16.to(torch.float8_e4m3fn) + w_fp8 = w_fp16.to(torch.float8_e4m3fn) + + x_scale = torch.randn(batch_size, dtype=torch.float32, device="cuda").abs() + 1e-4 + w_scale = torch.randn(num_groups, dtype=torch.float32, device="cuda").abs() + 1e-4 + + return x_fp8, w_fp8, m_sizes, x_scale, w_scale + + +def get_benchmark_config(use_fp8_w8a8=False): + if use_fp8_w8a8: + return { + "line_vals": ["fbgemm_grouped_gemm_fp8", "sglang_grouped_gemm"], + "line_names": ["FBGEMM Grouped GEMM FP8", "SGLang Grouped GEMM FP8"], + "styles": [("blue", "-"), ("red", "-")], + } + else: + return { + "line_vals": ["fbgemm_grouped_gemm", "sglang_grouped_gemm"], + "line_names": ["FBGEMM Grouped GEMM BF16", "SGLang Grouped GEMM BF16"], + "styles": [("blue", "-"), ("green", "-")], + } + + +def run_benchmark( + model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/" +): + config = get_benchmark_config(use_fp8_w8a8) + + benchmark_config = triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], + line_arg="provider", + line_vals=config["line_vals"], + line_names=config["line_names"], + styles=config["styles"], + ylabel="Time (ms)", + plot_name="grouped-gemm-performance", + args={}, + ) + + @triton.testing.perf_report(benchmark_config) + def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False): + print(f"Benchmarking {provider} with batch_size={batch_size}") + torch.cuda.manual_seed_all(0) + + num_groups = model_config["num_groups"] + hidden_size = model_config["hidden_size"] + intermediate_size = model_config["intermediate_size"] + + if provider == "fbgemm_grouped_gemm_fp8": + try: + test_data = create_fp8_test_data( + batch_size, num_groups, hidden_size, intermediate_size + ) + x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data + + def run_func(): + return fbgemm_grouped_gemm_fp8_rowwise( + x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True + ) + + except Exception as e: + print(f"FP8 not supported, skipping: {e}") + return float("inf"), float("inf"), float("inf") + else: + test_data = create_test_data( + batch_size, num_groups, hidden_size, intermediate_size + ) + ( + x, + w_fbgemm, + w_sglang, + c_fbgemm, + c_sglang, + m_sizes, + seg_indptr, + weight_indices, + ) = test_data + + if provider == "fbgemm_grouped_gemm": + + def run_func(): + return fbgemm_grouped_gemm( + x, w_fbgemm, m_sizes, use_fast_accum=True + ) + + else: + + def run_func(): + return sglang_grouped_gemm( + x, + w_sglang, + c_sglang, + num_groups, + weight_column_major=True, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + c_dtype=c_sglang.dtype, + ) + + for _ in range(10): + try: + run_func() + except Exception as e: + print(f"Error during warmup for {provider}: {e}") + return float("inf"), float("inf"), float("inf") + + torch.cuda.synchronize() + + try: + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles) + return ms, min_ms, max_ms + except Exception as e: + print(f"Error during benchmarking for {provider}: {e}") + return float("inf"), float("inf"), float("inf") + + dynamic_benchmark.run( + show_plots=True, + print_data=True, + save_path=save_path, + model_config=model_config, + use_fp8_w8a8=use_fp8_w8a8, + ) + + +def verify_correctness(model_config, use_fp8_w8a8): + print("Verifying correctness...") + batch_size = 128 + num_groups = model_config["num_groups"] + hidden_size = model_config["hidden_size"] + intermediate_size = model_config["intermediate_size"] + + test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size) + (x, w_fbgemm, w_sglang, c_fbgemm, c_sglang, m_sizes, seg_indptr, weight_indices) = ( + test_data + ) + + try: + result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True) + + result_sglang = sglang_grouped_gemm( + x, + w_sglang, + c_sglang, + num_groups, + weight_column_major=True, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + c_dtype=c_sglang.dtype, + ) + + if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3): + print("✓ BF16 Correctness verification passed!") + else: + max_diff = torch.max(torch.abs(result_fbgemm - result_sglang)) + print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}") + return False + + if use_fp8_w8a8: + try: + fp8_data = create_fp8_test_data( + batch_size, num_groups, hidden_size, intermediate_size + ) + x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale = fp8_data + + result_fp8 = fbgemm_grouped_gemm_fp8_rowwise( + x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale, use_fast_accum=True + ) + + assert result_fp8.shape == (batch_size, intermediate_size) + print("✓ FP8 functionality test passed!") + except Exception as e: + print(f"FP8 test failed (possibly unsupported): {e}") + return False + + return True + + except Exception as e: + print(f"✗ Error during correctness verification: {e}") + return False + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark FBGEMM vs SGLang Grouped GEMM" + ) + parser.add_argument( + "--model", + type=str, + default="mistralai/Mixtral-8x7B-Instruct-v0.1", + help="Model name to get configuration from", + ) + parser.add_argument( + "--tp-size", type=int, default=1, help="Tensor parallelism size" + ) + parser.add_argument( + "--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark" + ) + parser.add_argument( + "--save-path", + type=str, + default="./benchmark_grouped_gemm/", + help="Path to save benchmark results", + ) + parser.add_argument( + "--verify-correctness", + action="store_true", + help="Verify correctness before benchmarking", + ) + + args = parser.parse_args() + + try: + model_config = get_model_config(args.model, args.tp_size) + except Exception as e: + print(f"Failed to get model config: {e}") + print("Using default configuration...") + model_config = { + "num_groups": 8, + "hidden_size": 4096, + "intermediate_size": 14336, + "dtype": torch.bfloat16, + } + + print("Running benchmark with:") + print(f" num_groups: {model_config['num_groups']}") + print(f" hidden_size: {model_config['hidden_size']}") + print(f" intermediate_size: {model_config['intermediate_size']}") + print(f" use_fp8_w8a8: {args.use_fp8_w8a8}") + + if args.verify_correctness: + if not verify_correctness(model_config, args.use_fp8_w8a8): + print("Correctness verification failed. Exiting...") + return + + try: + run_benchmark( + model_config=model_config, + use_fp8_w8a8=args.use_fp8_w8a8, + save_path=args.save_path, + ) + except Exception as e: + print(f"Benchmark failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/benchmark/fbgemm/fbgemm_grouped_gemm.py b/benchmark/fbgemm/fbgemm_grouped_gemm.py new file mode 100644 index 00000000000..7d8568947f5 --- /dev/null +++ b/benchmark/fbgemm/fbgemm_grouped_gemm.py @@ -0,0 +1,1294 @@ +# Copy from https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import functools +import inspect +import sys +import warnings +from typing import Optional + +import torch +import triton # @manual +import triton.language as tl # @manual +from triton.runtime import driver # @manual + + +def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype: + """ + Maps torch dtype to triton dtype. + + Args: + dtype (torch.dtype): input dtype. + + Returns: + tl.dtype: triton dtype. + """ + if dtype == torch.float16: + return tl.float16 + elif dtype == torch.bfloat16: + return tl.bfloat16 + elif dtype == torch.float32: + return tl.float32 + elif dtype == torch.int32: + return tl.int32 + elif dtype == torch.float8_e4m3fn and torch.version.hip is None: + return tl.float8e4nv + else: + raise ValueError(f"Unsupported dtype {dtype}") + + +# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). +HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) + +if HAS_TMA_DESC: + print( + "TMA benchmarks will be running with experimental grid constant TMA descriptor.", + file=sys.stderr, + ) +else: + print( + "TMA benchmarks will be running without grid constant TMA descriptor.", + file=sys.stderr, + ) + + +class TmaAutoTuneHelper: + + # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 + class KernelParamWrapper: + def __init__(self, desc): + self.desc = desc + + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + TMA_SIZE = 128 + + def __init__(self): + self.fill_1d_tma_descriptor_inner = ( + triton.runtime.driver.active.utils.fill_1d_tma_descriptor + ) + self.fill_2d_tma_descriptor_inner = ( + triton.runtime.driver.active.utils.fill_2d_tma_descriptor + ) + if HAS_TMA_DESC: + self.descriptors = {} + else: + self.cuda_descriptors = {} + + # Call this method outside of the lambda function for grid size + def init_tma_descriptor(self, name): + if HAS_TMA_DESC: + self.descriptors[name] = torch.empty( + TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 + ) + else: + self.cuda_descriptors[name] = torch.empty( + TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 + ) + + # Call this method inside the lambda function for grid size + def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, desc_x.data_ptr() + ) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, buf_x.data_ptr() + ) + desc_x.copy_(buf_x, non_blocking=True) + + # Call this method inside the lambda function for grid size + def fill_2d_tma_descriptor( + self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size + ): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() + ) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() + ) + desc_x.copy_(buf_x, non_blocking=True) + + def get_tma_descriptor_kernel_param(self, name): + if HAS_TMA_DESC: + assert self.descriptors[name] is not None + return self.KernelParamWrapper(self.descriptors[name]) + else: + assert self.cuda_descriptors[name] is not None + return self.cuda_descriptors[name] + + +_NV_CONFIGS = [ + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + num_ctas=num_ctas, + ) + for block_size_m in [64, 128] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] + for num_ctas in [1] +] + +_HAS_WS_SUPPORT = None + + +def _check_ws_support(): + if not hasattr(tl, "async_task"): + return False + config_signature = inspect.signature(triton.Config).parameters + if ( + "num_consumer_groups" not in config_signature + or "num_buffers_warp_spec" not in config_signature + ): + return False + if not HAS_TMA_DESC: + return False + return True + + +def _set_ws_support(): + global _HAS_WS_SUPPORT + if _HAS_WS_SUPPORT is None: + _HAS_WS_SUPPORT = _check_ws_support() + + +_set_ws_support() + +if _HAS_WS_SUPPORT: + _NV_WS_CONFIGS = [ + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "NUM_CONSUMER_GROUPS": max(1, num_consumer_groups), + "USE_TMA_LOAD_ON_SCALES": use_tma_load_on_scales, + "USE_TMA_STORE": use_tma_store, + }, + num_stages=num_stages, + num_warps=num_warps, + num_ctas=num_ctas, + num_consumer_groups=num_consumer_groups, + num_buffers_warp_spec=num_stages, + ) + for block_size_m in [64, 128, 256] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [2, 3, 4] + for num_warps in [4, 8, 16] + # TODO(shikaili): Resolve LLVM error. + for num_ctas in [1] + for num_consumer_groups in [0, 2] + for use_tma_load_on_scales in [True, False] + # TODO(shikaili): Resolve compatibility with ws. + for use_tma_store in [False] + ] +else: + _NV_WS_CONFIGS = _NV_CONFIGS + + +_AMD_CONFIGS = [ + triton.Config( + { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "waves_per_eu": waves_per_cu, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m in [32, 64, 128] + for block_size_n in [32, 64, 128, 256] + for block_size_k in [128, 256] + for num_stages in [1, 2] + for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)] + for matrix_instr_nonkdim in [16] +] + + +def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs): + device = torch.cuda.current_device() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + if dtsize is None: + dtsize = named_args["c_ptr"].element_size() + if dtype is None: + dtype = named_args["c_ptr"].dtype + + pruned_configs = [] + for config in configs: + kw = config.kwargs + ( + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_stages, + num_warps, + num_consumer_groups, + use_tma_load_on_scales, + ) = ( + kw["BLOCK_SIZE_M"], + kw["BLOCK_SIZE_N"], + kw["BLOCK_SIZE_K"], + config.num_stages, + config.num_warps, + config.num_consumer_groups, + kw.get("USE_TMA_LOAD_ON_SCALES", False), + ) + G, M, N, K = ( + named_args["G"], + named_args["M_BUCKET"], + named_args["N"], + named_args["K"], + ) + + # 1. make sure we have enough smem + max_shared_memory = driver.active.utils.get_device_properties(device)[ + "max_shared_mem" + ] + if torch.version.hip: + required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize + else: + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory > max_shared_memory: + continue + + use_warp_specialization = num_consumer_groups >= 1 + + M_PER_GROUP = M // G + MIN_M_TILES = 32 if torch.version.hip else 64 + # 2. make sure we don't load M tiles that are too big + if ( + not use_warp_specialization + and BLOCK_M > MIN_M_TILES + and BLOCK_M > (M_PER_GROUP * 2) + ): + continue + # 3. make sure we don't load N tiles that are too small + if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2): + continue + + num_sm = driver.active.utils.get_device_properties(device)[ + "multiprocessor_count" + ] + N_TILES = N // BLOCK_N + MIN_N_TILES = 32 if torch.version.hip else 64 + # 4. make sure we don't load N tiles that are too big + if ( + not use_warp_specialization + and BLOCK_N > MIN_N_TILES + and M * N_TILES < num_sm + ): + continue + # 5. make sure we don't load N tiles that are too small + if BLOCK_N < 128 and M * N_TILES > 2 * num_sm: + continue + + # 6. make sure K can be evenly divided + if K % BLOCK_K != 0: + continue + + # 7. make sure we can partition for ws + if use_warp_specialization: + if num_warps != 4: + continue + + # "tritongpu-warp-spec-data-partition" + m_slice = BLOCK_M // num_consumer_groups + n_slice = BLOCK_N // num_consumer_groups + if m_slice < 64 and n_slice < 256: + continue + + if dtsize >= 2: + if use_tma_load_on_scales: + continue + pruned_configs.append(config) + + return pruned_configs + + +@triton.autotune( + configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, + restore_value=["c_ptr"], # restore for scatter_add fusion +) +@triton.jit +def _fbgemm_grouped_gemm( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + scatter_add_indices, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET, + N: tl.constexpr, + K: tl.constexpr, + NUM_SMS: tl.constexpr, + FUSE_SCATTER_ADD: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + USE_FAST_ACCUM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_CONSUMER_GROUPS: tl.constexpr, +) -> None: + tl.static_assert( + not (FUSE_SCATTER_ADD and USE_TMA_STORE), + "Cannot fuse scatter add with TMA store!", + ) + + tidx = tl.program_id(0) + + dtype: tl.dtype = c_ptr.dtype.element_ty + TMA_SIZE: tl.constexpr = tl.constexpr(128) + if USE_TMA_STORE: + c_desc_ptr = workspace + tidx * TMA_SIZE + else: + c_desc_ptr = None + + M_end_offset = 0 + M_end_offset = M_end_offset.to(tl.int64) + iterated_tiles = 0 + iterated_tiles = iterated_tiles.to(tl.int64) + for g in tl.range(G): + # Move across groups + m_size = tl.load(m_sizes + g) + + if m_size > 0: + M_start_offset = M_end_offset + M_end_offset = M_start_offset + m_size + N_start_offset = g.to(tl.int64) * N + n_size = N + + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + + if USE_TMA_STORE: + # pyre-ignore + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start_offset * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_ptr.dtype.element_ty, + ) + # pyre-ignore + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # Move across tiles + while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: + gidx = tidx - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + tl.static_assert(K % BLOCK_SIZE_K == 0) + if USE_TMA_LOAD: + m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + for k_offset in range(0, K, BLOCK_SIZE_K): + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + else: + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = ( + a_desc_ptr + + (M_start_offset + offs_am[:, None]) * K + + offs_k[None, :] + ) + b_ptrs = ( + b_desc_ptr + + (N_start_offset + offs_bn[:, None]) * K + + offs_k[None, :] + ) + for k_offset in range(0, K, BLOCK_SIZE_K): + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + accumulator += tl.dot(a, b.T) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + if USE_TMA_STORE: + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_ptr.dtype.element_ty), + [m_offset, n_offset], + ) + elif FUSE_SCATTER_ADD: + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = offs_am < m_size + m_offsets = tl.load( + scatter_add_indices + M_start_offset + offs_am, + mask=mask, + cache_modifier=".ca", + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.atomic_add( + c_ptr + m_offsets[:, None] * N + offs_bn[None, :], + c, + mask=mask[:, None] and offs_bn[None, :] < n_size, + sem="relaxed", + ) + else: + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store( + c_ptr + + (M_start_offset + offs_am[:, None]) * N + + offs_bn[None, :], + c, + mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, + ) + tidx += NUM_SMS + + iterated_tiles += num_tiles + + +# TODO(shikaili): Too much code duplication. Need to refactor. +@triton.autotune( + configs=_NV_WS_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, + restore_value=["c_ptr"], # restore for scatter_add fusion +) +@triton.jit +def _fbgemm_grouped_gemm_ws( + a_desc_ptr, + b_desc_ptr, + c_ptr, + workspace, + scatter_add_indices, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + NUM_SMS: tl.constexpr, + FUSE_SCATTER_ADD: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_FAST_ACCUM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_CONSUMER_GROUPS: tl.constexpr, + USE_TMA_LOAD_ON_SCALES: tl.constexpr, + USE_TMA_STORE: tl.constexpr, +) -> None: + tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!") + tl.static_assert(not USE_TMA_LOAD_ON_SCALES, "Not supported!") + tl.static_assert( + not (FUSE_SCATTER_ADD and USE_TMA_STORE), + "Cannot fuse scatter add with TMA store!", + ) + + tidx = tl.program_id(0) + + dtype: tl.dtype = c_ptr.dtype.element_ty + TMA_SIZE: tl.constexpr = tl.constexpr(128) + if USE_TMA_STORE: + c_desc_ptr = workspace + tidx * TMA_SIZE + else: + c_desc_ptr = None + + M_end_offset = 0 + M_end_offset = M_end_offset.to(tl.int64) + iterated_tiles = 0 + iterated_tiles = iterated_tiles.to(tl.int64) + for g in tl.range(G): + # Move across groups + m_size = tl.load(m_sizes + g, cache_modifier=".ca") + + if m_size > 0: + M_start_offset = M_end_offset + M_end_offset = M_start_offset + m_size + N_start_offset = g.to(tl.int64) * N + + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + tl.static_assert(N % BLOCK_SIZE_N == 0) + NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N + num_tiles = num_m_tiles * NUM_N_TILES + + if USE_TMA_STORE: + with tl.async_task([0]): + # pyre-ignore + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start_offset * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, N], + element_ty=c_ptr.dtype.element_ty, + ) + # pyre-ignore + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # Move across tiles + next_iterated_tiles = iterated_tiles + num_tiles + if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles): + for i in range(tidx, next_iterated_tiles, NUM_SMS): + gidx = i - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 + ) + tl.static_assert(K % BLOCK_SIZE_K == 0) + + m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + for k_offset in range(0, K, BLOCK_SIZE_K): + with tl.async_task([0]): + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if USE_TMA_STORE: + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + tl._experimental_descriptor_store( + c_desc_ptr, + accumulator.to(c_ptr.dtype.element_ty), + [m_offset, n_offset], + ) + elif FUSE_SCATTER_ADD: + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( + 0, BLOCK_SIZE_M + ) + mask = offs_am < m_size + m_offsets = tl.load( + scatter_add_indices + M_start_offset + offs_am, + mask=mask, + cache_modifier=".ca", + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.atomic_add( + c_ptr + m_offsets[:, None] * N + offs_bn[None, :], + c, + mask=mask[:, None], + sem="relaxed", + ) + else: + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( + 0, BLOCK_SIZE_M + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store( + c_ptr + + (M_start_offset + offs_am[:, None]) * N + + offs_bn[None, :], + c, + mask=offs_am[:, None] < m_size, + cache_modifier=".cs", + ) + tidx += NUM_SMS + + iterated_tiles += num_tiles + + +TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv + + +# TODO(shikaili): clean up redundant 'b_scale_desc_ptr' argument. +@triton.autotune( + configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={ + "early_config_prune": functools.partial( + early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1 + ) + }, + restore_value=["c_ptr"], # restore for scatter_add fusion +) +@triton.jit +def _fbgemm_grouped_gemm_fp8_rowwise( + a_desc_ptr, + a_scale_ptr, + b_desc_ptr, + b_scale_ptr, + b_scale_desc_ptr, + c_ptr, + workspace, + scatter_add_indices, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET, + N: tl.constexpr, + K: tl.constexpr, + NUM_SMS: tl.constexpr, + FUSE_SCATTER_ADD: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_TMA_STORE: tl.constexpr, + USE_FAST_ACCUM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_CONSUMER_GROUPS: tl.constexpr, +) -> None: + tl.static_assert( + not (FUSE_SCATTER_ADD and USE_TMA_STORE), + "Cannot fuse scatter add with TMA store!", + ) + + tidx = tl.program_id(0) + + dtype = TT_FP8_DTYPE + TMA_SIZE: tl.constexpr = tl.constexpr(128) + if USE_TMA_STORE: + c_desc_ptr = workspace + tidx * TMA_SIZE + else: + c_desc_ptr = None + + M_end_offset = 0 + M_end_offset = M_end_offset.to(tl.int64) + iterated_tiles = 0 + iterated_tiles = iterated_tiles.to(tl.int64) + for g in tl.range(G): + # Move across groups + m_size = tl.load(m_sizes + g) + + if m_size > 0: + M_start_offset = M_end_offset + M_end_offset = M_start_offset + m_size + N_start_offset = g.to(tl.int64) * N + n_size = N + + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + + if USE_TMA_STORE: + # pyre-ignore + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start_offset * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, n_size], + element_ty=c_ptr.dtype.element_ty, + ) + # pyre-ignore + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # Move across tiles + while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: + gidx = tidx - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + tl.static_assert(K % BLOCK_SIZE_K == 0) + if USE_TMA_LOAD: + m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + for k_offset in range(0, K, BLOCK_SIZE_K): + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + else: + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = ( + a_desc_ptr + + (M_start_offset + offs_am[:, None]) * K + + offs_k[None, :] + ) + b_ptrs = ( + b_desc_ptr + + (N_start_offset + offs_bn[:, None]) * K + + offs_k[None, :] + ) + for k_offset in range(0, K, BLOCK_SIZE_K): + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + accumulator += tl.dot(a, b.T) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + a_scale = tl.load( + a_scale_ptr + M_start_offset + offs_am[:, None], + mask=offs_am[:, None] < m_size, + ) + b_scale = tl.load( + b_scale_ptr + N_start_offset + offs_bn[None, :], + mask=offs_bn[None, :] < n_size, + ) + c = accumulator.to(tl.float32) * a_scale * b_scale + + if USE_TMA_STORE: + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + tl._experimental_descriptor_store( + c_desc_ptr, + c.to(c_ptr.dtype.element_ty), + [m_offset, n_offset], + ) + elif FUSE_SCATTER_ADD: + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = offs_am < m_size + m_offsets = tl.load( + scatter_add_indices + M_start_offset + offs_am, + mask=mask, + cache_modifier=".ca", + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + tl.atomic_add( + c_ptr + m_offsets[:, None] * N + offs_bn[None, :], + c.to(c_ptr.dtype.element_ty), + mask=mask[:, None] and offs_bn[None, :] < n_size, + sem="relaxed", + ) + else: + tl.store( + c_ptr + + (M_start_offset + offs_am[:, None]) * N + + offs_bn[None, :], + c, + mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size, + ) + tidx += NUM_SMS + + iterated_tiles += num_tiles + + +# TODO(shikaili): Too much code duplication. Need to refactor. +@triton.autotune( + configs=_NV_WS_CONFIGS, + key=["G", "M_BUCKET", "N", "K"], + prune_configs_by={ + "early_config_prune": functools.partial( + early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1 + ) + }, + restore_value=["c_ptr"], # restore for scatter_add fusion +) +@triton.jit +def _fbgemm_grouped_gemm_fp8_rowwise_ws( + a_desc_ptr, + a_scale_ptr, + b_desc_ptr, + b_scale_ptr, + b_scale_desc_ptr, + c_ptr, + workspace, + scatter_add_indices, + m_sizes, + # problem sizes + G: tl.constexpr, + M_BUCKET: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + NUM_SMS: tl.constexpr, + FUSE_SCATTER_ADD: tl.constexpr, + USE_TMA_LOAD: tl.constexpr, + USE_FAST_ACCUM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_CONSUMER_GROUPS: tl.constexpr, + USE_TMA_LOAD_ON_SCALES: tl.constexpr, + USE_TMA_STORE: tl.constexpr, +) -> None: + tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!") + tl.static_assert( + not (FUSE_SCATTER_ADD and USE_TMA_STORE), + "Cannot fuse scatter add with TMA store!", + ) + + tidx = tl.program_id(0) + + dtype = TT_FP8_DTYPE + TMA_SIZE: tl.constexpr = tl.constexpr(128) + if USE_TMA_STORE: + c_desc_ptr = workspace + tidx * TMA_SIZE + else: + c_desc_ptr = None + + M_end_offset = 0 + M_end_offset = M_end_offset.to(tl.int64) + iterated_tiles = 0 + iterated_tiles = iterated_tiles.to(tl.int64) + for g in tl.range(G): + # Move across groups + m_size = tl.load(m_sizes + g, cache_modifier=".ca") + + if m_size > 0: + M_start_offset = M_end_offset + M_end_offset = M_start_offset + m_size + N_start_offset = g.to(tl.int64) * N + + num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M) + tl.static_assert(N % BLOCK_SIZE_N == 0) + NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N + num_tiles = num_m_tiles * NUM_N_TILES + + if USE_TMA_STORE: + with tl.async_task([0]): + # pyre-ignore + tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=c_desc_ptr, + global_address=c_ptr + M_start_offset * N, + load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N], + global_size=[m_size, N], + element_ty=c_ptr.dtype.element_ty, + ) + # pyre-ignore + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr) + + # Move across tiles + next_iterated_tiles = iterated_tiles + num_tiles + if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles): + for i in range(tidx, next_iterated_tiles, NUM_SMS): + gidx = i - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros( + (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 + ) + tl.static_assert(K % BLOCK_SIZE_K == 0) + + m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + for k_offset in range(0, K, BLOCK_SIZE_K): + with tl.async_task([0]): + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if USE_TMA_LOAD_ON_SCALES: + with tl.async_task([0]): + b_scale = tl._experimental_descriptor_load( + b_scale_desc_ptr, + [n_offset], + [BLOCK_SIZE_N], + tl.float32, + ) + + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( + 0, BLOCK_SIZE_M + ) + a_scale = tl.load( + a_scale_ptr + M_start_offset + offs_am[:, None], + mask=offs_am[:, None] < m_size, + cache_modifier=".ca", + ) + c = accumulator.to(tl.float32) * a_scale * b_scale[None, :] + else: + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( + 0, BLOCK_SIZE_M + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + a_scale = tl.load( + a_scale_ptr + M_start_offset + offs_am[:, None], + mask=offs_am[:, None] < m_size, + cache_modifier=".ca", + ) + b_scale = tl.load( + b_scale_ptr + N_start_offset + offs_bn[None, :], + cache_modifier=".ca", + ) + c = accumulator.to(tl.float32) * a_scale * b_scale + + if USE_TMA_STORE: + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + tl._experimental_descriptor_store( + c_desc_ptr, + c.to(c_ptr.dtype.element_ty), + [m_offset, n_offset], + ) + elif FUSE_SCATTER_ADD: + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( + 0, BLOCK_SIZE_M + ) + mask = offs_am < m_size + m_offsets = tl.load( + scatter_add_indices + M_start_offset + offs_am, + mask=mask, + cache_modifier=".ca", + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + tl.atomic_add( + c_ptr + m_offsets[:, None] * N + offs_bn[None, :], + c, + mask=mask[:, None], + sem="relaxed", + ) + else: + with tl.async_task([1, NUM_CONSUMER_GROUPS]): + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( + 0, BLOCK_SIZE_M + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( + 0, BLOCK_SIZE_N + ) + tl.store( + c_ptr + + (M_start_offset + offs_am[:, None]) * N + + offs_bn[None, :], + c, + mask=offs_am[:, None] < m_size, + cache_modifier=".cs", + ) + tidx += NUM_SMS + + iterated_tiles += num_tiles + + +warnings.simplefilter("once") + + +def _grouped_gemm( + *, + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + x_scale: Optional[torch.Tensor], + w_scale: Optional[torch.Tensor], + use_fast_accum: bool, + use_warp_specialization: bool, + output_tensor: Optional[torch.Tensor], + scatter_add_indices: Optional[torch.Tensor], +) -> torch.Tensor: + + USE_TMA_LOAD = not torch.version.hip + USE_TMA_STORE = False + + if USE_TMA_LOAD and not HAS_TMA_DESC: + USE_TMA_LOAD = False + warnings.warn("TMA load is disabled as there is no TMA descriptor support!") + + if USE_TMA_STORE and not HAS_TMA_DESC: + USE_TMA_STORE = False + warnings.warn("TMA store is disabled as there is no TMA descriptor support!") + + # TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton. + if use_warp_specialization and torch.version.hip: + warnings.warn("Warp specialization is disabled as it is not supported on ROCm.") + use_warp_specialization = False + + if use_warp_specialization and not _HAS_WS_SUPPORT: + warnings.warn( + "Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs." + ) + use_warp_specialization = False + + if use_warp_specialization: + assert HAS_TMA_DESC + USE_TMA_STORE = True # Tuning decision + + G = m_sizes.shape[0] + + assert x.is_contiguous() + assert w.is_contiguous() + assert m_sizes.is_contiguous() + + M, K = x.shape + N = w.shape[0] // G + assert K == w.shape[1] + + if output_tensor is None: + FUSE_SCATTER_ADD = False + assert scatter_add_indices is None + y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16) + else: + FUSE_SCATTER_ADD = True + assert scatter_add_indices is not None + assert scatter_add_indices.is_contiguous() + assert scatter_add_indices.shape == (M,) + y = output_tensor + if M == 0 or N == 0: + return y + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + desc_helper = None + desc_x = x + desc_w = w + desc_ws = w_scale + workspace = None + + if USE_TMA_LOAD: + desc_helper = TmaAutoTuneHelper() + desc_helper.init_tma_descriptor("x") + desc_helper.init_tma_descriptor("w") + desc_x = desc_helper.get_tma_descriptor_kernel_param("x") + desc_w = desc_helper.get_tma_descriptor_kernel_param("w") + if use_warp_specialization and w_scale is not None: + desc_helper.init_tma_descriptor("ws") + desc_ws = desc_helper.get_tma_descriptor_kernel_param("ws") + + if USE_TMA_STORE: + workspace = torch.empty( + NUM_SMS * TmaAutoTuneHelper.TMA_SIZE, + device=x.device, + dtype=torch.uint8, + ) + + def grid(META): + if USE_TMA_LOAD: + nonlocal desc_helper # noqa: F824 + desc_helper.fill_2d_tma_descriptor( + "x", + x.data_ptr(), + M, + K, + META["BLOCK_SIZE_M"] // META["NUM_CONSUMER_GROUPS"], + META["BLOCK_SIZE_K"], + x.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "w", + w.data_ptr(), + N * G, + K, + META["BLOCK_SIZE_N"], + META["BLOCK_SIZE_K"], + w.element_size(), + ) + + if META.get("USE_TMA_LOAD_ON_SCALES", False): + desc_helper.fill_1d_tma_descriptor( + "ws", + w_scale.data_ptr(), + N * G, + META["BLOCK_SIZE_N"], + w_scale.element_size(), + ) + + return (NUM_SMS,) + + M_BUCKET_CAP = 16384 + M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP) + if x_scale is not None and w_scale is not None: + assert x_scale.is_contiguous() + assert w_scale.is_contiguous() + fn = ( + _fbgemm_grouped_gemm_fp8_rowwise_ws + if use_warp_specialization + else _fbgemm_grouped_gemm_fp8_rowwise + ) + args = ( + desc_x, + x_scale, + desc_w, + w_scale, + desc_ws, + y, + workspace, + scatter_add_indices, + m_sizes, + G, + M_BUCKET, + N, + K, + NUM_SMS, + FUSE_SCATTER_ADD, + USE_TMA_LOAD, + ) + if use_warp_specialization: + args += (use_fast_accum,) + else: + args += (USE_TMA_STORE, use_fast_accum) + fn[grid](*args) + else: + assert x_scale is None + assert w_scale is None + fn = ( + _fbgemm_grouped_gemm_ws if use_warp_specialization else _fbgemm_grouped_gemm + ) + args = ( + desc_x, + desc_w, + y, + workspace, + scatter_add_indices, + m_sizes, + G, + M_BUCKET, + N, + K, + NUM_SMS, + FUSE_SCATTER_ADD, + USE_TMA_LOAD, + ) + if use_warp_specialization: + args += (use_fast_accum,) + else: + args += (USE_TMA_STORE, use_fast_accum) + fn[grid](*args) + + return y + + +def grouped_gemm( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + use_fast_accum: bool = True, + *, + _use_warp_specialization: bool = True, + _output_tensor: Optional[torch.Tensor] = None, + _scatter_add_indices: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return _grouped_gemm( + x=x, + w=w, + m_sizes=m_sizes, + x_scale=None, + w_scale=None, + use_fast_accum=use_fast_accum, + use_warp_specialization=_use_warp_specialization, + output_tensor=_output_tensor, + scatter_add_indices=_scatter_add_indices, + ) + + +def grouped_gemm_fp8_rowwise( + x: torch.Tensor, + w: torch.Tensor, + m_sizes: torch.Tensor, + x_scale: torch.Tensor, + w_scale: torch.Tensor, + use_fast_accum: bool = True, + *, + _use_warp_specialization: bool = True, + _output_tensor: Optional[torch.Tensor] = None, + _scatter_add_indices: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return _grouped_gemm( + x=x, + w=w, + m_sizes=m_sizes, + x_scale=x_scale, + w_scale=w_scale, + use_fast_accum=use_fast_accum, + use_warp_specialization=_use_warp_specialization, + output_tensor=_output_tensor, + scatter_add_indices=_scatter_add_indices, + ) diff --git a/benchmark/fbgemm/test_grouped_gemm.py b/benchmark/fbgemm/test_grouped_gemm.py new file mode 100644 index 00000000000..92d1b213e13 --- /dev/null +++ b/benchmark/fbgemm/test_grouped_gemm.py @@ -0,0 +1,323 @@ +import os +import sys + +import pytest +import torch + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +try: + from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm + from fbgemm_grouped_gemm import ( + grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise, + ) + + FBGEMM_AVAILABLE = True + print("✓ Successfully imported FBGEMM grouped GEMM") +except ImportError as e: + print(f"✗ Failed to import FBGEMM grouped GEMM: {e}") + FBGEMM_AVAILABLE = False + +try: + from sglang.srt.layers.moe.ep_moe.kernels import ( + grouped_gemm_triton as sglang_grouped_gemm, + ) + + SGLANG_AVAILABLE = True + print("✓ Successfully imported SGLang grouped GEMM") +except ImportError as e: + print(f"✗ Failed to import SGLang grouped GEMM: {e}") + SGLANG_AVAILABLE = False + + +def create_uniform_groups(batch_size, num_groups, device): + tokens_per_group = batch_size // num_groups + return torch.full((num_groups,), tokens_per_group, dtype=torch.int64, device=device) + + +def create_non_uniform_groups(batch_size, num_groups, device): + remaining = batch_size + m_sizes = [] + + for i in range(num_groups - 1): + if remaining <= 1: + size = 1 + else: + max_size = remaining - (num_groups - i - 1) + 1 + size = torch.randint(1, max_size, (1,)).item() + m_sizes.append(size) + remaining -= size + + m_sizes.append(remaining) + return torch.tensor(m_sizes, dtype=torch.int64, device=device) + + +def create_sglang_inputs(x, w, m_sizes, num_groups, intermediate_size, device): + batch_size = x.shape[0] + + c_sglang = torch.empty( + batch_size, intermediate_size, dtype=torch.bfloat16, device=device + ) + + seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device=device) + current_pos = 0 + for i, size in enumerate(m_sizes): + current_pos += size + seg_indptr[i + 1] = current_pos + + weight_indices = torch.arange(num_groups, dtype=torch.int64, device=device) + w_sglang = w.view(num_groups, intermediate_size, -1) + + return c_sglang, seg_indptr, weight_indices, w_sglang + + +def create_fp8_data(batch_size, num_groups, hidden_size, intermediate_size, device): + torch.manual_seed(42) + + x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device) + w_fp16 = torch.randn( + num_groups * intermediate_size, hidden_size, dtype=torch.float16, device=device + ) + + x_fp8 = x_fp16.to(torch.float8_e4m3fn) + w_fp8 = w_fp16.to(torch.float8_e4m3fn) + + x_scale = torch.randn(batch_size, dtype=torch.float32, device=device).abs() + 1e-4 + w_scale = torch.randn(num_groups, dtype=torch.float32, device=device).abs() + 1e-4 + + return x_fp8, w_fp8, x_scale, w_scale + + +@pytest.fixture +def device(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + return torch.device("cuda") + + +@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") +@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available") +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("num_groups", [2, 4, 8]) +@pytest.mark.parametrize("hidden_size", [512, 1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +def test_uniform_groups(batch_size, num_groups, hidden_size, intermediate_size, device): + if batch_size % num_groups != 0: + pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}") + + torch.manual_seed(42) + + m_sizes = create_uniform_groups(batch_size, num_groups, device) + + x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) + w = torch.randn( + num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device + ) + + result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True) + + c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs( + x, w, m_sizes, num_groups, intermediate_size, device + ) + + result_sglang = sglang_grouped_gemm( + x, + w_sglang, + c_sglang, + num_groups, + weight_column_major=True, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + c_dtype=c_sglang.dtype, + ) + + assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3) + + +@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") +@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available") +@pytest.mark.parametrize("batch_size", [63, 100, 127]) +@pytest.mark.parametrize("num_groups", [3, 5, 7]) +@pytest.mark.parametrize("hidden_size", [512, 1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +def test_non_uniform_groups( + batch_size, num_groups, hidden_size, intermediate_size, device +): + torch.manual_seed(42) + + m_sizes = create_non_uniform_groups(batch_size, num_groups, device) + + x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) + w = torch.randn( + num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device + ) + + result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True) + + c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs( + x, w, m_sizes, num_groups, intermediate_size, device + ) + + result_sglang = sglang_grouped_gemm( + x, + w_sglang, + c_sglang, + num_groups, + weight_column_major=True, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + c_dtype=c_sglang.dtype, + ) + + assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3) + + +@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") +@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available") +@pytest.mark.parametrize("batch_size,num_groups", [(64, 4), (128, 8), (256, 16)]) +@pytest.mark.parametrize("hidden_size", [768, 2048, 4096]) +@pytest.mark.parametrize("intermediate_size", [2048, 4096, 8192]) +def test_large_dimensions( + batch_size, num_groups, hidden_size, intermediate_size, device +): + torch.manual_seed(42) + + m_sizes = create_uniform_groups(batch_size, num_groups, device) + + x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) + w = torch.randn( + num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device + ) + + result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True) + + c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs( + x, w, m_sizes, num_groups, intermediate_size, device + ) + + result_sglang = sglang_grouped_gemm( + x, + w_sglang, + c_sglang, + num_groups, + weight_column_major=True, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + c_dtype=c_sglang.dtype, + ) + + assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3) + + +@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") +@pytest.mark.parametrize("batch_size", [32, 64]) +@pytest.mark.parametrize("num_groups", [2, 4]) +@pytest.mark.parametrize("hidden_size", [512, 1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +def test_fp8_uniform_groups( + batch_size, num_groups, hidden_size, intermediate_size, device +): + if batch_size % num_groups != 0: + pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}") + + torch.manual_seed(42) + + m_sizes = create_uniform_groups(batch_size, num_groups, device) + x_fp8, w_fp8, x_scale, w_scale = create_fp8_data( + batch_size, num_groups, hidden_size, intermediate_size, device + ) + + try: + result_fp8 = fbgemm_grouped_gemm_fp8_rowwise( + x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True + ) + assert result_fp8.shape == (batch_size, intermediate_size) + assert result_fp8.dtype == torch.bfloat16 + except Exception as e: + pytest.skip(f"FP8 test failed (possibly unsupported): {e}") + + +@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") +@pytest.mark.parametrize("batch_size", [63, 100]) +@pytest.mark.parametrize("num_groups", [3, 5]) +@pytest.mark.parametrize("hidden_size", [512, 1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +def test_fp8_non_uniform_groups( + batch_size, num_groups, hidden_size, intermediate_size, device +): + torch.manual_seed(42) + + m_sizes = create_non_uniform_groups(batch_size, num_groups, device) + x_fp8, w_fp8, x_scale, w_scale = create_fp8_data( + batch_size, num_groups, hidden_size, intermediate_size, device + ) + + try: + result_fp8 = fbgemm_grouped_gemm_fp8_rowwise( + x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True + ) + assert result_fp8.shape == (batch_size, intermediate_size) + assert result_fp8.dtype == torch.bfloat16 + except Exception as e: + pytest.skip(f"FP8 test failed (possibly unsupported): {e}") + + +@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available") +def test_fbgemm_only_uniform(device): + torch.manual_seed(42) + + batch_size, num_groups = 64, 4 + hidden_size, intermediate_size = 512, 1024 + + m_sizes = create_uniform_groups(batch_size, num_groups, device) + x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) + w = torch.randn( + num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device + ) + + result = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True) + + assert result.shape == (batch_size, intermediate_size) + assert result.dtype == torch.bfloat16 + + +@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available") +def test_sglang_only_uniform(device): + torch.manual_seed(42) + + batch_size, num_groups = 64, 4 + hidden_size, intermediate_size = 512, 1024 + + m_sizes = create_uniform_groups(batch_size, num_groups, device) + x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) + w = torch.randn( + num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device + ) + + c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs( + x, w, m_sizes, num_groups, intermediate_size, device + ) + + result = sglang_grouped_gemm( + x, + w_sglang, + c_sglang, + num_groups, + weight_column_major=True, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + c_dtype=c_sglang.dtype, + ) + + assert result.shape == (batch_size, intermediate_size) + assert result.dtype == torch.bfloat16 + + +def test_imports(): + assert ( + FBGEMM_AVAILABLE or SGLANG_AVAILABLE + ), "Neither FBGEMM nor SGLang is available" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])