diff --git a/benchmark/kernels/bench_fused_temperature_softmax.py b/benchmark/kernels/bench_fused_temperature_softmax.py deleted file mode 100644 index fc624b721ecf..000000000000 --- a/benchmark/kernels/bench_fused_temperature_softmax.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Benchmark: fused_temperature_softmax vs separate div_ + softmax vs flashinfer.sampling.softmax. - -Each path clones logits every iteration so timing is not skewed by in-place reuse. -Uses torch.cuda.Event timing; default 50 warmup, 200 timed iterations. - -Columns tri/base and fi/base are speedup vs PyTorch baseline; tri/fi is t_flashinfer/t_triton -(>1 means Triton is faster). -""" - -import argparse - -import torch - - -def benchmark_fn(fn, warmup=50, iters=200): - """Time a zero-arg callable using CUDA events.""" - for _ in range(warmup): - fn() - torch.cuda.synchronize() - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): - fn() - end.record() - torch.cuda.synchronize() - return start.elapsed_time(end) / iters * 1000 # microseconds - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--warmup", type=int, default=50) - parser.add_argument("--iters", type=int, default=200) - args = parser.parse_args() - - from flashinfer.sampling import softmax as flashinfer_softmax - - from sglang.srt.layers.fused_sampling import ( - fused_temperature_softmax, - fused_temperature_softmax_inplace, - ) - - configs = [ - # (batch_size, vocab_size, dtype) - (1, 32000, torch.bfloat16), - (1, 128256, torch.bfloat16), - (32, 32000, torch.bfloat16), - (32, 128256, torch.bfloat16), - (128, 32000, torch.bfloat16), - (128, 128256, torch.bfloat16), - (512, 32000, torch.bfloat16), - (512, 128256, torch.bfloat16), - ] - - header = ( - f"{'bs':>5} {'vocab':>7} {'dtype':>8} " - f"{'baseline (us)':>14} {'triton (us)':>12} {'inplace (us)':>13} {'flashinfer (us)':>16} " - f"{'tri/base':>9} {'fi/base':>8} {'tri/fi':>7}" - ) - print(header) - print("-" * len(header)) - - for bs, vocab, dtype in configs: - temps = torch.rand(bs, 1, dtype=torch.float32, device="cuda") * 1.5 + 0.1 - temps_1d = temps.view(-1) - logits_src = torch.randn(bs, vocab, dtype=dtype, device="cuda") - - # --- Baseline: div_ + softmax --- - def run_baseline(src=logits_src, t=temps): - l = src.clone() - l.div_(t) - l[:] = torch.softmax(l, dim=-1) - - t_base = benchmark_fn(run_baseline, args.warmup, args.iters) - - # --- Triton fused (out-of-place) --- - def run_triton(src=logits_src, t=temps): - fused_temperature_softmax(src.clone(), t) - - t_triton = benchmark_fn(run_triton, args.warmup, args.iters) - - # --- Triton fused (in-place) --- - def run_inplace(src=logits_src, t=temps): - l = src.clone() - fused_temperature_softmax_inplace(l, t) - - t_ip = benchmark_fn(run_inplace, args.warmup, args.iters) - - # --- FlashInfer (clone each iter, same as other paths) --- - def run_flashinfer(src=logits_src, t=temps_1d): - l = src.clone() - flashinfer_softmax(l, temperature=t) - - t_fi = benchmark_fn(run_flashinfer, args.warmup, args.iters) - - sp_triton = t_base / t_triton - sp_fi = t_base / t_fi - tri_vs_fi = t_fi / t_triton - print( - f"{bs:>5} {vocab:>7} {str(dtype):>8} " - f"{t_base:>14.1f} {t_triton:>12.1f} {t_ip:>13.1f} {t_fi:>16.1f} " - f"{sp_triton:>8.2f}x {sp_fi:>7.2f}x {tri_vs_fi:>6.2f}x" - ) - - -if __name__ == "__main__": - main() diff --git a/python/sglang/srt/layers/fused_sampling.py b/python/sglang/srt/layers/fused_sampling.py deleted file mode 100644 index 55852c2f0f34..000000000000 --- a/python/sglang/srt/layers/fused_sampling.py +++ /dev/null @@ -1,371 +0,0 @@ -"""Fused Triton kernels for the sampling pipeline. - -Fuses temperature scaling + softmax into a single kernel to reduce -kernel launch overhead and global memory traffic during decode. - -Two kernel variants: - - Single-pass: vocab fits in one tile (1 read + 1 write). Used when - next_power_of_2(vocab) <= 32768. - - Multi-pass: 2-pass online softmax with autotune (2 reads + 1 write). - Used for large vocabs (e.g. 128K+). -""" - -import logging - -import torch -import triton -import triton.language as tl - -logger = logging.getLogger(__name__) - -_MAX_SINGLE_PASS_BLOCK = 32768 - -# --------------------------------------------------------------------------- -# Single-pass kernel: entire vocab fits in one BLOCK_SIZE tile. -# Data stays in registers — only 1 global memory read + 1 write. -# --------------------------------------------------------------------------- - - -@triton.jit -def _single_pass_temperature_softmax_kernel( - logits_ptr, - temperatures_ptr, - output_ptr, - vocab_size, - logits_stride, - output_stride, - BLOCK_SIZE: tl.constexpr, -): - row_idx = tl.program_id(0) - temp = tl.load(temperatures_ptr + row_idx) - - offsets = tl.arange(0, BLOCK_SIZE) - mask = offsets < vocab_size - - x = tl.load( - logits_ptr + row_idx * logits_stride + offsets, - mask=mask, - other=float("-inf"), - ) - x = (x / temp).to(tl.float32) - - x_max = tl.max(x, axis=0) - exp_x = tl.exp(x - x_max) - prob = exp_x / tl.sum(exp_x, axis=0) - - tl.store(output_ptr + row_idx * output_stride + offsets, prob, mask=mask) - - -@triton.jit -def _single_pass_temperature_softmax_inplace_kernel( - logits_ptr, - temperatures_ptr, - vocab_size, - stride, - BLOCK_SIZE: tl.constexpr, -): - row_idx = tl.program_id(0) - temp = tl.load(temperatures_ptr + row_idx) - - row_start = logits_ptr + row_idx * stride - offsets = tl.arange(0, BLOCK_SIZE) - mask = offsets < vocab_size - - x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) - x = (x / temp).to(tl.float32) - - x_max = tl.max(x, axis=0) - exp_x = tl.exp(x - x_max) - prob = exp_x / tl.sum(exp_x, axis=0) - - tl.store(row_start + offsets, prob, mask=mask) - - -# --------------------------------------------------------------------------- -# Multi-pass kernel: vocab too large for one tile. -# 2-pass online softmax with autotune over (BLOCK_SIZE, num_warps). -# --------------------------------------------------------------------------- - -_MULTI_PASS_AUTOTUNE_CONFIGS = [ - triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), - triton.Config({"BLOCK_SIZE": 2048}, num_warps=16), - triton.Config({"BLOCK_SIZE": 4096}, num_warps=8), - triton.Config({"BLOCK_SIZE": 4096}, num_warps=16), - triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=4), - triton.Config({"BLOCK_SIZE": 8192}, num_warps=16), - triton.Config({"BLOCK_SIZE": 8192}, num_warps=32), - triton.Config({"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4), - triton.Config({"BLOCK_SIZE": 16384}, num_warps=16), - triton.Config({"BLOCK_SIZE": 16384}, num_warps=32), - triton.Config({"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4), - triton.Config({"BLOCK_SIZE": 32768}, num_warps=32), - triton.Config({"BLOCK_SIZE": 32768}, num_warps=32, num_stages=4), -] - - -@triton.autotune(configs=_MULTI_PASS_AUTOTUNE_CONFIGS, key=["vocab_size"]) -@triton.jit -def _multi_pass_temperature_softmax_kernel( - logits_ptr, - temperatures_ptr, - output_ptr, - vocab_size, - logits_stride, - output_stride, - BLOCK_SIZE: tl.constexpr, -): - row_idx = tl.program_id(0) - temp = tl.load(temperatures_ptr + row_idx) - - logits_row = logits_ptr + row_idx * logits_stride - output_row = output_ptr + row_idx * output_stride - - # Pass 1: find global max (matches PyTorch's first reduction pass) - global_max = tl.full([], value=float("-inf"), dtype=tl.float32) - for start in range(0, vocab_size, BLOCK_SIZE): - offsets = start + tl.arange(0, BLOCK_SIZE) - mask = offsets < vocab_size - x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) - x = (x / temp).to(tl.float32) - global_max = tl.maximum(global_max, tl.max(x, axis=0)) - - # Pass 2: compute sum of exp(x - max) (matches PyTorch's second pass) - sum_exp = tl.full([], value=0.0, dtype=tl.float32) - for start in range(0, vocab_size, BLOCK_SIZE): - offsets = start + tl.arange(0, BLOCK_SIZE) - mask = offsets < vocab_size - x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) - x = (x / temp).to(tl.float32) - sum_exp += tl.sum(tl.exp(x - global_max), axis=0) - - # Pass 3: normalize (matches PyTorch's exp(x-max)/sum) - for start in range(0, vocab_size, BLOCK_SIZE): - offsets = start + tl.arange(0, BLOCK_SIZE) - mask = offsets < vocab_size - x = tl.load(logits_row + offsets, mask=mask, other=float("-inf")) - x = (x / temp).to(tl.float32) - - prob = tl.exp(x - global_max) / sum_exp - tl.store(output_row + offsets, prob, mask=mask) - - -@triton.jit -def _multi_pass_temperature_softmax_inplace_kernel( - logits_ptr, - temperatures_ptr, - vocab_size, - stride, - BLOCK_SIZE: tl.constexpr, -): - row_idx = tl.program_id(0) - temp = tl.load(temperatures_ptr + row_idx) - - row_start = logits_ptr + row_idx * stride - - # Pass 1: find global max (matches PyTorch's first reduction pass) - global_max = tl.full([], value=float("-inf"), dtype=tl.float32) - for start in range(0, vocab_size, BLOCK_SIZE): - offsets = start + tl.arange(0, BLOCK_SIZE) - mask = offsets < vocab_size - x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) - x = (x / temp).to(tl.float32) - global_max = tl.maximum(global_max, tl.max(x, axis=0)) - - # Pass 2: compute sum of exp(x - max) (matches PyTorch's second pass) - sum_exp = tl.full([], value=0.0, dtype=tl.float32) - for start in range(0, vocab_size, BLOCK_SIZE): - offsets = start + tl.arange(0, BLOCK_SIZE) - mask = offsets < vocab_size - x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) - x = (x / temp).to(tl.float32) - sum_exp += tl.sum(tl.exp(x - global_max), axis=0) - - # Pass 3: normalize (matches PyTorch's exp(x-max)/sum) - for start in range(0, vocab_size, BLOCK_SIZE): - offsets = start + tl.arange(0, BLOCK_SIZE) - mask = offsets < vocab_size - x = tl.load(row_start + offsets, mask=mask, other=float("-inf")) - x = (x / temp).to(tl.float32) - - prob = tl.exp(x - global_max) / sum_exp - tl.store(row_start + offsets, prob, mask=mask) - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - -_DEFAULT_MULTI_PASS_CONFIG = {"BLOCK_SIZE": 4096, "num_warps": 16} - -# Populated by warmup from the out-of-place kernel's autotune result. -_multi_pass_inplace_config: dict | None = None - - -def _single_pass_num_warps(block_size: int) -> int: - return max(4, min(32, block_size // 256)) - - -def _get_multi_pass_inplace_config() -> dict: - """Return the launch config for the multi-pass in-place kernel.""" - if _multi_pass_inplace_config is not None: - return _multi_pass_inplace_config - return _DEFAULT_MULTI_PASS_CONFIG - - -def _dispatch_kernel( - logits: torch.Tensor, - temperatures_flat: torch.Tensor, - vocab_size: int, - batch_size: int, - output: torch.Tensor = None, -) -> None: - """Dispatch to single-pass or multi-pass kernel. output=None means in-place.""" - grid = (batch_size,) - block_size = triton.next_power_of_2(vocab_size) - inplace = output is None - - if block_size <= _MAX_SINGLE_PASS_BLOCK: - if inplace: - _single_pass_temperature_softmax_inplace_kernel[grid]( - logits, - temperatures_flat, - vocab_size, - logits.stride(0), - BLOCK_SIZE=block_size, - num_warps=_single_pass_num_warps(block_size), - ) - else: - _single_pass_temperature_softmax_kernel[grid]( - logits, - temperatures_flat, - output, - vocab_size, - logits.stride(0), - output.stride(0), - BLOCK_SIZE=block_size, - num_warps=_single_pass_num_warps(block_size), - ) - else: - if inplace: - cfg = _get_multi_pass_inplace_config() - _multi_pass_temperature_softmax_inplace_kernel[grid]( - logits, - temperatures_flat, - vocab_size, - logits.stride(0), - **cfg, - ) - else: - _multi_pass_temperature_softmax_kernel[grid]( - logits, - temperatures_flat, - output, - vocab_size, - logits.stride(0), - output.stride(0), - ) - - -def fused_temperature_softmax( - logits: torch.Tensor, - temperatures: torch.Tensor, -) -> torch.Tensor: - """Fused temperature scaling + softmax. Returns float32 probabilities.""" - batch_size, vocab_size = logits.shape - if batch_size == 0: - return torch.empty(0, vocab_size, dtype=torch.float32, device=logits.device) - - if not logits.is_contiguous(): - logits = logits.contiguous() - - output = torch.empty( - batch_size, vocab_size, dtype=torch.float32, device=logits.device - ) - temperatures_flat = temperatures.contiguous().view(-1) - _dispatch_kernel(logits, temperatures_flat, vocab_size, batch_size, output) - return output - - -def fused_temperature_softmax_inplace( - logits: torch.Tensor, - temperatures: torch.Tensor, -) -> None: - """In-place fused temperature scaling + softmax. Overwrites logits with probabilities.""" - batch_size, vocab_size = logits.shape - if batch_size == 0: - return - - if not logits.is_contiguous(): - work = logits.contiguous() - fused_temperature_softmax_inplace(work, temperatures) - logits.copy_(work) - return - - temperatures_flat = temperatures.contiguous().view(-1) - _dispatch_kernel(logits, temperatures_flat, vocab_size, batch_size) - - -def warmup_fused_temperature_softmax( - vocab_size: int, - device: torch.device | int | None = None, - logits_dtype: torch.dtype = torch.float32, -) -> None: - """Pre-compile and autotune kernels at startup so first request has no latency spike. - - For multi-pass kernels the out-of-place variant is autotuned (safe — separate - input/output buffers), and its winning config is reused for the in-place - variant so that no autotune ever runs on a live logits buffer. - - ``logits_dtype`` should match ``next_token_logits`` at inference (usually - ``model_config.dtype``) so Triton specializes the same way as in production. - """ - global _multi_pass_inplace_config - - if device is None: - device = torch.cuda.current_device() - - block_size = triton.next_power_of_2(vocab_size) - is_multi_pass = block_size > _MAX_SINGLE_PASS_BLOCK - label = "multi-pass autotune" if is_multi_pass else "single-pass JIT" - logger.info( - "Warming up fused_temperature_softmax (%s, vocab_size=%d, logits_dtype=%s) ...", - label, - vocab_size, - logits_dtype, - ) - - dummy_logits = torch.randn(1, vocab_size, dtype=logits_dtype, device=device) - dummy_temps = torch.ones(1, 1, dtype=torch.float32, device=device) - - # 1. Out-of-place kernel: autotune runs here (safe, separate buffers). - fused_temperature_softmax(dummy_logits, dummy_temps) - - # 2. Propagate best config to the in-place kernel (no autotune needed). - if is_multi_pass: - best = getattr(_multi_pass_temperature_softmax_kernel, "best_config", None) - if best is not None: - _multi_pass_inplace_config = { - "BLOCK_SIZE": best.kwargs["BLOCK_SIZE"], - "num_warps": best.num_warps, - } - if best.num_stages is not None: - _multi_pass_inplace_config["num_stages"] = best.num_stages - ns = _multi_pass_inplace_config.get("num_stages", "default") - logger.info( - "Multi-pass autotune result: BLOCK_SIZE=%d, num_warps=%d, num_stages=%s", - _multi_pass_inplace_config["BLOCK_SIZE"], - _multi_pass_inplace_config["num_warps"], - ns, - ) - else: - _multi_pass_inplace_config = None - logger.warning( - "Multi-pass fused softmax: autotune did not set best_config; " - "using default launch config for in-place kernel." - ) - - # 3. In-place kernel: JIT compile only (uses the config from step 2). - fused_temperature_softmax_inplace(dummy_logits.clone(), dummy_temps) - torch.cuda.synchronize(device) - - logger.info("fused_temperature_softmax warmup done (vocab_size=%d).", vocab_size) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 4196787820f4..e947a48cbde8 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -18,7 +18,6 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils.common import crash_on_warnings, get_bool_env_var, is_cuda, is_npu -_use_fused_sampling = False if is_cuda(): from flashinfer.sampling import ( min_p_sampling_from_probs, @@ -28,15 +27,6 @@ top_k_renorm_prob, top_p_renorm_prob, ) - - from sglang.srt.layers.fused_sampling import fused_temperature_softmax_inplace - - _use_fused_sampling = True - -# Batch size threshold for fused Triton kernel vs PyTorch softmax. -# Below this threshold, PyTorch's native div+softmax is faster. -# At and above this threshold, the fused Triton kernel wins. -_FUSED_SAMPLING_BATCH_THRESHOLD = 128 if is_npu(): import torch_npu @@ -162,20 +152,11 @@ def forward( logprobs = logprobs_via_logsoftmax_kernel else: # Standard path: do softmax and sample from probs. - # Use fused Triton kernel for large batches where it excels; - # fall back to PyTorch for small batches where launch overhead dominates. - if ( - _use_fused_sampling - and logits.shape[0] >= _FUSED_SAMPLING_BATCH_THRESHOLD - ): - fused_temperature_softmax_inplace( - logits, sampling_info.temperatures - ) - probs = logits - else: - logits.div_(sampling_info.temperatures) - logits[:] = torch.softmax(logits, dim=-1) - probs = logits + logits.div_(sampling_info.temperatures) + + # In-place op to save memory + logits[:] = torch.softmax(logits, dim=-1) + probs = logits batch_next_token_ids = self._sample_from_probs( probs, sampling_info, positions, simple_sampling_case diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index af14370f8267..046af2d07907 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2019,22 +2019,6 @@ def kernel_warmup(self): if self._should_run_flashinfer_autotune(): self._flashinfer_autotune() - self._warmup_fused_sampling() - - def _warmup_fused_sampling(self): - """Pre-compile and autotune fused sampling Triton kernels.""" - if _is_hip: - return - from sglang.srt.layers.fused_sampling import warmup_fused_temperature_softmax - - logits_warmup_dtype = ( - torch.float32 if self.server_args.enable_fp32_lm_head else self.dtype - ) - warmup_fused_temperature_softmax( - self.model_config.vocab_size, - logits_dtype=logits_warmup_dtype, - ) - def _should_run_flashinfer_autotune(self) -> bool: """Check if flashinfer autotune should be run.""" if self.server_args.disable_flashinfer_autotune: diff --git a/test/registered/sampling/test_fused_temperature_softmax.py b/test/registered/sampling/test_fused_temperature_softmax.py deleted file mode 100644 index 956d67c2cefe..000000000000 --- a/test/registered/sampling/test_fused_temperature_softmax.py +++ /dev/null @@ -1,268 +0,0 @@ -"""Correctness tests for fused_temperature_softmax Triton kernel.""" - -import unittest - -import torch -from flashinfer.sampling import softmax as flashinfer_softmax - -from sglang.srt.layers.fused_sampling import ( - fused_temperature_softmax, - fused_temperature_softmax_inplace, -) -from sglang.srt.utils import get_device -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.test_utils import CustomTestCase - -register_cuda_ci(est_time=15, suite="stage-b-test-small-1-gpu") - - -def reference_temperature_softmax(logits, temperatures): - """Reference implementation: div + softmax (separate kernels).""" - logits = logits.clone() - logits.div_(temperatures) - return torch.softmax(logits, dim=-1).float() - - -class TestFusedTemperatureSoftmax(CustomTestCase): - @classmethod - def setUpClass(cls): - torch.set_default_device(get_device()) - torch.manual_seed(42) - - def _check_close(self, fused, ref, atol=1e-5, rtol=1e-5): - """Assert outputs are close and both are valid probability distributions.""" - self.assertEqual(fused.shape, ref.shape) - # Valid probabilities: non-negative, sum to ~1 - self.assertTrue((fused >= 0).all(), f"Negative probabilities in fused output") - row_sums = fused.sum(dim=-1) - torch.testing.assert_close( - row_sums, - torch.ones_like(row_sums), - atol=1e-4, - rtol=1e-4, - ) - torch.testing.assert_close(fused, ref, atol=atol, rtol=rtol) - - # --- out-of-place kernel --- - - def test_basic(self): - logits = torch.randn(4, 1024, dtype=torch.bfloat16) - temps = torch.tensor([0.7, 1.0, 1.5, 2.0], dtype=torch.float32).view(-1, 1) - ref = reference_temperature_softmax(logits, temps) - fused = fused_temperature_softmax(logits, temps) - self._check_close(fused, ref, atol=1e-4, rtol=1e-3) - - def test_large_vocab(self): - logits = torch.randn(8, 128256, dtype=torch.bfloat16) - temps = torch.full((8, 1), 0.6, dtype=torch.float32) - ref = reference_temperature_softmax(logits, temps) - fused = fused_temperature_softmax(logits, temps) - self._check_close(fused, ref, atol=1e-4, rtol=1e-3) - - def test_batch_sizes(self): - for bs in [1, 2, 16, 64, 128, 512]: - logits = torch.randn(bs, 32000, dtype=torch.bfloat16) - temps = torch.rand(bs, 1, dtype=torch.float32) * 1.5 + 0.1 - ref = reference_temperature_softmax(logits, temps) - fused = fused_temperature_softmax(logits, temps) - self._check_close(fused, ref, atol=1e-4, rtol=1e-3) - - def test_temperature_one(self): - """Temperature=1.0 should be equivalent to plain softmax.""" - logits = torch.randn(16, 32000, dtype=torch.bfloat16) - temps = torch.ones(16, 1, dtype=torch.float32) - ref = torch.softmax(logits.float(), dim=-1) - fused = fused_temperature_softmax(logits, temps) - self._check_close(fused, ref, atol=1e-4, rtol=1e-3) - - def test_very_low_temperature(self): - """Very low temperature should produce near-one-hot distribution.""" - logits = torch.randn(4, 1024, dtype=torch.bfloat16) - temps = torch.full((4, 1), 0.01, dtype=torch.float32) - fused = fused_temperature_softmax(logits, temps) - # Max probability should be very close to 1.0 - max_probs = fused.max(dim=-1).values - self.assertTrue((max_probs > 0.99).all()) - - def test_very_high_temperature(self): - """Very high temperature should produce near-uniform distribution.""" - logits = torch.randn(4, 1024, dtype=torch.bfloat16) - temps = torch.full((4, 1), 100.0, dtype=torch.float32) - fused = fused_temperature_softmax(logits, temps) - uniform = 1.0 / 1024 - self.assertTrue( - (fused - uniform).abs().max() < 0.01, - "High temperature should produce near-uniform distribution", - ) - - def test_fp16_input(self): - logits = torch.randn(8, 32000, dtype=torch.float16) - temps = torch.rand(8, 1, dtype=torch.float32) * 1.5 + 0.1 - ref = reference_temperature_softmax(logits, temps) - fused = fused_temperature_softmax(logits, temps) - self._check_close(fused, ref, atol=1e-3, rtol=1e-2) - - def test_fp32_input(self): - logits = torch.randn(8, 32000, dtype=torch.float32) - temps = torch.rand(8, 1, dtype=torch.float32) + 0.5 - ref = reference_temperature_softmax(logits, temps) - fused = fused_temperature_softmax(logits, temps) - self._check_close(fused, ref, atol=1e-5, rtol=1e-5) - - def test_mixed_temperatures(self): - """Each row has a different temperature.""" - logits = torch.randn(8, 32000, dtype=torch.bfloat16) - temps = torch.tensor( - [0.1, 0.5, 0.7, 1.0, 1.2, 1.5, 2.0, 5.0], dtype=torch.float32 - ).view(-1, 1) - ref = reference_temperature_softmax(logits, temps) - fused = fused_temperature_softmax(logits, temps) - self._check_close(fused, ref, atol=1e-4, rtol=1e-3) - - def test_empty_batch(self): - logits = torch.randn(0, 32000, dtype=torch.bfloat16) - temps = torch.ones(0, 1, dtype=torch.float32) - fused = fused_temperature_softmax(logits, temps) - self.assertEqual(fused.shape, (0, 32000)) - - # --- in-place kernel --- - - def test_inplace_basic(self): - logits = torch.randn(8, 32000, dtype=torch.float32) - temps = torch.rand(8, 1, dtype=torch.float32) * 1.5 + 0.1 - ref = reference_temperature_softmax(logits, temps) - fused_temperature_softmax_inplace(logits, temps) - # In-place writes back to logits in the original dtype - self._check_close(logits.float(), ref, atol=1e-5, rtol=1e-5) - - def test_inplace_bf16(self): - logits = torch.randn(8, 32000, dtype=torch.bfloat16) - temps = torch.rand(8, 1, dtype=torch.float32) + 0.5 - ref = reference_temperature_softmax(logits, temps) - fused_temperature_softmax_inplace(logits, temps) - self._check_close(logits.float(), ref, atol=2e-3, rtol=2e-3) - - def test_inplace_large_vocab(self): - logits = torch.randn(4, 128256, dtype=torch.bfloat16) - temps = torch.full((4, 1), 0.8, dtype=torch.float32) - ref = reference_temperature_softmax(logits, temps) - fused_temperature_softmax_inplace(logits, temps) - self._check_close(logits.float(), ref, atol=2e-3, rtol=2e-3) - - # --- exact known-value correctness --- - - def test_known_uniform_logits(self): - """Identical logits must produce uniform distribution regardless of temperature.""" - logits = torch.zeros(2, 5, dtype=torch.float32) - temps = torch.tensor([0.5, 2.0], dtype=torch.float32).view(-1, 1) - fused = fused_temperature_softmax(logits, temps) - expected = torch.full((2, 5), 0.2, dtype=torch.float32, device=fused.device) - torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) - - def test_known_softmax_values(self): - """Verify against hand-computed softmax(logits / T).""" - logits = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) - temps = torch.tensor([[1.0]], dtype=torch.float32) - fused = fused_temperature_softmax(logits, temps) - # softmax([1,2,3]) = exp([1,2,3]) / sum(exp([1,2,3])) - e = torch.exp(logits) - expected = (e / e.sum(dim=-1, keepdim=True)).to(fused.device) - torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) - - def test_known_softmax_with_temperature(self): - """Verify softmax([1,2,3] / 0.5) against hand computation.""" - logits = torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32) - temps = torch.tensor([[0.5]], dtype=torch.float32) - fused = fused_temperature_softmax(logits, temps) - scaled = logits / 0.5 - e = torch.exp(scaled) - expected = (e / e.sum(dim=-1, keepdim=True)).to(fused.device) - torch.testing.assert_close(fused, expected, atol=1e-6, rtol=1e-6) - - # --- argmax preservation --- - - def test_argmax_preserved(self): - """argmax must be invariant to temperature for finite T > 0.""" - logits = torch.randn(64, 32000, dtype=torch.bfloat16) - original_argmax = logits.float().argmax(dim=-1) - for t_val in [0.1, 0.5, 1.0, 2.0, 10.0]: - temps = torch.full((64, 1), t_val, dtype=torch.float32) - fused = fused_temperature_softmax(logits, temps) - fused_argmax = fused.argmax(dim=-1) - self.assertTrue( - (original_argmax == fused_argmax).all(), - f"argmax changed at temperature={t_val}", - ) - - # --- numerical stability --- - - def test_large_logits_no_nan(self): - """Extreme logit magnitudes must not produce NaN or Inf.""" - logits = torch.tensor( - [[1e6, -1e6, 0.0], [1e4, 1e4 + 1, 1e4 - 1]], dtype=torch.float32 - ) - temps = torch.tensor([[1.0], [0.01]], dtype=torch.float32) - fused = fused_temperature_softmax(logits, temps) - self.assertFalse(torch.isnan(fused).any(), "NaN in output") - self.assertFalse(torch.isinf(fused).any(), "Inf in output") - row_sums = fused.sum(dim=-1) - torch.testing.assert_close( - row_sums, - torch.ones_like(row_sums), - atol=1e-4, - rtol=1e-4, - ) - - def test_large_logits_inplace_no_nan(self): - """In-place variant: extreme logits must not produce NaN or Inf.""" - logits = torch.tensor( - [[1e6, -1e6, 0.0], [1e4, 1e4 + 1, 1e4 - 1]], dtype=torch.float32 - ) - temps = torch.tensor([[1.0], [0.01]], dtype=torch.float32) - fused_temperature_softmax_inplace(logits, temps) - self.assertFalse(torch.isnan(logits).any(), "NaN in output") - self.assertFalse(torch.isinf(logits).any(), "Inf in output") - - # --- comparison with flashinfer.sampling.softmax --- - - def test_vs_flashinfer_basic(self): - logits = torch.randn(4, 1024, dtype=torch.bfloat16) - temps = torch.tensor([0.7, 1.0, 1.5, 2.0], dtype=torch.float32).view(-1, 1) - fused = fused_temperature_softmax(logits, temps) - fi = flashinfer_softmax(logits, temperature=temps.view(-1)) - self._check_close(fused, fi, atol=1e-4, rtol=1e-3) - - def test_vs_flashinfer_large_vocab(self): - logits = torch.randn(8, 128256, dtype=torch.bfloat16) - temps = torch.full((8, 1), 0.6, dtype=torch.float32) - fused = fused_temperature_softmax(logits, temps) - fi = flashinfer_softmax(logits, temperature=temps.view(-1)) - self._check_close(fused, fi, atol=1e-4, rtol=1e-3) - - def test_vs_flashinfer_batch_sizes(self): - for bs in [1, 16, 64, 128, 512]: - logits = torch.randn(bs, 32000, dtype=torch.bfloat16) - temps = torch.rand(bs, 1, dtype=torch.float32) * 1.5 + 0.1 - fused = fused_temperature_softmax(logits, temps) - fi = flashinfer_softmax(logits, temperature=temps.view(-1)) - self._check_close(fused, fi, atol=1e-4, rtol=1e-3) - - def test_vs_flashinfer_scalar_temperature(self): - logits = torch.randn(16, 32000, dtype=torch.bfloat16) - temps_2d = torch.full((16, 1), 0.8, dtype=torch.float32) - fused = fused_temperature_softmax(logits, temps_2d) - fi = flashinfer_softmax(logits, temperature=0.8) - self._check_close(fused, fi, atol=1e-4, rtol=1e-3) - - def test_vs_flashinfer_mixed_temperatures(self): - logits = torch.randn(8, 32000, dtype=torch.bfloat16) - temps = torch.tensor( - [0.1, 0.5, 0.7, 1.0, 1.2, 1.5, 2.0, 5.0], dtype=torch.float32 - ).view(-1, 1) - fused = fused_temperature_softmax(logits, temps) - fi = flashinfer_softmax(logits, temperature=temps.view(-1)) - self._check_close(fused, fi, atol=1e-4, rtol=1e-3) - - -if __name__ == "__main__": - unittest.main()