diff --git a/benchmark/kernels/bench_fused_temperature_softmax.py b/benchmark/kernels/bench_fused_temperature_softmax.py new file mode 100644 index 000000000000..fc624b721ecf --- /dev/null +++ b/benchmark/kernels/bench_fused_temperature_softmax.py @@ -0,0 +1,108 @@ +"""Benchmark: fused_temperature_softmax vs separate div_ + softmax vs flashinfer.sampling.softmax. + +Each path clones logits every iteration so timing is not skewed by in-place reuse. +Uses torch.cuda.Event timing; default 50 warmup, 200 timed iterations. + +Columns tri/base and fi/base are speedup vs PyTorch baseline; tri/fi is t_flashinfer/t_triton +(>1 means Triton is faster). +""" + +import argparse + +import torch + + +def benchmark_fn(fn, warmup=50, iters=200): + """Time a zero-arg callable using CUDA events.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / iters * 1000 # microseconds + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--iters", type=int, default=200) + args = parser.parse_args() + + from flashinfer.sampling import softmax as flashinfer_softmax + + from sglang.srt.layers.fused_sampling import ( + fused_temperature_softmax, + fused_temperature_softmax_inplace, + ) + + configs = [ + # (batch_size, vocab_size, dtype) + (1, 32000, torch.bfloat16), + (1, 128256, torch.bfloat16), + (32, 32000, torch.bfloat16), + (32, 128256, torch.bfloat16), + (128, 32000, torch.bfloat16), + (128, 128256, torch.bfloat16), + (512, 32000, torch.bfloat16), + (512, 128256, torch.bfloat16), + ] + + header = ( + f"{'bs':>5} {'vocab':>7} {'dtype':>8} " + f"{'baseline (us)':>14} {'triton (us)':>12} {'inplace (us)':>13} {'flashinfer (us)':>16} " + f"{'tri/base':>9} {'fi/base':>8} {'tri/fi':>7}" + ) + print(header) + print("-" * len(header)) + + for bs, vocab, dtype in configs: + temps = torch.rand(bs, 1, dtype=torch.float32, device="cuda") * 1.5 + 0.1 + temps_1d = temps.view(-1) + logits_src = torch.randn(bs, vocab, dtype=dtype, device="cuda") + + # --- Baseline: div_ + softmax --- + def run_baseline(src=logits_src, t=temps): + l = src.clone() + l.div_(t) + l[:] = torch.softmax(l, dim=-1) + + t_base = benchmark_fn(run_baseline, args.warmup, args.iters) + + # --- Triton fused (out-of-place) --- + def run_triton(src=logits_src, t=temps): + fused_temperature_softmax(src.clone(), t) + + t_triton = benchmark_fn(run_triton, args.warmup, args.iters) + + # --- Triton fused (in-place) --- + def run_inplace(src=logits_src, t=temps): + l = src.clone() + fused_temperature_softmax_inplace(l, t) + + t_ip = benchmark_fn(run_inplace, args.warmup, args.iters) + + # --- FlashInfer (clone each iter, same as other paths) --- + def run_flashinfer(src=logits_src, t=temps_1d): + l = src.clone() + flashinfer_softmax(l, temperature=t) + + t_fi = benchmark_fn(run_flashinfer, args.warmup, args.iters) + + sp_triton = t_base / t_triton + sp_fi = t_base / t_fi + tri_vs_fi = t_fi / t_triton + print( + f"{bs:>5} {vocab:>7} {str(dtype):>8} " + f"{t_base:>14.1f} {t_triton:>12.1f} {t_ip:>13.1f} {t_fi:>16.1f} " + f"{sp_triton:>8.2f}x {sp_fi:>7.2f}x {tri_vs_fi:>6.2f}x" + ) + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/layers/fused_sampling.py b/python/sglang/srt/layers/fused_sampling.py new file mode 100644 index 000000000000..55852c2f0f34 --- /dev/null +++ b/python/sglang/srt/layers/fused_sampling.py @@ -0,0 +1,371 @@ +"""Fused Triton kernels for the sampling pipeline. + +Fuses temperature scaling + softmax into a single kernel to reduce +kernel launch overhead and global memory traffic during decode. + +Two kernel variants: + - Single-pass: vocab fits in one tile (1 read + 1 write). Used when + next_power_of_2(vocab) <= 32768. + - Multi-pass: 2-pass online softmax with autotune (2 reads + 1 write). + Used for large vocabs (e.g. 128K+). +""" + +import logging + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + +_MAX_SINGLE_PASS_BLOCK = 32768 + +# --------------------------------------------------------------------------- +# Single-pass kernel: entire vocab fits in one BLOCK_SIZE tile. +# Data stays in registers — only 1 global memory read + 1 write. +# --------------------------------------------------------------------------- + + +@triton.jit +def _single_pass_temperature_softmax_kernel( + logits_ptr, + temperatures_ptr, + output_ptr, + vocab_size, + logits_stride, + output_stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + + x = tl.load( + logits_ptr + row_idx * logits_stride + offsets, + mask=mask, + other=float("-inf"), + ) + x = (x / temp).to(tl.float32) + + x_max = tl.max(x, axis=0) + exp_x = tl.exp(x - x_max) + prob = exp_x / tl.sum(exp_x, axis=0) + + tl.store(output_ptr + row_idx * output_stride + offsets, prob, mask=mask) + + +@triton.jit +def _single_pass_temperature_softmax_inplace_kernel( + logits_ptr, + temperatures_ptr, + vocab_size, + stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + row_start = logits_ptr + row_idx * stride + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + + x_max = tl.max(x, axis=0) + exp_x = tl.exp(x - x_max) + prob = exp_x / tl.sum(exp_x, axis=0) + + tl.store(row_start + offsets, prob, mask=mask) + + +# --------------------------------------------------------------------------- +# Multi-pass kernel: vocab too large for one tile. +# 2-pass online softmax with autotune over (BLOCK_SIZE, num_warps). +# --------------------------------------------------------------------------- + +_MULTI_PASS_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config({"BLOCK_SIZE": 2048}, num_warps=16), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16), + triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=4), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=16), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=32), + triton.Config({"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4), + triton.Config({"BLOCK_SIZE": 16384}, num_warps=16), + triton.Config({"BLOCK_SIZE": 16384}, num_warps=32), + triton.Config({"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4), + triton.Config({"BLOCK_SIZE": 32768}, num_warps=32), + triton.Config({"BLOCK_SIZE": 32768}, num_warps=32, num_stages=4), +] + + +@triton.autotune(configs=_MULTI_PASS_AUTOTUNE_CONFIGS, key=["vocab_size"]) +@triton.jit +def _multi_pass_temperature_softmax_kernel( + logits_ptr, + temperatures_ptr, + output_ptr, + vocab_size, + logits_stride, + output_stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + logits_row = logits_ptr + row_idx * logits_stride + output_row = output_ptr + row_idx * output_stride + + # Pass 1: find global max (matches PyTorch's first reduction pass) + global_max = tl.full([], value=float("-inf"), dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + global_max = tl.maximum(global_max, tl.max(x, axis=0)) + + # Pass 2: compute sum of exp(x - max) (matches PyTorch's second pass) + sum_exp = tl.full([], value=0.0, dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + sum_exp += tl.sum(tl.exp(x - global_max), axis=0) + + # Pass 3: normalize (matches PyTorch's exp(x-max)/sum) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + + prob = tl.exp(x - global_max) / sum_exp + tl.store(output_row + offsets, prob, mask=mask) + + +@triton.jit +def _multi_pass_temperature_softmax_inplace_kernel( + logits_ptr, + temperatures_ptr, + vocab_size, + stride, + BLOCK_SIZE: tl.constexpr, +): + row_idx = tl.program_id(0) + temp = tl.load(temperatures_ptr + row_idx) + + row_start = logits_ptr + row_idx * stride + + # Pass 1: find global max (matches PyTorch's first reduction pass) + global_max = tl.full([], value=float("-inf"), dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + global_max = tl.maximum(global_max, tl.max(x, axis=0)) + + # Pass 2: compute sum of exp(x - max) (matches PyTorch's second pass) + sum_exp = tl.full([], value=0.0, dtype=tl.float32) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + sum_exp += tl.sum(tl.exp(x - global_max), axis=0) + + # Pass 3: normalize (matches PyTorch's exp(x-max)/sum) + for start in range(0, vocab_size, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) + x = (x / temp).to(tl.float32) + + prob = tl.exp(x - global_max) / sum_exp + tl.store(row_start + offsets, prob, mask=mask) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +_DEFAULT_MULTI_PASS_CONFIG = {"BLOCK_SIZE": 4096, "num_warps": 16} + +# Populated by warmup from the out-of-place kernel's autotune result. +_multi_pass_inplace_config: dict | None = None + + +def _single_pass_num_warps(block_size: int) -> int: + return max(4, min(32, block_size // 256)) + + +def _get_multi_pass_inplace_config() -> dict: + """Return the launch config for the multi-pass in-place kernel.""" + if _multi_pass_inplace_config is not None: + return _multi_pass_inplace_config + return _DEFAULT_MULTI_PASS_CONFIG + + +def _dispatch_kernel( + logits: torch.Tensor, + temperatures_flat: torch.Tensor, + vocab_size: int, + batch_size: int, + output: torch.Tensor = None, +) -> None: + """Dispatch to single-pass or multi-pass kernel. output=None means in-place.""" + grid = (batch_size,) + block_size = triton.next_power_of_2(vocab_size) + inplace = output is None + + if block_size <= _MAX_SINGLE_PASS_BLOCK: + if inplace: + _single_pass_temperature_softmax_inplace_kernel[grid]( + logits, + temperatures_flat, + vocab_size, + logits.stride(0), + BLOCK_SIZE=block_size, + num_warps=_single_pass_num_warps(block_size), + ) + else: + _single_pass_temperature_softmax_kernel[grid]( + logits, + temperatures_flat, + output, + vocab_size, + logits.stride(0), + output.stride(0), + BLOCK_SIZE=block_size, + num_warps=_single_pass_num_warps(block_size), + ) + else: + if inplace: + cfg = _get_multi_pass_inplace_config() + _multi_pass_temperature_softmax_inplace_kernel[grid]( + logits, + temperatures_flat, + vocab_size, + logits.stride(0), + **cfg, + ) + else: + _multi_pass_temperature_softmax_kernel[grid]( + logits, + temperatures_flat, + output, + vocab_size, + logits.stride(0), + output.stride(0), + ) + + +def fused_temperature_softmax( + logits: torch.Tensor, + temperatures: torch.Tensor, +) -> torch.Tensor: + """Fused temperature scaling + softmax. Returns float32 probabilities.""" + batch_size, vocab_size = logits.shape + if batch_size == 0: + return torch.empty(0, vocab_size, dtype=torch.float32, device=logits.device) + + if not logits.is_contiguous(): + logits = logits.contiguous() + + output = torch.empty( + batch_size, vocab_size, dtype=torch.float32, device=logits.device + ) + temperatures_flat = temperatures.contiguous().view(-1) + _dispatch_kernel(logits, temperatures_flat, vocab_size, batch_size, output) + return output + + +def fused_temperature_softmax_inplace( + logits: torch.Tensor, + temperatures: torch.Tensor, +) -> None: + """In-place fused temperature scaling + softmax. Overwrites logits with probabilities.""" + batch_size, vocab_size = logits.shape + if batch_size == 0: + return + + if not logits.is_contiguous(): + work = logits.contiguous() + fused_temperature_softmax_inplace(work, temperatures) + logits.copy_(work) + return + + temperatures_flat = temperatures.contiguous().view(-1) + _dispatch_kernel(logits, temperatures_flat, vocab_size, batch_size) + + +def warmup_fused_temperature_softmax( + vocab_size: int, + device: torch.device | int | None = None, + logits_dtype: torch.dtype = torch.float32, +) -> None: + """Pre-compile and autotune kernels at startup so first request has no latency spike. + + For multi-pass kernels the out-of-place variant is autotuned (safe — separate + input/output buffers), and its winning config is reused for the in-place + variant so that no autotune ever runs on a live logits buffer. + + ``logits_dtype`` should match ``next_token_logits`` at inference (usually + ``model_config.dtype``) so Triton specializes the same way as in production. + """ + global _multi_pass_inplace_config + + if device is None: + device = torch.cuda.current_device() + + block_size = triton.next_power_of_2(vocab_size) + is_multi_pass = block_size > _MAX_SINGLE_PASS_BLOCK + label = "multi-pass autotune" if is_multi_pass else "single-pass JIT" + logger.info( + "Warming up fused_temperature_softmax (%s, vocab_size=%d, logits_dtype=%s) ...", + label, + vocab_size, + logits_dtype, + ) + + dummy_logits = torch.randn(1, vocab_size, dtype=logits_dtype, device=device) + dummy_temps = torch.ones(1, 1, dtype=torch.float32, device=device) + + # 1. Out-of-place kernel: autotune runs here (safe, separate buffers). + fused_temperature_softmax(dummy_logits, dummy_temps) + + # 2. Propagate best config to the in-place kernel (no autotune needed). + if is_multi_pass: + best = getattr(_multi_pass_temperature_softmax_kernel, "best_config", None) + if best is not None: + _multi_pass_inplace_config = { + "BLOCK_SIZE": best.kwargs["BLOCK_SIZE"], + "num_warps": best.num_warps, + } + if best.num_stages is not None: + _multi_pass_inplace_config["num_stages"] = best.num_stages + ns = _multi_pass_inplace_config.get("num_stages", "default") + logger.info( + "Multi-pass autotune result: BLOCK_SIZE=%d, num_warps=%d, num_stages=%s", + _multi_pass_inplace_config["BLOCK_SIZE"], + _multi_pass_inplace_config["num_warps"], + ns, + ) + else: + _multi_pass_inplace_config = None + logger.warning( + "Multi-pass fused softmax: autotune did not set best_config; " + "using default launch config for in-place kernel." + ) + + # 3. In-place kernel: JIT compile only (uses the config from step 2). + fused_temperature_softmax_inplace(dummy_logits.clone(), dummy_temps) + torch.cuda.synchronize(device) + + logger.info("fused_temperature_softmax warmup done (vocab_size=%d).", vocab_size) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index e947a48cbde8..4196787820f4 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -18,6 +18,7 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils.common import crash_on_warnings, get_bool_env_var, is_cuda, is_npu +_use_fused_sampling = False if is_cuda(): from flashinfer.sampling import ( min_p_sampling_from_probs, @@ -27,6 +28,15 @@ top_k_renorm_prob, top_p_renorm_prob, ) + + from sglang.srt.layers.fused_sampling import fused_temperature_softmax_inplace + + _use_fused_sampling = True + +# Batch size threshold for fused Triton kernel vs PyTorch softmax. +# Below this threshold, PyTorch's native div+softmax is faster. +# At and above this threshold, the fused Triton kernel wins. +_FUSED_SAMPLING_BATCH_THRESHOLD = 128 if is_npu(): import torch_npu @@ -152,11 +162,20 @@ def forward( logprobs = logprobs_via_logsoftmax_kernel else: # Standard path: do softmax and sample from probs. - logits.div_(sampling_info.temperatures) - - # In-place op to save memory - logits[:] = torch.softmax(logits, dim=-1) - probs = logits + # Use fused Triton kernel for large batches where it excels; + # fall back to PyTorch for small batches where launch overhead dominates. + if ( + _use_fused_sampling + and logits.shape[0] >= _FUSED_SAMPLING_BATCH_THRESHOLD + ): + fused_temperature_softmax_inplace( + logits, sampling_info.temperatures + ) + probs = logits + else: + logits.div_(sampling_info.temperatures) + logits[:] = torch.softmax(logits, dim=-1) + probs = logits batch_next_token_ids = self._sample_from_probs( probs, sampling_info, positions, simple_sampling_case diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b1dee3fb6c31..80b0e959bfa0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2019,6 +2019,22 @@ def kernel_warmup(self): if self._should_run_flashinfer_autotune(): self._flashinfer_autotune() + self._warmup_fused_sampling() + + def _warmup_fused_sampling(self): + """Pre-compile and autotune fused sampling Triton kernels.""" + if _is_hip: + return + from sglang.srt.layers.fused_sampling import warmup_fused_temperature_softmax + + logits_warmup_dtype = ( + torch.float32 if self.server_args.enable_fp32_lm_head else self.dtype + ) + warmup_fused_temperature_softmax( + self.model_config.vocab_size, + logits_dtype=logits_warmup_dtype, + ) + def _should_run_flashinfer_autotune(self) -> bool: """Check if flashinfer autotune should be run.""" if self.server_args.disable_flashinfer_autotune: diff --git a/test/registered/sampling/test_fused_temperature_softmax.py b/test/registered/sampling/test_fused_temperature_softmax.py new file mode 100644 index 000000000000..956d67c2cefe --- /dev/null +++ b/test/registered/sampling/test_fused_temperature_softmax.py @@ -0,0 +1,268 @@ +"""Correctness tests for fused_temperature_softmax Triton kernel.""" + +import unittest + +import torch +from flashinfer.sampling import softmax as flashinfer_softmax + +from sglang.srt.layers.fused_sampling import ( + fused_temperature_softmax, + fused_temperature_softmax_inplace, +) +from sglang.srt.utils import get_device +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=15, suite="stage-b-test-small-1-gpu") + + +def reference_temperature_softmax(logits, temperatures): + """Reference implementation: div + softmax (separate kernels).""" + logits = logits.clone() + logits.div_(temperatures) + return torch.softmax(logits, dim=-1).float() + + +class TestFusedTemperatureSoftmax(CustomTestCase): + @classmethod + def setUpClass(cls): + torch.set_default_device(get_device()) + torch.manual_seed(42) + + def _check_close(self, fused, ref, atol=1e-5, rtol=1e-5): + """Assert outputs are close and both are valid probability distributions.""" + self.assertEqual(fused.shape, ref.shape) + # Valid probabilities: non-negative, sum to ~1 + self.assertTrue((fused >= 0).all(), f"Negative probabilities in fused output") + row_sums = fused.sum(dim=-1) + torch.testing.assert_close( + row_sums, + torch.ones_like(row_sums), + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close(fused, ref, atol=atol, rtol=rtol) + + # --- out-of-place kernel --- + + def test_basic(self): + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.tensor([0.7, 1.0, 1.5, 2.0], dtype=torch.float32).view(-1, 1) + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_large_vocab(self): + logits = torch.randn(8, 128256, dtype=torch.bfloat16) + temps = torch.full((8, 1), 0.6, dtype=torch.float32) + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_batch_sizes(self): + for bs in [1, 2, 16, 64, 128, 512]: + logits = torch.randn(bs, 32000, dtype=torch.bfloat16) + temps = torch.rand(bs, 1, dtype=torch.float32) * 1.5 + 0.1 + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_temperature_one(self): + """Temperature=1.0 should be equivalent to plain softmax.""" + logits = torch.randn(16, 32000, dtype=torch.bfloat16) + temps = torch.ones(16, 1, dtype=torch.float32) + ref = torch.softmax(logits.float(), dim=-1) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_very_low_temperature(self): + """Very low temperature should produce near-one-hot distribution.""" + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.full((4, 1), 0.01, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + # Max probability should be very close to 1.0 + max_probs = fused.max(dim=-1).values + self.assertTrue((max_probs > 0.99).all()) + + def test_very_high_temperature(self): + """Very high temperature should produce near-uniform distribution.""" + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.full((4, 1), 100.0, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + uniform = 1.0 / 1024 + self.assertTrue( + (fused - uniform).abs().max() < 0.01, + "High temperature should produce near-uniform distribution", + ) + + def test_fp16_input(self): + logits = torch.randn(8, 32000, dtype=torch.float16) + temps = torch.rand(8, 1, dtype=torch.float32) * 1.5 + 0.1 + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-3, rtol=1e-2) + + def test_fp32_input(self): + logits = torch.randn(8, 32000, dtype=torch.float32) + temps = torch.rand(8, 1, dtype=torch.float32) + 0.5 + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-5, rtol=1e-5) + + def test_mixed_temperatures(self): + """Each row has a different temperature.""" + logits = torch.randn(8, 32000, dtype=torch.bfloat16) + temps = torch.tensor( + [0.1, 0.5, 0.7, 1.0, 1.2, 1.5, 2.0, 5.0], dtype=torch.float32 + ).view(-1, 1) + ref = reference_temperature_softmax(logits, temps) + fused = fused_temperature_softmax(logits, temps) + self._check_close(fused, ref, atol=1e-4, rtol=1e-3) + + def test_empty_batch(self): + logits = torch.randn(0, 32000, dtype=torch.bfloat16) + temps = torch.ones(0, 1, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + self.assertEqual(fused.shape, (0, 32000)) + + # --- in-place kernel --- + + def test_inplace_basic(self): + logits = torch.randn(8, 32000, dtype=torch.float32) + temps = torch.rand(8, 1, dtype=torch.float32) * 1.5 + 0.1 + ref = reference_temperature_softmax(logits, temps) + fused_temperature_softmax_inplace(logits, temps) + # In-place writes back to logits in the original dtype + self._check_close(logits.float(), ref, atol=1e-5, rtol=1e-5) + + def test_inplace_bf16(self): + logits = torch.randn(8, 32000, dtype=torch.bfloat16) + temps = torch.rand(8, 1, dtype=torch.float32) + 0.5 + ref = reference_temperature_softmax(logits, temps) + fused_temperature_softmax_inplace(logits, temps) + self._check_close(logits.float(), ref, atol=2e-3, rtol=2e-3) + + def test_inplace_large_vocab(self): + logits = torch.randn(4, 128256, dtype=torch.bfloat16) + temps = torch.full((4, 1), 0.8, dtype=torch.float32) + ref = reference_temperature_softmax(logits, temps) + fused_temperature_softmax_inplace(logits, temps) + self._check_close(logits.float(), ref, atol=2e-3, rtol=2e-3) + + # --- exact known-value correctness --- + + def test_known_uniform_logits(self): + """Identical logits must produce uniform distribution regardless of temperature.""" + logits = torch.zeros(2, 5, dtype=torch.float32) + temps = torch.tensor([0.5, 2.0], dtype=torch.float32).view(-1, 1) + fused = fused_temperature_softmax(logits, temps) + expected = torch.full((2, 5), 0.2, dtype=torch.float32, device=fused.device) + torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) + + def test_known_softmax_values(self): + """Verify against hand-computed softmax(logits / T).""" + logits = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) + temps = torch.tensor([[1.0]], dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + # softmax([1,2,3]) = exp([1,2,3]) / sum(exp([1,2,3])) + e = torch.exp(logits) + expected = (e / e.sum(dim=-1, keepdim=True)).to(fused.device) + torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) + + def test_known_softmax_with_temperature(self): + """Verify softmax([1,2,3] / 0.5) against hand computation.""" + logits = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) + temps = torch.tensor([[0.5]], dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + scaled = logits / 0.5 + e = torch.exp(scaled) + expected = (e / e.sum(dim=-1, keepdim=True)).to(fused.device) + torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) + + # --- argmax preservation --- + + def test_argmax_preserved(self): + """argmax must be invariant to temperature for finite T > 0.""" + logits = torch.randn(64, 32000, dtype=torch.bfloat16) + original_argmax = logits.float().argmax(dim=-1) + for t_val in [0.1, 0.5, 1.0, 2.0, 10.0]: + temps = torch.full((64, 1), t_val, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + fused_argmax = fused.argmax(dim=-1) + self.assertTrue( + (original_argmax == fused_argmax).all(), + f"argmax changed at temperature={t_val}", + ) + + # --- numerical stability --- + + def test_large_logits_no_nan(self): + """Extreme logit magnitudes must not produce NaN or Inf.""" + logits = torch.tensor( + [[1e6, -1e6, 0.0], [1e4, 1e4 + 1, 1e4 - 1]], dtype=torch.float32 + ) + temps = torch.tensor([[1.0], [0.01]], dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + self.assertFalse(torch.isnan(fused).any(), "NaN in output") + self.assertFalse(torch.isinf(fused).any(), "Inf in output") + row_sums = fused.sum(dim=-1) + torch.testing.assert_close( + row_sums, + torch.ones_like(row_sums), + atol=1e-4, + rtol=1e-4, + ) + + def test_large_logits_inplace_no_nan(self): + """In-place variant: extreme logits must not produce NaN or Inf.""" + logits = torch.tensor( + [[1e6, -1e6, 0.0], [1e4, 1e4 + 1, 1e4 - 1]], dtype=torch.float32 + ) + temps = torch.tensor([[1.0], [0.01]], dtype=torch.float32) + fused_temperature_softmax_inplace(logits, temps) + self.assertFalse(torch.isnan(logits).any(), "NaN in output") + self.assertFalse(torch.isinf(logits).any(), "Inf in output") + + # --- comparison with flashinfer.sampling.softmax --- + + def test_vs_flashinfer_basic(self): + logits = torch.randn(4, 1024, dtype=torch.bfloat16) + temps = torch.tensor([0.7, 1.0, 1.5, 2.0], dtype=torch.float32).view(-1, 1) + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_large_vocab(self): + logits = torch.randn(8, 128256, dtype=torch.bfloat16) + temps = torch.full((8, 1), 0.6, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_batch_sizes(self): + for bs in [1, 16, 64, 128, 512]: + logits = torch.randn(bs, 32000, dtype=torch.bfloat16) + temps = torch.rand(bs, 1, dtype=torch.float32) * 1.5 + 0.1 + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_scalar_temperature(self): + logits = torch.randn(16, 32000, dtype=torch.bfloat16) + temps_2d = torch.full((16, 1), 0.8, dtype=torch.float32) + fused = fused_temperature_softmax(logits, temps_2d) + fi = flashinfer_softmax(logits, temperature=0.8) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + def test_vs_flashinfer_mixed_temperatures(self): + logits = torch.randn(8, 32000, dtype=torch.bfloat16) + temps = torch.tensor( + [0.1, 0.5, 0.7, 1.0, 1.2, 1.5, 2.0, 5.0], dtype=torch.float32 + ).view(-1, 1) + fused = fused_temperature_softmax(logits, temps) + fi = flashinfer_softmax(logits, temperature=temps.view(-1)) + self._check_close(fused, fi, atol=1e-4, rtol=1e-3) + + +if __name__ == "__main__": + unittest.main()