Skip to content

benchmark: Enable speculative decode microbenchmarking for paged decode#2628

Merged
bkryu merged 2 commits intoflashinfer-ai:mainfrom
bkryu:bench_specdec
Feb 25, 2026
Merged

benchmark: Enable speculative decode microbenchmarking for paged decode#2628
bkryu merged 2 commits intoflashinfer-ai:mainfrom
bkryu:bench_specdec

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Feb 24, 2026

📌 Description

No change to library code

  • Add speculative decode support to BatchDecodeWithPagedKVCacheWrapper benchmark path by handling --s_qo > 1 (multi-token query shape, packed causal mask generation, and q_len_per_req forwarding).
  • Update benchmarks/README.md with speculative decode docs.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Enabled speculative decoding in benchmark tests with conditional execution on supported backends and per-request sizing.
  • Documentation

    • Clarified speculative decoding configuration in benchmark docs and updated CSV help text explaining s_qo behavior.
  • Tests

    • Updated benchmark routines and metrics to report speculative decoding behavior and adjusted test flows to skip unsupported backends.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the benchmarking capabilities by introducing support for speculative decode in the paged KV cache decode attention wrapper. This allows for performance evaluation of multi-token query scenarios, which is crucial for advanced decoding strategies. The changes ensure that the benchmarking infrastructure correctly configures and passes necessary parameters, such as causal masks and query lengths, to the underlying attention kernels, while also gracefully handling backend limitations.

Highlights

  • Speculative Decode Support: Enabled speculative decode microbenchmarking for the BatchDecodeWithPagedKVCacheWrapper by handling multi-token query shapes (--s_qo > 1), generating packed causal masks, and forwarding q_len_per_req to attention kernels.
  • Backend Compatibility: Implemented checks to skip backends (FA2, FA2_TC, cuDNN) that do not currently support speculative decode during benchmarking.
  • Documentation Update: Updated the benchmarks/README.md to reflect the new speculative decode capabilities and the usage of the --s_qo parameter.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/README.md
    • Updated the description for BatchDecodeWithPagedKVCacheWrapper to mention speculative decode support when --s_qo > 1.
    • Clarified the --s_qo parameter description to distinguish between standard decode (1) and speculative decode (>1).
  • benchmarks/routines/attention.py
    • Modified the help message for the --s_qo argument to indicate its role in enabling speculative decode.
    • Added a new function generate_speculative_causal_mask to create packed causal masks for multi-token speculative decode.
    • Introduced a speculative_decode flag derived from s_qo to control speculative decode logic.
    • Added conditional logic to skip FA2, FA2_TC, and cuDNN backends if speculative decode is enabled, as they lack support.
    • Adjusted the query tensor q creation to accommodate batch_size * s_qo for multi-token queries.
    • Updated the ragged_q calculation to correctly reflect the s_qo dimension.
    • Integrated the generate_speculative_causal_mask function to create and pass speculative_mask to attention kernels when needed.
    • Modified run_backend_wrapper to pass q_len_per_req and mask arguments to the flashinfer.decode.batch_decode_with_paged_kv_cache function.
    • Updated the calculation of actual_seq_lens_q_flat to use s_qo for accurate TFLOPS computation in speculative decode scenarios.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 24, 2026

📝 Walkthrough

Walkthrough

Adds speculative decoding support to the benchmarks: introduces generate_speculative_causal_mask(), propagates packed causal masks and per-request q_len (s_qo) through backend paths, alters query shaping/metrics, and documents enabling speculative decode via --s_qo > 1. Unsupported backends are skipped for speculative mode.

Changes

Cohort / File(s) Summary
Documentation
benchmarks/README.md
Documented that --s_qo > 1 enables speculative decode on supported backends and clarified --s_qo in output CSV description.
Speculative decode & test harness
benchmarks/routines/attention.py
Added generate_speculative_causal_mask(batch_size, q_seq_len, device); updated testBatchDecodeWithPagedKVCacheWrapper to gate speculative_decode (s_qo>1), skip unsupported backends (fa2, fa2_tc, cudnn, auto), adjust query/ragged shapes, pass q_len_per_req and mask into backend wrappers/native backends, and reflect s_qo in metrics and sequence-length bookkeeping.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test Function
    participant MaskGen as Mask Generator
    participant Backend as Backend Runner
    participant Metrics as Metrics Calculator

    Test->>Test: Determine speculative_decode (s_qo > 1)
    alt speculative_decode == true
        Test->>MaskGen: generate_speculative_causal_mask(batch_size, s_qo, device)
        MaskGen-->>Test: packed uint16 mask tensor
        loop per backend
            Test->>Backend: check support (fa2/fa2_tc/cudnn/auto)
            alt supported
                Test->>Backend: run with q_len_per_req=s_qo, mask (if native)
                Backend-->>Test: decoded outputs
            else unsupported
                Test-->>Test: log and skip backend
            end
        end
        Test->>Metrics: compute tflops/TB/s using s_qo-aware lengths
        Metrics-->>Test: performance metrics
    else
        Test->>Backend: run standard decode (no mask)
        Backend-->>Test: decoded outputs
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • cyx-6
  • jiahanc
  • Anerudhan
  • nv-yunzheq
  • kahyunnam
  • yzh119

