From 26c8a2fa820405cb0bb3fdb769577fdf5003c4c3 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Fri, 28 Nov 2025 21:07:17 +0000 Subject: [PATCH 01/10] WIP Signed-off-by: Benjamin Chislett --- vllm/v1/core/sched/scheduler.py | 7 ++++++ vllm/v1/engine/core.py | 33 +++++++++++++++++++++++++- vllm/v1/engine/processor.py | 9 +++---- vllm/v1/executor/abstract.py | 8 +++---- vllm/v1/executor/multiproc_executor.py | 4 ++-- vllm/v1/executor/ray_executor.py | 3 ++- vllm/v1/executor/uniproc_executor.py | 8 ++++--- vllm/v1/structured_output/__init__.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 20 ++++++++++++++-- vllm/v1/worker/gpu_worker.py | 4 ++-- vllm/v1/worker/tpu_model_runner.py | 3 ++- vllm/v1/worker/tpu_worker.py | 4 ++-- vllm/v1/worker/worker_base.py | 2 +- 13 files changed, 84 insertions(+), 24 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e3ec8440a932..4b339763913e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1244,8 +1244,13 @@ def update_draft_token_ids( request = self.requests.get(req_id) if request is None or request.is_finished(): # The request may have been finished. Skip. + logger.info(f"Ignoring draft token ids for finished request {req_id}") continue + logger.info( + f"Updating draft token ids for request {req_id}: {spec_token_ids}") + prev_spec_token_ids = request.spec_token_ids + # Add newly generated spec token ids to the request. if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request @@ -1254,6 +1259,8 @@ def update_draft_token_ids( ) else: request.spec_token_ids = spec_token_ids + logger.info( + f"Updated draft token ids for request {req_id}: {request.spec_token_ids} (from {prev_spec_token_ids})") def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e3a5f51a8fc5..a7222433087d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -448,12 +448,43 @@ def step_with_batch_queue( # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. if deferred_scheduler_output: + # Make sure we have the draft token ids + # self.model_executor.take_draft_token_ids() + draft_token_ids = self.model_executor.take_draft_token_ids() + num_reject_spec_tokens = {} + assert draft_token_ids is not None, "Draft token ids must be available" + for req_id, spec_token_ids in zip( + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, + ): + num_reject_spec_tokens[req_id] = 0 + # Add newly generated spec token ids to the request. + request = self.scheduler.requests.get(req_id) + if request is None or request.is_finished(): + # The request may have been finished. Skip. + assert request not in deferred_scheduler_output.scheduled_spec_decode_tokens + continue + orig_num_draft_tokens = len(spec_token_ids) + if self.scheduler.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids + ) + logger.info( + f"Updated draft token ids for request {req_id}: {spec_token_ids} (from {deferred_scheduler_output.scheduled_spec_decode_tokens[req_id]})") + num_reject_spec_tokens[req_id] = len(spec_token_ids) + while len(spec_token_ids) < orig_num_draft_tokens: + spec_token_ids.append(-2) + # Now, we delegate a clear meaning to token -2: + # this is a padding token that should be ignored. + deferred_scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids + # We now have the tokens needed to compute the bitmask for the # deferred request. Get the bitmask and call sample tokens. grammar_output = self.scheduler.get_grammar_bitmask( deferred_scheduler_output ) - future = self.model_executor.sample_tokens(grammar_output, non_block=True) + future = self.model_executor.sample_tokens(grammar_output, num_reject_spec_tokens, non_block=True) batch_queue.appendleft((future, deferred_scheduler_output)) return engine_core_outputs, model_executed diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index af4f0e410e25..db1771dcf6ce 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -161,10 +161,11 @@ def _validate_supported_sampling_params( or params.structured_outputs ) ): - raise ValueError( - "async scheduling with spec decoding doesn't yet support " - "penalties, bad words or structured outputs in sampling parameters." - ) + logger.warning_once("There be dragons here!") + # raise ValueError( + # "async scheduling with spec decoding doesn't yet support " + # "penalties, bad words or structured outputs in sampling parameters." + # ) def _validate_params( self, diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index db8303fcec50..c21ab59965a6 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -207,21 +207,21 @@ def execute_model( @overload def sample_tokens( - self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False + self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: Literal[False] = False ) -> ModelRunnerOutput: pass @overload def sample_tokens( - self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True + self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: Literal[True] = True ) -> Future[ModelRunnerOutput]: pass def sample_tokens( - self, grammar_output: GrammarOutput | None, non_block: bool = False + self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: bool = False ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: output = self.collective_rpc( # type: ignore[call-overload] - "sample_tokens", args=(grammar_output,), non_block=non_block + "sample_tokens", args=(grammar_output, num_reject_spec_tokens), non_block=non_block ) return output[0] diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 7e8ebe25c460..944e11faee20 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -266,11 +266,11 @@ def execute_model( # type: ignore[override] ) def sample_tokens( # type: ignore[override] - self, grammar_output: GrammarOutput | None, non_block: bool = False + self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: bool = False ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: return self.collective_rpc( "sample_tokens", - args=(grammar_output,), + args=(grammar_output, num_reject_spec_tokens), unique_reply_rank=self.output_rank, non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 406eafcd339b..ff0424670385 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -412,6 +412,7 @@ def execute_model( # type: ignore[override] def sample_tokens( # type: ignore[override] self, grammar_output: "GrammarOutput | None", + num_reject_spec_tokens: dict[str, int] | None = None, non_block: bool = False, ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: """Execute the model on the Ray workers. @@ -432,7 +433,7 @@ def sample_tokens( # type: ignore[override] self.scheduler_output = None - return self._execute_dag(scheduler_output, grammar_output, non_block) + return self._execute_dag(scheduler_output, grammar_output, num_reject_spec_tokens, non_block) def _execute_dag( self, diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 095d3d1dac21..00245130de26 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -102,17 +102,19 @@ def execute_model( # type: ignore[override] ) def sample_tokens( # type: ignore[override] - self, grammar_output: GrammarOutput | None, non_block: bool = False + self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: bool = False ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: return self.collective_rpc( "sample_tokens", - args=(grammar_output,), + args=(grammar_output, num_reject_spec_tokens), non_block=non_block, single_value=True, ) def take_draft_token_ids(self) -> DraftTokenIds | None: - return self.collective_rpc("take_draft_token_ids", single_value=True) + values = self.collective_rpc("take_draft_token_ids", single_value=True) + logger.info(f"EXECUTOR TAKE DRAFT TOKEN IDS RETURNED VALUES: {values}") + return values def check_health(self) -> None: # UniProcExecutor will always be healthy as long as diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 029129cf1a47..152224c7ece4 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -254,6 +254,8 @@ def grammar_bitmask( state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) for i, token in enumerate(req_tokens + [None]): + if token == -2: + apply_bitmask = False self._fill_bitmasks( [ ( @@ -263,7 +265,6 @@ def grammar_bitmask( ) ] ) - if ( apply_bitmask and token is not None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6bff83658b45..35efa0f3e0b7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1180,6 +1180,7 @@ def _prepare_input_ids( ) # Scatter the draft tokens after the sampled tokens are scattered. + self._prev_draft_token_ids = self._draft_token_ids if self._draft_token_ids is None or not spec_flattened_indices: return @@ -1194,7 +1195,7 @@ def _prepare_input_ids( # because input_ids dtype is torch.int32, # so convert draft_token_ids to torch.int32 here. draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) - self._draft_token_ids = None + # self._draft_token_ids = None self.input_ids.gpu.scatter_( dim=0, @@ -3036,10 +3037,12 @@ def execute_model( @torch.inference_mode def sample_tokens( - self, grammar_output: "GrammarOutput | None" + self, grammar_output: "GrammarOutput | None", num_reject_spec_tokens: dict[str, int] | None = None ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: kv_connector_output = self.kv_connector_output self.kv_connector_output = None + logger.info(f"GPU MR RUNNING SAMPLE WITH INCOMING DRAFT TOKEN IDS {self._draft_token_ids}") + self._draft_token_ids = None if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. @@ -3078,6 +3081,17 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + if num_reject_spec_tokens is not None: + for req_id, batch_index in self.input_batch.req_id_to_index.items(): + if req_id in num_reject_spec_tokens: + num_reject = num_reject_spec_tokens[req_id] + if num_reject > 0: + num_total = sampler_output.sampled_token_ids.shape[1] + num_maybe_accepted = num_total - num_reject + sampler_output.sampled_token_ids[ + batch_index, num_maybe_accepted: + ] = -1 # Invalidate rejected tokens + self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): @@ -3209,6 +3223,8 @@ def propose_draft_token_ids(sampled_token_ids): return async_output def take_draft_token_ids(self) -> DraftTokenIds | None: + logger.info("TAKE DRAFT TOKEN IDS CALLED ON GPU MODEL RUNNER") + logger.info(f"MR HAS DRAFT TOKEN IDS: {self._draft_token_ids}") if self._draft_token_ids is None: return None req_ids = self.input_batch.req_ids diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d0c6091ce2a6..ce2b337bfb36 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -525,9 +525,9 @@ def annotate_profile(self, scheduler_output): @torch.inference_mode() def sample_tokens( - self, grammar_output: "GrammarOutput | None" + self, grammar_output: "GrammarOutput | None", num_reject_spec_tokens: dict[str, int] | None = None ) -> ModelRunnerOutput | AsyncModelRunnerOutput: - return self.model_runner.sample_tokens(grammar_output) + return self.model_runner.sample_tokens(grammar_output, num_reject_spec_tokens) @torch.inference_mode() def execute_model( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 9c1fbfd24149..c28bd2a53002 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1131,11 +1131,12 @@ def execute_model( @torch.no_grad() def sample_tokens( - self, grammar_output: "GrammarOutput | None" + self, grammar_output: "GrammarOutput | None", num_reject_spec_tokens: dict[str, int] | None = None ) -> ModelRunnerOutput: if self.scheduler_output is None: # Nothing to do (PP non-final rank case), output isn't used. return None # type: ignore[return-value] + assert num_reject_spec_tokens is None scheduler_output = self.scheduler_output mm_embed_inputs = self.mm_embed_inputs self.scheduler_output = None diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index ce18ca6c3716..25b7bc5b991b 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -250,8 +250,8 @@ def determine_available_memory(self) -> int: tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size return int(tpu_kv_cache_bytes) - def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput: - return self.model_runner.sample_tokens(grammar_output) + def sample_tokens(self, grammar_output: "GrammarOutput", num_reject_spec_tokens: dict[str, int] | None = None) -> ModelRunnerOutput: + return self.model_runner.sample_tokens(grammar_output, num_reject_spec_tokens) def execute_model( self, scheduler_output: "SchedulerOutput" diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 57e7037e946e..fd38200494f8 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -134,7 +134,7 @@ def execute_model( raise NotImplementedError def sample_tokens( - self, grammar_output: GrammarOutput + self, grammar_output: GrammarOutput, num_reject_spec_tokens: dict[str, int] | None = None ) -> ModelRunnerOutput | AsyncModelRunnerOutput: """Should be called immediately after execute_model iff it returned None.""" raise NotImplementedError From 7766f2066d1a7a112f525eafb73b007ffa65429d Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 1 Dec 2025 22:48:57 +0000 Subject: [PATCH 02/10] WIP - functional Signed-off-by: Benjamin Chislett --- vllm/v1/core/sched/scheduler.py | 6 -- vllm/v1/engine/core.py | 6 +- vllm/v1/executor/uniproc_executor.py | 1 - vllm/v1/structured_output/__init__.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 83 +++++++++++++++++++++------ 5 files changed, 68 insertions(+), 32 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4b339763913e..0858769e1be6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1244,12 +1244,8 @@ def update_draft_token_ids( request = self.requests.get(req_id) if request is None or request.is_finished(): # The request may have been finished. Skip. - logger.info(f"Ignoring draft token ids for finished request {req_id}") continue - logger.info( - f"Updating draft token ids for request {req_id}: {spec_token_ids}") - prev_spec_token_ids = request.spec_token_ids # Add newly generated spec token ids to the request. if self.structured_output_manager.should_advance(request): @@ -1259,8 +1255,6 @@ def update_draft_token_ids( ) else: request.spec_token_ids = spec_token_ids - logger.info( - f"Updated draft token ids for request {req_id}: {request.spec_token_ids} (from {prev_spec_token_ids})") def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a7222433087d..36fe5074f72d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -457,8 +457,6 @@ def step_with_batch_queue( draft_token_ids.req_ids, draft_token_ids.draft_token_ids, ): - num_reject_spec_tokens[req_id] = 0 - # Add newly generated spec token ids to the request. request = self.scheduler.requests.get(req_id) if request is None or request.is_finished(): # The request may have been finished. Skip. @@ -470,9 +468,7 @@ def step_with_batch_queue( spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] spec_token_ids ) - logger.info( - f"Updated draft token ids for request {req_id}: {spec_token_ids} (from {deferred_scheduler_output.scheduled_spec_decode_tokens[req_id]})") - num_reject_spec_tokens[req_id] = len(spec_token_ids) + num_reject_spec_tokens[req_id] = orig_num_draft_tokens - len(spec_token_ids) while len(spec_token_ids) < orig_num_draft_tokens: spec_token_ids.append(-2) # Now, we delegate a clear meaning to token -2: diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 00245130de26..82c3e42dd549 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -113,7 +113,6 @@ def sample_tokens( # type: ignore[override] def take_draft_token_ids(self) -> DraftTokenIds | None: values = self.collective_rpc("take_draft_token_ids", single_value=True) - logger.info(f"EXECUTOR TAKE DRAFT TOKEN IDS RETURNED VALUES: {values}") return values def check_health(self) -> None: diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 152224c7ece4..4e874462e92a 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -254,8 +254,6 @@ def grammar_bitmask( state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) for i, token in enumerate(req_tokens + [None]): - if token == -2: - apply_bitmask = False self._fill_bitmasks( [ ( @@ -265,6 +263,8 @@ def grammar_bitmask( ) ] ) + if token == -2: + apply_bitmask = False if ( apply_bitmask and token is not None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 35efa0f3e0b7..aaa033e24d1e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -574,6 +574,7 @@ def __init__( # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + self._draft_token_req_ids: list[str] | None = None self.transfer_event = torch.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_num_reqs, 1), @@ -582,6 +583,10 @@ def __init__( pin_memory=self.pin_memory, ) + self.invalid_spec_tokens_mask = self._make_buffer( + (self.max_num_reqs, 1 + self.num_spec_tokens), dtype=torch.bool + ) + # Pre-allocated tensor for copying valid sampled token counts to CPU, # with dedicated stream for overlapping and event for coordination. self.valid_sampled_token_count_event: torch.Event | None = None @@ -596,6 +601,20 @@ def __init__( pin_memory=self.pin_memory, ) + # We also copy the drafted tokens to the CPU asynchronously, + # in case we need them for structured outputs. + self.draft_token_ids_event: torch.Event | None = None + self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None + if self.use_async_scheduling: + self.draft_token_ids_event = torch.Event() + self.draft_token_ids_copy_stream = torch.cuda.Stream() + self.draft_token_ids_cpu = torch.empty( + (self.max_num_reqs, self.num_spec_tokens), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None @@ -3041,8 +3060,9 @@ def sample_tokens( ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: kv_connector_output = self.kv_connector_output self.kv_connector_output = None - logger.info(f"GPU MR RUNNING SAMPLE WITH INCOMING DRAFT TOKEN IDS {self._draft_token_ids}") + # logger.info(f"GPU MR RUNNING SAMPLE WITH INCOMING DRAFT TOKEN IDS {self._draft_token_ids}") self._draft_token_ids = None + self._draft_token_req_ids = None if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. @@ -3082,15 +3102,21 @@ def sample_tokens( sampler_output = self._sample(logits, spec_decode_metadata) if num_reject_spec_tokens is not None: - for req_id, batch_index in self.input_batch.req_id_to_index.items(): - if req_id in num_reject_spec_tokens: - num_reject = num_reject_spec_tokens[req_id] - if num_reject > 0: - num_total = sampler_output.sampled_token_ids.shape[1] - num_maybe_accepted = num_total - num_reject - sampler_output.sampled_token_ids[ - batch_index, num_maybe_accepted: - ] = -1 # Invalidate rejected tokens + num_reqs, num_sampled_toks = sampler_output.sampled_token_ids.shape + num_invalid_spec_tokens_cpu = torch.zeros( + (num_reqs,), dtype=torch.int32 + ) + for req_id, num_invalid_toks in num_reject_spec_tokens.items(): + req_index = self.input_batch.req_id_to_index[req_id] + num_invalid_spec_tokens_cpu[req_index] = num_invalid_toks + col_indices = torch.arange(num_sampled_toks, dtype=torch.int32) + mask_start_indices = num_sampled_toks - num_invalid_spec_tokens_cpu + mask = col_indices.unsqueeze(0) >= mask_start_indices.unsqueeze(1) + self.invalid_spec_tokens_mask.cpu[:num_reqs, :].copy_(mask) + self.invalid_spec_tokens_mask.copy_to_gpu(num_reqs) + sampler_output.sampled_token_ids.masked_fill_( + self.invalid_spec_tokens_mask.gpu[:num_reqs, :], -1 + ) self.input_batch.prev_sampled_token_ids = None @@ -3107,6 +3133,8 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, spec_decode_common_attn_metadata, ) + self._copy_draft_token_ids_to_cpu(self._draft_token_ids) + self._draft_token_req_ids = deepcopy(self.input_batch.req_ids) spec_config = self.speculative_config use_padded_batch_for_eagle = ( @@ -3223,18 +3251,37 @@ def propose_draft_token_ids(sampled_token_ids): return async_output def take_draft_token_ids(self) -> DraftTokenIds | None: - logger.info("TAKE DRAFT TOKEN IDS CALLED ON GPU MODEL RUNNER") - logger.info(f"MR HAS DRAFT TOKEN IDS: {self._draft_token_ids}") if self._draft_token_ids is None: return None - req_ids = self.input_batch.req_ids - if isinstance(self._draft_token_ids, torch.Tensor): - draft_token_ids = self._draft_token_ids.tolist() - else: - draft_token_ids = self._draft_token_ids - self._draft_token_ids = None + req_ids = self._draft_token_req_ids + draft_token_ids = self._get_draft_token_ids_cpu(len(req_ids)) return DraftTokenIds(req_ids, draft_token_ids) + def _copy_draft_token_ids_to_cpu( + self, draft_token_ids: torch.Tensor | list[list[int]] + ) -> None: + if isinstance(draft_token_ids, list): + return + if self.draft_token_ids_event is None: + return + # For async scheduling, trigger async copy of draft token ids to cpu. + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.draft_token_ids_copy_stream): + assert self.draft_token_ids_copy_stream is not None + self.draft_token_ids_copy_stream.wait_stream(default_stream) + self.draft_token_ids_cpu[: draft_token_ids.shape[0]].copy_( + draft_token_ids, non_blocking=True + ) + self.draft_token_ids_event.record() + + def _get_draft_token_ids_cpu(self, num_reqs: int) -> list[list[int]]: + if isinstance(self._draft_token_ids, list): + return self._draft_token_ids + if self.draft_token_ids_event is None: + return [] + self.draft_token_ids_event.synchronize() + return self.draft_token_ids_cpu[0:num_reqs].tolist() + def _copy_valid_sampled_token_count( self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor ) -> None: From 94fc4a41701ab120ee42e60fc6699a2238361ae3 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 2 Dec 2025 19:31:57 +0000 Subject: [PATCH 03/10] revert changes to processor.py Signed-off-by: Benjamin Chislett --- vllm/v1/engine/processor.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index db1771dcf6ce..af4f0e410e25 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -161,11 +161,10 @@ def _validate_supported_sampling_params( or params.structured_outputs ) ): - logger.warning_once("There be dragons here!") - # raise ValueError( - # "async scheduling with spec decoding doesn't yet support " - # "penalties, bad words or structured outputs in sampling parameters." - # ) + raise ValueError( + "async scheduling with spec decoding doesn't yet support " + "penalties, bad words or structured outputs in sampling parameters." + ) def _validate_params( self, From 5caf84e7ea6970497482034f5e259e0495af2420 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 2 Dec 2025 19:33:15 +0000 Subject: [PATCH 04/10] update input_processor.py Signed-off-by: Benjamin Chislett --- vllm/v1/engine/input_processor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index e6a94f4e3de5..e2db675214f1 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -157,12 +157,11 @@ def _validate_supported_sampling_params( or params.presence_penalty != 0.0 or params.repetition_penalty != 1.0 or params.bad_words_token_ids - or params.structured_outputs ) ): raise ValueError( "async scheduling with spec decoding doesn't yet support " - "penalties, bad words or structured outputs in sampling parameters." + "penalties or bad words in sampling parameters." ) def _validate_params( From bc24d530a9faaa62bd606a08a3bbf3261588e94b Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 2 Dec 2025 23:11:02 +0000 Subject: [PATCH 05/10] refactor some more Signed-off-by: Benjamin Chislett --- vllm/v1/core/sched/interface.py | 16 ++++++- vllm/v1/core/sched/output.py | 2 + vllm/v1/core/sched/scheduler.py | 23 +++++++--- vllm/v1/engine/core.py | 59 +++++++++++++------------- vllm/v1/executor/abstract.py | 8 ++-- vllm/v1/executor/multiproc_executor.py | 4 +- vllm/v1/executor/ray_executor.py | 3 +- vllm/v1/executor/uniproc_executor.py | 4 +- vllm/v1/structured_output/__init__.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 26 ++++++++---- vllm/v1/worker/gpu_worker.py | 5 ++- vllm/v1/worker/tpu_model_runner.py | 3 +- vllm/v1/worker/tpu_worker.py | 4 +- vllm/v1/worker/worker_base.py | 2 +- 14 files changed, 100 insertions(+), 62 deletions(-) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index c2f503ef2354..8b7f59ab0df7 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -88,8 +88,22 @@ def update_from_output( def update_draft_token_ids( self, draft_token_ids: "DraftTokenIds", + update_requests: bool = True, + update_scheduler_output: "SchedulerOutput | None" = None, + pad_filtered_draft_tokens: bool = False, ) -> None: - """Update the draft token ids for the scheduled requests.""" + """Update requests with newly generated draft token ids, applying + structured output grammar validation if needed. + + Args: + draft_token_ids: The input draft token ids for each request. + update_requests: If True, update the scheduler's `self.requests` state + with the draft token ids. + update_scheduler_output: If provided, update the given scheduler_output + with the corresponding draft token ids. + pad_filtered_draft_tokens: If True, pad the draft token ids so that + the length does not change after filtering invalid draft tokens. + """ raise NotImplementedError @abstractmethod diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b69fa87ebddc..89171cb440d0 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -228,3 +228,5 @@ class GrammarOutput: structured_output_request_ids: list[str] # Bitmask ordered as structured_output_request_ids. grammar_bitmask: "npt.NDArray[np.int32]" + # Number of invalid tokens per structured output request. + num_invalid_tokens_per_req: list[int] | None = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4e84cb3129b9..14459d96f75d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1277,6 +1277,9 @@ def _free_encoder_inputs(self, request: Request) -> None: def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, + update_requests: bool = True, + update_scheduler_output: SchedulerOutput | None = None, + pad_filtered_draft_tokens: bool = False, ) -> None: for req_id, spec_token_ids in zip( draft_token_ids.req_ids, @@ -1287,15 +1290,25 @@ def update_draft_token_ids( # The request may have been finished. Skip. continue - - # Add newly generated spec token ids to the request. + # Filter out spec tokens which do not adhere to the grammar. + orig_num_spec_tokens = len(spec_token_ids) if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids + assert metadata.grammar is not None + spec_token_ids = metadata.grammar.validate_tokens( + spec_token_ids, ) - else: + if pad_filtered_draft_tokens: + num_invalid_tokens = orig_num_spec_tokens - len(spec_token_ids) + spec_token_ids.extend([-1] * num_invalid_tokens) + + # Update the scheduler state. + if update_requests: request.spec_token_ids = spec_token_ids + if update_scheduler_output is not None: + update_scheduler_output.scheduled_spec_decode_tokens[req_id] = ( + spec_token_ids + ) def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 1e243667ee03..f7c8c857b7c5 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -448,39 +448,40 @@ def step_with_batch_queue( # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. if deferred_scheduler_output: - # Make sure we have the draft token ids - # self.model_executor.take_draft_token_ids() + # If we are doing speculative decoding with structured output, + # we need to get the draft token ids from the prior step before + # we can compute the grammar bitmask for the deferred request. draft_token_ids = self.model_executor.take_draft_token_ids() - num_reject_spec_tokens = {} - assert draft_token_ids is not None, "Draft token ids must be available" - for req_id, spec_token_ids in zip( - draft_token_ids.req_ids, - draft_token_ids.draft_token_ids, - ): - request = self.scheduler.requests.get(req_id) - if request is None or request.is_finished(): - # The request may have been finished. Skip. - assert request not in deferred_scheduler_output.scheduled_spec_decode_tokens - continue - orig_num_draft_tokens = len(spec_token_ids) - if self.scheduler.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids - ) - num_reject_spec_tokens[req_id] = orig_num_draft_tokens - len(spec_token_ids) - while len(spec_token_ids) < orig_num_draft_tokens: - spec_token_ids.append(-2) - # Now, we delegate a clear meaning to token -2: - # this is a padding token that should be ignored. - deferred_scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids - - # We now have the tokens needed to compute the bitmask for the - # deferred request. Get the bitmask and call sample tokens. + num_invalid_spec_tokens = None + if draft_token_ids is not None: + # Update the draft token ids on the scheduler output + # to filter out the invalid spec tokens, which will be padded with -1 + # and ignored by the grammar bitmask computation. + self.scheduler.update_draft_token_ids( + draft_token_ids, + update_requests=False, + update_scheduler_output=deferred_scheduler_output, + pad_filtered_draft_tokens=True, + ) + scheduled_spec_tokens = ( + deferred_scheduler_output.scheduled_spec_decode_tokens + ) + num_invalid_spec_tokens = { + req_id: sum(token_id == -1 for token_id in spec_token_ids) + for req_id, spec_token_ids in scheduled_spec_tokens.items() + } + + # Compute the grammar bitmask using the draft tokens (if any), + # and then unblock the model executor. grammar_output = self.scheduler.get_grammar_bitmask( deferred_scheduler_output ) - future = self.model_executor.sample_tokens(grammar_output, num_reject_spec_tokens, non_block=True) + if num_invalid_spec_tokens and grammar_output is not None: + grammar_output.num_invalid_tokens_per_req = [ + num_invalid_spec_tokens.get(req_id, 0) + for req_id in grammar_output.structured_output_request_ids + ] + future = self.model_executor.sample_tokens(grammar_output, non_block=True) batch_queue.appendleft((future, deferred_scheduler_output)) return engine_core_outputs, model_executed diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index c21ab59965a6..db8303fcec50 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -207,21 +207,21 @@ def execute_model( @overload def sample_tokens( - self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: Literal[False] = False + self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False ) -> ModelRunnerOutput: pass @overload def sample_tokens( - self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: Literal[True] = True + self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True ) -> Future[ModelRunnerOutput]: pass def sample_tokens( - self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: bool = False + self, grammar_output: GrammarOutput | None, non_block: bool = False ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: output = self.collective_rpc( # type: ignore[call-overload] - "sample_tokens", args=(grammar_output, num_reject_spec_tokens), non_block=non_block + "sample_tokens", args=(grammar_output,), non_block=non_block ) return output[0] diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 944e11faee20..7e8ebe25c460 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -266,11 +266,11 @@ def execute_model( # type: ignore[override] ) def sample_tokens( # type: ignore[override] - self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: bool = False + self, grammar_output: GrammarOutput | None, non_block: bool = False ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: return self.collective_rpc( "sample_tokens", - args=(grammar_output, num_reject_spec_tokens), + args=(grammar_output,), unique_reply_rank=self.output_rank, non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index ff0424670385..406eafcd339b 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -412,7 +412,6 @@ def execute_model( # type: ignore[override] def sample_tokens( # type: ignore[override] self, grammar_output: "GrammarOutput | None", - num_reject_spec_tokens: dict[str, int] | None = None, non_block: bool = False, ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: """Execute the model on the Ray workers. @@ -433,7 +432,7 @@ def sample_tokens( # type: ignore[override] self.scheduler_output = None - return self._execute_dag(scheduler_output, grammar_output, num_reject_spec_tokens, non_block) + return self._execute_dag(scheduler_output, grammar_output, non_block) def _execute_dag( self, diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 82c3e42dd549..69a35fbcf1e2 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -102,11 +102,11 @@ def execute_model( # type: ignore[override] ) def sample_tokens( # type: ignore[override] - self, grammar_output: GrammarOutput | None, num_reject_spec_tokens: dict[str, int] | None = None, non_block: bool = False + self, grammar_output: GrammarOutput | None, non_block: bool = False ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: return self.collective_rpc( "sample_tokens", - args=(grammar_output, num_reject_spec_tokens), + args=(grammar_output,), non_block=non_block, single_value=True, ) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index bcbaad8420b6..0b7715ff44e0 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -263,7 +263,8 @@ def grammar_bitmask( ) ] ) - if token == -2: + if token == -1: + # Stop advancing the grammar once we hit a padding token apply_bitmask = False if ( apply_bitmask diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bdb37722bbc4..bcdd1c8667da 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3096,11 +3096,11 @@ def execute_model( @torch.inference_mode def sample_tokens( - self, grammar_output: "GrammarOutput | None", num_reject_spec_tokens: dict[str, int] | None = None + self, + grammar_output: "GrammarOutput | None", ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: kv_connector_output = self.kv_connector_output self.kv_connector_output = None - # logger.info(f"GPU MR RUNNING SAMPLE WITH INCOMING DRAFT TOKEN IDS {self._draft_token_ids}") self._draft_token_ids = None self._draft_token_req_ids = None @@ -3141,12 +3141,17 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) - if num_reject_spec_tokens is not None: + if ( + grammar_output is not None + and grammar_output.num_invalid_tokens_per_req is not None + and self.use_async_scheduling + ): + num_invalid_tokens_per_req = grammar_output.num_invalid_tokens_per_req num_reqs, num_sampled_toks = sampler_output.sampled_token_ids.shape - num_invalid_spec_tokens_cpu = torch.zeros( - (num_reqs,), dtype=torch.int32 - ) - for req_id, num_invalid_toks in num_reject_spec_tokens.items(): + num_invalid_spec_tokens_cpu = torch.zeros((num_reqs,), dtype=torch.int32) + for req_id, num_invalid_toks in zip( + grammar_output.structured_output_request_ids, num_invalid_tokens_per_req + ): req_index = self.input_batch.req_id_to_index[req_id] num_invalid_spec_tokens_cpu[req_index] = num_invalid_toks col_indices = torch.arange(num_sampled_toks, dtype=torch.int32) @@ -3293,6 +3298,7 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: if self._draft_token_ids is None: return None req_ids = self._draft_token_req_ids + assert req_ids is not None draft_token_ids = self._get_draft_token_ids_cpu(len(req_ids)) return DraftTokenIds(req_ids, draft_token_ids) @@ -3312,12 +3318,14 @@ def _copy_draft_token_ids_to_cpu( draft_token_ids, non_blocking=True ) self.draft_token_ids_event.record() - + def _get_draft_token_ids_cpu(self, num_reqs: int) -> list[list[int]]: if isinstance(self._draft_token_ids, list): return self._draft_token_ids if self.draft_token_ids_event is None: - return [] + if self._draft_token_ids is None: + return [] + return self._draft_token_ids.tolist() self.draft_token_ids_event.synchronize() return self.draft_token_ids_cpu[0:num_reqs].tolist() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b214bdf7c83a..57883625adad 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -535,9 +535,10 @@ def annotate_profile(self, scheduler_output): @torch.inference_mode() def sample_tokens( - self, grammar_output: "GrammarOutput | None", num_reject_spec_tokens: dict[str, int] | None = None + self, + grammar_output: "GrammarOutput | None", ) -> ModelRunnerOutput | AsyncModelRunnerOutput: - return self.model_runner.sample_tokens(grammar_output, num_reject_spec_tokens) + return self.model_runner.sample_tokens(grammar_output) @torch.inference_mode() def execute_model( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0c2225788a80..f3dd9aa96d2a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1131,12 +1131,11 @@ def execute_model( @torch.no_grad() def sample_tokens( - self, grammar_output: "GrammarOutput | None", num_reject_spec_tokens: dict[str, int] | None = None + self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput: if self.scheduler_output is None: # Nothing to do (PP non-final rank case), output isn't used. return None # type: ignore[return-value] - assert num_reject_spec_tokens is None scheduler_output = self.scheduler_output mm_embed_inputs = self.mm_embed_inputs self.scheduler_output = None diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 25b7bc5b991b..ce18ca6c3716 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -250,8 +250,8 @@ def determine_available_memory(self) -> int: tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size return int(tpu_kv_cache_bytes) - def sample_tokens(self, grammar_output: "GrammarOutput", num_reject_spec_tokens: dict[str, int] | None = None) -> ModelRunnerOutput: - return self.model_runner.sample_tokens(grammar_output, num_reject_spec_tokens) + def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput: + return self.model_runner.sample_tokens(grammar_output) def execute_model( self, scheduler_output: "SchedulerOutput" diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index fd38200494f8..57e7037e946e 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -134,7 +134,7 @@ def execute_model( raise NotImplementedError def sample_tokens( - self, grammar_output: GrammarOutput, num_reject_spec_tokens: dict[str, int] | None = None + self, grammar_output: GrammarOutput ) -> ModelRunnerOutput | AsyncModelRunnerOutput: """Should be called immediately after execute_model iff it returned None.""" raise NotImplementedError From f98d33953b9546a3400078c106b6aa6a54811d99 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 4 Dec 2025 18:28:39 +0000 Subject: [PATCH 06/10] tidy Signed-off-by: Benjamin Chislett --- vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/executor/uniproc_executor.py | 3 +-- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 14459d96f75d..54117d653551 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1294,7 +1294,7 @@ def update_draft_token_ids( orig_num_spec_tokens = len(spec_token_ids) if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request - assert metadata.grammar is not None + assert metadata is not None and metadata.grammar is not None spec_token_ids = metadata.grammar.validate_tokens( spec_token_ids, ) diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 69a35fbcf1e2..095d3d1dac21 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -112,8 +112,7 @@ def sample_tokens( # type: ignore[override] ) def take_draft_token_ids(self) -> DraftTokenIds | None: - values = self.collective_rpc("take_draft_token_ids", single_value=True) - return values + return self.collective_rpc("take_draft_token_ids", single_value=True) def check_health(self) -> None: # UniProcExecutor will always be healthy as long as diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bcdd1c8667da..44f06621a48b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1263,7 +1263,6 @@ def _prepare_input_ids( # because input_ids dtype is torch.int32, # so convert draft_token_ids to torch.int32 here. draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) - # self._draft_token_ids = None self.input_ids.gpu.scatter_( dim=0, @@ -3141,6 +3140,7 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + # Mask out invalid spec tokens for async scheduling + structured outputs. if ( grammar_output is not None and grammar_output.num_invalid_tokens_per_req is not None @@ -3179,7 +3179,7 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_common_attn_metadata, ) self._copy_draft_token_ids_to_cpu(self._draft_token_ids) - self._draft_token_req_ids = deepcopy(self.input_batch.req_ids) + self._draft_token_req_ids = self.input_batch.req_ids.copy() spec_config = self.speculative_config use_padded_batch_for_eagle = ( From a0870a00e9ff4bdff65879b77f3a182447981fee Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Fri, 19 Dec 2025 22:48:32 +0000 Subject: [PATCH 07/10] small patch Signed-off-by: Benjamin Chislett --- vllm/v1/worker/gpu_model_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 30780fd2eafa..abcaa2a403a5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3443,6 +3443,9 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: if not self.num_spec_tokens: return None req_ids = self._draft_token_req_ids + if req_ids is None: + req_ids = self.input_batch.req_ids + return DraftTokenIds(req_ids, [[] for _ in req_ids]) assert req_ids is not None draft_token_ids = self._get_draft_token_ids_cpu(len(req_ids)) if draft_token_ids is None: From 20ea36bf534d68605948e04051bbdfb771477734 Mon Sep 17 00:00:00 2001 From: njhill Date: Thu, 1 Jan 2026 10:57:54 -0800 Subject: [PATCH 08/10] rework/simplify, fix remaining issues - revert some unrelated changes - separate update_draft_tokens_in_output() method - avoid propagating invalid_spec_tokens separately - change fits_in_drafter=False case to propose zeros - fix chunked prefill case by trimming draft tokens - fix acceptance rate metric to take grammar-rejected tokens into account - use separate-stream draft token cpu copy for non-async sched case too - only copy draft tokens to cpu when needed Signed-off-by: njhill --- tests/v1/e2e/test_async_scheduling.py | 14 ++- vllm/v1/core/sched/async_scheduler.py | 3 + vllm/v1/core/sched/interface.py | 26 ++--- vllm/v1/core/sched/output.py | 9 +- vllm/v1/core/sched/scheduler.py | 76 ++++++++++----- vllm/v1/engine/core.py | 36 ++----- vllm/v1/worker/gpu_model_runner.py | 133 ++++++++++---------------- vllm/v1/worker/gpu_worker.py | 3 +- 8 files changed, 148 insertions(+), 152 deletions(-) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 6447a33838d7..ccfe420f11bc 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -30,8 +30,9 @@ default_params = dict( temperature=0.0, # greedy - max_tokens=23, - min_tokens=18, + max_tokens=30, + # spec decoding currently doesn't support min_tokens + # min_tokens=28, ) @@ -86,7 +87,7 @@ def test_without_spec_decoding( run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) -def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): +def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch): """Test consistency and acceptance rates with some different combos of preemption, executor, async scheduling, prefill chunking, spec decoding model length. @@ -100,9 +101,16 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): # Set small draft model len to force doesn't-fit-in-drafter case. spec_config_short = spec_config | {"max_model_len": 50} + struct_outputs = StructuredOutputsParams(json=sample_json_schema) + test_sampling_params = [ dict(), dict(logprobs=2), + dict(structured_outputs=struct_outputs), + dict( + structured_outputs=struct_outputs, + logprobs=2, + ), ] # test_preemption, executor, async_scheduling, diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index a2e1b71e142b..3c66a23208ec 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -12,10 +12,12 @@ class AsyncScheduler(Scheduler): def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: super()._update_after_schedule(scheduler_output) + has_structured_output_requests = False pending_structured_output_tokens = False spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] + has_structured_output_requests |= request.use_structured_output pending_structured_output_tokens |= ( request.use_structured_output and request.num_output_placeholders > 0 ) @@ -33,6 +35,7 @@ def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: # We will update the actual spec token ids in the worker process. request.spec_token_ids = [-1] * self.num_spec_tokens + scheduler_output.has_structured_output_requests = has_structured_output_requests scheduler_output.pending_structured_output_tokens = ( pending_structured_output_tokens ) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index a4fa5f922090..66071f6c4c1a 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -85,24 +85,26 @@ def update_from_output( raise NotImplementedError @abstractmethod - def update_draft_token_ids( - self, - draft_token_ids: "DraftTokenIds", - update_requests: bool = True, - update_scheduler_output: "SchedulerOutput | None" = None, - pad_filtered_draft_tokens: bool = False, - ) -> None: + def update_draft_token_ids(self, draft_token_ids: "DraftTokenIds") -> None: """Update requests with newly generated draft token ids, applying structured output grammar validation if needed. Args: draft_token_ids: The input draft token ids for each request. - update_requests: If True, update the scheduler's `self.requests` state - with the draft token ids. - update_scheduler_output: If provided, update the given scheduler_output + """ + raise NotImplementedError + + @abstractmethod + def update_draft_token_ids_in_output( + self, draft_token_ids: "DraftTokenIds", scheduler_output: "SchedulerOutput" + ) -> None: + """Update scheduler output with newly generated draft token ids, applying + structured output grammar validation if needed. + + Args: + draft_token_ids: The input draft token ids for each request. + scheduler_output: Update the given scheduler_output with the corresponding draft token ids. - pad_filtered_draft_tokens: If True, pad the draft token ids so that - the length does not change after filtering invalid draft tokens. """ raise NotImplementedError diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 89171cb440d0..d2a23a115334 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -197,10 +197,17 @@ class SchedulerOutput: # Only used for v2 model runner. preempted_req_ids: set[str] | None = None + # Whether any of the scheduled requests use structured output. + # Set only in async scheduling case. + has_structured_output_requests: bool = False + # Whether the scheduled requests have all the output tokens they # need to perform grammar bitmask computation. pending_structured_output_tokens: bool = False + # Used for adjusting acceptance rate calculation. + num_invalid_spec_tokens: dict[str, int] | None = None + # KV Cache Connector metadata. kv_connector_metadata: KVConnectorMetadata | None = None @@ -228,5 +235,3 @@ class GrammarOutput: structured_output_request_ids: list[str] # Bitmask ordered as structured_output_request_ids. grammar_bitmask: "npt.NDArray[np.int32]" - # Number of invalid tokens per structured output request. - num_invalid_tokens_per_req: list[int] | None = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 11e3d46949b2..419ab1415394 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1130,6 +1130,8 @@ def update_from_output( spec_decoding_stats, num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted, + num_invalid_spec_tokens=scheduler_output.num_invalid_spec_tokens, + request_id=req_id, ) stopped = False @@ -1168,7 +1170,13 @@ def update_from_output( struct_output_request = request.structured_output_request assert struct_output_request is not None assert struct_output_request.grammar is not None - struct_output_request.grammar.accept_tokens(req_id, new_token_ids) + ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids) + if not ok: + logger.warning( + "Unexpected: grammar rejected tokens %s for request %s.", + new_token_ids, + req_id, + ) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] @@ -1317,13 +1325,28 @@ def _free_encoder_inputs(self, request: Request) -> None: # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input(request, input_id) - def update_draft_token_ids( - self, - draft_token_ids: DraftTokenIds, - update_requests: bool = True, - update_scheduler_output: SchedulerOutput | None = None, - pad_filtered_draft_tokens: bool = False, + def update_draft_token_ids(self, draft_token_ids: DraftTokenIds) -> None: + for req_id, spec_token_ids in zip( + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, + ): + request = self.requests.get(req_id) + if request is None or request.is_finished(): + # The request may have been finished. Skip. + continue + + # Add newly generated spec token ids to the request. + if self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) + request.spec_token_ids = spec_token_ids + + def update_draft_token_ids_in_output( + self, draft_token_ids: DraftTokenIds, scheduler_output: SchedulerOutput ) -> None: + num_invalid_spec_tokens: dict[str, int] = {} + + sched_spec_tokens = scheduler_output.scheduled_spec_decode_tokens for req_id, spec_token_ids in zip( draft_token_ids.req_ids, draft_token_ids.draft_token_ids, @@ -1333,25 +1356,28 @@ def update_draft_token_ids( # The request may have been finished. Skip. continue + placeholder_spec_tokens = sched_spec_tokens.get(req_id) + if not placeholder_spec_tokens: + continue + + orig_num_spec_tokens = len(placeholder_spec_tokens) + # Trim drafts to scheduled number of spec tokens + # (needed for chunked prefill case for example). + del spec_token_ids[orig_num_spec_tokens:] # Filter out spec tokens which do not adhere to the grammar. - orig_num_spec_tokens = len(spec_token_ids) if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request assert metadata is not None and metadata.grammar is not None - spec_token_ids = metadata.grammar.validate_tokens( - spec_token_ids, - ) - if pad_filtered_draft_tokens: - num_invalid_tokens = orig_num_spec_tokens - len(spec_token_ids) - spec_token_ids.extend([-1] * num_invalid_tokens) - - # Update the scheduler state. - if update_requests: - request.spec_token_ids = spec_token_ids - if update_scheduler_output is not None: - update_scheduler_output.scheduled_spec_decode_tokens[req_id] = ( - spec_token_ids - ) + spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) + # Pad to original number of spec tokens. + num_invalid_tokens = orig_num_spec_tokens - len(spec_token_ids) + if num_invalid_tokens: + spec_token_ids.extend([-1] * num_invalid_tokens) + num_invalid_spec_tokens[req_id] = num_invalid_tokens + + sched_spec_tokens[req_id] = spec_token_ids + + scheduler_output.num_invalid_spec_tokens = num_invalid_spec_tokens def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" @@ -1530,11 +1556,15 @@ def make_spec_decoding_stats( spec_decoding_stats: SpecDecodingStats | None, num_draft_tokens: int, num_accepted_tokens: int, + num_invalid_spec_tokens: dict[str, int] | None, + request_id: str, ) -> SpecDecodingStats | None: - if not self.log_stats: + if not self.log_stats or not num_draft_tokens: return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) + if num_invalid_spec_tokens: + num_draft_tokens -= num_invalid_spec_tokens.get(request_id, 0) spec_decoding_stats.observe_draft( num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens ) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0c92b7e0f370..dcf76da6a09f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -469,36 +469,20 @@ def step_with_batch_queue( # If we are doing speculative decoding with structured output, # we need to get the draft token ids from the prior step before # we can compute the grammar bitmask for the deferred request. - draft_token_ids = self.model_executor.take_draft_token_ids() - num_invalid_spec_tokens = None - if draft_token_ids is not None: - # Update the draft token ids on the scheduler output - # to filter out the invalid spec tokens, which will be padded with -1 - # and ignored by the grammar bitmask computation. - self.scheduler.update_draft_token_ids( - draft_token_ids, - update_requests=False, - update_scheduler_output=deferred_scheduler_output, - pad_filtered_draft_tokens=True, - ) - scheduled_spec_tokens = ( - deferred_scheduler_output.scheduled_spec_decode_tokens + if self.use_spec_decode: + draft_token_ids = self.model_executor.take_draft_token_ids() + assert draft_token_ids is not None + # Update the draft token ids in the scheduler output to + # filter out the invalid spec tokens, which will be padded + # with -1 and skipped by the grammar bitmask computation. + self.scheduler.update_draft_token_ids_in_output( + draft_token_ids, deferred_scheduler_output ) - num_invalid_spec_tokens = { - req_id: sum(token_id == -1 for token_id in spec_token_ids) - for req_id, spec_token_ids in scheduled_spec_tokens.items() - } - - # Compute the grammar bitmask using the draft tokens (if any), - # and then unblock the model executor. + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. grammar_output = self.scheduler.get_grammar_bitmask( deferred_scheduler_output ) - if num_invalid_spec_tokens and grammar_output is not None: - grammar_output.num_invalid_tokens_per_req = [ - num_invalid_spec_tokens.get(req_id, 0) - for req_id in grammar_output.structured_output_request_ids - ] future = self.model_executor.sample_tokens(grammar_output, non_block=True) batch_queue.appendleft((future, deferred_scheduler_output, exec_future)) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a52b50455a1e..034b6c66bfb8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -635,37 +635,34 @@ def __init__( pin_memory=self.pin_memory, ) - self.invalid_spec_tokens_mask = self._make_buffer( - (self.max_num_reqs, 1 + self.num_spec_tokens), dtype=torch.bool - ) - # Pre-allocated tensor for copying valid sampled token counts to CPU, # with dedicated stream for overlapping and event for coordination. self.valid_sampled_token_count_event: torch.Event | None = None self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None - if self.use_async_scheduling and self.num_spec_tokens: - self.valid_sampled_token_count_event = torch.Event() - self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() - self.valid_sampled_token_count_cpu = torch.empty( - self.max_num_reqs, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory, - ) - # We also copy the drafted tokens to the CPU asynchronously, # in case we need them for structured outputs. self.draft_token_ids_event: torch.Event | None = None self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None - if self.use_async_scheduling: + self.valid_sampled_token_count_cpu: torch.Tensor | None = None + self.draft_token_ids_cpu: torch.Tensor | None = None + if self.num_spec_tokens: self.draft_token_ids_event = torch.Event() self.draft_token_ids_copy_stream = torch.cuda.Stream() - self.draft_token_ids_cpu = torch.empty( - (self.max_num_reqs, self.num_spec_tokens), - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory, - ) + self.draft_token_ids_cpu = torch.empty( + (self.max_num_reqs, self.num_spec_tokens), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + if self.use_async_scheduling: + self.valid_sampled_token_count_event = torch.Event() + self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None @@ -1055,15 +1052,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.spec_token_ids[req_index].clear() self.input_batch.spec_token_ids[req_index].extend(spec_token_ids) - # there are no draft tokens with async scheduling, - # we clear the spec_decoding info in scheduler_output and - # use normal sampling but rejection_sampling. if self.use_async_scheduling: req_state.prev_num_draft_len = num_spec_tokens - if num_spec_tokens and self._draft_token_ids is None: - scheduler_output.total_num_scheduled_tokens -= num_spec_tokens - scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens - scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: @@ -1296,7 +1286,6 @@ def _prepare_input_ids( ) # Scatter the draft tokens after the sampled tokens are scattered. - self._prev_draft_token_ids = self._draft_token_ids if self._draft_token_ids is None or not spec_flattened_indices: return @@ -3119,20 +3108,6 @@ def execute_model( "after execute_model() returns None." ) - # self._draft_token_ids is None when `input_fits_in_drafter=False` - # and there is no draft tokens scheduled. so it need to update the - # spec_decoding info in scheduler_output with async_scheduling. - # use deepcopy to avoid the modification has influence on the - # scheduler_output in engine core process. - # TODO(Ronald1995): deepcopy is expensive when there is a large - # number of requests, optimize it later. - if ( - self.use_async_scheduling - and self.num_spec_tokens - and self._draft_token_ids is None - ): - scheduler_output = deepcopy(scheduler_output) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with ( record_function_or_nullcontext("gpu_model_runner: preprocess"), @@ -3375,8 +3350,7 @@ def execute_model( @torch.inference_mode def sample_tokens( - self, - grammar_output: "GrammarOutput | None", + self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: kv_connector_output = self.kv_connector_output self.kv_connector_output = None @@ -3421,29 +3395,6 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) - # Mask out invalid spec tokens for async scheduling + structured outputs. - if ( - grammar_output is not None - and grammar_output.num_invalid_tokens_per_req is not None - and self.use_async_scheduling - ): - num_invalid_tokens_per_req = grammar_output.num_invalid_tokens_per_req - num_reqs, num_sampled_toks = sampler_output.sampled_token_ids.shape - num_invalid_spec_tokens_cpu = torch.zeros((num_reqs,), dtype=torch.int32) - for req_id, num_invalid_toks in zip( - grammar_output.structured_output_request_ids, num_invalid_tokens_per_req - ): - req_index = self.input_batch.req_id_to_index[req_id] - num_invalid_spec_tokens_cpu[req_index] = num_invalid_toks - col_indices = torch.arange(num_sampled_toks, dtype=torch.int32) - mask_start_indices = num_sampled_toks - num_invalid_spec_tokens_cpu - mask = col_indices.unsqueeze(0) >= mask_start_indices.unsqueeze(1) - self.invalid_spec_tokens_mask.cpu[:num_reqs, :].copy_(mask) - self.invalid_spec_tokens_mask.copy_to_gpu(num_reqs) - sampler_output.sampled_token_ids.masked_fill_( - self.invalid_spec_tokens_mask.gpu[:num_reqs, :], -1 - ) - self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): @@ -3459,8 +3410,12 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, spec_decode_common_attn_metadata, ) - self._copy_draft_token_ids_to_cpu(self._draft_token_ids) - self._draft_token_req_ids = self.input_batch.req_ids.copy() + struct_output = scheduler_output.has_structured_output_requests + # Draft tokens don't need to be copied to the CPU if async + # scheduling is in use and there are no structured output reqs. + if struct_output or not self.use_async_scheduling: + self._copy_draft_token_ids_to_cpu(self._draft_token_ids) + self._draft_token_req_ids = self.input_batch.req_ids.copy() spec_config = self.speculative_config use_padded_batch_for_eagle = ( @@ -3506,6 +3461,20 @@ def propose_draft_token_ids(sampled_token_ids): next_token_ids, valid_sampled_tokens_count ) + # Since we couldn't run the drafter, just use zeros for the + # draft tokens. + num_reqs = len(self.input_batch.req_ids) + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32 + ).expand(num_reqs, self.num_spec_tokens) + struct_output = scheduler_output.has_structured_output_requests + if struct_output or not self.use_async_scheduling: + self._draft_token_req_ids = self.input_batch.req_ids.copy() + assert self.draft_token_ids_cpu is not None + assert self.draft_token_ids_event is not None + self.draft_token_ids_cpu[:num_reqs] = 0 + self.draft_token_ids_event.record() + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits, @@ -3579,26 +3548,23 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: if not self.num_spec_tokens: return None req_ids = self._draft_token_req_ids - if req_ids is None: - req_ids = self.input_batch.req_ids - return DraftTokenIds(req_ids, [[] for _ in req_ids]) - assert req_ids is not None + if not req_ids: + return None draft_token_ids = self._get_draft_token_ids_cpu(len(req_ids)) if draft_token_ids is None: - return DraftTokenIds(req_ids, [[] for _ in req_ids]) + return None return DraftTokenIds(req_ids, draft_token_ids) def _copy_draft_token_ids_to_cpu( self, draft_token_ids: torch.Tensor | list[list[int]] ) -> None: - if isinstance(draft_token_ids, list): - return - if self.draft_token_ids_event is None: + if isinstance(draft_token_ids, list) or self.draft_token_ids_event is None: return - # For async scheduling, trigger async copy of draft token ids to cpu. + # Trigger async copy of draft token ids to cpu. default_stream = torch.cuda.current_stream() with torch.cuda.stream(self.draft_token_ids_copy_stream): assert self.draft_token_ids_copy_stream is not None + assert self.draft_token_ids_cpu is not None self.draft_token_ids_copy_stream.wait_stream(default_stream) self.draft_token_ids_cpu[: draft_token_ids.shape[0]].copy_( draft_token_ids, non_blocking=True @@ -3608,12 +3574,10 @@ def _copy_draft_token_ids_to_cpu( def _get_draft_token_ids_cpu(self, num_reqs: int) -> list[list[int]] | None: if isinstance(self._draft_token_ids, list): return self._draft_token_ids - if self.draft_token_ids_event is None: - if self._draft_token_ids is None: - return None - return self._draft_token_ids.tolist() + assert self.draft_token_ids_event is not None + assert self.draft_token_ids_cpu is not None self.draft_token_ids_event.synchronize() - return self.draft_token_ids_cpu[0:num_reqs].tolist() + return self.draft_token_ids_cpu[:num_reqs].tolist() def _copy_valid_sampled_token_count( self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor @@ -3628,6 +3592,7 @@ def _copy_valid_sampled_token_count( self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore counts = valid_sampled_tokens_count counts_cpu = self.valid_sampled_token_count_cpu + assert counts_cpu is not None counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) self.valid_sampled_token_count_event.record() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c9ee3c31ca0a..c8441c09b2f9 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -566,8 +566,7 @@ def annotate_profile(self, scheduler_output): @torch.inference_mode() def sample_tokens( - self, - grammar_output: "GrammarOutput | None", + self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput: return self.model_runner.sample_tokens(grammar_output) From 6d81ab7317218ce8b4849abe99127170f5fff4d9 Mon Sep 17 00:00:00 2001 From: njhill Date: Sat, 3 Jan 2026 23:19:01 -0800 Subject: [PATCH 09/10] group draft_token_ids_cpu setting logic in same method Signed-off-by: njhill --- vllm/v1/worker/gpu_model_runner.py | 64 +++++++++++++++--------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 034b6c66bfb8..f3ca54916397 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3410,12 +3410,7 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, spec_decode_common_attn_metadata, ) - struct_output = scheduler_output.has_structured_output_requests - # Draft tokens don't need to be copied to the CPU if async - # scheduling is in use and there are no structured output reqs. - if struct_output or not self.use_async_scheduling: - self._copy_draft_token_ids_to_cpu(self._draft_token_ids) - self._draft_token_req_ids = self.input_batch.req_ids.copy() + self._copy_draft_token_ids_to_cpu(scheduler_output) spec_config = self.speculative_config use_padded_batch_for_eagle = ( @@ -3460,20 +3455,12 @@ def propose_draft_token_ids(sampled_token_ids): self._copy_valid_sampled_token_count( next_token_ids, valid_sampled_tokens_count ) - - # Since we couldn't run the drafter, just use zeros for the - # draft tokens. - num_reqs = len(self.input_batch.req_ids) + # Since we couldn't run the drafter, + # just use zeros for the draft tokens. self._draft_token_ids = torch.zeros( 1, device=self.device, dtype=torch.int32 - ).expand(num_reqs, self.num_spec_tokens) - struct_output = scheduler_output.has_structured_output_requests - if struct_output or not self.use_async_scheduling: - self._draft_token_req_ids = self.input_batch.req_ids.copy() - assert self.draft_token_ids_cpu is not None - assert self.draft_token_ids_event is not None - self.draft_token_ids_cpu[:num_reqs] = 0 - self.draft_token_ids_event.record() + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( @@ -3545,33 +3532,44 @@ def propose_draft_token_ids(sampled_token_ids): return async_output def take_draft_token_ids(self) -> DraftTokenIds | None: - if not self.num_spec_tokens: + if not self.num_spec_tokens or not self._draft_token_req_ids: return None req_ids = self._draft_token_req_ids - if not req_ids: - return None draft_token_ids = self._get_draft_token_ids_cpu(len(req_ids)) - if draft_token_ids is None: - return None return DraftTokenIds(req_ids, draft_token_ids) def _copy_draft_token_ids_to_cpu( - self, draft_token_ids: torch.Tensor | list[list[int]] + self, scheduler_output: "SchedulerOutput", zeros_only: bool = False ) -> None: - if isinstance(draft_token_ids, list) or self.draft_token_ids_event is None: + struct_output = scheduler_output.has_structured_output_requests + if self.use_async_scheduling and not struct_output: + # Draft tokens don't need to be copied to the CPU if async + # scheduling is in use and there are no structured output reqs. + return + # We must also set the corresponding request ids. + self._draft_token_req_ids = self.input_batch.req_ids.copy() + + draft_token_ids: torch.Tensor = self._draft_token_ids + if not torch.is_tensor(draft_token_ids): return - # Trigger async copy of draft token ids to cpu. + assert self.draft_token_ids_event is not None + assert self.draft_token_ids_copy_stream is not None + assert self.draft_token_ids_cpu is not None default_stream = torch.cuda.current_stream() + num_reqs = draft_token_ids.shape[0] with torch.cuda.stream(self.draft_token_ids_copy_stream): - assert self.draft_token_ids_copy_stream is not None - assert self.draft_token_ids_cpu is not None - self.draft_token_ids_copy_stream.wait_stream(default_stream) - self.draft_token_ids_cpu[: draft_token_ids.shape[0]].copy_( - draft_token_ids, non_blocking=True - ) + if not zeros_only: + # Trigger async copy of draft token ids to cpu. + self.draft_token_ids_copy_stream.wait_stream(default_stream) + self.draft_token_ids_cpu[:num_reqs].copy_( + draft_token_ids, non_blocking=True + ) + else: + # No copy needed, just zero-out cpu tensor. + self.draft_token_ids_cpu[:num_reqs] = 0 self.draft_token_ids_event.record() - def _get_draft_token_ids_cpu(self, num_reqs: int) -> list[list[int]] | None: + def _get_draft_token_ids_cpu(self, num_reqs: int) -> list[list[int]]: if isinstance(self._draft_token_ids, list): return self._draft_token_ids assert self.draft_token_ids_event is not None From a7c039216ba2dda0c8c0ab6181fbacb966858fb2 Mon Sep 17 00:00:00 2001 From: njhill Date: Mon, 5 Jan 2026 17:09:48 -0800 Subject: [PATCH 10/10] precommit and docs fix Signed-off-by: njhill --- vllm/v1/core/sched/interface.py | 2 +- vllm/v1/core/sched/scheduler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 66071f6c4c1a..92d8d929287b 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -104,7 +104,7 @@ def update_draft_token_ids_in_output( Args: draft_token_ids: The input draft token ids for each request. scheduler_output: Update the given scheduler_output - with the corresponding draft token ids. + with the corresponding draft token ids. """ raise NotImplementedError diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 419ab1415394..b3ea24dac823 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1338,7 +1338,7 @@ def update_draft_token_ids(self, draft_token_ids: DraftTokenIds) -> None: # Add newly generated spec token ids to the request. if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request - spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) + spec_token_ids = metadata.grammar.validate_tokens(spec_token_ids) # type: ignore[union-attr] request.spec_token_ids = spec_token_ids def update_draft_token_ids_in_output(