Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 75 additions & 45 deletions benchmark/bench_moe_topk_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import triton
from sgl_kernel import topk_softmax
from utils import HAS_VLLM, parse_args


def vllm_topk_softmax(gating_output, topk):
Expand All @@ -23,6 +24,22 @@ def vllm_topk_softmax(gating_output, topk):
return topk_weights, topk_indices


def navtive_topk_softmax(gating_output, topk):
Copy link
Collaborator

@chunyuan-w chunyuan-w Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_tokens, num_experts = gating_output.shape

import torch.nn.functional as F

topk_weights = torch.empty(
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
)
topk_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_indices = torch.topk(topk_weights, topk, dim=-1)
return topk_weights, topk_indices


def sglang_topk_softmax(gating_output, topk):
num_tokens, num_experts = gating_output.shape

Expand All @@ -37,18 +54,18 @@ def sglang_topk_softmax(gating_output, topk):
)

topk_softmax(
topk_weights=topk_weights,
topk_ids=topk_indices,
token_expert_indices=token_expert_indices,
gating_output=gating_output,
topk_weights,
topk_indices,
gating_output,
renormalize=False,
)

return topk_weights, topk_indices


def calculate_diff(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
(num_tokens, num_experts), device=gating_output.device, dtype=torch.float32
)
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
Expand All @@ -67,52 +84,65 @@ def calculate_diff(num_tokens, num_experts, topk):
)


num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
num_experts_range = [32, 64, 128, 256, 12, 512]
topk_range = [1, 2, 4, 8]
def get_benchmark(device="xpu"):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk", "dtype"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang", "native"],
line_names=["SGLang", "native"],
styles=[("blue", "-"), ("green", "-")],
ylabel="Latency (us)",
plot_name="topk-softmax-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, dtype, provider):

configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
gating_output = torch.randn(
(num_tokens, num_experts), device=device, dtype=dtype
)

if HAS_VLLM and (provider == "vllm" or provider == "vllm1"):
fn = lambda: vllm_topk_softmax(gating_output, topk)
elif provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk)
elif provider == "native":
fn = lambda: navtive_topk_softmax(gating_output, topk)

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang", "vllm"],
line_names=["SGLang", "VLLM"],
styles=[("blue", "-"), ("green", "-")],
ylabel="Latency (us)",
plot_name="topk-softmax-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)

gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms

if provider == "vllm" or provider == "vllm1":
fn = lambda: vllm_topk_softmax(gating_output, topk)
elif provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk)
return benchmark

quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
# Run correctness test on small configs if not using a real model
args = parse_args()
sweep_params = {
"num_tokens": [1, 32, 128, 512],
"num_experts": args.num_experts or [64],
"top_k": args.top_k or [2, 4],
"dtype": [torch.float16, torch.bfloat16],
}
keys = sweep_params.keys()
configs = list(itertools.product(*sweep_params.values()))
print(f"Testing {len(configs)} configurations...")
for config in configs:
num_tokens, num_experts, topk, dtype = config
print(
f"Config: num_tokens={num_tokens}, num_experts={num_experts}, topk={topk}, dtype={dtype}"
)

# calculate_diff(num_tokens, num_experts, topk)

if __name__ == "__main__":
configs = [
(20, 256, 4),
(20, 256, 8),
(20, 12, 4),
(20, 12, 1),
(20, 512, 4),
(20, 512, 1),
]
for num_tokens, num_experts, topk in configs:
calculate_diff(num_tokens, num_experts, topk)
benchmark.run(print_data=True)
global benchmark_configs
benchmark_configs = configs

# Run benchmark
print("Starting performance benchmark...")
benchmark = get_benchmark()
benchmark.run(print_data=True, show_plots=False, save_path=".")
226 changes: 226 additions & 0 deletions benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# utils.py
# Flexible config loader: supports
# 1. Hugging Face model config (--model-name)
# 2. Manual override via CLI args (e.g., --num-experts)
# 3. Safe fallback defaults

import argparse

from transformers import AutoConfig


def get_model_config(args):
"""
Get model config with priority:
1. CLI args override (e.g., --num-experts)
2. Hugging Face config (if --model-name given)
3. Hardcoded defaults (last resort)

Args:
args: Parsed command-line arguments

Returns:
dict: Standardized model config
"""
config_dict = {}

# Step 1: Load from Hugging Face model (if provided)
if args.model_name:
print(f"📡 Loading config from Hugging Face: {args.model_name}")
try:
hf_config = AutoConfig.from_pretrained(args.model_name)
except Exception as e:
raise ValueError(f"Failed to load {args.model_name}: {e}")

