Skip to content

pick fa2 for BatchDecodeWithPagedKVCacheWrapper auto backend#2530

Merged
yzh119 merged 5 commits intomainfrom
fix/mingyangw/0_6_1_decode_auto_backend
Feb 17, 2026
Merged

pick fa2 for BatchDecodeWithPagedKVCacheWrapper auto backend#2530
yzh119 merged 5 commits intomainfrom
fix/mingyangw/0_6_1_decode_auto_backend

Conversation

@saltyminty
Copy link
Collaborator

@saltyminty saltyminty commented Feb 9, 2026

📌 Description

BatchDecodeWithPagedKVCacheWrapper used to be hardcoded to select fa2, but now selects fa3 on Hopper due to the refactoring to use determine_attention_backend, which is slower than fa2. Revert to hardcode to fa2 for non-fp8 workloads.

Add auto backend support to flashinfer_benchmark.py for BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper.

🔍 Related Issues

#2400

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added an "auto" backend option that records and displays the actual backend as auto(resolved_backend).
    • Performance outputs now include a resolved_backend column showing which backend was used.
    • Auto routing differs for FP8 vs non-FP8 data paths when selecting the actual backend; supported wrappers limited to decode/prefill paged-KV variants.
  • Documentation

    • CLI/help text updated to document the --backends auto option and supported combinations.
  • Compatibility

    • Deprecated backend names normalized for consistent reporting.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @saltyminty, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refines the backend selection mechanism for attention operations, specifically ensuring that BatchDecodeWithPagedKVCacheWrapper leverages the optimal fa2 backend for non-FP8 scenarios to prevent performance degradation. Concurrently, it extends the benchmarking infrastructure to properly test and report on the auto backend's behavior for both decode and prefill operations, enhancing the clarity and utility of performance evaluations.

Highlights

  • Backend Selection Logic: The BatchDecodeWithPagedKVCacheWrapper now explicitly uses the fa2 backend for non-FP8 workloads when the auto backend is selected, addressing a performance regression where it was incorrectly defaulting to fa3 on Hopper architectures.
  • Benchmark Tooling Enhancement: The flashinfer_benchmark.py script has been updated to support the auto backend for BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper, allowing for more flexible and accurate performance testing.
  • Benchmark Output and Validation: The benchmark results now include a resolved_backend field to indicate the actual backend chosen when auto is specified. Additionally, validation has been added to ensure the auto backend is only used with supported wrappers.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/README.md
    • Updated the description for the --backends argument to include auto and specify its supported wrappers.
  • benchmarks/routines/attention.py
    • Added 'auto' to the list of available kernel backends in argument parsing.
    • Implemented validation to restrict the 'auto' backend to BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper.
    • Updated testBatchDecodeWithPagedKVCacheWrapper to include 'auto' in its supported backends list.
    • Introduced resolved_backends dictionary to track the actual backend chosen when 'auto' is used in testBatchDecodeWithPagedKVCacheWrapper.
    • Modified run_backend_wrapper in testBatchDecodeWithPagedKVCacheWrapper to include 'auto' in the list of backends that use flashinfer wrappers.
    • Adjusted performance metric printing in testBatchDecodeWithPagedKVCacheWrapper to display 'auto' with its resolved backend.
    • Added resolved_backend to the output CSV data in testBatchDecodeWithPagedKVCacheWrapper.
    • Updated testBatchPrefillWithPagedKVCacheWrapper to include 'auto' in its supported backends list.
    • Introduced resolved_backends dictionary to track the actual backend chosen when 'auto' is used in testBatchPrefillWithPagedKVCacheWrapper.
    • Modified run_backend_wrapper in testBatchPrefillWithPagedKVCacheWrapper to include 'auto' in the list of backends that use flashinfer wrappers.
    • Adjusted performance metric printing in testBatchPrefillWithPagedKVCacheWrapper to display 'auto' with its resolved backend.
    • Added resolved_backend to the output CSV data in testBatchPrefillWithPagedKVCacheWrapper.
  • benchmarks/routines/flashinfer_benchmark_utils.py
    • Added resolved_backend to the ATTENTION_COLUMNS_TO_REPORT list for benchmark output.
    • Updated _ATTENTION_SUPPORTED_BACKENDS to include 'auto' for BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper across various CUDA architectures.
  • flashinfer/decode.py
    • Modified the plan method of BatchDecodeWithPagedKVCacheWrapper to hardcode the fa2 backend when auto is selected for non-FP8 query data types, otherwise it uses determine_attention_backend.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an "auto" backend token across benchmark wrappers with runtime resolution and reporting, normalizes deprecated backend names, appends a resolved_backend field to perf outputs, and updates decode backend selection to route auto based on FP8 vs non‑FP8 q/kv data types.

Changes

Cohort / File(s) Summary
Docs
benchmarks/README.md
Document --backends auto option and note current wrapper support.
Benchmark wrappers & runtime
benchmarks/routines/attention.py
Add auto backend choice, normalize deprecated backend names (e.g., trtllm-gen-nativetrtllm-native, fa2_tcfa2), track/display auto(resolved_backend), and record resolved_backend in results.
Benchmark utils / output schema
benchmarks/routines/flashinfer_benchmark_utils.py
Add resolved_backend to perf output columns; include auto in per-version backend mappings for relevant routines.
Decode/backend selection
flashinfer/decode.py
When backend=auto, resolve based on q/kv dtypes: FP8 -> run determine_attention_backend (preserving FP8 flags, renamed flag use_fp16_qk_reductions), non‑FP8 -> route to fa2; return/record resolved backend.

Sequence Diagram

sequenceDiagram
    participant Bench as Benchmark Orchestration
    participant Resolver as Backend Resolver
    participant Decode as Decode Logic
    participant Wrapper as Wrapper Impl
    participant Output as Perf Recorder

    Bench->>Resolver: Run routine with backend="auto"
    Resolver->>Resolver: Check wrapper supports auto
    Resolver->>Decode: Query q/kv data types
    Decode-->>Resolver: Return resolved_backend (FP8 -> determine_attention_backend, else -> fa2)
    Resolver->>Wrapper: Execute using resolved_backend
    Wrapper-->>Resolver: Return metrics
    Resolver->>Output: Record metrics with backend="auto(resolved_backend)" and resolved_backend field
    Output-->>Bench: Present results
Loading

Estimated Code Review Effort

🎯 3 (Moderate) | ⏱️ ~30 minutes

Possibly Related PRs

Suggested Reviewers

  • yzh119
  • cyx-6
  • Anerudhan
  • bkryu
  • nvmbreughe

Poem

🐰 I hop through code with a twitchy flip,
"auto" finds a path on every trip,
FP8 nudges one way, plain tensors go two,
Resolved and logged — the bench sings true,
I nibble bytes and leave a tiny clue.

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately captures the main functional change: revertingBatchDecodeWithPagedKVCacheWrapper to use fa2 in the auto backend selection path.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main
Description check ✅ Passed The pull request description includes a clear explanation of changes, related issue link, and addresses the main objectives of reverting fa2 for non-fp8 workloads and adding auto backend support.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix/mingyangw/0_6_1_decode_auto_backend

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@saltyminty
Copy link
Collaborator Author

bot run

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a performance regression in BatchDecodeWithPagedKVCacheWrapper for non-FP8 workloads by hardcoding the fa2 backend when auto is selected. It also introduces auto backend support in the benchmark suite for both BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper. The changes are logical and well-aligned with the stated objectives. I've included a couple of minor suggestions to enhance code maintainability by reducing hardcoded values.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/decode.py`:
- Around line 1042-1053: The auto-backend logic incorrectly defaults to "fa2"
when only kv_data_type is FP8; update the branch in decode.py where
self._backend is set (the block using q_data_type, kv_data_type,
PosEncodingMode[pos_encoding_mode], and determine_attention_backend) so that it
calls determine_attention_backend whenever either q_data_type or kv_data_type is
an FP8 type (e.g., torch.float8_e4m3fn or torch.float8_e5m2) instead of only
when q_data_type is FP8, ensuring FP8 KV-only configurations select a valid
backend.

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @saltyminty . Left a comment on the benchmark code.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/routines/attention.py (1)

480-513: ⚠️ Potential issue | 🟡 Minor

"auto" backend bypasses head_grp_size == 5 constraint check for FA2.

The constraint check at lines 302–313 guards against head_grp_size == 5 for FA2, but only when "fa2" is explicitly in the backends list. When backend="auto", this check is skipped. Later, during .plan() (flashinfer/decode.py lines 1042–1056), "auto" resolves to "fa2" by default when non-FP8 dtypes are used, silently bypassing the constraint validation.

Either add a post-resolution check after the wrapper resolves the backend, or document this limitation with the --backends argument.

🧹 Nitpick comments (1)
benchmarks/routines/attention.py (1)

511-513: Accessing private _backend attribute is fragile.

backend_wrappers[backend]._backend relies on a private/internal attribute of the wrapper to introspect the resolved backend. This works for benchmarking purposes, but could break silently if the library renames or removes this attribute.

Consider checking if there's an official public API (e.g., a backend property) to query the resolved backend. If not, a defensive getattr(wrapper, '_backend', backend) fallback would prevent crashes.

