benchmark: Enable speculative decode microbenchmarking for paged decode#2628
benchmark: Enable speculative decode microbenchmarking for paged decode#2628bkryu merged 2 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the benchmarking capabilities by introducing support for speculative decode in the paged KV cache decode attention wrapper. This allows for performance evaluation of multi-token query scenarios, which is crucial for advanced decoding strategies. The changes ensure that the benchmarking infrastructure correctly configures and passes necessary parameters, such as causal masks and query lengths, to the underlying attention kernels, while also gracefully handling backend limitations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds speculative decoding support to the benchmarks: introduces generate_speculative_causal_mask(), propagates packed causal masks and per-request q_len (s_qo) through backend paths, alters query shaping/metrics, and documents enabling speculative decode via Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Function
participant MaskGen as Mask Generator
participant Backend as Backend Runner
participant Metrics as Metrics Calculator
Test->>Test: Determine speculative_decode (s_qo > 1)
alt speculative_decode == true
Test->>MaskGen: generate_speculative_causal_mask(batch_size, s_qo, device)
MaskGen-->>Test: packed uint16 mask tensor
loop per backend
Test->>Backend: check support (fa2/fa2_tc/cudnn/auto)
alt supported
Test->>Backend: run with q_len_per_req=s_qo, mask (if native)
Backend-->>Test: decoded outputs
else unsupported
Test-->>Test: log and skip backend
end
end
Test->>Metrics: compute tflops/TB/s using s_qo-aware lengths
Metrics-->>Test: performance metrics
else
Test->>Backend: run standard decode (no mask)
Backend-->>Test: decoded outputs
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request enables speculative decode microbenchmarking for the paged decode path in the attention benchmark routine. It introduces a utility to generate packed causal masks for speculative chunks and updates the benchmark logic to handle multi-token queries, including backend-specific constraints and performance metric calculations. The changes are well-integrated into the existing benchmarking framework, though there are some concerns regarding the data type and layout of the generated speculative mask and the consistency of masking across different backend wrappers.
| return backend_wrappers[backend].run( | ||
| q, kv_cache, k_scale=k_scale, v_scale=v_scale | ||
| q, kv_cache, k_scale=k_scale, v_scale=v_scale, q_len_per_req=s_qo | ||
| ) |
There was a problem hiding this comment.
The BatchDecodeWithPagedKVCacheWrapper.run method (as seen in flashinfer/decode.py) does not currently accept a mask parameter. Consequently, when speculative_decode is enabled, the auto and trtllm-gen backends (which use this wrapper) are executing unmasked attention. This is inconsistent with the trtllm-native path (line 626) which correctly applies the causal mask. This inconsistency will lead to misleading performance results and incorrect outputs if refcheck were enabled for these paths.
There was a problem hiding this comment.
Fair point. Will disallow backend='auto' for speculative decoding.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/routines/attention.py (1)
320-377:⚠️ Potential issue | 🔴 CriticalGuard
autobackend from speculative decode.The code explicitly removes
fa2,fa2_tc, andcudnnwhenspeculative_decode=True(whens_qo > 1), butautois never removed. Duringplan(), theautobackend resolves tofa2orfa3based on device architecture and dtype (flashinfer/decode.py:393-401), without considering the speculative decode constraint. Whenrun()is later called withq_len_per_req=s_qo > 1, the resolved backend may not support this configuration, leading to incorrect behavior.Add the suggested guard to filter out
autowhenspeculative_decode=True:Suggested fix
if "cudnn" in backends: remove_cudnn = False if speculative_decode: print("[INFO] cuDNN backend does not support speculative decode. Skipping.") remove_cudnn = True ... if remove_cudnn: backends.remove("cudnn") + + if speculative_decode and "auto" in backends: + print("[INFO] auto backend may resolve to unsupported kernels for speculative decode. Skipping.") + backends.remove("auto")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 320 - 377, The auto backend is not being filtered when speculative_decode is true, so add a guard that removes "auto" from backends when speculative_decode (s_qo > 1) to prevent plan() from resolving to unsupported backends later; locate the backend filtering block that uses speculative_decode and backends (same area handling "fa2", "fa2_tc", "cudnn") and add a conditional: if speculative_decode and "auto" in backends then print an informative message and backends.remove("auto").
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 320-377: The auto backend is not being filtered when
speculative_decode is true, so add a guard that removes "auto" from backends
when speculative_decode (s_qo > 1) to prevent plan() from resolving to
unsupported backends later; locate the backend filtering block that uses
speculative_decode and backends (same area handling "fa2", "fa2_tc", "cudnn")
and add a conditional: if speculative_decode and "auto" in backends then print
an informative message and backends.remove("auto").
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
benchmarks/routines/attention.py (2)
127-132:⚠️ Potential issue | 🟡 MinorEnforce
--s_qo >= 1during argument parsing.Non-positive values can flow into tensor shapes and
q_len_per_req, causing avoidable runtime failures later.Suggested fix
args = parser.parse_args(line) +if args.s_qo < 1: + parser.error("--s_qo must be >= 1") # Normalize backend names (handle deprecated names) args.backends = normalize_backends(args.backends)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 127 - 132, The --s_qo argument may be parsed as non-positive and later used for tensor shapes and computing q_len_per_req; enforce s_qo >= 1 during parsing by validating the parsed value and rejecting (parser.error) or converting invalid values. Locate the add_argument call for "--s_qo" in benchmarks/routines/attention.py and add a validation step (either via a small positive_int converter used as type or an explicit check after parse_args referencing args.s_qo) that raises an argparse error if s_qo < 1, ensuring downstream uses (q_len_per_req and any tensor-shape calculations) only see values >= 1.
334-380:⚠️ Potential issue | 🟠 MajorSpeculative backend pruning can make decode refcheck silently no-op.
With speculative decode (Lines 334-380),
fa2is removed, but decode refcheck still only assigns a reference fromfa2(Line 658). If multiple remaining backends run (e.g.,trtllm-gen+trtllm-native), outputs are collected but never compared.Suggested fix (add fallback reference selection in decode path)
# Perform reference check tested_backends = list(outputs.keys()) tested_outputs = list(outputs.values()) + if run_refcheck and not has_reference_output and len(tested_backends) > 1: + for candidate in ["trtllm-gen", "trtllm-native", "cudnn", "fa2_tc", "auto"]: + if candidate in tested_backends: + has_reference_output = True + reference_output = outputs[candidate] + if args.verbose >= 1: + print( + f"[INFO] FA2 not available for reference. Using {candidate} as reference backend for cross-comparison." + ) + break if len(tested_backends) > 1: if run_refcheck and has_reference_output: ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 334 - 380, Speculative pruning can remove "fa2" but later decode refcheck still assumes outputs["fa2"] (around the decode ref assignment at line ~658), causing comparisons to be skipped; update the decode reference selection to fall back when "fa2" is absent by checking if "fa2" in outputs and otherwise selecting a valid reference from the remaining backends (e.g., outputs[next(iter(outputs))] or by a small priority list), then use that chosen reference for the refcheck so outputs from multiple backends get compared even when "fa2" was removed; ensure this logic lives with the existing decode/refcheck code and uses the existing backends/outputs variables and speculative_decode flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 127-132: The --s_qo argument may be parsed as non-positive and
later used for tensor shapes and computing q_len_per_req; enforce s_qo >= 1
during parsing by validating the parsed value and rejecting (parser.error) or
converting invalid values. Locate the add_argument call for "--s_qo" in
benchmarks/routines/attention.py and add a validation step (either via a small
positive_int converter used as type or an explicit check after parse_args
referencing args.s_qo) that raises an argparse error if s_qo < 1, ensuring
downstream uses (q_len_per_req and any tensor-shape calculations) only see
values >= 1.
- Around line 334-380: Speculative pruning can remove "fa2" but later decode
refcheck still assumes outputs["fa2"] (around the decode ref assignment at line
~658), causing comparisons to be skipped; update the decode reference selection
to fall back when "fa2" is absent by checking if "fa2" in outputs and otherwise
selecting a valid reference from the remaining backends (e.g.,
outputs[next(iter(outputs))] or by a small priority list), then use that chosen
reference for the refcheck so outputs from multiple backends get compared even
when "fa2" was removed; ensure this logic lives with the existing
decode/refcheck code and uses the existing backends/outputs variables and
speculative_decode flag.
…de (flashinfer-ai#2628) <!-- .github/pull_request_template.md --> ## 📌 Description **No change to library code** * Add speculative decode support to `BatchDecodeWithPagedKVCacheWrapper` benchmark path by handling `--s_qo > 1` (multi-token query shape, packed causal mask generation, and q_len_per_req forwarding). * Update `benchmarks/README.md` with speculative decode docs. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Enabled speculative decoding in benchmark tests with conditional execution on supported backends and per-request sizing. * **Documentation** * Clarified speculative decoding configuration in benchmark docs and updated CSV help text explaining s_qo behavior. * **Tests** * Updated benchmark routines and metrics to report speculative decoding behavior and adjusted test flows to skip unsupported backends. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
No change to library code
BatchDecodeWithPagedKVCacheWrapperbenchmark path by handling--s_qo > 1(multi-token query shape, packed causal mask generation, and q_len_per_req forwarding).benchmarks/README.mdwith speculative decode docs.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Tests