diff --git a/vllm_ascend/patch/worker/patch_rejection_sampler.py b/vllm_ascend/patch/worker/patch_rejection_sampler.py new file mode 100644 index 00000000000..b8202210d30 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_rejection_sampler.py @@ -0,0 +1,7 @@ +import vllm.v1.sample.rejection_sampler as rs + +from vllm_ascend.sample.rejection_sampler import (expand_batch_to_tokens, + rejection_sample) + +rs.expand_batch_to_tokens = expand_batch_to_tokens +rs.rejection_sample = rejection_sample \ No newline at end of file diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index b7905373e35..a37d037b1ad 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -2,14 +2,9 @@ from typing import Optional import torch -import torch.nn as nn -import vllm.v1.sample.rejection_sampler as rs from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import (RejectionSampler, - apply_sampling_constraints, - generate_uniform_probs) -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.sample.rejection_sampler import generate_uniform_probs PLACEHOLDER_TOKEN_ID = -1 GREEDY_TEMPERATURE = -1 @@ -18,92 +13,6 @@ MAX_SPEC_LEN = 32 -class AscendRejectionSampler(RejectionSampler, nn.Module): - """ - The implementation strictly follows the algorithm described in - https://arxiv.org/abs/2211.17192. - However, we want to clarify the terminology used in the implementation: - accepted tokens: tokens that are accepted based on the relationship - between the "raw" draft and target probabilities. - recovered tokens: tokens that are sampled based on the adjusted probability - distribution, which is derived from both the draft and target - probabilities. - bonus tokens: - If all proposed tokens are accepted, the bonus token is added to the - end of the sequence. The bonus token is only sampled from the target - probabilities. We pass in the bonus tokens instead of sampling them - in the rejection sampler to allow for more flexibility in the - sampling process. For example, we can use top_p, top_k sampling for - bonus tokens, while spec decode does not support these sampling - strategies. - output tokens: - Tokens are finally generated with the rejection sampler. - output tokens = accepted tokens + recovered tokens + bonus tokens - """ - - def forward( - self, - metadata: SpecDecodeMetadata, - # [num_tokens, vocab_size] - draft_probs: Optional[torch.Tensor], - # [num_tokens, vocab_size] - target_logits: torch.Tensor, - # [batch_size, 1] - bonus_token_ids: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - ''' - Args: - metadata: - Metadata for spec decoding. - draft_probs (Optional[torch.Tensor]): - Probability distribution for the draft tokens. Shape is - [num_tokens, vocab_size]. Can be None if probabilities are - not provided, which is the case for ngram spec decode. - target_logits (torch.Tensor): - Target model's logits probability distribution. - Shape is [num_tokens, vocab_size]. Here, probabilities from - different requests are flattened into a single tensor because - this is the shape of the output logits. - NOTE: `target_logits` can be updated in place to save memory. - bonus_token_ids_tensor (torch.Tensor): - A tensor containing bonus tokens. Shape is [batch_size, 1]. - Bonus tokens are added to the end of the sequence if all - proposed tokens are accepted. We generate the bonus tokens - outside of the rejection sampler with the default sampling - strategy. It allows for more flexibility in the sampling - process such as top_p, top_k sampling. - sampling_metadata (SamplingMetadata): - Additional metadata needed for sampling, such as temperature, - top-k/top-p parameters, or other relevant information. - Returns: - output_token_ids (torch.Tensor): - A tensor containing the final output token IDs. - ''' - assert metadata.max_spec_len <= MAX_SPEC_LEN - # [num_tokens, vocab_size] - # NOTE(woosuk): `target_logits` can be updated in place inside the - # `compute_probs` function. - target_logits = apply_sampling_constraints( - target_logits, - metadata.cu_num_draft_tokens, - sampling_metadata, - ) - target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) - - output_token_ids = rejection_sample( - metadata.draft_token_ids, - metadata.num_draft_tokens, - metadata.max_spec_len, - metadata.cu_num_draft_tokens, - draft_probs, - target_probs, - bonus_token_ids, - sampling_metadata, - ) - return output_token_ids - - def rejection_sample( # [num_tokens] draft_token_ids: torch.Tensor, @@ -777,6 +686,3 @@ def sample_recovered_tokens_kernel( tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) - - -rs.expand_batch_to_tokens = expand_batch_to_tokens diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c66914c7150..71d900376e7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -92,11 +92,12 @@ MambaSpec, MLAAttentionSpec, UniformTypeKVCacheSpecs) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, LogprobsTensors, ModelRunnerOutput, - PoolerOutput, + DraftTokenIds, LogprobsLists, LogprobsTensors, + ModelRunnerOutput, PoolerOutput, SamplerOutput, make_empty_encoder_model_runner_output) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer @@ -137,7 +138,6 @@ from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.logits_processor import build_logitsprocs -from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType @@ -634,7 +634,7 @@ def _set_up_drafter(self): diagonal=1).to(self.device) if get_pp_group().is_last_rank: self.drafter = self._get_drafter() - self.rejection_sampler = AscendRejectionSampler(self.sampler) + self.rejection_sampler = RejectionSampler(self.sampler) self.actual_seq_lengths_q = list( range(self.decode_token_per_req, self.max_num_tokens + 1, self.decode_token_per_req)) @@ -2032,7 +2032,7 @@ def _calc_spec_decode_metadata( draft_token_ids = draft_token_ids[target_logits_indices + 1] if self.pcp_size > 1: logits_indices = logits_indices_pcp - metadata = SpecDecodeMetadata( + return SpecDecodeMetadata( draft_token_ids=draft_token_ids, num_draft_tokens=num_draft_tokens.tolist(), cu_num_draft_tokens=cu_num_draft_tokens, @@ -2041,7 +2041,6 @@ def _calc_spec_decode_metadata( bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, ) - return metadata def apply_grammar_bitmask( self, @@ -2426,144 +2425,31 @@ def sample_tokens( grammar_output, logits) with ProfileExecuteDuration().capture_async("Sample"): - # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.sampling_metadata - if spec_decode_metadata is None: - if lmhead_tp_enable() and logits is not None: - logits = logits[:self.input_batch.num_reqs] - sampler_output = self.sampler( - logits=logits, - sampling_metadata=sampling_metadata, - ) - else: - if lmhead_tp_enable() and logits is not None: - logits = logits[:len(spec_decode_metadata.logits_indices)] - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[ - spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[ - spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - if self.need_accepted_tokens: - self._update_states_after_model_execute(output_token_ids) - - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] - for i in discard_sampled_tokens_req_indices: - generator = self.input_batch.generators.get(int(i)) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - - # Copy some objects so they don't get modified after returning. - # This is important when using async scheduling. - req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() - - # NOTE: NPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:scheduler_output.total_num_scheduled_tokens], + sampler_output = self._sample(logits, spec_decode_metadata) + + self.input_batch.prev_sampled_token_ids = None + + ( + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) = self._bookkeeping_sync( scheduler_output, + sampler_output, + logits, + hidden_states, + scheduler_output.total_num_scheduled_tokens, + spec_decode_metadata, ) - num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] - sampled_token_ids = sampler_output.sampled_token_ids - if not self.use_async_scheduling: - # Get the valid generated tokens. - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - # No spec decode tokens. It's a tensor. - valid_sampled_token_ids = sampled_token_ids.tolist() - else: - # Includes spec decode tokens. It's a numpy array - valid_sampled_token_ids, _ = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[int(i)].clear() - else: - valid_sampled_token_ids = [] - invalid_req_indices = discard_sampled_tokens_req_indices.tolist( - ) - invalid_req_indices_set = set(invalid_req_indices) - assert sampled_token_ids.shape[-1] == 1 - - # Cache the sampled tokens on the NPU and avoid CPU sync. - # These will be copied into input_ids in the next step - # when preparing inputs. - self.input_batch.prev_sampled_token_ids = \ - sampled_token_ids - self.input_batch.prev_sampled_token_ids_invalid_indices = \ - invalid_req_indices_set - self.input_batch.prev_req_id_to_index = { - req_id: i - for i, req_id in enumerate(self.input_batch.req_ids) - if i not in invalid_req_indices_set - } - # Cache the sampled tokens in the model runner, so that the scheduler - # doesn't need to send them back. - # NOTE(woosuk): As an exception, when using PP, the scheduler sends - # the sampled tokens back, because there's no direct communication - # between the first-stage worker and the last-stage worker. - for req_idx in range(num_sampled_tokens): - if self.use_async_scheduling: - sampled_ids = [-1] * 1 if \ - req_idx not in invalid_req_indices_set else None - else: - sampled_ids = valid_sampled_token_ids[req_idx] - if not sampled_ids: - continue - - start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) - assert end_idx <= self.model_config.max_model_len, ( - "Sampled token IDs exceed the max model length. " - f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.model_config.max_model_len}") - - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.is_token_ids[req_idx, - start_idx:end_idx] = True - self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] - req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) - def propose_draft_token_ids(sampled_token_ids): assert self.spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( sampled_token_ids, - sampling_metadata, + self.input_batch.sampling_metadata, scheduler_output, spec_decode_metadata, positions, @@ -2626,11 +2512,161 @@ def propose_draft_token_ids(sampled_token_ids): self.debugger.step() return AsyncNPUModelRunnerOutput( model_runner_output=model_runner_output, - sampled_token_ids=sampled_token_ids, + sampled_token_ids=sampler_output.sampled_token_ids, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, ) + def _sample( + self, + logits: torch.Tensor | None, + spec_decode_metadata: SpecDecodeMetadata | None, + ) -> SamplerOutput: + # Sample the next token and get logprobs if needed. + sampling_metadata = self.input_batch.sampling_metadata + if spec_decode_metadata is None: + if lmhead_tp_enable() and logits is not None: + logits = logits[:self.input_batch.num_reqs] + return self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + if lmhead_tp_enable() and logits is not None: + logits = logits[:len(spec_decode_metadata.logits_indices)] + sampler_output = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + logits, + sampling_metadata, + ) + if self.need_accepted_tokens: # TODO remove this if + self._update_states_after_model_execute( + sampler_output.sampled_token_ids) + return sampler_output + + def _bookkeeping_sync( + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: torch.Tensor | None, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + spec_decode_metadata: SpecDecodeMetadata | None, + ) -> tuple[ + LogprobsLists | None, + list[list[int]], + dict[str, LogprobsTensors | None], + list[str], + dict[str, int], + list[int], + ]: + # TODO check PR https://github.com/vllm-project/vllm/pull/18777 + # num_nans_in_logits = {} + # if envs.VLLM_COMPUTE_NANS_IN_LOGITS: + # num_nans_in_logits = self._get_nans_in_logits(logits) + + discard_sampled_tokens_req_indices = \ + self.discard_request_indices.np[:self.num_discarded_requests] + for i in discard_sampled_tokens_req_indices: + generator = self.input_batch.generators.get(int(i)) + if generator is not None: + generator.set_offset(generator.get_offset() - 4) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + + # NOTE: NPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_tensors = sampler_output.logprobs_tensors + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:scheduler_output.total_num_scheduled_tokens], + scheduler_output, + ) + + num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] + sampled_token_ids = sampler_output.sampled_token_ids + invalid_req_indices = [] + if not self.use_async_scheduling: + # Get the valid generated tokens. + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + # No spec decode tokens. It's a tensor. + valid_sampled_token_ids = sampled_token_ids.tolist() + else: + # Includes spec decode tokens. It's a numpy array + valid_sampled_token_ids, _ = self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # Mask out the sampled tokens that should not be sampled. + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[int(i)].clear() + else: + valid_sampled_token_ids = [] + invalid_req_indices = discard_sampled_tokens_req_indices.tolist() + invalid_req_indices_set = set(invalid_req_indices) + assert sampled_token_ids.shape[-1] == 1 + + # Cache the sampled tokens on the NPU and avoid CPU sync. + # These will be copied into input_ids in the next step + # when preparing inputs. + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) + if i not in invalid_req_indices_set + } + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx in range(num_sampled_tokens): + if self.use_async_scheduling: + sampled_ids = [-1] * 1 if \ + req_idx not in invalid_req_indices_set else None + else: + sampled_ids = valid_sampled_token_ids[req_idx] + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + assert end_idx <= self.model_config.max_model_len, ( + "Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{self.model_config.max_model_len}") + + self.input_batch.token_ids_cpu[req_idx, + start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + return ( + # num_nans_in_logits, + logprobs_lists, + valid_sampled_token_ids, + prompt_logprobs_dict, + req_ids_output_copy, + req_id_to_index_output_copy, + invalid_req_indices, + ) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None