Skip to content

feat: support cross-attn cuda graph for whisper#21161

Closed
mickqian wants to merge 4 commits intomainfrom
whisper-cuda-graph
Closed

feat: support cross-attn cuda graph for whisper#21161
mickqian wants to merge 4 commits intomainfrom
whisper-cuda-graph

Conversation

@mickqian
Copy link
Copy Markdown
Collaborator

@mickqian mickqian commented Mar 23, 2026

Previously, whisper model use custom bmm + mask implementation of cross-attn, which make it in-compatible with cuda graph

Motivation

Modifications

  • Reworked cross-attention to use native RadixAttention path instead of the custom bmm + mask implementation.
  • Removed the Python-side Whisper encoder cache, switched encoder outputs to the native encoder-decoder cache flow managed by scheduler/attention backends.
  • Fixed CUDA graph replay padding for encoder-decoder batches by initializing replay buffers with valid encoder length metadata.
  • Narrowed CUDA graph enablement in server_args: explicitly allow Whisper only when both prefill and decode useflashinfer
  • Fixed flashinfer decode cross-attention CUDA-graph planning to use encoder-side CPU lengths, instead of incorrectly reusing decoder seq_lens_cpu

Accuracy Tests

# launch server
$ sglang serve \
    --model-path openai/whisper-tiny \
    --attention-backend flashinfer \
    --host 0.0.0.0 \
    --port 31018 \
    --cuda-graph-max-bs 2

# benchmark
$ python benchmark/asr/bench_sglang.py \
    --base-url http://127.0.0.1:31011 \
    --model openai/whisper-tiny \
    --api-type transcription \
    --language en \
    --concurrency 1 \
    -n 100
  • no-cg: WER 176.8765, avg latency 0.0980s, throughput 8.96 req/s
  • cuda-graph: WER 176.8765, avg latency 0.0503s, throughput 15.60 req/s

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates CUDA graph support for Whisper's cross-attention, which is crucial for optimizing the inference speed of encoder-decoder models. The changes involve a significant refactoring of Whisper's attention mechanism to leverage a more generalized attention module, alongside critical updates to the CUDA graph runner to manage encoder-specific sequence lengths and dynamic padding. Additionally, the system's configuration now intelligently enables CUDA graph for Whisper only when the flashinfer attention backend is consistently used for both prefill and decode operations.

Highlights

  • CUDA Graph Support for Whisper Cross-Attention: Enabled CUDA graph support specifically for Whisper's cross-attention mechanism, aiming to enhance inference performance for encoder-decoder models.
  • Whisper Cross-Attention Refactoring: Refactored the WhisperAttention module to delegate cross-attention computation to the self.attn module, simplifying the attention logic and removing previous manual masking and matrix multiplication.
  • CUDA Graph Runner Enhancements: Introduced new logic within the CUDA graph runner to correctly handle encoder_lens and implement proper padding for req_pool_indices, input_ids, and positions when batch sizes or token counts vary during graph replay.
  • Conditional CUDA Graph Activation: Updated server arguments to conditionally enable CUDA graph for Whisper models, requiring the flashinfer backend for both prefill and decode attention to activate this feature.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code changes involve modifications to the flashinfer backend, cuda graph runner, whisper model, and server arguments. In flashinfer_backend.py, the code was changed to assign paged_kernel_lens_cpu based on attention type. In cuda_graph_runner.py, a new parameter encoder_len_fill_value was added, and logic was added to fill encoder lengths. In whisper.py, the attention mechanism was refactored. In server_args.py, the logic for disabling CUDA graph for Whisper models was updated. The reviewer suggested simplifying the condition in flashinfer_backend.py to avoid redundancy, renaming encoder_len_fill_value to something more descriptive, extracting the logic for getting max source positions into a separate function, and creating a function to encapsulate the logic for disabling CUDA graph for Whisper models to improve maintainability.

Comment on lines +1062 to +1066
paged_kernel_lens_cpu = (
encoder_lens.to(device="cpu", non_blocking=False)
if encoder_lens is not None
else None
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition if encoder_lens is not None is checked twice, which is redundant. Consider simplifying this by assigning a default value to paged_kernel_lens_cpu outside the if-else block.

paged_kernel_lens_cpu = None
if wrapper_id == 0:
    # Normal attention
    paged_kernel_lens = seq_lens
    kv_start_idx = encoder_lens
    paged_kernel_lens_cpu = seq_lens_cpu
else:
    # Cross attention
    paged_kernel_lens = encoder_lens
    kv_start_idx = torch.zeros_like(encoder_lens)
    seq_lens_sum = encoder_lens.sum().item()
    paged_kernel_lens_cpu = encoder_lens.to(device="cpu", non_blocking=False) if encoder_lens is not None else None

raw_num_token: int,
bs: int,
seq_len_fill_value: int,
encoder_len_fill_value: int,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

A new parameter encoder_len_fill_value is added, but it's not clear from the name what it's used for. Consider renaming it to something more descriptive, such as padded_encoder_len.

Suggested change
encoder_len_fill_value: int,
padded_encoder_len: int,

Comment on lines +622 to +629
if self.is_encoder_decoder:
self.encoder_len_fill_value = int(
getattr(
self.model_runner.model_config.hf_config,
"max_source_positions",
1,
)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider extracting this logic into a separate function to improve readability and maintainability. The function could be named _get_max_source_positions or similar.

    def _get_max_source_positions(self):
        return int(
            getattr(
                self.model_runner.model_config.hf_config,
                "max_source_positions",
                1,
            )
        )

    def __init__(self, model_runner: ModelRunner):
        ...
        self.encoder_len_fill_value = 0
        if self.is_encoder_decoder:
            self.encoder_len_fill_value = self._get_max_source_positions()

Comment on lines +2275 to +2282
elif (
effective_prefill_backend != "flashinfer"
or effective_decode_backend != "flashinfer"
):
logger.warning(
"Cuda graph is disabled for Whisper unless both prefill and decode attention backends are flashinfer"
)
self.disable_cuda_graph = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to disable CUDA graph for Whisper models is duplicated here. Consider creating a function to encapsulate this logic and reuse it in both places to improve maintainability.

            def _disable_cuda_graph_for_whisper(self):
                model_architectures = model_config.hf_config.architectures or []
                is_whisper = "WhisperForConditionalGeneration" in model_architectures
                effective_prefill_backend = (
                    self.prefill_attention_backend or self.attention_backend
                )
                effective_decode_backend = (
                    self.decode_attention_backend or self.attention_backend
                )
                if not is_whisper:
                    logger.warning(
                        "Cuda graph is disabled for non-Whisper encoder-decoder models"
                    )
                    self.disable_cuda_graph = True
                elif (
                    effective_prefill_backend != "flashinfer"
                    or effective_decode_backend != "flashinfer"
                ):
                    logger.warning(
                        "Cuda graph is disabled for Whisper unless both prefill and decode attention backends are flashinfer"
                    )
                    self.disable_cuda_graph = True

        # Whisper currently only supports cuda graph on the native flashinfer
        # cross-attention path.
        if model_config.is_encoder_decoder:
            self._disable_cuda_graph_for_whisper()

attn_output = torch.bmm(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(q_len, num_heads * head_dim)
q = q.view(-1, self.attn.tp_q_head_num, self.attn.head_dim)
Copy link
Copy Markdown
Collaborator Author

@mickqian mickqian Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced hand-written cross-attn with native backend (RadixAttention)

paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
seq_lens_sum = encoder_lens.sum().item()
paged_kernel_lens_cpu = (
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be the encoder length in encoder-decoder arch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant