Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 0 additions & 75 deletions python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py

This file was deleted.

83 changes: 43 additions & 40 deletions python/sglang/jit_kernel/benchmark/bench_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,39 @@
from flashinfer.norm import fused_add_rmsnorm as fi_fused_add_rmsnorm
from flashinfer.norm import rmsnorm as fi_rmsnorm

from sglang.jit_kernel.benchmark.utils import run_benchmark
from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark
from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm
from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.utils import is_in_ci

register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large")
register_cuda_ci(est_time=30, suite="stage-b-kernel-benchmark-1-gpu-large")

IS_CI = is_in_ci()

DTYPE = torch.bfloat16
DEVICE = "cuda"

# JIT rmsnorm: hidden_size in {64,128,256} or (multiple of 256, <=8192)
# JIT fused_add_rmsnorm: hidden_size % 8 == 0, <=8192
# Use multiples of 256 <=8192 to satisfy both kernels
if IS_CI:
BS_LIST = [16]
HIDDEN_SIZE_LIST = [512, 2048]
else:
BS_LIST = [2**n for n in range(0, 14)]
HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192]

LINE_VALS = ["jit", "flashinfer"]
LINE_NAMES = ["SGL JIT Kernel", "FlashInfer"]
BS_LIST = get_benchmark_range(
full_range=[2**n for n in range(0, 14)],
ci_range=[16, 32],
)
HIDDEN_SIZE_LIST = get_benchmark_range(
full_range=sorted([1536, *range(1024, 8192 + 1, 1024)]),
ci_range=[512, 2048],
)

LINE_VALS = ["flashinfer", "jit"]
LINE_NAMES = ["FlashInfer", "SGL JIT Kernel"]
STYLES = [("blue", "--"), ("green", "-.")]
NUM_LAYERS = 4 # avoid L2 effect

configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST))
configs_0 = list(itertools.product(HIDDEN_SIZE_LIST + [16384], BS_LIST))
configs_1 = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["hidden_size", "batch_size"],
x_vals=configs,
x_vals=configs_0,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
Expand All @@ -50,20 +49,24 @@
)
)
def benchmark_rmsnorm(hidden_size: int, batch_size: int, provider: str):
input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE)
weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE)
FN_MAP = {
"jit": lambda: jit_rmsnorm(input.clone(), weight),
"flashinfer": lambda: fi_rmsnorm(input.clone(), weight, out=input.clone()),
}
fn = FN_MAP[provider]
return run_benchmark(fn)
input = torch.randn(
(NUM_LAYERS, batch_size, hidden_size), dtype=DTYPE, device=DEVICE
)
weight = torch.randn((NUM_LAYERS, hidden_size), dtype=DTYPE, device=DEVICE)
FN_MAP = {"jit": jit_rmsnorm, "flashinfer": fi_rmsnorm}

def f():
fn = FN_MAP[provider]
for i in range(NUM_LAYERS):
fn(input[i], weight[i], out=input[i])

return run_benchmark(f, scale=NUM_LAYERS)


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["hidden_size", "batch_size"],
x_vals=configs,
x_vals=configs_1,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
Expand All @@ -74,19 +77,19 @@ def benchmark_rmsnorm(hidden_size: int, batch_size: int, provider: str):
)
)
def benchmark_fused_add_rmsnorm(hidden_size: int, batch_size: int, provider: str):
input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE)
residual = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE)
weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE)
FN_MAP = {
"jit": lambda: jit_fused_add_rmsnorm(
input.clone(), residual.clone(), weight, torch.finfo(DTYPE).eps
),
"flashinfer": lambda: fi_fused_add_rmsnorm(
input.clone(), residual.clone(), weight, eps=torch.finfo(DTYPE).eps
),
}
fn = FN_MAP[provider]
return run_benchmark(fn)
input = torch.randn(
(NUM_LAYERS, batch_size, hidden_size), dtype=DTYPE, device=DEVICE
)
residual = torch.randn_like(input)
weight = torch.randn((NUM_LAYERS, hidden_size), dtype=DTYPE, device=DEVICE)
FN_MAP = {"jit": jit_fused_add_rmsnorm, "flashinfer": fi_fused_add_rmsnorm}

def f():
fn = FN_MAP[provider]
for i in range(NUM_LAYERS):
fn(input[i], residual[i], weight[i])

return run_benchmark(f, scale=NUM_LAYERS)


if __name__ == "__main__":
Expand Down
98 changes: 0 additions & 98 deletions python/sglang/jit_kernel/benchmark/bench_rmsnorm.py

This file was deleted.

19 changes: 12 additions & 7 deletions python/sglang/jit_kernel/benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Common utilities for jit_kernel benchmark files."""

from typing import Callable, List, Tuple
from typing import Callable, List, Sequence, Tuple

import torch
import triton.testing
Expand All @@ -19,25 +19,30 @@ def get_benchmark_range(full_range: List, ci_range: List) -> List:


def run_benchmark(
fn: Callable, quantiles: List[float] = None
fn: Callable,
quantiles: Sequence[float] = (),
scale: float = 1.0,
) -> Tuple[float, float, float]:
"""Execute benchmark using CUDA graph and return times in microseconds.

Args:
fn: Function to benchmark
quantiles: Quantiles for timing measurements [median, min, max]
scale: Scale the result down (usually num_layers).

Returns:
Tuple of (median_us, max_us, min_us)
"""
quantiles = quantiles or DEFAULT_QUANTILES
quantiles = list(quantiles or DEFAULT_QUANTILES)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return 1000 * ms / scale, 1000 * max_ms / scale, 1000 * min_ms / scale


def run_benchmark_no_cudagraph(
fn: Callable, quantiles: List[float] = None
fn: Callable,
quantiles: Sequence[float] = (),
scale: float = 1.0,
) -> Tuple[float, float, float]:
quantiles = quantiles or DEFAULT_QUANTILES
quantiles = list(quantiles or DEFAULT_QUANTILES)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return 1000 * ms / scale, 1000 * max_ms / scale, 1000 * min_ms / scale
Loading
Loading