Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
366 changes: 366 additions & 0 deletions benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
# python3 benchmark/kernels/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
import argparse

import torch
import triton
from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
from fbgemm_grouped_gemm import (
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
from transformers import AutoConfig

from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton as sglang_grouped_gemm,
)


def get_model_config(model_name: str, tp_size: int):
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

if config.architectures[0] == "DbrxForCausalLM":
num_groups = config.ffn_config.moe_num_experts
intermediate_size = config.ffn_config.ffn_hidden_size
elif config.architectures[0] == "JambaForCausalLM":
num_groups = config.num_experts
intermediate_size = config.intermediate_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Qwen3MoeForCausalLM":
num_groups = config.num_experts
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
num_groups = (
config.n_routed_experts + 1
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
intermediate_size = config.moe_intermediate_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
num_groups = config.text_config.num_local_experts
intermediate_size = config.text_config.intermediate_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
num_groups = config.num_local_experts
intermediate_size = config.moe_intermediate_size
else:
num_groups = config.num_local_experts
intermediate_size = config.intermediate_size
Comment on lines +49 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic in get_model_config for unhandled model architectures defaults to using config.num_local_experts and config.intermediate_size. If a new, unhandled architecture doesn't conform to these attribute names, an AttributeError will occur. While this error is caught in main() and a hardcoded default configuration is used, this might lead to the benchmark running with parameters that don't match the user's intended model.

Could we enhance this fallback? For example, by printing a more specific warning when this else branch is taken for an architecture not explicitly listed in the if/elif conditions? This would alert the user that the benchmark might be using assumed (and potentially incorrect) parameters for their model.

    else:
        # Fallback for unrecognized architectures.
        print(f"Warning: Model architecture '{config.architectures[0]}' not explicitly handled. "
              f"Attempting to use default attributes 'num_local_experts' and 'intermediate_size'. "
              f"This may lead to incorrect benchmark parameters if these attributes are not applicable or missing.")
        num_groups = config.num_local_experts
        intermediate_size = config.intermediate_size


shape_configs = {
"num_groups": num_groups,
"hidden_size": config.hidden_size,
"intermediate_size": intermediate_size,
"dtype": config.torch_dtype,
}
print(f"{shape_configs=}")
return shape_configs


def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
torch.manual_seed(42)

tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
)
Comment on lines +66 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current calculation of m_sizes assumes batch_size is perfectly divisible by num_groups. If not, sum(m_sizes) (where m_sizes elements are batch_size // num_groups) would be less than batch_size. This means not all input tokens from x would be processed, potentially leading to incorrect or misleading benchmark results for certain configurations.

Could m_sizes be adjusted to ensure all batch_size tokens are distributed among the groups, for example, by distributing any remainder tokens? This would make the benchmark more robust for arbitrary batch_size and num_groups combinations.

    base_tokens_per_group = batch_size // num_groups
    remainder_tokens = batch_size % num_groups
    m_sizes_list = [base_tokens_per_group] * num_groups
    for i in range(remainder_tokens):
        m_sizes_list[i] += 1
    m_sizes = torch.tensor(m_sizes_list, dtype=torch.int64, device="cuda")
    # Ensure sum(m_sizes) == batch_size, critical for processing all input tokens.
    # Consider adding: assert torch.sum(m_sizes).item() == batch_size, "Sum of m_sizes must equal batch_size"


x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")

base_weights = torch.randn(
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
)

w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
w_sglang = base_weights

c_fbgemm = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)
c_sglang = torch.empty(
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
)

seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda")
for i in range(1, num_groups + 1):
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
Comment on lines +87 to +89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Related to the m_sizes calculation: if m_sizes is updated to correctly distribute all tokens (especially when batch_size is not divisible by num_groups), the calculation for seg_indptr also needs to use these corrected m_sizes rather than the potentially inaccurate tokens_per_group.

SGLang's grouped GEMM relies on seg_indptr to define token segments for each expert. Using the actual m_sizes for its calculation is crucial for correctness.

    seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device="cuda")
    # seg_indptr[0] is 0. Calculate cumulative sum of actual m_sizes.
    # This assumes 'm_sizes' has been redefined to correctly sum to batch_size.
    torch.cumsum(m_sizes, dim=0, out=seg_indptr[1:])


weight_indices = torch.arange(num_groups, dtype=torch.int64, device="cuda")

return (
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
)


def create_fp8_test_data(batch_size, num_groups, hidden_size, intermediate_size):
torch.manual_seed(42)

tokens_per_group = batch_size // num_groups
m_sizes = torch.full(
(num_groups,), tokens_per_group, dtype=torch.int64, device="cuda"
)
Comment on lines +108 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the create_test_data function, the m_sizes calculation here in create_fp8_test_data might not account for all tokens if batch_size is not divisible by num_groups. This could affect the FP8 benchmark path.

Would it be beneficial to apply the same robust m_sizes calculation here (distributing remainder tokens) to ensure all batch_size tokens are processed?

    base_tokens_per_group = batch_size // num_groups
    remainder_tokens = batch_size % num_groups
    m_sizes_list = [base_tokens_per_group] * num_groups
    for i in range(remainder_tokens):
        m_sizes_list[i] += 1
    m_sizes = torch.tensor(m_sizes_list, dtype=torch.int64, device="cuda")


x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device="cuda")
w_fp16 = torch.randn(
num_groups * intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
)

x_fp8 = x_fp16.to(torch.float8_e4m3fn)
w_fp8 = w_fp16.to(torch.float8_e4m3fn)

x_scale = torch.randn(batch_size, dtype=torch.float32, device="cuda").abs() + 1e-4
w_scale = torch.randn(num_groups, dtype=torch.float32, device="cuda").abs() + 1e-4

return x_fp8, w_fp8, m_sizes, x_scale, w_scale


def get_benchmark_config(use_fp8_w8a8=False):
if use_fp8_w8a8:
return {
"line_vals": ["fbgemm_grouped_gemm_fp8", "sglang_grouped_gemm"],
"line_names": ["FBGEMM Grouped GEMM FP8", "SGLang Grouped GEMM FP8"],
"styles": [("blue", "-"), ("red", "-")],
}
else:
return {
"line_vals": ["fbgemm_grouped_gemm", "sglang_grouped_gemm"],
"line_names": ["FBGEMM Grouped GEMM BF16", "SGLang Grouped GEMM BF16"],
"styles": [("blue", "-"), ("green", "-")],
}


def run_benchmark(
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
):
config = get_benchmark_config(use_fp8_w8a8)

benchmark_config = triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
line_arg="provider",
line_vals=config["line_vals"],
line_names=config["line_names"],
styles=config["styles"],
ylabel="Time (ms)",
plot_name="grouped-gemm-performance",
args={},
)

@triton.testing.perf_report(benchmark_config)
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
print(f"Benchmarking {provider} with batch_size={batch_size}")
torch.cuda.manual_seed_all(0)

num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]

if provider == "fbgemm_grouped_gemm_fp8":
try:
test_data = create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data

def run_func():
return fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
)

except Exception as e:
print(f"FP8 not supported, skipping: {e}")
return float("inf"), float("inf"), float("inf")
else:
test_data = create_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
(
x,
w_fbgemm,
w_sglang,
c_fbgemm,
c_sglang,
m_sizes,
seg_indptr,
weight_indices,
) = test_data

if provider == "fbgemm_grouped_gemm":

def run_func():
return fbgemm_grouped_gemm(
x, w_fbgemm, m_sizes, use_fast_accum=True
)

else:

def run_func():
return sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)

for _ in range(10):
try:
run_func()
except Exception as e:
print(f"Error during warmup for {provider}: {e}")
return float("inf"), float("inf"), float("inf")

torch.cuda.synchronize()

try:
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
return ms, min_ms, max_ms
except Exception as e:
print(f"Error during benchmarking for {provider}: {e}")
return float("inf"), float("inf"), float("inf")

dynamic_benchmark.run(
show_plots=True,
print_data=True,
save_path=save_path,
model_config=model_config,
use_fp8_w8a8=use_fp8_w8a8,
)


def verify_correctness(model_config, use_fp8_w8a8):
print("Verifying correctness...")
batch_size = 128
num_groups = model_config["num_groups"]
hidden_size = model_config["hidden_size"]
intermediate_size = model_config["intermediate_size"]

test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
(x, w_fbgemm, w_sglang, c_fbgemm, c_sglang, m_sizes, seg_indptr, weight_indices) = (
test_data
)

try:
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)

result_sglang = sglang_grouped_gemm(
x,
w_sglang,
c_sglang,
num_groups,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
c_dtype=c_sglang.dtype,
)

if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
print("✓ BF16 Correctness verification passed!")
else:
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
return False

if use_fp8_w8a8:
try:
fp8_data = create_fp8_test_data(
batch_size, num_groups, hidden_size, intermediate_size
)
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale = fp8_data

result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
x_fp8, w_fp8, m_sizes_fp8, x_scale, w_scale, use_fast_accum=True
)

assert result_fp8.shape == (batch_size, intermediate_size)
print("✓ FP8 functionality test passed!")
except Exception as e:
print(f"FP8 test failed (possibly unsupported): {e}")
return False

return True

except Exception as e:
print(f"✗ Error during correctness verification: {e}")
return False


def main():
parser = argparse.ArgumentParser(
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
help="Model name to get configuration from",
)
parser.add_argument(
"--tp-size", type=int, default=1, help="Tensor parallelism size"
)
parser.add_argument(
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
)
parser.add_argument(
"--save-path",
type=str,
default="./benchmark_grouped_gemm/",
help="Path to save benchmark results",
)
parser.add_argument(
"--verify-correctness",
action="store_true",
help="Verify correctness before benchmarking",
)

args = parser.parse_args()

try:
model_config = get_model_config(args.model, args.tp_size)
except Exception as e:
print(f"Failed to get model config: {e}")
print("Using default configuration...")
model_config = {
"num_groups": 8,
"hidden_size": 4096,
"intermediate_size": 14336,
"dtype": torch.bfloat16,
}

print("Running benchmark with:")
print(f" num_groups: {model_config['num_groups']}")
print(f" hidden_size: {model_config['hidden_size']}")
print(f" intermediate_size: {model_config['intermediate_size']}")
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")

if args.verify_correctness:
if not verify_correctness(model_config, args.use_fp8_w8a8):
print("Correctness verification failed. Exiting...")
return

try:
run_benchmark(
model_config=model_config,
use_fp8_w8a8=args.use_fp8_w8a8,
save_path=args.save_path,
)
except Exception as e:
print(f"Benchmark failed: {e}")


if __name__ == "__main__":
main()
Loading
Loading