feat: support cross-attn cuda graph for whisper#21161
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates CUDA graph support for Whisper's cross-attention, which is crucial for optimizing the inference speed of encoder-decoder models. The changes involve a significant refactoring of Whisper's attention mechanism to leverage a more generalized attention module, alongside critical updates to the CUDA graph runner to manage encoder-specific sequence lengths and dynamic padding. Additionally, the system's configuration now intelligently enables CUDA graph for Whisper only when the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
The code changes involve modifications to the flashinfer backend, cuda graph runner, whisper model, and server arguments. In flashinfer_backend.py, the code was changed to assign paged_kernel_lens_cpu based on attention type. In cuda_graph_runner.py, a new parameter encoder_len_fill_value was added, and logic was added to fill encoder lengths. In whisper.py, the attention mechanism was refactored. In server_args.py, the logic for disabling CUDA graph for Whisper models was updated. The reviewer suggested simplifying the condition in flashinfer_backend.py to avoid redundancy, renaming encoder_len_fill_value to something more descriptive, extracting the logic for getting max source positions into a separate function, and creating a function to encapsulate the logic for disabling CUDA graph for Whisper models to improve maintainability.
| paged_kernel_lens_cpu = ( | ||
| encoder_lens.to(device="cpu", non_blocking=False) | ||
| if encoder_lens is not None | ||
| else None | ||
| ) |
There was a problem hiding this comment.
The condition if encoder_lens is not None is checked twice, which is redundant. Consider simplifying this by assigning a default value to paged_kernel_lens_cpu outside the if-else block.
paged_kernel_lens_cpu = None
if wrapper_id == 0:
# Normal attention
paged_kernel_lens = seq_lens
kv_start_idx = encoder_lens
paged_kernel_lens_cpu = seq_lens_cpu
else:
# Cross attention
paged_kernel_lens = encoder_lens
kv_start_idx = torch.zeros_like(encoder_lens)
seq_lens_sum = encoder_lens.sum().item()
paged_kernel_lens_cpu = encoder_lens.to(device="cpu", non_blocking=False) if encoder_lens is not None else None| raw_num_token: int, | ||
| bs: int, | ||
| seq_len_fill_value: int, | ||
| encoder_len_fill_value: int, |
| if self.is_encoder_decoder: | ||
| self.encoder_len_fill_value = int( | ||
| getattr( | ||
| self.model_runner.model_config.hf_config, | ||
| "max_source_positions", | ||
| 1, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Consider extracting this logic into a separate function to improve readability and maintainability. The function could be named _get_max_source_positions or similar.
def _get_max_source_positions(self):
return int(
getattr(
self.model_runner.model_config.hf_config,
"max_source_positions",
1,
)
)
def __init__(self, model_runner: ModelRunner):
...
self.encoder_len_fill_value = 0
if self.is_encoder_decoder:
self.encoder_len_fill_value = self._get_max_source_positions()| elif ( | ||
| effective_prefill_backend != "flashinfer" | ||
| or effective_decode_backend != "flashinfer" | ||
| ): | ||
| logger.warning( | ||
| "Cuda graph is disabled for Whisper unless both prefill and decode attention backends are flashinfer" | ||
| ) | ||
| self.disable_cuda_graph = True |
There was a problem hiding this comment.
The logic to disable CUDA graph for Whisper models is duplicated here. Consider creating a function to encapsulate this logic and reuse it in both places to improve maintainability.
def _disable_cuda_graph_for_whisper(self):
model_architectures = model_config.hf_config.architectures or []
is_whisper = "WhisperForConditionalGeneration" in model_architectures
effective_prefill_backend = (
self.prefill_attention_backend or self.attention_backend
)
effective_decode_backend = (
self.decode_attention_backend or self.attention_backend
)
if not is_whisper:
logger.warning(
"Cuda graph is disabled for non-Whisper encoder-decoder models"
)
self.disable_cuda_graph = True
elif (
effective_prefill_backend != "flashinfer"
or effective_decode_backend != "flashinfer"
):
logger.warning(
"Cuda graph is disabled for Whisper unless both prefill and decode attention backends are flashinfer"
)
self.disable_cuda_graph = True
# Whisper currently only supports cuda graph on the native flashinfer
# cross-attention path.
if model_config.is_encoder_decoder:
self._disable_cuda_graph_for_whisper()| attn_output = torch.bmm(attn_weights, v) | ||
| attn_output = attn_output.transpose(0, 1) | ||
| attn_output = attn_output.reshape(q_len, num_heads * head_dim) | ||
| q = q.view(-1, self.attn.tp_q_head_num, self.attn.head_dim) |
There was a problem hiding this comment.
replaced hand-written cross-attn with native backend (RadixAttention)
| paged_kernel_lens = encoder_lens | ||
| kv_start_idx = torch.zeros_like(encoder_lens) | ||
| seq_lens_sum = encoder_lens.sum().item() | ||
| paged_kernel_lens_cpu = ( |
There was a problem hiding this comment.
should be the encoder length in encoder-decoder arch
Previously, whisper model use custom
bmm + maskimplementation of cross-attn, which make it in-compatible with cuda graphMotivation
Modifications
RadixAttentionpath instead of the custombmm + maskimplementation.server_args: explicitly allow Whisper only when both prefill and decode useflashinferflashinferdecode cross-attention CUDA-graph planning to use encoder-side CPU lengths, instead of incorrectly reusing decoderseq_lens_cpuAccuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci