diff --git a/tests/kernels/test_flashinfer_mla_decode.py b/tests/kernels/test_flashinfer_mla_decode.py new file mode 100644 index 000000000000..2524d93b5ce4 --- /dev/null +++ b/tests/kernels/test_flashinfer_mla_decode.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F +from torch import Tensor + +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla +from vllm.platforms import current_platform + +FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="FlashInfer MLA Requires compute capability of 10 or above.", + allow_module_level=True) + + +def ref_mla( + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) +): + bs, num_heads, v_head_dim = out.shape + head_dim = query.shape[2] + + for i in range(bs): + # gather and flatten KV-cache + kv = kv_cache[ + block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, + head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + v = kv[:, :, :v_head_dim] + + q = query[i].view(num_heads, 1, head_dim) + o = F.scaled_dot_product_attention(q, + kv, + v, + scale=scale, + enable_gqa=True) + out[i] = o.view(num_heads, v_head_dim) + + return out + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("bs", [1, 2, 4, 16]) +@pytest.mark.parametrize("block_size", [32, 64]) +def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int): + torch.set_default_device('cuda') + torch.manual_seed(42) + + # Deepseek R1 config + num_heads = 128 + kv_lora_rank = 512 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + qk_head_dim = kv_lora_rank + qk_rope_head_dim + scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5 + + MAX_SEQ_LEN = 1024 + + seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) + + # Generate block tables with random but unique block IDs + # From https://github.com/flashinfer-ai/flashinfer/pull/1222 + blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size + max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4) + total_blocks_needed = sum(blocks_per_seq) + # Get random unique IDs for all blocks + all_block_ids = torch.randperm(total_blocks_needed) + + block_id = 0 + block_tables = torch.zeros( + (bs, max_num_blocks_per_seq), + dtype=torch.int32, + ) + + # Populate block tables and track block assignments + block_id = 0 + for i in range(bs): + num_blocks_needed = blocks_per_seq[i] + block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id + + num_blocks_needed] + block_id += num_blocks_needed + + kv_cache = torch.randn(block_tables.numel(), block_size, + qk_head_dim).to(dtype) + q = torch.randn(bs, num_heads, qk_head_dim).to(dtype) + + out_ref = q.new_zeros(bs, num_heads, kv_lora_rank) + ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor) + + workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=q.device, + ) + # Flashinfer MLA expects the query to be of shape + # (bs, q_len_per_request, num_heads, qk_head_dim), + # where q_len_per_request is the MTP query length (=1 without MTP) + q = q.unsqueeze(1) + + out_ans = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens_tensor, + max_seq_len=max_seq_len, + bmm1_scale=scale, + ) + out_ans = out_ans.squeeze(1) + torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ba20da4fd75f..885d8e8e3fc7 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -101,6 +101,10 @@ def copy_blocks( ) -> None: raise NotImplementedError + @staticmethod + def decode_supports_qlen_padding() -> bool: + return False + def advance_step(self, model_input: "ModelRunnerInputBase", sampled_token_ids: Optional[torch.Tensor], block_size: int, num_seqs: int, num_queries: int) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5eb9660cd1e8..146698a3a1e5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1465,6 +1465,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "FLASHMLA", "FLASHINFER", "FLASHINFER_VLLM_V1", + "FLASHINFER_MLA", "ROCM_AITER_MLA", "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 2e026d582a6d..994a57f523e8 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -158,14 +158,13 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - previous_hidden_states: torch.Tensor, + hidden_states: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, - previous_hidden_states, inputs_embeds, - spec_step_idx) + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) return hidden_states def compute_logits( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b61b39a9274d..9a42f691c0bf 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -227,6 +227,19 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if use_mla: # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here + if selected_backend == _Backend.FLASHINFER_MLA: + if use_v1 and cls.has_device_capability(100): + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + set_kv_cache_layout("HND") + logger.info_once( + "Using FlashInfer MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashinfer_mla.FlashInferMLABackend") + else: + logger.warning( + "FlashInfer MLA backend is only supported on V1 engine" + " and requires compute capability 10.0") if selected_backend == _Backend.CUTLASS_MLA: if use_v1: logger.info_once("Using Cutlass MLA backend on V1 engine.") diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 61ce868c13b4..b9deccd42a9b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -50,6 +50,7 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() FLASHINFER_VLLM_V1 = enum.auto() + FLASHINFER_MLA = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA_VLLM_V1 = enum.auto() FLASHMLA_VLLM_V1 = enum.auto() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8592d1b26dfa..091cccb15656 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -7,14 +7,14 @@ from typing import ClassVar, Optional, Union import torch + +import vllm.envs as envs from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) from flashinfer.decode import (_get_range_buf, get_seq_lens, trtllm_batch_decode_with_kv_cache) from flashinfer.prefill import trtllm_batch_context_with_kv_cache - -import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.config import VllmConfig @@ -186,7 +186,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.PURE_DECODE_ONLY - reorder_batch_threshold: ClassVar[int] = 1 + def get_reorder_batch_threshold(self) -> int | None: + return 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -445,8 +446,11 @@ def build(self, fast_build: bool = False) -> FlashInferMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens + decode_threshold = self.get_reorder_batch_threshold() + assert decode_threshold is not None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=decode_threshold) page_size = self.kv_cache_spec.block_size max_q_len = common_attn_metadata.max_query_len diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 66a8d91db89c..af7c8ee77079 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch @@ -83,7 +83,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 + def get_reorder_batch_threshold(self) -> int: + return 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): @@ -111,8 +112,9 @@ def build(self, state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.get_reorder_batch_threshold())) # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index badff67656c2..756887a55aa9 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,7 +190,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import ClassVar, Generic, Optional, TypeVar, Union +from typing import Generic, Optional, TypeVar, Union import torch @@ -349,6 +349,7 @@ class MLACommonMetadata(Generic[D]): num_reqs: int max_query_len: int + max_seq_len: int num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor @@ -379,6 +380,7 @@ def __post_init__(self): def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if its available since # it is faster than FA2. + return False return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL and current_platform.is_device_capability(100)) @@ -400,7 +402,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ - reorder_batch_threshold: ClassVar[int] = 1 + + def get_reorder_batch_threshold(self) -> int | None: + return self._reorder_batch_threshold def __init__(self, kv_cache_spec: AttentionSpec, @@ -416,6 +420,11 @@ def __init__(self, self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config + self.num_speculative_tokens = 0 + if vllm_config.speculative_config is not None: + self.num_speculative_tokens = \ + vllm_config.speculative_config.num_speculative_tokens + self._reorder_batch_threshold = 1 + self.num_speculative_tokens self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled self.num_heads = self.model_config.get_num_attention_heads( parallel_config) @@ -586,6 +595,7 @@ def build(self, num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.seq_lens_cpu.max().item() # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because @@ -603,8 +613,12 @@ def build(self, num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu) + decode_threshold = self.get_reorder_batch_threshold() + assert decode_threshold is not None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=decode_threshold) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -710,6 +724,7 @@ def build(self, attn_metadata = self.metadata_cls( num_reqs=common_attn_metadata.num_reqs, max_query_len=common_attn_metadata.max_query_len, + max_seq_len=max_seq_len, num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py new file mode 100644 index 000000000000..95cab24fb320 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) + +logger = init_logger(__name__) + +FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + + +class FlashInferMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER_MLA" + + @staticmethod + def get_impl_cls() -> type["FlashInferMLAImpl"]: + return FlashInferMLAImpl + + @staticmethod + def decode_supports_qlen_padding() -> bool: + return True + + +g_fi_workspace = torch.empty( + FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device="cuda", +) + + +class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashInferMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashInferMLA V1 with FP8 KV cache not yet supported") + + self._workspace_buffer = g_fi_workspace + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + batch_size = attn_metadata.decode.block_table.shape[0] + + # (batch_size * q_len_per_request, num_heads, head_dim_qk) + q = torch.cat([q_nope, q_pe], dim=-1) + + assert q.shape[0] % batch_size == 0, \ + f"q.shape[0] ({q.shape[0]}) must be a " \ + f"multiple of batch_size ({batch_size})" + + # (batch_size, q_len_per_request, num_heads, head_dim_qk) + q = q.reshape((batch_size, q.shape[0] // batch_size, *q.shape[1:])) + + o = trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), + workspace_buffer=self._workspace_buffer, + qk_nope_head_dim=self.qk_nope_head_dim, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + block_tables=attn_metadata.decode.block_table, + seq_lens=attn_metadata.decode.seq_lens, + max_seq_len=attn_metadata.max_seq_len, + bmm1_scale=self.scale, + ) + + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index b5aecff9937f..860249c2b585 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -161,8 +161,18 @@ def _forward_decode( assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) + q = torch.cat([q_nope, q_pe], dim=-1) + + batch_size = attn_metadata.decode.block_table.shape[0] + + needs_padding = q.shape[0] % batch_size != 0 + + if needs_padding: + raise ValueError("oops") + else: + q = q.reshape( + (batch_size, q.shape[0] // batch_size, *q.shape[1:] + )) # (batch_size, q_len_per_request, num_heads, head_dim_qk) o, _ = flash_mla_with_kvcache( q=q, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 7aeea40b25a6..13ec8fe43a9d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -169,10 +169,12 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention. attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ AttentionCGSupport.NEVER + # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. - reorder_batch_threshold: ClassVar[Optional[int]] = None + def get_reorder_batch_threshold(self) -> int | None: + return None @abstractmethod def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], @@ -524,7 +526,7 @@ def make_local_attention_virtual_batches( def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, - decode_threshold: int = 1, + decode_threshold: int, ) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode @@ -567,7 +569,7 @@ def split_decodes_and_prefills( def reorder_batch_to_split_decodes_and_prefills( input_batch: "InputBatch", scheduler_output: "SchedulerOutput", - decode_threshold: int = 1, + decode_threshold: int, ) -> bool: """ Reorders the batch to split into prefill and decode requests; places all diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b2380bb3dd5a..6d99dfb3b161 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -18,9 +18,7 @@ from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.utils import is_pin_memory_available -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, - TreeAttentionMetadataBuilder) +from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadataBuilder from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata @@ -79,11 +77,15 @@ def __init__( dtype=self.dtype, device=device) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. max_batch_size = vllm_config.scheduler_config.max_num_seqs + max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) + self.arange = torch.arange(max_num_slots_for_arange, + device=device, + dtype=torch.int32) self.arange = torch.arange( - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - max_batch_size + 1, + max_num_slots_for_arange, device=device, dtype=torch.int32, ) @@ -134,13 +136,16 @@ def propose( target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -199,10 +204,15 @@ def propose( ) if self.method == "deepseek_mtp": last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) + if self.num_speculative_tokens == 1: + draft_token_ids = logits.argmax(dim=-1) + return draft_token_ids.view(-1, 1) + positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] if self.first_branching_level == 0: @@ -221,21 +231,6 @@ def propose( draft_token_ids = logits.argmax(dim=-1) - # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - # TODO: Currently, MTP module released by deepseek only has - # one layer. Adapt this code to support multiple layers once - # there's a multi-layer MTP module. - - # Currently, only FlashAttention and TreeAttention support multi-token - # eagle spec decode. This is because the code below - # makes assumptions about attn_metadata attributes available. - assert isinstance(attn_metadata, - (FlashAttentionMetadata, TreeAttentionMetadata)) - # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -244,9 +239,12 @@ def propose( input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size - attn_metadata.num_actual_tokens = batch_size - attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] + + common_attn_metadata.num_actual_tokens = batch_size + common_attn_metadata.max_query_len = 1 + common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[:batch_size + 1]).clone() for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. @@ -267,27 +265,37 @@ def propose( positions) # Increment the sequence lengths. - attn_metadata.max_seq_len += 1 - attn_metadata.seq_lens += 1 - # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) + common_attn_metadata.seq_lens += 1 + common_attn_metadata.seq_lens_cpu += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, + 1) + + common_attn_metadata.num_computed_tokens_cpu = \ + common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size - block_ids = attn_metadata.block_table.gather( + block_numbers = positions // self.block_size + block_ids = common_attn_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + common_attn_metadata.slot_mapping = (block_ids * self.block_size + + positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) + common_attn_metadata.slot_mapping.masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID) + + # Rebuild attention metadata + attn_metadata = self.runner.attn_metadata_builders[ + 0].build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=token_index + 1, + ) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -306,12 +314,17 @@ def propose( with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( + ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) + if self.method == "deepseek_mtp": + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) @@ -337,6 +350,40 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def prepare_inputs_deferred(self, + common_attn_metadata: CommonAttentionMetadata): + """ + This function is used to prepare the inputs for the spec decode. + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `last_token_indices`. + """ + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + + total_num_tokens = query_start_loc_cpu[-1].item() + token_indices = self.arange[:total_num_tokens] + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + seq_lens=common_attn_metadata.seq_lens, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=common_attn_metadata.seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, + ) + + return spec_common_attn_metadata, token_indices + def propose_tree( self, tree_root_level: int, @@ -536,6 +583,9 @@ def prepare_inputs( device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + # num_rejected_tokens = num_rejected_tokens * 0 + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - num_rejected_tokens diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d9d0b4bec871..1eed81c1fe91 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -54,8 +54,10 @@ def num_tokens(self) -> int: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - else: + elif idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] + else: + return -1 class InputBatch: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 85976fc1c825..4a57dd6f2ef7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -247,6 +247,9 @@ def __init__( self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + self.backup_next_token_ids = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -310,6 +313,14 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + self.discard_req_np = np.zeros(self.max_num_reqs) + self.backup_next_token_ids_cpu = torch.zeros( + self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.backup_next_token_ids_np = self.backup_next_token_ids_cpu.numpy() + # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it # means this layer will perform attention using the keys and values @@ -703,6 +714,25 @@ def _prepare_inputs( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) + num_tokens = [ + self.requests[r].num_tokens for r in self.input_batch.req_ids + ] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning + self.discard_req_np[:num_reqs] = \ + self.seq_lens_np[:num_reqs] < num_tokens_np + + # Precompute get_token_id for when there is no valid next token + self.backup_next_token_ids_np[:num_reqs] = np.array([ + self.requests[self.input_batch.req_ids[i]].get_token_id( + self.seq_lens_np[i]) for i in range(num_reqs) + ]) + + self.backup_next_token_ids[:num_reqs].copy_( + self.backup_next_token_ids_cpu[:num_reqs], non_blocking=True) + # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) @@ -1680,8 +1710,42 @@ def execute_model( scheduler_output, ) - # Get the valid generated tokens. - sampled_token_ids = sampler_output.sampled_token_ids + if not self.speculative_config: + # Speculative decoding is not enabled. + spec_token_ids = None + valid_sampled_token_ids = self.get_valid_sampled_token_ids( + sampler_output.sampled_token_ids, + discard_sampled_tokens_req_indices) + else: + assert spec_decode_common_attn_metadata is not None + spec_token_ids, valid_sampled_token_ids = \ + self.propose_draft_token_ids( + scheduler_output, + sampler_output.sampled_token_ids, + sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + discard_sampled_tokens_req_indices) + + self.eplb_step() + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=spec_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits) + + def get_valid_sampled_token_ids( + self, sampled_token_ids: torch.Tensor, + discard_sampled_tokens_req_indices: list[int]) -> list[list[int]]: max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. @@ -1720,54 +1784,28 @@ def execute_model( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if not self.speculative_config: - # Speculative decoding is not enabled. - spec_token_ids = None - else: - assert spec_decode_common_attn_metadata is not None - spec_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, - ) - - self.eplb_step() - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - kv_connector_output=kv_connector_output, - num_nans_in_logits=num_nans_in_logits, - ) + return valid_sampled_token_ids def propose_draft_token_ids( - self, - scheduler_output: "SchedulerOutput", - sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, - hidden_states: torch.Tensor, + self, scheduler_output: "SchedulerOutput", + sampled_token_ids: torch.Tensor | list[list[int]], + sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, - ) -> list[list[int]]: + discard_sampled_tokens_req_indices: list[int] + ) -> tuple[list[list[int]], list[list[int]]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "medusa": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) + if sample_hidden_states.shape[0] == len(sampled_token_ids): # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states @@ -1787,25 +1825,53 @@ def propose_draft_token_ids( sampling_metadata=sampling_metadata, ) elif self.speculative_config.use_eagle(): + assert isinstance(sampled_token_ids, torch.Tensor) assert isinstance(self.drafter, EagleProposer) - # TODO(woosuk): Refactor the loop. - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) + assert discard_sampled_tokens_req_indices is not None + + _max_gen_len = sampled_token_ids.shape[-1] + # Get all sampled tokens from valid requests + _valid_sampled_token_ids_gpu = sampled_token_ids + + _valid_sampled_token_ids_gpu[ + discard_sampled_tokens_req_indices, :] = -1 + + # Generate a mask for all valid tokens within those requests + if _max_gen_len == 1: + _valid_mask = torch.ones_like(_valid_sampled_token_ids_gpu, + dtype=torch.bool) + else: + _valid_mask = ((_valid_sampled_token_ids_gpu != -1) & + (_valid_sampled_token_ids_gpu + < self.input_batch.vocab_size)) + + # Count valid tokens in each request + _valid_sampled_count = _valid_mask.sum(dim=1) + + _batch = _valid_sampled_token_ids_gpu.shape[0] + + # Get the rightmost valid index per row + _last_valid_indices = _valid_sampled_count - 1 + + _last_valid_indices_safe = torch.max( + _last_valid_indices, torch.zeros_like(_last_valid_indices)) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + _selected_tokens = torch.gather( + _valid_sampled_token_ids_gpu, 1, + _last_valid_indices_safe.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + next_token_ids_gpu_2 = torch.where( + _last_valid_indices != -1, _selected_tokens, + self.backup_next_token_ids[:_batch]) + + token_indices_to_sample = None + + if spec_decode_metadata is None or not self.supports_qlen_padding: + sampled_token_ids = self.get_valid_sampled_token_ids( + sampled_token_ids, discard_sampled_tokens_req_indices) if spec_decode_metadata is None: # input_ids can be None for multimodal models. @@ -1819,17 +1885,35 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, num_rejected_tokens_cpu) + if self.supports_qlen_padding: + _num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1] + ]) + + _num_rejected_tokens_gpu = torch.where( + _num_draft_tokens_gpu > 0, + _num_draft_tokens_gpu + 1 - _valid_sampled_count, + torch.zeros_like(_num_draft_tokens_gpu)) + + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs_deferred(common_attn_metadata) + token_indices_to_sample = \ + common_attn_metadata.query_start_loc[1:] - 1 \ + - _num_rejected_tokens_gpu + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + common_attn_metadata, num_rejected_tokens_cpu) target_token_ids = self.input_ids[token_indices] # TODO(woosuk): Support M-RoPE. @@ -1848,13 +1932,19 @@ def propose_draft_token_ids( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, + next_token_ids=next_token_ids_gpu_2, + last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds, ) spec_token_ids = draft_token_ids.tolist() - return spec_token_ids + + if spec_decode_metadata is not None and self.supports_qlen_padding: + sampled_token_ids = self.get_valid_sampled_token_ids( + sampled_token_ids, discard_sampled_tokens_req_indices) + + return spec_token_ids, sampled_token_ids def propose_ngram_draft_token_ids( self, @@ -2631,9 +2721,6 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) - # Calculate reorder batch threshold (if neeeded) - self.calculate_reorder_batch_threshold() - if len(self.attn_backends) > 0: return @@ -2677,7 +2764,7 @@ def calculate_reorder_batch_threshold(self) -> None: # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) reorder_batch_threshold_i = ( - attn_metadata_builder_i.reorder_batch_threshold) + attn_metadata_builder_i.get_reorder_batch_threshold()) if reorder_batch_threshold_i is not None: if self.reorder_batch_threshold is not None: if reorder_batch_threshold_i != \ @@ -2689,6 +2776,13 @@ def calculate_reorder_batch_threshold(self) -> None: f"{self.reorder_batch_threshold}") else: self.reorder_batch_threshold = reorder_batch_threshold_i + if self.supports_qlen_padding: + assert self.reorder_batch_threshold == \ + 1 + self.speculative_config.num_speculative_tokens, \ + "Reorder batch threshold must be 1 + num_speculative_tokens " \ + "for attention backends with qlen padding support. " \ + f"Got {self.reorder_batch_threshold} != " \ + f"{1 + self.speculative_config.num_speculative_tokens}." def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: @@ -2917,12 +3011,21 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + self.supports_qlen_padding = False if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config) + # check if qlen padding is supported + self.supports_qlen_padding = all( + attn_backend.decode_supports_qlen_padding() + for attn_backend in self.attn_backends) + + # Calculate reorder batch threshold (if needed) + self.calculate_reorder_batch_threshold() + if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches)