-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[V1][CUDA] Full cudagraph support for FlashInfer #21367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
210be12
9db6e4d
1928556
29e0fe8
a571e00
d08bf08
9f3839a
177afe1
e634bd5
a5d260e
c4694d2
54533b1
ef5f0fa
d573d29
393a573
c984155
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Optional | ||
| from typing import TYPE_CHECKING, ClassVar, Optional | ||
|
|
||
| import torch | ||
| from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, | ||
|
|
@@ -18,10 +18,11 @@ | |
| from vllm.config import VllmConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils import cdiv | ||
| from vllm.v1.attention.backends.flash_attn import use_cascade_attention | ||
| from vllm.v1.attention.backends.utils import ( | ||
| AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, | ||
| get_kv_cache_layout, get_per_layer_parameters, | ||
| AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, | ||
| PerLayerParameters, get_kv_cache_layout, get_per_layer_parameters, | ||
| infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, | ||
| split_decodes_and_prefills) | ||
| from vllm.v1.kv_cache_interface import AttentionSpec | ||
|
|
@@ -223,21 +224,49 @@ def __post_init__(self): | |
|
|
||
|
|
||
| class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): | ||
| attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ | ||
| AttentionCGSupport.PURE_DECODE_ONLY | ||
|
|
||
| def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, | ||
| device: torch.device): | ||
| self.device = device | ||
| self.vllm_config = vllm_config | ||
| self.cache_config = vllm_config.cache_config | ||
| self.kv_cache_spec = kv_cache_spec | ||
| self._workspace_buffer = None | ||
| self._prefill_wrapper = None # Wrapper for prefill/append | ||
| self._decode_wrapper = None # Wrapper for decode | ||
| self._decode_wrapper = None # Wrapper for decode (general shape) | ||
|
|
||
| self.compilation_config = vllm_config.compilation_config | ||
| max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, | ||
| self.kv_cache_spec.block_size) | ||
| max_num_reqs = vllm_config.scheduler_config.max_num_seqs | ||
| max_num_pages = max_num_reqs * max_num_pages_per_req | ||
| self.enable_cuda_graph = self.compilation_config.full_cuda_graph | ||
| if self.enable_cuda_graph: | ||
| # For full cudagraph capture, one `decode_wrapper` for each batch | ||
| # size is needed for FlashInfer. | ||
| self._decode_wrappers_cudagraph: dict[ | ||
| int, BatchDecodeWithPagedKVCacheWrapper] = {} | ||
| self._decode_cudagraph_max_bs = min( | ||
| max_num_reqs, self.compilation_config.max_capture_size) | ||
|
|
||
| self._cascade_wrapper = None # Wrapper for cascade attention | ||
|
|
||
| # Global hyperparameters shared by all attention layers | ||
| self.global_hyperparameters: Optional[PerLayerParameters] = None | ||
|
|
||
| self.vllm_config = vllm_config | ||
| self.cache_config = vllm_config.cache_config | ||
| self.kv_cache_spec = kv_cache_spec | ||
| # Preparing persistent buffers | ||
| self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.paged_kv_indices = torch.zeros( | ||
| max_num_pages, # max num pages possible | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
| self.paged_kv_last_page_len = torch.zeros(max_num_reqs, | ||
| dtype=torch.int32, | ||
| device=self.device) | ||
|
|
||
| def reorder_batch(self, input_batch: InputBatch, | ||
| scheduler_output: SchedulerOutput) -> bool: | ||
|
|
@@ -259,20 +288,49 @@ def _get_prefill_wrapper(self): | |
| self._get_workspace_buffer(), get_kv_cache_layout()) | ||
| return self._prefill_wrapper | ||
|
|
||
| def _get_decode_wrapper(self): | ||
| if self._decode_wrapper is None: | ||
| def _get_decode_wrapper(self, | ||
| batch_size: int, | ||
| use_cudagraph: bool = False): | ||
| if use_cudagraph: | ||
| decode_wrapper = self._decode_wrappers_cudagraph.get( | ||
| batch_size, None) | ||
| else: | ||
| decode_wrapper = self._decode_wrapper | ||
|
|
||
| if decode_wrapper is None: | ||
| num_qo_heads = ( | ||
| self.vllm_config.model_config.get_num_attention_heads( | ||
| self.vllm_config.parallel_config)) | ||
| num_kv_heads = self.vllm_config.model_config.get_num_kv_heads( | ||
| self.vllm_config.parallel_config) | ||
| use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( | ||
| num_qo_heads // num_kv_heads > 4) | ||
| self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( | ||
|
|
||
| if use_cudagraph: | ||
| paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] | ||
| paged_kv_indices = self.paged_kv_indices | ||
| paged_kv_last_page_len = self.paged_kv_last_page_len[: | ||
| batch_size] | ||
| else: | ||
| paged_kv_indptr = None | ||
| paged_kv_indices = None | ||
| paged_kv_last_page_len = None | ||
| decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( | ||
| self._get_workspace_buffer(), | ||
| get_kv_cache_layout(), | ||
| use_cuda_graph=use_cudagraph, | ||
| paged_kv_indptr_buffer=paged_kv_indptr, | ||
| paged_kv_indices_buffer=paged_kv_indices, | ||
| paged_kv_last_page_len_buffer=paged_kv_last_page_len, | ||
| use_tensor_cores=use_tensor_cores) | ||
| return self._decode_wrapper | ||
|
|
||
| # save the decode wrapper | ||
| if use_cudagraph: | ||
| self._decode_wrappers_cudagraph[batch_size] = decode_wrapper | ||
| else: | ||
| self._decode_wrapper = decode_wrapper | ||
|
|
||
| return decode_wrapper | ||
|
|
||
| def _get_cascade_wrapper(self): | ||
| if self._cascade_wrapper is None: | ||
|
|
@@ -350,16 +408,34 @@ def _plan(self, num_prefills: int, num_decodes: int, | |
| ) | ||
|
|
||
| if num_decodes > 0: | ||
| attn_metadata.decode_wrapper = self._get_decode_wrapper() | ||
| pure_decode = num_prefills == 0 | ||
| # possible required padding for cudagraph replay | ||
| use_cudagraph = (self.enable_cuda_graph and pure_decode and | ||
| num_decodes <= self._decode_cudagraph_max_bs) | ||
| if use_cudagraph: | ||
| num_input_tokens_decode = ( | ||
| self.vllm_config.pad_for_cudagraph(num_decodes)) | ||
| else: | ||
| num_input_tokens_decode = num_decodes | ||
|
|
||
| attn_metadata.decode_wrapper = self._get_decode_wrapper( | ||
| num_input_tokens_decode, use_cudagraph) | ||
| if not FlashInferBackend.use_trtllm_decode_attention( | ||
| num_decodes, attn_metadata.max_seq_len, | ||
| self.cache_config.cache_dtype, | ||
| attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, | ||
| attn_metadata.head_dim): | ||
| # TODO: Override flashinfer's plan function to avoid some | ||
| # host-to-device copy overhead. | ||
| attn_metadata.decode_wrapper.plan( | ||
| attn_metadata.paged_kv_indptr[:num_decodes + 1], | ||
| attn_metadata.paged_kv_indices, | ||
| attn_metadata.paged_kv_last_page_len[:num_decodes], | ||
| # NOTE: Use the persistent buffer with padding length, | ||
| # instead of the same address but chunked length buffers | ||
| # in the atten_metadata. This is to be compatible with | ||
| # FlashInfer's decode_wrapper when using cudagraph. | ||
| self.paged_kv_indptr[:num_input_tokens_decode + 1], | ||
| self.paged_kv_indices if use_cudagraph else \ | ||
| attn_metadata.paged_kv_indices, | ||
| self.paged_kv_last_page_len[:num_input_tokens_decode], | ||
| attn_metadata.num_qo_heads, | ||
| attn_metadata.num_kv_heads, | ||
| attn_metadata.head_dim, | ||
|
|
@@ -378,6 +454,7 @@ def build(self, | |
| common_prefix_len: int, | ||
| common_attn_metadata: CommonAttentionMetadata, | ||
| fast_build: bool = False) -> FlashInferMetadata: | ||
| num_reqs = common_attn_metadata.num_reqs | ||
| num_actual_tokens = common_attn_metadata.num_actual_tokens | ||
| num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ | ||
| split_decodes_and_prefills(common_attn_metadata) | ||
|
|
@@ -421,17 +498,31 @@ def build(self, | |
| device=block_table_tensor.device).unsqueeze(0) | ||
| < block_table_bounds.unsqueeze(1)) | ||
| paged_kv_indices = block_table_tensor[mask] | ||
| num_actual_pages = paged_kv_indices.size(0) | ||
| self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, | ||
| non_blocking=True) | ||
| self.paged_kv_indices[num_actual_pages:].fill_(-1) | ||
|
|
||
| paged_kv_indptr = torch.cat([ | ||
| torch.zeros(1, | ||
| dtype=block_table_bounds.dtype, | ||
| device=block_table_bounds.device), | ||
| block_table_bounds.cumsum(dim=0, dtype=torch.int32) | ||
| ]) | ||
| self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, | ||
| non_blocking=True) | ||
| # make sure self.paged_kv_indptr is not decreasing | ||
| self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) | ||
|
|
||
| paged_kv_last_page_len = seq_lens % page_size | ||
| paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, | ||
| page_size, paged_kv_last_page_len) | ||
| self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len, | ||
| non_blocking=True) | ||
| # Fill the remaining paged_kv_last_page_len with 1. This is because | ||
| # flashinfer treats 0 as a full page instead of empty. | ||
| self.paged_kv_last_page_len[num_reqs:].fill_(1) | ||
|
|
||
| cache_dtype = self.cache_config.cache_dtype | ||
| if cache_dtype.startswith("fp8"): | ||
| kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( | ||
|
|
@@ -441,9 +532,9 @@ def build(self, | |
| attn_metadata = FlashInferMetadata( | ||
| num_actual_tokens=num_actual_tokens, | ||
| qo_indptr=qo_indptr, | ||
| paged_kv_indptr=paged_kv_indptr, | ||
| paged_kv_indices=paged_kv_indices, | ||
| paged_kv_last_page_len=paged_kv_last_page_len, | ||
| paged_kv_indptr=self.paged_kv_indptr[:1 + num_reqs], | ||
| paged_kv_indices=self.paged_kv_indices[:num_actual_pages], | ||
| paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs], | ||
| num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( | ||
| self.vllm_config.parallel_config), | ||
| num_kv_heads=self.kv_cache_spec.num_kv_heads, | ||
|
|
@@ -471,6 +562,26 @@ def build(self, | |
|
|
||
| return attn_metadata | ||
|
|
||
| def build_for_cudagraph_capture( | ||
| self, common_attn_metadata: CommonAttentionMetadata): | ||
| """ | ||
| This method builds the metadata for full cudagraph capture. | ||
| Currently, only decode is supported for full cudagraphs with FlashInfer. | ||
| """ | ||
| m = common_attn_metadata | ||
|
|
||
| assert m.num_reqs == m.num_actual_tokens, \ | ||
| "FlashInfer only supports decode-only full CUDAGraph capture. " \ | ||
| "Make sure all cudagraph capture sizes <= max_num_seq." | ||
|
|
||
| m.max_query_len = 1 # decode-only | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently. |
||
|
|
||
| return self.build(0, m) | ||
|
|
||
| def can_run_in_cudagraph( | ||
| self, common_attn_metadata: CommonAttentionMetadata) -> bool: | ||
| return common_attn_metadata.max_query_len == 1 | ||
|
|
||
| def use_cascade_attention(self, *args, **kwargs) -> bool: | ||
| if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: | ||
| # TODO: The cascade wrapper currently does not support setting | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do we need a new name here? can we just use
num_decodes?i.e.
so then we dont need the else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I keep these as two variables for filling the padding region. See the codes below.