# Extract with fallbacks
config_dict.update(
{
"num_experts": getattr(hf_config, "moe_num_experts", None)
or getattr(hf_config, "num_experts", None)
or getattr(hf_config, "num_local_experts", None),
"top_k": getattr(hf_config, "moe_top_k", None)
or getattr(hf_config, "top_k", None)
or getattr(hf_config, "num_experts_per_tok", None),
"num_layers": getattr(hf_config, "num_hidden_layers", None)
or getattr(hf_config, "num_layers", None),
"hidden_size": getattr(hf_config, "hidden_size", None)
or getattr(hf_config, "d_model", None),
"ffn_hidden_size": getattr(hf_config, "intermediate_size", None)
or getattr(hf_config, "ffn_dim", None),
"num_heads": getattr(hf_config, "num_attention_heads", None),
"num_kv_heads": getattr(hf_config, "num_key_value_heads", None)
or getattr(hf_config, "num_attention_heads", None),
"head_dim": getattr(hf_config, "head_dim", None)
or (
getattr(hf_config, "hidden_size", None)
// getattr(hf_config, "num_attention_heads", 1)
if getattr(hf_config, "hidden_size")
and getattr(hf_config, "num_attention_heads")
else None
),
"vocab_size": getattr(hf_config, "vocab_size", None),
"max_seq_len": getattr(hf_config, "max_position_embeddings", None)
or getattr(hf_config, "n_positions", 32768),
"norm_eps": getattr(hf_config, "rms_norm_eps", None)
or getattr(hf_config, "layer_norm_eps", 1e-6),
"architectures": getattr(hf_config, "architectures", ["Unknown"]),
"dtype": getattr(hf_config, "torch_dtype", "float16"),
}
)
else:
print("🔧 No --model-name provided. Using CLI args or defaults.")

# Step 2: CLI args override everything
cli_overrides = {
"num_experts": args.num_experts,
"top_k": args.top_k,
"num_layers": args.num_layers,
"hidden_size": args.hidden_size,
"ffn_hidden_size": args.ffn_hidden_size,
"num_heads": args.num_heads,
"num_kv_heads": args.num_kv_heads,
"head_dim": args.head_dim,
"vocab_size": args.vocab_size,
"max_seq_len": args.max_seq_len,
"norm_eps": args.norm_eps,
}

for k, v in cli_overrides.items():
if v is not None:
config_dict[k] = v
print(f"⚙️ Overriding {k} = {v} (from CLI)")

# Step 3: Fill missing with safe defaults
defaults = {
"num_experts": 64,
"top_k": 2,
"num_layers": 32,
"hidden_size": 4096,
"ffn_hidden_size": 11008,
"num_heads": 32,
"num_kv_heads": 8,
"head_dim": 128,
"vocab_size": 32000,
"max_seq_len": 32768,
"norm_eps": 1e-6,
"architectures": ["LlamaForCausalLM"],
"dtype": "float16",
}

for k, v in defaults.items():
if k not in config_dict or config_dict[k] is None:
config_dict[k] = v
if args.model_name or any(
getattr(args, field) is not None
for field in ["num_experts", "top_k", "num_layers"]
):
pass # Don't log if user expected override
else:
print(f"💡 Using default {k} = {v}")

# Add model name
config_dict["model_name"] = args.model_name

return config_dict


def parse_args():
"""Parse all possible model and benchmark arguments (support list values)."""
parser = argparse.ArgumentParser(
description="Flexible benchmark with model config support"
)

# Model source
parser.add_argument(
"--model-name",
type=str,
default=None,
help="Hugging Face model name (e.g., deepseek-ai/DeepSeek-R1). If not set, use CLI args.",
)

# MoE parameters (support list)
parser.add_argument(
"--num-experts",
type=int,
default=None,
nargs="*",
help="Number of experts (can provide multiple values for sweep)",
)
parser.add_argument(
"--top-k",
type=int,
default=None,
nargs="*",
help="Top-k experts per token (multiple values allowed)",
)

# Transformer parameters (support list)
parser.add_argument(
"--num-layers",
type=int,
default=None,
nargs="*",
help="Number of transformer layers (multiple values)",
)
parser.add_argument(
"--hidden-size", type=int, default=None, nargs="*", help="Hidden size (d_model)"
)
parser.add_argument(
"--ffn-hidden-size",
type=int,
default=None,
nargs="*",
help="FFN/intermediate size",
)
parser.add_argument(
"--num-heads",
type=int,
default=None,
nargs="*",
help="Number of attention heads",
)
parser.add_argument(
"--num-kv-heads",
type=int,
default=None,
nargs="*",
help="Number of KV heads (for GQA)",
)
parser.add_argument(
"--head-dim",
type=int,
default=None,
nargs="*",
help="Dimension per attention head",
)
parser.add_argument(
"--vocab-size", type=int, default=None, nargs="*", help="Vocabulary size"
)
parser.add_argument(
"--max-seq-len",
type=int,
default=None,
nargs="*",
help="Maximum sequence length",
)
parser.add_argument(
"--norm-eps",
type=float,
default=None,
nargs="*",
help="Normalization epsilon (rms_norm_eps)",
)

# Benchmark settings
parser.add_argument(
"--device", type=str, default="xpu", help="Device (default: xpu)"
)
parser.add_argument(
"--dtype",
type=str,
default="torch.float32",
choices=["torch.float32", "torch.float16", "torch.bfloat16"],
help="Data type",
)

return parser.parse_args()
Loading