Poem

🐰 I stitched a mask in tiny rows,
Bits aligned where forward goes,
Speculative hops across the queue,
Backends dance, the tokens flew,
A little rabbit cheers for you!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: enabling speculative decode microbenchmarking for the paged decode benchmark path, which aligns with the core modifications in the PR.
Description check ✅ Passed The description covers the key changes (speculative decode support, masked generation, q_len_per_req forwarding) and documentation updates. Pre-commit and test checkboxes are marked complete. Minor non-critical sections are empty but the essential content is present.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables speculative decode microbenchmarking for the paged decode path in the attention benchmark routine. It introduces a utility to generate packed causal masks for speculative chunks and updates the benchmark logic to handle multi-token queries, including backend-specific constraints and performance metric calculations. The changes are well-integrated into the existing benchmarking framework, though there are some concerns regarding the data type and layout of the generated speculative mask and the consistency of masking across different backend wrappers.

Comment on lines 596 to 598
return backend_wrappers[backend].run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale
q, kv_cache, k_scale=k_scale, v_scale=v_scale, q_len_per_req=s_qo
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The BatchDecodeWithPagedKVCacheWrapper.run method (as seen in flashinfer/decode.py) does not currently accept a mask parameter. Consequently, when speculative_decode is enabled, the auto and trtllm-gen backends (which use this wrapper) are executing unmasked attention. This is inconsistent with the trtllm-native path (line 626) which correctly applies the causal mask. This inconsistency will lead to misleading performance results and incorrect outputs if refcheck were enabled for these paths.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. Will disallow backend='auto' for speculative decoding.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/routines/attention.py (1)

320-377: ⚠️ Potential issue | 🔴 Critical

Guard auto backend from speculative decode.

The code explicitly removes fa2, fa2_tc, and cudnn when speculative_decode=True (when s_qo > 1), but auto is never removed. During plan(), the auto backend resolves to fa2 or fa3 based on device architecture and dtype (flashinfer/decode.py:393-401), without considering the speculative decode constraint. When run() is later called with q_len_per_req=s_qo > 1, the resolved backend may not support this configuration, leading to incorrect behavior.

Add the suggested guard to filter out auto when speculative_decode=True:

Suggested fix
     if "cudnn" in backends:
         remove_cudnn = False
         if speculative_decode:
             print("[INFO] cuDNN backend does not support speculative decode. Skipping.")
             remove_cudnn = True
         ...
         if remove_cudnn:
             backends.remove("cudnn")
+
+    if speculative_decode and "auto" in backends:
+        print("[INFO] auto backend may resolve to unsupported kernels for speculative decode. Skipping.")
+        backends.remove("auto")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 320 - 377, The auto backend is
not being filtered when speculative_decode is true, so add a guard that removes
"auto" from backends when speculative_decode (s_qo > 1) to prevent plan() from
resolving to unsupported backends later; locate the backend filtering block that
uses speculative_decode and backends (same area handling "fa2", "fa2_tc",
"cudnn") and add a conditional: if speculative_decode and "auto" in backends
then print an informative message and backends.remove("auto").
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 320-377: The auto backend is not being filtered when
speculative_decode is true, so add a guard that removes "auto" from backends
when speculative_decode (s_qo > 1) to prevent plan() from resolving to
unsupported backends later; locate the backend filtering block that uses
speculative_decode and backends (same area handling "fa2", "fa2_tc", "cudnn")
and add a conditional: if speculative_decode and "auto" in backends then print
an informative message and backends.remove("auto").

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26ef055 and 833797d.

📒 Files selected for processing (2)
  • benchmarks/README.md
  • benchmarks/routines/attention.py

@bkryu bkryu self-assigned this Feb 25, 2026
@bkryu bkryu added ready benchmark Pertains to performance benchmarking labels Feb 25, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
benchmarks/routines/attention.py (2)

127-132: ⚠️ Potential issue | 🟡 Minor

Enforce --s_qo >= 1 during argument parsing.

Non-positive values can flow into tensor shapes and q_len_per_req, causing avoidable runtime failures later.

Suggested fix
 args = parser.parse_args(line)