Defensive fallback
-            resolved_backends[backend] = backend_wrappers[backend]._backend
+            resolved_backends[backend] = getattr(backend_wrappers[backend], '_backend', backend)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/routines/attention.py (1)

102-114: ⚠️ Potential issue | 🟡 Minor

"auto" backend lacks runtime validation for unsupported routines.

The help text correctly states that auto is only supported for BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper, but there's no runtime guard. If a user passes --backends auto with BatchPrefillWithRaggedKVCacheWrapper or BatchMLAPagedAttentionWrapper, it will crash with an unhelpful AttributeError (calling .detach() on None returned from the run_backend_wrapper fallback).

Consider adding an early validation in those routines to strip auto with a clear message, or validate in parse_attention_args against the routine.

🧹 Nitpick comments (2)
benchmarks/routines/attention.py (2)

1866-1890: Output schema inconsistency: resolved_backend field missing in Ragged and MLA routines.

The decode and paged-prefill routines now emit a resolved_backend field in their output records, but testBatchPrefillWithRaggedKVCacheWrapper (and testBatchMLAPagedAttentionWrapper) do not. If downstream tooling consumes these records with a uniform schema, this will cause KeyError or missing-column issues.

Consider adding cur_res["resolved_backend"] = backend in these routines for schema consistency, even if auto isn't supported there.

Proposed fix for Ragged (similar for MLA)
                 cur_res["backend"] = backend
+                cur_res["resolved_backend"] = backend
                 cur_res["page_size"] = 0  # No page size for ragged

471-504: _backend accesses an internal attribute and may break if FlashInfer internals change.

The code accesses backend_wrappers[backend]._backend (lines 502, 1075, 1105) to retrieve the resolved backend from the wrapper object. The underscore prefix indicates this is a private/internal attribute rather than part of the public API. While this is acceptable for benchmark tooling, the pattern is fragile and could break if FlashInfer updates its wrapper implementation.

For benchmark purposes, this is a minor concern since the code will fail loudly at runtime if the attribute disappears. A more robust approach would be to check if FlashInfer exposes a public accessor method, but the current implementation is pragmatic for non-production benchmarking code.

@saltyminty
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !304 has been created, and the CI pipeline #43647951 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43647951: canceled

@saltyminty
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !304 has been updated with latest changes, and the CI pipeline #43649627 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43649627: 4/20 passed

@saltyminty saltyminty requested a review from bkryu February 11, 2026 17:34
@saltyminty saltyminty force-pushed the fix/mingyangw/0_6_1_decode_auto_backend branch from 52841e5 to 8f15557 Compare February 12, 2026 21:04
@saltyminty
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !304 has been updated with latest changes, and the CI pipeline #43914952 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/decode.py`:
- Around line 1042-1056: The code mutates self._backend from "auto" to a
concrete value, preventing re-evaluation on subsequent plan() calls; change to
preserve the user request (e.g., add a new attribute like
self._requested_backend) and always resolve the effective backend from that on
each plan() invocation: keep self._requested_backend (set from constructor/args)
and in plan() check if self._requested_backend == "auto" then call
determine_attention_backend(...) with PosEncodingMode[pos_encoding_mode].value,
q_data_type, kv_data_type, etc., otherwise use the explicitly requested backend,
and assign the resolved value to a temporary/effective variable (not overwrite
self._requested_backend) or to self._backend only as a cached effective_backend
that can be recomputed next call.
🧹 Nitpick comments (1)
benchmarks/README.md (1)

187-187: Consider adding "auto" to the backend legend for completeness.

The documentation correctly adds "auto" to the available backends list and specifies its current support scope. However, the backend legend at lines 447-460 does not include an entry for "auto", which could leave users unclear about what this backend does (e.g., automatically selects the optimal backend based on data types, as mentioned in the PR objectives).

📚 Suggested documentation enhancement

Add an entry for "auto" in the Backend Legend section (after line 460):

 - cuda: FlashInfer CUDA kernels
 - cute-dsl: FlashInfer CuTe-DSL kernels (Blackwell SM10.0+)
 - moe_a2a: MoE All-to-All communication (requires mpirun, Blackwell SM10.0+ with MNNVL)
+- auto: Automatically selects the optimal backend based on workload characteristics (e.g., FP8 vs non-FP8 data types). Currently supported for BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper.

@saltyminty
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !304 has been created, and the CI pipeline #43918433 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43918433: 11/20 passed

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 merged commit 55ba155 into main Feb 17, 2026
25 of 35 checks passed
@yzh119 yzh119 deleted the fix/mingyangw/0_6_1_decode_auto_backend branch February 17, 2026 16:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants