From fddeca5a557905b3f8d659eee3f0b598f6f38af4 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Mon, 19 Jan 2026 14:54:59 +0000 Subject: [PATCH 1/2] more more Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> style check Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> minor style change Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> more Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- benchmarks/README.md | 30 +- benchmarks/flashinfer_benchmark.py | 11 +- .../routines/flashinfer_benchmark_utils.py | 87 ++ benchmarks/routines/sampling.py | 997 ++++++++++++++++++ benchmarks/samples/sample_testlist.txt | 31 + 5 files changed, 1153 insertions(+), 3 deletions(-) create mode 100644 benchmarks/routines/sampling.py diff --git a/benchmarks/README.md b/benchmarks/README.md index b66882d38b..f37ca3b1ed 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -5,11 +5,11 @@ The aim of `flashinfer_benchmark.py` is to provide a single framework for benchm ## Overview This framework provides tools to: -- Benchmark FlashInfer's Attention, GEMM, MOE, Norm, and Quantization API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, CuTe-DSL, and TensorRT-LLM +- Benchmark FlashInfer's Attention, GEMM, MOE, Norm, Quantization, and Sampling API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, CuTe-DSL, and TensorRT-LLM - Compare performance across different configurations - Batch performance test multiple test cases -Currently supports testing attention, gemm, fused MOE, normalization, and quantization APIs: +Currently supports testing attention, gemm, fused MOE, normalization, quantization, and sampling APIs: - Attention: - `BatchDecodeWithPagedKVCacheWrapper` - Decode attention with paged KV cache. - Also supports computationally similar `cudnn_batch_decode_with_kv_cache` and `trtllm_batch_decode_with_kv_cache`. @@ -42,6 +42,14 @@ Currently supports testing attention, gemm, fused MOE, normalization, and quanti - `mxfp4_quantize` - Quantize tensor to MxFP4 format (Blackwell SM10.0+). - `nvfp4_quantize` - Quantize tensor to NVFP4 format with configurable scale factor layout (Blackwell SM10.0+). - `nvfp4_batched_quantize` - Batched NVFP4 quantization (Blackwell SM10.0+). +- Sampling: + - `sampling_from_probs` - Basic category sampling from probability distributions. + - `top_p_sampling_from_probs` - Top-p (nucleus) sampling from probabilities. + - `top_k_sampling_from_probs` - Top-k sampling from probabilities. + - `top_k_top_p_sampling_from_probs` - Combined top-k and top-p sampling from probabilities. + - `top_k_renorm_probs` - Renormalize probabilities by top-k thresholding. + - `top_p_renorm_probs` - Renormalize probabilities by top-p thresholding. + - `top_k_mask_logits` - Mask logits by top-k thresholding. ## Quick Start ### Single Test Run @@ -316,6 +324,17 @@ mpirun -np 8 python benchmarks/flashinfer_benchmark.py \ | `--sf_vec_size` | Scale factor vector size for NVFP4 quantization. Default: 16 | | `--backends` | Backend to test. Default: `cuda` | +### Sampling Flags +| Flag | Description | +|--------------------------|-------------------------------------------------------------------------------------------------------------| +| `--batch_size` | Batch size (number of sequences to sample from) | +| `--vocab_size` | Vocabulary size. Default: 128256 (Llama 3 vocab size) | +| `--input_dtype` | Input data type: `float32` (default), `float16`, or `bfloat16` | +| `--top_p` | Top-p threshold for nucleus sampling. Default: 0.9 | +| `--top_k` | Top-k threshold for top-k sampling. Default: 50 | +| `--no_deterministic` | Disable deterministic sampling. Default: deterministic is enabled | +| `--backends` | Backend to test. Default: `cuda` | + ## `flashinfer_benchmark.py` Routine & Backend Support Matrix The following table summarizes the support surface of each routine & backend's on various [CUDA Compute Capabilities](https://developer.nvidia.com/cuda-gpus). @@ -357,6 +376,13 @@ Legend: | **mxfp4_quantize** | | | | | | cuda | cuda | | | **nvfp4_quantize** | | | | | | cuda | cuda | | | **nvfp4_batched_quantize** | | | | | | cuda | cuda | | +| **sampling_from_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_p_sampling_from_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_sampling_from_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_top_p_sampling_from_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_renorm_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_p_renorm_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_mask_logits** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | Backend Legend: - fa2: FlashAttention2 diff --git a/benchmarks/flashinfer_benchmark.py b/benchmarks/flashinfer_benchmark.py index fdbc54098c..79430b5bf9 100644 --- a/benchmarks/flashinfer_benchmark.py +++ b/benchmarks/flashinfer_benchmark.py @@ -44,6 +44,10 @@ def run_test(args): from routines.quantization import run_quantization_test res = run_quantization_test(args) + elif args.routine in benchmark_apis["sampling"]: + from routines.sampling import run_sampling_test + + res = run_sampling_test(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -89,7 +93,8 @@ def parse_args(line=sys.argv[1:]): + list(benchmark_apis["moe"]) + list(benchmark_apis["moe_comm"]) + list(benchmark_apis["norm"]) - + list(benchmark_apis["quantization"]), + + list(benchmark_apis["quantization"]) + + list(benchmark_apis["sampling"]), ) args, _ = parser.parse_known_args(line[:]) @@ -199,6 +204,10 @@ def parse_args(line=sys.argv[1:]): from routines.quantization import parse_quantization_args args = parse_quantization_args(line, parser) + elif args.routine in benchmark_apis["sampling"]: + from routines.sampling import parse_sampling_args + + args = parse_sampling_args(line, parser) else: raise ValueError(f"Unsupported routine: {args.routine}") diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index b207f5cb43..74f8ef77ba 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -94,6 +94,12 @@ "do_shuffle", "sf_vec_size", ], + "sampling": [ + "vocab_size", + "top_p", + "top_k", + "deterministic", + ], "general": [ "batch_size", "hidden_size", @@ -118,6 +124,7 @@ + output_column_dict["moe_comm"] + output_column_dict["norm"] + output_column_dict["quantization"] + + output_column_dict["sampling"] + output_column_dict["general"] ) @@ -157,6 +164,15 @@ "nvfp4_quantize", "nvfp4_batched_quantize", ], + "sampling": [ + "sampling_from_probs", + "top_p_sampling_from_probs", + "top_k_sampling_from_probs", + "top_k_top_p_sampling_from_probs", + "top_k_renorm_probs", + "top_p_renorm_probs", + "top_k_mask_logits", + ], } @@ -431,6 +447,77 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cuda"], "12.0": ["cuda"], }, + # SAMPLING - supported on all architectures + "sampling_from_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_p_sampling_from_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_sampling_from_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_top_p_sampling_from_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_renorm_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_p_renorm_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_mask_logits": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, } diff --git a/benchmarks/routines/sampling.py b/benchmarks/routines/sampling.py new file mode 100644 index 0000000000..071d3884d9 --- /dev/null +++ b/benchmarks/routines/sampling.py @@ -0,0 +1,997 @@ +"""Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections import defaultdict + +import numpy as np +import torch + +import flashinfer +from flashinfer.testing.utils import bench_gpu_time + +from .flashinfer_benchmark_utils import ( + dtype_str_to_torch_dtype, + filter_backends_by_compute_capability, + get_device, + is_close_stats, + print_perf_metrics, +) + + +def run_sampling_test(args): + """Run a sampling test. We expose all sampling API in this benchmark. + TopK is under sampling_topk.py- please see it for topk benchmark. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + + """ + if args.routine == "sampling_from_probs": + return testSamplingFromProbs(args) + if args.routine == "top_p_sampling_from_probs": + return testTopPSamplingFromProbs(args) + if args.routine == "top_k_sampling_from_probs": + return testTopKSamplingFromProbs(args) + if args.routine == "top_k_top_p_sampling_from_probs": + return testTopKTopPSamplingFromProbs(args) + if args.routine == "top_k_renorm_probs": + return testTopKRenormProbs(args) + if args.routine == "top_p_renorm_probs": + return testTopPRenormProbs(args) + if args.routine == "top_k_mask_logits": + return testTopKMaskLogits(args) + raise ValueError(f"Unsupported routine: {args.routine}") + + +def parse_sampling_args(line, parser): + """Parse command line arguments for sampling test configuration. + + Args: + line: Command line arguments + parser: ArgumentParser object already populated with shared arguments + + Returns: + Parsed argument namespace + + """ + parser.add_argument( + "--batch_size", + type=int, + required=True, + help="Batch size (number of sequences to sample from).", + ) + parser.add_argument( + "--vocab_size", + type=int, + required=False, + default=128256, + help="Vocabulary size. Default: 128256 (Llama 3 vocab size).", + ) + parser.add_argument( + "--input_dtype", + type=str, + required=False, + default="float32", + help="Data type of the input tensor. Default: float32.", + ) + parser.add_argument( + "--top_p", + type=float, + required=False, + default=0.9, + help="Top-p threshold for nucleus sampling. Default: 0.9", + ) + parser.add_argument( + "--top_k", + type=int, + required=False, + default=50, + help="Top-k threshold for top-k sampling. Default: 50", + ) + parser.add_argument( + "--no_deterministic", + action="store_true", + default=False, + help="Disable deterministic sampling. Default: deterministic is enabled.", + ) + parser.add_argument( + "--backends", + type=str, + required=False, + nargs="+", + default=["cuda"], + choices=["cuda"], + help="Backend to test. Default: cuda.", + ) + + args = parser.parse_args(line) + if args.verbose >= 1: + print(f"[INFO] {args = }") + return args + + +def testSamplingFromProbs(args): + """Test sampling_from_probs API. + + This test: + Sampling rng is not compatible with CUDA Graphs + in the specific implementation in flashinfer, + so we disable them for test. + 1. Generates random probability distributions and normalize them in FP32 + 2. Runs sampling_from_probs + 3. Fetch performance numbers. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + + """ + if args.verbose >= 1: + print("[INFO] Running testSamplingFromProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + deterministic = not args.no_deterministic + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", + ) + + # Sampling uses RNG which is incompatible with CUDA graph capture + is_cuda_graph_compatible = False + + ## Prepare input tensors + input_shape = (batch_size, vocab_size) + # Generate random probabilities and normalize in float32 for numerical stability + pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) + probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) + probs = probs.to(input_dtype) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.sampling_from_probs( + probs, + deterministic=deterministic, + ) + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for sampling + # Read: probs tensor + # Write: samples tensor (int32) + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # probs read + + batch_size * 4 # samples write (int32) + ) + # Sampling is memory-bound + problem_flops = num_elements # rough estimate + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = str(std_time) + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["vocab_size"] = vocab_size + cur_res["deterministic"] = str(deterministic) + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopPSamplingFromProbs(args): + """Test top_p_sampling_from_probs API. + + This test: + 1. Generates random probability distributions and normalize them in FP32 + 2. Runs top_p_sampling_from_probs (nucleus sampling) + 3. Fetch performance numbers. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + + """ + if args.verbose >= 1: + print("[INFO] Running testTopPSamplingFromProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_p = args.top_p + deterministic = not args.no_deterministic + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", + ) + + # Sampling uses RNG which is incompatible with CUDA graph capture + is_cuda_graph_compatible = False + + ## Prepare input tensors + input_shape = (batch_size, vocab_size) + # Generate random probabilities and normalize in float32 for numerical stability + pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) + probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) + probs = probs.to(input_dtype) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_p = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_p_sampling_from_probs( + probs, + top_p=top_p, + deterministic=deterministic, + ) + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # probs read + + batch_size * 4 # samples write (int32) + ) + problem_flops = num_elements * 2 # sorting/filtering ops + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = str(std_time) + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["vocab_size"] = vocab_size + cur_res["top_p"] = top_p + cur_res["deterministic"] = str(deterministic) + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKSamplingFromProbs(args): + """Test top_k_sampling_from_probs API. + + This test: + 1. Generates random probability distributions + 2. Runs top_k_sampling_from_probs + 3. Measures performance metrics + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + + """ + if args.verbose >= 1: + print("[INFO] Running testTopKSamplingFromProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + deterministic = not args.no_deterministic + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", + ) + + # Sampling uses RNG which is incompatible with CUDA graph capture + is_cuda_graph_compatible = False + + ## Prepare input tensors + input_shape = (batch_size, vocab_size) + # Generate random probabilities and normalize in float32 for numerical stability + pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) + probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) + probs = probs.to(input_dtype) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_k = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_k_sampling_from_probs( + probs, + top_k=top_k, + deterministic=deterministic, + ) + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # probs read + + batch_size * 4 # samples write (int32) + ) + problem_flops = num_elements * 2 # sorting/filtering ops + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = str(std_time) + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["deterministic"] = str(deterministic) + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKTopPSamplingFromProbs(args): + """Test top_k_top_p_sampling_from_probs API. + + This test: + 1. Generates random probability distributions and normalize them in FP32 + 2. Runs top_k_top_p_sampling_from_probs (combined top-k and top-p) + 3. Fetch performance numbers. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + + """ + if args.verbose >= 1: + print("[INFO] Running testTopKTopPSamplingFromProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + top_p = args.top_p + deterministic = not args.no_deterministic + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", + ) + + # Sampling uses RNG which is incompatible with CUDA graph capture + is_cuda_graph_compatible = False + + ## Prepare input tensors + input_shape = (batch_size, vocab_size) + # Generate random probabilities and normalize in float32 for numerical stability + pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) + probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) + probs = probs.to(input_dtype) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_k = }") + print(f"[VVERBOSE] {top_p = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_k_top_p_sampling_from_probs( + probs, + top_k=top_k, + top_p=top_p, + deterministic=deterministic, + ) + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # probs read + + batch_size * 4 # samples write (int32) + ) + problem_flops = num_elements * 3 # more ops for combined filtering + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = str(std_time) + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["top_p"] = top_p + cur_res["deterministic"] = str(deterministic) + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKRenormProbs(args): + """Test top_k_renorm_probs API. + + This test: + 1. Generates random probability distributions and normalize them in FP32 + 2. Runs top_k_renorm_probs (renormalize by top-k thresholding) + 3. Fetch performance numbers. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + + """ + if args.verbose >= 1: + print("[INFO] Running testTopKRenormProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", + ) + + ## Prepare input tensors + input_shape = (batch_size, vocab_size) + # Generate random probabilities and normalize in float32 for numerical stability + pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) + probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) + probs = probs.to(input_dtype) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_k = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_k_renorm_probs(probs, top_k=top_k) + raise ValueError(f"Unsupported backend: {backend}") + + # Reference implementation for refcheck + has_reference_output = False + if run_refcheck: + # PyTorch reference: keep top-k, set rest to 0, renormalize + topk_vals, topk_indices = torch.topk(probs.float(), k=top_k, dim=-1) + reference_output = torch.zeros_like(probs) + # NOTE: dont explicitly specify dtype here + # keep it the same as input. + reference_output.scatter_(-1, topk_indices, topk_vals) + reference_output = reference_output / reference_output.sum(dim=-1, keepdim=True) + reference_output = reference_output.to(input_dtype) + has_reference_output = True + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend(cur_backend, probs).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + tested_backends = list(outputs.keys()) + tested_outputs = list(outputs.values()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats( + reference_output.float(), + tested_outputs[i].float(), + rtol=1e-2, + atol=1e-2, + ) + if num_different_elements > 0: + print( + f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}: " + f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ", + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {tested_backends[i]} output mismatch with {num_different_elements} elements", + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # probs read + + num_elements * input_dtype.itemsize # renorm_probs write + ) + problem_flops = num_elements * 2 + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = str(std_time) + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopPRenormProbs(args): + """Test top_p_renorm_probs API. + + This test: + 1. Generates random probability distributions + 2. Runs top_p_renorm_probs (renormalize by top-p thresholding) + 3. Measures performance metrics + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + + """ + if args.verbose >= 1: + print("[INFO] Running testTopPRenormProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_p = args.top_p + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", + ) + + ## Prepare input tensors + input_shape = (batch_size, vocab_size) + # Generate random probabilities and normalize in float32 for numerical stability + pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) + probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) + probs = probs.to(input_dtype) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_p = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_p_renorm_probs(probs, top_p=top_p) + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # probs read + + num_elements + * input_dtype.itemsize # renorm_probs write (same dtype as input) + ) + problem_flops = num_elements * 2 + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = str(std_time) + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["vocab_size"] = vocab_size + cur_res["top_p"] = top_p + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKMaskLogits(args): + """Test top_k_mask_logits API. + + This test: + 1. Generates random logits + 2. Runs top_k_mask_logits (mask logits by top-k thresholding) + 3. Measures performance metrics + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + + """ + if args.verbose >= 1: + print("[INFO] Running testTopKMaskLogits") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", + ) + + ## Prepare input tensors + input_shape = (batch_size, vocab_size) + logits = torch.randn(input_shape, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {logits.shape = }") + print(f"[VVERBOSE] {logits.dtype = }") + print(f"[VVERBOSE] {top_k = }") + + def run_backend(backend, logits): + if backend == "cuda": + return flashinfer.sampling.top_k_mask_logits(logits, top_k=top_k) + raise ValueError(f"Unsupported backend: {backend}") + + # Reference implementation for refcheck + has_reference_output = False + if run_refcheck: + # PyTorch reference: keep top-k logits, set rest to -inf + topk_vals, topk_indices = torch.topk(logits.float(), k=top_k, dim=-1) + reference_output = torch.full_like(logits, float("-inf")) + # NOTE: dont explicitly specify dtype here + # keep it the same as input. + reference_output.scatter_(-1, topk_indices, topk_vals) + reference_output = reference_output.to(input_dtype) + has_reference_output = True + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend(cur_backend, logits).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, logits), + ) + + tested_backends = list(outputs.keys()) + tested_outputs = list(outputs.values()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + # For masked logits, check: + # 1. Same positions are masked (-inf) + # 2. Unmasked values match the original logits + out = tested_outputs[i].float() + ref = reference_output.float() + + # Check that the same positions are masked + out_masked = torch.isinf(out) & (out < 0) + ref_masked = torch.isinf(ref) & (ref < 0) + mask_match = (out_masked == ref_masked).all() + if not mask_match: + print(f"[ERROR] Mask mismatch from backend {tested_backends[i]}") + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {tested_backends[i]} mask mismatch", + ) + + # Check that unmasked values match the reference (original top-k logits) + unmasked_positions = ~out_masked + if unmasked_positions.any(): + out_unmasked = out[unmasked_positions] + ref_unmasked = ref[unmasked_positions] + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats(ref_unmasked, out_unmasked, rtol=1e-3, atol=1e-3) + if num_different_elements > 0: + print( + f"[ERROR] Unmasked values mismatch from backend {tested_backends[i]}: " + f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ", + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {tested_backends[i]} unmasked values mismatch", + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # logits read + + num_elements * input_dtype.itemsize # masked_logits write + ) + problem_flops = num_elements * 2 + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = str(std_time) + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res diff --git a/benchmarks/samples/sample_testlist.txt b/benchmarks/samples/sample_testlist.txt index 03a33f33ec..c9bd7b7c76 100644 --- a/benchmarks/samples/sample_testlist.txt +++ b/benchmarks/samples/sample_testlist.txt @@ -175,3 +175,34 @@ --routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_basic" --routine nvfp4_batched_quantize --batch_size 8 --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_large" --routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype float16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_fp16" + +## Sampling +# Basic sampling from probabilities +--routine sampling_from_probs --batch_size 32 --vocab_size 128256 --input_dtype float32 -vv --generate_repro_command --case_tag "sampling_basic" +--routine sampling_from_probs --batch_size 128 --vocab_size 128256 --input_dtype float32 -vv --generate_repro_command --case_tag "sampling_large_batch" + +# Top-p (nucleus) sampling +--routine top_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_p 0.9 --input_dtype float32 -vv --generate_repro_command --case_tag "top_p_sampling_p09" +--routine top_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_p 0.95 --input_dtype float32 -vv --generate_repro_command --case_tag "top_p_sampling_p095" +--routine top_p_sampling_from_probs --batch_size 128 --vocab_size 128256 --top_p 0.9 --input_dtype float32 -vv --generate_repro_command --case_tag "top_p_sampling_large_batch" + +# Top-k sampling +--routine top_k_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_sampling_k50" +--routine top_k_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_k 10 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_sampling_k10" +--routine top_k_sampling_from_probs --batch_size 128 --vocab_size 128256 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_sampling_large_batch" + +# Combined top-k and top-p sampling +--routine top_k_top_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_k 50 --top_p 0.9 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_top_p_sampling" +--routine top_k_top_p_sampling_from_probs --batch_size 128 --vocab_size 128256 --top_k 50 --top_p 0.95 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_top_p_sampling_large" + +# Top-k renormalize probabilities +--routine top_k_renorm_probs --batch_size 32 --vocab_size 128256 --top_k 50 --input_dtype float32 --refcheck -vv --generate_repro_command --case_tag "top_k_renorm_probs" +--routine top_k_renorm_probs --batch_size 32 --vocab_size 128256 --top_k 10 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "top_k_renorm_probs_bf16" + +# Top-p renormalize probabilities +--routine top_p_renorm_probs --batch_size 32 --vocab_size 128256 --top_p 0.9 --input_dtype float32 -vv --generate_repro_command --case_tag "top_p_renorm_probs" +--routine top_p_renorm_probs --batch_size 32 --vocab_size 128256 --top_p 0.5 --input_dtype float32 -vv --generate_repro_command --case_tag "top_p_renorm_probs_p05" + +# Top-k mask logits +--routine top_k_mask_logits --batch_size 32 --vocab_size 128256 --top_k 50 --input_dtype float32 --refcheck -vv --generate_repro_command --case_tag "top_k_mask_logits" +--routine top_k_mask_logits --batch_size 32 --vocab_size 128256 --top_k 10 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "top_k_mask_logits_bf16" From 5e5e8115da851b1bf3b9b86592b39430c69baaa2 Mon Sep 17 00:00:00 2001 From: vincentzed <207368749+vincentzed@users.noreply.github.com> Date: Fri, 30 Jan 2026 13:47:58 +0000 Subject: [PATCH 2/2] WIP Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com> --- benchmarks/routines/sampling.py | 1175 ++++++++++++------------------- 1 file changed, 454 insertions(+), 721 deletions(-) diff --git a/benchmarks/routines/sampling.py b/benchmarks/routines/sampling.py index 071d3884d9..6c6069a0c0 100644 --- a/benchmarks/routines/sampling.py +++ b/benchmarks/routines/sampling.py @@ -30,6 +30,160 @@ ) +# ============================================================ +# Shared helpers to reduce boilerplate across sampling benchmarks +# ============================================================ + + +def _setup_sampling_benchmark(args, routine_name): + """Common setup: logging, device, backend filtering, dtype validation. + + Returns: + tuple: (device, backends, input_dtype). backends is empty list if none + are available for the current compute capability. + + """ + if args.verbose >= 1: + print(f"[INFO] Running {routine_name}") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", + ) + + backends = args.backends[:] + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return device, [], None + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. " + f"Supported dtypes are float32, float16, bfloat16.", + ) + + return device, backends, input_dtype + + +def _create_normalized_probs(batch_size, vocab_size, input_dtype, device): + """Generate random probability distributions, normalized in FP32 for stability. + + Returns: + tuple: (probs tensor, input_shape tuple) + + """ + input_shape = (batch_size, vocab_size) + pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) + probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) + return probs.to(input_dtype), input_shape + + +def _bench_sampling( + args, + backends, + run_backend, + input_tensor, + is_cuda_graph_compatible, + run_refcheck=False, +): + """Run timing across backends, optionally collecting outputs for refcheck. + + Returns: + tuple: (backend_times dict, outputs dict). outputs is empty if + run_refcheck is False. + + """ + backend_times = {} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend(cur_backend, input_tensor).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_tensor), + ) + return backend_times, outputs + + +def _collect_results( + args, + backends, + backend_times, + input_dtype, + problem_bytes, + problem_flops, + extra_result_fields, +): + """Calculate perf metrics and build result dicts. + + Returns: + list: Result dicts for CSV output. + + """ + res = [] + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = str(std_time) + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + cur_res.update(extra_result_fields) + res.append(cur_res) + return res + + +def _check_is_close(outputs, reference_output, args, rtol=1e-2, atol=1e-2): + """Compare backend outputs to reference using is_close_stats.""" + for backend, output in outputs.items(): + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats( + reference_output.float(), + output.float(), + rtol=rtol, + atol=atol, + ) + if num_different_elements > 0: + print( + f"[ERROR] Output tensor mismatch from backend {backend}: " + f"{num_different_elements}/{num_elements} " + f"({num_different_elements_percentage:.2f}%) elements differ", + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {backend} output mismatch " + f"with {num_different_elements} elements", + ) + + +# ============================================================ +# Public API +# ============================================================ + + def run_sampling_test(args): """Run a sampling test. We expose all sampling API in this benchmark. TopK is under sampling_topk.py- please see it for topk benchmark. @@ -125,61 +279,33 @@ def parse_sampling_args(line, parser): return args -def testSamplingFromProbs(args): - """Test sampling_from_probs API. +# ============================================================ +# Individual benchmark functions +# ============================================================ - This test: - Sampling rng is not compatible with CUDA Graphs - in the specific implementation in flashinfer, - so we disable them for test. - 1. Generates random probability distributions and normalize them in FP32 - 2. Runs sampling_from_probs - 3. Fetch performance numbers. - Args: - args: Parsed command line arguments containing test configuration +def testSamplingFromProbs(args): + """Test sampling_from_probs API. - Returns: - dict: List of dictionaries containing performance results + Sampling RNG is not compatible with CUDA Graphs in the specific + implementation in flashinfer, so CUDA graph capture is disabled. """ - if args.verbose >= 1: - print("[INFO] Running testSamplingFromProbs") - print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - device = get_device(args) - if args.generate_repro_command: - print( - f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", - ) + device, backends, input_dtype = _setup_sampling_benchmark( + args, + "testSamplingFromProbs", + ) + if not backends: + return [] - ## Parse input arguments - backends = args.backends[:] batch_size = args.batch_size - vocab_size = args.vocab_size deterministic = not args.no_deterministic - res = [] - - backends = filter_backends_by_compute_capability(backends, args.routine, device) - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return res - - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) - if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: - raise ValueError( - f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", - ) - - # Sampling uses RNG which is incompatible with CUDA graph capture - is_cuda_graph_compatible = False - - ## Prepare input tensors - input_shape = (batch_size, vocab_size) - # Generate random probabilities and normalize in float32 for numerical stability - pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) - probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) - probs = probs.to(input_dtype) + probs, input_shape = _create_normalized_probs( + batch_size, + args.vocab_size, + input_dtype, + device, + ) if args.verbose >= 2: print(f"[VVERBOSE] {probs.shape = }") @@ -193,107 +319,47 @@ def run_backend(backend, probs): ) raise ValueError(f"Unsupported backend: {backend}") - # Storage for timing results - backend_times = {backend: [] for backend in backends} - for cur_backend in backends: - backend_times[cur_backend] = bench_gpu_time( - fn=run_backend, - dry_run_iters=args.dry_run_iters, - repeat_iters=args.num_iters, - enable_cupti=args.use_cupti, - use_cuda_graph=is_cuda_graph_compatible, - input_args=(cur_backend, probs), - ) - - for backend in backends: - if len(backend_times[backend]) > 0: - median_time = np.median(backend_times[backend]) - std_time = np.std(backend_times[backend]) - - # Memory bandwidth calculation for sampling - # Read: probs tensor - # Write: samples tensor (int32) - num_elements = np.prod(input_shape) - problem_bytes = ( - num_elements * input_dtype.itemsize # probs read - + batch_size * 4 # samples write (int32) - ) - # Sampling is memory-bound - problem_flops = num_elements # rough estimate - tflops = problem_flops / (10**9 * median_time) - tb_per_sec = problem_bytes / (10**9 * median_time) - - print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + backend_times, _ = _bench_sampling( + args, + backends, + run_backend, + probs, + is_cuda_graph_compatible=False, + ) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = str(std_time) - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["input_dtype"] = str(input_dtype) - cur_res["vocab_size"] = vocab_size - cur_res["deterministic"] = str(deterministic) - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) - return res + num_elements = np.prod(input_shape) + return _collect_results( + args, + backends, + backend_times, + input_dtype, + problem_bytes=num_elements * input_dtype.itemsize + batch_size * 4, + problem_flops=num_elements, + extra_result_fields={ + "vocab_size": args.vocab_size, + "deterministic": str(deterministic), + }, + ) def testTopPSamplingFromProbs(args): - """Test top_p_sampling_from_probs API. - - This test: - 1. Generates random probability distributions and normalize them in FP32 - 2. Runs top_p_sampling_from_probs (nucleus sampling) - 3. Fetch performance numbers. - - Args: - args: Parsed command line arguments containing test configuration - - Returns: - dict: List of dictionaries containing performance results - - """ - if args.verbose >= 1: - print("[INFO] Running testTopPSamplingFromProbs") - print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - device = get_device(args) - if args.generate_repro_command: - print( - f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", - ) + """Test top_p_sampling_from_probs API (nucleus sampling).""" + device, backends, input_dtype = _setup_sampling_benchmark( + args, + "testTopPSamplingFromProbs", + ) + if not backends: + return [] - ## Parse input arguments - backends = args.backends[:] batch_size = args.batch_size - vocab_size = args.vocab_size top_p = args.top_p deterministic = not args.no_deterministic - res = [] - - backends = filter_backends_by_compute_capability(backends, args.routine, device) - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return res - - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) - if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: - raise ValueError( - f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", - ) - - # Sampling uses RNG which is incompatible with CUDA graph capture - is_cuda_graph_compatible = False - - ## Prepare input tensors - input_shape = (batch_size, vocab_size) - # Generate random probabilities and normalize in float32 for numerical stability - pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) - probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) - probs = probs.to(input_dtype) + probs, input_shape = _create_normalized_probs( + batch_size, + args.vocab_size, + input_dtype, + device, + ) if args.verbose >= 2: print(f"[VVERBOSE] {probs.shape = }") @@ -309,218 +375,106 @@ def run_backend(backend, probs): ) raise ValueError(f"Unsupported backend: {backend}") - # Storage for timing results - backend_times = {backend: [] for backend in backends} - for cur_backend in backends: - backend_times[cur_backend] = bench_gpu_time( - fn=run_backend, - dry_run_iters=args.dry_run_iters, - repeat_iters=args.num_iters, - enable_cupti=args.use_cupti, - use_cuda_graph=is_cuda_graph_compatible, - input_args=(cur_backend, probs), - ) - - for backend in backends: - if len(backend_times[backend]) > 0: - median_time = np.median(backend_times[backend]) - std_time = np.std(backend_times[backend]) + backend_times, _ = _bench_sampling( + args, + backends, + run_backend, + probs, + is_cuda_graph_compatible=False, + ) - num_elements = np.prod(input_shape) - problem_bytes = ( - num_elements * input_dtype.itemsize # probs read - + batch_size * 4 # samples write (int32) - ) - problem_flops = num_elements * 2 # sorting/filtering ops - tflops = problem_flops / (10**9 * median_time) - tb_per_sec = problem_bytes / (10**9 * median_time) + num_elements = np.prod(input_shape) + return _collect_results( + args, + backends, + backend_times, + input_dtype, + problem_bytes=num_elements * input_dtype.itemsize + batch_size * 4, + problem_flops=num_elements * 2, + extra_result_fields={ + "vocab_size": args.vocab_size, + "top_p": top_p, + "deterministic": str(deterministic), + }, + ) - print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = str(std_time) - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["input_dtype"] = str(input_dtype) - cur_res["vocab_size"] = vocab_size - cur_res["top_p"] = top_p - cur_res["deterministic"] = str(deterministic) - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) - return res +def testTopKSamplingFromProbs(args): + """Test top_k_sampling_from_probs API.""" + device, backends, input_dtype = _setup_sampling_benchmark( + args, + "testTopKSamplingFromProbs", + ) + if not backends: + return [] + batch_size = args.batch_size + top_k = args.top_k + deterministic = not args.no_deterministic + probs, input_shape = _create_normalized_probs( + batch_size, + args.vocab_size, + input_dtype, + device, + ) -def testTopKSamplingFromProbs(args): - """Test top_k_sampling_from_probs API. + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_k = }") - This test: - 1. Generates random probability distributions - 2. Runs top_k_sampling_from_probs - 3. Measures performance metrics + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_k_sampling_from_probs( + probs, + top_k=top_k, + deterministic=deterministic, + ) + raise ValueError(f"Unsupported backend: {backend}") - Args: - args: Parsed command line arguments containing test configuration + backend_times, _ = _bench_sampling( + args, + backends, + run_backend, + probs, + is_cuda_graph_compatible=False, + ) - Returns: - dict: List of dictionaries containing performance results + num_elements = np.prod(input_shape) + return _collect_results( + args, + backends, + backend_times, + input_dtype, + problem_bytes=num_elements * input_dtype.itemsize + batch_size * 4, + problem_flops=num_elements * 2, + extra_result_fields={ + "vocab_size": args.vocab_size, + "top_k": top_k, + "deterministic": str(deterministic), + }, + ) - """ - if args.verbose >= 1: - print("[INFO] Running testTopKSamplingFromProbs") - print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - device = get_device(args) - if args.generate_repro_command: - print( - f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", - ) +def testTopKTopPSamplingFromProbs(args): + """Test top_k_top_p_sampling_from_probs API (combined top-k and top-p).""" + device, backends, input_dtype = _setup_sampling_benchmark( + args, + "testTopKTopPSamplingFromProbs", + ) + if not backends: + return [] - ## Parse input arguments - backends = args.backends[:] batch_size = args.batch_size - vocab_size = args.vocab_size - top_k = args.top_k - deterministic = not args.no_deterministic - res = [] - - backends = filter_backends_by_compute_capability(backends, args.routine, device) - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return res - - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) - if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: - raise ValueError( - f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", - ) - - # Sampling uses RNG which is incompatible with CUDA graph capture - is_cuda_graph_compatible = False - - ## Prepare input tensors - input_shape = (batch_size, vocab_size) - # Generate random probabilities and normalize in float32 for numerical stability - pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) - probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) - probs = probs.to(input_dtype) - - if args.verbose >= 2: - print(f"[VVERBOSE] {probs.shape = }") - print(f"[VVERBOSE] {probs.dtype = }") - print(f"[VVERBOSE] {top_k = }") - - def run_backend(backend, probs): - if backend == "cuda": - return flashinfer.sampling.top_k_sampling_from_probs( - probs, - top_k=top_k, - deterministic=deterministic, - ) - raise ValueError(f"Unsupported backend: {backend}") - - # Storage for timing results - backend_times = {backend: [] for backend in backends} - for cur_backend in backends: - backend_times[cur_backend] = bench_gpu_time( - fn=run_backend, - dry_run_iters=args.dry_run_iters, - repeat_iters=args.num_iters, - enable_cupti=args.use_cupti, - use_cuda_graph=is_cuda_graph_compatible, - input_args=(cur_backend, probs), - ) - - for backend in backends: - if len(backend_times[backend]) > 0: - median_time = np.median(backend_times[backend]) - std_time = np.std(backend_times[backend]) - - num_elements = np.prod(input_shape) - problem_bytes = ( - num_elements * input_dtype.itemsize # probs read - + batch_size * 4 # samples write (int32) - ) - problem_flops = num_elements * 2 # sorting/filtering ops - tflops = problem_flops / (10**9 * median_time) - tb_per_sec = problem_bytes / (10**9 * median_time) - - print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = str(std_time) - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["input_dtype"] = str(input_dtype) - cur_res["vocab_size"] = vocab_size - cur_res["top_k"] = top_k - cur_res["deterministic"] = str(deterministic) - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) - return res - - -def testTopKTopPSamplingFromProbs(args): - """Test top_k_top_p_sampling_from_probs API. - - This test: - 1. Generates random probability distributions and normalize them in FP32 - 2. Runs top_k_top_p_sampling_from_probs (combined top-k and top-p) - 3. Fetch performance numbers. - - Args: - args: Parsed command line arguments containing test configuration - - Returns: - dict: List of dictionaries containing performance results - - """ - if args.verbose >= 1: - print("[INFO] Running testTopKTopPSamplingFromProbs") - print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - device = get_device(args) - if args.generate_repro_command: - print( - f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", - ) - - ## Parse input arguments - backends = args.backends[:] - batch_size = args.batch_size - vocab_size = args.vocab_size top_k = args.top_k top_p = args.top_p deterministic = not args.no_deterministic - res = [] - - backends = filter_backends_by_compute_capability(backends, args.routine, device) - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return res - - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) - if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: - raise ValueError( - f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", - ) - - # Sampling uses RNG which is incompatible with CUDA graph capture - is_cuda_graph_compatible = False - - ## Prepare input tensors - input_shape = (batch_size, vocab_size) - # Generate random probabilities and normalize in float32 for numerical stability - pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) - probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) - probs = probs.to(input_dtype) + probs, input_shape = _create_normalized_probs( + batch_size, + args.vocab_size, + input_dtype, + device, + ) if args.verbose >= 2: print(f"[VVERBOSE] {probs.shape = }") @@ -538,103 +492,49 @@ def run_backend(backend, probs): ) raise ValueError(f"Unsupported backend: {backend}") - # Storage for timing results - backend_times = {backend: [] for backend in backends} - for cur_backend in backends: - backend_times[cur_backend] = bench_gpu_time( - fn=run_backend, - dry_run_iters=args.dry_run_iters, - repeat_iters=args.num_iters, - enable_cupti=args.use_cupti, - use_cuda_graph=is_cuda_graph_compatible, - input_args=(cur_backend, probs), - ) - - for backend in backends: - if len(backend_times[backend]) > 0: - median_time = np.median(backend_times[backend]) - std_time = np.std(backend_times[backend]) - - num_elements = np.prod(input_shape) - problem_bytes = ( - num_elements * input_dtype.itemsize # probs read - + batch_size * 4 # samples write (int32) - ) - problem_flops = num_elements * 3 # more ops for combined filtering - tflops = problem_flops / (10**9 * median_time) - tb_per_sec = problem_bytes / (10**9 * median_time) - - print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + backend_times, _ = _bench_sampling( + args, + backends, + run_backend, + probs, + is_cuda_graph_compatible=False, + ) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = str(std_time) - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["input_dtype"] = str(input_dtype) - cur_res["vocab_size"] = vocab_size - cur_res["top_k"] = top_k - cur_res["top_p"] = top_p - cur_res["deterministic"] = str(deterministic) - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) - return res + num_elements = np.prod(input_shape) + return _collect_results( + args, + backends, + backend_times, + input_dtype, + problem_bytes=num_elements * input_dtype.itemsize + batch_size * 4, + problem_flops=num_elements * 3, + extra_result_fields={ + "vocab_size": args.vocab_size, + "top_k": top_k, + "top_p": top_p, + "deterministic": str(deterministic), + }, + ) def testTopKRenormProbs(args): - """Test top_k_renorm_probs API. - - This test: - 1. Generates random probability distributions and normalize them in FP32 - 2. Runs top_k_renorm_probs (renormalize by top-k thresholding) - 3. Fetch performance numbers. - - Args: - args: Parsed command line arguments containing test configuration - - Returns: - dict: List of dictionaries containing performance results - - """ - if args.verbose >= 1: - print("[INFO] Running testTopKRenormProbs") - print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - device = get_device(args) - if args.generate_repro_command: - print( - f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", - ) + """Test top_k_renorm_probs API (renormalize by top-k thresholding).""" + device, backends, input_dtype = _setup_sampling_benchmark( + args, + "testTopKRenormProbs", + ) + if not backends: + return [] - ## Parse input arguments - backends = args.backends[:] - batch_size = args.batch_size - vocab_size = args.vocab_size top_k = args.top_k is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - res = [] - - backends = filter_backends_by_compute_capability(backends, args.routine, device) - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return res - - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) - if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: - raise ValueError( - f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", - ) - - ## Prepare input tensors - input_shape = (batch_size, vocab_size) - # Generate random probabilities and normalize in float32 for numerical stability - pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) - probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) - probs = probs.to(input_dtype) + probs, input_shape = _create_normalized_probs( + args.batch_size, + args.vocab_size, + input_dtype, + device, + ) if args.verbose >= 2: print(f"[VVERBOSE] {probs.shape = }") @@ -646,141 +546,59 @@ def run_backend(backend, probs): return flashinfer.sampling.top_k_renorm_probs(probs, top_k=top_k) raise ValueError(f"Unsupported backend: {backend}") - # Reference implementation for refcheck - has_reference_output = False - if run_refcheck: + backend_times, outputs = _bench_sampling( + args, + backends, + run_backend, + probs, + is_cuda_graph_compatible, + run_refcheck=run_refcheck, + ) + + if run_refcheck and outputs: # PyTorch reference: keep top-k, set rest to 0, renormalize topk_vals, topk_indices = torch.topk(probs.float(), k=top_k, dim=-1) reference_output = torch.zeros_like(probs) - # NOTE: dont explicitly specify dtype here - # keep it the same as input. reference_output.scatter_(-1, topk_indices, topk_vals) - reference_output = reference_output / reference_output.sum(dim=-1, keepdim=True) - reference_output = reference_output.to(input_dtype) - has_reference_output = True - - # Storage for timing results and outputs - backend_times = {backend: [] for backend in backends} - outputs = {} - for cur_backend in backends: - if run_refcheck: - outputs[cur_backend] = run_backend(cur_backend, probs).detach() - backend_times[cur_backend] = bench_gpu_time( - fn=run_backend, - dry_run_iters=args.dry_run_iters, - repeat_iters=args.num_iters, - enable_cupti=args.use_cupti, - use_cuda_graph=is_cuda_graph_compatible, - input_args=(cur_backend, probs), + reference_output = reference_output / reference_output.sum( + dim=-1, + keepdim=True, ) - - tested_backends = list(outputs.keys()) - tested_outputs = list(outputs.values()) - if len(tested_backends) > 0: - if run_refcheck and has_reference_output: - for i in range(len(tested_backends)): - ( - num_different_elements, - num_elements, - num_different_elements_percentage, - ) = is_close_stats( - reference_output.float(), - tested_outputs[i].float(), - rtol=1e-2, - atol=1e-2, - ) - if num_different_elements > 0: - print( - f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}: " - f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ", - ) - if not args.allow_output_mismatch: - raise AssertionError( - f"[ERROR] Backend {tested_backends[i]} output mismatch with {num_different_elements} elements", - ) - - for backend in backends: - if len(backend_times[backend]) > 0: - median_time = np.median(backend_times[backend]) - std_time = np.std(backend_times[backend]) - - num_elements = np.prod(input_shape) - problem_bytes = ( - num_elements * input_dtype.itemsize # probs read - + num_elements * input_dtype.itemsize # renorm_probs write - ) - problem_flops = num_elements * 2 - tflops = problem_flops / (10**9 * median_time) - tb_per_sec = problem_bytes / (10**9 * median_time) - - print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = str(std_time) - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["input_dtype"] = str(input_dtype) - cur_res["vocab_size"] = vocab_size - cur_res["top_k"] = top_k - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) - return res + reference_output = reference_output.to(input_dtype) + _check_is_close(outputs, reference_output, args) + + num_elements = np.prod(input_shape) + return _collect_results( + args, + backends, + backend_times, + input_dtype, + problem_bytes=num_elements * input_dtype.itemsize * 2, + problem_flops=num_elements * 2, + extra_result_fields={ + "vocab_size": args.vocab_size, + "top_k": top_k, + }, + ) def testTopPRenormProbs(args): - """Test top_p_renorm_probs API. - - This test: - 1. Generates random probability distributions - 2. Runs top_p_renorm_probs (renormalize by top-p thresholding) - 3. Measures performance metrics - - Args: - args: Parsed command line arguments containing test configuration - - Returns: - dict: List of dictionaries containing performance results - - """ - if args.verbose >= 1: - print("[INFO] Running testTopPRenormProbs") - print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - device = get_device(args) - if args.generate_repro_command: - print( - f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", - ) + """Test top_p_renorm_probs API (renormalize by top-p thresholding).""" + device, backends, input_dtype = _setup_sampling_benchmark( + args, + "testTopPRenormProbs", + ) + if not backends: + return [] - ## Parse input arguments - backends = args.backends[:] - batch_size = args.batch_size - vocab_size = args.vocab_size top_p = args.top_p is_cuda_graph_compatible = not args.no_cuda_graph - res = [] - - backends = filter_backends_by_compute_capability(backends, args.routine, device) - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return res - - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) - if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: - raise ValueError( - f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", - ) - - ## Prepare input tensors - input_shape = (batch_size, vocab_size) - # Generate random probabilities and normalize in float32 for numerical stability - pre_norm_probs = torch.rand(input_shape, dtype=torch.float32, device=device) - probs = pre_norm_probs / pre_norm_probs.sum(dim=-1, keepdim=True) - probs = probs.to(input_dtype) + probs, input_shape = _create_normalized_probs( + args.batch_size, + args.vocab_size, + input_dtype, + device, + ) if args.verbose >= 2: print(f"[VVERBOSE] {probs.shape = }") @@ -792,98 +610,44 @@ def run_backend(backend, probs): return flashinfer.sampling.top_p_renorm_probs(probs, top_p=top_p) raise ValueError(f"Unsupported backend: {backend}") - # Storage for timing results - backend_times = {backend: [] for backend in backends} - for cur_backend in backends: - backend_times[cur_backend] = bench_gpu_time( - fn=run_backend, - dry_run_iters=args.dry_run_iters, - repeat_iters=args.num_iters, - enable_cupti=args.use_cupti, - use_cuda_graph=is_cuda_graph_compatible, - input_args=(cur_backend, probs), - ) - - for backend in backends: - if len(backend_times[backend]) > 0: - median_time = np.median(backend_times[backend]) - std_time = np.std(backend_times[backend]) - - num_elements = np.prod(input_shape) - problem_bytes = ( - num_elements * input_dtype.itemsize # probs read - + num_elements - * input_dtype.itemsize # renorm_probs write (same dtype as input) - ) - problem_flops = num_elements * 2 - tflops = problem_flops / (10**9 * median_time) - tb_per_sec = problem_bytes / (10**9 * median_time) - - print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + backend_times, _ = _bench_sampling( + args, + backends, + run_backend, + probs, + is_cuda_graph_compatible, + ) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = str(std_time) - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["input_dtype"] = str(input_dtype) - cur_res["vocab_size"] = vocab_size - cur_res["top_p"] = top_p - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) - return res + num_elements = np.prod(input_shape) + return _collect_results( + args, + backends, + backend_times, + input_dtype, + problem_bytes=num_elements * input_dtype.itemsize * 2, + problem_flops=num_elements * 2, + extra_result_fields={ + "vocab_size": args.vocab_size, + "top_p": top_p, + }, + ) def testTopKMaskLogits(args): - """Test top_k_mask_logits API. - - This test: - 1. Generates random logits - 2. Runs top_k_mask_logits (mask logits by top-k thresholding) - 3. Measures performance metrics - - Args: - args: Parsed command line arguments containing test configuration - - Returns: - dict: List of dictionaries containing performance results - - """ - if args.verbose >= 1: - print("[INFO] Running testTopKMaskLogits") - print(f"[INFO] FlashInfer version: {flashinfer.__version__}") - - device = get_device(args) - if args.generate_repro_command: - print( - f"[INFO] To reproduce this test case, run the following command: {args.repro_command}", - ) + """Test top_k_mask_logits API (mask logits by top-k thresholding).""" + device, backends, input_dtype = _setup_sampling_benchmark( + args, + "testTopKMaskLogits", + ) + if not backends: + return [] - ## Parse input arguments - backends = args.backends[:] - batch_size = args.batch_size - vocab_size = args.vocab_size top_k = args.top_k is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - res = [] - backends = filter_backends_by_compute_capability(backends, args.routine, device) - if len(backends) == 0: - print("[ERROR] No backends to test. Exiting.") - return res - - input_dtype = dtype_str_to_torch_dtype(args.input_dtype) - if input_dtype not in [torch.float32, torch.float16, torch.bfloat16]: - raise ValueError( - f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are float32, float16, bfloat16.", - ) - - ## Prepare input tensors - input_shape = (batch_size, vocab_size) + # Unlike other sampling benchmarks, this operates on raw logits, not probs + input_shape = (args.batch_size, args.vocab_size) logits = torch.randn(input_shape, dtype=input_dtype, device=device) if args.verbose >= 2: @@ -896,102 +660,71 @@ def run_backend(backend, logits): return flashinfer.sampling.top_k_mask_logits(logits, top_k=top_k) raise ValueError(f"Unsupported backend: {backend}") - # Reference implementation for refcheck - has_reference_output = False - if run_refcheck: + backend_times, outputs = _bench_sampling( + args, + backends, + run_backend, + logits, + is_cuda_graph_compatible, + run_refcheck=run_refcheck, + ) + + if run_refcheck and outputs: # PyTorch reference: keep top-k logits, set rest to -inf topk_vals, topk_indices = torch.topk(logits.float(), k=top_k, dim=-1) reference_output = torch.full_like(logits, float("-inf")) - # NOTE: dont explicitly specify dtype here - # keep it the same as input. reference_output.scatter_(-1, topk_indices, topk_vals) reference_output = reference_output.to(input_dtype) - has_reference_output = True - # Storage for timing results and outputs - backend_times = {backend: [] for backend in backends} - outputs = {} - for cur_backend in backends: - if run_refcheck: - outputs[cur_backend] = run_backend(cur_backend, logits).detach() - backend_times[cur_backend] = bench_gpu_time( - fn=run_backend, - dry_run_iters=args.dry_run_iters, - repeat_iters=args.num_iters, - enable_cupti=args.use_cupti, - use_cuda_graph=is_cuda_graph_compatible, - input_args=(cur_backend, logits), - ) + for backend, output in outputs.items(): + out = output.float() + ref = reference_output.float() + + # Check that the same positions are masked (-inf) + out_masked = torch.isinf(out) & (out < 0) + ref_masked = torch.isinf(ref) & (ref < 0) + mask_match = (out_masked == ref_masked).all() + if not mask_match: + print(f"[ERROR] Mask mismatch from backend {backend}") + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {backend} mask mismatch", + ) - tested_backends = list(outputs.keys()) - tested_outputs = list(outputs.values()) - if len(tested_backends) > 0: - if run_refcheck and has_reference_output: - for i in range(len(tested_backends)): - # For masked logits, check: - # 1. Same positions are masked (-inf) - # 2. Unmasked values match the original logits - out = tested_outputs[i].float() - ref = reference_output.float() - - # Check that the same positions are masked - out_masked = torch.isinf(out) & (out < 0) - ref_masked = torch.isinf(ref) & (ref < 0) - mask_match = (out_masked == ref_masked).all() - if not mask_match: - print(f"[ERROR] Mask mismatch from backend {tested_backends[i]}") + # Check that unmasked values match the reference + unmasked_positions = ~out_masked + if unmasked_positions.any(): + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats( + ref[unmasked_positions], + out[unmasked_positions], + rtol=1e-3, + atol=1e-3, + ) + if num_different_elements > 0: + print( + f"[ERROR] Unmasked values mismatch from backend {backend}: " + f"{num_different_elements}/{num_elements} " + f"({num_different_elements_percentage:.2f}%) elements differ", + ) if not args.allow_output_mismatch: raise AssertionError( - f"[ERROR] Backend {tested_backends[i]} mask mismatch", + f"[ERROR] Backend {backend} unmasked values mismatch", ) - # Check that unmasked values match the reference (original top-k logits) - unmasked_positions = ~out_masked - if unmasked_positions.any(): - out_unmasked = out[unmasked_positions] - ref_unmasked = ref[unmasked_positions] - ( - num_different_elements, - num_elements, - num_different_elements_percentage, - ) = is_close_stats(ref_unmasked, out_unmasked, rtol=1e-3, atol=1e-3) - if num_different_elements > 0: - print( - f"[ERROR] Unmasked values mismatch from backend {tested_backends[i]}: " - f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ", - ) - if not args.allow_output_mismatch: - raise AssertionError( - f"[ERROR] Backend {tested_backends[i]} unmasked values mismatch", - ) - - for backend in backends: - if len(backend_times[backend]) > 0: - median_time = np.median(backend_times[backend]) - std_time = np.std(backend_times[backend]) - - num_elements = np.prod(input_shape) - problem_bytes = ( - num_elements * input_dtype.itemsize # logits read - + num_elements * input_dtype.itemsize # masked_logits write - ) - problem_flops = num_elements * 2 - tflops = problem_flops / (10**9 * median_time) - tb_per_sec = problem_bytes / (10**9 * median_time) - - print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = str(std_time) - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["input_dtype"] = str(input_dtype) - cur_res["vocab_size"] = vocab_size - cur_res["top_k"] = top_k - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) - return res + num_elements = np.prod(input_shape) + return _collect_results( + args, + backends, + backend_times, + input_dtype, + problem_bytes=num_elements * input_dtype.itemsize * 2, + problem_flops=num_elements * 2, + extra_result_fields={ + "vocab_size": args.vocab_size, + "top_k": top_k, + }, + )