-
-
Notifications
You must be signed in to change notification settings - Fork 16.7k
[Core] Refactor _prepare_model_input_tensors - take 2 #6164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,14 @@ | ||
| from vllm.attention.backends.abstract import (AttentionBackend, | ||
| AttentionMetadata) | ||
| AttentionMetadata, | ||
| AttentionMetadataBuilder) | ||
| from vllm.attention.layer import Attention | ||
| from vllm.attention.selector import get_attn_backend | ||
|
|
||
| __all__ = [ | ||
| "Attention", | ||
| "AttentionBackend", | ||
| "AttentionMetadata", | ||
| "AttentionMetadataBuilder", | ||
| "Attention", | ||
| "get_attn_backend", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,24 @@ | ||
| """Attention layer with FlashAttention.""" | ||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, List, Optional, Tuple, Type | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type | ||
|
|
||
| import torch | ||
| from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
| AttentionMetadata, AttentionType) | ||
| AttentionMetadata, | ||
| AttentionMetadataBuilder, | ||
| AttentionType) | ||
| from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, | ||
| compute_slot_mapping_start_idx, | ||
| is_block_tables_empty) | ||
| from vllm.sequence import SequenceGroupMetadata | ||
| from vllm.utils import make_tensor_with_pad | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.worker.model_runner import (GPUModelRunnerBase, | ||
| ModelInputForGPUBuilder) | ||
|
|
||
|
|
||
| class FlashAttentionBackend(AttentionBackend): | ||
|
|
@@ -28,6 +39,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: | |
| def get_metadata_cls() -> Type["AttentionMetadata"]: | ||
| return FlashAttentionMetadata | ||
|
|
||
| @staticmethod | ||
| def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: | ||
| return FlashAttentionMetadataBuilder | ||
|
|
||
| @staticmethod | ||
| def get_kv_cache_shape( | ||
| num_blocks: int, | ||
|
|
@@ -184,6 +199,170 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: | |
| return self._cached_decode_metadata | ||
|
|
||
|
|
||
| class FlashAttentionMetadataBuilder( | ||
| AttentionMetadataBuilder[FlashAttentionMetadata]): | ||
|
|
||
| def __init__(self, input_builder: "ModelInputForGPUBuilder"): | ||
| self.slot_mapping: List[int] = [] | ||
| self.prefill_seq_lens: List[int] = [] | ||
| self.context_lens: List[int] = [] | ||
| self.block_tables: List[List[int]] = [] | ||
| self.curr_seq_lens: List[int] = [] | ||
| self.num_prefills = 0 | ||
| self.num_prefill_tokens = 0 | ||
| self.num_decode_tokens = 0 | ||
|
|
||
| self.sliding_window = input_builder.sliding_window | ||
| self.block_size = input_builder.block_size | ||
| self.use_v2_block_manager = ( | ||
| input_builder.scheduler_config.use_v2_block_manager) | ||
|
|
||
| def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, | ||
| token_lens: List[int], seq_lens: List[int], | ||
| curr_seq_lens: List[int], query_lens: List[int], | ||
| context_lens: List[int], | ||
| curr_sliding_window_blocks: List[int], | ||
| prefix_cache_hit: bool, chunked_prefill_enabled: bool): | ||
| """Add a sequence group to the metadata. Specifically update/append | ||
| 1. context length. | ||
| 2. block table. | ||
| 3. slot mapping. | ||
| """ | ||
| is_prompt = seq_group_metadata.is_prompt | ||
| block_tables = seq_group_metadata.block_tables | ||
|
|
||
| for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, | ||
| curr_sliding_window_block) in zip( | ||
| seq_group_metadata.seq_data.keys(), token_lens, seq_lens, | ||
| curr_seq_lens, query_lens, context_lens, | ||
| curr_sliding_window_blocks): | ||
| self.context_lens.append(context_len) | ||
|
|
||
| if is_prompt: | ||
| self.num_prefills += 1 | ||
| self.num_prefill_tokens += token_len | ||
| self.prefill_seq_lens.append(seq_len) | ||
| else: | ||
| assert query_len == 1, ( | ||
| "seq_len: {}, context_len: {}, query_len: {}".format( | ||
| seq_len, context_len, query_len)) | ||
| self.num_decode_tokens += query_len | ||
| self.curr_seq_lens.append(curr_seq_len) | ||
|
|
||
| # Compute block table. | ||
| # TODO(sang): Combine chunked prefill and prefix caching by | ||
| # only allowing multiple of block_size chunk size. | ||
| # NOTE: This only works for oooooooxxx style attention. | ||
| block_table = [] | ||
| if prefix_cache_hit: | ||
| # NOTE(woosuk): For flash-attn, the block table should | ||
| # include the entries for the incoming prefill tokens. | ||
| block_table = block_tables[seq_id] | ||
| elif ((chunked_prefill_enabled or not is_prompt) | ||
| and block_tables is not None): | ||
| block_table = block_tables[seq_id][-curr_sliding_window_block:] | ||
| self.block_tables.append(block_table) | ||
|
|
||
| # Compute slot mapping. | ||
| is_profile_run = is_block_tables_empty(block_tables) | ||
| start_idx = compute_slot_mapping_start_idx( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah the only reason is if we combine these 2 then the argument list would be more ugly: compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len,
self.block_size,
seq_group_metadata.block_tables,
is_prompt, query_len, context_len, self.sliding_window, self.use_v2_block_manager)I don't have preference tho so if you prefer to combine them I could do that. Plz let me know
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer to combine it if you don't mind! Agreed there are too many args, but I think it is better than having 2 calls depending on each other (imo it is harder to use). But it is nit!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I'll merge this one and do it in the next PR. |
||
| is_prompt, query_len, context_len, self.sliding_window, | ||
| self.use_v2_block_manager) | ||
| compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, | ||
| seq_len, context_len, start_idx, | ||
| self.block_size, | ||
| seq_group_metadata.block_tables) | ||
|
|
||
| def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, | ||
| cuda_graph_pad_size: int, batch_size: int): | ||
| """Build attention metadata with on-device tensors.""" | ||
| device = runner.device | ||
| use_captured_graph = cuda_graph_pad_size != -1 | ||
|
|
||
| logits_soft_cap = getattr(runner.model_config.hf_config, | ||
| "attn_logit_softcapping", None) | ||
| if logits_soft_cap is not None: | ||
| raise ValueError( | ||
| "Please use Flashinfer backend for models with logits_soft_cap" | ||
| " (i.e., Gemma-2). Otherwise, the output might be wrong." | ||
| " Set Flashinfer backend by " | ||
| "export VLLM_ATTENTION_BACKEND=FLASHINFER.") | ||
|
|
||
| max_query_len = max(query_lens) | ||
| max_prefill_seq_len = max(self.prefill_seq_lens, default=0) | ||
| max_decode_seq_len = max(self.curr_seq_lens, default=0) | ||
| num_decode_tokens = self.num_decode_tokens | ||
|
|
||
| if use_captured_graph: | ||
| self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) | ||
| self.block_tables.extend([] * cuda_graph_pad_size) | ||
| num_decode_tokens = batch_size + cuda_graph_pad_size | ||
|
|
||
| # The shape of graph_block_tables is | ||
| # [max batch size, max context len // block size]. | ||
| input_block_tables = runner.graph_block_tables[:batch_size] | ||
| for i, block_table in enumerate(self.block_tables): | ||
| if block_table: | ||
| input_block_tables[i, :len(block_table)] = block_table | ||
| block_tables = torch.tensor(input_block_tables, device=device) | ||
| else: | ||
| max_block_table_len = max( | ||
| len(block_table) for block_table in self.block_tables) | ||
| block_tables = make_tensor_with_pad( | ||
| self.block_tables, | ||
| max_len=max_block_table_len, | ||
| pad=0, | ||
| dtype=torch.int, | ||
| device=device, | ||
| ) | ||
| assert max_query_len > 0, ("query_lens: {}".format(query_lens)) | ||
|
|
||
| context_lens_tensor = torch.tensor(self.context_lens, | ||
| dtype=torch.int, | ||
| device=device) | ||
| seq_lens_tensor = torch.tensor(seq_lens, | ||
| dtype=torch.int, | ||
| device=device) | ||
| query_lens_tensor = torch.tensor(query_lens, | ||
| dtype=torch.long, | ||
| device=device) | ||
| query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, | ||
| dtype=torch.int32, | ||
| device=device) | ||
| seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, | ||
| dtype=torch.int32, | ||
| device=device) | ||
| torch.cumsum(seq_lens_tensor, | ||
| dim=0, | ||
| dtype=seq_start_loc.dtype, | ||
| out=seq_start_loc[1:]) | ||
| torch.cumsum(query_lens_tensor, | ||
| dim=0, | ||
| dtype=query_start_loc.dtype, | ||
| out=query_start_loc[1:]) | ||
|
|
||
| slot_mapping_tensor = torch.tensor(self.slot_mapping, | ||
| dtype=torch.long, | ||
| device=device) | ||
|
|
||
| return FlashAttentionMetadata( | ||
| num_prefills=self.num_prefills, | ||
| slot_mapping=slot_mapping_tensor, | ||
| num_prefill_tokens=self.num_prefill_tokens, | ||
| num_decode_tokens=num_decode_tokens, | ||
| seq_lens=seq_lens, | ||
| seq_lens_tensor=seq_lens_tensor, | ||
| max_query_len=max_query_len, | ||
| max_prefill_seq_len=max_prefill_seq_len, | ||
| max_decode_seq_len=max_decode_seq_len, | ||
| query_start_loc=query_start_loc, | ||
| seq_start_loc=seq_start_loc, | ||
| context_lens_tensor=context_lens_tensor, | ||
| block_tables=block_tables, | ||
| use_cuda_graph=use_captured_graph, | ||
| ) | ||
|
|
||
|
|
||
| class FlashAttentionImpl(AttentionImpl): | ||
| """ | ||
| If the input tensors contain prompt tokens, the layout is as follows: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.