+if args.s_qo < 1:
+    parser.error("--s_qo must be >= 1")
 
 # Normalize backend names (handle deprecated names)
 args.backends = normalize_backends(args.backends)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 127 - 132, The --s_qo argument
may be parsed as non-positive and later used for tensor shapes and computing
q_len_per_req; enforce s_qo >= 1 during parsing by validating the parsed value
and rejecting (parser.error) or converting invalid values. Locate the
add_argument call for "--s_qo" in benchmarks/routines/attention.py and add a
validation step (either via a small positive_int converter used as type or an
explicit check after parse_args referencing args.s_qo) that raises an argparse
error if s_qo < 1, ensuring downstream uses (q_len_per_req and any tensor-shape
calculations) only see values >= 1.

334-380: ⚠️ Potential issue | 🟠 Major

Speculative backend pruning can make decode refcheck silently no-op.

With speculative decode (Lines 334-380), fa2 is removed, but decode refcheck still only assigns a reference from fa2 (Line 658). If multiple remaining backends run (e.g., trtllm-gen + trtllm-native), outputs are collected but never compared.

Suggested fix (add fallback reference selection in decode path)
     # Perform reference check
     tested_backends = list(outputs.keys())
     tested_outputs = list(outputs.values())
+    if run_refcheck and not has_reference_output and len(tested_backends) > 1:
+        for candidate in ["trtllm-gen", "trtllm-native", "cudnn", "fa2_tc", "auto"]:
+            if candidate in tested_backends:
+                has_reference_output = True
+                reference_output = outputs[candidate]
+                if args.verbose >= 1:
+                    print(
+                        f"[INFO] FA2 not available for reference. Using {candidate} as reference backend for cross-comparison."
+                    )
+                break
     if len(tested_backends) > 1:
         if run_refcheck and has_reference_output:
             ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 334 - 380, Speculative pruning
can remove "fa2" but later decode refcheck still assumes outputs["fa2"] (around
the decode ref assignment at line ~658), causing comparisons to be skipped;
update the decode reference selection to fall back when "fa2" is absent by
checking if "fa2" in outputs and otherwise selecting a valid reference from the
remaining backends (e.g., outputs[next(iter(outputs))] or by a small priority
list), then use that chosen reference for the refcheck so outputs from multiple
backends get compared even when "fa2" was removed; ensure this logic lives with
the existing decode/refcheck code and uses the existing backends/outputs
variables and speculative_decode flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 127-132: The --s_qo argument may be parsed as non-positive and
later used for tensor shapes and computing q_len_per_req; enforce s_qo >= 1
during parsing by validating the parsed value and rejecting (parser.error) or
converting invalid values. Locate the add_argument call for "--s_qo" in
benchmarks/routines/attention.py and add a validation step (either via a small
positive_int converter used as type or an explicit check after parse_args
referencing args.s_qo) that raises an argparse error if s_qo < 1, ensuring
downstream uses (q_len_per_req and any tensor-shape calculations) only see
values >= 1.
- Around line 334-380: Speculative pruning can remove "fa2" but later decode
refcheck still assumes outputs["fa2"] (around the decode ref assignment at line
~658), causing comparisons to be skipped; update the decode reference selection
to fall back when "fa2" is absent by checking if "fa2" in outputs and otherwise
selecting a valid reference from the remaining backends (e.g.,
outputs[next(iter(outputs))] or by a small priority list), then use that chosen
reference for the refcheck so outputs from multiple backends get compared even
when "fa2" was removed; ensure this logic lives with the existing
decode/refcheck code and uses the existing backends/outputs variables and
speculative_decode flag.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 833797d and 060f67c.

📒 Files selected for processing (1)
  • benchmarks/routines/attention.py

Copy link
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bkryu bkryu merged commit 4e03158 into flashinfer-ai:main Feb 25, 2026
21 checks passed
@bkryu bkryu deleted the bench_specdec branch February 26, 2026 01:44
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…de (flashinfer-ai#2628)

<!-- .github/pull_request_template.md -->

## 📌 Description

**No change to library code**

* Add speculative decode support to `BatchDecodeWithPagedKVCacheWrapper`
benchmark path by handling `--s_qo > 1` (multi-token query shape, packed
causal mask generation, and q_len_per_req forwarding).
* Update `benchmarks/README.md` with speculative decode docs.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Enabled speculative decoding in benchmark tests with conditional
execution on supported backends and per-request sizing.

* **Documentation**
* Clarified speculative decoding configuration in benchmark docs and
updated CSV help text explaining s_qo behavior.

* **Tests**
* Updated benchmark routines and metrics to report speculative decoding
behavior and adjusted test flows to skip unsupported backends.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

benchmark Pertains to performance benchmarking ready

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants