From dd0832385cc6a73075dfffc52a7eea76751dae4d Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Fri, 12 Dec 2025 16:25:02 +0800 Subject: [PATCH 01/13] extract _sample Signed-off-by: realliujiaxu --- vllm_ascend/worker/model_runner_v1.py | 89 ++++++++++++++------------- 1 file changed, 46 insertions(+), 43 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c6c881ebc30..88cbd4b259f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1310,7 +1310,7 @@ def _calc_spec_decode_metadata( draft_token_ids = draft_token_ids[target_logits_indices + 1] if self.pcp_size > 1: logits_indices = logits_indices_pcp - metadata = SpecDecodeMetadata( + return SpecDecodeMetadata( draft_token_ids=draft_token_ids, num_draft_tokens=num_draft_tokens.tolist(), cu_num_draft_tokens=cu_num_draft_tokens, @@ -1319,7 +1319,6 @@ def _calc_spec_decode_metadata( bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, ) - return metadata def propose_draft_token_ids( self, @@ -1672,46 +1671,8 @@ def sample_tokens( grammar_output, logits) with ProfileExecuteDuration().capture_async("Sample"): - # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.sampling_metadata - if spec_decode_metadata is None: - if lmhead_tp_enable() and logits is not None: - logits = logits[:self.input_batch.num_reqs] - sampler_output = self.sampler( - logits=logits, - sampling_metadata=sampling_metadata, - ) - else: - if lmhead_tp_enable() and logits is not None: - logits = logits[:len(spec_decode_metadata.logits_indices)] - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[ - spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[ - spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - if self.need_accepted_tokens: - self._update_states_after_model_execute(output_token_ids) + sampler_output = self._sample(logits, spec_decode_metadata) + discard_sampled_tokens_req_indices = \ self.discard_request_indices.np[:self.num_discarded_requests] for i in discard_sampled_tokens_req_indices: @@ -1810,7 +1771,7 @@ def propose_draft_token_ids(sampled_token_ids): assert self.spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( sampled_token_ids, - sampling_metadata, + self.input_batch.sampling_metadata, scheduler_output, spec_decode_metadata, positions, @@ -1878,6 +1839,48 @@ def propose_draft_token_ids(sampled_token_ids): async_output_copy_stream=self.async_output_copy_stream, ) + def _sample(self, logits, spec_decode_metadata): + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + if lmhead_tp_enable() and logits is not None: + logits = logits[:self.input_batch.num_reqs] + return self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + else: + if lmhead_tp_enable() and logits is not None: + logits = logits[:len(spec_decode_metadata.logits_indices)] + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[ + spec_decode_metadata.bonus_logits_indices] + sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + target_logits = logits[ + spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids + if self.need_accepted_tokens: + self._update_states_after_model_execute(output_token_ids) + def _build_dummy_attn_metadata( self, with_prefill: bool, From ba3360b527bb960f01f71313dcb963be7ba0e4ed Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Fri, 12 Dec 2025 17:17:53 +0800 Subject: [PATCH 02/13] extract _bookkeeping_sync Signed-off-by: realliujiaxu --- vllm_ascend/worker/model_runner_v1.py | 222 +++++++++++++++----------- 1 file changed, 128 insertions(+), 94 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 88cbd4b259f..9aa97092f01 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -69,7 +69,7 @@ MambaSpec, MLAAttentionSpec, UniformTypeKVCacheSpecs) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - ModelRunnerOutput, + LogprobsLists, LogprobsTensors, ModelRunnerOutput,SamplerOutput, make_empty_encoder_model_runner_output) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -1672,100 +1672,21 @@ def sample_tokens( with ProfileExecuteDuration().capture_async("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) - - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] - for i in discard_sampled_tokens_req_indices: - generator = self.input_batch.generators.get(int(i)) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - - # Copy some objects so they don't get modified after returning. - # This is important when using async scheduling. - req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() - - # NOTE: NPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:scheduler_output.total_num_scheduled_tokens], - scheduler_output, + ( + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync(scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, ) - num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] - sampled_token_ids = sampler_output.sampled_token_ids - - if not self.use_async_scheduling: - # Get the valid generated tokens. - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. It's a tensor. - valid_sampled_token_ids = sampled_token_ids.tolist() - else: - # Includes spec decode tokens. It's a numpy array - valid_sampled_token_ids, _ = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() - else: - valid_sampled_token_ids = [] - invalid_req_indices = discard_sampled_tokens_req_indices.tolist( - ) - invalid_req_indices_set = set(invalid_req_indices) - if self.num_spec_tokens <= 0: - assert sampled_token_ids.shape[-1] == 1 - # Cache the sampled tokens on the NPU and avoid CPU sync. - # These will be copied into input_ids in the next step - # when preparing inputs. - self.input_batch.prev_sampled_token_ids = sampled_token_ids - - - self.input_batch.prev_sampled_token_ids_invalid_indices = \ - invalid_req_indices_set - self.input_batch.prev_req_id_to_index = { - req_id: i - for i, req_id in enumerate(self.input_batch.req_ids) - if i not in invalid_req_indices_set - } - # Cache the sampled tokens in the model runner, so that the scheduler - # doesn't need to send them back. - # NOTE(woosuk): As an exception, when using PP, the scheduler sends - # the sampled tokens back, because there's no direct communication - # between the first-stage worker and the last-stage worker. - for req_idx in range(num_sampled_tokens): - if self.use_async_scheduling: - sampled_ids = [-1] * 1 if \ - req_idx not in invalid_req_indices_set else None - else: - sampled_ids = valid_sampled_token_ids[req_idx] - if not sampled_ids: - continue - - start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) - assert end_idx <= self.model_config.max_model_len, ( - "Sampled token IDs exceed the max model length. " - f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.model_config.max_model_len}") - - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.is_token_ids[req_idx, - start_idx:end_idx] = True - self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] - req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) def propose_draft_token_ids(sampled_token_ids): assert self.spec_decode_common_attn_metadata is not None @@ -1834,7 +1755,7 @@ def propose_draft_token_ids(sampled_token_ids): self.debugger.step() return AsyncGPUModelRunnerOutput( model_runner_output=model_runner_output, - sampled_token_ids=sampled_token_ids, + sampled_token_ids=sampler_output.sampled_token_ids, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, ) @@ -1881,6 +1802,119 @@ def _sample(self, logits, spec_decode_metadata): if self.need_accepted_tokens: self._update_states_after_model_execute(output_token_ids) + def _bookkeeping_sync( + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: torch.Tensor | None, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + spec_decode_metadata: SpecDecodeMetadata | None, + ): + discard_sampled_tokens_req_indices = \ + self.discard_request_indices.np[:self.num_discarded_requests] + for i in discard_sampled_tokens_req_indices: + generator = self.input_batch.generators.get(int(i)) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:scheduler_output.total_num_scheduled_tokens], + scheduler_output, + ) + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + invalid_req_indices = [] + + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. It's a tensor. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. It's a numpy array + valid_sampled_token_ids, _ = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)].clear() + else: + valid_sampled_token_ids = [] + invalid_req_indices = discard_sampled_tokens_req_indices.tolist( + ) + invalid_req_indices_set = set(invalid_req_indices) + if self.num_spec_tokens <= 0: + assert sampled_token_ids.shape[-1] == 1 + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = sampled_token_ids + + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] * 1 if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, + start_idx:end_idx] = True + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + return ( + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + + def _build_dummy_attn_metadata( self, with_prefill: bool, From e95e7fb37ff0e34d3d000ae3274e3ecdcf65440b Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Fri, 12 Dec 2025 17:41:39 +0800 Subject: [PATCH 03/13] refactor rejection sampler Signed-off-by: realliujiaxu --- vllm_ascend/sample/rejection_sampler.py | 85 +------------------------ vllm_ascend/worker/model_runner_v1.py | 79 +++++++++++------------ 2 files changed, 39 insertions(+), 125 deletions(-) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index c1ef10db5a9..3d2c2a71faf 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -22,89 +22,8 @@ class AscendRejectionSampler(RejectionSampler, nn.Module): - """ - The implementation strictly follows the algorithm described in - https://arxiv.org/abs/2211.17192. - However, we want to clarify the terminology used in the implementation: - accepted tokens: tokens that are accepted based on the relationship - between the "raw" draft and target probabilities. - recovered tokens: tokens that are sampled based on the adjusted probability - distribution, which is derived from both the draft and target - probabilities. - bonus tokens: - If all proposed tokens are accepted, the bonus token is added to the - end of the sequence. The bonus token is only sampled from the target - probabilities. We pass in the bonus tokens instead of sampling them - in the rejection sampler to allow for more flexibility in the - sampling process. For example, we can use top_p, top_k sampling for - bonus tokens, while spec decode does not support these sampling - strategies. - output tokens: - Tokens are finally generated with the rejection sampler. - output tokens = accepted tokens + recovered tokens + bonus tokens - """ - - def forward( - self, - metadata: SpecDecodeMetadata, - # [num_tokens, vocab_size] - draft_probs: Optional[torch.Tensor], - # [num_tokens, vocab_size] - target_logits: torch.Tensor, - # [batch_size, 1] - bonus_token_ids: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - ''' - Args: - metadata: - Metadata for spec decoding. - draft_probs (Optional[torch.Tensor]): - Probability distribution for the draft tokens. Shape is - [num_tokens, vocab_size]. Can be None if probabilities are - not provided, which is the case for ngram spec decode. - target_logits (torch.Tensor): - Target model's logits probability distribution. - Shape is [num_tokens, vocab_size]. Here, probabilities from - different requests are flattened into a single tensor because - this is the shape of the output logits. - NOTE: `target_logits` can be updated in place to save memory. - bonus_token_ids_tensor (torch.Tensor): - A tensor containing bonus tokens. Shape is [batch_size, 1]. - Bonus tokens are added to the end of the sequence if all - proposed tokens are accepted. We generate the bonus tokens - outside of the rejection sampler with the default sampling - strategy. It allows for more flexibility in the sampling - process such as top_p, top_k sampling. - sampling_metadata (SamplingMetadata): - Additional metadata needed for sampling, such as temperature, - top-k/top-p parameters, or other relevant information. - Returns: - output_token_ids (torch.Tensor): - A tensor containing the final output token IDs. - ''' - assert metadata.max_spec_len <= MAX_SPEC_LEN - # [num_tokens, vocab_size] - # NOTE(woosuk): `target_logits` can be updated in place inside the - # `compute_probs` function. - target_logits = apply_sampling_constraints( - target_logits, - metadata.cu_num_draft_tokens, - sampling_metadata, - ) - target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) + pass - output_token_ids = rejection_sample( - metadata.draft_token_ids, - metadata.num_draft_tokens, - metadata.max_spec_len, - metadata.cu_num_draft_tokens, - draft_probs, - target_probs, - bonus_token_ids, - sampling_metadata, - ) - return output_token_ids def apply_sampling_constraints( @@ -846,4 +765,6 @@ def sample_recovered_tokens_kernel( orig_prob) +rs.apply_sampling_constraints = apply_sampling_constraints +rs.rejection_sample = rejection_sample rs.expand_batch_to_tokens = expand_batch_to_tokens diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9aa97092f01..8c3e65b30f7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1672,6 +1672,7 @@ def sample_tokens( with ProfileExecuteDuration().capture_async("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) + ( logprobs_lists, valid_sampled_token_ids, @@ -1760,6 +1761,7 @@ def propose_draft_token_ids(sampled_token_ids): async_output_copy_stream=self.async_output_copy_stream, ) + # overwrite _sample for lmhead_tp_enable and need_accepted_tokens def _sample(self, logits, spec_decode_metadata): # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata @@ -1770,37 +1772,18 @@ def _sample(self, logits, spec_decode_metadata): logits=logits, sampling_metadata=sampling_metadata, ) - else: - if lmhead_tp_enable() and logits is not None: - logits = logits[:len(spec_decode_metadata.logits_indices)] - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[ - spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[ - spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - if self.need_accepted_tokens: - self._update_states_after_model_execute(output_token_ids) + + if lmhead_tp_enable() and logits is not None: + logits = logits[:len(spec_decode_metadata.logits_indices)] + sampler_output = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + logits, + sampling_metadata, + ) + if self.need_accepted_tokens: # TODO remove this if + self._update_states_after_model_execute(sampler_output.sampled_token_ids) + return sampler_output def _bookkeeping_sync( self, @@ -1824,18 +1807,6 @@ def _bookkeeping_sync( req_id_to_index_output_copy = \ self.input_batch.req_id_to_index.copy() - # NOTE: NPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:scheduler_output.total_num_scheduled_tokens], - scheduler_output, - ) - num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids invalid_req_indices = [] @@ -1879,6 +1850,10 @@ def _bookkeeping_sync( # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. + logprobs_tensors = sampler_output.logprobs_tensors + cu_num_accepted_tokens = ( + [0] if spec_decode_metadata and logprobs_tensors else None + ) for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: sampled_ids = [-1] * 1 if \ @@ -1905,6 +1880,24 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) + if cu_num_accepted_tokens is not None: + cu_num_accepted_tokens.append( + cu_num_accepted_tokens[-1] + len(sampled_ids) + ) + + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:scheduler_output.total_num_scheduled_tokens], + scheduler_output, + ) + return ( logprobs_lists, valid_sampled_token_ids, From a54e1950273e561ed361f458560c5612a15b27b8 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Fri, 12 Dec 2025 20:38:59 +0800 Subject: [PATCH 04/13] fix lint Signed-off-by: realliujiaxu --- vllm_ascend/sample/rejection_sampler.py | 2 -- vllm_ascend/worker/model_runner_v1.py | 31 +++++++++++-------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 3d2c2a71faf..ec27ed5fc0e 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -10,7 +10,6 @@ from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.rejection_sampler import (RejectionSampler, generate_uniform_probs) -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type @@ -25,7 +24,6 @@ class AscendRejectionSampler(RejectionSampler, nn.Module): pass - def apply_sampling_constraints( logits: torch.Tensor, # [num_tokens, vocab_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8c3e65b30f7..f6a3cf55e90 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -69,7 +69,7 @@ MambaSpec, MLAAttentionSpec, UniformTypeKVCacheSpecs) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - LogprobsLists, LogprobsTensors, ModelRunnerOutput,SamplerOutput, + ModelRunnerOutput, SamplerOutput, make_empty_encoder_model_runner_output) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -1680,7 +1680,8 @@ def sample_tokens( req_ids_output_copy, req_id_to_index_output_copy, invalid_req_indices, - ) = self._bookkeeping_sync(scheduler_output, + ) = self._bookkeeping_sync( + scheduler_output, sampler_output, logits, hidden_states, @@ -1688,7 +1689,6 @@ def sample_tokens( spec_decode_metadata, ) - def propose_draft_token_ids(sampled_token_ids): assert self.spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( @@ -1781,8 +1781,9 @@ def _sample(self, logits, spec_decode_metadata): logits, sampling_metadata, ) - if self.need_accepted_tokens: # TODO remove this if - self._update_states_after_model_execute(sampler_output.sampled_token_ids) + if self.need_accepted_tokens: # TODO remove this if + self._update_states_after_model_execute( + sampler_output.sampled_token_ids) return sampler_output def _bookkeeping_sync( @@ -1828,8 +1829,7 @@ def _bookkeeping_sync( valid_sampled_token_ids[int(i)].clear() else: valid_sampled_token_ids = [] - invalid_req_indices = discard_sampled_tokens_req_indices.tolist( - ) + invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices_set = set(invalid_req_indices) if self.num_spec_tokens <= 0: assert sampled_token_ids.shape[-1] == 1 @@ -1851,9 +1851,8 @@ def _bookkeeping_sync( # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. logprobs_tensors = sampler_output.logprobs_tensors - cu_num_accepted_tokens = ( - [0] if spec_decode_metadata and logprobs_tensors else None - ) + cu_num_accepted_tokens = ([0] if spec_decode_metadata + and logprobs_tensors else None) for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: sampled_ids = [-1] * 1 if \ @@ -1871,9 +1870,8 @@ def _bookkeeping_sync( f"{self.model_config.max_model_len}") self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.is_token_ids[req_idx, - start_idx:end_idx] = True + start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx req_id = self.input_batch.req_ids[req_idx] @@ -1881,10 +1879,8 @@ def _bookkeeping_sync( req_state.output_token_ids.extend(sampled_ids) if cu_num_accepted_tokens is not None: - cu_num_accepted_tokens.append( - cu_num_accepted_tokens[-1] + len(sampled_ids) - ) - + cu_num_accepted_tokens.append(cu_num_accepted_tokens[-1] + + len(sampled_ids)) # NOTE: NPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. @@ -1907,7 +1903,6 @@ def _bookkeeping_sync( invalid_req_indices, ) - def _build_dummy_attn_metadata( self, with_prefill: bool, From c8a92e33af98dddbf84fe9f0f56bdd040a02be8a Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Fri, 12 Dec 2025 21:01:06 +0800 Subject: [PATCH 05/13] delete AscendRejectionSampler Signed-off-by: realliujiaxu --- vllm_ascend/patch/worker/__init__.py | 1 + .../patch/worker/patch_rejection_sampler.py | 11 +++++++++++ vllm_ascend/sample/rejection_sampler.py | 14 +------------- vllm_ascend/worker/model_runner_v1.py | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_rejection_sampler.py diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 45e37a5d67a..07a77f7dedf 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_qwen3_vl # noqa import vllm_ascend.patch.worker.patch_rope # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa +import vllm_ascend.patch.worker.patch_rejection_sampler # noqa diff --git a/vllm_ascend/patch/worker/patch_rejection_sampler.py b/vllm_ascend/patch/worker/patch_rejection_sampler.py new file mode 100644 index 00000000000..f94fee6051b --- /dev/null +++ b/vllm_ascend/patch/worker/patch_rejection_sampler.py @@ -0,0 +1,11 @@ +import vllm.v1.sample.rejection_sampler as rs + +from vllm_ascend.sample.rejection_sampler import (apply_sampling_constraints, + expand_batch_to_tokens, + rejection_sample) + +# TODO: delete this patch after apply_sampling_constraints and rejection_sample +# are extracted to as class func of RejectionSampler +rs.apply_sampling_constraints = apply_sampling_constraints +rs.rejection_sample = rejection_sample +rs.expand_batch_to_tokens = expand_batch_to_tokens diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index ec27ed5fc0e..4ab03a77d8d 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -2,14 +2,11 @@ from typing import Optional import torch -import torch.nn as nn import torch_npu -import vllm.v1.sample.rejection_sampler as rs from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p -from vllm.v1.sample.rejection_sampler import (RejectionSampler, - generate_uniform_probs) +from vllm.v1.sample.rejection_sampler import generate_uniform_probs from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type @@ -20,10 +17,6 @@ MAX_SPEC_LEN = 32 -class AscendRejectionSampler(RejectionSampler, nn.Module): - pass - - def apply_sampling_constraints( logits: torch.Tensor, # [num_tokens, vocab_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] @@ -761,8 +754,3 @@ def sample_recovered_tokens_kernel( tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) - - -rs.apply_sampling_constraints = apply_sampling_constraints -rs.rejection_sample = rejection_sample -rs.expand_batch_to_tokens = expand_batch_to_tokens diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f6a3cf55e90..8e57578d8ab 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -72,6 +72,7 @@ ModelRunnerOutput, SamplerOutput, make_empty_encoder_model_runner_output) from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer @@ -112,7 +113,6 @@ from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.logits_processor import build_logitsprocs -from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer @@ -413,7 +413,7 @@ def _set_up_drafter(self): ) if get_pp_group().is_last_rank: self.drafter = self._get_drafter() - self.rejection_sampler = AscendRejectionSampler(self.sampler) + self.rejection_sampler = RejectionSampler(self.sampler) self.actual_seq_lengths_q = list( range(self.decode_token_per_req, self.max_num_tokens + 1, self.decode_token_per_req)) From 876ee5ccf90a910c22cf5f8032e1a3f193c1c547 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Fri, 12 Dec 2025 20:14:57 +0800 Subject: [PATCH 06/13] fix eagle proposer Signed-off-by: realliujiaxu --- vllm_ascend/spec_decode/eagle_proposer.py | 55 ++++++++++++----------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 27a7f7179d3..47dc823db75 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -169,7 +169,7 @@ def generate_token_ids(self, eagle_attn_metadata = attn_metadata[self.attn_layer_name] if spec_decode_metadata is None: # input_ids can be None for multimodal models. - target_token_ids = self.runner.input_ids[:num_scheduled_tokens] + target_token_ids = self.runner.input_ids.gpu[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] if self.name == SpecDcodeType.EAGLE3: target_hidden_states = torch.cat( @@ -192,7 +192,7 @@ def generate_token_ids(self, ) cu_num_tokens, token_indices =\ self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens) - target_token_ids = self.runner.input_ids[token_indices] + target_token_ids = self.runner.input_ids.gpu[token_indices] target_positions = positions[token_indices] if self.name == SpecDcodeType.EAGLE3: target_hidden_states = torch.cat( @@ -245,7 +245,7 @@ def _get_eagle_atten_dict( num_scheduled_tokens) # Get positions. - positions_np = self.runner.positions_np[:total_num_scheduled_tokens] + positions_np = self.runner.positions.np[:total_num_scheduled_tokens] np.add(self.runner.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) @@ -270,7 +270,7 @@ def _get_eagle_atten_dict( self.runner.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), - out=self.runner.input_ids_cpu[:total_num_scheduled_tokens]) + out=self.runner.input_ids.cpu[:total_num_scheduled_tokens]) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. @@ -299,40 +299,41 @@ def _get_eagle_atten_dict( np.add( block_numbers * block_size, block_offsets, - out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + out=block_table.slot_mapping.np[:total_num_scheduled_tokens]) # Prepare the attention metadata. - self.runner.query_start_loc_np[0] = 0 - self.runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + self.runner.query_start_loc.np[0] = 0 + self.runner.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens - self.runner.seq_lens_np[:num_reqs] = ( + self.runner.seq_lens.np[:num_reqs] = ( self.runner.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) # Copy the tensors to the NPU. - self.runner.input_ids[:total_num_scheduled_tokens].copy_( - self.runner.input_ids_cpu[:total_num_scheduled_tokens], + self.runner.input_ids.gpu[:total_num_scheduled_tokens].copy_( + self.runner.input_ids.cpu[:total_num_scheduled_tokens], non_blocking=True) if self.runner.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.runner.mrope_positions[:, :total_num_scheduled_tokens].copy_( - self.runner. - mrope_positions_cpu[:, :total_num_scheduled_tokens], - non_blocking=True) + self.runner.mrope_positions.gpu[:, :total_num_scheduled_tokens] \ + .copy_( + self.runner. + mrope_positions.cpu[:, :total_num_scheduled_tokens], + non_blocking=True) else: # Common case (1D positions) - self.runner.positions[:total_num_scheduled_tokens].copy_( - self.runner.positions_cpu[:total_num_scheduled_tokens], + self.runner.positions.gpu[:total_num_scheduled_tokens].copy_( + self.runner.positions.cpu[:total_num_scheduled_tokens], non_blocking=True) - self.runner.query_start_loc[:num_reqs + 1].copy_( - self.runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) - self.runner.seq_lens[:num_reqs].copy_( - self.runner.seq_lens_cpu[:num_reqs], non_blocking=True) + self.runner.query_start_loc.gpu[:num_reqs + 1].copy_( + self.runner.query_start_loc.cpu[:num_reqs + 1], non_blocking=True) + self.runner.seq_lens.gpu[:num_reqs].copy_( + self.runner.seq_lens.cpu[:num_reqs], non_blocking=True) # Fill unused with -1. Needed for reshape_and_cache - self.runner.seq_lens[num_reqs:].fill_(0) - self.runner.query_start_loc[num_reqs + 1:].fill_(-1) + self.runner.seq_lens.gpu[num_reqs:].fill_(0) + self.runner.query_start_loc.gpu[num_reqs + 1:].fill_(-1) attn_metadata = {} # Prepare the attention metadata for each KV cache group and make layers @@ -340,10 +341,10 @@ def _get_eagle_atten_dict( for kv_cache_group_id, kv_cache_group_spec in enumerate( self.runner.kv_cache_config.kv_cache_groups): common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.runner.query_start_loc[:num_reqs + 1], - query_start_loc_cpu=self.runner.query_start_loc_cpu[:num_reqs + + query_start_loc=self.runner.query_start_loc.gpu[:num_reqs + 1], + query_start_loc_cpu=self.runner.query_start_loc.cpu[:num_reqs + 1], - seq_lens_cpu=self.runner.seq_lens_cpu, + seq_lens_cpu=self.runner.seq_lens.cpu, num_reqs=num_reqs, max_query_len=max_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens, @@ -351,8 +352,8 @@ def _get_eagle_atten_dict( block_table_tensor=self.runner.input_batch.block_table[0]. get_device_tensor(), slot_mapping=self.runner.input_batch.block_table[0]. - slot_mapping, - positions=self.runner.positions, + slot_mapping.gpu, + positions=self.runner.positions.gpu, attn_mask=self.runner.attn_mask, spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, From 5cba4c767ea8eba98a8bda88d0341c383b224f25 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Sat, 13 Dec 2025 11:00:39 +0800 Subject: [PATCH 07/13] delete _bookkeeping_sync Signed-off-by: realliujiaxu --- vllm_ascend/worker/model_runner_v1.py | 233 +++++++++++++------------- 1 file changed, 117 insertions(+), 116 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8e57578d8ab..100d5ef1cbf 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1674,6 +1674,7 @@ def sample_tokens( sampler_output = self._sample(logits, spec_decode_metadata) ( + num_nans_in_logits, logprobs_lists, valid_sampled_token_ids, prompt_logprobs_dict, @@ -1786,122 +1787,122 @@ def _sample(self, logits, spec_decode_metadata): sampler_output.sampled_token_ids) return sampler_output - def _bookkeeping_sync( - self, - scheduler_output: "SchedulerOutput", - sampler_output: SamplerOutput, - logits: torch.Tensor | None, - hidden_states: torch.Tensor, - num_scheduled_tokens: int, - spec_decode_metadata: SpecDecodeMetadata | None, - ): - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] - for i in discard_sampled_tokens_req_indices: - generator = self.input_batch.generators.get(int(i)) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - - # Copy some objects so they don't get modified after returning. - # This is important when using async scheduling. - req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() - - num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] - sampled_token_ids = sampler_output.sampled_token_ids - invalid_req_indices = [] - - if not self.use_async_scheduling: - # Get the valid generated tokens. - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. It's a tensor. - valid_sampled_token_ids = sampled_token_ids.tolist() - else: - # Includes spec decode tokens. It's a numpy array - valid_sampled_token_ids, _ = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() - else: - valid_sampled_token_ids = [] - invalid_req_indices = discard_sampled_tokens_req_indices.tolist() - invalid_req_indices_set = set(invalid_req_indices) - if self.num_spec_tokens <= 0: - assert sampled_token_ids.shape[-1] == 1 - # Cache the sampled tokens on the NPU and avoid CPU sync. - # These will be copied into input_ids in the next step - # when preparing inputs. - self.input_batch.prev_sampled_token_ids = sampled_token_ids - - self.input_batch.prev_sampled_token_ids_invalid_indices = \ - invalid_req_indices_set - self.input_batch.prev_req_id_to_index = { - req_id: i - for i, req_id in enumerate(self.input_batch.req_ids) - if i not in invalid_req_indices_set - } - # Cache the sampled tokens in the model runner, so that the scheduler - # doesn't need to send them back. - # NOTE(woosuk): As an exception, when using PP, the scheduler sends - # the sampled tokens back, because there's no direct communication - # between the first-stage worker and the last-stage worker. - logprobs_tensors = sampler_output.logprobs_tensors - cu_num_accepted_tokens = ([0] if spec_decode_metadata - and logprobs_tensors else None) - for req_idx in range(num_sampled_tokens): - if self.use_async_scheduling: - sampled_ids = [-1] * 1 if \ - req_idx not in invalid_req_indices_set else None - else: - sampled_ids = valid_sampled_token_ids[req_idx] - if not sampled_ids: - continue - - start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) - assert end_idx <= self.model_config.max_model_len, ( - "Sampled token IDs exceed the max model length. " - f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.model_config.max_model_len}") - - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True - self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] - req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) - - if cu_num_accepted_tokens is not None: - cu_num_accepted_tokens.append(cu_num_accepted_tokens[-1] + - len(sampled_ids)) - - # NOTE: NPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:scheduler_output.total_num_scheduled_tokens], - scheduler_output, - ) - - return ( - logprobs_lists, - valid_sampled_token_ids, - prompt_logprobs_dict, - req_ids_output_copy, - req_id_to_index_output_copy, - invalid_req_indices, - ) + # def _bookkeeping_sync( + # self, + # scheduler_output: "SchedulerOutput", + # sampler_output: SamplerOutput, + # logits: torch.Tensor | None, + # hidden_states: torch.Tensor, + # num_scheduled_tokens: int, + # spec_decode_metadata: SpecDecodeMetadata | None, + # ): + # discard_sampled_tokens_req_indices = \ + # self.discard_request_indices.np[:self.num_discarded_requests] + # for i in discard_sampled_tokens_req_indices: + # generator = self.input_batch.generators.get(int(i)) + # if generator is not None: + # generator.set_offset(generator.get_offset() - 4) + # + # # Copy some objects so they don't get modified after returning. + # # This is important when using async scheduling. + # req_ids_output_copy = self.input_batch.req_ids.copy() + # req_id_to_index_output_copy = \ + # self.input_batch.req_id_to_index.copy() + # + # num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + # sampled_token_ids = sampler_output.sampled_token_ids + # invalid_req_indices = [] + # + # if not self.use_async_scheduling: + # # Get the valid generated tokens. + # max_gen_len = sampled_token_ids.shape[-1] + # if max_gen_len == 1: + # # No spec decode tokens. It's a tensor. + # valid_sampled_token_ids = sampled_token_ids.tolist() + # else: + # # Includes spec decode tokens. It's a numpy array + # valid_sampled_token_ids, _ = self.rejection_sampler.parse_output( + # sampled_token_ids, + # self.input_batch.vocab_size, + # ) + # # Mask out the sampled tokens that should not be sampled. + # for i in discard_sampled_tokens_req_indices: + # valid_sampled_token_ids[int(i)].clear() + # else: + # valid_sampled_token_ids = [] + # invalid_req_indices = discard_sampled_tokens_req_indices.tolist() + # invalid_req_indices_set = set(invalid_req_indices) + # if self.num_spec_tokens <= 0: + # assert sampled_token_ids.shape[-1] == 1 + # # Cache the sampled tokens on the NPU and avoid CPU sync. + # # These will be copied into input_ids in the next step + # # when preparing inputs. + # self.input_batch.prev_sampled_token_ids = sampled_token_ids + # + # self.input_batch.prev_sampled_token_ids_invalid_indices = \ + # invalid_req_indices_set + # self.input_batch.prev_req_id_to_index = { + # req_id: i + # for i, req_id in enumerate(self.input_batch.req_ids) + # if i not in invalid_req_indices_set + # } + # # Cache the sampled tokens in the model runner, so that the scheduler + # # doesn't need to send them back. + # # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # # the sampled tokens back, because there's no direct communication + # # between the first-stage worker and the last-stage worker. + # logprobs_tensors = sampler_output.logprobs_tensors + # cu_num_accepted_tokens = ([0] if spec_decode_metadata + # and logprobs_tensors else None) + # for req_idx in range(num_sampled_tokens): + # if self.use_async_scheduling: + # sampled_ids = [-1] * 1 if \ + # req_idx not in invalid_req_indices_set else None + # else: + # sampled_ids = valid_sampled_token_ids[req_idx] + # if not sampled_ids: + # continue + # + # start_idx = self.input_batch.num_tokens_no_spec[req_idx] + # end_idx = start_idx + len(sampled_ids) + # assert end_idx <= self.model_config.max_model_len, ( + # "Sampled token IDs exceed the max model length. " + # f"Total number of tokens: {end_idx} > max_model_len: " + # f"{self.model_config.max_model_len}") + # + # self.input_batch.token_ids_cpu[req_idx, + # start_idx:end_idx] = sampled_ids + # self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True + # self.input_batch.num_tokens_no_spec[req_idx] = end_idx + # self.input_batch.num_tokens[req_idx] = end_idx + # req_id = self.input_batch.req_ids[req_idx] + # req_state = self.requests[req_id] + # req_state.output_token_ids.extend(sampled_ids) + # + # if cu_num_accepted_tokens is not None: + # cu_num_accepted_tokens.append(cu_num_accepted_tokens[-1] + + # len(sampled_ids)) + # + # # NOTE: NPU -> CPU Sync happens here. + # # Move as many CPU operations as possible before this sync point. + # logprobs_tensors = sampler_output.logprobs_tensors + # logprobs_lists = logprobs_tensors.tolists() \ + # if logprobs_tensors is not None else None + # + # # Compute prompt logprobs if needed. + # prompt_logprobs_dict = self._get_prompt_logprobs_dict( + # hidden_states[:scheduler_output.total_num_scheduled_tokens], + # scheduler_output, + # ) + # + # return ( + # logprobs_lists, + # valid_sampled_token_ids, + # prompt_logprobs_dict, + # req_ids_output_copy, + # req_id_to_index_output_copy, + # invalid_req_indices, + # ) def _build_dummy_attn_metadata( self, From 44b458277a0b2eea9c55b6532a6b4a6bb7f96c84 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Sat, 13 Dec 2025 15:04:53 +0800 Subject: [PATCH 08/13] update bookkeeping Signed-off-by: realliujiaxu --- vllm_ascend/worker/model_runner_v1.py | 277 +++++++++++++------------- 1 file changed, 143 insertions(+), 134 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 100d5ef1cbf..0d0cf41cc8c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -69,7 +69,7 @@ MambaSpec, MLAAttentionSpec, UniformTypeKVCacheSpecs) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - ModelRunnerOutput, SamplerOutput, + ModelRunnerOutput, SamplerOutput,LogprobsLists, LogprobsTensors, make_empty_encoder_model_runner_output) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler @@ -1673,23 +1673,6 @@ def sample_tokens( with ProfileExecuteDuration().capture_async("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) - ( - num_nans_in_logits, - logprobs_lists, - valid_sampled_token_ids, - prompt_logprobs_dict, - req_ids_output_copy, - req_id_to_index_output_copy, - invalid_req_indices, - ) = self._bookkeeping_sync( - scheduler_output, - sampler_output, - logits, - hidden_states, - scheduler_output.total_num_scheduled_tokens, - spec_decode_metadata, - ) - def propose_draft_token_ids(sampled_token_ids): assert self.spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( @@ -1704,6 +1687,22 @@ def propose_draft_token_ids(sampled_token_ids): aux_hidden_states, ) + ( + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, + ) + with ProfileExecuteDuration().capture_async("Draft"): if self.speculative_config: use_padded_batch_for_eagle = self.speculative_config and \ @@ -1787,122 +1786,132 @@ def _sample(self, logits, spec_decode_metadata): sampler_output.sampled_token_ids) return sampler_output - # def _bookkeeping_sync( - # self, - # scheduler_output: "SchedulerOutput", - # sampler_output: SamplerOutput, - # logits: torch.Tensor | None, - # hidden_states: torch.Tensor, - # num_scheduled_tokens: int, - # spec_decode_metadata: SpecDecodeMetadata | None, - # ): - # discard_sampled_tokens_req_indices = \ - # self.discard_request_indices.np[:self.num_discarded_requests] - # for i in discard_sampled_tokens_req_indices: - # generator = self.input_batch.generators.get(int(i)) - # if generator is not None: - # generator.set_offset(generator.get_offset() - 4) - # - # # Copy some objects so they don't get modified after returning. - # # This is important when using async scheduling. - # req_ids_output_copy = self.input_batch.req_ids.copy() - # req_id_to_index_output_copy = \ - # self.input_batch.req_id_to_index.copy() - # - # num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] - # sampled_token_ids = sampler_output.sampled_token_ids - # invalid_req_indices = [] - # - # if not self.use_async_scheduling: - # # Get the valid generated tokens. - # max_gen_len = sampled_token_ids.shape[-1] - # if max_gen_len == 1: - # # No spec decode tokens. It's a tensor. - # valid_sampled_token_ids = sampled_token_ids.tolist() - # else: - # # Includes spec decode tokens. It's a numpy array - # valid_sampled_token_ids, _ = self.rejection_sampler.parse_output( - # sampled_token_ids, - # self.input_batch.vocab_size, - # ) - # # Mask out the sampled tokens that should not be sampled. - # for i in discard_sampled_tokens_req_indices: - # valid_sampled_token_ids[int(i)].clear() - # else: - # valid_sampled_token_ids = [] - # invalid_req_indices = discard_sampled_tokens_req_indices.tolist() - # invalid_req_indices_set = set(invalid_req_indices) - # if self.num_spec_tokens <= 0: - # assert sampled_token_ids.shape[-1] == 1 - # # Cache the sampled tokens on the NPU and avoid CPU sync. - # # These will be copied into input_ids in the next step - # # when preparing inputs. - # self.input_batch.prev_sampled_token_ids = sampled_token_ids - # - # self.input_batch.prev_sampled_token_ids_invalid_indices = \ - # invalid_req_indices_set - # self.input_batch.prev_req_id_to_index = { - # req_id: i - # for i, req_id in enumerate(self.input_batch.req_ids) - # if i not in invalid_req_indices_set - # } - # # Cache the sampled tokens in the model runner, so that the scheduler - # # doesn't need to send them back. - # # NOTE(woosuk): As an exception, when using PP, the scheduler sends - # # the sampled tokens back, because there's no direct communication - # # between the first-stage worker and the last-stage worker. - # logprobs_tensors = sampler_output.logprobs_tensors - # cu_num_accepted_tokens = ([0] if spec_decode_metadata - # and logprobs_tensors else None) - # for req_idx in range(num_sampled_tokens): - # if self.use_async_scheduling: - # sampled_ids = [-1] * 1 if \ - # req_idx not in invalid_req_indices_set else None - # else: - # sampled_ids = valid_sampled_token_ids[req_idx] - # if not sampled_ids: - # continue - # - # start_idx = self.input_batch.num_tokens_no_spec[req_idx] - # end_idx = start_idx + len(sampled_ids) - # assert end_idx <= self.model_config.max_model_len, ( - # "Sampled token IDs exceed the max model length. " - # f"Total number of tokens: {end_idx} > max_model_len: " - # f"{self.model_config.max_model_len}") - # - # self.input_batch.token_ids_cpu[req_idx, - # start_idx:end_idx] = sampled_ids - # self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True - # self.input_batch.num_tokens_no_spec[req_idx] = end_idx - # self.input_batch.num_tokens[req_idx] = end_idx - # req_id = self.input_batch.req_ids[req_idx] - # req_state = self.requests[req_id] - # req_state.output_token_ids.extend(sampled_ids) - # - # if cu_num_accepted_tokens is not None: - # cu_num_accepted_tokens.append(cu_num_accepted_tokens[-1] + - # len(sampled_ids)) - # - # # NOTE: NPU -> CPU Sync happens here. - # # Move as many CPU operations as possible before this sync point. - # logprobs_tensors = sampler_output.logprobs_tensors - # logprobs_lists = logprobs_tensors.tolists() \ - # if logprobs_tensors is not None else None - # - # # Compute prompt logprobs if needed. - # prompt_logprobs_dict = self._get_prompt_logprobs_dict( - # hidden_states[:scheduler_output.total_num_scheduled_tokens], - # scheduler_output, - # ) - # - # return ( - # logprobs_lists, - # valid_sampled_token_ids, - # prompt_logprobs_dict, - # req_ids_output_copy, - # req_id_to_index_output_copy, - # invalid_req_indices, - # ) + # TODO: remove this func after eagle_proposer is refactored and + # _bookkeeping_sync is moved after propose_draft_token_ids + def _bookkeeping_sync( + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: torch.Tensor | None, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + spec_decode_metadata: SpecDecodeMetadata | None, + ) -> tuple[ + LogprobsLists | None, + list[list[int]], + dict[str, LogprobsTensors | None], + list[str], + dict[str, int], + list[int], + ]: + num_reqs = self.input_batch.num_reqs + discard_sampled_tokens_req_indices = np.nonzero( + self.discard_request_mask.np[:num_reqs] + )[0] + for i in discard_sampled_tokens_req_indices: + gen = self.input_batch.generators.get(int(i)) + if gen is not None: + gen.set_offset(gen.get_offset() - 4) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + logprobs_tensors = sampler_output.logprobs_tensors + invalid_req_indices = [] + cu_num_tokens: list[int] | None = None + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. + valid_sampled_token_ids = self._to_list(sampled_token_ids) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)].clear() + else: + # Includes spec decode tokens. + valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + discard_sampled_tokens_req_indices, + return_cu_num_tokens=logprobs_tensors is not None, + ) + else: + valid_sampled_token_ids = [] + invalid_req_indices = discard_sampled_tokens_req_indices.tolist() + invalid_req_indices_set = set(invalid_req_indices) + + # Cache the sampled tokens on the GPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + # With spec decoding, this is done in propose_draft_token_ids(). + if self.num_spec_tokens <= 0: + assert sampled_token_ids.shape[-1] == 1 + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } + + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + req_ids = self.input_batch.req_ids + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] + + num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0 + + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + num_sampled_ids + assert end_idx <= self.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.max_model_len}" + ) + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + + req_id = req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + logprobs_lists = ( + logprobs_tensors.tolists(cu_num_tokens) + if not self.use_async_scheduling and logprobs_tensors is not None + else None + ) + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output.num_scheduled_tokens, + ) + + return ( + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) def _build_dummy_attn_metadata( self, From ad659f15b36fcc90e693e5ed347dc87b7d8df7a2 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Sat, 13 Dec 2025 15:34:36 +0800 Subject: [PATCH 09/13] fix async scheduling Signed-off-by: realliujiaxu --- vllm_ascend/worker/model_runner_v1.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0d0cf41cc8c..5242ed6fcd3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1757,8 +1757,10 @@ def propose_draft_token_ids(sampled_token_ids): return AsyncGPUModelRunnerOutput( model_runner_output=model_runner_output, sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, ) # overwrite _sample for lmhead_tp_enable and need_accepted_tokens From 8882134958379557df75f85d97643922393270dd Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Sat, 13 Dec 2025 15:44:06 +0800 Subject: [PATCH 10/13] fix lint Signed-off-by: realliujiaxu --- vllm_ascend/worker/model_runner_v1.py | 36 +++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5242ed6fcd3..455a43bf0cd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -69,7 +69,8 @@ MambaSpec, MLAAttentionSpec, UniformTypeKVCacheSpecs) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - ModelRunnerOutput, SamplerOutput,LogprobsLists, LogprobsTensors, + LogprobsLists, LogprobsTensors, ModelRunnerOutput, + SamplerOutput, make_empty_encoder_model_runner_output) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler @@ -1799,17 +1800,16 @@ def _bookkeeping_sync( num_scheduled_tokens: int, spec_decode_metadata: SpecDecodeMetadata | None, ) -> tuple[ - LogprobsLists | None, - list[list[int]], - dict[str, LogprobsTensors | None], - list[str], - dict[str, int], - list[int], + LogprobsLists | None, + list[list[int]], + dict[str, LogprobsTensors | None], + list[str], + dict[str, int], + list[int], ]: num_reqs = self.input_batch.num_reqs discard_sampled_tokens_req_indices = np.nonzero( - self.discard_request_mask.np[:num_reqs] - )[0] + self.discard_request_mask.np[:num_reqs])[0] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -1868,7 +1868,9 @@ def _bookkeeping_sync( req_ids = self.input_batch.req_ids for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: - sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None + sampled_ids = [ + -1 + ] if req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] @@ -1882,10 +1884,10 @@ def _bookkeeping_sync( assert end_idx <= self.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.max_model_len}" - ) + f"{self.max_model_len}") - self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -1894,11 +1896,9 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - logprobs_lists = ( - logprobs_tensors.tolists(cu_num_tokens) - if not self.use_async_scheduling and logprobs_tensors is not None - else None - ) + logprobs_lists = (logprobs_tensors.tolists(cu_num_tokens) + if not self.use_async_scheduling + and logprobs_tensors is not None else None) # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( From 0b9ad8b0e5410cc6c5afd866def5de8051cfe195 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Mon, 15 Dec 2025 13:56:15 +0800 Subject: [PATCH 11/13] fix bookkeeping Signed-off-by: realliujiaxu --- vllm_ascend/worker/model_runner_v1.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f8eac00fc6a..b8303b3dc52 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1835,9 +1835,9 @@ def _bookkeeping_sync( dict[str, int], list[int], ]: - num_reqs = self.input_batch.num_reqs - discard_sampled_tokens_req_indices = np.nonzero( - self.discard_request_mask.np[:num_reqs])[0] + # TODO: implement PR 28597 from vllm + discard_sampled_tokens_req_indices = \ + self.discard_request_indices.np[:self.num_discarded_requests] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -1875,13 +1875,15 @@ def _bookkeeping_sync( invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices_set = set(invalid_req_indices) - # Cache the sampled tokens on the GPU and avoid CPU sync. - # These will be copied into input_ids in the next step - # when preparing inputs. - # With spec decoding, this is done in propose_draft_token_ids(). if self.num_spec_tokens <= 0: assert sampled_token_ids.shape[-1] == 1 + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. self.input_batch.prev_sampled_token_ids = sampled_token_ids + + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) From 5c61b3e9a2c1654f9af5328360587a7cd44b5c1b Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Mon, 15 Dec 2025 14:00:15 +0800 Subject: [PATCH 12/13] rm unused `prev_sampled_token_ids_invalid_indices` Signed-off-by: realliujiaxu --- vllm_ascend/worker/model_runner_v1.py | 2 -- vllm_ascend/worker/npu_input_batch.py | 1 - 2 files changed, 3 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b8303b3dc52..7dd2f8539be 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1882,8 +1882,6 @@ def _bookkeeping_sync( # when preparing inputs. self.input_batch.prev_sampled_token_ids = sampled_token_ids - self.input_batch.prev_sampled_token_ids_invalid_indices = \ - invalid_req_indices_set self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index d9db156640f..70d5ab2290f 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -314,7 +314,6 @@ def __init__( # Cached reference to the GPU tensor of previously sampled tokens self.prev_sampled_token_ids: torch.Tensor | None = None - self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None self.prev_req_id_to_index: dict[str, int] | None = None # These are used to update output_token_ids with real sampled # ids from prior step, if required by current sampling params From ee2ea15818ce64f48b06d4581ac29e29d5d95efe Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Mon, 15 Dec 2025 21:22:41 +0800 Subject: [PATCH 13/13] add comment and e2e UT Signed-off-by: realliujiaxu --- .../spec_decode_v1/test_v1_spec_decode.py | 54 +++++++++++++++++++ vllm_ascend/patch/__init__.py | 20 ++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index 3d7c5453918..f207c64dde9 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import math import os import random from typing import Any @@ -239,3 +240,56 @@ def test_suffix_acceptance( # Heuristic: expect at least 80% acceptance rate at the end. assert last_accept_rate > 0.60 + + +@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"]) +def test_eagle_logprobs( + model_name: str, + use_eagle3: bool, +): + prompt = {"role": "user", "content": "Hello world " * 10} + sampling_params = SamplingParams(temperature=0, + logprobs=1, + max_tokens=10, + ignore_eos=False) + + ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False) + ref_outputs = ref_llm.chat([prompt], sampling_params) + ref_logprobs = [] + for output in ref_outputs[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + ref_logprobs.append(logprobs[token_id]) + del ref_llm + + spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() + with VllmRunner( + model_name, + max_num_seqs=1, + max_num_batched_tokens=2048, + gpu_memory_utilization=0.6, + speculative_config={ + "method": "eagle3" if use_eagle3 else "eagle", + "model": spec_model_name, + "num_speculative_tokens": 2, + "max_model_len": 128, + }, + max_model_len=128, + enforce_eager=False, + ) as runner: + spec_outputs = runner.model.chat([prompt], sampling_params) + + # Collect logprobs outputs from spec decode LLM. + spec_logprobs = [] + for output in spec_outputs[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + spec_logprobs.append(logprobs[token_id]) + + for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): + assert math.isclose(ref_logprob.logprob, + spec_logprob.logprob, + rel_tol=5e-2, + abs_tol=1e-1) + assert ref_logprob.rank == spec_logprob.rank + assert ref_logprob.decoded_token == spec_logprob.decoded_token diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 08b1c7a4ce9..83b90fceeff 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -228,7 +228,7 @@ # Future Plan: # Remove this patch when the bug is fixed. # -# ** File: worker/patch_qwen3_next_mtp.py** +# ** 11. File: worker/patch_qwen3_next_mtp.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.v1.worker.utils.bind_kv_cache` # Why: @@ -241,7 +241,7 @@ # Future Plan: # Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu. # -# ** File: worker/patch_module.py** +# ** 12. File: worker/patch_module.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort` # Why: @@ -253,3 +253,19 @@ # Future Plan: # Remove this patch when bool is supported in 'torch.argsort' func of npu. # +# ** 13. File: worker/patch_rejection_sampler.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.sample.rejection_sampler` +# Why: +# - some functions from `rejection_sampler` are not supported or slow on npu. +# How: +# - add npu_top_k_top_p to 'apply_sampling_constraints' func +# - add custom triton kernel to `expand_batch_to_tokens` and `rejection_sample` +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/874 +# https://github.com/vllm-project/vllm/pull/4849 +# Future Plan: +# 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler +# to override them, then delete the patch file `worker/patch_rejection_sampler.py`. +# 2. make these functions as costom op, then remove AscendRejectionSampler +# \ No newline at end of file