diff --git a/benchmarks/README.md b/benchmarks/README.md index b66882d38b..f8555fbb24 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -5,11 +5,11 @@ The aim of `flashinfer_benchmark.py` is to provide a single framework for benchm ## Overview This framework provides tools to: -- Benchmark FlashInfer's Attention, GEMM, MOE, Norm, and Quantization API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, CuTe-DSL, and TensorRT-LLM +- Benchmark FlashInfer's Attention, GEMM, MOE, Norm, Quantization, Sampling, and RoPE API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, CuTe-DSL, and TensorRT-LLM - Compare performance across different configurations - Batch performance test multiple test cases -Currently supports testing attention, gemm, fused MOE, normalization, and quantization APIs: +Currently supports testing attention, gemm, fused MOE, normalization, quantization, sampling, and RoPE APIs: - Attention: - `BatchDecodeWithPagedKVCacheWrapper` - Decode attention with paged KV cache. - Also supports computationally similar `cudnn_batch_decode_with_kv_cache` and `trtllm_batch_decode_with_kv_cache`. @@ -42,6 +42,31 @@ Currently supports testing attention, gemm, fused MOE, normalization, and quanti - `mxfp4_quantize` - Quantize tensor to MxFP4 format (Blackwell SM10.0+). - `nvfp4_quantize` - Quantize tensor to NVFP4 format with configurable scale factor layout (Blackwell SM10.0+). - `nvfp4_batched_quantize` - Batched NVFP4 quantization (Blackwell SM10.0+). +- Sampling: + - `softmax` - Softmax with optional temperature scaling. + - `sampling_from_probs` - Sample token indices from probability distributions. + - `sampling_from_logits` - Sample token indices from logits (fused softmax + sampling). + - `top_k_sampling_from_probs` - Top-K sampling from probabilities. + - `top_p_sampling_from_probs` - Top-P (nucleus) sampling from probabilities. + - `top_k_top_p_sampling_from_probs` - Combined Top-K and Top-P sampling from probabilities. + - `top_k_top_p_sampling_from_logits` - Combined Top-K and Top-P sampling from logits. + - `min_p_sampling_from_probs` - Min-P sampling from probabilities. + - `top_k_renorm_probs` - Renormalize probabilities after Top-K filtering. + - `top_p_renorm_probs` - Renormalize probabilities after Top-P filtering. + - `top_k_mask_logits` - Mask logits outside Top-K values. + - `chain_speculative_sampling` - Chain speculative sampling for speculative decoding. + - `top_k` - Radix-based Top-K selection. + - `top_k_page_table_transform` - Fused Top-K with page table lookup. + - `top_k_ragged_transform` - Fused Top-K with ragged index transform. +- RoPE (Rotary Positional Embeddings): + - `apply_rope` - Apply RoPE with indptr/offsets. + - `apply_rope_pos_ids` - Apply RoPE with position IDs. + - `apply_llama31_rope` - Apply Llama 3.1 style RoPE with indptr/offsets. + - `apply_llama31_rope_pos_ids` - Apply Llama 3.1 style RoPE with position IDs. + - `apply_rope_with_cos_sin_cache` - Apply RoPE with precomputed cos/sin cache. + - `mla_rope_quantize_fp8` - MLA RoPE with FP8 quantization (SM8.9+). + - `rope_quantize_fp8` - RoPE with FP8 quantization (SM8.9+). + - `rope_quantize_fp8_append_paged_kv_cache` - RoPE with FP8 quantization and paged KV cache append (SM8.9+). ## Quick Start ### Single Test Run @@ -316,6 +341,44 @@ mpirun -np 8 python benchmarks/flashinfer_benchmark.py \ | `--sf_vec_size` | Scale factor vector size for NVFP4 quantization. Default: 16 | | `--backends` | Backend to test. Default: `cuda` | +### Sampling Flags +| Flag | Description | +|--------------------------|-------------------------------------------------------------------------------------------------------------| +| `--batch_size` | Batch size (number of sequences) | +| `--vocab_size` | Vocabulary size | +| `--input_dtype` | Input data type for logits: `float32` (default), `float16`, or `bfloat16` | +| `--top_k` | Top-K value for top-k sampling. Default: 50 | +| `--top_p` | Top-P threshold for top-p (nucleus) sampling. Default: 0.9 | +| `--min_p` | Min-P threshold for min-p sampling. Default: 0.1 | +| `--temperature` | Temperature for softmax. Default: 1.0 | +| `--filter_apply_order` | Order of applying top-k and top-p filters: `top_k_first` (default) or `joint` | +| `--num_speculate_tokens` | Number of speculative tokens for chain speculative sampling. Default: 5 | +| `--max_len` | Max sequence length for `top_k_page_table_transform` and `top_k_ragged_transform`. Default: 4096 | +| `--num_rows` | Number of rows for `top_k_page_table_transform` and `top_k_ragged_transform`. Defaults to batch_size | +| `--backends` | Backend to test: `cuda` (default) | + +### RoPE Flags +| Flag | Description | +|--------------------------|-------------------------------------------------------------------------------------------------------------| +| `--batch_size` | Batch size (number of sequences) | +| `--seq_len` | Sequence length (qkv_len or kv_len) | +| `--num_qo_heads` | Number of query/output heads | +| `--num_kv_heads` | Number of key/value heads | +| `--head_dim` | Head dimension | +| `--rotary_dim` | Rotary dimension (defaults to head_dim if not specified) | +| `--no_rope_dim` | Number of dimensions without RoPE (for MLA). Default: 0 | +| `--input_dtype` | Input data type: `float16` (default) or `bfloat16` | +| `--quant_dtype` | Quantized data type for FP8 routines: `fp8_e4m3` (default) or `fp8_e5m2` | +| `--rope_scale` | RoPE scaling factor. Default: 1.0 | +| `--rope_theta` | RoPE theta base frequency. Default: 10000.0 | +| `--interleave` | Use interleaved rotary embedding (GPT-J style) | +| `--page_size` | Page size for paged KV cache. Default: 16 | +| `--kv_layout` | KV cache layout: `NHD` (default) or `HND` | +| `--low_freq_factor` | Low frequency factor for Llama 3.1 RoPE. Default: 1.0 | +| `--high_freq_factor` | High frequency factor for Llama 3.1 RoPE. Default: 4.0 | +| `--old_context_len` | Old context length for Llama 3.1 RoPE. Default: 8192 | +| `--backends` | Backend to test: `cuda` (default) | + ## `flashinfer_benchmark.py` Routine & Backend Support Matrix The following table summarizes the support surface of each routine & backend's on various [CUDA Compute Capabilities](https://developer.nvidia.com/cuda-gpus). @@ -357,6 +420,29 @@ Legend: | **mxfp4_quantize** | | | | | | cuda | cuda | | | **nvfp4_quantize** | | | | | | cuda | cuda | | | **nvfp4_batched_quantize** | | | | | | cuda | cuda | | +| **softmax** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **sampling_from_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **sampling_from_logits** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_sampling_from_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_p_sampling_from_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_top_p_sampling_from_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_top_p_sampling_from_logits** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **min_p_sampling_from_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_renorm_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_p_renorm_probs** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_mask_logits** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **chain_speculative_sampling** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_page_table_transform** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **top_k_ragged_transform** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **apply_rope** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **apply_rope_pos_ids** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **apply_llama31_rope** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **apply_llama31_rope_pos_ids** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **apply_rope_with_cos_sin_cache** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **mla_rope_quantize_fp8** | | | | cuda | cuda | cuda | cuda | cuda | +| **rope_quantize_fp8** | | | | cuda | cuda | cuda | cuda | cuda | +| **rope_quantize_fp8_append_paged_kv_cache** | | | | cuda | cuda | cuda | cuda | cuda | Backend Legend: - fa2: FlashAttention2 diff --git a/benchmarks/flashinfer_benchmark.py b/benchmarks/flashinfer_benchmark.py index fdbc54098c..2b01f4cf82 100644 --- a/benchmarks/flashinfer_benchmark.py +++ b/benchmarks/flashinfer_benchmark.py @@ -44,6 +44,14 @@ def run_test(args): from routines.quantization import run_quantization_test res = run_quantization_test(args) + elif args.routine in benchmark_apis["sampling"]: + from routines.sampling import run_sampling_test + + res = run_sampling_test(args) + elif args.routine in benchmark_apis["rope"]: + from routines.rope import run_rope_test + + res = run_rope_test(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -89,7 +97,9 @@ def parse_args(line=sys.argv[1:]): + list(benchmark_apis["moe"]) + list(benchmark_apis["moe_comm"]) + list(benchmark_apis["norm"]) - + list(benchmark_apis["quantization"]), + + list(benchmark_apis["quantization"]) + + list(benchmark_apis["sampling"]) + + list(benchmark_apis["rope"]), ) args, _ = parser.parse_known_args(line[:]) @@ -199,6 +209,14 @@ def parse_args(line=sys.argv[1:]): from routines.quantization import parse_quantization_args args = parse_quantization_args(line, parser) + elif args.routine in benchmark_apis["sampling"]: + from routines.sampling import parse_sampling_args + + args = parse_sampling_args(line, parser) + elif args.routine in benchmark_apis["rope"]: + from routines.rope import parse_rope_args + + args = parse_rope_args(line, parser) else: raise ValueError(f"Unsupported routine: {args.routine}") diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index b207f5cb43..28638dedf9 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -14,11 +14,8 @@ "backend", ], "attention": [ - "page_size", "s_qo", "s_kv", - "num_qo_heads", - "num_kv_heads", "head_dim_qk", "head_dim_vo", "head_dim_ckv", @@ -30,9 +27,7 @@ "random_actual_seq_len", ], "gemm": [ - "m", "n", - "k", "group_size", "tile_size", "scale_major_mode", @@ -67,38 +62,58 @@ ], "moe_comm": [ "num_tokens", - "hidden_size", "num_experts", "top_k", "ep_size", - "input_dtype", - "quant_dtype", "max_num_tokens", ], "norm": [ "num_heads", "scale", "eps", - "enable_pdl", "use_global_scale", - "is_sf_swizzled_layout", ], "quantization": [ - "m", - "k", - "is_sf_swizzled_layout", "alignment", - "enable_pdl", "global_scale", "sf_layout", "do_shuffle", "sf_vec_size", ], + "sampling": [ + "vocab_size", + "top_k", + "top_p", + "min_p", + "temperature", + "num_speculate_tokens", + "filter_apply_order", + "max_len", + "num_rows", + ], + "rope": [ + "seq_len", + "head_dim", + "rotary_dim", + "no_rope_dim", + "rope_theta", + "rope_scale", + "interleave", + "kv_layout", + ], "general": [ "batch_size", "hidden_size", "input_dtype", "out_dtype", + "quant_dtype", + "m", + "k", + "num_qo_heads", + "num_kv_heads", + "page_size", + "enable_pdl", + "is_sf_swizzled_layout", "refcheck", "no_cuda_graph", "use_cupti", @@ -118,6 +133,8 @@ + output_column_dict["moe_comm"] + output_column_dict["norm"] + output_column_dict["quantization"] + + output_column_dict["sampling"] + + output_column_dict["rope"] + output_column_dict["general"] ) @@ -157,6 +174,33 @@ "nvfp4_quantize", "nvfp4_batched_quantize", ], + "sampling": [ + "softmax", + "sampling_from_probs", + "sampling_from_logits", + "top_k_sampling_from_probs", + "top_p_sampling_from_probs", + "top_k_top_p_sampling_from_probs", + "top_k_top_p_sampling_from_logits", + "min_p_sampling_from_probs", + "top_k_renorm_probs", + "top_p_renorm_probs", + "top_k_mask_logits", + "chain_speculative_sampling", + "top_k", + "top_k_page_table_transform", + "top_k_ragged_transform", + ], + "rope": [ + "apply_rope", + "apply_rope_pos_ids", + "apply_llama31_rope", + "apply_llama31_rope_pos_ids", + "apply_rope_with_cos_sin_cache", + "mla_rope_quantize_fp8", + "rope_quantize_fp8", + "rope_quantize_fp8_append_paged_kv_cache", + ], } @@ -431,6 +475,238 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cuda"], "12.0": ["cuda"], }, + # SAMPLING + "softmax": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "sampling_from_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "sampling_from_logits": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_sampling_from_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_p_sampling_from_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_top_p_sampling_from_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_top_p_sampling_from_logits": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "min_p_sampling_from_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_renorm_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_p_renorm_probs": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_mask_logits": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "chain_speculative_sampling": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_page_table_transform": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "top_k_ragged_transform": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + # ROPE + "apply_rope": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "apply_rope_pos_ids": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "apply_llama31_rope": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "apply_llama31_rope_pos_ids": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "apply_rope_with_cos_sin_cache": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "mla_rope_quantize_fp8": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "rope_quantize_fp8": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "rope_quantize_fp8_append_paged_kv_cache": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, } diff --git a/benchmarks/routines/rope.py b/benchmarks/routines/rope.py new file mode 100644 index 0000000000..efd06f6ab4 --- /dev/null +++ b/benchmarks/routines/rope.py @@ -0,0 +1,1710 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections import defaultdict + +import numpy as np +import torch + +import flashinfer +from flashinfer.testing.utils import bench_gpu_time + +from .flashinfer_benchmark_utils import ( + dtype_str_to_torch_dtype, + get_device, + print_perf_metrics, + filter_backends_by_compute_capability, +) + + +def run_rope_test(args): + """ + Run a RoPE test. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.routine == "apply_rope": + return testApplyRope(args) + elif args.routine == "apply_rope_pos_ids": + return testApplyRopePosIds(args) + elif args.routine == "apply_llama31_rope": + return testApplyLlama31Rope(args) + elif args.routine == "apply_llama31_rope_pos_ids": + return testApplyLlama31RopePosIds(args) + elif args.routine == "apply_rope_with_cos_sin_cache": + return testApplyRopeWithCosSinCache(args) + elif args.routine == "mla_rope_quantize_fp8": + return testMlaRopeQuantizeFp8(args) + elif args.routine == "rope_quantize_fp8": + return testRopeQuantizeFp8(args) + elif args.routine == "rope_quantize_fp8_append_paged_kv_cache": + return testRopeQuantizeFp8AppendPagedKvCache(args) + else: + raise ValueError(f"Unsupported routine: {args.routine}") + + +def parse_rope_args(line, parser): + """ + Parse command line arguments for RoPE test configuration. + + Args: + line: Command line arguments + parser: ArgumentParser object already populated with shared arguments + + Returns: + Parsed argument namespace + """ + parser.add_argument( + "--batch_size", + type=int, + required=True, + help="Batch size.", + ) + parser.add_argument( + "--seq_len", + type=int, + required=True, + help="Sequence length (qkv_len or kv_len).", + ) + parser.add_argument( + "--num_qo_heads", + type=int, + required=True, + help="Number of query/output heads.", + ) + parser.add_argument( + "--num_kv_heads", + type=int, + required=True, + help="Number of key/value heads.", + ) + parser.add_argument( + "--head_dim", + type=int, + required=True, + help="Head dimension.", + ) + parser.add_argument( + "--rotary_dim", + type=int, + required=False, + default=None, + help="Rotary dimension (defaults to head_dim if not specified).", + ) + parser.add_argument( + "--no_rope_dim", + type=int, + required=False, + default=0, + help="Number of dimensions without RoPE (for MLA). Default: 0.", + ) + parser.add_argument( + "--input_dtype", + type=str, + required=False, + default="float16", + choices=["float16", "bfloat16"], + help="Data type of the input tensor.", + ) + parser.add_argument( + "--quant_dtype", + type=str, + required=False, + default="fp8_e4m3", + choices=["fp8_e4m3", "fp8_e5m2"], + help="Quantized data type for FP8 routines.", + ) + parser.add_argument( + "--rope_scale", + type=float, + required=False, + default=1.0, + help="RoPE scaling factor.", + ) + parser.add_argument( + "--rope_theta", + type=float, + required=False, + default=10000.0, + help="RoPE theta base frequency.", + ) + parser.add_argument( + "--interleave", + action="store_true", + help="Use interleaved rotary embedding (GPT-J style).", + ) + parser.add_argument( + "--page_size", + type=int, + required=False, + default=16, + help="Page size for paged KV cache.", + ) + parser.add_argument( + "--kv_layout", + type=str, + required=False, + default="NHD", + choices=["NHD", "HND"], + help="KV cache layout.", + ) + parser.add_argument( + "--low_freq_factor", + type=float, + required=False, + default=1.0, + help="Low frequency factor for Llama 3.1 RoPE.", + ) + parser.add_argument( + "--high_freq_factor", + type=float, + required=False, + default=4.0, + help="High frequency factor for Llama 3.1 RoPE.", + ) + parser.add_argument( + "--old_context_len", + type=int, + required=False, + default=8192, + help="Old context length for Llama 3.1 RoPE.", + ) + parser.add_argument( + "--backends", + type=str, + required=False, + nargs="+", + default=["cuda"], + choices=["cuda"], + help="Kernel backends to test. Default: cuda", + ) + + args = parser.parse_args(line) + + # Default rotary_dim to head_dim if not specified + if args.rotary_dim is None: + args.rotary_dim = args.head_dim + + if args.verbose >= 1: + print(f"[INFO] {args = }") + return args + + +def testApplyRope(args): + """ + Test apply_rope API (with indptr/offsets). + + This test: + 1. Generates random Q and K tensors + 2. Runs flashinfer.rope.apply_rope with indptr/offsets + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testApplyRope") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + seq_len = args.seq_len + num_qo_heads = args.num_qo_heads + num_kv_heads = args.num_kv_heads + head_dim = args.head_dim + rotary_dim = args.rotary_dim + rope_scale = args.rope_scale + rope_theta = args.rope_theta + interleave = args.interleave + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + # Shape: (batch_size * seq_len, num_heads, head_dim) + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=input_dtype, device=device + ) + k = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # indptr for ragged tensor + indptr = torch.arange( + 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device=device + ) + + # offsets (per-request position offset) + offsets = torch.zeros(batch_size, dtype=torch.int32, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {q.shape = }") + print(f"[VVERBOSE] {k.shape = }") + print(f"[VVERBOSE] {indptr.shape = }") + print(f"[VVERBOSE] {offsets.shape = }") + print(f"[VVERBOSE] {rotary_dim = }") + print(f"[VVERBOSE] {rope_scale = }") + print(f"[VVERBOSE] {rope_theta = }") + print(f"[VVERBOSE] {interleave = }") + + def run_backend(backend, q, k, indptr, offsets): + if backend == "cuda": + return flashinfer.rope.apply_rope( + q, + k, + indptr=indptr, + offsets=offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, q, k, indptr, offsets), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + total_tokens = batch_size * seq_len + # Memory bandwidth calculation + # Read: q + k + # Write: q_rope + k_rope + problem_bytes = ( + total_tokens * num_qo_heads * head_dim * input_dtype.itemsize # q read + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k read + + total_tokens + * num_qo_heads + * head_dim + * input_dtype.itemsize # q_rope write + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k_rope write + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["seq_len"] = seq_len + cur_res["num_qo_heads"] = num_qo_heads + cur_res["num_kv_heads"] = num_kv_heads + cur_res["head_dim"] = head_dim + cur_res["rotary_dim"] = rotary_dim + cur_res["rope_theta"] = rope_theta + cur_res["rope_scale"] = rope_scale + cur_res["interleave"] = interleave + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testApplyRopePosIds(args): + """ + Test apply_rope API with pos_ids. + + This test: + 1. Generates random Q and K tensors + 2. Runs flashinfer.rope.apply_rope with pos_ids + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testApplyRopePosIds") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + seq_len = args.seq_len + num_qo_heads = args.num_qo_heads + num_kv_heads = args.num_kv_heads + head_dim = args.head_dim + rotary_dim = args.rotary_dim + rope_scale = args.rope_scale + rope_theta = args.rope_theta + interleave = args.interleave + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + # Shape: (batch_size * seq_len, num_heads, head_dim) + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=input_dtype, device=device + ) + k = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # pos_ids: (batch_size * seq_len,) + pos_ids = torch.arange(seq_len, dtype=torch.int32, device=device).repeat(batch_size) + + if args.verbose >= 2: + print(f"[VVERBOSE] {q.shape = }") + print(f"[VVERBOSE] {k.shape = }") + print(f"[VVERBOSE] {pos_ids.shape = }") + print(f"[VVERBOSE] {rotary_dim = }") + print(f"[VVERBOSE] {rope_scale = }") + print(f"[VVERBOSE] {rope_theta = }") + print(f"[VVERBOSE] {interleave = }") + + def run_backend(backend, q, k, pos_ids): + if backend == "cuda": + return flashinfer.rope.apply_rope_pos_ids( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, q, k, pos_ids), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + total_tokens = batch_size * seq_len + # Memory bandwidth calculation + # Read: q + k + pos_ids + # Write: q_rope + k_rope + problem_bytes = ( + total_tokens * num_qo_heads * head_dim * input_dtype.itemsize # q read + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k read + + total_tokens * 4 # pos_ids read (int32) + + total_tokens + * num_qo_heads + * head_dim + * input_dtype.itemsize # q_rope write + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k_rope write + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["seq_len"] = seq_len + cur_res["num_qo_heads"] = num_qo_heads + cur_res["num_kv_heads"] = num_kv_heads + cur_res["head_dim"] = head_dim + cur_res["rotary_dim"] = rotary_dim + cur_res["rope_theta"] = rope_theta + cur_res["rope_scale"] = rope_scale + cur_res["interleave"] = interleave + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testApplyLlama31Rope(args): + """ + Test apply_llama31_rope API (with indptr/offsets). + + This test: + 1. Generates random Q and K tensors + 2. Runs flashinfer.rope.apply_llama31_rope with indptr/offsets + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testApplyLlama31Rope") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + seq_len = args.seq_len + num_qo_heads = args.num_qo_heads + num_kv_heads = args.num_kv_heads + head_dim = args.head_dim + rotary_dim = args.rotary_dim + rope_scale = args.rope_scale + rope_theta = args.rope_theta + interleave = args.interleave + low_freq_factor = args.low_freq_factor + high_freq_factor = args.high_freq_factor + old_context_len = args.old_context_len + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + # Shape: (batch_size * seq_len, num_heads, head_dim) + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=input_dtype, device=device + ) + k = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # indptr for ragged tensor + indptr = torch.arange( + 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device=device + ) + + # offsets (per-request position offset) + offsets = torch.zeros(batch_size, dtype=torch.int32, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {q.shape = }") + print(f"[VVERBOSE] {k.shape = }") + print(f"[VVERBOSE] {indptr.shape = }") + print(f"[VVERBOSE] {offsets.shape = }") + print(f"[VVERBOSE] {rotary_dim = }") + print(f"[VVERBOSE] {rope_scale = }") + print(f"[VVERBOSE] {rope_theta = }") + print(f"[VVERBOSE] {interleave = }") + print(f"[VVERBOSE] {low_freq_factor = }") + print(f"[VVERBOSE] {high_freq_factor = }") + print(f"[VVERBOSE] {old_context_len = }") + + def run_backend(backend, q, k, indptr, offsets): + if backend == "cuda": + return flashinfer.rope.apply_llama31_rope( + q, + k, + indptr=indptr, + offsets=offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, q, k, indptr, offsets), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + total_tokens = batch_size * seq_len + # Memory bandwidth calculation + # Read: q + k + # Write: q_rope + k_rope + problem_bytes = ( + total_tokens * num_qo_heads * head_dim * input_dtype.itemsize # q read + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k read + + total_tokens + * num_qo_heads + * head_dim + * input_dtype.itemsize # q_rope write + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k_rope write + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["seq_len"] = seq_len + cur_res["num_qo_heads"] = num_qo_heads + cur_res["num_kv_heads"] = num_kv_heads + cur_res["head_dim"] = head_dim + cur_res["rotary_dim"] = rotary_dim + cur_res["rope_theta"] = rope_theta + cur_res["rope_scale"] = rope_scale + cur_res["interleave"] = interleave + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testApplyLlama31RopePosIds(args): + """ + Test apply_llama31_rope API with pos_ids. + + This test: + 1. Generates random Q and K tensors + 2. Runs flashinfer.rope.apply_llama31_rope with pos_ids + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testApplyLlama31RopePosIds") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + seq_len = args.seq_len + num_qo_heads = args.num_qo_heads + num_kv_heads = args.num_kv_heads + head_dim = args.head_dim + rotary_dim = args.rotary_dim + rope_scale = args.rope_scale + rope_theta = args.rope_theta + interleave = args.interleave + low_freq_factor = args.low_freq_factor + high_freq_factor = args.high_freq_factor + old_context_len = args.old_context_len + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + # Shape: (batch_size * seq_len, num_heads, head_dim) + q = torch.randn( + batch_size * seq_len, num_qo_heads, head_dim, dtype=input_dtype, device=device + ) + k = torch.randn( + batch_size * seq_len, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # pos_ids: (batch_size * seq_len,) + pos_ids = torch.arange(seq_len, dtype=torch.int32, device=device).repeat(batch_size) + + if args.verbose >= 2: + print(f"[VVERBOSE] {q.shape = }") + print(f"[VVERBOSE] {k.shape = }") + print(f"[VVERBOSE] {pos_ids.shape = }") + print(f"[VVERBOSE] {rotary_dim = }") + print(f"[VVERBOSE] {rope_scale = }") + print(f"[VVERBOSE] {rope_theta = }") + print(f"[VVERBOSE] {interleave = }") + print(f"[VVERBOSE] {low_freq_factor = }") + print(f"[VVERBOSE] {high_freq_factor = }") + print(f"[VVERBOSE] {old_context_len = }") + + def run_backend(backend, q, k, pos_ids): + if backend == "cuda": + return flashinfer.rope.apply_llama31_rope_pos_ids( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, q, k, pos_ids), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + total_tokens = batch_size * seq_len + # Memory bandwidth calculation + # Read: q + k + pos_ids + # Write: q_rope + k_rope + problem_bytes = ( + total_tokens * num_qo_heads * head_dim * input_dtype.itemsize # q read + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k read + + total_tokens * 4 # pos_ids read (int32) + + total_tokens + * num_qo_heads + * head_dim + * input_dtype.itemsize # q_rope write + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k_rope write + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["seq_len"] = seq_len + cur_res["num_qo_heads"] = num_qo_heads + cur_res["num_kv_heads"] = num_kv_heads + cur_res["head_dim"] = head_dim + cur_res["rotary_dim"] = rotary_dim + cur_res["rope_theta"] = rope_theta + cur_res["rope_scale"] = rope_scale + cur_res["interleave"] = interleave + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testApplyRopeWithCosSinCache(args): + """ + Test apply_rope_with_cos_sin_cache API. + + This test: + 1. Generates random Q and K tensors with precomputed cos/sin cache + 2. Runs flashinfer.rope.apply_rope_with_cos_sin_cache + 3. Measures performance metrics (TB/sec) + + Note: This API uses flattened Q/K tensors and a combined cos_sin_cache. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testApplyRopeWithCosSinCache") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + seq_len = args.seq_len + num_qo_heads = args.num_qo_heads + num_kv_heads = args.num_kv_heads + head_dim = args.head_dim + rotary_dim = args.rotary_dim + interleave = args.interleave + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + total_tokens = batch_size * seq_len + # Shape: (total_tokens, num_heads * head_dim) - flattened for this API + q = torch.randn( + total_tokens, num_qo_heads * head_dim, dtype=input_dtype, device=device + ) + k = torch.randn( + total_tokens, num_kv_heads * head_dim, dtype=input_dtype, device=device + ) + + # Precomputed cos_sin_cache: (max_seq_len, rotary_dim) + # First half is cos, second half is sin + max_seq_len = seq_len + cos_sin_cache = torch.randn( + max_seq_len, rotary_dim, dtype=input_dtype, device=device + ) + + # positions: (total_tokens,) + positions = torch.arange(seq_len, dtype=torch.long, device=device).repeat( + batch_size + ) + + # is_neox is the inverse of interleave + is_neox = not interleave + + if args.verbose >= 2: + print(f"[VVERBOSE] {q.shape = }") + print(f"[VVERBOSE] {k.shape = }") + print(f"[VVERBOSE] {cos_sin_cache.shape = }") + print(f"[VVERBOSE] {positions.shape = }") + print(f"[VVERBOSE] {is_neox = }") + + def run_backend(backend, positions, q, k, cos_sin_cache): + if backend == "cuda": + return flashinfer.rope.apply_rope_with_cos_sin_cache( + positions, + q, + k, + head_dim, + cos_sin_cache, + is_neox=is_neox, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, positions, q, k, cos_sin_cache), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: q + k + cos_sin_cache + positions + # Write: q_rope + k_rope + problem_bytes = ( + total_tokens * num_qo_heads * head_dim * input_dtype.itemsize # q read + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k read + + max_seq_len * rotary_dim * input_dtype.itemsize # cos_sin_cache read + + total_tokens * 8 # positions read (int64) + + total_tokens + * num_qo_heads + * head_dim + * input_dtype.itemsize # q_rope write + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # k_rope write + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["seq_len"] = seq_len + cur_res["num_qo_heads"] = num_qo_heads + cur_res["num_kv_heads"] = num_kv_heads + cur_res["head_dim"] = head_dim + cur_res["rotary_dim"] = rotary_dim + cur_res["interleave"] = interleave + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testMlaRopeQuantizeFp8(args): + """ + Test mla_rope_quantize_fp8 API (for MLA attention). + + This test: + 1. Generates random pre-split Q and K tensors (rotary and non-rotary parts) + 2. Creates precomputed cos_sin_cache + 3. Runs flashinfer.rope.mla_rope_quantize_fp8 + 4. Measures performance metrics (TB/sec) + + Note: This API takes pre-split q_rope, k_rope, q_nope, k_nope tensors + and a precomputed cos_sin_cache. It is the same as rope_quantize_fp8. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testMlaRopeQuantizeFp8") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + seq_len = args.seq_len + num_qo_heads = args.num_qo_heads + num_kv_heads = args.num_kv_heads + head_dim = args.head_dim + no_rope_dim = args.no_rope_dim + rope_dim = head_dim - no_rope_dim + interleave = args.interleave + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + quant_dtype = dtype_str_to_torch_dtype(args.quant_dtype) + + ## Prepare input tensors (pre-split for this API) + total_tokens = batch_size * seq_len + max_seq_len = seq_len + + # q_rope: (total_tokens, num_qo_heads, rope_dim) + q_rope = torch.randn( + total_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + # k_rope: (total_tokens, rope_dim) for MLA (no num_kv_heads dimension) + k_rope = torch.randn(total_tokens, rope_dim, dtype=input_dtype, device=device) + # q_nope: (total_tokens, num_qo_heads, no_rope_dim) or None + q_nope = ( + torch.randn( + total_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + if no_rope_dim > 0 + else None + ) + # k_nope: (total_tokens, no_rope_dim) for MLA or None + k_nope = ( + torch.randn(total_tokens, no_rope_dim, dtype=input_dtype, device=device) + if no_rope_dim > 0 + else None + ) + + # Precomputed cos_sin_cache: (max_seq_len, rope_dim) in float32 + cos_sin_cache = torch.randn( + max_seq_len, rope_dim, dtype=torch.float32, device=device + ) + + # pos_ids: (total_tokens,) + pos_ids = torch.arange(seq_len, dtype=torch.int32, device=device).repeat(batch_size) + + # is_neox is the inverse of interleave + is_neox = not interleave + + if args.verbose >= 2: + print(f"[VVERBOSE] {q_rope.shape = }") + print(f"[VVERBOSE] {k_rope.shape = }") + print( + f"[VVERBOSE] q_nope.shape = {q_nope.shape if q_nope is not None else None}" + ) + print( + f"[VVERBOSE] k_nope.shape = {k_nope.shape if k_nope is not None else None}" + ) + print(f"[VVERBOSE] {cos_sin_cache.shape = }") + print(f"[VVERBOSE] {pos_ids.shape = }") + print(f"[VVERBOSE] {rope_dim = }") + print(f"[VVERBOSE] {no_rope_dim = }") + print(f"[VVERBOSE] {is_neox = }") + + def run_backend(backend, q_rope, k_rope, q_nope, k_nope, cos_sin_cache, pos_ids): + if backend == "cuda": + return flashinfer.rope.mla_rope_quantize_fp8( + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + is_neox=is_neox, + quantize_dtype=quant_dtype, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=( + cur_backend, + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + ), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for MLA + # Read: q_rope + k_rope + q_nope + k_nope + cos_sin_cache + pos_ids + # Write: q_rope_out + k_rope_out + q_nope_out + k_nope_out + nope_bytes = 0 + if no_rope_dim > 0: + nope_bytes = ( + total_tokens + * num_qo_heads + * no_rope_dim + * input_dtype.itemsize # q_nope read + + total_tokens + * no_rope_dim + * input_dtype.itemsize # k_nope read (MLA shape) + + total_tokens + * num_qo_heads + * no_rope_dim + * quant_dtype.itemsize # q_nope_out write + + total_tokens + * no_rope_dim + * quant_dtype.itemsize # k_nope_out write (MLA shape) + ) + problem_bytes = ( + total_tokens + * num_qo_heads + * rope_dim + * input_dtype.itemsize # q_rope read + + total_tokens + * rope_dim + * input_dtype.itemsize # k_rope read (MLA shape) + + nope_bytes + + max_seq_len * rope_dim * 4 # cos_sin_cache read (float32) + + total_tokens * 4 # pos_ids read (int32) + + total_tokens + * num_qo_heads + * rope_dim + * quant_dtype.itemsize # q_rope_out write + + total_tokens + * rope_dim + * quant_dtype.itemsize # k_rope_out write (MLA shape) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["seq_len"] = seq_len + cur_res["num_qo_heads"] = num_qo_heads + cur_res["num_kv_heads"] = num_kv_heads + cur_res["head_dim"] = head_dim + cur_res["no_rope_dim"] = no_rope_dim + cur_res["interleave"] = interleave + cur_res["quant_dtype"] = args.quant_dtype + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testRopeQuantizeFp8(args): + """ + Test rope_quantize_fp8 API. + + This test: + 1. Generates random pre-split Q and K tensors (rotary and non-rotary parts) + 2. Creates precomputed cos_sin_cache + 3. Runs flashinfer.rope.rope_quantize_fp8 + 4. Measures performance metrics (TB/sec) + + Note: This API takes pre-split q_rope, k_rope, q_nope, k_nope tensors + and a precomputed cos_sin_cache. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testRopeQuantizeFp8") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + seq_len = args.seq_len + num_qo_heads = args.num_qo_heads + num_kv_heads = args.num_kv_heads + head_dim = args.head_dim + rotary_dim = args.rotary_dim + no_rope_dim = args.no_rope_dim + interleave = args.interleave + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + quant_dtype = dtype_str_to_torch_dtype(args.quant_dtype) + + ## Prepare input tensors (pre-split for this API) + total_tokens = batch_size * seq_len + max_seq_len = seq_len + + # q_rope: (total_tokens, num_qo_heads, rotary_dim) + q_rope = torch.randn( + total_tokens, num_qo_heads, rotary_dim, dtype=input_dtype, device=device + ) + # k_rope: (total_tokens, num_kv_heads, rotary_dim) + k_rope = torch.randn( + total_tokens, num_kv_heads, rotary_dim, dtype=input_dtype, device=device + ) + # q_nope: (total_tokens, num_qo_heads, no_rope_dim) or None + q_nope = ( + torch.randn( + total_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + if no_rope_dim > 0 + else None + ) + # k_nope: (total_tokens, num_kv_heads, no_rope_dim) or None + k_nope = ( + torch.randn( + total_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + if no_rope_dim > 0 + else None + ) + + # Precomputed cos_sin_cache: (max_seq_len, rotary_dim) in float32 + cos_sin_cache = torch.randn( + max_seq_len, rotary_dim, dtype=torch.float32, device=device + ) + + # pos_ids: (total_tokens,) + pos_ids = torch.arange(seq_len, dtype=torch.int32, device=device).repeat(batch_size) + + # is_neox is the inverse of interleave + is_neox = not interleave + + if args.verbose >= 2: + print(f"[VVERBOSE] {q_rope.shape = }") + print(f"[VVERBOSE] {k_rope.shape = }") + print( + f"[VVERBOSE] q_nope.shape = {q_nope.shape if q_nope is not None else None}" + ) + print( + f"[VVERBOSE] k_nope.shape = {k_nope.shape if k_nope is not None else None}" + ) + print(f"[VVERBOSE] {cos_sin_cache.shape = }") + print(f"[VVERBOSE] {pos_ids.shape = }") + print(f"[VVERBOSE] {rotary_dim = }") + print(f"[VVERBOSE] {no_rope_dim = }") + print(f"[VVERBOSE] {is_neox = }") + + def run_backend(backend, q_rope, k_rope, q_nope, k_nope, cos_sin_cache, pos_ids): + if backend == "cuda": + return flashinfer.rope.rope_quantize_fp8( + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + is_neox=is_neox, + quantize_dtype=quant_dtype, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=( + cur_backend, + q_rope, + k_rope, + q_nope, + k_nope, + cos_sin_cache, + pos_ids, + ), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: q_rope + k_rope + q_nope + k_nope + cos_sin_cache + pos_ids + # Write: q_rope_out + k_rope_out + q_nope_out + k_nope_out + nope_bytes = 0 + if no_rope_dim > 0: + nope_bytes = ( + total_tokens + * num_qo_heads + * no_rope_dim + * input_dtype.itemsize # q_nope read + + total_tokens + * num_kv_heads + * no_rope_dim + * input_dtype.itemsize # k_nope read + + total_tokens + * num_qo_heads + * no_rope_dim + * quant_dtype.itemsize # q_nope_out write + + total_tokens + * num_kv_heads + * no_rope_dim + * quant_dtype.itemsize # k_nope_out write + ) + problem_bytes = ( + total_tokens + * num_qo_heads + * rotary_dim + * input_dtype.itemsize # q_rope read + + total_tokens + * num_kv_heads + * rotary_dim + * input_dtype.itemsize # k_rope read + + nope_bytes + + max_seq_len * rotary_dim * 4 # cos_sin_cache read (float32) + + total_tokens * 4 # pos_ids read (int32) + + total_tokens + * num_qo_heads + * rotary_dim + * quant_dtype.itemsize # q_rope_out write + + total_tokens + * num_kv_heads + * rotary_dim + * quant_dtype.itemsize # k_rope_out write + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["seq_len"] = seq_len + cur_res["num_qo_heads"] = num_qo_heads + cur_res["num_kv_heads"] = num_kv_heads + cur_res["head_dim"] = head_dim + cur_res["rotary_dim"] = rotary_dim + cur_res["no_rope_dim"] = no_rope_dim + cur_res["interleave"] = interleave + cur_res["quant_dtype"] = args.quant_dtype + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testRopeQuantizeFp8AppendPagedKvCache(args): + """ + Test rope_quantize_fp8_append_paged_kv_cache API. + + This test: + 1. Generates random pre-split Q, K, V tensors with precomputed cos_sin_cache + 2. Creates paged KV cache in FP8 + 3. Runs flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache + 4. Measures performance metrics (TB/sec) + + Note: This API takes pre-split tensors (q_rope, k_rope, q_nope, k_nope, v) + and a precomputed cos_sin_cache. The paged KV cache is a tuple of + (k_cache, v_cache) both in FP8. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testRopeQuantizeFp8AppendPagedKvCache") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + seq_len = args.seq_len + num_qo_heads = args.num_qo_heads + num_kv_heads = args.num_kv_heads + head_dim = args.head_dim + rotary_dim = args.rotary_dim + no_rope_dim = head_dim - rotary_dim # For GQA/MHA + interleave = args.interleave + page_size = args.page_size + kv_layout = args.kv_layout + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + quant_dtype = dtype_str_to_torch_dtype(args.quant_dtype) + + ## Prepare input tensors (pre-split for this API) + total_tokens = batch_size * seq_len + max_seq_len = seq_len + + # q_rope: (total_tokens, num_qo_heads, rotary_dim) + q_rope = torch.randn( + total_tokens, num_qo_heads, rotary_dim, dtype=input_dtype, device=device + ) + # k_rope: (total_tokens, num_kv_heads, rotary_dim) + k_rope = torch.randn( + total_tokens, num_kv_heads, rotary_dim, dtype=input_dtype, device=device + ) + # q_nope: (total_tokens, num_qo_heads, no_rope_dim) or None + q_nope = ( + torch.randn( + total_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + if no_rope_dim > 0 + else None + ) + # k_nope: (total_tokens, num_kv_heads, no_rope_dim) or None + k_nope = ( + torch.randn( + total_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + if no_rope_dim > 0 + else None + ) + # v: (total_tokens, num_kv_heads, head_dim) + v = torch.randn( + total_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # Precomputed cos_sin_cache: (max_seq_len, rotary_dim) in float32 + cos_sin_cache = torch.randn( + max_seq_len, rotary_dim, dtype=torch.float32, device=device + ) + + # pos_ids: (total_tokens,) + pos_ids = torch.arange(seq_len, dtype=torch.int32, device=device).repeat(batch_size) + + # Paged KV cache - separate k and v caches as a tuple + # Note: FP8 tensors cannot be created with randn, use empty instead + num_pages_per_request = (seq_len + page_size - 1) // page_size + total_pages = batch_size * num_pages_per_request + if kv_layout == "NHD": + k_cache = torch.empty( + total_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.empty( + total_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + else: # HND + k_cache = torch.empty( + total_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.empty( + total_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + paged_kv_cache = (k_cache, v_cache) + + # KV indices: page indices for each request + kv_indices = torch.arange(total_pages, dtype=torch.int32, device=device) + + # KV indptr: (batch_size + 1,) + kv_indptr = torch.arange( + 0, + (batch_size + 1) * num_pages_per_request, + num_pages_per_request, + dtype=torch.int32, + device=device, + ) + + # Batch indices: which request each token belongs to + batch_indices = torch.arange( + batch_size, dtype=torch.int32, device=device + ).repeat_interleave(seq_len) + + # Positions: position within each request's sequence for each token + positions = torch.arange(seq_len, dtype=torch.int32, device=device).repeat( + batch_size + ) + + # is_neox is the inverse of interleave + is_neox = not interleave + + if args.verbose >= 2: + print(f"[VVERBOSE] {q_rope.shape = }") + print(f"[VVERBOSE] {k_rope.shape = }") + print( + f"[VVERBOSE] q_nope.shape = {q_nope.shape if q_nope is not None else None}" + ) + print( + f"[VVERBOSE] k_nope.shape = {k_nope.shape if k_nope is not None else None}" + ) + print(f"[VVERBOSE] {v.shape = }") + print(f"[VVERBOSE] {cos_sin_cache.shape = }") + print(f"[VVERBOSE] {pos_ids.shape = }") + print(f"[VVERBOSE] k_cache.shape = {k_cache.shape}") + print(f"[VVERBOSE] v_cache.shape = {v_cache.shape}") + print(f"[VVERBOSE] {kv_indices.shape = }") + print(f"[VVERBOSE] {kv_indptr.shape = }") + print(f"[VVERBOSE] {batch_indices.shape = }") + print(f"[VVERBOSE] {positions.shape = }") + print(f"[VVERBOSE] {rotary_dim = }") + print(f"[VVERBOSE] {no_rope_dim = }") + print(f"[VVERBOSE] {is_neox = }") + print(f"[VVERBOSE] {page_size = }") + print(f"[VVERBOSE] {kv_layout = }") + + def run_backend( + backend, + q_rope, + k_rope, + q_nope, + k_nope, + v, + cos_sin_cache, + pos_ids, + paged_kv_cache, + kv_indices, + kv_indptr, + batch_indices, + positions, + ): + if backend == "cuda": + return flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + v, + cos_sin_cache, + pos_ids, + paged_kv_cache, + kv_indices, + kv_indptr, + batch_indices, + positions, + is_neox=is_neox, + quantize_dtype=quant_dtype, + page_size=page_size, + kv_layout=kv_layout, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=( + cur_backend, + q_rope, + k_rope, + q_nope, + k_nope, + v, + cos_sin_cache, + pos_ids, + paged_kv_cache, + kv_indices, + kv_indptr, + batch_indices, + positions, + ), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: q_rope + k_rope + q_nope + k_nope + v + cos_sin_cache + pos_ids + # Write: q_rope_out + q_nope_out + paged_kv_cache (k and v) + nope_bytes = 0 + if no_rope_dim > 0: + nope_bytes = ( + total_tokens + * num_qo_heads + * no_rope_dim + * input_dtype.itemsize # q_nope read + + total_tokens + * num_kv_heads + * no_rope_dim + * input_dtype.itemsize # k_nope read + + total_tokens + * num_qo_heads + * no_rope_dim + * quant_dtype.itemsize # q_nope_out write + ) + problem_bytes = ( + total_tokens + * num_qo_heads + * rotary_dim + * input_dtype.itemsize # q_rope read + + total_tokens + * num_kv_heads + * rotary_dim + * input_dtype.itemsize # k_rope read + + total_tokens + * num_kv_heads + * head_dim + * input_dtype.itemsize # v read + + nope_bytes + + max_seq_len * rotary_dim * 4 # cos_sin_cache read (float32) + + total_tokens * 4 # pos_ids read (int32) + + total_tokens + * num_qo_heads + * rotary_dim + * quant_dtype.itemsize # q_rope_out write + + total_tokens + * num_kv_heads + * head_dim + * quant_dtype.itemsize + * 2 # k, v to paged cache + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["seq_len"] = seq_len + cur_res["num_qo_heads"] = num_qo_heads + cur_res["num_kv_heads"] = num_kv_heads + cur_res["head_dim"] = head_dim + cur_res["rotary_dim"] = rotary_dim + cur_res["no_rope_dim"] = no_rope_dim + cur_res["interleave"] = interleave + cur_res["quant_dtype"] = args.quant_dtype + cur_res["page_size"] = page_size + cur_res["kv_layout"] = kv_layout + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res diff --git a/benchmarks/routines/sampling.py b/benchmarks/routines/sampling.py new file mode 100644 index 0000000000..72b4178969 --- /dev/null +++ b/benchmarks/routines/sampling.py @@ -0,0 +1,2025 @@ +""" +Copyright (c) 2026 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections import defaultdict + +import numpy as np +import torch + +import flashinfer +from flashinfer.testing.utils import bench_gpu_time + +from .flashinfer_benchmark_utils import ( + dtype_str_to_torch_dtype, + get_device, + is_close_stats, + print_perf_metrics, + filter_backends_by_compute_capability, +) + + +def run_sampling_test(args): + """ + Run a sampling test. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.routine == "softmax": + return testSoftmax(args) + elif args.routine == "sampling_from_probs": + return testSamplingFromProbs(args) + elif args.routine == "sampling_from_logits": + return testSamplingFromLogits(args) + elif args.routine == "top_k_sampling_from_probs": + return testTopKSamplingFromProbs(args) + elif args.routine == "top_p_sampling_from_probs": + return testTopPSamplingFromProbs(args) + elif args.routine == "top_k_top_p_sampling_from_probs": + return testTopKTopPSamplingFromProbs(args) + elif args.routine == "top_k_top_p_sampling_from_logits": + return testTopKTopPSamplingFromLogits(args) + elif args.routine == "min_p_sampling_from_probs": + return testMinPSamplingFromProbs(args) + elif args.routine == "top_k_renorm_probs": + return testTopKRenormProbs(args) + elif args.routine == "top_p_renorm_probs": + return testTopPRenormProbs(args) + elif args.routine == "top_k_mask_logits": + return testTopKMaskLogits(args) + elif args.routine == "chain_speculative_sampling": + return testChainSpeculativeSampling(args) + elif args.routine == "top_k": + return testTopK(args) + elif args.routine == "top_k_page_table_transform": + return testTopKPageTableTransform(args) + elif args.routine == "top_k_ragged_transform": + return testTopKRaggedTransform(args) + else: + raise ValueError(f"Unsupported routine: {args.routine}") + + +def parse_sampling_args(line, parser): + """ + Parse command line arguments for sampling test configuration. + + Args: + line: Command line arguments + parser: ArgumentParser object already populated with shared arguments + + Returns: + Parsed argument namespace + """ + # Routines that don't use vocab_size (they use max_len instead) + routines_without_vocab_size = [ + "top_k_page_table_transform", + "top_k_ragged_transform", + ] + + # Pre-parse to check routine for conditional requirements + pre_parser = parser + pre_args, _ = pre_parser.parse_known_args(line[:]) + + parser.add_argument( + "--batch_size", + type=int, + required=True, + help="Batch size.", + ) + parser.add_argument( + "--vocab_size", + type=int, + required=(pre_args.routine not in routines_without_vocab_size), + default=None, + help="Vocabulary size.", + ) + parser.add_argument( + "--input_dtype", + type=str, + required=False, + default="float32", + choices=["float32", "float16", "bfloat16"], + help="Data type of the input tensor.", + ) + parser.add_argument( + "--top_k", + type=int, + required=False, + default=50, + help="Top-K value for top-k sampling.", + ) + parser.add_argument( + "--top_p", + type=float, + required=False, + default=0.9, + help="Top-P threshold for top-p sampling.", + ) + parser.add_argument( + "--min_p", + type=float, + required=False, + default=0.1, + help="Min-P threshold for min-p sampling.", + ) + parser.add_argument( + "--temperature", + type=float, + required=False, + default=1.0, + help="Temperature for softmax.", + ) + parser.add_argument( + "--filter_apply_order", + type=str, + required=False, + default="top_k_first", + choices=["top_k_first", "joint"], + help="Order of applying top-k and top-p filters.", + ) + parser.add_argument( + "--num_speculate_tokens", + type=int, + required=False, + default=5, + help="Number of speculative tokens for chain speculative sampling.", + ) + parser.add_argument( + "--max_len", + type=int, + required=False, + default=4096, + help="Max sequence length for top_k_page_table_transform and top_k_ragged_transform.", + ) + parser.add_argument( + "--num_rows", + type=int, + required=False, + default=None, + help="Number of rows for top_k_page_table_transform and top_k_ragged_transform. Defaults to batch_size.", + ) + parser.add_argument( + "--backends", + type=str, + required=False, + nargs="+", + default=["cuda"], + choices=["cuda"], + help="Kernel backends to test. Default: cuda", + ) + + args = parser.parse_args(line) + + # Default num_rows to batch_size if not specified + if args.num_rows is None: + args.num_rows = args.batch_size + + if args.verbose >= 1: + print(f"[INFO] {args = }") + return args + + +def testSoftmax(args): + """ + Test softmax API. + + This test: + 1. Generates random input logits + 2. Runs flashinfer.sampling.softmax + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testSoftmax") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + temperature = args.temperature + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + logits = torch.randn(batch_size, vocab_size, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {logits.shape = }") + print(f"[VVERBOSE] {logits.dtype = }") + + def run_backend(backend, logits): + if backend == "cuda": + return flashinfer.sampling.softmax(logits, temperature=temperature) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results and outputs for refcheck + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend(cur_backend, logits).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, logits), + ) + + # Reference check: compare against PyTorch softmax + if run_refcheck and outputs: + reference_output = torch.softmax(logits.float() / temperature, dim=-1) + for backend, output in outputs.items(): + num_diff, num_total, pct_diff = is_close_stats( + reference_output, output.float(), rtol=1e-3, atol=1e-5 + ) + if num_diff > 0: + print( + f"[REFCHECK] Backend {backend}: {num_diff}/{num_total} " + f"({pct_diff:.2f}%) elements differ from PyTorch reference" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {backend} output mismatch with {num_diff} elements" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for softmax + # Read: logits (input) + # Write: probs (float32 output) + problem_bytes = ( + batch_size * vocab_size * input_dtype.itemsize # input read + + batch_size * vocab_size * 4 # output write (float32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["temperature"] = temperature + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testSamplingFromProbs(args): + """ + Test sampling_from_probs API. + + This test: + 1. Generates random input probabilities + 2. Runs flashinfer.sampling.sampling_from_probs + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testSamplingFromProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + # Use explicit seed/offset to enable CUDA graph compatibility + seed = args.random_seed + offset = 0 + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + ## Prepare input tensors (probs are always float32) + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.sampling_from_probs( + probs, seed=seed, offset=offset + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: probs (float32) + # Write: samples (int32) + problem_bytes = ( + batch_size * vocab_size * 4 # probs read (float32) + + batch_size * 4 # samples write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testSamplingFromLogits(args): + """ + Test sampling_from_logits API. + + This test: + 1. Generates random input logits + 2. Runs flashinfer.sampling.sampling_from_logits + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testSamplingFromLogits") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + # Use explicit seed/offset to enable CUDA graph compatibility + seed = args.random_seed + offset = 0 + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + logits = torch.randn(batch_size, vocab_size, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {logits.shape = }") + print(f"[VVERBOSE] {logits.dtype = }") + + def run_backend(backend, logits): + if backend == "cuda": + return flashinfer.sampling.sampling_from_logits( + logits, seed=seed, offset=offset + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, logits), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: logits (input_dtype) + # Write: samples (int32) + problem_bytes = ( + batch_size * vocab_size * input_dtype.itemsize # logits read + + batch_size * 4 # samples write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKSamplingFromProbs(args): + """ + Test top_k_sampling_from_probs API. + + This test: + 1. Generates random input probabilities + 2. Runs flashinfer.sampling.top_k_sampling_from_probs + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopKSamplingFromProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + # Use explicit seed/offset to enable CUDA graph compatibility + seed = args.random_seed + offset = 0 + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + ## Prepare input tensors + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_k = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_k_sampling_from_probs( + probs, top_k, seed=seed, offset=offset + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: probs (float32) + # Write: samples (int32) + problem_bytes = ( + batch_size * vocab_size * 4 # probs read (float32) + + batch_size * 4 # samples write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopPSamplingFromProbs(args): + """ + Test top_p_sampling_from_probs API. + + This test: + 1. Generates random input probabilities + 2. Runs flashinfer.sampling.top_p_sampling_from_probs + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopPSamplingFromProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_p = args.top_p + # Use explicit seed/offset to enable CUDA graph compatibility + seed = args.random_seed + offset = 0 + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + ## Prepare input tensors + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_p = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_p_sampling_from_probs( + probs, top_p, seed=seed, offset=offset + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: probs (float32) + # Write: samples (int32) + problem_bytes = ( + batch_size * vocab_size * 4 # probs read (float32) + + batch_size * 4 # samples write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["top_p"] = top_p + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKTopPSamplingFromProbs(args): + """ + Test top_k_top_p_sampling_from_probs API. + + This test: + 1. Generates random input probabilities + 2. Runs flashinfer.sampling.top_k_top_p_sampling_from_probs + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopKTopPSamplingFromProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + top_p = args.top_p + filter_apply_order = args.filter_apply_order + # Use explicit seed/offset to enable CUDA graph compatibility + seed = args.random_seed + offset = 0 + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + ## Prepare input tensors + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_k = }") + print(f"[VVERBOSE] {top_p = }") + print(f"[VVERBOSE] {filter_apply_order = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_k_top_p_sampling_from_probs( + probs, + top_k, + top_p, + filter_apply_order=filter_apply_order, + seed=seed, + offset=offset, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: probs (float32) + # Write: samples (int32) + problem_bytes = ( + batch_size * vocab_size * 4 # probs read (float32) + + batch_size * 4 # samples write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["top_p"] = top_p + cur_res["filter_apply_order"] = filter_apply_order + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKTopPSamplingFromLogits(args): + """ + Test top_k_top_p_sampling_from_logits API. + + This test: + 1. Generates random input logits + 2. Runs flashinfer.sampling.top_k_top_p_sampling_from_logits + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopKTopPSamplingFromLogits") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + top_p = args.top_p + filter_apply_order = args.filter_apply_order + # Use explicit seed/offset to enable CUDA graph compatibility + seed = args.random_seed + offset = 0 + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + logits = torch.randn(batch_size, vocab_size, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {logits.shape = }") + print(f"[VVERBOSE] {logits.dtype = }") + print(f"[VVERBOSE] {top_k = }") + print(f"[VVERBOSE] {top_p = }") + print(f"[VVERBOSE] {filter_apply_order = }") + + def run_backend(backend, logits): + if backend == "cuda": + return flashinfer.sampling.top_k_top_p_sampling_from_logits( + logits, + top_k, + top_p, + filter_apply_order=filter_apply_order, + seed=seed, + offset=offset, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, logits), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: logits (input_dtype) + # Write: samples (int32) + problem_bytes = ( + batch_size * vocab_size * input_dtype.itemsize # logits read + + batch_size * 4 # samples write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["top_p"] = top_p + cur_res["filter_apply_order"] = filter_apply_order + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testMinPSamplingFromProbs(args): + """ + Test min_p_sampling_from_probs API. + + This test: + 1. Generates random input probabilities + 2. Runs flashinfer.sampling.min_p_sampling_from_probs + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testMinPSamplingFromProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + min_p = args.min_p + # Use explicit seed/offset to enable CUDA graph compatibility + seed = args.random_seed + offset = 0 + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + ## Prepare input tensors + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {min_p = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.min_p_sampling_from_probs( + probs, min_p, seed=seed, offset=offset + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: probs (float32) + # Write: samples (int32) + problem_bytes = ( + batch_size * vocab_size * 4 # probs read (float32) + + batch_size * 4 # samples write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["min_p"] = min_p + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKRenormProbs(args): + """ + Test top_k_renorm_probs API. + + This test: + 1. Generates random input probabilities + 2. Runs flashinfer.sampling.top_k_renorm_probs + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopKRenormProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + pre_norm_prob = torch.rand(batch_size, vocab_size, dtype=input_dtype, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_k = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_k_renorm_probs(probs, top_k) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results and outputs for refcheck + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend(cur_backend, probs).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + # Reference check: PyTorch implementation of top-k renormalization + # Keep top-k values, set rest to 0, then renormalize + # Note: Small mismatches can occur due to tie-breaking at the k-th boundary + if run_refcheck and outputs: + topk_vals, topk_indices = torch.topk(probs.float(), k=top_k, dim=-1) + reference_output = torch.zeros_like(probs, dtype=torch.float32) + reference_output.scatter_(-1, topk_indices, topk_vals) + reference_output = reference_output / reference_output.sum(dim=-1, keepdim=True) + reference_output = reference_output.to(input_dtype) + + for backend, output in outputs.items(): + num_diff, num_total, pct_diff = is_close_stats( + reference_output.float(), output.float(), rtol=1e-2, atol=1e-2 + ) + # Allow tiny mismatch percentage (<0.01%) due to tie-breaking at k-th boundary + if pct_diff > 0.01: + print( + f"[REFCHECK] Backend {backend}: {num_diff}/{num_total} " + f"({pct_diff:.2f}%) elements differ from PyTorch reference" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {backend} output mismatch with {num_diff} elements" + ) + elif num_diff > 0 and args.verbose >= 1: + print( + f"[REFCHECK] Backend {backend}: {num_diff}/{num_total} " + f"({pct_diff:.4f}%) elements differ (within acceptable threshold)" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: probs (input_dtype) + # Write: renorm_probs (input_dtype) + problem_bytes = ( + batch_size * vocab_size * input_dtype.itemsize # probs read + + batch_size * vocab_size * input_dtype.itemsize # renorm_probs write + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopPRenormProbs(args): + """ + Test top_p_renorm_probs API. + + This test: + 1. Generates random input probabilities + 2. Runs flashinfer.sampling.top_p_renorm_probs + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopPRenormProbs") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_p = args.top_p + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + ## Prepare input tensors (top_p_renorm_probs uses float32) + pre_norm_prob = torch.rand(batch_size, vocab_size, device=device) + probs = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + + if args.verbose >= 2: + print(f"[VVERBOSE] {probs.shape = }") + print(f"[VVERBOSE] {probs.dtype = }") + print(f"[VVERBOSE] {top_p = }") + + def run_backend(backend, probs): + if backend == "cuda": + return flashinfer.sampling.top_p_renorm_probs(probs, top_p) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results and outputs for refcheck + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend(cur_backend, probs).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, probs), + ) + + # Reference check: PyTorch implementation of top-p renormalization + # Sort probs descending, compute cumsum, threshold at top_p, renormalize + if run_refcheck and outputs: + sorted_probs, sorted_indices = torch.sort( + probs.float(), dim=-1, descending=True + ) + cumsum_probs = torch.cumsum(sorted_probs, dim=-1) + # Create mask: keep probs where cumsum <= top_p (shift by 1 to include the boundary element) + cumsum_shifted = torch.cat( + [torch.zeros_like(cumsum_probs[:, :1]), cumsum_probs[:, :-1]], dim=-1 + ) + mask = cumsum_shifted < top_p + # Keep at least one element per row + mask[:, 0] = True + # Zero out elements beyond top-p threshold + sorted_probs_masked = sorted_probs * mask.float() + # Scatter back to original positions + reference_output = torch.zeros_like(probs) + reference_output.scatter_(-1, sorted_indices, sorted_probs_masked) + # Renormalize + reference_output = reference_output / reference_output.sum(dim=-1, keepdim=True) + + for backend, output in outputs.items(): + num_diff, num_total, pct_diff = is_close_stats( + reference_output, output.float(), rtol=1e-2, atol=1e-3 + ) + if num_diff > 0: + print( + f"[REFCHECK] Backend {backend}: {num_diff}/{num_total} " + f"({pct_diff:.2f}%) elements differ from PyTorch reference" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {backend} output mismatch with {num_diff} elements" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: probs (float32) + # Write: renorm_probs (float32) + problem_bytes = ( + batch_size * vocab_size * 4 # probs read (float32) + + batch_size * vocab_size * 4 # renorm_probs write (float32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["top_p"] = top_p + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKMaskLogits(args): + """ + Test top_k_mask_logits API. + + This test: + 1. Generates random input logits + 2. Runs flashinfer.sampling.top_k_mask_logits + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopKMaskLogits") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + logits = torch.randn(batch_size, vocab_size, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {logits.shape = }") + print(f"[VVERBOSE] {logits.dtype = }") + print(f"[VVERBOSE] {top_k = }") + + def run_backend(backend, logits): + if backend == "cuda": + return flashinfer.sampling.top_k_mask_logits(logits, top_k) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results and outputs for refcheck + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend(cur_backend, logits).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, logits), + ) + + # Reference check: PyTorch implementation of top-k masking + # Keep top-k logits, set rest to -inf + if run_refcheck and outputs: + topk_vals, topk_indices = torch.topk(logits.float(), k=top_k, dim=-1) + reference_output = torch.full_like(logits, float("-inf"), dtype=torch.float32) + reference_output.scatter_(-1, topk_indices, topk_vals) + reference_output = reference_output.to(input_dtype) + + for backend, output in outputs.items(): + out = output.float() + ref = reference_output.float() + + # Check that the same positions are masked (-inf) + out_masked = torch.isinf(out) & (out < 0) + ref_masked = torch.isinf(ref) & (ref < 0) + mask_match = (out_masked == ref_masked).all() + + if not mask_match: + num_mask_diff = (out_masked != ref_masked).sum().item() + print( + f"[REFCHECK] Backend {backend}: Mask mismatch - " + f"{num_mask_diff} positions have different masking" + ) + if not args.allow_output_mismatch: + raise AssertionError(f"[ERROR] Backend {backend} mask mismatch") + else: + # Check that unmasked values match the reference + unmasked_positions = ~out_masked + if unmasked_positions.any(): + num_diff, num_total, pct_diff = is_close_stats( + ref[unmasked_positions], + out[unmasked_positions], + rtol=1e-3, + atol=1e-5, + ) + if num_diff > 0: + print( + f"[REFCHECK] Backend {backend}: {num_diff}/{num_total} " + f"({pct_diff:.2f}%) unmasked elements differ" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {backend} value mismatch" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: logits (input_dtype) + # Write: masked_logits (input_dtype) + problem_bytes = ( + batch_size * vocab_size * input_dtype.itemsize # logits read + + batch_size * vocab_size * input_dtype.itemsize # masked_logits write + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testChainSpeculativeSampling(args): + """ + Test chain_speculative_sampling API. + + This test: + 1. Generates random draft and target probabilities + 2. Runs flashinfer.sampling.chain_speculative_sampling + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testChainSpeculativeSampling") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + num_speculate_tokens = args.num_speculate_tokens + # Use explicit seed/offset to enable CUDA graph compatibility + seed = args.random_seed + offset = 0 + is_cuda_graph_compatible = not args.no_cuda_graph + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + ## Prepare input tensors + # Draft probs: (batch_size, num_speculate_tokens, vocab_size) + pre_norm_draft_prob = torch.rand( + batch_size, num_speculate_tokens, vocab_size, device=device + ) + draft_probs = pre_norm_draft_prob / pre_norm_draft_prob.sum(dim=-1, keepdim=True) + + # Draft token IDs: (batch_size, num_speculate_tokens) + draft_token_ids = torch.randint( + vocab_size, (batch_size, num_speculate_tokens), device=device, dtype=torch.int32 + ) + + # Target probs: (batch_size, num_speculate_tokens + 1, vocab_size) + pre_norm_target_prob = torch.rand( + batch_size, num_speculate_tokens + 1, vocab_size, device=device + ) + target_probs = pre_norm_target_prob / pre_norm_target_prob.sum(dim=-1, keepdim=True) + + if args.verbose >= 2: + print(f"[VVERBOSE] {draft_probs.shape = }") + print(f"[VVERBOSE] {draft_token_ids.shape = }") + print(f"[VVERBOSE] {target_probs.shape = }") + print(f"[VVERBOSE] {num_speculate_tokens = }") + + def run_backend(backend, draft_probs, draft_token_ids, target_probs): + if backend == "cuda": + return flashinfer.sampling.chain_speculative_sampling( + draft_probs, draft_token_ids, target_probs, seed=seed, offset=offset + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results + backend_times = {backend: [] for backend in backends} + for cur_backend in backends: + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, draft_probs, draft_token_ids, target_probs), + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + n = num_speculate_tokens + # Memory bandwidth calculation + # Read: draft_probs + draft_token_ids + target_probs + # Write: output_token_ids + accepted_num + emitted_num + problem_bytes = ( + batch_size * n * vocab_size * 4 # draft_probs read (float32) + + batch_size * n * 4 # draft_token_ids read (int32) + + batch_size * (n + 1) * vocab_size * 4 # target_probs read (float32) + + batch_size * (n + 1) * 4 # output_token_ids write (int32) + + batch_size * 4 # accepted_num write (int32) + + batch_size * 4 # emitted_num write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["num_speculate_tokens"] = num_speculate_tokens + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopK(args): + """ + Test top_k API (radix-based top-k selection). + + This test: + 1. Generates random input tensor + 2. Runs flashinfer.top_k + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopK") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + vocab_size = args.vocab_size + top_k = args.top_k + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + input_tensor = torch.randn(batch_size, vocab_size, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {top_k = }") + + def run_backend(backend, input_tensor): + if backend == "cuda": + return flashinfer.top_k(input_tensor, top_k) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results and outputs for refcheck + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend(cur_backend, input_tensor) + # top_k returns (values, indices) tuple - detach both + outputs[cur_backend] = ( + outputs[cur_backend][0].detach(), + outputs[cur_backend][1].detach(), + ) + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_tensor), + ) + + # Reference check: compare against PyTorch torch.topk + # Note: FlashInfer top_k returns UNSORTED results by default, so we compare + # sorted values to verify the same elements are selected + if run_refcheck and outputs: + ref_values, ref_indices = torch.topk(input_tensor.float(), k=top_k, dim=-1) + + for backend, (out_values, out_indices) in outputs.items(): + # Sort both outputs to compare (FlashInfer returns unsorted by default) + ref_sorted, _ = torch.sort(ref_values, dim=-1, descending=True) + out_sorted, _ = torch.sort(out_values.float(), dim=-1, descending=True) + + # Check sorted values match + num_diff_vals, num_total_vals, pct_diff_vals = is_close_stats( + ref_sorted, out_sorted, rtol=1e-3, atol=1e-5 + ) + + # Verify indices point to correct values in original tensor + gathered_vals = torch.gather( + input_tensor.float(), dim=-1, index=out_indices + ) + idx_vals_match = torch.allclose( + gathered_vals, out_values.float(), rtol=1e-3, atol=1e-5 + ) + + if num_diff_vals > 0 or not idx_vals_match: + if num_diff_vals > 0: + print( + f"[REFCHECK] Backend {backend}: {num_diff_vals}/{num_total_vals} " + f"({pct_diff_vals:.2f}%) sorted values differ from PyTorch reference" + ) + if not idx_vals_match: + print( + f"[REFCHECK] Backend {backend}: indices don't match their values" + ) + if not args.allow_output_mismatch: + raise AssertionError(f"[ERROR] Backend {backend} output mismatch") + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: input (input_dtype) + # Write: values (input_dtype) + indices (int64) + problem_bytes = ( + batch_size * vocab_size * input_dtype.itemsize # input read + + batch_size * top_k * input_dtype.itemsize # values write + + batch_size * top_k * 8 # indices write (int64) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["vocab_size"] = vocab_size + cur_res["top_k"] = top_k + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKPageTableTransform(args): + """ + Test top_k_page_table_transform API. + + This test: + 1. Generates random input scores and page table + 2. Runs flashinfer.top_k_page_table_transform + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopKPageTableTransform") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + batch_size = args.batch_size + num_rows = args.num_rows + max_len = args.max_len + top_k = args.top_k + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + # Input scores: (num_rows, max_len) + input_scores = torch.randn(num_rows, max_len, dtype=input_dtype, device=device) + + # Source page table: (batch_size, max_len) + src_page_table = torch.randint( + 0, 1000, (batch_size, max_len), dtype=torch.int32, device=device + ) + + # Lengths: (num_rows,) + lengths = torch.full((num_rows,), max_len, dtype=torch.int32, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_scores.shape = }") + print(f"[VVERBOSE] {input_scores.dtype = }") + print(f"[VVERBOSE] {src_page_table.shape = }") + print(f"[VVERBOSE] {lengths.shape = }") + print(f"[VVERBOSE] {top_k = }") + + def run_backend(backend, input_scores, src_page_table, lengths): + if backend == "cuda": + return flashinfer.top_k_page_table_transform( + input_scores, src_page_table, lengths, top_k + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results and outputs for refcheck + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend( + cur_backend, input_scores, src_page_table, lengths + ).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_scores, src_page_table, lengths), + ) + + # Reference check: PyTorch implementation of top-k + page table transform + # For each row i: output[i, j] = src_page_table[i, topk_indices[i, j]] + # Note: FlashInfer uses unsorted top-k internally, so we compare sorted sets per row + if run_refcheck and outputs: + # Get top-k indices + _, topk_indices = torch.topk(input_scores.float(), k=top_k, dim=-1) + # Gather from page table - row index maps to batch index (1:1 when row_to_batch is None) + reference_output = torch.gather( + src_page_table[:num_rows], dim=-1, index=topk_indices.int() + ) + + for backend, output in outputs.items(): + # Sort both outputs per row to compare sets (order may differ due to unsorted top-k) + ref_sorted, _ = torch.sort(reference_output, dim=-1) + out_sorted, _ = torch.sort(output, dim=-1) + + matches = (ref_sorted == out_sorted).all() + if not matches: + num_diff = (ref_sorted != out_sorted).sum().item() + num_total = reference_output.numel() + pct_diff = num_diff / num_total * 100.0 + print( + f"[REFCHECK] Backend {backend}: {num_diff}/{num_total} " + f"({pct_diff:.2f}%) sorted page table entries differ from PyTorch reference" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {backend} output mismatch with {num_diff} elements" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: input_scores + src_page_table + # Write: output_page_table + problem_bytes = ( + num_rows * max_len * input_dtype.itemsize # input_scores read + + batch_size * max_len * 4 # src_page_table read (int32) + + num_rows * top_k * 4 # output_page_table write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["max_len"] = max_len + cur_res["num_rows"] = num_rows + cur_res["top_k"] = top_k + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testTopKRaggedTransform(args): + """ + Test top_k_ragged_transform API. + + This test: + 1. Generates random input scores + 2. Runs flashinfer.top_k_ragged_transform + 3. Measures performance metrics (TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testTopKRaggedTransform") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] + num_rows = args.num_rows + max_len = args.max_len + top_k = args.top_k + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + + ## Prepare input tensors + # Input scores: (num_rows, max_len) + input_scores = torch.randn(num_rows, max_len, dtype=input_dtype, device=device) + + # Offsets: (num_rows,) + offsets = torch.arange( + 0, num_rows * max_len, max_len, dtype=torch.int32, device=device + ) + + # Lengths: (num_rows,) + lengths = torch.full((num_rows,), max_len, dtype=torch.int32, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_scores.shape = }") + print(f"[VVERBOSE] {input_scores.dtype = }") + print(f"[VVERBOSE] {offsets.shape = }") + print(f"[VVERBOSE] {lengths.shape = }") + print(f"[VVERBOSE] {top_k = }") + + def run_backend(backend, input_scores, offsets, lengths): + if backend == "cuda": + return flashinfer.top_k_ragged_transform( + input_scores, offsets, lengths, top_k + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results and outputs for refcheck + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend( + cur_backend, input_scores, offsets, lengths + ).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_scores, offsets, lengths), + ) + + # Reference check: PyTorch implementation of top-k + ragged index transform + # For each row i: output_indices[i, j] = topk_indices[i, j] + offsets[i] + # Note: FlashInfer uses unsorted top-k internally, so we compare sorted sets per row + if run_refcheck and outputs: + # Get top-k indices + _, topk_indices = torch.topk(input_scores.float(), k=top_k, dim=-1) + # Add offsets to each row's indices + reference_output = topk_indices.int() + offsets.unsqueeze(-1) + + for backend, output in outputs.items(): + # Sort both outputs per row to compare sets (order may differ due to unsorted top-k) + ref_sorted, _ = torch.sort(reference_output, dim=-1) + out_sorted, _ = torch.sort(output, dim=-1) + + matches = (ref_sorted == out_sorted).all() + if not matches: + num_diff = (ref_sorted != out_sorted).sum().item() + num_total = reference_output.numel() + pct_diff = num_diff / num_total * 100.0 + print( + f"[REFCHECK] Backend {backend}: {num_diff}/{num_total} " + f"({pct_diff:.2f}%) sorted ragged indices differ from PyTorch reference" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {backend} output mismatch with {num_diff} elements" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation + # Read: input_scores + # Write: output_indices + problem_bytes = ( + num_rows * max_len * input_dtype.itemsize # input_scores read + + num_rows * top_k * 4 # output_indices write (int32) + ) + tflops = 0 # Memory-bound operation + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["max_len"] = max_len + cur_res["num_rows"] = num_rows + cur_res["top_k"] = top_k + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res diff --git a/benchmarks/samples/sample_testlist.txt b/benchmarks/samples/sample_testlist.txt index 03a33f33ec..38cef5a4ad 100644 --- a/benchmarks/samples/sample_testlist.txt +++ b/benchmarks/samples/sample_testlist.txt @@ -1,14 +1,10 @@ ## Attention Prefill -# Paged prefill for Llama 3.1 70B --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag "Llama-3.1-70B" ---routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 32 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag "Llama-3.1-70B" # Ragged prefill for DeepSeep-R1 ---routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag "DeepSeek-R1" --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag "DeepSeek-R1" ## Attention Decode ---routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 1 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag "Llama-3.1-70B" --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag "Llama-3.1-70B" ## Attention MLA @@ -16,24 +12,20 @@ --routine BatchMLAPagedAttentionWrapper --backends trtllm-gen-native fa2 fa3 --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag "DeepSeek-R1" ## FP8 bmm ---routine bmm_fp8 --batch_size 256 --m 1 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command --routine bmm_fp8 --batch_size 64 --m 4 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command ## FP8 GEMM with groupwise scaling ---routine gemm_fp8_nt_groupwise --m 4 --n 1024 --k 7168 --mma_sm 1 --scale_major_mode MN --backends cutlass --refcheck -vv --generate_repro_command --routine gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --scale_major_mode MN --backends cutlass --refcheck -vv --generate_repro_command ## FP8 group GEMM with groupwise scaling ---routine group_gemm_fp8_nt_groupwise --m 4 --n 1024 --k 7168 --mma_sm 1 --group_size 2 --scale_major_mode MN --refcheck -vv --generate_repro_command --routine group_gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --group_size 2 --scale_major_mode MN --refcheck -vv --generate_repro_command ## FP4 GEMM # non-autotuned ---routine mm_fp4 --m 1 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command ---routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command +--routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command # autotuned ---routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command +--routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command ## MoE --routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 256 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight -vv --generate_repro_command --case_tag "trtllm_moe_sample" @@ -175,3 +167,83 @@ --routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_basic" --routine nvfp4_batched_quantize --batch_size 8 --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_large" --routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype float16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_fp16" + +## Sampling +# Basic softmax with temperature +--routine softmax --batch_size 32 --vocab_size 32000 --temperature 1.0 --input_dtype float32 -vv --generate_repro_command --case_tag "softmax_llama" +--routine softmax --batch_size 64 --vocab_size 128256 --temperature 0.8 --input_dtype float32 -vv --generate_repro_command --case_tag "softmax_llama3_temp" + +# Sampling from probabilities +--routine sampling_from_probs --batch_size 32 --vocab_size 32000 -vv --generate_repro_command --case_tag "sampling_from_probs_llama" +--routine sampling_from_probs --batch_size 64 --vocab_size 128256 -vv --generate_repro_command --case_tag "sampling_from_probs_llama3" + +# Sampling from logits (fused softmax + sampling) +--routine sampling_from_logits --batch_size 32 --vocab_size 32000 --input_dtype float32 -vv --generate_repro_command --case_tag "sampling_from_logits_llama" +--routine sampling_from_logits --batch_size 64 --vocab_size 128256 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "sampling_from_logits_llama3" + +# Top-K sampling +--routine top_k_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_k 50 -vv --generate_repro_command --case_tag "top_k_sampling_k50" +--routine top_k_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_k 100 -vv --generate_repro_command --case_tag "top_k_sampling_k100" + +# Top-P (nucleus) sampling +--routine top_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_p 0.9 -vv --generate_repro_command --case_tag "top_p_sampling_p09" +--routine top_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_p 0.95 -vv --generate_repro_command --case_tag "top_p_sampling_p095" + +# Combined Top-K + Top-P sampling +--routine top_k_top_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_k 50 --top_p 0.9 --filter_apply_order top_k_first -vv --generate_repro_command --case_tag "top_k_top_p_probs" +--routine top_k_top_p_sampling_from_logits --batch_size 32 --vocab_size 32000 --top_k 50 --top_p 0.9 --filter_apply_order top_k_first --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_top_p_logits" + +# Min-P sampling +--routine min_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --min_p 0.1 -vv --generate_repro_command --case_tag "min_p_sampling_p01" +--routine min_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --min_p 0.05 -vv --generate_repro_command --case_tag "min_p_sampling_p005" + +# Top-K renormalization +--routine top_k_renorm_probs --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_renorm" + +# Top-P renormalization +--routine top_p_renorm_probs --batch_size 32 --vocab_size 32000 --top_p 0.9 -vv --generate_repro_command --case_tag "top_p_renorm" + +# Top-K mask logits +--routine top_k_mask_logits --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_mask" + +# Chain speculative sampling (for speculative decoding) +--routine chain_speculative_sampling --batch_size 16 --vocab_size 32000 --num_speculate_tokens 5 -vv --generate_repro_command --case_tag "chain_spec_sampling_5" +--routine chain_speculative_sampling --batch_size 32 --vocab_size 128256 --num_speculate_tokens 8 -vv --generate_repro_command --case_tag "chain_spec_sampling_8" + +# Top-K selection (radix-based) +--routine top_k --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_radix" +--routine top_k --batch_size 64 --vocab_size 128256 --top_k 100 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "top_k_radix_large" + +# Top-K with page table transform +--routine top_k_page_table_transform --batch_size 16 --num_rows 16 --max_len 4096 --top_k 64 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_page_table" + +# Top-K with ragged transform +--routine top_k_ragged_transform --batch_size 16 --num_rows 16 --max_len 4096 --top_k 64 --input_dtype float32 -vv --generate_repro_command --case_tag "top_k_ragged" + +## RoPE (Rotary Positional Embeddings) +# Basic RoPE with indptr/offsets +--routine apply_rope --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag "apply_rope_llama" +--routine apply_rope --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "apply_rope_llama70b" + +# RoPE with position IDs +--routine apply_rope_pos_ids --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag "apply_rope_pos_ids" +--routine apply_rope_pos_ids --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --interleave -vv --generate_repro_command --case_tag "apply_rope_pos_ids_interleave" + +# Llama 3.1 style RoPE +--routine apply_llama31_rope --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --rope_theta 500000.0 --rope_scale 1.0 --low_freq_factor 1.0 --high_freq_factor 4.0 --old_context_len 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "apply_llama31_rope" +--routine apply_llama31_rope_pos_ids --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --rope_theta 500000.0 --rope_scale 1.0 --low_freq_factor 1.0 --high_freq_factor 4.0 --old_context_len 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "apply_llama31_rope_pos_ids" + +# RoPE with cos/sin cache +--routine apply_rope_with_cos_sin_cache --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag "apply_rope_cos_sin_cache" +--routine apply_rope_with_cos_sin_cache --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --interleave -vv --generate_repro_command --case_tag "apply_rope_cos_sin_cache_interleave" + +# MLA RoPE with FP8 quantization (SM8.9+ required) +--routine mla_rope_quantize_fp8 --batch_size 16 --seq_len 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim 192 --no_rope_dim 64 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "mla_rope_fp8_deepseek" + +# RoPE with FP8 quantization (SM8.9+ required) +--routine rope_quantize_fp8 --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "rope_fp8_llama" +--routine rope_quantize_fp8 --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "rope_fp8_llama70b" + +# RoPE with FP8 quantization and paged KV cache append (SM8.9+ required) +--routine rope_quantize_fp8_append_paged_kv_cache --batch_size 16 --seq_len 64 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --page_size 16 --kv_layout NHD --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "rope_fp8_paged_kv" +--routine rope_quantize_fp8_append_paged_kv_cache --batch_size 32 --seq_len 64 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --page_size 16 --kv_layout HND --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag "rope_fp8_paged_kv_hnd" diff --git a/benchmarks/samples/sample_testlist_output.csv b/benchmarks/samples/sample_testlist_output.csv index b07c523ecb..92d6985144 100644 --- a/benchmarks/samples/sample_testlist_output.csv +++ b/benchmarks/samples/sample_testlist_output.csv @@ -1,53 +1,112 @@ -routine,median_time,std_time,tflops,tb_per_sec,backend,page_size,batch_size,s_qo,s_kv,num_qo_heads,num_kv_heads,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,m,n,k,group_size,tile_size,scale_major_mode,out_dtype,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,hidden_size,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,input_dtype,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command -BatchPrefillWithPagedKVCacheWrapper,0.01244799979031086,0.0009464459008260536,13.963516944729905,0.3050282827732261,fa2,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchPrefillWithPagedKVCacheWrapper,0.01839040070772171,0.00021363710731210026,9.45155349045863,0.20646597430613514,cudnn,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchPrefillWithPagedKVCacheWrapper,0.008396799862384795,5.550615129103214e-05,20.70048814413847,0.45219512936224815,trtllm-gen,16,1,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchPrefillWithPagedKVCacheWrapper,0.4833280146121979,0.003473954933671819,250.42114497152383,0.9745931908746264,fa2,16,32,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,399,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 32 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchPrefillWithPagedKVCacheWrapper,0.3817088007926941,0.0008139816712432105,317.08871937101173,1.2340511694301386,cudnn,16,32,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,399,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 32 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchPrefillWithPagedKVCacheWrapper,0.7442896127700807,0.00045553586925676576,162.6188955741738,0.6328829314799427,trtllm-gen,16,32,1024,1024,64,8,128,128,,,True,torch.bfloat16,torch.bfloat16,399,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 32 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchPrefillWithRaggedKVCacheWrapper,0.016127999871969223,0.00017067009388203107,26.943492277380717,1.0463492146555617,fa2,0,1,1024,1024,128,128,192,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 -BatchPrefillWithRaggedKVCacheWrapper,0.012083200365304947,9.971927609146905e-05,35.962710777165306,1.3966101272685556,cutlass,0,1,1024,1024,128,128,192,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 -BatchPrefillWithRaggedKVCacheWrapper,0.018636800348758698,0.00019618687934467522,23.316483080151837,0.9054944885495858,cudnn,0,1,1024,1024,128,128,192,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 -BatchPrefillWithRaggedKVCacheWrapper,0.49769599735736847,0.0053031528422255855,217.96787873723878,1.7256503660070768,fa2,0,16,1024,1024,128,128,192,128,,,True,torch.bfloat16,torch.bfloat16,327,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 -BatchPrefillWithRaggedKVCacheWrapper,0.5328896045684814,0.0014588070313195434,203.57263468827725,1.6116833067056566,cutlass,0,16,1024,1024,128,128,192,128,,,True,torch.bfloat16,torch.bfloat16,327,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 -BatchPrefillWithRaggedKVCacheWrapper,0.3123199939727783,0.0005266243249355331,347.3416460473396,2.7499016924124846,cudnn,0,16,1024,1024,128,128,192,128,,,True,torch.bfloat16,torch.bfloat16,327,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 -BatchDecodeWithPagedKVCacheWrapper,0.03481600061058998,0.00022415261777036224,0.07905882214290773,0.0108235292219457,fa2,16,1,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,84,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 1 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchDecodeWithPagedKVCacheWrapper,0.010452799871563912,4.634408684090133e-05,0.2633277240376533,0.03605081936229778,fa2_tc,16,1,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,84,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 1 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchDecodeWithPagedKVCacheWrapper,0.010683200135827065,8.551724418800796e-05,0.25764864132510307,0.035273325895698635,cudnn,16,1,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,84,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 1 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchDecodeWithPagedKVCacheWrapper,0.005734400078654289,3.59118315932623e-05,0.4799999934162147,0.06571428481293416,trtllm-gen,16,1,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,84,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 1 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchDecodeWithPagedKVCacheWrapper,0.006188799999654293,9.38182546863553e-05,0.444756980376447,0.06088934850391835,trtllm-gen-native,16,1,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,84,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 1 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchDecodeWithPagedKVCacheWrapper,0.055296000093221664,7.235326722628814e-05,4.756740732721512,0.604074073055686,fa2,16,16,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchDecodeWithPagedKVCacheWrapper,0.017208000272512437,9.39948911411913e-05,15.28525870726272,1.941125027372111,fa2_tc,16,16,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchDecodeWithPagedKVCacheWrapper,0.013764800131320953,9.368704015807337e-05,19.108794424228098,2.426688341372557,cudnn,16,16,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchDecodeWithPagedKVCacheWrapper,0.00963199995458126,3.5251792291711315e-05,27.30780079321905,3.467906993096757,trtllm-gen,16,16,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchDecodeWithPagedKVCacheWrapper,0.009657599776983262,6.567264399555296e-05,27.2354148105071,3.458714460254226,trtllm-gen-native,16,16,1,1024,64,8,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -BatchMLAPagedAttentionWrapper,0.024420800060033797,0.00010761519902284579,91.55081939697266,0.9553665704090659,trtllm-gen-native,32,16,1,1024,128,,,,512,64,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends trtllm-gen-native fa2 fa3 --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag DeepSeek-R1 -BatchMLAPagedAttentionWrapper,0.04095999896526337,0.0004555166043636676,54.58359909057617,0.5696000143893066,fa2,32,16,1,1024,128,,,,512,64,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends trtllm-gen-native fa2 fa3 --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag DeepSeek-R1 -bmm_fp8,0.2860383987426758,0.0004968334033247884,13.13843316323707,0.02569322172234466,cudnn,,256,,,,,,,,,,,,,,1,1024,7168,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,,torch.float8_e4m3fn,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 256 --m 1 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command -bmm_fp8,0.28600159883499143,0.00047219507506758893,13.14012368919739,0.025696527676546826,cublas,,256,,,,,,,,,,,,,,1,1024,7168,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,,torch.float8_e4m3fn,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 256 --m 1 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command -bmm_fp8,0.2657311916351318,0.00010204157533447904,14.142473681298727,0.027656700573153076,cutlass,,256,,,,,,,,,,,,,,1,1024,7168,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,,torch.float8_e4m3fn,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 256 --m 1 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command -bmm_fp8,0.07516320049762726,0.000206792690088806,49.99915329734576,0.09814504905539344,cudnn,,64,,,,,,,,,,,,,,4,1024,7168,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,,torch.float8_e4m3fn,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 64 --m 4 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command -bmm_fp8,0.07495999932289124,0.0002273436657742046,50.13469074101705,0.09841110014187592,cublas,,64,,,,,,,,,,,,,,4,1024,7168,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,,torch.float8_e4m3fn,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 64 --m 4 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command -bmm_fp8,0.072297602891922,9.86900944781081e-05,51.980926526955464,0.10203513954712654,cutlass,,64,,,,,,,,,,,,,,4,1024,7168,,,,torch.bfloat16,,,,,,,,,,,,,,,,,,,,torch.float8_e4m3fn,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 64 --m 4 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command -gemm_fp8_nt_groupwise,0.01966080069541931,1.0062279270694632e-06,2.986666561025716,0.37520832006189414,cutlass,,,,,,,,,,,,,,,,4,1024,7168,,128,MN,torch.bfloat16,1,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine gemm_fp8_nt_groupwise --m 4 --n 1024 --k 7168 --mma_sm 1 --scale_major_mode MN --backends cutlass --refcheck -vv --generate_repro_command -gemm_fp8_nt_groupwise,0.019865599274635316,1.2664492914798042e-06,11.823505586357996,0.37690723025708733,cutlass,,,,,,,,,,,,,,,,16,1024,7168,,128,MN,torch.bfloat16,1,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --scale_major_mode MN --backends cutlass --refcheck -vv --generate_repro_command -group_gemm_fp8_nt_groupwise,0.02232320010662079,3.682025865219179e-05,5.260917406065297,0.6609174280359654,cutlass,,,,,,,,,,,,,,,,4,1024,7168,2,128,MN,torch.bfloat16,1,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine group_gemm_fp8_nt_groupwise --m 4 --n 1024 --k 7168 --mma_sm 1 --group_size 2 --scale_major_mode MN --refcheck -vv --generate_repro_command -group_gemm_fp8_nt_groupwise,0.02252800017595291,9.332681200911203e-05,20.852363473498134,0.6647272675354804,cutlass,,,,,,,,,,,,,,,,16,1024,7168,2,128,MN,torch.bfloat16,1,,,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine group_gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --group_size 2 --scale_major_mode MN --refcheck -vv --generate_repro_command -mm_fp4,0.012697599828243256,1.7496054581842973e-06,1.1561290478966861,0.28947581036726805,cudnn,,,,,,,,,,,,,,,,1,1024,7168,,,,torch.bfloat16,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 1 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command -mm_fp4,0.009216000139713288,3.619193125084151e-05,1.5928888647409136,0.39883332728707516,cutlass,,,,,,,,,,,,,,,,1,1024,7168,,,,torch.bfloat16,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 1 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command -mm_fp4,0.01085439994931221,3.680557773573178e-05,1.3524528365043527,0.3386320770530395,trtllm,,,,,,,,,,,,,,,,1,1024,7168,,,,torch.bfloat16,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 1 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command -mm_fp4,0.012697599828243256,1.7233627387457594e-06,4.6245161915867445,0.2908064555465576,cudnn,,,,,,,,,,,,,,,,4,1024,7168,,,,torch.bfloat16,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command -mm_fp4,0.009216000139713288,5.744219533265305e-07,6.3715554589636545,0.4006666605926154,cutlass,,,,,,,,,,,,,,,,4,1024,7168,,,,torch.bfloat16,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command -mm_fp4,0.011059200018644333,0.00010158394987753556,5.309629620678304,0.33388888832599684,trtllm,,,,,,,,,,,,,,,,4,1024,7168,,,,torch.bfloat16,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command -mm_fp4,0.009216000139713288,3.6526144810097353e-05,6.3715554589636545,0.4006666605926154,cutlass_autotune,,,,,,,,,,,,,,,,4,1024,7168,,,,torch.bfloat16,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command -mm_fp4,0.01085439994931221,0.00010171082787792661,5.409811346017411,0.34018868083389336,trtllm_autotune,,,,,,,,,,,,,,,,4,1024,7168,,,,torch.bfloat16,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,True,False,False,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command -trtllm_fp4_block_scale_moe,0.22354559898376464,0.0001550481673529622,230.55523251765356,1.817630057791967,trtllm,,,,,,,,,,,,,,,,,,,,,,,,,,1024,1024,1024,256,8,8,4,2.5,0,256,8,deepseek_v3,True,0,True,False,torch.bfloat16,torch.bfloat16,swiglu,,,,,,,False,False,False,False,42,trtllm_moe_sample,True,python3 flashinfer_benchmark.py --routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 256 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample -trtllm_fp4_block_scale_moe,0.22620320320129395,0.00015420901713100778,227.84649740850875,0.9027731398581356,trtllm,,,,,,,,,,,,,,,,,,,,,,,,,,1024,1024,1024,128,8,None,None,2.5,0,128,8,renormalize_naive,True,0,False,False,torch.bfloat16,torch.bfloat16,swiglu,,,,,,,False,False,False,False,42,trtllm_moe_sample,True,python3 flashinfer_benchmark.py --routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 8 --routing_method renormalize_naive --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample -trtllm_fp8_block_scale_moe,0.556544017791748,0.00016468714317887162,92.60652509840739,1.45451329296815,trtllm,,,,,,,,,,,,,,,,,,,,,,,,,,1024,1024,1024,256,8,8,4,2.5,0,256,8,deepseek_v3,True,0,True,False,torch.bfloat16,torch.bfloat16,,,,,,,,False,False,False,False,42,trtllm_moe_sample,True,python3 flashinfer_benchmark.py --routine trtllm_fp8_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 256 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample -trtllm_fp8_per_tensor_scale_moe,0.12308800220489502,0.00022594891363694267,52.34020236412443,3.2989491154796857,trtllm,,,,,,,,,,,,,,,,,,,,,,,,,,1024,1024,1024,128,1,None,None,2.5,0,128,8,llama4,,,True,True,torch.bfloat16,torch.bfloat16,,,,,,,,False,False,False,False,42,trtllm_moe_sample,True,python3 flashinfer_benchmark.py --routine trtllm_fp8_per_tensor_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 1 --routed_scaling_factor 2.5 --use_routing_bias --routing_method llama4 --use_routing_scales_on_input -vv --generate_repro_command --case_tag trtllm_moe_sample -trtllm_fp8_block_scale_moe,0.10864640474319458,0.00013707306207295686,59.29741494187403,3.7398678857383114,trtllm,,,,,,,,,,,,,,,,,,,,,,,,,,1024,1024,1024,128,1,None,None,2.5,0,128,8,renormalize,True,0,False,False,torch.bfloat16,torch.bfloat16,,,,,,,,False,False,False,False,42,trtllm_moe_sample,True,python3 flashinfer_benchmark.py --routine trtllm_fp8_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 1 --routing_method renormalize --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample -cutlass_fused_moe,0.026214399933815004,0.00010491445634120407,0.24000000060594173,0.00812988283302598,cutlass,,,,,,,,,,,,,,,,,,,,,,,,,,32,128,128,2,2,,,,,,,,False,0,,False,torch.float16,torch.float16,,base,False,1,0,1,0,False,False,False,False,42,cutlass_moe_base,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant base --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_base -cutlass_fused_moe,0.025804799795150758,0.0001189026818394618,0.2438095257449853,0.004290674637235764,cutlass,,,,,,,,,,,,,,,,,,,,,,,,,,32,128,128,2,2,,,,,,,,False,0,,False,torch.float16,torch.float16,,fp8,False,1,0,1,0,False,False,False,False,42,cutlass_moe_fp8_scale,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant fp8 --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_fp8_scale -cutlass_fused_moe,0.02990399897098541,0.00010532265873466246,0.21038845025724934,0.002195826720824563,cutlass,,,,,,,,,,,,,,,,,,,,,,,,,,32,128,128,2,2,,,,,,,,False,0,,False,torch.float16,torch.float16,,nvfp4,False,1,0,1,0,False,False,False,False,42,cutlass_moe_nvfp4_weights,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_nvfp4_weights -cutlass_fused_moe,0.02949439883232117,0.00010231710901612176,0.21331019614156588,0.0020180102784388863,cutlass,,,,,,,,,,,,,,,,,,,,,,,,,,32,128,128,2,2,,,,,,,,False,0,,False,torch.float16,torch.float16,,nvfp4,True,1,0,1,0,False,False,False,False,42,cutlass_moe_nvfp4_weights_quantized,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --quantized_input --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_nvfp4_weights_quantized -cutlass_fused_moe,0.025190401077270507,0.00010808926975429793,0.24975608688012632,0.031890242538648944,cutlass,,,,,,,,,,,,,,,,,,,,,,,,,,32,128,128,8,2,,,,,,,,False,0,,False,torch.float16,torch.float16,,base,False,2,0,4,0,False,False,False,False,42,cutlass_moe_nvfp4_ep_tp,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 8 --top_k 2 --cutlass_variant base --input_dtype float16 --tp_size 2 --tp_rank 0 --ep_size 4 --ep_rank 0 -vv --generate_repro_command --case_tag cutlass_moe_nvfp4_ep_tp +routine,median_time,std_time,tflops,tb_per_sec,backend,s_qo,s_kv,head_dim_qk,head_dim_vo,head_dim_ckv,head_dim_kpe,causal,q_dtype,kv_dtype,avg_actual_seq_len,random_actual_seq_len,n,group_size,tile_size,scale_major_mode,mma_sm,use_128x4_sf_layout,use_nvfp4,num_tokens,intermediate_size,num_experts,top_k,n_group,topk_group,routed_scaling_factor,local_expert_offset,local_num_experts,routing_method,use_shuffled_weight,weight_layout,use_routing_bias,use_routing_scales_on_input,weight_dtype,gated_act,cutlass_variant,quantized_input,tp_size,tp_rank,ep_size,ep_rank,num_tokens,num_experts,top_k,ep_size,max_num_tokens,num_heads,scale,eps,use_global_scale,alignment,global_scale,sf_layout,do_shuffle,sf_vec_size,vocab_size,top_k,top_p,min_p,temperature,num_speculate_tokens,filter_apply_order,max_len,num_rows,seq_len,head_dim,rotary_dim,no_rope_dim,rope_theta,rope_scale,interleave,kv_layout,batch_size,hidden_size,input_dtype,out_dtype,quant_dtype,m,k,num_qo_heads,num_kv_heads,page_size,enable_pdl,is_sf_swizzled_layout,refcheck,no_cuda_graph,use_cupti,allow_output_mismatch,random_seed,case_tag,generate_repro_command,repro_command +BatchPrefillWithPagedKVCacheWrapper,0.011616,0.007307573782194897,14.963658402203857,0.32687603305785123,fa2,1024,1024,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,,,,,,64,8,16,,,True,False,True,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B +BatchPrefillWithPagedKVCacheWrapper,0.020448,0.0007820994992113643,8.50048200312989,0.1856901408450704,cudnn,1024,1024,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,,,,,,64,8,16,,,True,False,True,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B +BatchPrefillWithPagedKVCacheWrapper,0.010416,0.00012602601583271074,16.6875821812596,0.36453456221198155,trtllm-gen,1024,1024,128,128,,,True,torch.bfloat16,torch.bfloat16,103,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,,,,,,64,8,16,,,True,False,True,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B +BatchPrefillWithRaggedKVCacheWrapper,0.5077125,0.0034909562580409865,213.66765797572447,1.6916055444764508,fa2,1024,1024,192,128,,,True,torch.bfloat16,torch.bfloat16,327,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,128,128,0,,,True,False,True,True,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 +BatchPrefillWithRaggedKVCacheWrapper,0.5157455,0.003936178626395013,210.3396748977936,1.6652579227545372,cutlass,1024,1024,192,128,,,True,torch.bfloat16,torch.bfloat16,327,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,128,128,0,,,True,False,True,True,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 +BatchPrefillWithRaggedKVCacheWrapper,0.291505,0.0014704125078811516,372.14367094904026,2.9462591722268914,cudnn,1024,1024,192,128,,,True,torch.bfloat16,torch.bfloat16,327,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,128,128,0,,,True,False,True,True,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 +BatchDecodeWithPagedKVCacheWrapper,0.05048,0.0003465227345435733,5.210553407290016,0.6617052297939778,fa2,1,1024,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,64,8,16,,,True,False,True,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B +BatchDecodeWithPagedKVCacheWrapper,0.0218245,0.00026476689118291706,12.051993676831083,1.5305221196361884,fa2_tc,1,1024,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,64,8,16,,,True,False,True,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B +BatchDecodeWithPagedKVCacheWrapper,0.015151999999999999,0.00032976116104437367,17.35934107708553,2.204519535374868,cudnn,1,1024,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,64,8,16,,,True,False,True,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B +BatchDecodeWithPagedKVCacheWrapper,0.013504,0.00023366085584786245,19.477838862559242,2.473554502369668,trtllm-gen,1,1024,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,64,8,16,,,True,False,True,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B +BatchDecodeWithPagedKVCacheWrapper,0.013488,0.00026481675844922575,19.50094424673784,2.476488730723606,trtllm-native,1,1024,128,128,,,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,64,8,16,,,True,False,True,True,42,Llama-3.1-70B,True,python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B +BatchMLAPagedAttentionWrapper,0.024320500000000002,0.0015404488869590362,91.92838287353516,0.9593065932032647,trtllm-native,1,1024,,,512,64,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,128,128,32,,,True,False,True,False,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends trtllm-gen-native fa2 fa3 --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag DeepSeek-R1 +BatchMLAPagedAttentionWrapper,0.0391525,0.00025226886891216313,57.103485107421875,0.5958959453419322,fa2,1,1024,,,512,64,False,torch.bfloat16,torch.bfloat16,501,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,16,,,,,,,128,128,32,,,True,False,True,False,42,DeepSeek-R1,True,python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends trtllm-gen-native fa2 fa3 --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag DeepSeek-R1 +bmm_fp8,0.08452799999999999,0.00044238307306375754,44.45978118493281,0.08727162597009276,cudnn,,,,,,,,,,,,1024,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,64,,torch.float8_e4m3fn,torch.bfloat16,,4,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 64 --m 4 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command +bmm_fp8,0.084641,0.0004453848074044147,44.40042513675405,0.0871551139518673,cublas,,,,,,,,,,,,1024,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,64,,torch.float8_e4m3fn,torch.bfloat16,,4,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 64 --m 4 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command +bmm_fp8,0.0862085,0.00042494313998724884,43.59310722260566,0.08557040199052297,cutlass,,,,,,,,,,,,1024,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,64,,torch.float8_e4m3fn,torch.bfloat16,,4,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 64 --m 4 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command +gemm_fp8_nt_groupwise,0.016048,0.00011820765156659246,14.636155533399801,0.46656829511465603,cutlass,,,,,,,,,,,,1024,,128,MN,1,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,fp8_e4m3,torch.bfloat16,,16,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --scale_major_mode MN --backends cutlass --refcheck -vv --generate_repro_command +group_gemm_fp8_nt_groupwise,0.0228325,0.0001976228984707995,20.574271236176504,0.6558623015438519,cutlass,,,,,,,,,,,,1024,2,128,MN,1,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,fp8_e4m3,torch.bfloat16,,16,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine group_gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --group_size 2 --scale_major_mode MN --refcheck -vv --generate_repro_command +mm_fp4,0.009152,0.00011045976643104036,821.2623216783217,0.7160839160839161,cudnn,,,,,,,,,,,,1024,,,,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,fp8_e4m3,torch.bfloat16,,512,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command +mm_fp4,0.010176,9.755467982396113e-05,738.619572327044,0.6440251572327044,cutlass,,,,,,,,,,,,1024,,,,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,fp8_e4m3,torch.bfloat16,,512,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command +mm_fp4,0.01184,0.00017707846346244992,634.8135783783783,0.5535135135135135,trtllm,,,,,,,,,,,,1024,,,,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,fp8_e4m3,torch.bfloat16,,512,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command +mm_fp4,0.0091355,0.00014340164651154523,822.745637129878,0.7173772645175415,cudnn_autotune,,,,,,,,,,,,1024,,,,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,fp8_e4m3,torch.bfloat16,,512,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command +mm_fp4,0.009936,0.00014641446498060083,756.4606247987117,0.6595813204508857,cutlass_autotune,,,,,,,,,,,,1024,,,,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,fp8_e4m3,torch.bfloat16,,512,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command +mm_fp4,0.014048,0.00016289848577155853,535.0365011389522,0.46651480637813214,trtllm_autotune,,,,,,,,,,,,1024,,,,,True,True,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1,,fp8_e4m3,torch.bfloat16,,512,7168,,,,,,True,False,True,False,42,None,True,python3 flashinfer_benchmark.py --routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command +trtllm_fp4_block_scale_moe,0.131376,0.0007720556427846086,392.3061103397881,3.476436974789916,trtllm,,,,,,,,,,,,,,,,,,,1024,1024,256,8,8,4,2.5,0,256,deepseek_v3,True,0,True,False,torch.bfloat16,swiglu,,,,,,,1024,256,8,,,,,,,,,,,,,8,,,,,,,,,,,,,,,,,1024,torch.bfloat16,,,,,,,,,,False,False,True,False,42,trtllm_moe_sample,True,python3 flashinfer_benchmark.py --routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 256 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample +trtllm_fp4_block_scale_moe,0.09689600000000001,0.000823255195347509,531.9064517833552,2.367915455746367,trtllm,,,,,,,,,,,,,,,,,,,1024,1024,128,8,None,None,2.5,0,128,renormalize_naive,True,0,False,False,torch.bfloat16,swiglu,,,,,,,1024,128,8,,,,,,,,,,,,,8,,,,,,,,,,,,,,,,,1024,torch.bfloat16,,,,,,,,,,False,False,True,False,42,trtllm_moe_sample,True,python3 flashinfer_benchmark.py --routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 8 --routing_method renormalize_naive --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample +trtllm_fp8_block_scale_moe,0.30302450000000003,0.001741246330138908,170.08396202947284,2.67140337497463,trtllm,,,,,,,,,,,,,,,,,,,1024,1024,256,8,8,4,2.5,0,256,deepseek_v3,True,0,True,False,torch.bfloat16,,,,,,,,1024,256,8,,,,,,,,,,,,,8,,,,,,,,,,,,,,,,,1024,torch.bfloat16,,,,,,,,,,False,False,True,False,42,trtllm_moe_sample,True,python3 flashinfer_benchmark.py --routine trtllm_fp8_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 256 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample +trtllm_fp8_per_tensor_scale_moe,0.092,0.00038124563239756883,70.02664069565218,4.413707130434783,trtllm,,,,,,,,,,,,,,,,,,,1024,1024,128,1,None,None,2.5,0,128,llama4,,,True,True,torch.bfloat16,,,,,,,,1024,128,1,,,,,,,,,,,,,1,,,,,,,,,,,,,,,,,1024,torch.bfloat16,,,,,,,,,,False,False,True,False,42,trtllm_moe_sample,True,python3 flashinfer_benchmark.py --routine trtllm_fp8_per_tensor_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 1 --routed_scaling_factor 2.5 --use_routing_bias --routing_method llama4 --use_routing_scales_on_input -vv --generate_repro_command --case_tag trtllm_moe_sample +cutlass_fused_moe,0.027808,0.0002716811468533574,0.22624626006904489,0.007663981588032221,cutlass,,,,,,,,,,,,,,,,,,,32,128,2,2,,,,,,,False,0,,False,torch.float16,,base,False,1,0,1,0,32,2,2,1,,,,,,,,,,,,2,,,,,,,,,,,,,,,,,128,torch.float16,,,,,,,,,,False,False,True,False,42,cutlass_moe_base,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant base --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_base +cutlass_fused_moe,0.028432,0.0001927351348929982,0.22128081035453012,0.00389420371412493,cutlass,,,,,,,,,,,,,,,,,,,32,128,2,2,,,,,,,False,0,,False,torch.float16,,fp8,False,1,0,1,0,32,2,2,1,,,,,,,,,,,,2,,,,,,,,,,,,,,,,,128,torch.float16,,,,,,,,,,False,False,True,False,42,cutlass_moe_fp8_scale,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant fp8 --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_fp8_scale +cutlass_fused_moe,0.029984,0.0003416674146333277,0.20982710779082178,0.002198505869797225,cutlass,,,,,,,,,,,,,,,,,,,32,128,2,2,,,,,,,False,0,,False,torch.float16,,nvfp4,False,1,0,1,0,32,2,2,1,,,,,,,,,,,,2,,,,,,,,,,,,,,,,,128,torch.float16,,,,,,,,,,False,False,True,False,42,cutlass_moe_nvfp4_weights,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_nvfp4_weights +cutlass_fused_moe,0.028319999999999998,0.0003103376584016549,0.22215593220338983,0.002327683615819209,cutlass,,,,,,,,,,,,,,,,,,,32,128,2,2,,,,,,,False,0,,False,torch.float16,,nvfp4,True,1,0,1,0,32,2,2,1,,,,,,,,,,,,2,,,,,,,,,,,,,,,,,128,torch.float16,,,,,,,,,,False,False,True,False,42,cutlass_moe_nvfp4_weights_quantized,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --quantized_input --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_nvfp4_weights_quantized +cutlass_fused_moe,0.026368,0.0002925743988952024,0.23860194174757282,0.03046601941747573,cutlass,,,,,,,,,,,,,,,,,,,32,128,8,2,,,,,,,False,0,,False,torch.float16,,base,False,2,0,4,0,32,8,2,4,,,,,,,,,,,,2,,,,,,,,,,,,,,,,,128,torch.float16,,,,,,,,,,False,False,True,False,42,cutlass_moe_nvfp4_ep_tp,True,python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 8 --top_k 2 --cutlass_variant base --input_dtype float16 --tp_size 2 --tp_rank 0 --ep_size 4 --ep_rank 0 -vv --generate_repro_command --case_tag cutlass_moe_nvfp4_ep_tp +rmsnorm,0.002912,7.812019940806316e-05,0.22505494505494505,0.18285714285714286,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,fp8_e4m3,,,,,,,False,False,True,False,True,False,42,rmsnorm_llama_hidden,True,python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag rmsnorm_llama_hidden +rmsnorm,0.003424,8.368695637114948e-05,0.765607476635514,0.6172710280373832,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,,,,,,,,,,,,,,,,,,,,,,,,64,8192,torch.bfloat16,fp8_e4m3,,,,,,,False,False,True,False,True,False,42,rmsnorm_large_hidden,True,python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag rmsnorm_large_hidden +rmsnorm,0.00256,7.790495633926137e-05,0.256,0.2049,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32,,1e-06,,,,,,,,,,,,,,,,,,,,,,,,32,128,torch.bfloat16,fp8_e4m3,,,,,,,False,False,True,False,True,False,42,rmsnorm_3d_gqa,True,python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag rmsnorm_3d_gqa +rmsnorm,0.00256,8.524028781431153e-05,0.256,0.2049,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,64,,1e-06,,,,,,,,,,,,,,,,,,,,,,,,16,128,torch.float16,fp8_e4m3,,,,,,,False,False,True,False,True,False,42,rmsnorm_3d_mha,True,python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 16 --num_heads 64 --hidden_size 128 --input_dtype float16 --refcheck -vv --generate_repro_command --case_tag rmsnorm_3d_mha +rmsnorm,0.002912,0.00011321305382134853,0.22505494505494505,0.18285714285714286,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,fp8_e4m3,,,,,,,True,False,True,False,True,False,42,rmsnorm_pdl,True,python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --enable_pdl --refcheck -vv --generate_repro_command --case_tag rmsnorm_pdl +rmsnorm_quant,0.002944,5.0394929198173184e-05,0.22260869565217392,0.13634782608695653,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1e-06,,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,torch.float8_e4m3fn,,,,,,,False,False,True,False,True,False,42,rmsnorm_quant_fp8_e4m3,True,python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag rmsnorm_quant_fp8_e4m3 +rmsnorm_quant,0.0034714999999999998,8.928969705402746e-05,0.7551317874117818,0.45779864611839266,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1e-06,,,,,,,,,,,,,,,,,,,,,,,,64,8192,torch.bfloat16,torch.float8_e4m3fn,,,,,,,False,False,True,False,True,False,42,rmsnorm_quant_large,True,python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag rmsnorm_quant_large +rmsnorm_quant,0.002912,0.00010142936239351772,0.22505494505494505,0.13784615384615384,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1e-06,,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.float16,torch.float8_e5m2,,,,,,,False,False,True,False,True,False,42,rmsnorm_quant_fp8_e5m2,True,python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype float16 --out_dtype fp8_e5m2 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag rmsnorm_quant_fp8_e5m2 +fused_add_rmsnorm_quant,0.003232,0.00010254337293717882,0.24332673267326732,0.2864158415841584,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1e-06,,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,torch.float8_e4m3fn,,,,,,,False,False,True,False,True,False,42,fused_add_rmsnorm_quant_fp8_e4m3,True,python3 flashinfer_benchmark.py --routine fused_add_rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag fused_add_rmsnorm_quant_fp8_e4m3 +fused_add_rmsnorm_quant,0.003936,6.175634020532265e-05,0.7992195121951219,0.9365853658536585,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1e-06,,,,,,,,,,,,,,,,,,,,,,,,64,8192,torch.bfloat16,torch.float8_e4m3fn,,,,,,,False,False,True,False,True,False,42,fused_add_rmsnorm_quant_large,True,python3 flashinfer_benchmark.py --routine fused_add_rmsnorm_quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag fused_add_rmsnorm_quant_large +fused_add_rmsnorm_quant,0.0032,8.371249740762858e-05,0.24576,0.28928,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1e-06,,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,torch.float8_e4m3fn,,,,,,,True,False,True,False,True,False,42,fused_add_rmsnorm_quant_pdl,True,python3 flashinfer_benchmark.py --routine fused_add_rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --enable_pdl --refcheck -vv --generate_repro_command --case_tag fused_add_rmsnorm_quant_pdl +rmsnorm_fp4quant,0.006016499999999999,0.00015377948136500162,0.10892711709465638,0.0571867364746946,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,rmsnorm_fp4quant_nvfp4,True,python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag rmsnorm_fp4quant_nvfp4 +rmsnorm_fp4quant,0.006656,0.00010579418698586425,0.39384615384615385,0.2043076923076923,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,64,8192,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,rmsnorm_fp4quant_nvfp4_large,True,python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag rmsnorm_fp4quant_nvfp4_large +rmsnorm_fp4quant,0.004512,0.00012588926350831783,0.1452482269503546,0.07625531914893617,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,True,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,rmsnorm_fp4quant_nvfp4_global,True,python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale -vv --generate_repro_command --case_tag rmsnorm_fp4quant_nvfp4_global +rmsnorm_fp4quant,0.0063035,0.00011677332743396502,0.10396763702704846,0.05458300943920044,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,nvfp4,,,,,,,False,True,False,False,True,False,42,rmsnorm_fp4quant_nvfp4_swizzled,True,python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --is_sf_swizzled_layout -vv --generate_repro_command --case_tag rmsnorm_fp4quant_nvfp4_swizzled +rmsnorm_fp4quant,0.005344,0.00021772911049181175,0.12263473053892215,0.06361676646706586,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,mxfp4,,,,,,,False,False,False,False,True,False,42,rmsnorm_fp4quant_mxfp4,True,python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 -vv --generate_repro_command --case_tag rmsnorm_fp4quant_mxfp4 +rmsnorm_fp4quant,0.004032,0.00012955181288666794,0.16253968253968254,0.08336507936507936,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,128,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,rmsnorm_fp4quant_3d,True,python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag rmsnorm_fp4quant_3d +add_rmsnorm_fp4quant,0.005248,0.00010570372115808723,0.14985365853658536,0.16546341463414635,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,add_rmsnorm_fp4quant_nvfp4,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_nvfp4 +add_rmsnorm_fp4quant,0.005792,0.00010571013301581935,0.5431160220994475,0.5968618784530386,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,64,8192,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,add_rmsnorm_fp4quant_nvfp4_large,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_nvfp4_large +add_rmsnorm_fp4quant,0.003616,8.856776438913255e-05,0.21748672566371682,0.24014159292035397,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,True,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,add_rmsnorm_fp4quant_nvfp4_global,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_nvfp4_global +add_rmsnorm_fp4quant,0.005536,0.00010942230120044093,0.1420578034682081,0.15685549132947976,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,nvfp4,,,,,,,False,True,False,False,True,False,42,add_rmsnorm_fp4quant_nvfp4_swizzled,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --is_sf_swizzled_layout -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_nvfp4_swizzled +add_rmsnorm_fp4quant,0.005344,0.00015592187858739457,0.1471616766467066,0.1617245508982036,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,mxfp4,,,,,,,False,False,False,False,True,False,42,add_rmsnorm_fp4quant_mxfp4,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_mxfp4 +add_rmsnorm_fp4quant,0.003968,0.0001344020667830504,0.19819354838709677,0.21683870967741936,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,128,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,add_rmsnorm_fp4quant_3d,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_3d +add_rmsnorm_fp4quant,0.005248,0.00013531114760679053,0.14985365853658536,0.16702439024390245,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,add_rmsnorm_fp4quant_both_sf,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_both_sf +add_rmsnorm_fp4quant,0.005632,0.00013822792771361366,0.5585454545454546,0.6196363636363637,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,64,8192,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,add_rmsnorm_fp4quant_both_sf_large,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_both_sf_large +add_rmsnorm_fp4quant,0.0037914999999999997,9.353935000843226e-05,0.20741975471449298,0.23118660160886195,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,True,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,nvfp4,,,,,,,False,False,False,False,True,False,42,add_rmsnorm_fp4quant_both_sf_global,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale --output_both_sf_layouts -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_both_sf_global +add_rmsnorm_fp4quant,0.005376,9.779302974479657e-05,0.1462857142857143,0.16152380952380951,cute-dsl,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1e-06,False,,,,,,,,,,,,,,,,,,,,,,,32,4096,torch.bfloat16,mxfp4,,,,,,,False,False,False,False,True,False,42,add_rmsnorm_fp4quant_mxfp4_both_sf,True,python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 --output_both_sf_layouts -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_mxfp4_both_sf +mxfp8_quantize,0.006784,7.182856906087795e-05,1.8547924528301887,1.8741132075471698,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32,,,,,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,1024,4096,,,,False,True,False,False,True,False,42,mxfp8_quantize_basic,True,python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag mxfp8_quantize_basic +mxfp8_quantize,0.014832,0.00021055667223391944,3.3934498381877023,3.4287982740021574,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32,,,,,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,2048,8192,,,,False,True,False,False,True,False,42,mxfp8_quantize_large,True,python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag mxfp8_quantize_large +mxfp8_quantize,0.00672,0.00012175248115199402,1.8724571428571428,1.8919619047619047,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32,,,,,,,,,,,,,,,,,,,,,,None,,torch.float16,,,1024,4096,,,,False,True,False,False,True,False,42,mxfp8_quantize_fp16,True,python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 1024 --k 4096 --input_dtype float16 -vv --generate_repro_command --case_tag mxfp8_quantize_fp16 +mxfp8_quantize,0.00656,0.00015712367386517179,1.9181268292682927,1.9381073170731706,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32,,,,,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,1024,4096,,,,False,False,False,False,True,False,42,mxfp8_quantize_no_swizzle,True,python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 --no_sf_swizzled_layout -vv --generate_repro_command --case_tag mxfp8_quantize_no_swizzle +mxfp8_quantize,0.014864,0.00017424996413199052,3.3861442411194833,3.4214165769644778,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32,,,,,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,2048,8192,,,,True,True,False,False,True,False,42,mxfp8_quantize_pdl,True,python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 --enable_pdl -vv --generate_repro_command --case_tag mxfp8_quantize_pdl +mxfp8_quantize,0.006816,0.0001178738685582555,1.8460845070422536,1.8653145539906104,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32,,,,,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,1024,4096,,,,False,True,True,False,True,False,42,mxfp8_quantize_refcheck,True,python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag mxfp8_quantize_refcheck +mxfp4_quantize,0.048032,0.0005276333280687346,0.26196935376415725,0.22103664223850766,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,1024,4096,,,,False,True,False,False,True,False,42,mxfp4_quantize_basic,True,python3 flashinfer_benchmark.py --routine mxfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag mxfp4_quantize_basic +mxfp4_quantize,0.12345600000000001,0.0011042029307252459,0.4076889580093312,0.3439875583203732,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,2048,8192,,,,False,True,False,False,True,False,42,mxfp4_quantize_large,True,python3 flashinfer_benchmark.py --routine mxfp4_quantize --m 2048 --k 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag mxfp4_quantize_large +mxfp4_quantize,0.048656000000000005,0.00037438014548495135,0.25860966787241035,0.21820190726734623,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,1024,4096,,,,False,True,True,False,True,False,42,mxfp4_quantize_refcheck,True,python3 flashinfer_benchmark.py --routine mxfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag mxfp4_quantize_refcheck +nvfp4_quantize,0.0079835,0.0001245252138680712,1.576114736644329,1.346265171917079,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,128x4,False,16,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,1024,4096,,,,False,True,False,False,True,False,42,nvfp4_quantize_128x4,True,python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 128x4 -vv --generate_repro_command --case_tag nvfp4_quantize_128x4 +nvfp4_quantize,0.0138085,0.000238419427340419,3.6449757757902743,3.1134170981641742,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,128x4,False,16,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,2048,8192,,,,False,True,False,False,True,False,42,nvfp4_quantize_128x4_large,True,python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 128x4 -vv --generate_repro_command --case_tag nvfp4_quantize_128x4_large +nvfp4_quantize,0.008,0.0001464365467437081,1.572864,1.3434885,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,8x4,False,16,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,1024,4096,,,,False,True,False,False,True,False,42,nvfp4_quantize_8x4,True,python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 8x4 -vv --generate_repro_command --case_tag nvfp4_quantize_8x4 +nvfp4_quantize,3.7088235000000003,0.0300100512421533,0.003392696363145887,0.002897929222029573,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,128x4,True,16,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,1024,4096,,,,False,True,False,False,True,False,42,nvfp4_quantize_shuffle,True,python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --do_shuffle -vv --generate_repro_command --case_tag nvfp4_quantize_shuffle +nvfp4_quantize,0.008031,0.0001525645073039234,1.566792678371311,1.3383025775121404,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,128x4,False,16,,,,,,,,,,,,,,,,,,None,,torch.bfloat16,,,1024,4096,,,,True,True,False,False,True,False,42,nvfp4_quantize_pdl,True,python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --enable_pdl -vv --generate_repro_command --case_tag nvfp4_quantize_pdl +nvfp4_batched_quantize,0.021888,0.0001485392727717338,2.2995087719298244,1.9641639254385965,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,,,16,,,,,,,,,,,,,,,,,,4,,torch.bfloat16,,,1024,4096,,,,False,True,False,False,True,False,42,nvfp4_batched_basic,True,python3 flashinfer_benchmark.py --routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag nvfp4_batched_basic +nvfp4_batched_quantize,0.07880000000000001,0.0006070973562782165,5.10981197969543,4.364631116751268,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,,,16,,,,,,,,,,,,,,,,,,8,,torch.bfloat16,,,2048,8192,,,,False,True,False,False,True,False,42,nvfp4_batched_large,True,python3 flashinfer_benchmark.py --routine nvfp4_batched_quantize --batch_size 8 --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag nvfp4_batched_large +nvfp4_batched_quantize,0.01984,0.00012693900985206308,2.536877419354839,2.166916330645161,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,,,16,,,,,,,,,,,,,,,,,,4,,torch.float16,,,1024,4096,,,,False,True,False,False,True,False,42,nvfp4_batched_fp16,True,python3 flashinfer_benchmark.py --routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype float16 --global_scale 1.0 -vv --generate_repro_command --case_tag nvfp4_batched_fp16 +softmax,0.012064,0.00015532180214709794,0,0.6790450928381963,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32000,,,,1.0,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,softmax_llama,True,python3 flashinfer_benchmark.py --routine softmax --batch_size 32 --vocab_size 32000 --temperature 1.0 --input_dtype float32 -vv --generate_repro_command --case_tag softmax_llama +softmax,0.035552,0.0003970301192045208,0,1.847071107110711,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,128256,,,,0.8,,,,,,,,,,,,,64,,float32,,,,,,,,,,False,False,True,False,42,softmax_llama3_temp,True,python3 flashinfer_benchmark.py --routine softmax --batch_size 64 --vocab_size 128256 --temperature 0.8 --input_dtype float32 -vv --generate_repro_command --case_tag softmax_llama3_temp +sampling_from_probs,0.0135365,0.0001606484222006415,0,0.3025987515236583,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32000,,,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,sampling_from_probs_llama,True,python3 flashinfer_benchmark.py --routine sampling_from_probs --batch_size 32 --vocab_size 32000 -vv --generate_repro_command --case_tag sampling_from_probs_llama +sampling_from_probs,0.042753,0.00035013663364781225,0,0.7679880242322176,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,128256,,,,,,,,,,,,,,,,,64,,float32,,,,,,,,,,False,False,True,False,42,sampling_from_probs_llama3,True,python3 flashinfer_benchmark.py --routine sampling_from_probs --batch_size 64 --vocab_size 128256 -vv --generate_repro_command --case_tag sampling_from_probs_llama3 +sampling_from_logits,0.0161925,0.00013220608491627353,0,0.2529645206113942,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32000,,,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,sampling_from_logits_llama,True,python3 flashinfer_benchmark.py --routine sampling_from_logits --batch_size 32 --vocab_size 32000 --input_dtype float32 -vv --generate_repro_command --case_tag sampling_from_logits_llama +sampling_from_logits,0.07783999999999999,0.00028485495139339475,0,0.21090729701952726,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,128256,,,,,,,,,,,,,,,,,64,,bfloat16,,,,,,,,,,False,False,True,False,42,sampling_from_logits_llama3,True,python3 flashinfer_benchmark.py --routine sampling_from_logits --batch_size 64 --vocab_size 128256 --input_dtype bfloat16 -vv --generate_repro_command --case_tag sampling_from_logits_llama3 +top_k_sampling_from_probs,0.1495995,0.0001935201425060321,0,0.02738062627214663,cuda,,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,32000,50,,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_k_sampling_k50,True,python3 flashinfer_benchmark.py --routine top_k_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_k 50 -vv --generate_repro_command --case_tag top_k_sampling_k50 +top_k_sampling_from_probs,0.498001,0.0017754358053415802,0,0.03296558842251321,cuda,,,,,,,,,,,,,,,,,,,,,,100,,,,,,,,,,,,,,,,,,,,,100,,,,,,,,,,,,128256,100,,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_k_sampling_k100,True,python3 flashinfer_benchmark.py --routine top_k_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_k 100 -vv --generate_repro_command --case_tag top_k_sampling_k100 +top_p_sampling_from_probs,0.022912,0.00022344903719242592,0,0.17877653631284915,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32000,,0.9,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_p_sampling_p09,True,python3 flashinfer_benchmark.py --routine top_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_p 0.9 -vv --generate_repro_command --case_tag top_p_sampling_p09 +top_p_sampling_from_probs,0.071664,0.00048469280534742286,0,0.22908149140433132,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,128256,,0.95,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_p_sampling_p095,True,python3 flashinfer_benchmark.py --routine top_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_p 0.95 -vv --generate_repro_command --case_tag top_p_sampling_p095 +top_k_top_p_sampling_from_probs,0.043567999999999996,0.00013004975543571325,0,0.09401689313257439,cuda,,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,32000,50,0.9,,,,top_k_first,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_k_top_p_probs,True,python3 flashinfer_benchmark.py --routine top_k_top_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_k 50 --top_p 0.9 --filter_apply_order top_k_first -vv --generate_repro_command --case_tag top_k_top_p_probs +top_k_top_p_sampling_from_logits,0.050304,0.00016321227146129575,0,0.08142748091603054,cuda,,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,32000,50,0.9,,,,top_k_first,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_k_top_p_logits,True,python3 flashinfer_benchmark.py --routine top_k_top_p_sampling_from_logits --batch_size 32 --vocab_size 32000 --top_k 50 --top_p 0.9 --filter_apply_order top_k_first --input_dtype float32 -vv --generate_repro_command --case_tag top_k_top_p_logits +min_p_sampling_from_probs,0.0127045,0.00010446553924088511,0,0.3224155220591129,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32000,,,0.1,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,min_p_sampling_p01,True,python3 flashinfer_benchmark.py --routine min_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --min_p 0.1 -vv --generate_repro_command --case_tag min_p_sampling_p01 +min_p_sampling_from_probs,0.041808,0.00017156848001114136,0,0.392673555300421,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,128256,,,0.05,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,min_p_sampling_p005,True,python3 flashinfer_benchmark.py --routine min_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --min_p 0.05 -vv --generate_repro_command --case_tag min_p_sampling_p005 +top_k_renorm_probs,0.023552,0.00016119284516793347,0,0.34782608695652173,cuda,,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,32000,50,,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_k_renorm,True,python3 flashinfer_benchmark.py --routine top_k_renorm_probs --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_renorm +top_p_renorm_probs,0.08048,0.00015717873760644349,0,0.10178926441351889,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32000,,0.9,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_p_renorm,True,python3 flashinfer_benchmark.py --routine top_p_renorm_probs --batch_size 32 --vocab_size 32000 --top_p 0.9 -vv --generate_repro_command --case_tag top_p_renorm +top_k_mask_logits,0.020448,0.00010783782061760889,0,0.40062597809076683,cuda,,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,32000,50,,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_k_mask,True,python3 flashinfer_benchmark.py --routine top_k_mask_logits --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_mask +chain_speculative_sampling,0.027168,0.00045243810626427136,0,0.8292414605418139,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,32000,,,,,5,,,,,,,,,,,,16,,float32,,,,,,,,,,False,False,True,False,42,chain_spec_sampling_5,True,python3 flashinfer_benchmark.py --routine chain_speculative_sampling --batch_size 16 --vocab_size 32000 --num_speculate_tokens 5 -vv --generate_repro_command --case_tag chain_spec_sampling_5 +chain_speculative_sampling,0.078304,0.00041787000237979534,0,3.564153657539845,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,128256,,,,,8,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,chain_spec_sampling_8,True,python3 flashinfer_benchmark.py --routine chain_speculative_sampling --batch_size 32 --vocab_size 128256 --num_speculate_tokens 8 -vv --generate_repro_command --case_tag chain_spec_sampling_8 +top_k,0.017551999999999998,0.00014782604266126072,0,0.23445761166818596,cuda,,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,,,,,,,,,,50,,,,,,,,,,,,32000,50,,,,,,,,,,,,,,,,32,,float32,,,,,,,,,,False,False,True,False,42,top_k_radix,True,python3 flashinfer_benchmark.py --routine top_k --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_radix +top_k,0.038896,0.00019824246938198372,0,0.4237136980666392,cuda,,,,,,,,,,,,,,,,,,,,,,100,,,,,,,,,,,,,,,,,,,,,100,,,,,,,,,,,,128256,100,,,,,,,,,,,,,,,,64,,bfloat16,,,,,,,,,,False,False,True,False,42,top_k_radix_large,True,python3 flashinfer_benchmark.py --routine top_k --batch_size 64 --vocab_size 128256 --top_k 100 --input_dtype bfloat16 -vv --generate_repro_command --case_tag top_k_radix_large +top_k_page_table_transform,0.0080005,0.00012896701128583244,0,0.06604387225798387,cuda,,,,,,,,,,,,,,,,,,,,,,64,,,,,,,,,,,,,,,,,,,,,64,,,,,,,,,,,,,64,,,,,,4096,16,,,,,,,,,16,,float32,,,,,,,,,,False,False,True,False,42,top_k_page_table,True,python3 flashinfer_benchmark.py --routine top_k_page_table_transform --batch_size 16 --num_rows 16 --max_len 4096 --top_k 64 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_page_table +top_k_ragged_transform,0.007424,0.00011858433941948468,0,0.03586206896551724,cuda,,,,,,,,,,,,,,,,,,,,,,64,,,,,,,,,,,,,,,,,,,,,64,,,,,,,,,,,,,64,,,,,,4096,16,,,,,,,,,16,,float32,,,,,,,,,,False,False,True,False,42,top_k_ragged,True,python3 flashinfer_benchmark.py --routine top_k_ragged_transform --batch_size 16 --num_rows 16 --max_len 4096 --top_k 64 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_ragged +apply_rope,0.122912,0.0006609962724722589,0,2.729955740692528,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1024,128,128,,10000.0,1.0,False,,16,,float16,,fp8_e4m3,,,32,8,16,,,False,False,True,False,42,apply_rope_llama,True,python3 flashinfer_benchmark.py --routine apply_rope --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag apply_rope_llama +apply_rope,0.5599689999999999,0.0022532907592131843,0,4.314380088897779,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,2048,128,128,,10000.0,1.0,False,,32,,bfloat16,,fp8_e4m3,,,64,8,16,,,False,False,True,False,42,apply_rope_llama70b,True,python3 flashinfer_benchmark.py --routine apply_rope --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag apply_rope_llama70b +apply_rope_pos_ids,0.0809755,0.00050208063650728,0,4.144585164648567,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1024,128,128,,10000.0,1.0,False,,16,,float16,,fp8_e4m3,,,32,8,16,,,False,False,True,False,42,apply_rope_pos_ids,True,python3 flashinfer_benchmark.py --routine apply_rope_pos_ids --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag apply_rope_pos_ids +apply_rope_pos_ids,0.41246499999999997,0.0005260596924304349,0,5.857906120519317,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,2048,128,128,,10000.0,1.0,True,,32,,bfloat16,,fp8_e4m3,,,64,8,16,,,False,False,True,False,42,apply_rope_pos_ids_interleave,True,python3 flashinfer_benchmark.py --routine apply_rope_pos_ids --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --interleave -vv --generate_repro_command --case_tag apply_rope_pos_ids_interleave +apply_llama31_rope,0.1242245,0.0004841010695666286,0,2.701112260463918,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1024,128,128,,500000.0,1.0,False,,16,,bfloat16,,fp8_e4m3,,,32,8,16,,,False,False,True,False,42,apply_llama31_rope,True,python3 flashinfer_benchmark.py --routine apply_llama31_rope --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --rope_theta 500000.0 --rope_scale 1.0 --low_freq_factor 1.0 --high_freq_factor 4.0 --old_context_len 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag apply_llama31_rope +apply_llama31_rope_pos_ids,0.081233,0.0005974024662375163,0,4.131447268967046,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1024,128,128,,500000.0,1.0,False,,16,,bfloat16,,fp8_e4m3,,,32,8,16,,,False,False,True,False,42,apply_llama31_rope_pos_ids,True,python3 flashinfer_benchmark.py --routine apply_llama31_rope_pos_ids --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --rope_theta 500000.0 --rope_scale 1.0 --low_freq_factor 1.0 --high_freq_factor 4.0 --old_context_len 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag apply_llama31_rope_pos_ids +mla_rope_quantize_fp8,0.441457,0.0014416343873072154,0,2.759015170220429,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1024,192,,64,,,False,,16,,bfloat16,,fp8_e4m3,,,128,128,16,,,False,False,True,False,42,mla_rope_fp8_deepseek,True,python3 flashinfer_benchmark.py --routine mla_rope_quantize_fp8 --batch_size 16 --seq_len 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim 192 --no_rope_dim 64 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag mla_rope_fp8_deepseek +rope_quantize_fp8,0.08296,0.0005573869352214454,0,3.040598649951784,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1024,128,128,0,,,False,,16,,bfloat16,,fp8_e4m3,,,32,8,16,,,False,False,True,False,42,rope_fp8_llama,True,python3 flashinfer_benchmark.py --routine rope_quantize_fp8 --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag rope_fp8_llama +rope_quantize_fp8,0.5697295,0.002435763801949799,0,3.182650798317447,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,2048,128,128,0,,,False,,32,,bfloat16,,fp8_e4m3,,,64,8,16,,,False,False,True,False,42,rope_fp8_llama70b,True,python3 flashinfer_benchmark.py --routine rope_quantize_fp8 --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag rope_fp8_llama70b +rope_quantize_fp8_append_paged_kv_cache,0.0097445,0.00012860200879716725,0,1.9407082969880445,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,64,128,128,0,,,False,NHD,16,,bfloat16,,fp8_e4m3,,,32,8,16,,,False,False,True,False,42,rope_fp8_paged_kv,True,python3 flashinfer_benchmark.py --routine rope_quantize_fp8_append_paged_kv_cache --batch_size 16 --seq_len 64 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --page_size 16 --kv_layout NHD --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag rope_fp8_paged_kv +rope_quantize_fp8_append_paged_kv_cache,0.0262405,0.00016648466395837577,0,2.399173796231017,cuda,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,64,128,128,0,,,False,HND,32,,bfloat16,,fp8_e4m3,,,64,8,16,,,False,False,True,False,42,rope_fp8_paged_kv_hnd,True,python3 flashinfer_benchmark.py --routine rope_quantize_fp8_append_paged_kv_cache --batch_size 32 --seq_len 64 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --page_size 16 --kv_layout HND --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag rope_fp8_paged_kv_hnd diff --git a/benchmarks/samples/sample_testlist_output.txt b/benchmarks/samples/sample_testlist_output.txt index d2c5cc4fa1..425fbe45c7 100644 --- a/benchmarks/samples/sample_testlist_output.txt +++ b/benchmarks/samples/sample_testlist_output.txt @@ -1,6 +1,7 @@ -[INFO] args = Namespace(routine='BatchPrefillWithPagedKVCacheWrapper', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=True, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='Llama-3.1-70B', generate_repro_command=True, repro_command='', backends=['fa2', 'fa3', 'cudnn', 'trtllm-gen'], page_size=16, batch_size=1, s_qo=1024, s_kv=1024, num_qo_heads=64, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, head_dim_ckv=None, head_dim_kpe=None, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, random_actual_seq_len=True) +flashinfer/benchmarks$ python3 flashinfer_benchmark.py --testlist samples/sample_testlist.txt --output_path samples/sample_testlist_output.csv +[INFO] args = Namespace(routine='BatchPrefillWithPagedKVCacheWrapper', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=True, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='Llama-3.1-70B', generate_repro_command=True, repro_command='', backends=['fa2', 'fa3', 'cudnn', 'trtllm-gen'], page_size=16, batch_size=1, s_qo=1024, s_kv=1024, num_qo_heads=64, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, head_dim_ckv=None, head_dim_kpe=None, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, random_actual_seq_len=True) [INFO] Running testBatchPrefillWithPagedKVCacheWrapper -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B [WARNING] fa3 for routine BatchPrefillWithPagedKVCacheWrapper is not supported on compute capability 10.0. Skipping. @@ -21,62 +22,12 @@ [VVERBOSE] kv_indices.shape = torch.Size([7]) [VVERBOSE] kv_last_page_len.shape = torch.Size([1]) [VVERBOSE] scale = 0.08838834764831843 -[PERF] fa2 :: median time 0.012 ms; std 0.001 ms; achieved tflops 13.964 TFLOPs/sec; achieved tb_per_sec 0.305 TB/sec -[PERF] cudnn :: median time 0.018 ms; std 0.000 ms; achieved tflops 9.452 TFLOPs/sec; achieved tb_per_sec 0.206 TB/sec -[PERF] trtllm-gen :: median time 0.008 ms; std 0.000 ms; achieved tflops 20.700 TFLOPs/sec; achieved tb_per_sec 0.452 TB/sec -[INFO] args = Namespace(routine='BatchPrefillWithPagedKVCacheWrapper', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=True, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='Llama-3.1-70B', generate_repro_command=True, repro_command='', backends=['fa2', 'fa3', 'cudnn', 'trtllm-gen'], page_size=16, batch_size=32, s_qo=1024, s_kv=1024, num_qo_heads=64, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, head_dim_ckv=None, head_dim_kpe=None, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, random_actual_seq_len=True) -[INFO] Running testBatchPrefillWithPagedKVCacheWrapper -[INFO] FlashInfer version: 0.3.1 -[VVERBOSE] gpu_name = 'NVIDIA_B200' -[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 fa3 cudnn trtllm-gen --page_size 16 --batch_size 32 --s_qo 1024 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -[WARNING] fa3 for routine BatchPrefillWithPagedKVCacheWrapper is not supported on compute capability 10.0. Skipping. -[VVERBOSE] s_qo == s_kv, making actual_seq_lens_kv the same as actual_seq_lens_q -[VERBOSE] Average actual qo seq len: 399 -[VERBOSE] Average actual kv seq len: 399 -[VVERBOSE] actual_seq_lens_q.flatten() = tensor([103, 436, 861, 271, 107, 72, 701, 21, 615, 122, 467, 215, 331, 459, - 88, 373, 100, 872, 664, 131, 662, 309, 770, 344, 492, 414, 806, 386, - 192, 956, 277, 161], dtype=torch.int32) -[VVERBOSE] actual_seq_lens_kv.flatten() = tensor([103, 436, 861, 271, 107, 72, 701, 21, 615, 122, 467, 215, 331, 459, - 88, 373, 100, 872, 664, 131, 662, 309, 770, 344, 492, 414, 806, 386, - 192, 956, 277, 161], dtype=torch.int32) -[VVERBOSE] q.shape = torch.Size([12778, 64, 128]) -[VVERBOSE] num_pages_per_seq = 64 -[VVERBOSE] total_num_pages = 2048 -[VVERBOSE] kv_cache.shape = torch.Size([2048, 2, 8, 16, 128]) -[VVERBOSE] kv_cache.stride() = (32768, 16384, 128, 1024, 1) -[VVERBOSE] block_tables.shape = torch.Size([32, 64]) -[VVERBOSE] qo_indptr.shape = torch.Size([33]) -[VVERBOSE] qo_indptr.dtype = torch.int32 -[VVERBOSE] kv_indptr.shape = torch.Size([33]) -[VVERBOSE] kv_indices.shape = torch.Size([815]) -[VVERBOSE] kv_last_page_len.shape = torch.Size([32]) -[VVERBOSE] scale = 0.08838834764831843 -[PERF] fa2 :: median time 0.483 ms; std 0.003 ms; achieved tflops 250.421 TFLOPs/sec; achieved tb_per_sec 0.975 TB/sec -[PERF] cudnn :: median time 0.382 ms; std 0.001 ms; achieved tflops 317.089 TFLOPs/sec; achieved tb_per_sec 1.234 TB/sec -[PERF] trtllm-gen :: median time 0.744 ms; std 0.000 ms; achieved tflops 162.619 TFLOPs/sec; achieved tb_per_sec 0.633 TB/sec -[INFO] args = Namespace(routine='BatchPrefillWithRaggedKVCacheWrapper', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=True, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='DeepSeek-R1', generate_repro_command=True, repro_command='', backends=['fa2', 'fa3', 'cutlass', 'cudnn'], page_size=0, batch_size=1, s_qo=1024, s_kv=1024, num_qo_heads=128, num_kv_heads=128, head_dim_qk=192, head_dim_vo=128, head_dim_ckv=None, head_dim_kpe=None, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, random_actual_seq_len=True) +[PERF] fa2 :: median time 0.012 ms; std 0.007 ms; achieved tflops 14.964 TFLOPs/sec; achieved tb_per_sec 0.327 TB/sec +[PERF] cudnn :: median time 0.020 ms; std 0.001 ms; achieved tflops 8.500 TFLOPs/sec; achieved tb_per_sec 0.186 TB/sec +[PERF] trtllm-gen :: median time 0.010 ms; std 0.000 ms; achieved tflops 16.688 TFLOPs/sec; achieved tb_per_sec 0.365 TB/sec +[INFO] args = Namespace(routine='BatchPrefillWithRaggedKVCacheWrapper', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=True, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='DeepSeek-R1', generate_repro_command=True, repro_command='', backends=['fa2', 'fa3', 'cutlass', 'cudnn'], page_size=0, batch_size=16, s_qo=1024, s_kv=1024, num_qo_heads=128, num_kv_heads=128, head_dim_qk=192, head_dim_vo=128, head_dim_ckv=None, head_dim_kpe=None, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, random_actual_seq_len=True) [INFO] Running testBatchPrefillWithRaggedKVCacheWrapper -[INFO] FlashInfer version: 0.3.1 -[VVERBOSE] gpu_name = 'NVIDIA_B200' -[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 1 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 -[WARNING] fa3 for routine BatchPrefillWithRaggedKVCacheWrapper is not supported on compute capability 10.0. Skipping. -[VVERBOSE] s_qo == s_kv, making actual_seq_lens_kv the same as actual_seq_lens_q -[VERBOSE] Average actual qo seq len: 103 -[VERBOSE] Average actual kv seq len: 103 -[VVERBOSE] actual_seq_lens_q.flatten() = tensor([103], dtype=torch.int32) -[VVERBOSE] actual_seq_lens_kv.flatten() = tensor([103], dtype=torch.int32) -[VVERBOSE] q.shape = torch.Size([103, 128, 192]) -[VVERBOSE] k.shape = torch.Size([103, 128, 192]) -[VVERBOSE] v.shape = torch.Size([103, 128, 128]) -[VVERBOSE] qo_indptr.shape = torch.Size([2]) -[VVERBOSE] kv_indptr.shape = torch.Size([2]) -[VVERBOSE] scale = 0.07216878364870323 -[PERF] fa2 :: median time 0.016 ms; std 0.000 ms; achieved tflops 26.943 TFLOPs/sec; achieved tb_per_sec 1.046 TB/sec -[PERF] cutlass :: median time 0.012 ms; std 0.000 ms; achieved tflops 35.963 TFLOPs/sec; achieved tb_per_sec 1.397 TB/sec -[PERF] cudnn :: median time 0.019 ms; std 0.000 ms; achieved tflops 23.316 TFLOPs/sec; achieved tb_per_sec 0.905 TB/sec -[INFO] args = Namespace(routine='BatchPrefillWithRaggedKVCacheWrapper', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=True, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='DeepSeek-R1', generate_repro_command=True, repro_command='', backends=['fa2', 'fa3', 'cutlass', 'cudnn'], page_size=0, batch_size=16, s_qo=1024, s_kv=1024, num_qo_heads=128, num_kv_heads=128, head_dim_qk=192, head_dim_vo=128, head_dim_ckv=None, head_dim_kpe=None, q_dtype='bfloat16', kv_dtype='bfloat16', causal=True, random_actual_seq_len=True) -[INFO] Running testBatchPrefillWithRaggedKVCacheWrapper -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 fa3 cutlass cudnn --batch_size 16 --s_qo 1024 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag DeepSeek-R1 [WARNING] fa3 for routine BatchPrefillWithRaggedKVCacheWrapper is not supported on compute capability 10.0. Skipping. @@ -93,35 +44,13 @@ [VVERBOSE] qo_indptr.shape = torch.Size([17]) [VVERBOSE] kv_indptr.shape = torch.Size([17]) [VVERBOSE] scale = 0.07216878364870323 -[PERF] fa2 :: median time 0.498 ms; std 0.005 ms; achieved tflops 217.968 TFLOPs/sec; achieved tb_per_sec 1.726 TB/sec -[PERF] cutlass :: median time 0.533 ms; std 0.001 ms; achieved tflops 203.573 TFLOPs/sec; achieved tb_per_sec 1.612 TB/sec -[PERF] cudnn :: median time 0.312 ms; std 0.001 ms; achieved tflops 347.342 TFLOPs/sec; achieved tb_per_sec 2.750 TB/sec -[INFO] args = Namespace(routine='BatchDecodeWithPagedKVCacheWrapper', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=True, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='Llama-3.1-70B', generate_repro_command=True, repro_command='', backends=['fa2', 'fa2_tc', 'cudnn', 'trtllm-gen', 'trtllm-gen-native'], page_size=16, batch_size=1, s_qo=1, s_kv=1024, num_qo_heads=64, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, head_dim_ckv=None, head_dim_kpe=None, q_dtype='bfloat16', kv_dtype='bfloat16', causal=False, random_actual_seq_len=True) +[PERF] fa2 :: median time 0.508 ms; std 0.003 ms; achieved tflops 213.668 TFLOPs/sec; achieved tb_per_sec 1.692 TB/sec +[PERF] cutlass :: median time 0.516 ms; std 0.004 ms; achieved tflops 210.340 TFLOPs/sec; achieved tb_per_sec 1.665 TB/sec +[PERF] cudnn :: median time 0.292 ms; std 0.001 ms; achieved tflops 372.144 TFLOPs/sec; achieved tb_per_sec 2.946 TB/sec +[WARNING] Backend name 'trtllm-gen-native' has been renamed to 'trtllm-native' and will be removed in a future release. +[INFO] args = Namespace(routine='BatchDecodeWithPagedKVCacheWrapper', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=True, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='Llama-3.1-70B', generate_repro_command=True, repro_command='', backends=['fa2', 'fa2_tc', 'cudnn', 'trtllm-gen', 'trtllm-native'], page_size=16, batch_size=16, s_qo=1, s_kv=1024, num_qo_heads=64, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, head_dim_ckv=None, head_dim_kpe=None, q_dtype='bfloat16', kv_dtype='bfloat16', causal=False, random_actual_seq_len=True) [INFO] Running testBatchDecodeWithPagedKVCacheWrapper -[INFO] FlashInfer version: 0.3.1 -[VVERBOSE] gpu_name = 'NVIDIA_B200' -[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 1 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B -[VERBOSE] Average actual seq len: 84 -[VVERBOSE] actual_seq_lens_kv.flatten() = tensor([84], device='cuda:0', dtype=torch.int32) -[VVERBOSE] q.shape = torch.Size([1, 64, 128]) -[VVERBOSE] num_pages_per_seq = 64 -[VVERBOSE] total_num_pages = 64 -[VVERBOSE] kv_cache.shape = torch.Size([64, 2, 8, 16, 128]) -[VVERBOSE] kv_cache.stride() = (32768, 16384, 128, 1024, 1) -[VVERBOSE] block_tables.shape = torch.Size([1, 64]) -[VVERBOSE] kv_indptr.shape = torch.Size([2]) -[VVERBOSE] kv_indices.shape = torch.Size([6]) -[VVERBOSE] kv_last_page_len.shape = torch.Size([1]) -[VVERBOSE] scale = 0.08838834764831843 -[ERROR] Output tensor mismatch between backends fa2 and cudnn: 5063 / 8192 (61.80%) elements are different -[PERF] fa2 :: median time 0.035 ms; std 0.000 ms; achieved tflops 0.079 TFLOPs/sec; achieved tb_per_sec 0.011 TB/sec -[PERF] fa2_tc :: median time 0.010 ms; std 0.000 ms; achieved tflops 0.263 TFLOPs/sec; achieved tb_per_sec 0.036 TB/sec -[PERF] cudnn :: median time 0.011 ms; std 0.000 ms; achieved tflops 0.258 TFLOPs/sec; achieved tb_per_sec 0.035 TB/sec -[PERF] trtllm-gen :: median time 0.006 ms; std 0.000 ms; achieved tflops 0.480 TFLOPs/sec; achieved tb_per_sec 0.066 TB/sec -[PERF] trtllm-gen-nati:: median time 0.006 ms; std 0.000 ms; achieved tflops 0.445 TFLOPs/sec; achieved tb_per_sec 0.061 TB/sec -[INFO] args = Namespace(routine='BatchDecodeWithPagedKVCacheWrapper', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=True, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='Llama-3.1-70B', generate_repro_command=True, repro_command='', backends=['fa2', 'fa2_tc', 'cudnn', 'trtllm-gen', 'trtllm-gen-native'], page_size=16, batch_size=16, s_qo=1, s_kv=1024, num_qo_heads=64, num_kv_heads=8, head_dim_qk=128, head_dim_vo=128, head_dim_ckv=None, head_dim_kpe=None, q_dtype='bfloat16', kv_dtype='bfloat16', causal=False, random_actual_seq_len=True) -[INFO] Running testBatchDecodeWithPagedKVCacheWrapper -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc cudnn trtllm-gen trtllm-gen-native --page_size 16 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --allow_output_mismatch --generate_repro_command --case_tag Llama-3.1-70B [VERBOSE] Average actual seq len: 501 @@ -137,14 +66,15 @@ [VVERBOSE] kv_indices.shape = torch.Size([509]) [VVERBOSE] kv_last_page_len.shape = torch.Size([16]) [VVERBOSE] scale = 0.08838834764831843 -[PERF] fa2 :: median time 0.055 ms; std 0.000 ms; achieved tflops 4.757 TFLOPs/sec; achieved tb_per_sec 0.604 TB/sec -[PERF] fa2_tc :: median time 0.017 ms; std 0.000 ms; achieved tflops 15.285 TFLOPs/sec; achieved tb_per_sec 1.941 TB/sec -[PERF] cudnn :: median time 0.014 ms; std 0.000 ms; achieved tflops 19.109 TFLOPs/sec; achieved tb_per_sec 2.427 TB/sec -[PERF] trtllm-gen :: median time 0.010 ms; std 0.000 ms; achieved tflops 27.308 TFLOPs/sec; achieved tb_per_sec 3.468 TB/sec -[PERF] trtllm-gen-nati:: median time 0.010 ms; std 0.000 ms; achieved tflops 27.235 TFLOPs/sec; achieved tb_per_sec 3.459 TB/sec -[INFO] args = Namespace(routine='BatchMLAPagedAttentionWrapper', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='DeepSeek-R1', generate_repro_command=True, repro_command='', backends=['trtllm-gen-native', 'fa2', 'fa3'], page_size=32, batch_size=16, s_qo=1, s_kv=1024, num_qo_heads=128, num_kv_heads=128, head_dim_qk=None, head_dim_vo=None, head_dim_ckv=512, head_dim_kpe=64, q_dtype='bfloat16', kv_dtype='bfloat16', causal=False, random_actual_seq_len=True) +[PERF] fa2 :: median time 0.050 ms; std 0.000 ms; achieved tflops 5.211 TFLOPs/sec; achieved tb_per_sec 0.662 TB/sec +[PERF] fa2_tc :: median time 0.022 ms; std 0.000 ms; achieved tflops 12.052 TFLOPs/sec; achieved tb_per_sec 1.531 TB/sec +[PERF] cudnn :: median time 0.015 ms; std 0.000 ms; achieved tflops 17.359 TFLOPs/sec; achieved tb_per_sec 2.205 TB/sec +[PERF] trtllm-gen :: median time 0.014 ms; std 0.000 ms; achieved tflops 19.478 TFLOPs/sec; achieved tb_per_sec 2.474 TB/sec +[PERF] trtllm-native :: median time 0.013 ms; std 0.000 ms; achieved tflops 19.501 TFLOPs/sec; achieved tb_per_sec 2.476 TB/sec +[WARNING] Backend name 'trtllm-gen-native' has been renamed to 'trtllm-native' and will be removed in a future release. +[INFO] args = Namespace(routine='BatchMLAPagedAttentionWrapper', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='DeepSeek-R1', generate_repro_command=True, repro_command='', backends=['trtllm-native', 'fa2', 'fa3'], page_size=32, batch_size=16, s_qo=1, s_kv=1024, num_qo_heads=128, num_kv_heads=128, head_dim_qk=None, head_dim_vo=None, head_dim_ckv=512, head_dim_kpe=64, q_dtype='bfloat16', kv_dtype='bfloat16', causal=False, random_actual_seq_len=True) [INFO] Running testBatchMLAPagedAttentionWrapper -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine BatchMLAPagedAttentionWrapper --backends trtllm-gen-native fa2 fa3 --page_size 32 --batch_size 16 --s_qo 1 --s_kv 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim_ckv 512 --head_dim_kpe 64 --random_actual_seq_len -vv --refcheck --q_dtype bfloat16 --kv_dtype bfloat16 --generate_repro_command --case_tag DeepSeek-R1 [WARNING] fa3 for routine BatchMLAPagedAttentionWrapper is not supported on compute capability 10.0. Skipping. @@ -164,29 +94,13 @@ [VVERBOSE] kv_indptr.shape = torch.Size([17]) [VVERBOSE] kv_indices.shape = torch.Size([258]) [VVERBOSE] actual_seq_lens_kv.shape = torch.Size([16, 1, 1, 1]) -[VVERBOSE] sm_scale = 0.041666666666666664 +[VVERBOSE] sm_scale = 0.07216878364870323 [VVERBOSE] workspace_buffer.shape = torch.Size([134217728]) -[PERF] trtllm-gen-nati:: median time 0.024 ms; std 0.000 ms; achieved tflops 91.551 TFLOPs/sec; achieved tb_per_sec 0.955 TB/sec -[PERF] fa2 :: median time 0.041 ms; std 0.000 ms; achieved tflops 54.584 TFLOPs/sec; achieved tb_per_sec 0.570 TB/sec -[INFO] args = Namespace(routine='bmm_fp8', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=256, m=1, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cublas', 'cutlass'], use_128x4_sf_layout=False, use_nvfp4=False, autotune=False) +[PERF] trtllm-native :: median time 0.024 ms; std 0.002 ms; achieved tflops 91.928 TFLOPs/sec; achieved tb_per_sec 0.959 TB/sec +[PERF] fa2 :: median time 0.039 ms; std 0.000 ms; achieved tflops 57.103 TFLOPs/sec; achieved tb_per_sec 0.596 TB/sec +[INFO] args = Namespace(routine='bmm_fp8', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=64, m=4, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cublas', 'cutlass'], use_128x4_sf_layout=False, use_nvfp4=False, autotune=False) [INFO] Running testBmmFp8 -[INFO] FlashInfer version: 0.3.1 -[VVERBOSE] gpu_name = 'NVIDIA_B200' -[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 256 --m 1 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command -[VVERBOSE] input_fp8.shape = torch.Size([256, 1, 7168]) -[VVERBOSE] input_fp8.dtype = torch.float8_e4m3fn -[VVERBOSE] mat2_fp8.shape = torch.Size([256, 7168, 1024]) -[VVERBOSE] mat2_fp8.dtype = torch.float8_e4m3fn -[VVERBOSE] input_inv_s = tensor(0.0109, device='cuda:0') -[VVERBOSE] input_inv_s.dtype = torch.float32 -[VVERBOSE] mat2_inv_s = tensor(0.0135, device='cuda:0') -[VVERBOSE] mat2_inv_s.dtype = torch.float32 -[PERF] cudnn :: median time 0.286 ms; std 0.000 ms; achieved tflops 13.138 TFLOPs/sec; achieved tb_per_sec 0.026 TB/sec -[PERF] cublas :: median time 0.286 ms; std 0.000 ms; achieved tflops 13.140 TFLOPs/sec; achieved tb_per_sec 0.026 TB/sec -[PERF] cutlass :: median time 0.266 ms; std 0.000 ms; achieved tflops 14.142 TFLOPs/sec; achieved tb_per_sec 0.028 TB/sec -[INFO] args = Namespace(routine='bmm_fp8', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=64, m=4, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cublas', 'cutlass'], use_128x4_sf_layout=False, use_nvfp4=False, autotune=False) -[INFO] Running testBmmFp8 -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine bmm_fp8 --batch_size 64 --m 4 --n 1024 --k 7168 --input_dtype fp8_e4m3 --mat2_dtype fp8_e4m3 --out_dtype bfloat16 --backends cudnn cublas cutlass --refcheck -vv --generate_repro_command [VVERBOSE] input_fp8.shape = torch.Size([64, 4, 7168]) @@ -197,24 +111,12 @@ [VVERBOSE] input_inv_s.dtype = torch.float32 [VVERBOSE] mat2_inv_s = tensor(0.0131, device='cuda:0') [VVERBOSE] mat2_inv_s.dtype = torch.float32 -[PERF] cudnn :: median time 0.075 ms; std 0.000 ms; achieved tflops 49.999 TFLOPs/sec; achieved tb_per_sec 0.098 TB/sec -[PERF] cublas :: median time 0.075 ms; std 0.000 ms; achieved tflops 50.135 TFLOPs/sec; achieved tb_per_sec 0.098 TB/sec -[PERF] cutlass :: median time 0.072 ms; std 0.000 ms; achieved tflops 51.981 TFLOPs/sec; achieved tb_per_sec 0.102 TB/sec -[INFO] args = Namespace(routine='gemm_fp8_nt_groupwise', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=4, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cutlass'], use_128x4_sf_layout=False, use_nvfp4=False, autotune=False) -[INFO] Running testGemmFp8NtGroupwise -[INFO] FlashInfer version: 0.3.1 -[VVERBOSE] gpu_name = 'NVIDIA_B200' -[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine gemm_fp8_nt_groupwise --m 4 --n 1024 --k 7168 --mma_sm 1 --scale_major_mode MN --backends cutlass --refcheck -vv --generate_repro_command -[VVERBOSE] a_val.shape = torch.Size([4, 7168]) -[VVERBOSE] b_val.shape = torch.Size([1024, 7168]) -[VVERBOSE] a_fp8.shape = torch.Size([4, 7168]) -[VVERBOSE] b_fp8.shape = torch.Size([1024, 7168]) -[VVERBOSE] a_scale.shape = torch.Size([56, 4]) -[VVERBOSE] b_scale.shape = torch.Size([56, 8]) -[PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 2.987 TFLOPs/sec; achieved tb_per_sec 0.375 TB/sec -[INFO] args = Namespace(routine='gemm_fp8_nt_groupwise', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=16, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cutlass'], use_128x4_sf_layout=False, use_nvfp4=False, autotune=False) +[PERF] cudnn :: median time 0.085 ms; std 0.000 ms; achieved tflops 44.460 TFLOPs/sec; achieved tb_per_sec 0.087 TB/sec +[PERF] cublas :: median time 0.085 ms; std 0.000 ms; achieved tflops 44.400 TFLOPs/sec; achieved tb_per_sec 0.087 TB/sec +[PERF] cutlass :: median time 0.086 ms; std 0.000 ms; achieved tflops 43.593 TFLOPs/sec; achieved tb_per_sec 0.086 TB/sec +[INFO] args = Namespace(routine='gemm_fp8_nt_groupwise', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=16, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cutlass'], use_128x4_sf_layout=False, use_nvfp4=False, autotune=False) [INFO] Running testGemmFp8NtGroupwise -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --scale_major_mode MN --backends cutlass --refcheck -vv --generate_repro_command [VVERBOSE] a_val.shape = torch.Size([16, 7168]) @@ -223,23 +125,10 @@ [VVERBOSE] b_fp8.shape = torch.Size([1024, 7168]) [VVERBOSE] a_scale.shape = torch.Size([56, 16]) [VVERBOSE] b_scale.shape = torch.Size([56, 8]) -[PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 11.824 TFLOPs/sec; achieved tb_per_sec 0.377 TB/sec -[INFO] args = Namespace(routine='group_gemm_fp8_nt_groupwise', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=4, n=1024, k=7168, tile_size=128, group_size=2, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn'], use_128x4_sf_layout=False, use_nvfp4=False, autotune=False) +[PERF] cutlass :: median time 0.016 ms; std 0.000 ms; achieved tflops 14.636 TFLOPs/sec; achieved tb_per_sec 0.467 TB/sec +[INFO] args = Namespace(routine='group_gemm_fp8_nt_groupwise', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=16, n=1024, k=7168, tile_size=128, group_size=2, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn'], use_128x4_sf_layout=False, use_nvfp4=False, autotune=False) [INFO] Running testGroupGemmFp8NtGroupwise -[INFO] FlashInfer version: 0.3.1 -[VVERBOSE] gpu_name = 'NVIDIA_B200' -[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine group_gemm_fp8_nt_groupwise --m 4 --n 1024 --k 7168 --mma_sm 1 --group_size 2 --scale_major_mode MN --refcheck -vv --generate_repro_command -[VVERBOSE] a_val.shape = torch.Size([8, 7168]) -[VVERBOSE] b_val.shape = torch.Size([2, 1024, 7168]) -[VVERBOSE] a_fp8.shape = torch.Size([8, 7168]) -[VVERBOSE] b_fp8.shape = torch.Size([2, 1024, 7168]) -[VVERBOSE] a_scale.shape = torch.Size([56, 8]) -[VVERBOSE] b_scale.shape = torch.Size([2, 56, 8]) -[VVERBOSE] m_indptr.shape = torch.Size([3]) -[PERF] cutlass :: median time 0.022 ms; std 0.000 ms; achieved tflops 5.261 TFLOPs/sec; achieved tb_per_sec 0.661 TB/sec -[INFO] args = Namespace(routine='group_gemm_fp8_nt_groupwise', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=16, n=1024, k=7168, tile_size=128, group_size=2, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn'], use_128x4_sf_layout=False, use_nvfp4=False, autotune=False) -[INFO] Running testGroupGemmFp8NtGroupwise -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine group_gemm_fp8_nt_groupwise --m 16 --n 1024 --k 7168 --mma_sm 1 --group_size 2 --scale_major_mode MN --refcheck -vv --generate_repro_command [VVERBOSE] a_val.shape = torch.Size([32, 7168]) @@ -249,52 +138,43 @@ [VVERBOSE] a_scale.shape = torch.Size([56, 32]) [VVERBOSE] b_scale.shape = torch.Size([2, 56, 8]) [VVERBOSE] m_indptr.shape = torch.Size([3]) -[PERF] cutlass :: median time 0.023 ms; std 0.000 ms; achieved tflops 20.852 TFLOPs/sec; achieved tb_per_sec 0.665 TB/sec -[INFO] args = Namespace(routine='mm_fp4', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=1, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cutlass', 'trtllm'], use_128x4_sf_layout=True, use_nvfp4=True, autotune=False) -[INFO] Running testMmFp4 -[INFO] FlashInfer version: 0.3.1 -[VVERBOSE] gpu_name = 'NVIDIA_B200' -[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mm_fp4 --m 1 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command -[VVERBOSE] input_fp4.shape = torch.Size([1, 3584]) -[VVERBOSE] input_fp4.dtype = torch.uint8 -[VVERBOSE] mat2_fp4.shape = torch.Size([1024, 3584]) -[VVERBOSE] mat2_fp4.dtype = torch.uint8 -[PERF] cudnn :: median time 0.013 ms; std 0.000 ms; achieved tflops 1.156 TFLOPs/sec; achieved tb_per_sec 0.289 TB/sec -[PERF] cutlass :: median time 0.009 ms; std 0.000 ms; achieved tflops 1.593 TFLOPs/sec; achieved tb_per_sec 0.399 TB/sec -[PERF] trtllm :: median time 0.011 ms; std 0.000 ms; achieved tflops 1.352 TFLOPs/sec; achieved tb_per_sec 0.339 TB/sec -[INFO] args = Namespace(routine='mm_fp4', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=4, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cutlass', 'trtllm'], use_128x4_sf_layout=True, use_nvfp4=True, autotune=False) +[PERF] cutlass :: median time 0.023 ms; std 0.000 ms; achieved tflops 20.574 TFLOPs/sec; achieved tb_per_sec 0.656 TB/sec +[INFO] args = Namespace(routine='mm_fp4', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=512, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cutlass', 'trtllm'], use_128x4_sf_layout=True, use_nvfp4=True, autotune=False) [INFO] Running testMmFp4 -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' -[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command -[VVERBOSE] input_fp4.shape = torch.Size([4, 3584]) +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command +[VVERBOSE] input_fp4.shape = torch.Size([512, 3584]) [VVERBOSE] input_fp4.dtype = torch.uint8 [VVERBOSE] mat2_fp4.shape = torch.Size([1024, 3584]) [VVERBOSE] mat2_fp4.dtype = torch.uint8 -[PERF] cudnn :: median time 0.013 ms; std 0.000 ms; achieved tflops 4.625 TFLOPs/sec; achieved tb_per_sec 0.291 TB/sec -[PERF] cutlass :: median time 0.009 ms; std 0.000 ms; achieved tflops 6.372 TFLOPs/sec; achieved tb_per_sec 0.401 TB/sec -[PERF] trtllm :: median time 0.011 ms; std 0.000 ms; achieved tflops 5.310 TFLOPs/sec; achieved tb_per_sec 0.334 TB/sec -[INFO] args = Namespace(routine='mm_fp4', no_cuda_graph=False, use_cupti=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=4, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cutlass', 'trtllm'], use_128x4_sf_layout=True, use_nvfp4=True, autotune=True) +[PERF] cudnn :: median time 0.009 ms; std 0.000 ms; achieved tflops 821.262 TFLOPs/sec; achieved tb_per_sec 0.716 TB/sec +[PERF] cutlass :: median time 0.010 ms; std 0.000 ms; achieved tflops 738.620 TFLOPs/sec; achieved tb_per_sec 0.644 TB/sec +[PERF] trtllm :: median time 0.012 ms; std 0.000 ms; achieved tflops 634.814 TFLOPs/sec; achieved tb_per_sec 0.554 TB/sec +[INFO] args = Namespace(routine='mm_fp4', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=True, repro_command='', batch_size=1, m=512, n=1024, k=7168, tile_size=128, group_size=1, scale_major_mode='MN', input_dtype='fp8_e4m3', mat2_dtype='fp8_e4m3', out_dtype='bfloat16', mma_sm=1, backends=['cudnn', 'cutlass', 'trtllm'], use_128x4_sf_layout=True, use_nvfp4=True, autotune=True) [INFO] Running testMmFp4 -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' -[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command -[INFO] cudnn backend does not support autotune -[VVERBOSE] input_fp4.shape = torch.Size([4, 3584]) +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mm_fp4 --m 512 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn cutlass trtllm --use_128x4_sf_layout --use_nvfp4 --autotune --refcheck -vv --generate_repro_command +[VVERBOSE] input_fp4.shape = torch.Size([512, 3584]) [VVERBOSE] input_fp4.dtype = torch.uint8 [VVERBOSE] mat2_fp4.shape = torch.Size([1024, 3584]) [VVERBOSE] mat2_fp4.dtype = torch.uint8 [INFO] Autotune warmup for mm_fp4: 5 iters -2025-09-23 00:32:18,077 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... -2025-09-23 00:32:18,224 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends +2026-02-03 15:06:53,361 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... +2026-02-03 15:06:58,625 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends [INFO] Autotune warmup for mm_fp4: 5 iters -2025-09-23 00:32:18,225 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... -2025-09-23 00:32:18,247 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends -[PERF] cutlass_autotun:: median time 0.009 ms; std 0.000 ms; achieved tflops 6.372 TFLOPs/sec; achieved tb_per_sec 0.401 TB/sec -[PERF] trtllm_autotune:: median time 0.011 ms; std 0.000 ms; achieved tflops 5.410 TFLOPs/sec; achieved tb_per_sec 0.340 TB/sec -[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +2026-02-03 15:06:58,625 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... +2026-02-03 15:06:59,031 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends +[INFO] Autotune warmup for mm_fp4: 5 iters +2026-02-03 15:06:59,031 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ... +2026-02-03 15:06:59,105 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends +[PERF] cudnn_autotune :: median time 0.009 ms; std 0.000 ms; achieved tflops 822.746 TFLOPs/sec; achieved tb_per_sec 0.717 TB/sec +[PERF] cutlass_autotun:: median time 0.010 ms; std 0.000 ms; achieved tflops 756.461 TFLOPs/sec; achieved tb_per_sec 0.660 TB/sec +[PERF] trtllm_autotune:: median time 0.014 ms; std 0.000 ms; achieved tflops 535.037 TFLOPs/sec; achieved tb_per_sec 0.467 TB/sec +[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, num_experts=256, top_k=8, input_dtype='bfloat16', intermediate_size=1024, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testTrtllmFp4BlockScaleMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 256 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample [INFO] Configuration: tokens=1024, hidden=1024, intermediate=1024, experts=256, top_k=8 @@ -302,10 +182,11 @@ [VVERBOSE] hidden_states.shape = torch.Size([1024, 1024]) [VVERBOSE] gemm1_weights_fp4.shape = torch.Size([256, 2048, 512]) [VVERBOSE] gemm2_weights_fp4.shape = torch.Size([256, 1024, 512]) -[PERF] trtllm :: median time 0.224 ms; std 0.000 ms; achieved tflops 230.555 TFLOPs/sec; achieved tb_per_sec 1.818 TB/sec -[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=8, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='renormalize_naive', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=4, gated_act_type=0) +[VVERBOSE] num_active_experts = 256 +[PERF] trtllm :: median time 0.131 ms; std 0.001 ms; achieved tflops 392.306 TFLOPs/sec; achieved tb_per_sec 3.476 TB/sec +[INFO] args = Namespace(routine='trtllm_fp4_block_scale_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, num_experts=128, top_k=8, input_dtype='bfloat16', intermediate_size=1024, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='renormalize_naive', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=4, gated_act_type=0) [INFO] Running testTrtllmFp4BlockScaleMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine trtllm_fp4_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 8 --routing_method renormalize_naive --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample [INFO] Configuration: tokens=1024, hidden=1024, intermediate=1024, experts=128, top_k=8 @@ -313,10 +194,11 @@ [VVERBOSE] hidden_states.shape = torch.Size([1024, 1024]) [VVERBOSE] gemm1_weights_fp4.shape = torch.Size([128, 2048, 512]) [VVERBOSE] gemm2_weights_fp4.shape = torch.Size([128, 1024, 512]) -[PERF] trtllm :: median time 0.226 ms; std 0.000 ms; achieved tflops 227.846 TFLOPs/sec; achieved tb_per_sec 0.903 TB/sec -[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=256, top_k=8, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[VVERBOSE] num_active_experts = 128 +[PERF] trtllm :: median time 0.097 ms; std 0.001 ms; achieved tflops 531.906 TFLOPs/sec; achieved tb_per_sec 2.368 TB/sec +[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, num_experts=256, top_k=8, input_dtype='bfloat16', intermediate_size=1024, n_group=8, topk_group=4, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=True, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=False, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testTrtllmFp8BlockScaleMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine trtllm_fp8_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 256 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --routing_method deepseek_v3 --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample [INFO] Configuration: tokens=1024, hidden=1024, intermediate=1024, experts=256, top_k=8 @@ -324,10 +206,11 @@ [VVERBOSE] hidden_states.shape = torch.Size([1024, 1024]) [VVERBOSE] gemm1_weights_fp8.shape = torch.Size([256, 2048, 1024]) [VVERBOSE] gemm2_weights_fp8.shape = torch.Size([256, 1024, 1024]) -[PERF] trtllm :: median time 0.557 ms; std 0.000 ms; achieved tflops 92.607 TFLOPs/sec; achieved tb_per_sec 1.455 TB/sec -[INFO] args = Namespace(routine='trtllm_fp8_per_tensor_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='llama4', use_shuffled_weight=False, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=True, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=3, gated_act_type=0) +[VVERBOSE] num_active_experts = 256 +[PERF] trtllm :: median time 0.303 ms; std 0.002 ms; achieved tflops 170.084 TFLOPs/sec; achieved tb_per_sec 2.671 TB/sec +[INFO] args = Namespace(routine='trtllm_fp8_per_tensor_scale_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, num_experts=128, top_k=1, input_dtype='bfloat16', intermediate_size=1024, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='llama4', use_shuffled_weight=False, weight_layout=0, use_routing_bias=True, use_routing_scales_on_input=True, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=3, gated_act_type=0) [INFO] Running testTrtllmFp8PerTensorScaleMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine trtllm_fp8_per_tensor_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 1 --routed_scaling_factor 2.5 --use_routing_bias --routing_method llama4 --use_routing_scales_on_input -vv --generate_repro_command --case_tag trtllm_moe_sample [INFO] Configuration: tokens=1024, hidden=1024, intermediate=1024, experts=128, top_k=1 @@ -335,10 +218,11 @@ [VVERBOSE] hidden_states.shape = torch.Size([1024, 1024]) [VVERBOSE] gemm1_weights_fp8.shape = torch.Size([128, 2048, 1024]) [VVERBOSE] gemm2_weights_fp8.shape = torch.Size([128, 1024, 1024]) -[PERF] trtllm :: median time 0.123 ms; std 0.000 ms; achieved tflops 52.340 TFLOPs/sec; achieved tb_per_sec 3.299 TB/sec -[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, intermediate_size=1024, num_experts=128, top_k=1, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='renormalize', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='bfloat16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=1, gated_act_type=0) +[VVERBOSE] num_active_experts = 128 +[PERF] trtllm :: median time 0.092 ms; std 0.000 ms; achieved tflops 70.027 TFLOPs/sec; achieved tb_per_sec 4.414 TB/sec +[INFO] args = Namespace(routine='trtllm_fp8_block_scale_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='trtllm_moe_sample', generate_repro_command=True, repro_command='', num_tokens=1024, hidden_size=1024, num_experts=128, top_k=1, input_dtype='bfloat16', intermediate_size=1024, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='renormalize', use_shuffled_weight=True, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=1, gated_act_type=0) [INFO] Running testTrtllmFp8BlockScaleMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine trtllm_fp8_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 1 --routing_method renormalize --use_shuffled_weight -vv --generate_repro_command --case_tag trtllm_moe_sample [INFO] Configuration: tokens=1024, hidden=1024, intermediate=1024, experts=128, top_k=1 @@ -346,49 +230,1002 @@ [VVERBOSE] hidden_states.shape = torch.Size([1024, 1024]) [VVERBOSE] gemm1_weights_fp8.shape = torch.Size([128, 2048, 1024]) [VVERBOSE] gemm2_weights_fp8.shape = torch.Size([128, 1024, 1024]) -[PERF] trtllm :: median time 0.109 ms; std 0.000 ms; achieved tflops 59.297 TFLOPs/sec; achieved tb_per_sec 3.740 TB/sec -[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_base', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, intermediate_size=128, num_experts=2, top_k=2, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='float16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[ERROR] Error running test: --routine trtllm_fp8_block_scale_moe --num_tokens 1024 --hidden_size 1024 --intermediate_size 1024 --num_experts 128 --top_k 1 --routing_method renormalize --use_shuffled_weight -vv --generate_repro_command --case_tag "trtllm_moe_sample" +[ERROR] Error: Check failed: routing_logits.value().dtype() == dl_bfloat16 (float32 vs. bfloat16) : routing_logits must be bfloat16. +[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_base', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, num_experts=2, top_k=2, input_dtype='float16', intermediate_size=128, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testCutlassFusedMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant base --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_base [VVERBOSE] x.shape = torch.Size([32, 128]) [VVERBOSE] w31_weight.shape = torch.Size([2, 256, 128]) [VVERBOSE] w2_weight.shape = torch.Size([2, 128, 128]) -[PERF] cutlass :: median time 0.026 ms; std 0.000 ms; achieved tflops 0.240 TFLOPs/sec; achieved tb_per_sec 0.008 TB/sec -[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_fp8_scale', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, intermediate_size=128, num_experts=2, top_k=2, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='float16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='fp8', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[VVERBOSE] num_active_experts = 2 +[PERF] cutlass :: median time 0.028 ms; std 0.000 ms; achieved tflops 0.226 TFLOPs/sec; achieved tb_per_sec 0.008 TB/sec +[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_fp8_scale', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, num_experts=2, top_k=2, input_dtype='float16', intermediate_size=128, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='fp8', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testCutlassFusedMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant fp8 --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_fp8_scale [VVERBOSE] x.shape = torch.Size([32, 128]) [VVERBOSE] w31_weight.shape = torch.Size([2, 256, 128]) [VVERBOSE] w2_weight.shape = torch.Size([2, 128, 128]) -[PERF] cutlass :: median time 0.026 ms; std 0.000 ms; achieved tflops 0.244 TFLOPs/sec; achieved tb_per_sec 0.004 TB/sec -[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_nvfp4_weights', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, intermediate_size=128, num_experts=2, top_k=2, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='float16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='nvfp4', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[VVERBOSE] num_active_experts = 2 +[PERF] cutlass :: median time 0.028 ms; std 0.000 ms; achieved tflops 0.221 TFLOPs/sec; achieved tb_per_sec 0.004 TB/sec +[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_nvfp4_weights', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, num_experts=2, top_k=2, input_dtype='float16', intermediate_size=128, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='nvfp4', quantized_input=False, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testCutlassFusedMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_nvfp4_weights [VVERBOSE] x.shape = torch.Size([32, 128]) [VVERBOSE] w31_weight.shape = torch.Size([2, 256, 128]) [VVERBOSE] w2_weight.shape = torch.Size([2, 128, 128]) +[VVERBOSE] num_active_experts = 2 [PERF] cutlass :: median time 0.030 ms; std 0.000 ms; achieved tflops 0.210 TFLOPs/sec; achieved tb_per_sec 0.002 TB/sec -[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_nvfp4_weights_quantized', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, intermediate_size=128, num_experts=2, top_k=2, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='float16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='nvfp4', quantized_input=True, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) +[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_nvfp4_weights_quantized', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, num_experts=2, top_k=2, input_dtype='float16', intermediate_size=128, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='nvfp4', quantized_input=True, tp_size=1, tp_rank=0, ep_size=1, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testCutlassFusedMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --quantized_input --input_dtype float16 -vv --generate_repro_command --case_tag cutlass_moe_nvfp4_weights_quantized [VVERBOSE] x.shape = torch.Size([32, 128]) [VVERBOSE] w31_weight.shape = torch.Size([2, 256, 128]) [VVERBOSE] w2_weight.shape = torch.Size([2, 128, 128]) -[PERF] cutlass :: median time 0.029 ms; std 0.000 ms; achieved tflops 0.213 TFLOPs/sec; achieved tb_per_sec 0.002 TB/sec -[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_nvfp4_ep_tp', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, intermediate_size=128, num_experts=8, top_k=2, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, tile_tokens_dim=8, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, input_dtype='float16', weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=2, tp_rank=0, ep_size=4, ep_rank=0, routing_method_type=2, gated_act_type=0) +[VVERBOSE] num_active_experts = 2 +[PERF] cutlass :: median time 0.028 ms; std 0.000 ms; achieved tflops 0.222 TFLOPs/sec; achieved tb_per_sec 0.002 TB/sec +[INFO] args = Namespace(routine='cutlass_fused_moe', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='cutlass_moe_nvfp4_ep_tp', generate_repro_command=True, repro_command='', num_tokens=32, hidden_size=128, num_experts=8, top_k=2, input_dtype='float16', intermediate_size=128, n_group=None, topk_group=None, routed_scaling_factor=2.5, local_expert_offset=0, local_num_experts=None, routing_method='deepseek_v3', use_shuffled_weight=False, weight_layout=0, use_routing_bias=False, use_routing_scales_on_input=False, weight_dtype='bfloat16', gated_act='swiglu', autotune=False, cutlass_variant='base', quantized_input=False, tp_size=2, tp_rank=0, ep_size=4, ep_rank=0, routing_method_type=2, gated_act_type=0) [INFO] Running testCutlassFusedMoe -[INFO] FlashInfer version: 0.3.1 +[INFO] FlashInfer version: 0.6.2 [VVERBOSE] gpu_name = 'NVIDIA_B200' [INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 8 --top_k 2 --cutlass_variant base --input_dtype float16 --tp_size 2 --tp_rank 0 --ep_size 4 --ep_rank 0 -vv --generate_repro_command --case_tag cutlass_moe_nvfp4_ep_tp [VVERBOSE] x.shape = torch.Size([32, 128]) [VVERBOSE] w31_weight.shape = torch.Size([8, 256, 128]) [VVERBOSE] w2_weight.shape = torch.Size([8, 128, 128]) -[PERF] cutlass :: median time 0.025 ms; std 0.000 ms; achieved tflops 0.250 TFLOPs/sec; achieved tb_per_sec 0.032 TB/sec +[VVERBOSE] num_active_experts = 8 +[PERF] cutlass :: median time 0.026 ms; std 0.000 ms; achieved tflops 0.239 TFLOPs/sec; achieved tb_per_sec 0.030 TB/sec +[INFO] args = Namespace(routine='rmsnorm', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_llama_hidden', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnorm +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag rmsnorm_llama_hidden +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([4096]) +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.225 TFLOPs/sec; achieved tb_per_sec 0.183 TB/sec +[INFO] args = Namespace(routine='rmsnorm', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_large_hidden', generate_repro_command=True, repro_command='', batch_size=64, hidden_size=8192, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnorm +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag rmsnorm_large_hidden +[VVERBOSE] input_tensor.shape = torch.Size([64, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([8192]) +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.766 TFLOPs/sec; achieved tb_per_sec 0.617 TB/sec +[INFO] args = Namespace(routine='rmsnorm', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_3d_gqa', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=128, num_heads=32, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnorm +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag rmsnorm_3d_gqa +[VVERBOSE] input_tensor.shape = torch.Size([32, 32, 128]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([128]) +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.256 TFLOPs/sec; achieved tb_per_sec 0.205 TB/sec +[INFO] args = Namespace(routine='rmsnorm', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_3d_mha', generate_repro_command=True, repro_command='', batch_size=16, hidden_size=128, num_heads=64, input_dtype='float16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnorm +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 16 --num_heads 64 --hidden_size 128 --input_dtype float16 --refcheck -vv --generate_repro_command --case_tag rmsnorm_3d_mha +[VVERBOSE] input_tensor.shape = torch.Size([16, 64, 128]) +[VVERBOSE] input_tensor.dtype = torch.float16 +[VVERBOSE] weight.shape = torch.Size([128]) +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.256 TFLOPs/sec; achieved tb_per_sec 0.205 TB/sec +[INFO] args = Namespace(routine='rmsnorm', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_pdl', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=True, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnorm +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --enable_pdl --refcheck -vv --generate_repro_command --case_tag rmsnorm_pdl +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([4096]) +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.225 TFLOPs/sec; achieved tb_per_sec 0.183 TB/sec +[INFO] args = Namespace(routine='rmsnorm_quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_quant_fp8_e4m3', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnormQuant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag rmsnorm_quant_fp8_e4m3 +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_tensor.dtype = torch.float8_e4m3fn +[VVERBOSE] scale = 1.0 +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.223 TFLOPs/sec; achieved tb_per_sec 0.136 TB/sec +[INFO] args = Namespace(routine='rmsnorm_quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_quant_large', generate_repro_command=True, repro_command='', batch_size=64, hidden_size=8192, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnormQuant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag rmsnorm_quant_large +[VVERBOSE] input_tensor.shape = torch.Size([64, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([8192]) +[VVERBOSE] out_tensor.dtype = torch.float8_e4m3fn +[VVERBOSE] scale = 1.0 +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.755 TFLOPs/sec; achieved tb_per_sec 0.458 TB/sec +[INFO] args = Namespace(routine='rmsnorm_quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_quant_fp8_e5m2', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='float16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e5m2', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnormQuant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype float16 --out_dtype fp8_e5m2 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag rmsnorm_quant_fp8_e5m2 +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.float16 +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_tensor.dtype = torch.float8_e5m2 +[VVERBOSE] scale = 1.0 +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.225 TFLOPs/sec; achieved tb_per_sec 0.138 TB/sec +[INFO] args = Namespace(routine='fused_add_rmsnorm_quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='fused_add_rmsnorm_quant_fp8_e4m3', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testFusedAddRmsnormQuant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine fused_add_rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag fused_add_rmsnorm_quant_fp8_e4m3 +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_tensor.dtype = torch.float8_e4m3fn +[VVERBOSE] scale = 1.0 +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.243 TFLOPs/sec; achieved tb_per_sec 0.286 TB/sec +[INFO] args = Namespace(routine='fused_add_rmsnorm_quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='fused_add_rmsnorm_quant_large', generate_repro_command=True, repro_command='', batch_size=64, hidden_size=8192, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testFusedAddRmsnormQuant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine fused_add_rmsnorm_quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag fused_add_rmsnorm_quant_large +[VVERBOSE] input_tensor.shape = torch.Size([64, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([64, 8192]) +[VVERBOSE] weight.shape = torch.Size([8192]) +[VVERBOSE] out_tensor.dtype = torch.float8_e4m3fn +[VVERBOSE] scale = 1.0 +[PERF] cuda :: median time 0.004 ms; std 0.000 ms; achieved tflops 0.799 TFLOPs/sec; achieved tb_per_sec 0.937 TB/sec +[INFO] args = Namespace(routine='fused_add_rmsnorm_quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='fused_add_rmsnorm_quant_pdl', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=True, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testFusedAddRmsnormQuant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine fused_add_rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --enable_pdl --refcheck -vv --generate_repro_command --case_tag fused_add_rmsnorm_quant_pdl +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_tensor.dtype = torch.float8_e4m3fn +[VVERBOSE] scale = 1.0 +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.246 TFLOPs/sec; achieved tb_per_sec 0.289 TB/sec +[INFO] args = Namespace(routine='rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_fp4quant_nvfp4', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag rmsnorm_fp4quant_nvfp4 +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[PERF] cute-dsl :: median time 0.006 ms; std 0.000 ms; achieved tflops 0.109 TFLOPs/sec; achieved tb_per_sec 0.057 TB/sec +[INFO] args = Namespace(routine='rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_fp4quant_nvfp4_large', generate_repro_command=True, repro_command='', batch_size=64, hidden_size=8192, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag rmsnorm_fp4quant_nvfp4_large +[VVERBOSE] input_tensor.shape = torch.Size([64, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([8192]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[PERF] cute-dsl :: median time 0.007 ms; std 0.000 ms; achieved tflops 0.394 TFLOPs/sec; achieved tb_per_sec 0.204 TB/sec +[INFO] args = Namespace(routine='rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_fp4quant_nvfp4_global', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=True, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale -vv --generate_repro_command --case_tag rmsnorm_fp4quant_nvfp4_global +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = True +[VVERBOSE] is_sf_swizzled_layout = False +[PERF] cute-dsl :: median time 0.005 ms; std 0.000 ms; achieved tflops 0.145 TFLOPs/sec; achieved tb_per_sec 0.076 TB/sec +[INFO] args = Namespace(routine='rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_fp4quant_nvfp4_swizzled', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=True, output_both_sf_layouts=False) +[INFO] Running testRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --is_sf_swizzled_layout -vv --generate_repro_command --case_tag rmsnorm_fp4quant_nvfp4_swizzled +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = True +[PERF] cute-dsl :: median time 0.006 ms; std 0.000 ms; achieved tflops 0.104 TFLOPs/sec; achieved tb_per_sec 0.055 TB/sec +[INFO] args = Namespace(routine='rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_fp4quant_mxfp4', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='mxfp4', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 -vv --generate_repro_command --case_tag rmsnorm_fp4quant_mxfp4 +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'mxfp4' +[VVERBOSE] block_size = 32 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[PERF] cute-dsl :: median time 0.005 ms; std 0.000 ms; achieved tflops 0.123 TFLOPs/sec; achieved tb_per_sec 0.064 TB/sec +[INFO] args = Namespace(routine='rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rmsnorm_fp4quant_3d', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=128, num_heads=32, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag rmsnorm_fp4quant_3d +[VVERBOSE] input_tensor.shape = torch.Size([32, 32, 128]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([128]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[PERF] cute-dsl :: median time 0.004 ms; std 0.000 ms; achieved tflops 0.163 TFLOPs/sec; achieved tb_per_sec 0.083 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_nvfp4', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_nvfp4 +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] output_both_sf_layouts = False +[PERF] cute-dsl :: median time 0.005 ms; std 0.000 ms; achieved tflops 0.150 TFLOPs/sec; achieved tb_per_sec 0.165 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_nvfp4_large', generate_repro_command=True, repro_command='', batch_size=64, hidden_size=8192, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_nvfp4_large +[VVERBOSE] input_tensor.shape = torch.Size([64, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([64, 8192]) +[VVERBOSE] weight.shape = torch.Size([8192]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] output_both_sf_layouts = False +[PERF] cute-dsl :: median time 0.006 ms; std 0.000 ms; achieved tflops 0.543 TFLOPs/sec; achieved tb_per_sec 0.597 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_nvfp4_global', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=True, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_nvfp4_global +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = True +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] output_both_sf_layouts = False +[PERF] cute-dsl :: median time 0.004 ms; std 0.000 ms; achieved tflops 0.217 TFLOPs/sec; achieved tb_per_sec 0.240 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_nvfp4_swizzled', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=True, output_both_sf_layouts=False) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --is_sf_swizzled_layout -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_nvfp4_swizzled +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = True +[VVERBOSE] output_both_sf_layouts = False +[PERF] cute-dsl :: median time 0.006 ms; std 0.000 ms; achieved tflops 0.142 TFLOPs/sec; achieved tb_per_sec 0.157 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_mxfp4', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='mxfp4', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_mxfp4 +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'mxfp4' +[VVERBOSE] block_size = 32 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] output_both_sf_layouts = False +[PERF] cute-dsl :: median time 0.005 ms; std 0.000 ms; achieved tflops 0.147 TFLOPs/sec; achieved tb_per_sec 0.162 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_3d', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=128, num_heads=32, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=False) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_3d +[VVERBOSE] input_tensor.shape = torch.Size([32, 32, 128]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 32, 128]) +[VVERBOSE] weight.shape = torch.Size([128]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] output_both_sf_layouts = False +[PERF] cute-dsl :: median time 0.004 ms; std 0.000 ms; achieved tflops 0.198 TFLOPs/sec; achieved tb_per_sec 0.217 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_both_sf', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=True) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_both_sf +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] output_both_sf_layouts = True +[PERF] cute-dsl :: median time 0.005 ms; std 0.000 ms; achieved tflops 0.150 TFLOPs/sec; achieved tb_per_sec 0.167 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_both_sf_large', generate_repro_command=True, repro_command='', batch_size=64, hidden_size=8192, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=True) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --output_both_sf_layouts -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_both_sf_large +[VVERBOSE] input_tensor.shape = torch.Size([64, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([64, 8192]) +[VVERBOSE] weight.shape = torch.Size([8192]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] output_both_sf_layouts = True +[PERF] cute-dsl :: median time 0.006 ms; std 0.000 ms; achieved tflops 0.559 TFLOPs/sec; achieved tb_per_sec 0.620 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_both_sf_global', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='fp8_e4m3', backends=['cuda'], use_global_scale=True, is_sf_swizzled_layout=False, output_both_sf_layouts=True) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale --output_both_sf_layouts -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_both_sf_global +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'nvfp4' +[VVERBOSE] block_size = 16 +[VVERBOSE] use_global_scale = True +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] output_both_sf_layouts = True +[PERF] cute-dsl :: median time 0.004 ms; std 0.000 ms; achieved tflops 0.207 TFLOPs/sec; achieved tb_per_sec 0.231 TB/sec +[INFO] args = Namespace(routine='add_rmsnorm_fp4quant', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='add_rmsnorm_fp4quant_mxfp4_both_sf', generate_repro_command=True, repro_command='', batch_size=32, hidden_size=4096, num_heads=None, input_dtype='bfloat16', eps=1e-06, enable_pdl=False, scale=1.0, out_dtype='mxfp4', backends=['cuda'], use_global_scale=False, is_sf_swizzled_layout=False, output_both_sf_layouts=True) +[INFO] Running testAddRmsnormFp4quant +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 --output_both_sf_layouts -vv --generate_repro_command --case_tag add_rmsnorm_fp4quant_mxfp4_both_sf +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] residual_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_dtype = 'mxfp4' +[VVERBOSE] block_size = 32 +[VVERBOSE] use_global_scale = False +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] output_both_sf_layouts = True +[PERF] cute-dsl :: median time 0.005 ms; std 0.000 ms; achieved tflops 0.146 TFLOPs/sec; achieved tb_per_sec 0.162 TB/sec +[INFO] args = Namespace(routine='mxfp8_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp8_quantize_basic', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testMxfp8Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag mxfp8_quantize_basic +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] is_sf_swizzled_layout = True +[VVERBOSE] alignment = 32 +[VVERBOSE] enable_pdl = False +[PERF] cuda :: median time 0.007 ms; std 0.000 ms; achieved tflops 1.855 TFLOPs/sec; achieved tb_per_sec 1.874 TB/sec +[INFO] args = Namespace(routine='mxfp8_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp8_quantize_large', generate_repro_command=True, repro_command='', m=2048, k=8192, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testMxfp8Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag mxfp8_quantize_large +[VVERBOSE] input_tensor.shape = torch.Size([2048, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] is_sf_swizzled_layout = True +[VVERBOSE] alignment = 32 +[VVERBOSE] enable_pdl = False +[PERF] cuda :: median time 0.015 ms; std 0.000 ms; achieved tflops 3.393 TFLOPs/sec; achieved tb_per_sec 3.429 TB/sec +[INFO] args = Namespace(routine='mxfp8_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp8_quantize_fp16', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='float16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testMxfp8Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 1024 --k 4096 --input_dtype float16 -vv --generate_repro_command --case_tag mxfp8_quantize_fp16 +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.float16 +[VVERBOSE] is_sf_swizzled_layout = True +[VVERBOSE] alignment = 32 +[VVERBOSE] enable_pdl = False +[PERF] cuda :: median time 0.007 ms; std 0.000 ms; achieved tflops 1.872 TFLOPs/sec; achieved tb_per_sec 1.892 TB/sec +[INFO] args = Namespace(routine='mxfp8_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp8_quantize_no_swizzle', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=False, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testMxfp8Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 --no_sf_swizzled_layout -vv --generate_repro_command --case_tag mxfp8_quantize_no_swizzle +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] is_sf_swizzled_layout = False +[VVERBOSE] alignment = 32 +[VVERBOSE] enable_pdl = False +[PERF] cuda :: median time 0.007 ms; std 0.000 ms; achieved tflops 1.918 TFLOPs/sec; achieved tb_per_sec 1.938 TB/sec +[INFO] args = Namespace(routine='mxfp8_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp8_quantize_pdl', generate_repro_command=True, repro_command='', m=2048, k=8192, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=True, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testMxfp8Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 --enable_pdl -vv --generate_repro_command --case_tag mxfp8_quantize_pdl +[VVERBOSE] input_tensor.shape = torch.Size([2048, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] is_sf_swizzled_layout = True +[VVERBOSE] alignment = 32 +[VVERBOSE] enable_pdl = True +[PERF] cuda :: median time 0.015 ms; std 0.000 ms; achieved tflops 3.386 TFLOPs/sec; achieved tb_per_sec 3.421 TB/sec +[INFO] args = Namespace(routine='mxfp8_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp8_quantize_refcheck', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testMxfp8Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag mxfp8_quantize_refcheck +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] is_sf_swizzled_layout = True +[VVERBOSE] alignment = 32 +[VVERBOSE] enable_pdl = False +[VVERBOSE] Backend cuda: x_q.shape = torch.Size([1024, 4096]), x_q.dtype = torch.float8_e4m3fn, sf.shape = torch.Size([131072]), sf.dtype = torch.uint8 +[VVERBOSE] Round-trip error: 0/4194304 (0.00%) elements differ +[PERF] cuda :: median time 0.007 ms; std 0.000 ms; achieved tflops 1.846 TFLOPs/sec; achieved tb_per_sec 1.865 TB/sec +[INFO] args = Namespace(routine='mxfp4_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp4_quantize_basic', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testMxfp4Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag mxfp4_quantize_basic +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[PERF] cuda :: median time 0.048 ms; std 0.001 ms; achieved tflops 0.262 TFLOPs/sec; achieved tb_per_sec 0.221 TB/sec +[INFO] args = Namespace(routine='mxfp4_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp4_quantize_large', generate_repro_command=True, repro_command='', m=2048, k=8192, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testMxfp4Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp4_quantize --m 2048 --k 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag mxfp4_quantize_large +[VVERBOSE] input_tensor.shape = torch.Size([2048, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[PERF] cuda :: median time 0.123 ms; std 0.001 ms; achieved tflops 0.408 TFLOPs/sec; achieved tb_per_sec 0.344 TB/sec +[INFO] args = Namespace(routine='mxfp4_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp4_quantize_refcheck', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testMxfp4Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag mxfp4_quantize_refcheck +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] Backend cuda: x_q.shape = torch.Size([1024, 2048]), x_q.dtype = torch.uint8, sf.shape = torch.Size([1024, 128]), sf.dtype = torch.uint8 +[VVERBOSE] Round-trip error: 0/4194304 (0.00%) elements differ +[PERF] cuda :: median time 0.049 ms; std 0.000 ms; achieved tflops 0.259 TFLOPs/sec; achieved tb_per_sec 0.218 TB/sec +[INFO] args = Namespace(routine='nvfp4_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='nvfp4_quantize_128x4', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testNvfp4Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 128x4 -vv --generate_repro_command --case_tag nvfp4_quantize_128x4 +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] global_scale = 1.0 +[VVERBOSE] sf_layout_str = '128x4' +[VVERBOSE] do_shuffle = False +[VVERBOSE] sf_vec_size = 16 +[VVERBOSE] enable_pdl = False +[PERF] cuda :: median time 0.008 ms; std 0.000 ms; achieved tflops 1.576 TFLOPs/sec; achieved tb_per_sec 1.346 TB/sec +[INFO] args = Namespace(routine='nvfp4_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='nvfp4_quantize_128x4_large', generate_repro_command=True, repro_command='', m=2048, k=8192, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testNvfp4Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 128x4 -vv --generate_repro_command --case_tag nvfp4_quantize_128x4_large +[VVERBOSE] input_tensor.shape = torch.Size([2048, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] global_scale = 1.0 +[VVERBOSE] sf_layout_str = '128x4' +[VVERBOSE] do_shuffle = False +[VVERBOSE] sf_vec_size = 16 +[VVERBOSE] enable_pdl = False +[PERF] cuda :: median time 0.014 ms; std 0.000 ms; achieved tflops 3.645 TFLOPs/sec; achieved tb_per_sec 3.113 TB/sec +[INFO] args = Namespace(routine='nvfp4_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='nvfp4_quantize_8x4', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='8x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testNvfp4Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 8x4 -vv --generate_repro_command --case_tag nvfp4_quantize_8x4 +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] global_scale = 1.0 +[VVERBOSE] sf_layout_str = '8x4' +[VVERBOSE] do_shuffle = False +[VVERBOSE] sf_vec_size = 16 +[VVERBOSE] enable_pdl = False +[PERF] cuda :: median time 0.008 ms; std 0.000 ms; achieved tflops 1.573 TFLOPs/sec; achieved tb_per_sec 1.343 TB/sec +[INFO] args = Namespace(routine='nvfp4_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='nvfp4_quantize_shuffle', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=True, sf_vec_size=16) +[INFO] Running testNvfp4Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --do_shuffle -vv --generate_repro_command --case_tag nvfp4_quantize_shuffle +[WARNING] do_shuffle=True is not CUDA graph compatible. Disabling CUDA graph. +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] global_scale = 1.0 +[VVERBOSE] sf_layout_str = '128x4' +[VVERBOSE] do_shuffle = True +[VVERBOSE] sf_vec_size = 16 +[VVERBOSE] enable_pdl = False +[PERF] cuda :: median time 3.709 ms; std 0.030 ms; achieved tflops 0.003 TFLOPs/sec; achieved tb_per_sec 0.003 TB/sec +[INFO] args = Namespace(routine='nvfp4_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='nvfp4_quantize_pdl', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=True, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testNvfp4Quantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --enable_pdl -vv --generate_repro_command --case_tag nvfp4_quantize_pdl +[VVERBOSE] input_tensor.shape = torch.Size([1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] global_scale = 1.0 +[VVERBOSE] sf_layout_str = '128x4' +[VVERBOSE] do_shuffle = False +[VVERBOSE] sf_vec_size = 16 +[VVERBOSE] enable_pdl = True +[PERF] cuda :: median time 0.008 ms; std 0.000 ms; achieved tflops 1.567 TFLOPs/sec; achieved tb_per_sec 1.338 TB/sec +[INFO] args = Namespace(routine='nvfp4_batched_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='nvfp4_batched_basic', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=4, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testNvfp4BatchedQuantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag nvfp4_batched_basic +[VVERBOSE] input_tensor.shape = torch.Size([4, 1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] global_scale = 1.0 +[VVERBOSE] sf_vec_size = 16 +[PERF] cuda :: median time 0.022 ms; std 0.000 ms; achieved tflops 2.300 TFLOPs/sec; achieved tb_per_sec 1.964 TB/sec +[INFO] args = Namespace(routine='nvfp4_batched_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='nvfp4_batched_large', generate_repro_command=True, repro_command='', m=2048, k=8192, input_dtype='bfloat16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=8, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testNvfp4BatchedQuantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine nvfp4_batched_quantize --batch_size 8 --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag nvfp4_batched_large +[VVERBOSE] input_tensor.shape = torch.Size([8, 2048, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] global_scale = 1.0 +[VVERBOSE] sf_vec_size = 16 +[PERF] cuda :: median time 0.079 ms; std 0.001 ms; achieved tflops 5.110 TFLOPs/sec; achieved tb_per_sec 4.365 TB/sec +[INFO] args = Namespace(routine='nvfp4_batched_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='nvfp4_batched_fp16', generate_repro_command=True, repro_command='', m=1024, k=4096, input_dtype='float16', is_sf_swizzled_layout=True, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=4, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) +[INFO] Running testNvfp4BatchedQuantize +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype float16 --global_scale 1.0 -vv --generate_repro_command --case_tag nvfp4_batched_fp16 +[VVERBOSE] input_tensor.shape = torch.Size([4, 1024, 4096]) +[VVERBOSE] input_tensor.dtype = torch.float16 +[VVERBOSE] global_scale = 1.0 +[VVERBOSE] sf_vec_size = 16 +[PERF] cuda :: median time 0.020 ms; std 0.000 ms; achieved tflops 2.537 TFLOPs/sec; achieved tb_per_sec 2.167 TB/sec +[INFO] args = Namespace(routine='softmax', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='softmax_llama', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testSoftmax +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine softmax --batch_size 32 --vocab_size 32000 --temperature 1.0 --input_dtype float32 -vv --generate_repro_command --case_tag softmax_llama +[VVERBOSE] logits.shape = torch.Size([32, 32000]) +[VVERBOSE] logits.dtype = torch.float32 +[PERF] cuda :: median time 0.012 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.679 TB/sec +[INFO] args = Namespace(routine='softmax', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='softmax_llama3_temp', generate_repro_command=True, repro_command='', batch_size=64, vocab_size=128256, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=0.8, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=64, backends=['cuda']) +[INFO] Running testSoftmax +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine softmax --batch_size 64 --vocab_size 128256 --temperature 0.8 --input_dtype float32 -vv --generate_repro_command --case_tag softmax_llama3_temp +[VVERBOSE] logits.shape = torch.Size([64, 128256]) +[VVERBOSE] logits.dtype = torch.float32 +[PERF] cuda :: median time 0.036 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 1.847 TB/sec +[INFO] args = Namespace(routine='sampling_from_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='sampling_from_probs_llama', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testSamplingFromProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine sampling_from_probs --batch_size 32 --vocab_size 32000 -vv --generate_repro_command --case_tag sampling_from_probs_llama +[VVERBOSE] probs.shape = torch.Size([32, 32000]) +[VVERBOSE] probs.dtype = torch.float32 +[PERF] cuda :: median time 0.014 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.303 TB/sec +[INFO] args = Namespace(routine='sampling_from_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='sampling_from_probs_llama3', generate_repro_command=True, repro_command='', batch_size=64, vocab_size=128256, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=64, backends=['cuda']) +[INFO] Running testSamplingFromProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine sampling_from_probs --batch_size 64 --vocab_size 128256 -vv --generate_repro_command --case_tag sampling_from_probs_llama3 +[VVERBOSE] probs.shape = torch.Size([64, 128256]) +[VVERBOSE] probs.dtype = torch.float32 +[PERF] cuda :: median time 0.043 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.768 TB/sec +[INFO] args = Namespace(routine='sampling_from_logits', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='sampling_from_logits_llama', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testSamplingFromLogits +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine sampling_from_logits --batch_size 32 --vocab_size 32000 --input_dtype float32 -vv --generate_repro_command --case_tag sampling_from_logits_llama +[VVERBOSE] logits.shape = torch.Size([32, 32000]) +[VVERBOSE] logits.dtype = torch.float32 +[PERF] cuda :: median time 0.016 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.253 TB/sec +[INFO] args = Namespace(routine='sampling_from_logits', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='sampling_from_logits_llama3', generate_repro_command=True, repro_command='', batch_size=64, vocab_size=128256, input_dtype='bfloat16', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=64, backends=['cuda']) +[INFO] Running testSamplingFromLogits +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine sampling_from_logits --batch_size 64 --vocab_size 128256 --input_dtype bfloat16 -vv --generate_repro_command --case_tag sampling_from_logits_llama3 +[VVERBOSE] logits.shape = torch.Size([64, 128256]) +[VVERBOSE] logits.dtype = torch.bfloat16 +[PERF] cuda :: median time 0.078 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.211 TB/sec +[INFO] args = Namespace(routine='top_k_sampling_from_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_sampling_k50', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopKSamplingFromProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_k 50 -vv --generate_repro_command --case_tag top_k_sampling_k50 +[VVERBOSE] probs.shape = torch.Size([32, 32000]) +[VVERBOSE] probs.dtype = torch.float32 +[VVERBOSE] top_k = 50 +[PERF] cuda :: median time 0.150 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.027 TB/sec +[INFO] args = Namespace(routine='top_k_sampling_from_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_sampling_k100', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=128256, input_dtype='float32', top_k=100, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopKSamplingFromProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_k 100 -vv --generate_repro_command --case_tag top_k_sampling_k100 +[VVERBOSE] probs.shape = torch.Size([32, 128256]) +[VVERBOSE] probs.dtype = torch.float32 +[VVERBOSE] top_k = 100 +[PERF] cuda :: median time 0.498 ms; std 0.002 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.033 TB/sec +[INFO] args = Namespace(routine='top_p_sampling_from_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_p_sampling_p09', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopPSamplingFromProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_p 0.9 -vv --generate_repro_command --case_tag top_p_sampling_p09 +[VVERBOSE] probs.shape = torch.Size([32, 32000]) +[VVERBOSE] probs.dtype = torch.float32 +[VVERBOSE] top_p = 0.9 +[PERF] cuda :: median time 0.023 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.179 TB/sec +[INFO] args = Namespace(routine='top_p_sampling_from_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_p_sampling_p095', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=128256, input_dtype='float32', top_k=50, top_p=0.95, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopPSamplingFromProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --top_p 0.95 -vv --generate_repro_command --case_tag top_p_sampling_p095 +[VVERBOSE] probs.shape = torch.Size([32, 128256]) +[VVERBOSE] probs.dtype = torch.float32 +[VVERBOSE] top_p = 0.95 +[PERF] cuda :: median time 0.072 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.229 TB/sec +[INFO] args = Namespace(routine='top_k_top_p_sampling_from_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_top_p_probs', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopKTopPSamplingFromProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k_top_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --top_k 50 --top_p 0.9 --filter_apply_order top_k_first -vv --generate_repro_command --case_tag top_k_top_p_probs +[VVERBOSE] probs.shape = torch.Size([32, 32000]) +[VVERBOSE] probs.dtype = torch.float32 +[VVERBOSE] top_k = 50 +[VVERBOSE] top_p = 0.9 +[VVERBOSE] filter_apply_order = 'top_k_first' +[PERF] cuda :: median time 0.044 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.094 TB/sec +[INFO] args = Namespace(routine='top_k_top_p_sampling_from_logits', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_top_p_logits', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopKTopPSamplingFromLogits +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k_top_p_sampling_from_logits --batch_size 32 --vocab_size 32000 --top_k 50 --top_p 0.9 --filter_apply_order top_k_first --input_dtype float32 -vv --generate_repro_command --case_tag top_k_top_p_logits +[VVERBOSE] logits.shape = torch.Size([32, 32000]) +[VVERBOSE] logits.dtype = torch.float32 +[VVERBOSE] top_k = 50 +[VVERBOSE] top_p = 0.9 +[VVERBOSE] filter_apply_order = 'top_k_first' +[PERF] cuda :: median time 0.050 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.081 TB/sec +[INFO] args = Namespace(routine='min_p_sampling_from_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='min_p_sampling_p01', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testMinPSamplingFromProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine min_p_sampling_from_probs --batch_size 32 --vocab_size 32000 --min_p 0.1 -vv --generate_repro_command --case_tag min_p_sampling_p01 +[VVERBOSE] probs.shape = torch.Size([32, 32000]) +[VVERBOSE] probs.dtype = torch.float32 +[VVERBOSE] min_p = 0.1 +[PERF] cuda :: median time 0.013 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.322 TB/sec +[INFO] args = Namespace(routine='min_p_sampling_from_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='min_p_sampling_p005', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=128256, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.05, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testMinPSamplingFromProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine min_p_sampling_from_probs --batch_size 32 --vocab_size 128256 --min_p 0.05 -vv --generate_repro_command --case_tag min_p_sampling_p005 +[VVERBOSE] probs.shape = torch.Size([32, 128256]) +[VVERBOSE] probs.dtype = torch.float32 +[VVERBOSE] min_p = 0.05 +[PERF] cuda :: median time 0.042 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.393 TB/sec +[INFO] args = Namespace(routine='top_k_renorm_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_renorm', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopKRenormProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k_renorm_probs --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_renorm +[VVERBOSE] probs.shape = torch.Size([32, 32000]) +[VVERBOSE] probs.dtype = torch.float32 +[VVERBOSE] top_k = 50 +[PERF] cuda :: median time 0.024 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.348 TB/sec +[INFO] args = Namespace(routine='top_p_renorm_probs', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_p_renorm', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopPRenormProbs +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_p_renorm_probs --batch_size 32 --vocab_size 32000 --top_p 0.9 -vv --generate_repro_command --case_tag top_p_renorm +[VVERBOSE] probs.shape = torch.Size([32, 32000]) +[VVERBOSE] probs.dtype = torch.float32 +[VVERBOSE] top_p = 0.9 +[PERF] cuda :: median time 0.080 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.102 TB/sec +[INFO] args = Namespace(routine='top_k_mask_logits', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_mask', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopKMaskLogits +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k_mask_logits --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_mask +[VVERBOSE] logits.shape = torch.Size([32, 32000]) +[VVERBOSE] logits.dtype = torch.float32 +[VVERBOSE] top_k = 50 +[PERF] cuda :: median time 0.020 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.401 TB/sec +[INFO] args = Namespace(routine='chain_speculative_sampling', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='chain_spec_sampling_5', generate_repro_command=True, repro_command='', batch_size=16, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=16, backends=['cuda']) +[INFO] Running testChainSpeculativeSampling +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine chain_speculative_sampling --batch_size 16 --vocab_size 32000 --num_speculate_tokens 5 -vv --generate_repro_command --case_tag chain_spec_sampling_5 +[VVERBOSE] draft_probs.shape = torch.Size([16, 5, 32000]) +[VVERBOSE] draft_token_ids.shape = torch.Size([16, 5]) +[VVERBOSE] target_probs.shape = torch.Size([16, 6, 32000]) +[VVERBOSE] num_speculate_tokens = 5 +[PERF] cuda :: median time 0.027 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.829 TB/sec +[INFO] args = Namespace(routine='chain_speculative_sampling', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='chain_spec_sampling_8', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=128256, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=8, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testChainSpeculativeSampling +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine chain_speculative_sampling --batch_size 32 --vocab_size 128256 --num_speculate_tokens 8 -vv --generate_repro_command --case_tag chain_spec_sampling_8 +[VVERBOSE] draft_probs.shape = torch.Size([32, 8, 128256]) +[VVERBOSE] draft_token_ids.shape = torch.Size([32, 8]) +[VVERBOSE] target_probs.shape = torch.Size([32, 9, 128256]) +[VVERBOSE] num_speculate_tokens = 8 +[PERF] cuda :: median time 0.078 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 3.564 TB/sec +[INFO] args = Namespace(routine='top_k', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_radix', generate_repro_command=True, repro_command='', batch_size=32, vocab_size=32000, input_dtype='float32', top_k=50, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=32, backends=['cuda']) +[INFO] Running testTopK +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k --batch_size 32 --vocab_size 32000 --top_k 50 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_radix +[VVERBOSE] input_tensor.shape = torch.Size([32, 32000]) +[VVERBOSE] input_tensor.dtype = torch.float32 +[VVERBOSE] top_k = 50 +[PERF] cuda :: median time 0.018 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.234 TB/sec +[INFO] args = Namespace(routine='top_k', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_radix_large', generate_repro_command=True, repro_command='', batch_size=64, vocab_size=128256, input_dtype='bfloat16', top_k=100, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=64, backends=['cuda']) +[INFO] Running testTopK +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k --batch_size 64 --vocab_size 128256 --top_k 100 --input_dtype bfloat16 -vv --generate_repro_command --case_tag top_k_radix_large +[VVERBOSE] input_tensor.shape = torch.Size([64, 128256]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] top_k = 100 +[PERF] cuda :: median time 0.039 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.424 TB/sec +[INFO] args = Namespace(routine='top_k_page_table_transform', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_page_table', generate_repro_command=True, repro_command='', batch_size=16, vocab_size=None, input_dtype='float32', top_k=64, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=16, backends=['cuda']) +[INFO] Running testTopKPageTableTransform +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k_page_table_transform --batch_size 16 --num_rows 16 --max_len 4096 --top_k 64 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_page_table +[VVERBOSE] input_scores.shape = torch.Size([16, 4096]) +[VVERBOSE] input_scores.dtype = torch.float32 +[VVERBOSE] src_page_table.shape = torch.Size([16, 4096]) +[VVERBOSE] lengths.shape = torch.Size([16]) +[VVERBOSE] top_k = 64 +[PERF] cuda :: median time 0.008 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.066 TB/sec +[INFO] args = Namespace(routine='top_k_ragged_transform', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='top_k_ragged', generate_repro_command=True, repro_command='', batch_size=16, vocab_size=None, input_dtype='float32', top_k=64, top_p=0.9, min_p=0.1, temperature=1.0, filter_apply_order='top_k_first', num_speculate_tokens=5, max_len=4096, num_rows=16, backends=['cuda']) +[INFO] Running testTopKRaggedTransform +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine top_k_ragged_transform --batch_size 16 --num_rows 16 --max_len 4096 --top_k 64 --input_dtype float32 -vv --generate_repro_command --case_tag top_k_ragged +[VVERBOSE] input_scores.shape = torch.Size([16, 4096]) +[VVERBOSE] input_scores.dtype = torch.float32 +[VVERBOSE] offsets.shape = torch.Size([16]) +[VVERBOSE] lengths.shape = torch.Size([16]) +[VVERBOSE] top_k = 64 +[PERF] cuda :: median time 0.007 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 0.036 TB/sec +[INFO] args = Namespace(routine='apply_rope', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='apply_rope_llama', generate_repro_command=True, repro_command='', batch_size=16, seq_len=1024, num_qo_heads=32, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='float16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testApplyRope +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine apply_rope --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag apply_rope_llama +[VVERBOSE] q.shape = torch.Size([16384, 32, 128]) +[VVERBOSE] k.shape = torch.Size([16384, 8, 128]) +[VVERBOSE] indptr.shape = torch.Size([17]) +[VVERBOSE] offsets.shape = torch.Size([16]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] rope_scale = 1.0 +[VVERBOSE] rope_theta = 10000.0 +[VVERBOSE] interleave = False +[PERF] cuda :: median time 0.123 ms; std 0.001 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 2.730 TB/sec +[INFO] args = Namespace(routine='apply_rope', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='apply_rope_llama70b', generate_repro_command=True, repro_command='', batch_size=32, seq_len=2048, num_qo_heads=64, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testApplyRope +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine apply_rope --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag apply_rope_llama70b +[VVERBOSE] q.shape = torch.Size([65536, 64, 128]) +[VVERBOSE] k.shape = torch.Size([65536, 8, 128]) +[VVERBOSE] indptr.shape = torch.Size([33]) +[VVERBOSE] offsets.shape = torch.Size([32]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] rope_scale = 1.0 +[VVERBOSE] rope_theta = 10000.0 +[VVERBOSE] interleave = False +[PERF] cuda :: median time 0.560 ms; std 0.002 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 4.314 TB/sec +[INFO] args = Namespace(routine='apply_rope_pos_ids', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='apply_rope_pos_ids', generate_repro_command=True, repro_command='', batch_size=16, seq_len=1024, num_qo_heads=32, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='float16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testApplyRopePosIds +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine apply_rope_pos_ids --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag apply_rope_pos_ids +[VVERBOSE] q.shape = torch.Size([16384, 32, 128]) +[VVERBOSE] k.shape = torch.Size([16384, 8, 128]) +[VVERBOSE] pos_ids.shape = torch.Size([16384]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] rope_scale = 1.0 +[VVERBOSE] rope_theta = 10000.0 +[VVERBOSE] interleave = False +[PERF] cuda :: median time 0.081 ms; std 0.001 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 4.145 TB/sec +[INFO] args = Namespace(routine='apply_rope_pos_ids', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='apply_rope_pos_ids_interleave', generate_repro_command=True, repro_command='', batch_size=32, seq_len=2048, num_qo_heads=64, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=True, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testApplyRopePosIds +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine apply_rope_pos_ids --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --interleave -vv --generate_repro_command --case_tag apply_rope_pos_ids_interleave +[VVERBOSE] q.shape = torch.Size([65536, 64, 128]) +[VVERBOSE] k.shape = torch.Size([65536, 8, 128]) +[VVERBOSE] pos_ids.shape = torch.Size([65536]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] rope_scale = 1.0 +[VVERBOSE] rope_theta = 10000.0 +[VVERBOSE] interleave = True +[PERF] cuda :: median time 0.412 ms; std 0.001 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 5.858 TB/sec +[INFO] args = Namespace(routine='apply_llama31_rope', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='apply_llama31_rope', generate_repro_command=True, repro_command='', batch_size=16, seq_len=1024, num_qo_heads=32, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=500000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testApplyLlama31Rope +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine apply_llama31_rope --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --rope_theta 500000.0 --rope_scale 1.0 --low_freq_factor 1.0 --high_freq_factor 4.0 --old_context_len 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag apply_llama31_rope +[VVERBOSE] q.shape = torch.Size([16384, 32, 128]) +[VVERBOSE] k.shape = torch.Size([16384, 8, 128]) +[VVERBOSE] indptr.shape = torch.Size([17]) +[VVERBOSE] offsets.shape = torch.Size([16]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] rope_scale = 1.0 +[VVERBOSE] rope_theta = 500000.0 +[VVERBOSE] interleave = False +[VVERBOSE] low_freq_factor = 1.0 +[VVERBOSE] high_freq_factor = 4.0 +[VVERBOSE] old_context_len = 8192 +[PERF] cuda :: median time 0.124 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 2.701 TB/sec +[INFO] args = Namespace(routine='apply_llama31_rope_pos_ids', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='apply_llama31_rope_pos_ids', generate_repro_command=True, repro_command='', batch_size=16, seq_len=1024, num_qo_heads=32, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=500000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testApplyLlama31RopePosIds +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine apply_llama31_rope_pos_ids --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --rope_theta 500000.0 --rope_scale 1.0 --low_freq_factor 1.0 --high_freq_factor 4.0 --old_context_len 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag apply_llama31_rope_pos_ids +[VVERBOSE] q.shape = torch.Size([16384, 32, 128]) +[VVERBOSE] k.shape = torch.Size([16384, 8, 128]) +[VVERBOSE] pos_ids.shape = torch.Size([16384]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] rope_scale = 1.0 +[VVERBOSE] rope_theta = 500000.0 +[VVERBOSE] interleave = False +[VVERBOSE] low_freq_factor = 1.0 +[VVERBOSE] high_freq_factor = 4.0 +[VVERBOSE] old_context_len = 8192 +[PERF] cuda :: median time 0.081 ms; std 0.001 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 4.131 TB/sec +[INFO] args = Namespace(routine='apply_rope_with_cos_sin_cache', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='apply_rope_cos_sin_cache', generate_repro_command=True, repro_command='', batch_size=16, seq_len=1024, num_qo_heads=32, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='float16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testApplyRopeWithCosSinCache +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine apply_rope_with_cos_sin_cache --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag apply_rope_cos_sin_cache +[VVERBOSE] q.shape = torch.Size([16384, 4096]) +[VVERBOSE] k.shape = torch.Size([16384, 1024]) +[VVERBOSE] cos_sin_cache.shape = torch.Size([1024, 128]) +[VVERBOSE] positions.shape = torch.Size([16384]) +[VVERBOSE] is_neox = True +[ERROR] Error running test: --routine apply_rope_with_cos_sin_cache --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype float16 -vv --generate_repro_command --case_tag "apply_rope_cos_sin_cache" +[ERROR] Error: cos_sin_cache should be float32 +[INFO] args = Namespace(routine='apply_rope_with_cos_sin_cache', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='apply_rope_cos_sin_cache_interleave', generate_repro_command=True, repro_command='', batch_size=32, seq_len=2048, num_qo_heads=64, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=True, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testApplyRopeWithCosSinCache +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine apply_rope_with_cos_sin_cache --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --interleave -vv --generate_repro_command --case_tag apply_rope_cos_sin_cache_interleave +[VVERBOSE] q.shape = torch.Size([65536, 8192]) +[VVERBOSE] k.shape = torch.Size([65536, 1024]) +[VVERBOSE] cos_sin_cache.shape = torch.Size([2048, 128]) +[VVERBOSE] positions.shape = torch.Size([65536]) +[VVERBOSE] is_neox = False +[ERROR] Error running test: --routine apply_rope_with_cos_sin_cache --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --interleave -vv --generate_repro_command --case_tag "apply_rope_cos_sin_cache_interleave" +[ERROR] Error: cos_sin_cache should be float32 +[INFO] args = Namespace(routine='mla_rope_quantize_fp8', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mla_rope_fp8_deepseek', generate_repro_command=True, repro_command='', batch_size=16, seq_len=1024, num_qo_heads=128, num_kv_heads=128, head_dim=192, rotary_dim=192, no_rope_dim=64, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testMlaRopeQuantizeFp8 +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mla_rope_quantize_fp8 --batch_size 16 --seq_len 1024 --num_qo_heads 128 --num_kv_heads 128 --head_dim 192 --no_rope_dim 64 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag mla_rope_fp8_deepseek +[VVERBOSE] q_rope.shape = torch.Size([16384, 128, 128]) +[VVERBOSE] k_rope.shape = torch.Size([16384, 128]) +[VVERBOSE] q_nope.shape = torch.Size([16384, 128, 64]) +[VVERBOSE] k_nope.shape = torch.Size([16384, 64]) +[VVERBOSE] cos_sin_cache.shape = torch.Size([1024, 128]) +[VVERBOSE] pos_ids.shape = torch.Size([16384]) +[VVERBOSE] rope_dim = 128 +[VVERBOSE] no_rope_dim = 64 +[VVERBOSE] is_neox = True +[PERF] cuda :: median time 0.441 ms; std 0.001 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 2.759 TB/sec +[INFO] args = Namespace(routine='rope_quantize_fp8', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rope_fp8_llama', generate_repro_command=True, repro_command='', batch_size=16, seq_len=1024, num_qo_heads=32, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testRopeQuantizeFp8 +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rope_quantize_fp8 --batch_size 16 --seq_len 1024 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag rope_fp8_llama +[VVERBOSE] q_rope.shape = torch.Size([16384, 32, 128]) +[VVERBOSE] k_rope.shape = torch.Size([16384, 8, 128]) +[VVERBOSE] q_nope.shape = None +[VVERBOSE] k_nope.shape = None +[VVERBOSE] cos_sin_cache.shape = torch.Size([1024, 128]) +[VVERBOSE] pos_ids.shape = torch.Size([16384]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] no_rope_dim = 0 +[VVERBOSE] is_neox = True +[PERF] cuda :: median time 0.083 ms; std 0.001 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 3.041 TB/sec +[INFO] args = Namespace(routine='rope_quantize_fp8', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rope_fp8_llama70b', generate_repro_command=True, repro_command='', batch_size=32, seq_len=2048, num_qo_heads=64, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testRopeQuantizeFp8 +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rope_quantize_fp8 --batch_size 32 --seq_len 2048 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag rope_fp8_llama70b +[VVERBOSE] q_rope.shape = torch.Size([65536, 64, 128]) +[VVERBOSE] k_rope.shape = torch.Size([65536, 8, 128]) +[VVERBOSE] q_nope.shape = None +[VVERBOSE] k_nope.shape = None +[VVERBOSE] cos_sin_cache.shape = torch.Size([2048, 128]) +[VVERBOSE] pos_ids.shape = torch.Size([65536]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] no_rope_dim = 0 +[VVERBOSE] is_neox = True +[PERF] cuda :: median time 0.570 ms; std 0.002 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 3.183 TB/sec +[INFO] args = Namespace(routine='rope_quantize_fp8_append_paged_kv_cache', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rope_fp8_paged_kv', generate_repro_command=True, repro_command='', batch_size=16, seq_len=64, num_qo_heads=32, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=False, page_size=16, kv_layout='NHD', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testRopeQuantizeFp8AppendPagedKvCache +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rope_quantize_fp8_append_paged_kv_cache --batch_size 16 --seq_len 64 --num_qo_heads 32 --num_kv_heads 8 --head_dim 128 --page_size 16 --kv_layout NHD --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag rope_fp8_paged_kv +[VVERBOSE] q_rope.shape = torch.Size([1024, 32, 128]) +[VVERBOSE] k_rope.shape = torch.Size([1024, 8, 128]) +[VVERBOSE] q_nope.shape = None +[VVERBOSE] k_nope.shape = None +[VVERBOSE] v.shape = torch.Size([1024, 8, 128]) +[VVERBOSE] cos_sin_cache.shape = torch.Size([64, 128]) +[VVERBOSE] pos_ids.shape = torch.Size([1024]) +[VVERBOSE] k_cache.shape = torch.Size([64, 16, 8, 128]) +[VVERBOSE] v_cache.shape = torch.Size([64, 16, 8, 128]) +[VVERBOSE] kv_indices.shape = torch.Size([64]) +[VVERBOSE] kv_indptr.shape = torch.Size([17]) +[VVERBOSE] batch_indices.shape = torch.Size([1024]) +[VVERBOSE] positions.shape = torch.Size([1024]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] no_rope_dim = 0 +[VVERBOSE] is_neox = True +[VVERBOSE] page_size = 16 +[VVERBOSE] kv_layout = 'NHD' +[PERF] cuda :: median time 0.010 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 1.941 TB/sec +[INFO] args = Namespace(routine='rope_quantize_fp8_append_paged_kv_cache', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='rope_fp8_paged_kv_hnd', generate_repro_command=True, repro_command='', batch_size=32, seq_len=64, num_qo_heads=64, num_kv_heads=8, head_dim=128, rotary_dim=128, no_rope_dim=0, input_dtype='bfloat16', quant_dtype='fp8_e4m3', rope_scale=1.0, rope_theta=10000.0, interleave=False, page_size=16, kv_layout='HND', low_freq_factor=1.0, high_freq_factor=4.0, old_context_len=8192, backends=['cuda']) +[INFO] Running testRopeQuantizeFp8AppendPagedKvCache +[INFO] FlashInfer version: 0.6.2 +[VVERBOSE] gpu_name = 'NVIDIA_B200' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rope_quantize_fp8_append_paged_kv_cache --batch_size 32 --seq_len 64 --num_qo_heads 64 --num_kv_heads 8 --head_dim 128 --page_size 16 --kv_layout HND --input_dtype bfloat16 --quant_dtype fp8_e4m3 -vv --generate_repro_command --case_tag rope_fp8_paged_kv_hnd +[VVERBOSE] q_rope.shape = torch.Size([2048, 64, 128]) +[VVERBOSE] k_rope.shape = torch.Size([2048, 8, 128]) +[VVERBOSE] q_nope.shape = None +[VVERBOSE] k_nope.shape = None +[VVERBOSE] v.shape = torch.Size([2048, 8, 128]) +[VVERBOSE] cos_sin_cache.shape = torch.Size([64, 128]) +[VVERBOSE] pos_ids.shape = torch.Size([2048]) +[VVERBOSE] k_cache.shape = torch.Size([128, 8, 16, 128]) +[VVERBOSE] v_cache.shape = torch.Size([128, 8, 16, 128]) +[VVERBOSE] kv_indices.shape = torch.Size([128]) +[VVERBOSE] kv_indptr.shape = torch.Size([33]) +[VVERBOSE] batch_indices.shape = torch.Size([2048]) +[VVERBOSE] positions.shape = torch.Size([2048]) +[VVERBOSE] rotary_dim = 128 +[VVERBOSE] no_rope_dim = 0 +[VVERBOSE] is_neox = True +[VVERBOSE] page_size = 16 +[VVERBOSE] kv_layout = 'HND' +[PERF] cuda :: median time 0.026 ms; std 0.000 ms; achieved tflops 0.000 TFLOPs/sec; achieved tb_per_sec 2.399 TB/sec