Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test-xpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
timeout-minutes: 20
run: |
docker exec -w /root/sglang ci_sglang_xpu \
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py "
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py && python3 bench_moe_topk_softmax.py "

- name: Run E2E Bfloat16 tests
timeout-minutes: 20
Expand Down
136 changes: 90 additions & 46 deletions benchmark/bench_moe_topk_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import triton
from sgl_kernel import topk_softmax
from utils import get_model_config, parse_args


def vllm_topk_softmax(gating_output, topk):
Expand All @@ -23,7 +24,35 @@ def vllm_topk_softmax(gating_output, topk):
return topk_weights, topk_indices


def sglang_topk_softmax(gating_output, topk):
def navtive_topk_softmax(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
num_tokens, num_experts = gating_output.shape

import torch.nn.functional as F

topk_weights = torch.empty(
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
)
topk_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_indices = torch.topk(topk_weights, topk, dim=-1)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

return topk_weights, topk_indices


def sglang_topk_softmax(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
num_tokens, num_experts = gating_output.shape

topk_weights = torch.empty(
Expand All @@ -37,18 +66,18 @@ def sglang_topk_softmax(gating_output, topk):
)

topk_softmax(
topk_weights=topk_weights,
topk_ids=topk_indices,
token_expert_indices=token_expert_indices,
gating_output=gating_output,
topk_weights,
topk_indices,
gating_output,
renormalize=renormalize,
)

return topk_weights, topk_indices


def calculate_diff(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
(num_tokens, num_experts), device=gating_output.device, dtype=torch.float32
)
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
Expand All @@ -67,52 +96,67 @@ def calculate_diff(num_tokens, num_experts, topk):
)


num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
num_experts_range = [32, 64, 128, 256, 12, 512]
topk_range = [1, 2, 4, 8]
def get_benchmark(device="xpu"):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk", "dtype", "renormalize"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang", "native"],
line_names=["SGLang", "native"],
styles=[("blue", "-"), ("green", "-")],
ylabel="Latency (us)",
plot_name="topk-softmax-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, dtype, renormalize, provider):

configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
gating_output = torch.randn(
(num_tokens, num_experts), device=device, dtype=dtype
)

if provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk, renormalize)
elif provider == "native":
fn = lambda: navtive_topk_softmax(gating_output, topk, renormalize)

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang", "vllm"],
line_names=["SGLang", "VLLM"],
styles=[("blue", "-"), ("green", "-")],
ylabel="Latency (us)",
plot_name="topk-softmax-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)

gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms

return benchmark

if provider == "vllm" or provider == "vllm1":
fn = lambda: vllm_topk_softmax(gating_output, topk)
elif provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk)

quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
if __name__ == "__main__":
# Run correctness test on small configs if not using a real model
args = parse_args()
params = get_model_config(args)

sweep_params = {
"num_tokens": args.num_tokens,
"num_experts": params["num_experts"] or [64],
"top_k": params["top_k"] or [2, 4],
"dtype": [torch.bfloat16],
"renormalize": [False],
}

keys = sweep_params.keys()
configs = list(itertools.product(*sweep_params.values()))
print(f"Testing {len(configs)} configurations...")
for config in configs:
num_tokens, num_experts, topk, dtype, renormalize = config
print(
f"Config: num_tokens={num_tokens}, num_experts={num_experts}, topk={topk}, dtype={dtype}, renormalize={renormalize}"
)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms
# calculate_diff(num_tokens, num_experts, topk)

global benchmark_configs
benchmark_configs = configs

if __name__ == "__main__":
configs = [
(20, 256, 4),
(20, 256, 8),
(20, 12, 4),
(20, 12, 1),
(20, 512, 4),
(20, 512, 1),
]
for num_tokens, num_experts, topk in configs:
calculate_diff(num_tokens, num_experts, topk)
benchmark.run(print_data=True)
# Run benchmark
print("Starting performance benchmark...")
benchmark = get_benchmark()
benchmark.run(print_data=True, show_plots=False, save_path=".")
Loading