Skip to content

benchmarks: Expand microbenchmark harness to include sampling and RoPe APIs#2484

Merged
bkryu merged 8 commits intoflashinfer-ai:mainfrom
bkryu:bench_sampling_rope
Feb 5, 2026
Merged

benchmarks: Expand microbenchmark harness to include sampling and RoPe APIs#2484
bkryu merged 8 commits intoflashinfer-ai:mainfrom
bkryu:bench_sampling_rope

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Feb 3, 2026

📌 Description

This PR expands the FlashInfer microbenchmark harness (flashinfer_benchmark.py) to include Sampling and RoPE (Rotary Positional Embeddings) APIs, addressing issue #2361.
Sampling routines added (15 APIs):

  • softmax - Softmax with optional temperature scaling
  • sampling_from_probs / sampling_from_logits - Basic categorical sampling
  • top_k_sampling_from_probs - Top-K sampling
  • top_p_sampling_from_probs - Top-P (nucleus) sampling
  • top_k_top_p_sampling_from_probs / top_k_top_p_sampling_from_logits - Combined Top-K and Top-P
  • min_p_sampling_from_probs - Min-P sampling
  • top_k_renorm_probs / top_p_renorm_probs - Probability renormalization
  • top_k_mask_logits - Logits masking
  • chain_speculative_sampling - Chain speculative sampling for speculative decoding
  • top_k / top_k_page_table_transform / top_k_ragged_transform - Radix-based Top-K selection

RoPE routines added (8 APIs):

  • apply_rope / apply_rope_pos_ids - Standard RoPE with indptr/offsets or position IDs
  • apply_llama31_rope / apply_llama31_rope_pos_ids - Llama 3.1 style RoPE
  • apply_rope_with_cos_sin_cache - RoPE with precomputed cos/sin cache
  • mla_rope_quantize_fp8 - MLA RoPE with FP8 quantization (SM8.9+)
  • rope_quantize_fp8 - RoPE with FP8 quantization (SM8.9+)
  • rope_quantize_fp8_append_paged_kv_cache - RoPE with FP8 quantization and paged KV cache append (SM8.9+)

🔍 Related Issues

#2361

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added comprehensive benchmarking for Sampling (softmax, sampling/top-k/top-p/min-p, chain speculative sampling, page-table/ragged variants) and RoPE (apply_rope, Llama‑3.1 variants, cos/sin cache, FP8 quantize, paged KV cache) workflows with CLI-driven tests and multi-backend timing/metrics.
  • Documentation

    • Expanded docs with Sampling and RoPE Quick Start, option tables, examples, and updated routine/backend support matrix.
  • Chores

    • Updated sample test listings and default case configurations for existing kernel benchmarks.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 3, 2026

📝 Walkthrough

Walkthrough

Adds comprehensive benchmarking support for two new routine families: Sampling and RoPE. Introduces new modules for sampling and rope benchmarks, updates harness routing and utilities, extends README and sample test lists with new flags, cases, and backend/support mappings.

Changes

Cohort / File(s) Summary
Documentation & Config
benchmarks/README.md, benchmarks/samples/sample_testlist.txt
Added Sampling and RoPE documentation, flags, examples, and ~50+ new sample test entries; adjusted existing mm_fp4 entries and several batch_size defaults.
Main Harness
benchmarks/flashinfer_benchmark.py
Added sampling and rope to allowed routines; lazy-imports and dispatch to run_sampling_test / run_rope_test; extended CLI routing to parse sampling/rope args.
Benchmark Utilities
benchmarks/routines/flashinfer_benchmark_utils.py
Updated output_column_dict / full_output_columns; added sampling and rope categories; extended benchmark_apis, dtype_str_to_torch_dtype mappings, and routine_cc_to_supported_backends.
Sampling Module
benchmarks/routines/sampling.py
New module: dispatcher, parse_sampling_args, run_sampling_test, and ~20+ test functions (softmax, sampling_from_probs/logits, top_k/top_p/min_p variants, renorm, mask, chain_speculative, top_k transforms, etc.) with input prep, per-backend timing, and metric reporting.
RoPE Module
benchmarks/routines/rope.py
New module: dispatcher, parse_rope_args, run_rope_test, and ~10+ test functions (apply_rope variants, llama3.1 variants, cos_sin_cache, MLA/FP8 quantize, paged KV cache append) with input prep, backend filtering, timing, and metric reporting.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant MainHarness as Main Harness
    participant Parser as Routine Parser
    participant Dispatcher as Sampling Dispatcher
    participant TestFunc as Test Function
    participant Backend as GPU Backend

    User->>MainHarness: flashinfer_benchmark.py --routine sampling_...
    MainHarness->>Parser: parse_sampling_args(cli_args)
    Parser-->>MainHarness: args
    MainHarness->>Dispatcher: run_sampling_test(args)
    Dispatcher->>TestFunc: select and call testSamplingFromX(args)
    TestFunc->>TestFunc: prepare tensors/logits/probs
    TestFunc->>TestFunc: filter backends by compute capability
    loop per backend
        TestFunc->>Backend: execute kernel via bench_gpu_time
        Backend-->>TestFunc: timing/stats
    end
    TestFunc->>TestFunc: compute TB/s and metrics
    TestFunc-->>MainHarness: optional result dict
