diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 042e953866cf..c703d6aae9f9 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -98,7 +98,7 @@ def test_without_spec_decoding( @single_gpu_only @large_gpu_mark(min_gb=16) -def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch): +def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch): """Test consistency and acceptance rates with some different combos of preemption, executor, async scheduling, prefill chunking, spec decoding model length. @@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch) ) +def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch): + """Test ngram_gpu speculative decoding with different configurations. + + This test specifically validates ngram_gpu behavior with various: + - Number of speculative tokens (2-6) + - Prompt lookup window sizes (min/max) + - Async scheduling enabled (as in production) + - Different executors and chunking settings + """ + + # Variant with larger speculation window + ngram_gpu_config = { + "method": "ngram_gpu", + "num_speculative_tokens": 3, + "prompt_lookup_max": 3, + "prompt_lookup_min": 2, + } + + # Test configurations covering various scenarios + # test_preemption, executor, async_scheduling, + # spec_config, test_prefill_chunking + test_configs = [ + (False, "mp", False, None, False), + (False, "mp", False, ngram_gpu_config, False), + (True, "mp", False, ngram_gpu_config, True), + (False, "mp", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, False), + (True, "uni", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, True), + ] + + # Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight + # and ngram_gpu doesn't require a specific draft model + run_tests(monkeypatch, MODEL, test_configs, [{}]) + + @dynamo_config.patch(cache_size_limit=16) def run_tests( monkeypatch: pytest.MonkeyPatch, @@ -282,11 +318,12 @@ def run_test( else dict(gpu_memory_utilization=0.9) ) spec_mml = (spec_config or {}).get("max_model_len") + spec_method = (spec_config or {}).get("method", "none") test_config = ( f"executor={executor}, preemption={test_preemption}, " f"async_sched={async_scheduling}, " f"chunk_prefill={test_prefill_chunking}, " - f"spec_decoding={spec_decoding}, spec_mml={spec_mml}" + f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}" ) print("-" * 80) print(f"---- TESTING {test_str}: {test_config}") @@ -294,7 +331,7 @@ def run_test( with VllmRunner( model, - max_model_len=512, + max_model_len=4096, enable_chunked_prefill=test_prefill_chunking, # Force prefill chunking max_num_batched_tokens=48 if test_prefill_chunking else None, diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4066dfe9e34d..3988070ca759 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness( cleanup_dist_env_and_memory() +@pytest.mark.parametrize("async_scheduling", [True], ids=["async"]) +@single_gpu_only +@large_gpu_mark(min_gb=20) +def test_ngram_gpu_default_with_async_scheduling( + async_scheduling: bool, +): + """ + Test ngram_gpu speculative decoding (k=3) correctness with and without + async scheduling, validated via GSM8K accuracy. + Uses Qwen/Qwen3-8B (ref GSM8K accuracy: 87%-92%). + """ + qwen3_model = "Qwen/Qwen3-8B" + spec_llm = LLM( + model=qwen3_model, + speculative_config={ + "method": "ngram_gpu", + "prompt_lookup_max": 3, + "prompt_lookup_min": 2, + "num_speculative_tokens": 2, + }, + max_model_len=4096, + async_scheduling=async_scheduling, + ) + evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8) + del spec_llm + cleanup_dist_env_and_memory() + + @single_gpu_only @large_gpu_mark(min_gb=20) def test_suffix_decoding_acceptance( diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 9d37a5331c96..2bf53a7fad74 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -907,6 +907,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE. disable_cache = not is_compile_cache_enabled(self.inductor_config) + # TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors. + is_ngram_gpu_enabled = ( + vllm_config.speculative_config is not None + and vllm_config.speculative_config.use_ngram_gpu() + ) + disable_cache = disable_cache or is_ngram_gpu_enabled + if disable_cache: logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") else: diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index a950ba531ad2..27b5188eb52d 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -47,6 +47,7 @@ "step3p5_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes] +NgramGPUTypes = Literal["ngram_gpu"] SpeculativeMethod = Literal[ "ngram", "medusa", @@ -54,6 +55,7 @@ "draft_model", "suffix", EagleModelTypes, + NgramGPUTypes, ] @@ -364,6 +366,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "ngram_gpu": + self.model = "ngram_gpu" elif self.method == "suffix": self.model = "suffix" elif self.method == "extract_hidden_states": @@ -374,8 +378,9 @@ def __post_init__(self): ) if self.method in ("ngram", "[ngram]"): - # Unified to "ngram" internally self.method = "ngram" + + if self.method in ("ngram", "ngram_gpu"): # Set default values if not provided if self.prompt_lookup_min is None and self.prompt_lookup_max is None: # TODO(woosuk): Tune these values. They are arbitrarily chosen. @@ -832,6 +837,9 @@ def uses_draft_model(self) -> bool: def uses_extract_hidden_states(self) -> bool: return self.method == "extract_hidden_states" + def use_ngram_gpu(self) -> bool: + return self.method == "ngram_gpu" + def __repr__(self) -> str: method = self.method model = ( diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 34c668362d40..6b98d9107dfc 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -41,7 +41,7 @@ from .parallel import ParallelConfig from .profiler import ProfilerConfig from .scheduler import SchedulerConfig -from .speculative import EagleModelTypes, SpeculativeConfig +from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig from .structured_outputs import StructuredOutputsConfig from .utils import SupportsHash, config, replace from .weight_transfer import WeightTransferConfig @@ -698,11 +698,13 @@ def __post_init__(self): if self.speculative_config is not None: if ( self.speculative_config.method not in get_args(EagleModelTypes) + and self.speculative_config.method not in get_args(NgramGPUTypes) and self.speculative_config.method != "draft_model" ): raise ValueError( "Currently, async scheduling is only supported " - "with EAGLE/MTP/Draft Model kind of speculative decoding." + "with EAGLE/MTP/Draft Model/NGram GPU kind of " + "speculative decoding" ) if self.speculative_config.disable_padded_drafter_batch: raise ValueError( @@ -720,6 +722,7 @@ def __post_init__(self): if ( self.speculative_config is not None and self.speculative_config.method not in get_args(EagleModelTypes) + and self.speculative_config.method not in get_args(NgramGPUTypes) ): logger.warning_once( "Async scheduling not supported with %s-based " diff --git a/vllm/tool_parsers/hermes_tool_parser.py b/vllm/tool_parsers/hermes_tool_parser.py index b9b1dcda6f68..5bde5b2c07ab 100644 --- a/vllm/tool_parsers/hermes_tool_parser.py +++ b/vllm/tool_parsers/hermes_tool_parser.py @@ -385,6 +385,7 @@ def extract_tool_calls_streaming( prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( "arguments" ) + assert current_tool_call is not None cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -489,6 +490,7 @@ def extract_tool_calls_streaming( # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration + assert isinstance(current_tool_call, dict) if self.current_tool_id == len(self.prev_tool_call_arr) - 1: self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py new file mode 100644 index 000000000000..3ff84180463d --- /dev/null +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -0,0 +1,660 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GPU-accelerated N-gram proposer using fully async PyTorch tensor operations. + +This version uses a fully vectorized approach with unfold and argmax for +finding the first match across all sequences in parallel. +""" + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CompilationConfig, + CompilationMode, + CUDAGraphMode, + VllmConfig, +) +from vllm.forward_context import set_forward_context +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + + +@support_torch_compile() +class NgramGPUKernel(nn.Module): + """GPU-accelerated N-gram proposer using fully async tensor operations.""" + + def __init__( + self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda" + ): + super().__init__() + + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.device = device + + def _find_first_and_extract_all_n_parallel( + self, + token_ids: torch.Tensor, + seq_lengths: torch.Tensor, + min_ngram_len: int, + max_ngram_len: int, + num_draft_tokens: int, + ) -> torch.Tensor: + """ + Find suffix n-gram matches and extract following tokens. + Searches for the earliest prior occurrence of the trailing n-gram, + tries multiple lengths, and picks the longest valid match. + + Args: + token_ids: Token IDs for each sequence + seq_lengths: Actual length of each sequence (excluding padding) + min_ngram_len: Minimum n-gram size to search for (e.g., 2) + max_ngram_len: Maximum n-gram size to search for (e.g., 5) + num_draft_tokens: Number of tokens to extract after match (k) + + Returns: + Draft token predictions; -1 means invalid/no match. + """ + batch_size = token_ids.shape[0] + max_seq_len = token_ids.shape[1] + device = token_ids.device + num_ngram_sizes = max_ngram_len - min_ngram_len + 1 + + # All n-gram sizes to try. + ngram_lengths = torch.arange(min_ngram_len, max_ngram_len + 1, device=device) + batch_indices = torch.arange(batch_size, device=device) + + # Earliest match per (sequence, ngram_len); -1 means no match. + first_match_positions = torch.full( + (batch_size, num_ngram_sizes), -1, dtype=torch.long, device=device + ) + + for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)): + # Sliding windows of size ngram_len; unfold is O(1) view. + search_windows = token_ids.unfold(1, ngram_len, 1) + num_windows = search_windows.shape[1] + + # Trailing suffix (last ngram_len tokens) for each sequence. + suffix_starts = seq_lengths - ngram_len + suffix_indices = suffix_starts.unsqueeze(1) + torch.arange( + ngram_len, device=device + ) + suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0)) + + # Window matches for each sequence. + matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1) + + # Match must leave room for at least one draft token. + max_valid_suffix_start = seq_lengths - ngram_len - 1 + window_positions = torch.arange(num_windows, device=device) + valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1) + final_matches = matches & valid_mask + + # Find earliest match (argmax=0 when empty; verify with has_match). + first_match_idx = torch.argmax(final_matches.int(), dim=1) + has_match = final_matches[batch_indices, first_match_idx] + + # Store valid match positions (window index = position). + first_match_positions[:, i] = torch.where(has_match, first_match_idx, -1) + + # Select the longest n-gram with a match. + best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1) + best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx # Flip back + + # Match position for the best n-gram. + best_match_pos = first_match_positions[batch_indices, best_ngram_idx] + + # Avoid data-dependent branching. + has_any_match = best_match_pos >= 0 + + # Length of the best matching n-gram. + best_ngram_lengths = ngram_lengths[best_ngram_idx] + + # Start position right after the matched suffix. + draft_start = torch.where( + has_any_match, + best_match_pos + best_ngram_lengths, + torch.zeros_like(best_match_pos), + ) + tokens_available = seq_lengths - draft_start + + # Gather indices for draft tokens. + draft_indices = draft_start.unsqueeze(1) + torch.arange( + num_draft_tokens, device=device + ) + draft_indices = draft_indices.clamp(min=0, max=max_seq_len - 1) + + # Extract draft tokens; gather always runs. + draft_tokens = torch.gather(token_ids, 1, draft_indices) + + # Mask positions beyond available tokens. + position_indices = torch.arange(num_draft_tokens, device=device).unsqueeze(0) + valid_positions = position_indices < tokens_available.unsqueeze(1) + + draft_tokens = torch.where( + valid_positions, + draft_tokens, + torch.full_like(draft_tokens, -1), + ) + + # If no match, mask all positions. + draft_tokens = torch.where( + has_any_match.unsqueeze(1), + draft_tokens, + torch.full_like(draft_tokens, -1), + ) + + return draft_tokens + + def forward( + self, + num_tokens_no_spec: torch.Tensor, + token_ids_gpu: torch.Tensor, + combined_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for N-gram proposal using GPU tensor operations. + + Args: + num_tokens_no_spec: Number of tokens for each sequence [batch_size] + token_ids_gpu: Token IDs [batch_size, max_len] + combined_mask: Whether each sequence is valid for spec decode [batch_size] + + Returns: + draft_tokens: [batch_size, k] on GPU + num_valid_draft_tokens: [batch_size] int32 on GPU, count of + leading valid (non -1) tokens per request. + """ + + device = token_ids_gpu.device + + # Infer batch size to preserve dynamic shape. + actual_batch_size = token_ids_gpu.shape[0] + + # Allocate in forward so torch.compile can optimize. + # NOTE(patchy): Do NOT pre-allocate this as a buffer + # it breaks torch.compile + draft_tokens = torch.full( + (actual_batch_size, self.k), -1, dtype=torch.int32, device=device + ) + + results = self._find_first_and_extract_all_n_parallel( + token_ids_gpu, + num_tokens_no_spec, + min_ngram_len=self.min_n, + max_ngram_len=self.max_n, + num_draft_tokens=self.k, + ) + + draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1) + + # Count leading contiguous valid (non -1) tokens per request. + is_valid = draft_tokens != -1 # [batch, k] + cum_valid = is_valid.int().cumsum(dim=1) # [batch, k] + positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0) + num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1) + + return draft_tokens, num_valid_draft_tokens + + def load_model(self, *args, **kwargs): + """No model to load for N-gram proposer.""" + pass + + +class NgramProposerGPU: + def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + compilation_config = CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["none"], + splitting_ops=[], + compile_sizes=[], + inductor_compile_config={ + "enable_auto_functionalized_v2": False, + "max_autotune": True, + "aggressive_fusion": True, + "triton.autotune_pointwise": True, + "coordinate_descent_tuning": True, + "use_mixed_mm": False, + }, + cudagraph_mode=CUDAGraphMode.NONE, + ) + model_config = vllm_config.model_config + speculative_config = vllm_config.speculative_config + scheduler_config = vllm_config.scheduler_config + + self.vllm_config = VllmConfig( + compilation_config=compilation_config, + model_config=model_config, + speculative_config=speculative_config, + scheduler_config=scheduler_config, + ) + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.device = device + + self.kernel = NgramGPUKernel( + vllm_config=self.vllm_config, prefix="ngram_gpu_kernel", device=device + ) + self.kernel.to(device) + self.kernel.eval() + + self._dummy_run() + + def _dummy_run(self): + token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data( + batch_size=self.max_num_seqs, + max_seq_len=self.max_model_len, + pattern_len=self.k, + device=self.device, + ) + + combined_mask = sampled_flags & valid_mask & (num_tokens >= self.min_n) + + for _ in range(3): + with set_forward_context(None, self.vllm_config): + _, _ = self.kernel(num_tokens, token_ids, combined_mask) + + def _generate_dummy_data( + self, + batch_size: int, + max_seq_len: int, + pattern_len: int, + device: str = "cuda", + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate random test data with n-gram repetitions. + + Args: + batch_size: Number of sequences in the batch + max_seq_len: Maximum sequence length + pattern_len: Length of patterns to inject for matching + device: Device to place tensors on + + Returns: + token_ids: [batch_size, max_seq_len] tensor + num_tokens: [batch_size] tensor + sampled_flags: [batch_size] bool tensor + valid_mask: [batch_size] bool tensor + """ + token_ids = torch.zeros( + batch_size, + max_seq_len, + dtype=torch.int32, + device=device, + ) + + num_tokens = torch.randint( + pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device + ) + + sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device) + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + + return token_ids, num_tokens, sampled_flags, valid_mask + + def propose( + self, + num_tokens_no_spec: torch.Tensor, # [batch_size] + token_ids_gpu: torch.Tensor, # [batch_size, max_len] + valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1] + valid_sampled_tokens_count: torch.Tensor, # [batch_size] + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Propose draft tokens using GPU-accelerated n-gram matching. + + Scatter sampled tokens into `token_ids_gpu`, compute temporary + updated lengths, then run the kernel. + + Args: + num_tokens_no_spec: Number of tokens per sequence (read-only) + token_ids_gpu: Token IDs tensor (modified in-place with new tokens) + valid_sampled_token_ids_gpu: Newly sampled tokens to scatter + valid_sampled_tokens_count: Count of valid tokens per sequence + + Returns: + draft_tokens: Proposed draft token IDs [batch_size, k] + num_valid_draft_tokens: Count of leading valid draft tokens + per request [batch_size] + """ + assert token_ids_gpu.device == self.device + assert num_tokens_no_spec.device == self.device + + batch_size = num_tokens_no_spec.shape[0] + max_seq_len = token_ids_gpu.shape[1] + max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1 + + # Scatter newly sampled tokens into token_ids_gpu. + offsets = torch.arange(max_new_tokens, device=self.device) + write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0) + valid_write_mask = offsets.unsqueeze(0) < valid_sampled_tokens_count.unsqueeze( + 1 + ) + in_bounds = write_positions < max_seq_len + scatter_mask = ( + valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds + ) + + write_positions_long = write_positions.clamp(max=max_seq_len - 1).long() + existing_values = token_ids_gpu.gather(1, write_positions_long) + + tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_gpu.dtype) + tokens_to_scatter = torch.where( + scatter_mask, + tokens_cast, + existing_values, + ) + token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter) + + num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count + + # Compute validity masks. + sampled_flags = valid_sampled_tokens_count > 0 + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device) + + with set_forward_context(None, self.vllm_config): + combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n) + + with record_function_or_nullcontext("ngram_proposer_gpu: kernel"): + draft_tokens, num_valid_draft_tokens = self.kernel( + num_tokens_tmp, + token_ids_gpu, + combined_mask, + ) + + return draft_tokens, num_valid_draft_tokens + + def update_token_ids_ngram( + self, + sampled_token_ids: torch.Tensor | list[list[int]], + gpu_input_batch: InputBatch, + token_ids_gpu: torch.Tensor, + num_tokens_no_spec: torch.Tensor, + discard_request_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepare speculative decoding inputs on device: + compute next token ids and valid counts, honoring discarded requests + and rejected tokens, without CPU-GPU sync. + """ + num_reqs = gpu_input_batch.num_reqs + + if isinstance(sampled_token_ids, list): + # When disable_padded_drafter_batch=True, sampled_token_ids is + # an irregular list[list[int]] where sublists may have different + # lengths (including empty lists for discarded requests). + # Pad all sublists to the same length with -1 before converting + # to tensor. + max_len = max( + (len(sublist) for sublist in sampled_token_ids), + default=0, + ) + # Ensure at least length 1 for tensor creation + max_len = max(max_len, 1) + padded_list = [ + sublist + [-1] * (max_len - len(sublist)) + for sublist in sampled_token_ids + ] + sampled_token_ids = torch.tensor( + padded_list, dtype=torch.int32, device=self.device + ) + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor for ngram_gpu" + ) + + # Backup last valid token before speculative tokens. + backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long() + backup_next_token_ids = torch.gather( + token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1) + ).squeeze(1) + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + # Invalidate sampled tokens for discarded requests. + discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1) + valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1) + + # Mask valid tokens within each request. + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) + + # Count valid tokens per request. + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Rightmost valid index per row. + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Last valid token from each row; undefined if none. + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) + + # Use last token if valid; otherwise fallback to backup. + next_token_ids = torch.where( + last_valid_indices != -1, + selected_tokens, + backup_next_token_ids, + ) + + return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu + + def load_model(self, *args, **kwargs): + self.kernel.load_model(*args, **kwargs) + + +def update_scheduler_for_invalid_drafts( + num_valid_draft_tokens_event: torch.cuda.Event, + num_valid_draft_tokens_cpu: torch.Tensor, + scheduler_output: "SchedulerOutput", + req_id_to_index: dict[str, int], +) -> None: + """Trim invalid speculative slots using per-request valid draft counts. + + Args: + num_valid_draft_tokens_event: Event for async D2H completion. + num_valid_draft_tokens_cpu: CPU buffer of valid draft counts. + scheduler_output: Scheduler metadata to update in-place. + req_id_to_index: Request-id to batch-index mapping. + """ + req_data = scheduler_output.scheduled_cached_reqs + num_valid_draft_tokens_event.synchronize() + + for req_id in req_data.req_ids: + req_index = req_id_to_index.get(req_id) + if req_index is None: + continue + + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id) + if spec_token_ids is None: + continue + + scheduled_k = len(spec_token_ids) + + valid_k = int(num_valid_draft_tokens_cpu[req_index].item()) + valid_k = max(0, min(valid_k, scheduled_k)) + + tokens_to_trim = scheduled_k - valid_k + scheduler_output.total_num_scheduled_tokens -= tokens_to_trim + scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim + + if valid_k == 0: + scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None) + else: + scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids[ + :valid_k + ] + + +def update_ngram_gpu_tensors_incremental( + input_batch: InputBatch, + token_ids_gpu_tensor: torch.Tensor, + num_tokens_no_spec_gpu: torch.Tensor, + new_reqs: list[CachedRequestState], + device: torch.device, + _pinned_idx_buf: torch.Tensor, + _pinned_val_buf: torch.Tensor, +) -> None: + """Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu + for ngram GPU proposer. + """ + prev_req_id_to_index = input_batch.prev_req_id_to_index + curr_req_id_to_index = input_batch.req_id_to_index + + if not curr_req_id_to_index: + return + + active_indices = list(curr_req_id_to_index.values()) + n_active = len(active_indices) + + # Use resident pinned buffers to avoid per-call allocation. + active_idx_cpu = _pinned_idx_buf[:n_active] + active_idx_cpu.copy_(torch.as_tensor(active_indices, dtype=torch.long)) + + active_idx_gpu = active_idx_cpu.to(device=device, non_blocking=True) + + new_req_ids = {req.req_id for req in new_reqs} + + # First run, no previous state. + if prev_req_id_to_index is None: + for idx in active_indices: + num_tokens = input_batch.num_tokens_no_spec[idx] + if num_tokens > 0: + token_ids_gpu_tensor[idx, :num_tokens].copy_( + input_batch.token_ids_cpu_tensor[idx, :num_tokens], + non_blocking=True, + ) + + _sync_num_tokens( + input_batch, + num_tokens_no_spec_gpu, + active_idx_cpu, + active_idx_gpu, + n_active, + device, + _pinned_val_buf, + ) + return + + # Detect index changes for reorder. + reorder_src: list[int] = [] + reorder_dst: list[int] = [] + + for req_id, curr_idx in curr_req_id_to_index.items(): + if req_id in new_req_ids: + continue + prev_idx = prev_req_id_to_index.get(req_id) + if prev_idx is not None and prev_idx != curr_idx: + reorder_src.append(prev_idx) + reorder_dst.append(curr_idx) + + if reorder_src: + src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=device) + dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=device) + + temp_token_ids = token_ids_gpu_tensor[src_tensor].clone() + temp_num_tokens = num_tokens_no_spec_gpu[src_tensor].clone() + + token_ids_gpu_tensor[dst_tensor] = temp_token_ids + num_tokens_no_spec_gpu[dst_tensor] = temp_num_tokens + + # Full copy for new/resumed requests. + for req_state in new_reqs: + new_req_idx = curr_req_id_to_index.get(req_state.req_id) + if new_req_idx is None: + continue + + num_tokens = input_batch.num_tokens_no_spec[new_req_idx] + if num_tokens > 0: + token_ids_gpu_tensor[new_req_idx, :num_tokens].copy_( + input_batch.token_ids_cpu_tensor[new_req_idx, :num_tokens], + non_blocking=True, + ) + + # Always batch-sync sequence lengths from CPU for ALL active requests. + _sync_num_tokens( + input_batch, + num_tokens_no_spec_gpu, + active_idx_cpu, + active_idx_gpu, + n_active, + device, + _pinned_val_buf, + ) + + +def _sync_num_tokens( + input_batch: InputBatch, + num_tokens_no_spec_gpu: torch.Tensor, + active_idx_cpu: torch.Tensor, + active_idx_gpu: torch.Tensor, + n_active: int, + device: torch.device, + _pinned_val_buf: torch.Tensor, +) -> None: + """Batch-sync GPU sequence lengths from CPU source of truth. + + Inputs: + input_batch: Batch container with CPU length tensor. + num_tokens_no_spec_gpu: Destination GPU length tensor. + active_idx_cpu: Active request indices on CPU. + active_idx_gpu: Active request indices on GPU. + n_active: Number of active requests. + device: Target CUDA device. + _pinned_val_buf: Resident pinned int32 staging buffer. + Outputs: + None (updates num_tokens_no_spec_gpu in-place). + """ + src_cpu = input_batch.num_tokens_no_spec_cpu_tensor + vals = _pinned_val_buf[:n_active] + vals.copy_(src_cpu.index_select(0, active_idx_cpu)) + + num_tokens_no_spec_gpu.index_copy_( + 0, + active_idx_gpu, + vals.to(device=device, non_blocking=True), + ) + + +def copy_num_valid_draft_tokens( + num_valid_draft_tokens_cpu: torch.Tensor, + num_valid_draft_tokens_copy_stream: torch.cuda.Stream, + num_valid_draft_tokens_event: torch.cuda.Event, + num_valid_draft_tokens: torch.Tensor | None, + batch_size: int, +) -> None: + """ + Async D2H copy of per-request valid draft counts. + """ + if num_valid_draft_tokens is None: + return + + num_reqs_to_copy = min(batch_size, num_valid_draft_tokens.shape[0]) + if num_reqs_to_copy <= 0: + return + + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(num_valid_draft_tokens_copy_stream): + num_valid_draft_tokens_copy_stream.wait_stream(default_stream) + num_valid_draft_tokens_cpu[:num_reqs_to_copy].copy_( + num_valid_draft_tokens[:num_reqs_to_copy], non_blocking=True + ) + num_valid_draft_tokens_event.record() diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c70970fdc06e..579c9b7a5acc 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -127,7 +127,13 @@ def __init__( # allocation if max_model_len is big. # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) self.req_prompt_embeds: dict[int, torch.Tensor] = {} - self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec_cpu_tensor = torch.zeros( + (max_num_reqs,), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy() self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( (max_num_reqs,), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index abeb10735129..8f2418e6d2fa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,7 +10,7 @@ from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager from copy import copy, deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import reduce from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast @@ -162,6 +162,12 @@ from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.ngram_proposer_gpu import ( + NgramProposerGPU, + copy_num_valid_draft_tokens, + update_ngram_gpu_tensors_incremental, + update_scheduler_for_invalid_drafts, +) from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -422,7 +428,7 @@ def __init__( # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches + # TODO: Support overlapping micro-batches # https://github.com/vllm-project/vllm/issues/18019 self.broadcast_pp_output = ( self.parallel_config.distributed_executor_backend == "external_launcher" @@ -491,6 +497,7 @@ def __init__( if self.speculative_config and get_pp_group().is_last_rank: self.drafter: ( NgramProposer # noqa: F823 + | NgramProposerGPU | SuffixDecodingProposer | EagleProposer | DraftModelProposer @@ -507,6 +514,23 @@ def __init__( device=self.device, runner=self, ) + elif self.speculative_config.use_ngram_gpu(): + self.drafter = NgramProposerGPU(self.vllm_config, self.device, self) + self.num_tokens_no_spec_gpu = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=device + ) + self.token_ids_gpu_tensor = torch.zeros( + self.max_num_reqs, + self.max_model_len, + dtype=torch.int32, + device=device, + ) + self._ngram_pinned_idx_buf = torch.zeros( + self.max_num_reqs, dtype=torch.long, pin_memory=True + ) + self._ngram_pinned_val_buf = torch.zeros( + self.max_num_reqs, dtype=torch.int32, pin_memory=True + ) elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): @@ -562,7 +586,7 @@ def __init__( ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - # We need to use the encoder length for encoder-decoer + # We need to use the encoder length for encoder-decoder # because of KV cache for cross-attention. max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, @@ -708,6 +732,21 @@ def __init__( # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + # N-gram GPU path: async D2H buffer/event for per-request valid draft counts. + self._num_valid_draft_tokens: torch.Tensor | None = None + self._num_valid_draft_tokens_cpu: torch.Tensor | None = None + self._num_valid_draft_tokens_event: torch.cuda.Event | None = None + self._num_valid_draft_tokens_copy_stream: torch.cuda.Stream | None = None + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + self._num_valid_draft_tokens_cpu = torch.empty( + self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory + ) + self._num_valid_draft_tokens_event = torch.cuda.Event() + self._num_valid_draft_tokens_copy_stream = torch.cuda.Stream() + self._draft_token_req_ids: list[str] | None = None self.transfer_event = torch.Event() self.sampled_token_ids_pinned_cpu = torch.empty( @@ -979,6 +1018,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) + is_ngram_gpu = ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ) + if is_ngram_gpu: + ngram_gpu_new_reqs: list[CachedRequestState] = [] + reqs_to_add: list[CachedRequestState] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: @@ -1041,12 +1087,31 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self._init_xdrope_positions(req_state) reqs_to_add.append(req_state) + # Track new requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + # Save scheduler-allocated spec lengths before trimming so + # prev_num_draft_len keeps the optimistic count for rejection correction. + original_num_spec_per_req: dict[str, int] = {} + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + for req_id, toks in scheduled_spec_tokens.items(): + original_num_spec_per_req[req_id] = len(toks) + update_scheduler_for_invalid_drafts( + self._num_valid_draft_tokens_event, + self._num_valid_draft_tokens_cpu, + scheduler_output, + self.input_batch.req_id_to_index, + ) + # Wait until valid_sampled_tokens_count is copied to cpu, # then use it to update actual num_computed_tokens of each request. valid_sampled_token_count = self._get_valid_sampled_token_count() @@ -1063,13 +1128,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # prev_num_draft_len is used in async scheduling mode with # spec decode. it indicates if need to update num_computed_tokens # of the request. for example: - # fist step: num_computed_tokens = 0, spec_tokens = [], + # first step: num_computed_tokens = 0, spec_tokens = [], # prev_num_draft_len = 0. # second step: num_computed_tokens = 100(prompt length), # spec_tokens = [a,b], prev_num_draft_len = 0. # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], # prev_num_draft_len = 2. - # num_computed_tokens in first step and second step does't contain + # num_computed_tokens in first step and second step doesn't contain # the spec tokens length, but in third step it contains the # spec tokens length. we only need to update num_computed_tokens # when prev_num_draft_len > 0. @@ -1083,6 +1148,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens -= num_rejected req_state.output_token_ids.extend([-1] * num_accepted) + if is_ngram_gpu and num_accepted > 0 and req_index is not None: + self.input_batch.num_tokens_no_spec[req_index] += num_accepted + # Update the cached states. req_state.num_computed_tokens = num_computed_tokens @@ -1143,6 +1211,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] reqs_to_add.append(req_state) + # Track resumed requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) continue # Update the persistent batch. @@ -1163,6 +1234,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add spec_token_ids to token_ids_cpu. self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) + # Restore scheduler-side draft count after ngram trimming. + if original_num_spec_per_req: + orig = original_num_spec_per_req.get(req_id, 0) + if orig != req_state.prev_num_draft_len: + req_state.prev_num_draft_len = orig # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -1177,6 +1253,18 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + # Incrementally update ngram_gpu tensors after batch is stable + if is_ngram_gpu: + update_ngram_gpu_tensors_incremental( + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + ngram_gpu_new_reqs, + self.device, + _pinned_idx_buf=self._ngram_pinned_idx_buf, + _pinned_val_buf=self._ngram_pinned_val_buf, + ) + def _update_states_after_model_execute( self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" ) -> None: @@ -3399,6 +3487,23 @@ def execute_model( else: logger.error("RoutedExpertsCapturer not initialized.") + # If ngram_gpu is used, we need to copy the scheduler_output to avoid + # the modification has influence on the scheduler_output in engine core process. + # The replace is much faster than deepcopy. + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy() + spec_decode_tokens_copy = ( + scheduler_output.scheduled_spec_decode_tokens.copy() + ) + scheduler_output = replace( + scheduler_output, + num_scheduled_tokens=num_scheduled_tokens_copy, + scheduled_spec_decode_tokens=spec_decode_tokens_copy, + ) + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): get_kv_transfer_group().handle_preemptions( scheduler_output.preempted_req_ids @@ -3812,6 +3917,32 @@ def propose_draft_token_ids(sampled_token_ids): self._copy_valid_sampled_token_count( next_token_ids, valid_sampled_tokens_count ) + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32 + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) + elif ( + spec_config.use_ngram_gpu() + and not spec_config.disable_padded_drafter_batch + ): + assert isinstance(self.drafter, NgramProposerGPU) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count, _ = ( + self.drafter.update_token_ids_ngram( + sampled_token_ids, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) # Since we couldn't run the drafter, # just use zeros for the draft tokens. self._draft_token_ids = torch.zeros( @@ -4051,6 +4182,52 @@ def propose_draft_token_ids( self.input_batch.token_ids_cpu, slot_mappings=slot_mappings, ) + if isinstance(self.drafter, NgramProposer): + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when ngram is used." + ) + draft_token_ids = self.drafter.propose( + sampled_token_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + ) + elif spec_config.use_ngram_gpu(): + assert isinstance(self.drafter, NgramProposerGPU) + ( + next_token_ids, + valid_sampled_tokens_count, + valid_sampled_token_ids_gpu, + ) = self.drafter.update_token_ids_ngram( + sampled_token_ids, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + batch_size = next_token_ids.shape[0] + + draft_token_ids, num_valid_draft_tokens = self.drafter.propose( + self.num_tokens_no_spec_gpu[:batch_size], + self.token_ids_gpu_tensor[:batch_size], + valid_sampled_token_ids_gpu, + valid_sampled_tokens_count, + ) + + # Cache valid draft counts for scheduler-side trimming. + self._num_valid_draft_tokens = num_valid_draft_tokens + + # Async D2H copy on a dedicated stream. + copy_num_valid_draft_tokens( + self._num_valid_draft_tokens_cpu, + self._num_valid_draft_tokens_copy_stream, + self._num_valid_draft_tokens_event, + self._num_valid_draft_tokens, + self.input_batch.num_reqs, + ) elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer)