Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e79bd13
[Kernel] Fuse temperature scaling + softmax into single Triton kernel…
Godmook Mar 13, 2026
7a5928a
Add AutoTune
Godmook Mar 13, 2026
dcd11ad
Add Dispatcher
Godmook Mar 13, 2026
97a84bd
Fix Errors
Godmook Mar 26, 2026
8d73f49
Add Autotune and Refactor Minor changes
Godmook Mar 30, 2026
131ce9d
Fix Lint and Merge Bench
Godmook Mar 30, 2026
e333bc5
Fix AMD CI Crash
Godmook Mar 30, 2026
4138803
Merge branch 'main' into fused_sampling
Godmook Mar 31, 2026
fa120a8
Change Dtype
Godmook Mar 31, 2026
cc84c88
Remove Trition Autotune at In-Place Kernel
Godmook Mar 31, 2026
f356ce2
Fix Sampler Issues
Godmook Apr 1, 2026
bf92dc1
Remove WarmUp for Test Notebook CI
Godmook Apr 1, 2026
066dbb0
Remove WarmUp for Test Notebook CI2
Godmook Apr 1, 2026
d5eeaca
Rollback inplace kernel
Godmook Apr 1, 2026
e27f2b8
Merge branch 'main' into fused_sampling
Godmook Apr 1, 2026
08161e1
sampler: use OOP fused softmax for grammar batches (fix structured ou…
Godmook Apr 1, 2026
f30d3d8
Change to 3-Pass Kernel
Godmook Apr 1, 2026
e4e8954
Add Hybrid Logic
Godmook Apr 1, 2026
a06721d
Add TestCoverage More detailed
Godmook Apr 5, 2026
72b214c
Merge upstream/main into fused_sampling
Godmook Apr 5, 2026
97518a4
Tolerance Update
Godmook Apr 5, 2026
3c8d733
Tolerance Updates
Godmook Apr 5, 2026
8148e92
Sample Size 100
Godmook Apr 5, 2026
932069b
Modify Future Flasky Tests
Godmook Apr 5, 2026
9c560d5
Remove DocString
Godmook Apr 5, 2026
661a1de
Fix Lint
Godmook Apr 5, 2026
c709eb4
Modify DocString
Godmook Apr 5, 2026
aaf61aa
Try online softmax
Godmook Apr 6, 2026
e12fbb4
Modify Threshold and rename SUITE Name
Godmook Apr 6, 2026
afc7193
Refactoring tiny formula
Godmook Apr 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions benchmark/kernels/bench_fused_temperature_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Benchmark: fused_temperature_softmax vs separate div_ + softmax vs flashinfer.sampling.softmax.

Each path clones logits every iteration so timing is not skewed by in-place reuse.
Uses torch.cuda.Event timing; default 50 warmup, 200 timed iterations.

Columns tri/base and fi/base are speedup vs PyTorch baseline; tri/fi is t_flashinfer/t_triton
(>1 means Triton is faster).
"""

import argparse

import torch


def benchmark_fn(fn, warmup=50, iters=200):
"""Time a zero-arg callable using CUDA events."""
for _ in range(warmup):
fn()
torch.cuda.synchronize()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters * 1000 # microseconds


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--warmup", type=int, default=50)
parser.add_argument("--iters", type=int, default=200)
args = parser.parse_args()

from flashinfer.sampling import softmax as flashinfer_softmax

from sglang.srt.layers.fused_sampling import (
fused_temperature_softmax,
fused_temperature_softmax_inplace,
)

configs = [
# (batch_size, vocab_size, dtype)
(1, 32000, torch.bfloat16),
(1, 128256, torch.bfloat16),
(32, 32000, torch.bfloat16),
(32, 128256, torch.bfloat16),
(128, 32000, torch.bfloat16),
(128, 128256, torch.bfloat16),
(512, 32000, torch.bfloat16),
(512, 128256, torch.bfloat16),
]

header = (
f"{'bs':>5} {'vocab':>7} {'dtype':>8} "
f"{'baseline (us)':>14} {'triton (us)':>12} {'inplace (us)':>13} {'flashinfer (us)':>16} "
f"{'tri/base':>9} {'fi/base':>8} {'tri/fi':>7}"
)
print(header)
print("-" * len(header))

for bs, vocab, dtype in configs:
temps = torch.rand(bs, 1, dtype=torch.float32, device="cuda") * 1.5 + 0.1
temps_1d = temps.view(-1)
logits_src = torch.randn(bs, vocab, dtype=dtype, device="cuda")

# --- Baseline: div_ + softmax ---
def run_baseline(src=logits_src, t=temps):
l = src.clone()
l.div_(t)
l[:] = torch.softmax(l, dim=-1)

t_base = benchmark_fn(run_baseline, args.warmup, args.iters)

# --- Triton fused (out-of-place) ---
def run_triton(src=logits_src, t=temps):
fused_temperature_softmax(src.clone(), t)

t_triton = benchmark_fn(run_triton, args.warmup, args.iters)

# --- Triton fused (in-place) ---
def run_inplace(src=logits_src, t=temps):
l = src.clone()
fused_temperature_softmax_inplace(l, t)

t_ip = benchmark_fn(run_inplace, args.warmup, args.iters)

# --- FlashInfer (clone each iter, same as other paths) ---
def run_flashinfer(src=logits_src, t=temps_1d):
l = src.clone()
flashinfer_softmax(l, temperature=t)

t_fi = benchmark_fn(run_flashinfer, args.warmup, args.iters)

sp_triton = t_base / t_triton
sp_fi = t_base / t_fi
tri_vs_fi = t_fi / t_triton
print(
f"{bs:>5} {vocab:>7} {str(dtype):>8} "
f"{t_base:>14.1f} {t_triton:>12.1f} {t_ip:>13.1f} {t_fi:>16.1f} "
f"{sp_triton:>8.2f}x {sp_fi:>7.2f}x {tri_vs_fi:>6.2f}x"
)


if __name__ == "__main__":
main()
Loading
Loading