Loading
sequenceDiagram
    participant User
    participant MainHarness as Main Harness
    participant Parser as Routine Parser
    participant Dispatcher as RoPE Dispatcher
    participant TestFunc as Test Function
    participant Backend as GPU Backend

    User->>MainHarness: flashinfer_benchmark.py --routine apply_rope ...
    MainHarness->>Parser: parse_rope_args(cli_args)
    Parser-->>MainHarness: args
    MainHarness->>Dispatcher: run_rope_test(args)
    Dispatcher->>TestFunc: select and call testApplyRopeX(args)
    TestFunc->>TestFunc: build Q/K, indptr/pos_ids, cos_sin_cache (if used)
    TestFunc->>TestFunc: filter backends
    loop per backend
        TestFunc->>Backend: execute rope kernel via bench_gpu_time
        Backend-->>TestFunc: timing/stats
    end
    TestFunc->>TestFunc: compute TB/s and metrics
    TestFunc-->>MainHarness: optional result dict
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • Anerudhan
  • yzh119
  • cyx-6
  • jiahanc
  • nv-yunzheq

Poem

🐰 I hopped through code and flagged each case,

I sampled logits, twitched my whiskers with grace.
RoPE spun its coils in a rhythmic sweep,
I timed the kernels while carrots I keep.
Hooray — the benchmark rabbit counts TB/s in peace!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: expanding the microbenchmark harness to include sampling and RoPE APIs, which aligns with the extensive additions across multiple files.
Description check ✅ Passed The PR description comprehensively covers the changes: it lists all 15 sampling APIs and 8 RoPE APIs added, explains their purposes, provides clear context for the changes (issue #2361), and includes a completed checklist.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FlashInfer microbenchmark suite by integrating comprehensive benchmarking capabilities for Sampling and Rotary Positional Embeddings (RoPE) APIs. This expansion addresses issue #2361, allowing developers to measure and compare the performance of various sampling strategies and RoPE implementations, including FP8 quantized versions and Llama 3.1 specific RoPE, across different configurations and backends. The changes provide a more complete performance analysis tool for critical components in large language model inference.

Highlights

  • Expanded Microbenchmark Harness: The flashinfer_benchmark.py harness has been significantly expanded to include comprehensive benchmarking capabilities for both Sampling and Rotary Positional Embeddings (RoPE) APIs.
  • New Sampling APIs Benchmarked: Fifteen new Sampling APIs are now supported, including softmax, various sampling_from_probs and sampling_from_logits methods (basic, Top-K, Top-P, Min-P, combined Top-K/Top-P), probability renormalization (top_k_renorm_probs, top_p_renorm_probs), logits masking (top_k_mask_logits), chain speculative sampling, and radix-based Top-K selection (top_k, top_k_page_table_transform, top_k_ragged_transform).
  • New RoPE APIs Benchmarked: Eight new RoPE APIs have been integrated, covering standard RoPE (apply_rope, apply_rope_pos_ids), Llama 3.1 style RoPE (apply_llama31_rope, apply_llama31_rope_pos_ids), RoPE with precomputed cos/sin cache (apply_rope_with_cos_sin_cache), and FP8 quantized RoPE versions (mla_rope_quantize_fp8, rope_quantize_fp8, rope_quantize_fp8_append_paged_kv_cache).
  • Updated Documentation and Test Lists: The benchmarks/README.md file has been updated to reflect the new API support and includes detailed flag descriptions for Sampling and RoPE. The sample_testlist.txt and sample_testlist_output.csv files have been extended with example commands and expected results for the new benchmarks.
  • Refactored Benchmark Utilities: The flashinfer_benchmark_utils.py module was refactored to generalize argument handling and output column definitions, accommodating the new routine categories and their specific parameters more efficiently.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/README.md
    • Updated the overview to mention support for Sampling and RoPE APIs.
    • Expanded the list of supported APIs to include detailed descriptions of 15 Sampling routines and 8 RoPE routines.
    • Added new sections for 'Sampling Flags' and 'RoPE Flags' to document command-line arguments specific to these new benchmarks.
    • Extended the 'Routine & Backend Support Matrix' to include the new Sampling and RoPE APIs.
  • benchmarks/flashinfer_benchmark.py
    • Added conditional logic in run_test to dispatch to run_sampling_test and run_rope_test based on the selected routine.
    • Updated the parse_args function to include the new Sampling and RoPE APIs in the list of recognized routines.
    • Added conditional argument parsing for Sampling and RoPE routines by importing parse_sampling_args and parse_rope_args.
  • benchmarks/routines/flashinfer_benchmark_utils.py
    • Refactored output_column_dict by moving common arguments to a general category and removing them from specific routine categories (attention, gemm, moe_comm, norm, quantization).
    • Added new sampling and rope categories to output_column_dict with their respective benchmark-specific arguments.
    • Updated ALL_OUTPUT_COLUMNS to include the new sampling and rope output columns.
    • Extended ROUTINE_SUPPORT_MATRIX to define backend support for all new Sampling and RoPE APIs across various CUDA compute capabilities.
  • benchmarks/routines/rope.py
    • New file, implementing the run_rope_test function which acts as a dispatcher for various RoPE API tests.
    • Includes parse_rope_args for handling command-line arguments specific to RoPE benchmarks.
    • Contains individual test functions for apply_rope, apply_rope_pos_ids, apply_llama31_rope, apply_llama31_rope_pos_ids, apply_rope_with_cos_sin_cache, mla_rope_quantize_fp8, rope_quantize_fp8, and rope_quantize_fp8_append_paged_kv_cache. Each function sets up tensors, calls the respective FlashInfer API, and collects performance metrics.
  • benchmarks/routines/sampling.py
    • New file, implementing the run_sampling_test function which acts as a dispatcher for various Sampling API tests.
    • Includes parse_sampling_args for handling command-line arguments specific to Sampling benchmarks, with conditional argument requirements for vocab_size.
    • Contains individual test functions for softmax, sampling_from_probs, sampling_from_logits, top_k_sampling_from_probs, top_p_sampling_from_probs, top_k_top_p_sampling_from_probs, top_k_top_p_sampling_from_logits, min_p_sampling_from_probs, top_k_renorm_probs, top_p_renorm_probs, top_k_mask_logits, chain_speculative_sampling, top_k, top_k_page_table_transform, and top_k_ragged_transform. Each function sets up tensors, calls the respective FlashInfer API, and collects performance metrics.
  • benchmarks/samples/sample_testlist.txt
    • Removed some duplicate or less representative benchmark commands for existing attention and GEMM routines.
    • Added numerous new benchmark commands for the newly introduced Sampling and RoPE APIs, covering various configurations and data types.
  • benchmarks/samples/sample_testlist_output.csv
    • Updated the CSV header to reflect the new arguments and output columns introduced for Sampling and RoPE.
    • Added new rows with sample benchmark results for the new Sampling and RoPE routines.
Activity
  • The pull request was created by bkryu. No further human activity (comments, reviews, or updates) has been recorded since its creation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@bkryu bkryu added benchmark Pertains to performance benchmarking op: misc labels Feb 3, 2026
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Feb 3, 2026

@kahyunnam, can I get a review from you?

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly expands the FlashInfer microbenchmark harness by adding comprehensive support for Sampling and RoPE APIs. The changes are well-structured, introducing new modules for the added routines and updating documentation and utility files accordingly. My review identified a few areas for improvement. There's a minor inconsistency in parameter organization within flashinfer_benchmark_utils.py. Both new benchmark files, rope.py and sampling.py, contain considerable code duplication that could be refactored for better maintainability. Additionally, rope.py has a misleading comment and a potential bug related to inconsistent parameter handling. Addressing these points will enhance the quality and reliability of the new benchmarks.

num_kv_heads = args.num_kv_heads
head_dim = args.head_dim
rotary_dim = args.rotary_dim
no_rope_dim = head_dim - rotary_dim # For GQA/MHA
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The no_rope_dim is calculated here as head_dim - rotary_dim, which ignores the --no_rope_dim argument parsed from the command line. Other test functions in this file, like testRopeQuantizeFp8, correctly use args.no_rope_dim. This inconsistency can lead to incorrect benchmark configurations and user confusion. Please use args.no_rope_dim and consider adding an assertion like assert args.head_dim == args.rotary_dim + args.no_rope_dim to ensure consistency.

Suggested change
no_rope_dim = head_dim - rotary_dim # For GQA/MHA
no_rope_dim = args.no_rope_dim

],
"gemm": [
"m",
"n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the refactoring that moved m and k to the general category, the n parameter should also be moved from gemm to general. This would group all fundamental GEMM dimensions (m, n, k) together in the common parameters.

Comment on lines +33 to +61
def run_rope_test(args):
"""
Run a RoPE test.

Args:
args: Parsed command line arguments containing test configuration

Returns:
dict: List of dictionaries containing performance results
"""
if args.routine == "apply_rope":
return testApplyRope(args)
elif args.routine == "apply_rope_pos_ids":
return testApplyRopePosIds(args)
elif args.routine == "apply_llama31_rope":
return testApplyLlama31Rope(args)
elif args.routine == "apply_llama31_rope_pos_ids":
return testApplyLlama31RopePosIds(args)
elif args.routine == "apply_rope_with_cos_sin_cache":
return testApplyRopeWithCosSinCache(args)
elif args.routine == "mla_rope_quantize_fp8":
return testMlaRopeQuantizeFp8(args)
elif args.routine == "rope_quantize_fp8":
return testRopeQuantizeFp8(args)
elif args.routine == "rope_quantize_fp8_append_paged_kv_cache":
return testRopeQuantizeFp8AppendPagedKvCache(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication across the test functions in this file (e.g., testApplyRope, testApplyRopePosIds, etc.). The boilerplate for benchmarking, calculating performance metrics, and storing results is repeated in each function. Consider creating a generic test runner helper function to reduce this duplication. This would improve maintainability and readability.

4. Measures performance metrics (TB/sec)

Note: This API takes pre-split q_rope, k_rope, q_nope, k_nope tensors
and a precomputed cos_sin_cache. It is the same as rope_quantize_fp8.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment states that mla_rope_quantize_fp8 is the same as rope_quantize_fp8. This is misleading, as they handle different K-tensor shapes for MLA and GQA/MHA architectures respectively. Please clarify or remove this comment to avoid confusion.

Suggested change
and a precomputed cos_sin_cache. It is the same as rope_quantize_fp8.
and a precomputed cos_sin_cache.

Comment on lines +33 to +75
def run_sampling_test(args):
"""
Run a sampling test.

Args:
args: Parsed command line arguments containing test configuration

Returns:
dict: List of dictionaries containing performance results
"""
if args.routine == "softmax":
return testSoftmax(args)
elif args.routine == "sampling_from_probs":
return testSamplingFromProbs(args)
elif args.routine == "sampling_from_logits":
return testSamplingFromLogits(args)
elif args.routine == "top_k_sampling_from_probs":
return testTopKSamplingFromProbs(args)
elif args.routine == "top_p_sampling_from_probs":
return testTopPSamplingFromProbs(args)
elif args.routine == "top_k_top_p_sampling_from_probs":
return testTopKTopPSamplingFromProbs(args)
elif args.routine == "top_k_top_p_sampling_from_logits":
return testTopKTopPSamplingFromLogits(args)
elif args.routine == "min_p_sampling_from_probs":
return testMinPSamplingFromProbs(args)
elif args.routine == "top_k_renorm_probs":
return testTopKRenormProbs(args)
elif args.routine == "top_p_renorm_probs":
return testTopPRenormProbs(args)
elif args.routine == "top_k_mask_logits":
return testTopKMaskLogits(args)
elif args.routine == "chain_speculative_sampling":
return testChainSpeculativeSampling(args)
elif args.routine == "top_k":
return testTopK(args)
elif args.routine == "top_k_page_table_transform":
return testTopKPageTableTransform(args)
elif args.routine == "top_k_ragged_transform":
return testTopKRaggedTransform(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to rope.py, there is a large amount of duplicated code across the test functions in this file. The boilerplate for benchmarking, calculating performance metrics, and storing results is repeated. Refactoring this into a common helper function would greatly improve the maintainability of this file.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@benchmarks/README.md`:
- Around line 344-381: The markdown tables under the headings "### Sampling
Flags" and "### RoPE Flags" start immediately after the headings and before the
following text, triggering MD058; add a single blank line between each heading
and its table, and ensure there is a blank line after each table (i.e., add one
empty line before the table under "### Sampling Flags" and one after it, and do
the same for the table under "### RoPE Flags") so both tables have blank lines
before and after to satisfy markdownlint.

In `@benchmarks/routines/rope.py`:
- Around line 872-946: cos_sin_cache is created with input_dtype (often float16)
but apply_rope_with_cos_sin_cache requires float32, causing a ValueError;
recreate cos_sin_cache with dtype=torch.float32 on the same device and pass that
into run_backend/apply_rope_with_cos_sin_cache, and update the memory bandwidth
calc in problem_bytes to use 4 bytes per element for the cos_sin_cache read
(max_seq_len * rotary_dim * 4) instead of input_dtype.itemsize so the reported
TB/sec matches the actual float32 cache size.

Comment on lines +344 to +381
### Sampling Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--vocab_size` | Vocabulary size |
| `--input_dtype` | Input data type for logits: `float32` (default), `float16`, or `bfloat16` |
| `--top_k` | Top-K value for top-k sampling. Default: 50 |
| `--top_p` | Top-P threshold for top-p (nucleus) sampling. Default: 0.9 |
| `--min_p` | Min-P threshold for min-p sampling. Default: 0.1 |
| `--temperature` | Temperature for softmax. Default: 1.0 |
| `--filter_apply_order` | Order of applying top-k and top-p filters: `top_k_first` (default) or `joint` |
| `--num_speculate_tokens` | Number of speculative tokens for chain speculative sampling. Default: 5 |
| `--max_len` | Max sequence length for `top_k_page_table_transform` and `top_k_ragged_transform`. Default: 4096 |
| `--num_rows` | Number of rows for `top_k_page_table_transform` and `top_k_ragged_transform`. Defaults to batch_size |
| `--backends` | Backend to test: `cuda` (default) |

### RoPE Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--seq_len` | Sequence length (qkv_len or kv_len) |
| `--num_qo_heads` | Number of query/output heads |
| `--num_kv_heads` | Number of key/value heads |
| `--head_dim` | Head dimension |
| `--rotary_dim` | Rotary dimension (defaults to head_dim if not specified) |
| `--no_rope_dim` | Number of dimensions without RoPE (for MLA). Default: 0 |
| `--input_dtype` | Input data type: `float16` (default) or `bfloat16` |
| `--quant_dtype` | Quantized data type for FP8 routines: `fp8_e4m3` (default) or `fp8_e5m2` |
| `--rope_scale` | RoPE scaling factor. Default: 1.0 |
| `--rope_theta` | RoPE theta base frequency. Default: 10000.0 |
| `--interleave` | Use interleaved rotary embedding (GPT-J style) |
| `--page_size` | Page size for paged KV cache. Default: 16 |
| `--kv_layout` | KV cache layout: `NHD` (default) or `HND` |
| `--low_freq_factor` | Low frequency factor for Llama 3.1 RoPE. Default: 1.0 |
| `--high_freq_factor` | High frequency factor for Llama 3.1 RoPE. Default: 4.0 |
| `--old_context_len` | Old context length for Llama 3.1 RoPE. Default: 8192 |
| `--backends` | Backend to test: `cuda` (default) |

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add blank lines around the new tables to satisfy markdownlint.
Lines 345 and 361 start tables immediately after headings; MD058 expects blank lines before/after tables.

📝 Proposed markdownlint-friendly spacing
### Sampling Flags
+
| Flag                     | Description                                                                                                 |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size`           | Batch size (number of sequences)                                                                           |
| `--vocab_size`           | Vocabulary size                                                                                            |
| `--input_dtype`          | Input data type for logits: `float32` (default), `float16`, or `bfloat16`                                  |
| `--top_k`                | Top-K value for top-k sampling. Default: 50                                                                |
| `--top_p`                | Top-P threshold for top-p (nucleus) sampling. Default: 0.9                                                 |
| `--min_p`                | Min-P threshold for min-p sampling. Default: 0.1                                                           |
| `--temperature`          | Temperature for softmax. Default: 1.0                                                                      |
| `--filter_apply_order`   | Order of applying top-k and top-p filters: `top_k_first` (default) or `joint`                              |
| `--num_speculate_tokens` | Number of speculative tokens for chain speculative sampling. Default: 5                                    |
| `--max_len`              | Max sequence length for `top_k_page_table_transform` and `top_k_ragged_transform`. Default: 4096           |
| `--num_rows`             | Number of rows for `top_k_page_table_transform` and `top_k_ragged_transform`. Defaults to batch_size       |
| `--backends`             | Backend to test: `cuda` (default)                                                                          |
+
### RoPE Flags
+
| Flag                     | Description                                                                                                 |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size`           | Batch size (number of sequences)                                                                           |
| `--seq_len`              | Sequence length (qkv_len or kv_len)                                                                        |
| `--num_qo_heads`         | Number of query/output heads                                                                               |
| `--num_kv_heads`         | Number of key/value heads                                                                                  |
| `--head_dim`             | Head dimension                                                                                             |
| `--rotary_dim`           | Rotary dimension (defaults to head_dim if not specified)                                                   |
| `--no_rope_dim`          | Number of dimensions without RoPE (for MLA). Default: 0                                                    |
| `--input_dtype`          | Input data type: `float16` (default) or `bfloat16`                                                         |
| `--quant_dtype`          | Quantized data type for FP8 routines: `fp8_e4m3` (default) or `fp8_e5m2`                                   |
| `--rope_scale`           | RoPE scaling factor. Default: 1.0                                                                          |
| `--rope_theta`           | RoPE theta base frequency. Default: 10000.0                                                                |
| `--interleave`           | Use interleaved rotary embedding (GPT-J style)                                                             |
| `--page_size`            | Page size for paged KV cache. Default: 16                                                                  |
| `--kv_layout`            | KV cache layout: `NHD` (default) or `HND`                                                                  |
| `--low_freq_factor`      | Low frequency factor for Llama 3.1 RoPE. Default: 1.0                                                      |
| `--high_freq_factor`     | High frequency factor for Llama 3.1 RoPE. Default: 4.0                                                     |
| `--old_context_len`      | Old context length for Llama 3.1 RoPE. Default: 8192                                                       |
| `--backends`             | Backend to test: `cuda` (default)                                                                          |
+
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
### Sampling Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--vocab_size` | Vocabulary size |
| `--input_dtype` | Input data type for logits: `float32` (default), `float16`, or `bfloat16` |
| `--top_k` | Top-K value for top-k sampling. Default: 50 |
| `--top_p` | Top-P threshold for top-p (nucleus) sampling. Default: 0.9 |
| `--min_p` | Min-P threshold for min-p sampling. Default: 0.1 |
| `--temperature` | Temperature for softmax. Default: 1.0 |
| `--filter_apply_order` | Order of applying top-k and top-p filters: `top_k_first` (default) or `joint` |
| `--num_speculate_tokens` | Number of speculative tokens for chain speculative sampling. Default: 5 |
| `--max_len` | Max sequence length for `top_k_page_table_transform` and `top_k_ragged_transform`. Default: 4096 |
| `--num_rows` | Number of rows for `top_k_page_table_transform` and `top_k_ragged_transform`. Defaults to batch_size |
| `--backends` | Backend to test: `cuda` (default) |
### RoPE Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--seq_len` | Sequence length (qkv_len or kv_len) |
| `--num_qo_heads` | Number of query/output heads |
| `--num_kv_heads` | Number of key/value heads |
| `--head_dim` | Head dimension |
| `--rotary_dim` | Rotary dimension (defaults to head_dim if not specified) |
| `--no_rope_dim` | Number of dimensions without RoPE (for MLA). Default: 0 |
| `--input_dtype` | Input data type: `float16` (default) or `bfloat16` |
| `--quant_dtype` | Quantized data type for FP8 routines: `fp8_e4m3` (default) or `fp8_e5m2` |
| `--rope_scale` | RoPE scaling factor. Default: 1.0 |
| `--rope_theta` | RoPE theta base frequency. Default: 10000.0 |
| `--interleave` | Use interleaved rotary embedding (GPT-J style) |
| `--page_size` | Page size for paged KV cache. Default: 16 |
| `--kv_layout` | KV cache layout: `NHD` (default) or `HND` |
| `--low_freq_factor` | Low frequency factor for Llama 3.1 RoPE. Default: 1.0 |
| `--high_freq_factor` | High frequency factor for Llama 3.1 RoPE. Default: 4.0 |
| `--old_context_len` | Old context length for Llama 3.1 RoPE. Default: 8192 |
| `--backends` | Backend to test: `cuda` (default) |
### Sampling Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--vocab_size` | Vocabulary size |
| `--input_dtype` | Input data type for logits: `float32` (default), `float16`, or `bfloat16` |
| `--top_k` | Top-K value for top-k sampling. Default: 50 |
| `--top_p` | Top-P threshold for top-p (nucleus) sampling. Default: 0.9 |
| `--min_p` | Min-P threshold for min-p sampling. Default: 0.1 |
| `--temperature` | Temperature for softmax. Default: 1.0 |
| `--filter_apply_order` | Order of applying top-k and top-p filters: `top_k_first` (default) or `joint` |
| `--num_speculate_tokens` | Number of speculative tokens for chain speculative sampling. Default: 5 |
| `--max_len` | Max sequence length for `top_k_page_table_transform` and `top_k_ragged_transform`. Default: 4096 |
| `--num_rows` | Number of rows for `top_k_page_table_transform` and `top_k_ragged_transform`. Defaults to batch_size |
| `--backends` | Backend to test: `cuda` (default) |
### RoPE Flags
| Flag | Description |
|--------------------------|-------------------------------------------------------------------------------------------------------------|
| `--batch_size` | Batch size (number of sequences) |
| `--seq_len` | Sequence length (qkv_len or kv_len) |
| `--num_qo_heads` | Number of query/output heads |
| `--num_kv_heads` | Number of key/value heads |
| `--head_dim` | Head dimension |
| `--rotary_dim` | Rotary dimension (defaults to head_dim if not specified) |
| `--no_rope_dim` | Number of dimensions without RoPE (for MLA). Default: 0 |
| `--input_dtype` | Input data type: `float16` (default) or `bfloat16` |
| `--quant_dtype` | Quantized data type for FP8 routines: `fp8_e4m3` (default) or `fp8_e5m2` |
| `--rope_scale` | RoPE scaling factor. Default: 1.0 |
| `--rope_theta` | RoPE theta base frequency. Default: 10000.0 |
| `--interleave` | Use interleaved rotary embedding (GPT-J style) |
| `--page_size` | Page size for paged KV cache. Default: 16 |
| `--kv_layout` | KV cache layout: `NHD` (default) or `HND` |
| `--low_freq_factor` | Low frequency factor for Llama 3.1 RoPE. Default: 1.0 |
| `--high_freq_factor` | High frequency factor for Llama 3.1 RoPE. Default: 4.0 |
| `--old_context_len` | Old context length for Llama 3.1 RoPE. Default: 8192 |
| `--backends` | Backend to test: `cuda` (default) |
🧰 Tools
🪛 LanguageTool

[uncategorized] ~378-~378: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ... | | --high_freq_factor | High frequency factor for Llama 3.1 RoPE. Default: 4.0...

(EN_COMPOUND_ADJECTIVE_INTERNAL)

🪛 markdownlint-cli2 (0.20.0)

[warning] 345-345: Tables should be surrounded by blank lines

(MD058, blanks-around-tables)


[warning] 361-361: Tables should be surrounded by blank lines

(MD058, blanks-around-tables)

🤖 Prompt for AI Agents
In `@benchmarks/README.md` around lines 344 - 381, The markdown tables under the
headings "### Sampling Flags" and "### RoPE Flags" start immediately after the
headings and before the following text, triggering MD058; add a single blank
line between each heading and its table, and ensure there is a blank line after
each table (i.e., add one empty line before the table under "### Sampling Flags"
and one after it, and do the same for the table under "### RoPE Flags") so both
tables have blank lines before and after to satisfy markdownlint.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 4, 2026

#2374 is another relevant PR

Copy link
Copy Markdown
Collaborator

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Will be very useful for rope, thanks for implementing

@bkryu bkryu force-pushed the bench_sampling_rope branch from 95886b6 to 73d2f0a Compare February 5, 2026 00:46
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@benchmarks/routines/rope.py`:
- Around line 99-207: Validate parsed args so rotary_dim and no_rope_dim cannot
produce negative or out-of-range head dims: after you set args.rotary_dim =
args.head_dim (when None) add checks that args.no_rope_dim >= 0, args.rotary_dim
>= 0, args.no_rope_dim <= args.head_dim, args.rotary_dim <= args.head_dim, and
that args.rotary_dim + args.no_rope_dim <= args.head_dim; if any check fails
raise argparse.ArgumentTypeError or exit with a clear error message referencing
rotary_dim, no_rope_dim, and head_dim so the CLI user sees which values are
invalid.

In `@benchmarks/routines/sampling.py`:
- Around line 1567-1590: The src_page_table is created with batch_size rows but
must align with num_rows for testTopKPageTableTransform where input_scores has
shape (num_rows, max_len) and lengths is (num_rows,); change the src_page_table
creation in benchmarks/routines/sampling.py so src_page_table is sized
(num_rows, max_len) (use num_rows instead of batch_size) and keep
dtype=torch.int32 and device=device so it matches input_scores/lengths before
calling flashinfer.top_k_page_table_transform; also verify any subsequent
memory-bandwidth or shape assumptions reference num_rows rather than batch_size.

Comment on lines +99 to +207
"--head_dim",
type=int,
required=True,
help="Head dimension.",
)
parser.add_argument(
"--rotary_dim",
type=int,
required=False,
default=None,
help="Rotary dimension (defaults to head_dim if not specified).",
)
parser.add_argument(
"--no_rope_dim",
type=int,
required=False,
default=0,
help="Number of dimensions without RoPE (for MLA). Default: 0.",
)
parser.add_argument(
"--input_dtype",
type=str,
required=False,
default="float16",
choices=["float16", "bfloat16"],
help="Data type of the input tensor.",
)
parser.add_argument(
"--quant_dtype",
type=str,
required=False,
default="fp8_e4m3",
choices=["fp8_e4m3", "fp8_e5m2"],
help="Quantized data type for FP8 routines.",
)
parser.add_argument(
"--rope_scale",
type=float,
required=False,
default=1.0,
help="RoPE scaling factor.",
)
parser.add_argument(
"--rope_theta",
type=float,
required=False,
default=10000.0,
help="RoPE theta base frequency.",
)
parser.add_argument(
"--interleave",
action="store_true",
help="Use interleaved rotary embedding (GPT-J style).",
)
parser.add_argument(
"--page_size",
type=int,
required=False,
default=16,
help="Page size for paged KV cache.",
)
parser.add_argument(
"--kv_layout",
type=str,
required=False,
default="NHD",
choices=["NHD", "HND"],
help="KV cache layout.",
)
parser.add_argument(
"--low_freq_factor",
type=float,
required=False,
default=1.0,
help="Low frequency factor for Llama 3.1 RoPE.",
)
parser.add_argument(
"--high_freq_factor",
type=float,
required=False,
default=4.0,
help="High frequency factor for Llama 3.1 RoPE.",
)
parser.add_argument(
"--old_context_len",
type=int,
required=False,
default=8192,
help="Old context length for Llama 3.1 RoPE.",
)
parser.add_argument(
"--backends",
type=str,
required=False,
nargs="+",
default=["cuda"],
choices=["cuda"],
help="Kernel backends to test. Default: cuda",
)

args = parser.parse_args(line)

# Default rotary_dim to head_dim if not specified
if args.rotary_dim is None:
args.rotary_dim = args.head_dim

if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate rotary_dim/no_rope_dim against head_dim to avoid negative tensor shapes.
Invalid CLI values can lead to negative dims in MLA/FP8 paths and crash with cryptic tensor errors.

✅ Suggested validation
     # Default rotary_dim to head_dim if not specified
     if args.rotary_dim is None:
         args.rotary_dim = args.head_dim
+
+    if args.rotary_dim < 0 or args.rotary_dim > args.head_dim:
+        raise ValueError("--rotary_dim must be in [0, head_dim].")
+    if args.no_rope_dim < 0 or args.no_rope_dim > args.head_dim:
+        raise ValueError("--no_rope_dim must be in [0, head_dim].")
🤖 Prompt for AI Agents
In `@benchmarks/routines/rope.py` around lines 99 - 207, Validate parsed args so
rotary_dim and no_rope_dim cannot produce negative or out-of-range head dims:
after you set args.rotary_dim = args.head_dim (when None) add checks that
args.no_rope_dim >= 0, args.rotary_dim >= 0, args.no_rope_dim <= args.head_dim,
args.rotary_dim <= args.head_dim, and that args.rotary_dim + args.no_rope_dim <=
args.head_dim; if any check fails raise argparse.ArgumentTypeError or exit with
a clear error message referencing rotary_dim, no_rope_dim, and head_dim so the
CLI user sees which values are invalid.

Comment on lines +1567 to +1590
backends = args.backends[:]
batch_size = args.batch_size
num_rows = args.num_rows
max_len = args.max_len
top_k = args.top_k
is_cuda_graph_compatible = not args.no_cuda_graph
res = []

backends = filter_backends_by_compute_capability(backends, args.routine, device)
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res

input_dtype = dtype_str_to_torch_dtype(args.input_dtype)

## Prepare input tensors
# Input scores: (num_rows, max_len)
input_scores = torch.randn(num_rows, max_len, dtype=input_dtype, device=device)

# Source page table: (batch_size, max_len)
src_page_table = torch.randint(
0, 1000, (batch_size, max_len), dtype=torch.int32, device=device
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /repo && find . -name "sampling.py" -path "*/benchmarks/routines/*" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 123


🏁 Script executed:

wc -l ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 106


🏁 Script executed:

sed -n '1550,1620p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 2563


🏁 Script executed:

sed -n '1500,1560p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 2326


🏁 Script executed:

sed -n '1620,1654p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 1484


🏁 Script executed:

sed -n '1590,1630p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 1667


🏁 Script executed:

grep -n "sampled_token_ids\|temperatures" ./benchmarks/routines/sampling.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n "sampled_token_ids\s*=" ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n "temperatures\s*=" ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

grep -n "add_argument.*num_rows\|add_argument.*batch_size" ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

find ./benchmarks -name "*.py" -type f | xargs grep -l "add_argument.*num_rows" | head -3

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n "add_argument.*num_rows|add_argument.*batch_size" ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n "BatchTopKPrefixPageTableTransform\|top_k_page_table\|top_k_sampling" -A 10 ./benchmarks/routines/sampling.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

find . -name "*.py" -o -name "*.cu" -o -name "*.h" | xargs grep -l "BatchTopKPrefixPageTableTransform\|top_k_page_table" 2>/dev/null | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 226


🏁 Script executed:

sed -n '1620,1640p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 957


🏁 Script executed:

rg -n "^import vllm|^from vllm|vllm\.sample_top_k" ./benchmarks/routines/sampling.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

head -50 ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 1525


🏁 Script executed:

awk '/^def benchmark_top_k_sampling/,/^def [a-zA-Z]/ {print NR": "$0}' ./benchmarks/routines/sampling.py | tail -100

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n "^def benchmark_" ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

find ./benchmarks -name "*.py" | xargs grep -l "benchmark_top_k_sampling" | head -3

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

find ./benchmarks -name "*.py" -type f | xargs grep -l "sampling\|benchmark" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 503


🏁 Script executed:

sed -n '1540,1575p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 1164


🏁 Script executed:

find . -name "test*.py" -o -name "*test.py" | xargs grep -l "top_k_sampling\|sampling" 2>/dev/null | head -3

Repository: flashinfer-ai/flashinfer

Length of output: 159


🏁 Script executed:

sed -n '1567,1590p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 845


🏁 Script executed:

rg -n "input_scores" ./benchmarks/routines/sampling.py -A 3 -B 1

Repository: flashinfer-ai/flashinfer

Length of output: 3034


🏁 Script executed:

sed -n '1598,1650p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 2304


🏁 Script executed:

find . -name "*.py" | xargs grep -l "def sample_top_k" 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 52


🏁 Script executed:

sed -n '65,80p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 626


🏁 Script executed:

git log --oneline ./benchmarks/routines/sampling.py 2>/dev/null | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 86


🏁 Script executed:

git status ./benchmarks/routines/sampling.py 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 133


🏁 Script executed:

sed -n '1605,1611p' ./benchmarks/routines/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 284


🏁 Script executed:

rg -n "num_rows\|batch_size" ./benchmarks/routines/sampling.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


Align src_page_table rows with num_rows, not batch_size.

In testTopKPageTableTransform, input_scores is shaped (num_rows, max_len) and lengths is shaped (num_rows,), but src_page_table uses (batch_size, max_len). All three tensors are passed together to flashinfer.top_k_page_table_transform(). Since lengths corresponds to num_rows sequences, the page table must also have num_rows rows; if num_rows differs from batch_size, the current shape will cause a mismatch. The memory bandwidth calculation explicitly references both dimensions separately, confirming they are intended to vary independently and both should align with the sequence count (num_rows).

Suggested fix
-    src_page_table = torch.randint(
-        0, 1000, (batch_size, max_len), dtype=torch.int32, device=device
-    )
+    src_page_table = torch.randint(
+        0, 1000, (num_rows, max_len), dtype=torch.int32, device=device
+    )
🤖 Prompt for AI Agents
In `@benchmarks/routines/sampling.py` around lines 1567 - 1590, The src_page_table
is created with batch_size rows but must align with num_rows for
testTopKPageTableTransform where input_scores has shape (num_rows, max_len) and
lengths is (num_rows,); change the src_page_table creation in
benchmarks/routines/sampling.py so src_page_table is sized (num_rows, max_len)
(use num_rows instead of batch_size) and keep dtype=torch.int32 and
device=device so it matches input_scores/lengths before calling
flashinfer.top_k_page_table_transform; also verify any subsequent
memory-bandwidth or shape assumptions reference num_rows rather than batch_size.

@bkryu bkryu mentioned this pull request Feb 5, 2026
5 tasks
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Feb 5, 2026

@vincentzed , please check commit de91629 for newly added refchecks in sampling APIs.

cc @kahyunnam there was a suggestion to add refchecks to the sampling APIs so I added them.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/routines/sampling.py`:
- Line 1666: The unused variable ref_indices in the reference check should be
ignored to satisfy Ruff; change the unpacking in the torch.topk call (the line
assigning ref_values, ref_indices) to use a throwaway name (e.g., replace
ref_indices with _ or _ref_indices) so only ref_values is considered and the
linter warning is removed.

# Note: FlashInfer top_k returns UNSORTED results by default, so we compare
# sorted values to verify the same elements are selected
if run_refcheck and outputs:
ref_values, ref_indices = torch.topk(input_tensor.float(), k=top_k, dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Drop unused ref_indices to satisfy Ruff.

Ruff flags the unused variable in the reference check; replace it with _ (or _ref_indices).

♻️ Proposed fix
-        ref_values, ref_indices = torch.topk(input_tensor.float(), k=top_k, dim=-1)
+        ref_values, _ = torch.topk(input_tensor.float(), k=top_k, dim=-1)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
ref_values, ref_indices = torch.topk(input_tensor.float(), k=top_k, dim=-1)
ref_values, _ = torch.topk(input_tensor.float(), k=top_k, dim=-1)
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 1666-1666: Unpacked variable ref_indices is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In `@benchmarks/routines/sampling.py` at line 1666, The unused variable
ref_indices in the reference check should be ignored to satisfy Ruff; change the
unpacking in the torch.topk call (the line assigning ref_values, ref_indices) to
use a throwaway name (e.g., replace ref_indices with _ or _ref_indices) so only
ref_values is considered and the linter warning is removed.

@bkryu bkryu merged commit 1748eb5 into flashinfer-ai:main Feb 5, 2026
31 checks passed
@bkryu bkryu deleted the bench_sampling_rope branch February 6, 2026 17:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

benchmark Pertains to performance benchmarking op: misc

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants