benchmarks: Expand microbenchmark harness to include sampling and RoPe APIs#2484
benchmarks: Expand microbenchmark harness to include sampling and RoPe APIs#2484bkryu merged 8 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds comprehensive benchmarking support for two new routine families: Sampling and RoPE. Introduces new modules for sampling and rope benchmarks, updates harness routing and utilities, extends README and sample test lists with new flags, cases, and backend/support mappings. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant MainHarness as Main Harness
participant Parser as Routine Parser
participant Dispatcher as Sampling Dispatcher
participant TestFunc as Test Function
participant Backend as GPU Backend
User->>MainHarness: flashinfer_benchmark.py --routine sampling_...
MainHarness->>Parser: parse_sampling_args(cli_args)
Parser-->>MainHarness: args
MainHarness->>Dispatcher: run_sampling_test(args)
Dispatcher->>TestFunc: select and call testSamplingFromX(args)
TestFunc->>TestFunc: prepare tensors/logits/probs
TestFunc->>TestFunc: filter backends by compute capability
loop per backend
TestFunc->>Backend: execute kernel via bench_gpu_time
Backend-->>TestFunc: timing/stats
end
TestFunc->>TestFunc: compute TB/s and metrics
TestFunc-->>MainHarness: optional result dict
sequenceDiagram
participant User
participant MainHarness as Main Harness
participant Parser as Routine Parser
participant Dispatcher as RoPE Dispatcher
participant TestFunc as Test Function
participant Backend as GPU Backend
User->>MainHarness: flashinfer_benchmark.py --routine apply_rope ...
MainHarness->>Parser: parse_rope_args(cli_args)
Parser-->>MainHarness: args
MainHarness->>Dispatcher: run_rope_test(args)
Dispatcher->>TestFunc: select and call testApplyRopeX(args)
TestFunc->>TestFunc: build Q/K, indptr/pos_ids, cos_sin_cache (if used)
TestFunc->>TestFunc: filter backends
loop per backend
TestFunc->>Backend: execute rope kernel via bench_gpu_time
Backend-->>TestFunc: timing/stats
end
TestFunc->>TestFunc: compute TB/s and metrics
TestFunc-->>MainHarness: optional result dict
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the FlashInfer microbenchmark suite by integrating comprehensive benchmarking capabilities for Sampling and Rotary Positional Embeddings (RoPE) APIs. This expansion addresses issue #2361, allowing developers to measure and compare the performance of various sampling strategies and RoPE implementations, including FP8 quantized versions and Llama 3.1 specific RoPE, across different configurations and backends. The changes provide a more complete performance analysis tool for critical components in large language model inference. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
@kahyunnam, can I get a review from you? |
There was a problem hiding this comment.
Code Review
This pull request significantly expands the FlashInfer microbenchmark harness by adding comprehensive support for Sampling and RoPE APIs. The changes are well-structured, introducing new modules for the added routines and updating documentation and utility files accordingly. My review identified a few areas for improvement. There's a minor inconsistency in parameter organization within flashinfer_benchmark_utils.py. Both new benchmark files, rope.py and sampling.py, contain considerable code duplication that could be refactored for better maintainability. Additionally, rope.py has a misleading comment and a potential bug related to inconsistent parameter handling. Addressing these points will enhance the quality and reliability of the new benchmarks.
| num_kv_heads = args.num_kv_heads | ||
| head_dim = args.head_dim | ||
| rotary_dim = args.rotary_dim | ||
| no_rope_dim = head_dim - rotary_dim # For GQA/MHA |
There was a problem hiding this comment.
The no_rope_dim is calculated here as head_dim - rotary_dim, which ignores the --no_rope_dim argument parsed from the command line. Other test functions in this file, like testRopeQuantizeFp8, correctly use args.no_rope_dim. This inconsistency can lead to incorrect benchmark configurations and user confusion. Please use args.no_rope_dim and consider adding an assertion like assert args.head_dim == args.rotary_dim + args.no_rope_dim to ensure consistency.
| no_rope_dim = head_dim - rotary_dim # For GQA/MHA | |
| no_rope_dim = args.no_rope_dim |
| ], | ||
| "gemm": [ | ||
| "m", | ||
| "n", |
| def run_rope_test(args): | ||
| """ | ||
| Run a RoPE test. | ||
|
|
||
| Args: | ||
| args: Parsed command line arguments containing test configuration | ||
|
|
||
| Returns: | ||
| dict: List of dictionaries containing performance results | ||
| """ | ||
| if args.routine == "apply_rope": | ||
| return testApplyRope(args) | ||
| elif args.routine == "apply_rope_pos_ids": | ||
| return testApplyRopePosIds(args) | ||
| elif args.routine == "apply_llama31_rope": | ||
| return testApplyLlama31Rope(args) | ||
| elif args.routine == "apply_llama31_rope_pos_ids": | ||
| return testApplyLlama31RopePosIds(args) | ||
| elif args.routine == "apply_rope_with_cos_sin_cache": | ||
| return testApplyRopeWithCosSinCache(args) | ||
| elif args.routine == "mla_rope_quantize_fp8": | ||
| return testMlaRopeQuantizeFp8(args) | ||
| elif args.routine == "rope_quantize_fp8": | ||
| return testRopeQuantizeFp8(args) | ||
| elif args.routine == "rope_quantize_fp8_append_paged_kv_cache": | ||
| return testRopeQuantizeFp8AppendPagedKvCache(args) | ||
| else: | ||
| raise ValueError(f"Unsupported routine: {args.routine}") | ||
|
|
There was a problem hiding this comment.
There is significant code duplication across the test functions in this file (e.g., testApplyRope, testApplyRopePosIds, etc.). The boilerplate for benchmarking, calculating performance metrics, and storing results is repeated in each function. Consider creating a generic test runner helper function to reduce this duplication. This would improve maintainability and readability.
| 4. Measures performance metrics (TB/sec) | ||
|
|
||
| Note: This API takes pre-split q_rope, k_rope, q_nope, k_nope tensors | ||
| and a precomputed cos_sin_cache. It is the same as rope_quantize_fp8. |
There was a problem hiding this comment.
This comment states that mla_rope_quantize_fp8 is the same as rope_quantize_fp8. This is misleading, as they handle different K-tensor shapes for MLA and GQA/MHA architectures respectively. Please clarify or remove this comment to avoid confusion.
| and a precomputed cos_sin_cache. It is the same as rope_quantize_fp8. | |
| and a precomputed cos_sin_cache. |
| def run_sampling_test(args): | ||
| """ | ||
| Run a sampling test. | ||
|
|
||
| Args: | ||
| args: Parsed command line arguments containing test configuration | ||
|
|
||
| Returns: | ||
| dict: List of dictionaries containing performance results | ||
| """ | ||
| if args.routine == "softmax": | ||
| return testSoftmax(args) | ||
| elif args.routine == "sampling_from_probs": | ||
| return testSamplingFromProbs(args) | ||
| elif args.routine == "sampling_from_logits": | ||
| return testSamplingFromLogits(args) | ||
| elif args.routine == "top_k_sampling_from_probs": | ||
| return testTopKSamplingFromProbs(args) | ||
| elif args.routine == "top_p_sampling_from_probs": | ||
| return testTopPSamplingFromProbs(args) | ||
| elif args.routine == "top_k_top_p_sampling_from_probs": | ||
| return testTopKTopPSamplingFromProbs(args) | ||
| elif args.routine == "top_k_top_p_sampling_from_logits": | ||
| return testTopKTopPSamplingFromLogits(args) | ||
| elif args.routine == "min_p_sampling_from_probs": | ||
| return testMinPSamplingFromProbs(args) | ||
| elif args.routine == "top_k_renorm_probs": | ||
| return testTopKRenormProbs(args) | ||
| elif args.routine == "top_p_renorm_probs": | ||
| return testTopPRenormProbs(args) | ||
| elif args.routine == "top_k_mask_logits": | ||
| return testTopKMaskLogits(args) | ||
| elif args.routine == "chain_speculative_sampling": | ||
| return testChainSpeculativeSampling(args) | ||
| elif args.routine == "top_k": | ||
| return testTopK(args) | ||
| elif args.routine == "top_k_page_table_transform": | ||
| return testTopKPageTableTransform(args) | ||
| elif args.routine == "top_k_ragged_transform": | ||
| return testTopKRaggedTransform(args) | ||
| else: | ||
| raise ValueError(f"Unsupported routine: {args.routine}") | ||
|
|
There was a problem hiding this comment.
Similar to rope.py, there is a large amount of duplicated code across the test functions in this file. The boilerplate for benchmarking, calculating performance metrics, and storing results is repeated. Refactoring this into a common helper function would greatly improve the maintainability of this file.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@benchmarks/README.md`:
- Around line 344-381: The markdown tables under the headings "### Sampling
Flags" and "### RoPE Flags" start immediately after the headings and before the
following text, triggering MD058; add a single blank line between each heading
and its table, and ensure there is a blank line after each table (i.e., add one
empty line before the table under "### Sampling Flags" and one after it, and do
the same for the table under "### RoPE Flags") so both tables have blank lines
before and after to satisfy markdownlint.
In `@benchmarks/routines/rope.py`:
- Around line 872-946: cos_sin_cache is created with input_dtype (often float16)
but apply_rope_with_cos_sin_cache requires float32, causing a ValueError;
recreate cos_sin_cache with dtype=torch.float32 on the same device and pass that
into run_backend/apply_rope_with_cos_sin_cache, and update the memory bandwidth
calc in problem_bytes to use 4 bytes per element for the cos_sin_cache read
(max_seq_len * rotary_dim * 4) instead of input_dtype.itemsize so the reported
TB/sec matches the actual float32 cache size.
| ### Sampling Flags | ||
| | Flag | Description | | ||
| |--------------------------|-------------------------------------------------------------------------------------------------------------| | ||
| | `--batch_size` | Batch size (number of sequences) | | ||
| | `--vocab_size` | Vocabulary size | | ||
| | `--input_dtype` | Input data type for logits: `float32` (default), `float16`, or `bfloat16` | | ||
| | `--top_k` | Top-K value for top-k sampling. Default: 50 | | ||
| | `--top_p` | Top-P threshold for top-p (nucleus) sampling. Default: 0.9 | | ||
| | `--min_p` | Min-P threshold for min-p sampling. Default: 0.1 | | ||
| | `--temperature` | Temperature for softmax. Default: 1.0 | | ||
| | `--filter_apply_order` | Order of applying top-k and top-p filters: `top_k_first` (default) or `joint` | | ||
| | `--num_speculate_tokens` | Number of speculative tokens for chain speculative sampling. Default: 5 | | ||
| | `--max_len` | Max sequence length for `top_k_page_table_transform` and `top_k_ragged_transform`. Default: 4096 | | ||
| | `--num_rows` | Number of rows for `top_k_page_table_transform` and `top_k_ragged_transform`. Defaults to batch_size | | ||
| | `--backends` | Backend to test: `cuda` (default) | | ||
|
|
||
| ### RoPE Flags | ||
| | Flag | Description | | ||
| |--------------------------|-------------------------------------------------------------------------------------------------------------| | ||
| | `--batch_size` | Batch size (number of sequences) | | ||
| | `--seq_len` | Sequence length (qkv_len or kv_len) | | ||
| | `--num_qo_heads` | Number of query/output heads | | ||
| | `--num_kv_heads` | Number of key/value heads | | ||
| | `--head_dim` | Head dimension | | ||
| | `--rotary_dim` | Rotary dimension (defaults to head_dim if not specified) | | ||
| | `--no_rope_dim` | Number of dimensions without RoPE (for MLA). Default: 0 | | ||
| | `--input_dtype` | Input data type: `float16` (default) or `bfloat16` | | ||
| | `--quant_dtype` | Quantized data type for FP8 routines: `fp8_e4m3` (default) or `fp8_e5m2` | | ||
| | `--rope_scale` | RoPE scaling factor. Default: 1.0 | | ||
| | `--rope_theta` | RoPE theta base frequency. Default: 10000.0 | | ||
| | `--interleave` | Use interleaved rotary embedding (GPT-J style) | | ||
| | `--page_size` | Page size for paged KV cache. Default: 16 | | ||
| | `--kv_layout` | KV cache layout: `NHD` (default) or `HND` | | ||
| | `--low_freq_factor` | Low frequency factor for Llama 3.1 RoPE. Default: 1.0 | | ||
| | `--high_freq_factor` | High frequency factor for Llama 3.1 RoPE. Default: 4.0 | | ||
| | `--old_context_len` | Old context length for Llama 3.1 RoPE. Default: 8192 | | ||
| | `--backends` | Backend to test: `cuda` (default) | | ||
|
|
There was a problem hiding this comment.
Add blank lines around the new tables to satisfy markdownlint.
Lines 345 and 361 start tables immediately after headings; MD058 expects blank lines before/after tables.
📝 Proposed markdownlint-friendly spacing
### Sampling Flags
+
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--vocab_size` | Vocabulary size |
| `--input_dtype` | Input data type for logits: `float32` (default), `float16`, or `bfloat16` |
| `--top_k` | Top-K value for top-k sampling. Default: 50 |
| `--top_p` | Top-P threshold for top-p (nucleus) sampling. Default: 0.9 |
| `--min_p` | Min-P threshold for min-p sampling. Default: 0.1 |
| `--temperature` | Temperature for softmax. Default: 1.0 |
| `--filter_apply_order` | Order of applying top-k and top-p filters: `top_k_first` (default) or `joint` |
| `--num_speculate_tokens` | Number of speculative tokens for chain speculative sampling. Default: 5 |
| `--max_len` | Max sequence length for `top_k_page_table_transform` and `top_k_ragged_transform`. Default: 4096 |
| `--num_rows` | Number of rows for `top_k_page_table_transform` and `top_k_ragged_transform`. Defaults to batch_size |
| `--backends` | Backend to test: `cuda` (default) |
+
### RoPE Flags
+
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--seq_len` | Sequence length (qkv_len or kv_len) |
| `--num_qo_heads` | Number of query/output heads |
| `--num_kv_heads` | Number of key/value heads |
| `--head_dim` | Head dimension |
| `--rotary_dim` | Rotary dimension (defaults to head_dim if not specified) |
| `--no_rope_dim` | Number of dimensions without RoPE (for MLA). Default: 0 |
| `--input_dtype` | Input data type: `float16` (default) or `bfloat16` |
| `--quant_dtype` | Quantized data type for FP8 routines: `fp8_e4m3` (default) or `fp8_e5m2` |
| `--rope_scale` | RoPE scaling factor. Default: 1.0 |
| `--rope_theta` | RoPE theta base frequency. Default: 10000.0 |
| `--interleave` | Use interleaved rotary embedding (GPT-J style) |
| `--page_size` | Page size for paged KV cache. Default: 16 |
| `--kv_layout` | KV cache layout: `NHD` (default) or `HND` |
| `--low_freq_factor` | Low frequency factor for Llama 3.1 RoPE. Default: 1.0 |
| `--high_freq_factor` | High frequency factor for Llama 3.1 RoPE. Default: 4.0 |
| `--old_context_len` | Old context length for Llama 3.1 RoPE. Default: 8192 |
| `--backends` | Backend to test: `cuda` (default) |
+📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ### Sampling Flags | |
| | Flag | Description | | |
| |--------------------------|-------------------------------------------------------------------------------------------------------------| | |
| | `--batch_size` | Batch size (number of sequences) | | |
| | `--vocab_size` | Vocabulary size | | |
| | `--input_dtype` | Input data type for logits: `float32` (default), `float16`, or `bfloat16` | | |
| | `--top_k` | Top-K value for top-k sampling. Default: 50 | | |
| | `--top_p` | Top-P threshold for top-p (nucleus) sampling. Default: 0.9 | | |
| | `--min_p` | Min-P threshold for min-p sampling. Default: 0.1 | | |
| | `--temperature` | Temperature for softmax. Default: 1.0 | | |
| | `--filter_apply_order` | Order of applying top-k and top-p filters: `top_k_first` (default) or `joint` | | |
| | `--num_speculate_tokens` | Number of speculative tokens for chain speculative sampling. Default: 5 | | |
| | `--max_len` | Max sequence length for `top_k_page_table_transform` and `top_k_ragged_transform`. Default: 4096 | | |
| | `--num_rows` | Number of rows for `top_k_page_table_transform` and `top_k_ragged_transform`. Defaults to batch_size | | |
| | `--backends` | Backend to test: `cuda` (default) | | |
| ### RoPE Flags | |
| | Flag | Description | | |
| |--------------------------|-------------------------------------------------------------------------------------------------------------| | |
| | `--batch_size` | Batch size (number of sequences) | | |
| | `--seq_len` | Sequence length (qkv_len or kv_len) | | |
| | `--num_qo_heads` | Number of query/output heads | | |
| | `--num_kv_heads` | Number of key/value heads | | |
| | `--head_dim` | Head dimension | | |
| | `--rotary_dim` | Rotary dimension (defaults to head_dim if not specified) | | |
| | `--no_rope_dim` | Number of dimensions without RoPE (for MLA). Default: 0 | | |
| | `--input_dtype` | Input data type: `float16` (default) or `bfloat16` | | |
| | `--quant_dtype` | Quantized data type for FP8 routines: `fp8_e4m3` (default) or `fp8_e5m2` | | |
| | `--rope_scale` | RoPE scaling factor. Default: 1.0 | | |
| | `--rope_theta` | RoPE theta base frequency. Default: 10000.0 | | |
| | `--interleave` | Use interleaved rotary embedding (GPT-J style) | | |
| | `--page_size` | Page size for paged KV cache. Default: 16 | | |
| | `--kv_layout` | KV cache layout: `NHD` (default) or `HND` | | |
| | `--low_freq_factor` | Low frequency factor for Llama 3.1 RoPE. Default: 1.0 | | |
| | `--high_freq_factor` | High frequency factor for Llama 3.1 RoPE. Default: 4.0 | | |
| | `--old_context_len` | Old context length for Llama 3.1 RoPE. Default: 8192 | | |
| | `--backends` | Backend to test: `cuda` (default) | | |
| ### Sampling Flags | |
| | Flag | Description | | |
| |--------------------------|-------------------------------------------------------------------------------------------------------------| | |
| | `--batch_size` | Batch size (number of sequences) | | |
| | `--vocab_size` | Vocabulary size | | |
| | `--input_dtype` | Input data type for logits: `float32` (default), `float16`, or `bfloat16` | | |
| | `--top_k` | Top-K value for top-k sampling. Default: 50 | | |
| | `--top_p` | Top-P threshold for top-p (nucleus) sampling. Default: 0.9 | | |
| | `--min_p` | Min-P threshold for min-p sampling. Default: 0.1 | | |
| | `--temperature` | Temperature for softmax. Default: 1.0 | | |
| | `--filter_apply_order` | Order of applying top-k and top-p filters: `top_k_first` (default) or `joint` | | |
| | `--num_speculate_tokens` | Number of speculative tokens for chain speculative sampling. Default: 5 | | |
| | `--max_len` | Max sequence length for `top_k_page_table_transform` and `top_k_ragged_transform`. Default: 4096 | | |
| | `--num_rows` | Number of rows for `top_k_page_table_transform` and `top_k_ragged_transform`. Defaults to batch_size | | |
| | `--backends` | Backend to test: `cuda` (default) | | |
| ### RoPE Flags | |
| | Flag | Description | | |
| |--------------------------|-------------------------------------------------------------------------------------------------------------| | |
| | `--batch_size` | Batch size (number of sequences) | | |
| | `--seq_len` | Sequence length (qkv_len or kv_len) | | |
| | `--num_qo_heads` | Number of query/output heads | | |
| | `--num_kv_heads` | Number of key/value heads | | |
| | `--head_dim` | Head dimension | | |
| | `--rotary_dim` | Rotary dimension (defaults to head_dim if not specified) | | |
| | `--no_rope_dim` | Number of dimensions without RoPE (for MLA). Default: 0 | | |
| | `--input_dtype` | Input data type: `float16` (default) or `bfloat16` | | |
| | `--quant_dtype` | Quantized data type for FP8 routines: `fp8_e4m3` (default) or `fp8_e5m2` | | |
| | `--rope_scale` | RoPE scaling factor. Default: 1.0 | | |
| | `--rope_theta` | RoPE theta base frequency. Default: 10000.0 | | |
| | `--interleave` | Use interleaved rotary embedding (GPT-J style) | | |
| | `--page_size` | Page size for paged KV cache. Default: 16 | | |
| | `--kv_layout` | KV cache layout: `NHD` (default) or `HND` | | |
| | `--low_freq_factor` | Low frequency factor for Llama 3.1 RoPE. Default: 1.0 | | |
| | `--high_freq_factor` | High frequency factor for Llama 3.1 RoPE. Default: 4.0 | | |
| | `--old_context_len` | Old context length for Llama 3.1 RoPE. Default: 8192 | | |
| | `--backends` | Backend to test: `cuda` (default) | | |
🧰 Tools
🪛 LanguageTool
[uncategorized] ~378-~378: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ... | | --high_freq_factor | High frequency factor for Llama 3.1 RoPE. Default: 4.0...
(EN_COMPOUND_ADJECTIVE_INTERNAL)
🪛 markdownlint-cli2 (0.20.0)
[warning] 345-345: Tables should be surrounded by blank lines
(MD058, blanks-around-tables)
[warning] 361-361: Tables should be surrounded by blank lines
(MD058, blanks-around-tables)
🤖 Prompt for AI Agents
In `@benchmarks/README.md` around lines 344 - 381, The markdown tables under the
headings "### Sampling Flags" and "### RoPE Flags" start immediately after the
headings and before the following text, triggering MD058; add a single blank
line between each heading and its table, and ensure there is a blank line after
each table (i.e., add one empty line before the table under "### Sampling Flags"
and one after it, and do the same for the table under "### RoPE Flags") so both
tables have blank lines before and after to satisfy markdownlint.
|
#2374 is another relevant PR |
kahyunnam
left a comment
There was a problem hiding this comment.
LGTM. Will be very useful for rope, thanks for implementing
95886b6 to
73d2f0a
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@benchmarks/routines/rope.py`:
- Around line 99-207: Validate parsed args so rotary_dim and no_rope_dim cannot
produce negative or out-of-range head dims: after you set args.rotary_dim =
args.head_dim (when None) add checks that args.no_rope_dim >= 0, args.rotary_dim
>= 0, args.no_rope_dim <= args.head_dim, args.rotary_dim <= args.head_dim, and
that args.rotary_dim + args.no_rope_dim <= args.head_dim; if any check fails
raise argparse.ArgumentTypeError or exit with a clear error message referencing
rotary_dim, no_rope_dim, and head_dim so the CLI user sees which values are
invalid.
In `@benchmarks/routines/sampling.py`:
- Around line 1567-1590: The src_page_table is created with batch_size rows but
must align with num_rows for testTopKPageTableTransform where input_scores has
shape (num_rows, max_len) and lengths is (num_rows,); change the src_page_table
creation in benchmarks/routines/sampling.py so src_page_table is sized
(num_rows, max_len) (use num_rows instead of batch_size) and keep
dtype=torch.int32 and device=device so it matches input_scores/lengths before
calling flashinfer.top_k_page_table_transform; also verify any subsequent
memory-bandwidth or shape assumptions reference num_rows rather than batch_size.
| "--head_dim", | ||
| type=int, | ||
| required=True, | ||
| help="Head dimension.", | ||
| ) | ||
| parser.add_argument( | ||
| "--rotary_dim", | ||
| type=int, | ||
| required=False, | ||
| default=None, | ||
| help="Rotary dimension (defaults to head_dim if not specified).", | ||
| ) | ||
| parser.add_argument( | ||
| "--no_rope_dim", | ||
| type=int, | ||
| required=False, | ||
| default=0, | ||
| help="Number of dimensions without RoPE (for MLA). Default: 0.", | ||
| ) | ||
| parser.add_argument( | ||
| "--input_dtype", | ||
| type=str, | ||
| required=False, | ||
| default="float16", | ||
| choices=["float16", "bfloat16"], | ||
| help="Data type of the input tensor.", | ||
| ) | ||
| parser.add_argument( | ||
| "--quant_dtype", | ||
| type=str, | ||
| required=False, | ||
| default="fp8_e4m3", | ||
| choices=["fp8_e4m3", "fp8_e5m2"], | ||
| help="Quantized data type for FP8 routines.", | ||
| ) | ||
| parser.add_argument( | ||
| "--rope_scale", | ||
| type=float, | ||
| required=False, | ||
| default=1.0, | ||
| help="RoPE scaling factor.", | ||
| ) | ||
| parser.add_argument( | ||
| "--rope_theta", | ||
| type=float, | ||
| required=False, | ||
| default=10000.0, | ||
| help="RoPE theta base frequency.", | ||
| ) | ||
| parser.add_argument( | ||
| "--interleave", | ||
| action="store_true", | ||
| help="Use interleaved rotary embedding (GPT-J style).", | ||
| ) | ||
| parser.add_argument( | ||
| "--page_size", | ||
| type=int, | ||
| required=False, | ||
| default=16, | ||
| help="Page size for paged KV cache.", | ||
| ) | ||
| parser.add_argument( | ||
| "--kv_layout", | ||
| type=str, | ||
| required=False, | ||
| default="NHD", | ||
| choices=["NHD", "HND"], | ||
| help="KV cache layout.", | ||
| ) | ||
| parser.add_argument( | ||
| "--low_freq_factor", | ||
| type=float, | ||
| required=False, | ||
| default=1.0, | ||
| help="Low frequency factor for Llama 3.1 RoPE.", | ||
| ) | ||
| parser.add_argument( | ||
| "--high_freq_factor", | ||
| type=float, | ||
| required=False, | ||
| default=4.0, | ||
| help="High frequency factor for Llama 3.1 RoPE.", | ||
| ) | ||
| parser.add_argument( | ||
| "--old_context_len", | ||
| type=int, | ||
| required=False, | ||
| default=8192, | ||
| help="Old context length for Llama 3.1 RoPE.", | ||
| ) | ||
| parser.add_argument( | ||
| "--backends", | ||
| type=str, | ||
| required=False, | ||
| nargs="+", | ||
| default=["cuda"], | ||
| choices=["cuda"], | ||
| help="Kernel backends to test. Default: cuda", | ||
| ) | ||
|
|
||
| args = parser.parse_args(line) | ||
|
|
||
| # Default rotary_dim to head_dim if not specified | ||
| if args.rotary_dim is None: | ||
| args.rotary_dim = args.head_dim | ||
|
|
||
| if args.verbose >= 1: | ||
| print(f"[INFO] {args = }") | ||
| return args |
There was a problem hiding this comment.
Validate rotary_dim/no_rope_dim against head_dim to avoid negative tensor shapes.
Invalid CLI values can lead to negative dims in MLA/FP8 paths and crash with cryptic tensor errors.
✅ Suggested validation
# Default rotary_dim to head_dim if not specified
if args.rotary_dim is None:
args.rotary_dim = args.head_dim
+
+ if args.rotary_dim < 0 or args.rotary_dim > args.head_dim:
+ raise ValueError("--rotary_dim must be in [0, head_dim].")
+ if args.no_rope_dim < 0 or args.no_rope_dim > args.head_dim:
+ raise ValueError("--no_rope_dim must be in [0, head_dim].")🤖 Prompt for AI Agents
In `@benchmarks/routines/rope.py` around lines 99 - 207, Validate parsed args so
rotary_dim and no_rope_dim cannot produce negative or out-of-range head dims:
after you set args.rotary_dim = args.head_dim (when None) add checks that
args.no_rope_dim >= 0, args.rotary_dim >= 0, args.no_rope_dim <= args.head_dim,
args.rotary_dim <= args.head_dim, and that args.rotary_dim + args.no_rope_dim <=
args.head_dim; if any check fails raise argparse.ArgumentTypeError or exit with
a clear error message referencing rotary_dim, no_rope_dim, and head_dim so the
CLI user sees which values are invalid.
| backends = args.backends[:] | ||
| batch_size = args.batch_size | ||
| num_rows = args.num_rows | ||
| max_len = args.max_len | ||
| top_k = args.top_k | ||
| is_cuda_graph_compatible = not args.no_cuda_graph | ||
| res = [] | ||
|
|
||
| backends = filter_backends_by_compute_capability(backends, args.routine, device) | ||
| if len(backends) == 0: | ||
| print("[ERROR] No backends to test. Exiting.") | ||
| return res | ||
|
|
||
| input_dtype = dtype_str_to_torch_dtype(args.input_dtype) | ||
|
|
||
| ## Prepare input tensors | ||
| # Input scores: (num_rows, max_len) | ||
| input_scores = torch.randn(num_rows, max_len, dtype=input_dtype, device=device) | ||
|
|
||
| # Source page table: (batch_size, max_len) | ||
| src_page_table = torch.randint( | ||
| 0, 1000, (batch_size, max_len), dtype=torch.int32, device=device | ||
| ) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /repo && find . -name "sampling.py" -path "*/benchmarks/routines/*" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 123
🏁 Script executed:
wc -l ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 106
🏁 Script executed:
sed -n '1550,1620p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 2563
🏁 Script executed:
sed -n '1500,1560p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 2326
🏁 Script executed:
sed -n '1620,1654p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 1484
🏁 Script executed:
sed -n '1590,1630p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 1667
🏁 Script executed:
grep -n "sampled_token_ids\|temperatures" ./benchmarks/routines/sampling.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg -n "sampled_token_ids\s*=" ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg -n "temperatures\s*=" ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
grep -n "add_argument.*num_rows\|add_argument.*batch_size" ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
find ./benchmarks -name "*.py" -type f | xargs grep -l "add_argument.*num_rows" | head -3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg -n "add_argument.*num_rows|add_argument.*batch_size" ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg -n "BatchTopKPrefixPageTableTransform\|top_k_page_table\|top_k_sampling" -A 10 ./benchmarks/routines/sampling.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
find . -name "*.py" -o -name "*.cu" -o -name "*.h" | xargs grep -l "BatchTopKPrefixPageTableTransform\|top_k_page_table" 2>/dev/null | head -5Repository: flashinfer-ai/flashinfer
Length of output: 226
🏁 Script executed:
sed -n '1620,1640p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 957
🏁 Script executed:
rg -n "^import vllm|^from vllm|vllm\.sample_top_k" ./benchmarks/routines/sampling.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
head -50 ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 1525
🏁 Script executed:
awk '/^def benchmark_top_k_sampling/,/^def [a-zA-Z]/ {print NR": "$0}' ./benchmarks/routines/sampling.py | tail -100Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
rg -n "^def benchmark_" ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
find ./benchmarks -name "*.py" | xargs grep -l "benchmark_top_k_sampling" | head -3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
find ./benchmarks -name "*.py" -type f | xargs grep -l "sampling\|benchmark" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 503
🏁 Script executed:
sed -n '1540,1575p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 1164
🏁 Script executed:
find . -name "test*.py" -o -name "*test.py" | xargs grep -l "top_k_sampling\|sampling" 2>/dev/null | head -3Repository: flashinfer-ai/flashinfer
Length of output: 159
🏁 Script executed:
sed -n '1567,1590p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 845
🏁 Script executed:
rg -n "input_scores" ./benchmarks/routines/sampling.py -A 3 -B 1Repository: flashinfer-ai/flashinfer
Length of output: 3034
🏁 Script executed:
sed -n '1598,1650p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 2304
🏁 Script executed:
find . -name "*.py" | xargs grep -l "def sample_top_k" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 52
🏁 Script executed:
sed -n '65,80p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 626
🏁 Script executed:
git log --oneline ./benchmarks/routines/sampling.py 2>/dev/null | head -10Repository: flashinfer-ai/flashinfer
Length of output: 86
🏁 Script executed:
git status ./benchmarks/routines/sampling.py 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 133
🏁 Script executed:
sed -n '1605,1611p' ./benchmarks/routines/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 284
🏁 Script executed:
rg -n "num_rows\|batch_size" ./benchmarks/routines/sampling.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
Align src_page_table rows with num_rows, not batch_size.
In testTopKPageTableTransform, input_scores is shaped (num_rows, max_len) and lengths is shaped (num_rows,), but src_page_table uses (batch_size, max_len). All three tensors are passed together to flashinfer.top_k_page_table_transform(). Since lengths corresponds to num_rows sequences, the page table must also have num_rows rows; if num_rows differs from batch_size, the current shape will cause a mismatch. The memory bandwidth calculation explicitly references both dimensions separately, confirming they are intended to vary independently and both should align with the sequence count (num_rows).
Suggested fix
- src_page_table = torch.randint(
- 0, 1000, (batch_size, max_len), dtype=torch.int32, device=device
- )
+ src_page_table = torch.randint(
+ 0, 1000, (num_rows, max_len), dtype=torch.int32, device=device
+ )🤖 Prompt for AI Agents
In `@benchmarks/routines/sampling.py` around lines 1567 - 1590, The src_page_table
is created with batch_size rows but must align with num_rows for
testTopKPageTableTransform where input_scores has shape (num_rows, max_len) and
lengths is (num_rows,); change the src_page_table creation in
benchmarks/routines/sampling.py so src_page_table is sized (num_rows, max_len)
(use num_rows instead of batch_size) and keep dtype=torch.int32 and
device=device so it matches input_scores/lengths before calling
flashinfer.top_k_page_table_transform; also verify any subsequent
memory-bandwidth or shape assumptions reference num_rows rather than batch_size.
|
@vincentzed , please check commit de91629 for newly added refchecks in sampling APIs. cc @kahyunnam there was a suggestion to add refchecks to the sampling APIs so I added them. |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/routines/sampling.py`:
- Line 1666: The unused variable ref_indices in the reference check should be
ignored to satisfy Ruff; change the unpacking in the torch.topk call (the line
assigning ref_values, ref_indices) to use a throwaway name (e.g., replace
ref_indices with _ or _ref_indices) so only ref_values is considered and the
linter warning is removed.
| # Note: FlashInfer top_k returns UNSORTED results by default, so we compare | ||
| # sorted values to verify the same elements are selected | ||
| if run_refcheck and outputs: | ||
| ref_values, ref_indices = torch.topk(input_tensor.float(), k=top_k, dim=-1) |
There was a problem hiding this comment.
Drop unused ref_indices to satisfy Ruff.
Ruff flags the unused variable in the reference check; replace it with _ (or _ref_indices).
♻️ Proposed fix
- ref_values, ref_indices = torch.topk(input_tensor.float(), k=top_k, dim=-1)
+ ref_values, _ = torch.topk(input_tensor.float(), k=top_k, dim=-1)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ref_values, ref_indices = torch.topk(input_tensor.float(), k=top_k, dim=-1) | |
| ref_values, _ = torch.topk(input_tensor.float(), k=top_k, dim=-1) |
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 1666-1666: Unpacked variable ref_indices is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@benchmarks/routines/sampling.py` at line 1666, The unused variable
ref_indices in the reference check should be ignored to satisfy Ruff; change the
unpacking in the torch.topk call (the line assigning ref_values, ref_indices) to
use a throwaway name (e.g., replace ref_indices with _ or _ref_indices) so only
ref_values is considered and the linter warning is removed.
📌 Description
This PR expands the FlashInfer microbenchmark harness (
flashinfer_benchmark.py) to include Sampling and RoPE (Rotary Positional Embeddings) APIs, addressing issue #2361.Sampling routines added (15 APIs):
softmax- Softmax with optional temperature scalingsampling_from_probs/sampling_from_logits- Basic categorical samplingtop_k_sampling_from_probs- Top-K samplingtop_p_sampling_from_probs- Top-P (nucleus) samplingtop_k_top_p_sampling_from_probs/top_k_top_p_sampling_from_logits- Combined Top-K and Top-Pmin_p_sampling_from_probs- Min-P samplingtop_k_renorm_probs/top_p_renorm_probs- Probability renormalizationtop_k_mask_logits- Logits maskingchain_speculative_sampling- Chain speculative sampling for speculative decodingtop_k/top_k_page_table_transform/top_k_ragged_transform- Radix-based Top-K selectionRoPE routines added (8 APIs):
apply_rope/apply_rope_pos_ids- Standard RoPE with indptr/offsets or position IDsapply_llama31_rope/apply_llama31_rope_pos_ids- Llama 3.1 style RoPEapply_rope_with_cos_sin_cache- RoPE with precomputed cos/sin cachemla_rope_quantize_fp8- MLA RoPE with FP8 quantization (SM8.9+)rope_quantize_fp8- RoPE with FP8 quantization (SM8.9+)rope_quantize_fp8_append_paged_kv_cache- RoPE with FP8 quantization and paged KV cache append (SM8.9+)🔍 Related Issues
#2361
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Chores