diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 7300357a1676..f603dc96d321 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -49,7 +49,6 @@ ) from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs -from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.spec_decode import init_speculator @@ -139,7 +138,12 @@ def __init__( dtype=self.dtype, device=self.device, ) - self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) + self.sampler = Sampler( + max_num_reqs=self.max_num_reqs, + vocab_size=self.vocab_size, + device=self.device, + logprobs_mode=self.model_config.logprobs_mode, + ) # CUDA graphs. self.cudagraph_manager = CudaGraphManager( @@ -310,12 +314,14 @@ def _dummy_sampler_run( hidden_states: torch.Tensor, ) -> None: num_reqs = hidden_states.shape[0] - sampling_metadata = SamplingMetadata.make_dummy( - num_reqs=num_reqs, - device=self.device, - ) logits = self.model.compute_logits(hidden_states) - self.sampler(logits, sampling_metadata) + idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device) + idx_mapping_np = np.arange(num_reqs, dtype=np.int32) + pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device) + # NOTE(woosuk): During the initial memory profiling, the sampler may skip + # top_k, top_p, and logprobs, using less GPU memory than what is possible + # during actual execution. + self.sampler(logits, idx_mapping, idx_mapping_np, pos) @torch.inference_mode() def profile_run(self) -> None: @@ -401,9 +407,10 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None: assert new_req_data.prefill_token_ids is not None assert new_req_data.sampling_params is not None req_id = new_req_data.req_id + prompt_len = len(new_req_data.prompt_token_ids) self.req_states.add_request( req_id=req_id, - prompt_len=len(new_req_data.prompt_token_ids), + prompt_len=prompt_len, prefill_token_ids=new_req_data.prefill_token_ids, num_computed_tokens=new_req_data.num_computed_tokens, sampling_params=new_req_data.sampling_params, @@ -423,6 +430,9 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None: self.block_tables.append_block_ids( req_index, new_req_data.block_ids, overwrite=True ) + self.sampler.add_request( + req_index, prompt_len, new_req_data.sampling_params + ) # Add new blocks for the existing requests. cached_reqs = scheduler_output.scheduled_cached_reqs @@ -436,6 +446,11 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None: self.req_states.apply_staged_writes() self.block_tables.apply_staged_writes() + self.sampler.apply_staged_writes( + self.req_states.prefill_token_ids.gpu, + self.req_states.prefill_len.np, + self.req_states.prompt_len, + ) if self.uses_mrope: self.mrope_states.apply_staged_writes() @@ -612,10 +627,10 @@ def sample( self, hidden_states: torch.Tensor, input_batch: InputBatch, - sampling_metadata: SamplingMetadata, grammar_output: GrammarOutput | None, ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]: sample_hidden_states = hidden_states[input_batch.logits_indices] + sample_pos = input_batch.positions[input_batch.logits_indices] logits = self.model.compute_logits(sample_hidden_states) if grammar_output is not None: # Apply grammar bitmask to the logits in-place. @@ -627,7 +642,12 @@ def sample( ) # Sample tokens and compute logprobs (if needed). - sampler_output = self.sampler(logits, sampling_metadata) + sampler_output = self.sampler( + logits, + input_batch.expanded_idx_mapping, + input_batch.idx_mapping_np, + sample_pos, + ) if input_batch.num_draft_tokens == 0: # No draft tokens (common case). @@ -766,7 +786,7 @@ def postprocess( input_batch.idx_mapping, self.req_states.num_computed_tokens.gpu, self.req_states.last_sampled_tokens, - self.req_states.output_bin_counts, + self.sampler.penalties_state.output_bin_counts, sampled_tokens, num_sampled, num_rejected, @@ -786,7 +806,6 @@ def postprocess( def propose_draft( self, input_batch: InputBatch, - sampling_metadata: SamplingMetadata, last_hidden_states: torch.Tensor, aux_hidden_states: list[torch.Tensor] | None, num_sampled: torch.Tensor, @@ -801,13 +820,14 @@ def propose_draft( ] draft_tokens = self.speculator.propose( input_batch, - sampling_metadata, last_hidden_states, aux_hidden_states, num_sampled, num_rejected, last_sampled_tokens, next_prefill_tokens, + self.sampler.sampling_states.temperature.gpu, + self.sampler.sampling_states.seeds.gpu, ) return draft_tokens @@ -893,12 +913,6 @@ def execute_model( scheduler_output, num_tokens_after_padding, ) - - pos = input_batch.positions[input_batch.logits_indices] - sampling_metadata = self.req_states.make_sampling_metadata( - input_batch.expanded_idx_mapping, input_batch.idx_mapping_np, pos - ) - if self.lora_config: # Activate LoRA adapters. lora_inputs = self.req_states.make_lora_inputs( @@ -917,7 +931,6 @@ def execute_model( device=self.device, ) self.prepare_dummy_attn_metadata(input_batch) - sampling_metadata = None # Run model. if cudagraph_mode == CUDAGraphMode.FULL: @@ -946,7 +959,7 @@ def execute_model( positions=positions, ) - self.execute_model_state = hidden_states, input_batch, sampling_metadata + self.execute_model_state = hidden_states, input_batch return None @torch.inference_mode() @@ -955,12 +968,11 @@ def sample_tokens( grammar_output: GrammarOutput | None, ) -> AsyncOutput | ModelRunnerOutput: assert self.execute_model_state is not None - hidden_states, input_batch, sampling_metadata = self.execute_model_state + hidden_states, input_batch = self.execute_model_state self.execute_model_state = None # type: ignore - assert sampling_metadata is not None sampler_output, num_sampled, num_rejected = self.sample( - hidden_states, input_batch, sampling_metadata, grammar_output + hidden_states, input_batch, grammar_output ) prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) @@ -992,7 +1004,6 @@ def sample_tokens( if self.do_spec_decode: draft_tokens = self.propose_draft( input_batch, - sampling_metadata, hidden_states, None, # aux_hidden_states num_sampled, diff --git a/vllm/v1/worker/gpu/sample/metadata.py b/vllm/v1/worker/gpu/sample/metadata.py deleted file mode 100644 index 27167fd20c5e..000000000000 --- a/vllm/v1/worker/gpu/sample/metadata.py +++ /dev/null @@ -1,79 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass - -import torch - - -@dataclass -class SamplingMetadata: - idx_mapping: torch.Tensor - - temperature: torch.Tensor - - top_p: torch.Tensor | None - top_k: torch.Tensor | None - min_p: torch.Tensor | None - - # For penalties - repetition_penalty: torch.Tensor - frequency_penalty: torch.Tensor - presence_penalty: torch.Tensor - prompt_bin_mask: torch.Tensor - output_bin_counts: torch.Tensor - - seeds: torch.Tensor - pos: torch.Tensor - - # None means no logprobs, 0 means sampled token logprobs only - max_num_logprobs: int | None - - @classmethod - def make_dummy( - cls, - num_reqs: int, - device: torch.device, - ) -> "SamplingMetadata": - assert num_reqs > 0 - idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) - - temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device) - temperature[0] = 0.5 - # TODO(woosuk): Use top-p and top-k for dummy sampler. - # Currently, they are disabled because of memory usage. - # top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device) - # top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device) - top_p = None - top_k = None - min_p = torch.zeros(num_reqs, dtype=torch.float32, device=device) - # NOTE(woosuk): We must set penalties to their default values to make sure - # the penalties kernel does not touch the placeholder bin_counts tensors. - repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device) - frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) - presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) - - # NOTE(woosuk): These are placeholder tensors to avoid None checks in the - # penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton - # specialization and re-compilation at runtime. - prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) - output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) - - seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) - pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) - max_num_logprobs = 20 - - return cls( - idx_mapping=idx_mapping, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - repetition_penalty=repetition_penalty, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - prompt_bin_mask=prompt_bin_mask, - output_bin_counts=output_bin_counts, - seeds=seeds, - pos=pos, - max_num_logprobs=max_num_logprobs, - ) diff --git a/vllm/v1/worker/gpu/sample/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py index 26b0346b29d8..6226ff15e7de 100644 --- a/vllm/v1/worker/gpu/sample/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -1,9 +1,87 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np import torch +from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton -from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata +from vllm.utils.math_utils import cdiv +from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor + + +class PenaltiesState: + def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device): + self.max_num_reqs = max_num_reqs + self.vocab_size = vocab_size + self.device = device + + self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) + self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) + self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) + + # Initialize repetition penalty manually because 0 is an invalid value for it. + self.repetition_penalty.np.fill(1.0) + self.repetition_penalty.copy_to_uva() + + # Statistics for penalties. + self.prompt_bin_mask = torch.zeros( + self.max_num_reqs, + cdiv(self.vocab_size, 32), + dtype=torch.int32, + device=self.device, + ) + # TODO(woosuk): This tensor is rarely used but can be very large, taking up + # GBs of GPU memory. Optimize the memory usage. + self.output_bin_counts = torch.zeros( + self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device + ) + + self._penalties_reqs: list[int] = [] + + def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None: + self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty + self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty + self.presence_penalty.np[req_idx] = sampling_params.presence_penalty + if use_penalty(sampling_params): + self._penalties_reqs.append(req_idx) + + def apply_staged_writes( + self, + prefill_token_ids: torch.Tensor, + prefill_lens: np.ndarray, + prompt_lens: np.ndarray, + ) -> None: + # TODO(woosuk): Optimize this. + for req_idx in self._penalties_reqs: + bincount( + prefill_token_ids[req_idx], + int(prefill_lens[req_idx]), + int(prompt_lens[req_idx]), + self.prompt_bin_mask[req_idx], + self.output_bin_counts[req_idx], + ) + self._penalties_reqs.clear() + + self.repetition_penalty.copy_to_uva() + self.frequency_penalty.copy_to_uva() + self.presence_penalty.copy_to_uva() + + def apply_penalties_and_temperature( + self, + logits: torch.Tensor, + idx_mapping: torch.Tensor, + temperature: torch.Tensor, + ) -> None: + apply_penalties_and_temperature( + logits, + idx_mapping, + temperature, + self.repetition_penalty.gpu, + self.frequency_penalty.gpu, + self.presence_penalty.gpu, + self.prompt_bin_mask, + self.output_bin_counts, + ) @triton.jit @@ -84,7 +162,13 @@ def _penalties_and_temperature_kernel( def apply_penalties_and_temperature( logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + idx_mapping: torch.Tensor, + temperature: torch.Tensor, + repetition_penalty: torch.Tensor, + frequency_penalty: torch.Tensor, + presence_penalty: torch.Tensor, + prompt_bin_mask: torch.Tensor, + output_bin_counts: torch.Tensor, ) -> None: num_reqs, vocab_size = logits.shape BLOCK_SIZE = 8192 @@ -92,15 +176,15 @@ def apply_penalties_and_temperature( _penalties_and_temperature_kernel[(num_reqs, num_blocks)]( logits, logits.stride(0), - sampling_metadata.idx_mapping, - sampling_metadata.repetition_penalty, - sampling_metadata.frequency_penalty, - sampling_metadata.presence_penalty, - sampling_metadata.temperature, - sampling_metadata.prompt_bin_mask, - sampling_metadata.prompt_bin_mask.stride(0), - sampling_metadata.output_bin_counts, - sampling_metadata.output_bin_counts.stride(0), + idx_mapping, + repetition_penalty, + frequency_penalty, + presence_penalty, + temperature, + prompt_bin_mask, + prompt_bin_mask.stride(0), + output_bin_counts, + output_bin_counts.stride(0), vocab_size, BLOCK_SIZE=BLOCK_SIZE, ) @@ -153,3 +237,11 @@ def bincount( output_bin_counts, BLOCK_SIZE=BLOCK_SIZE, ) + + +def use_penalty(sampling_params: SamplingParams) -> bool: + return ( + sampling_params.repetition_penalty != 1.0 + or sampling_params.frequency_penalty != 0.0 + or sampling_params.presence_penalty != 0.0 + ) diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 6ed849ec8a1d..a9df2b48c51c 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -1,23 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np import torch import vllm.envs as envs from vllm.config.model import LogprobsMode +from vllm.sampling_params import SamplingParams from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs -from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.sample.min_p import apply_min_p from vllm.v1.worker.gpu.sample.output import SamplerOutput -from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature +from vllm.v1.worker.gpu.sample.penalties import PenaltiesState +from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates class Sampler: def __init__( self, + max_num_reqs: int, + vocab_size: int, + device: torch.device, logprobs_mode: LogprobsMode = "raw_logprobs", ): if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]: @@ -25,26 +31,54 @@ def __init__( self.logprobs_mode = logprobs_mode self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default. + self.sampling_states = SamplingStates(max_num_reqs, vocab_size) + self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device) + self.logit_bias_state = LogitBiasState(max_num_reqs, device) + + def add_request( + self, + req_idx: int, + prompt_len: int, + sampling_params: SamplingParams, + ) -> None: + self.sampling_states.add_request(req_idx, sampling_params) + self.penalties_state.add_request(req_idx, sampling_params) + self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params) + + def apply_staged_writes( + self, + prefill_token_ids: torch.Tensor, + prefill_lens: np.ndarray, + prompt_lens: np.ndarray, + ) -> None: + self.sampling_states.apply_staged_writes() + self.penalties_state.apply_staged_writes( + prefill_token_ids, prefill_lens, prompt_lens + ) + self.logit_bias_state.apply_staged_writes() + def __call__( self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + idx_mapping: torch.Tensor, + idx_mapping_np: np.ndarray, + pos: torch.Tensor, ) -> SamplerOutput: # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # that num_nans is computed before applying penalties and temperature. num_nans = get_num_nans(logits) if self.compute_nans else None - sampled, processed_logits = self.sample(logits, sampling_metadata) - if sampling_metadata.max_num_logprobs is not None: + sampled, processed_logits = self.sample( + logits, idx_mapping, idx_mapping_np, pos + ) + + max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) + if max_num_logprobs != NO_LOGPROBS: logits = ( processed_logits if self.logprobs_mode == "processed_logprobs" else logits ) - logprobs_tensors = compute_topk_logprobs( - logits, - sampling_metadata.max_num_logprobs, - sampled, - ) + logprobs_tensors = compute_topk_logprobs(logits, max_num_logprobs, sampled) else: logprobs_tensors = None @@ -62,27 +96,41 @@ def __call__( def sample( self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + idx_mapping: torch.Tensor, + idx_mapping_np: np.ndarray, + pos: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: # Copy logits to a new FP32 tensor. logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) + # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place. + self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos) + # Apply penalties and temperature in place. - apply_penalties_and_temperature(logits, sampling_metadata) - # Apply min_p in place. - if sampling_metadata.min_p is not None: - apply_min_p(logits, sampling_metadata.idx_mapping, sampling_metadata.min_p) - # Apply top_k and/or top_p. This might return a new tensor. - logits = apply_top_k_top_p( - logits, sampling_metadata.top_k, sampling_metadata.top_p + self.penalties_state.apply_penalties_and_temperature( + logits, idx_mapping, self.sampling_states.temperature.gpu ) + # Apply min_p in place if any request has a non-zero min_p. + do_min_p = self.sampling_states.do_min_p(idx_mapping_np) + if do_min_p: + apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu) + + # Apply top_k and/or top_p. This might return a new tensor. + do_top_k = self.sampling_states.do_top_k(idx_mapping_np) + top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None + do_top_p = self.sampling_states.do_top_p(idx_mapping_np) + top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None + if do_top_k or do_top_p: + logits = apply_top_k_top_p(logits, top_k, top_p) + + # Sample the next token. sampled = gumbel_sample( logits, - sampling_metadata.idx_mapping, - sampling_metadata.temperature, - sampling_metadata.seeds, - sampling_metadata.pos, + idx_mapping, + self.sampling_states.temperature.gpu, + self.sampling_states.seeds.gpu, + pos, apply_temperature=False, ) return sampled, logits diff --git a/vllm/v1/worker/gpu/sample/states.py b/vllm/v1/worker/gpu/sample/states.py new file mode 100644 index 000000000000..f2a2279224a5 --- /dev/null +++ b/vllm/v1/worker/gpu/sample/states.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np +import torch + +from vllm.sampling_params import SamplingParams +from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor + +NO_LOGPROBS = -1 +_NP_INT64_MIN = np.iinfo(np.int64).min +_NP_INT64_MAX = np.iinfo(np.int64).max + + +class SamplingStates: + def __init__(self, max_num_reqs: int, vocab_size: int): + self.max_num_reqs = max_num_reqs + self.vocab_size = vocab_size + + self.temperature = UvaBackedTensor(max_num_reqs, dtype=torch.float32) + self.top_k = UvaBackedTensor(max_num_reqs, dtype=torch.int32) + self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32) + self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32) + self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64) + + # Initialize top_k and top_p manually because 0 is an invalid value for them. + self.top_k.np.fill(self.vocab_size) + self.top_k.copy_to_uva() + self.top_p.np.fill(1.0) + self.top_p.copy_to_uva() + + self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32) + # -1 means no logprobs are requested. + self.num_logprobs.fill(NO_LOGPROBS) + + def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None: + self.temperature.np[req_idx] = sampling_params.temperature + self.top_p.np[req_idx] = sampling_params.top_p + if 0 < sampling_params.top_k < self.vocab_size: + top_k = sampling_params.top_k + else: + top_k = self.vocab_size + self.top_k.np[req_idx] = top_k + self.min_p.np[req_idx] = sampling_params.min_p + + if sampling_params.seed is not None: + seed = sampling_params.seed + else: + seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX) + self.seeds.np[req_idx] = seed + + if sampling_params.logprobs is not None: + num_logprobs = sampling_params.logprobs + else: + num_logprobs = NO_LOGPROBS + self.num_logprobs[req_idx] = num_logprobs + + def apply_staged_writes(self) -> None: + self.temperature.copy_to_uva() + self.top_p.copy_to_uva() + self.top_k.copy_to_uva() + self.min_p.copy_to_uva() + self.seeds.copy_to_uva() + + def do_min_p(self, idx_mapping_np: np.ndarray) -> bool: + return np.any(self.min_p.np[idx_mapping_np] != 0.0) + + def do_top_k(self, idx_mapping_np: np.ndarray) -> bool: + return np.any(self.top_k.np[idx_mapping_np] != self.vocab_size) + + def do_top_p(self, idx_mapping_np: np.ndarray) -> bool: + return np.any(self.top_p.np[idx_mapping_np] != 1.0) + + def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int: + return int(np.max(self.num_logprobs[idx_mapping_np])) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index ed9260120207..176b0d28097a 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -17,7 +17,6 @@ from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample -from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager logger = init_logger(__name__) @@ -188,7 +187,6 @@ def capture_model(self) -> None: def propose( self, input_batch: InputBatch, - sampling_metadata: SamplingMetadata, # [num_tokens, hidden_size] last_hidden_states: torch.Tensor, # num_layers x [num_tokens, hidden_size] @@ -201,6 +199,10 @@ def propose( last_sampled: torch.Tensor, # [num_reqs] next_prefill_tokens: torch.Tensor, + # [max_num_reqs] + temperature: torch.Tensor, + # [max_num_reqs] + seeds: torch.Tensor, ) -> torch.Tensor: # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # number of rejected tokens, we maintain the size of eagle's input_ids and @@ -246,8 +248,8 @@ def propose( # affect the output distribution after rejection sampling. idx_mapping = self.idx_mapping[:num_reqs] idx_mapping.copy_(input_batch.idx_mapping) - self.temperature.copy_(sampling_metadata.temperature) - self.seeds.copy_(sampling_metadata.seeds) + self.temperature.copy_(temperature) + self.seeds.copy_(seeds) # Gather the values and copy them to the pre-allocated buffers. pos = self.input_buffers.positions[:num_reqs] torch.gather(input_batch.positions, 0, last_token_indices, out=pos) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index abfc88405c96..f11b03ab61a7 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -7,14 +7,9 @@ from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams -from vllm.utils.math_utils import cdiv from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor -from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu.sample.penalties import bincount -_NP_INT64_MIN = np.iinfo(np.int64).min -_NP_INT64_MAX = np.iinfo(np.int64).max NO_LORA_ID = 0 @@ -81,38 +76,8 @@ def __init__( self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) self.lora_ids.fill(NO_LORA_ID) - # Sampling parameters. - self.temperature = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) - self.top_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) - self.top_k = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) - self.min_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) - self.repetition_penalty = UvaBackedTensor( - self.max_num_reqs, dtype=torch.float32 - ) - self.frequency_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) - self.presence_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32) - self.seeds = UvaBackedTensor(self.max_num_reqs, dtype=torch.int64) - - self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32) - # -1 means no logprobs are requested. - self.num_logprobs.fill(-1) self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) - # Statistics for penalties. - self.prompt_bin_mask = torch.zeros( - self.max_num_reqs, - cdiv(self.vocab_size, 32), - dtype=torch.int32, - device=self.device, - ) - # TODO(woosuk): This tensor is rarely used but can be extremely large. - # Optimize the memory usage. - self.output_bin_counts = torch.zeros( - self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device - ) - - self._penalties_reqs: list[int] = [] - @property def num_reqs(self) -> int: return len(self.req_id_to_index) @@ -147,33 +112,6 @@ def add_request( else: self.lora_ids[req_idx] = NO_LORA_ID - self.temperature.np[req_idx] = sampling_params.temperature - self.top_p.np[req_idx] = sampling_params.top_p - if 0 < sampling_params.top_k < self.vocab_size: - top_k = sampling_params.top_k - else: - top_k = self.vocab_size - self.top_k.np[req_idx] = top_k - self.min_p.np[req_idx] = sampling_params.min_p - self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty - self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty - self.presence_penalty.np[req_idx] = sampling_params.presence_penalty - - if use_penalty(sampling_params): - self._penalties_reqs.append(req_idx) - - if sampling_params.seed is not None: - seed = sampling_params.seed - else: - seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX) - self.seeds.np[req_idx] = seed - - if sampling_params.logprobs is not None: - num_logprobs = sampling_params.logprobs - else: - num_logprobs = -1 - self.num_logprobs[req_idx] = num_logprobs - # For now, only support prompt logprobs for the prompt tokens. needs_prompt_logprobs = sampling_params.prompt_logprobs is not None self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs @@ -183,17 +121,6 @@ def apply_staged_writes(self) -> None: self.prefill_token_ids.apply_write() self.num_computed_tokens.apply_write() - # TODO(woosuk): Optimize this. - for req_idx in self._penalties_reqs: - bincount( - self.prefill_token_ids.gpu[req_idx], - int(self.prefill_len.np[req_idx]), - int(self.prompt_len[req_idx]), - self.prompt_bin_mask[req_idx], - self.output_bin_counts[req_idx], - ) - self._penalties_reqs.clear() - def remove_request(self, req_id: str) -> None: self.extra_data.pop(req_id, None) req_idx = self.req_id_to_index.pop(req_id, None) @@ -203,53 +130,6 @@ def remove_request(self, req_id: str) -> None: self.index_to_req_id.pop(req_idx, None) self.free_indices.append(req_idx) - def make_sampling_metadata( - self, - idx_mapping: torch.Tensor, - idx_mapping_np: np.ndarray, - pos: torch.Tensor, - ) -> SamplingMetadata: - temperature = self.temperature.copy_to_uva() - - top_p = self.top_p.np[idx_mapping_np] - no_top_p = np.all(top_p == 1.0) - top_p = self.top_p.copy_to_uva()[idx_mapping] if not no_top_p else None - - top_k = self.top_k.np[idx_mapping_np] - no_top_k = np.all(top_k == self.vocab_size) - top_k = self.top_k.copy_to_uva()[idx_mapping] if not no_top_k else None - - min_p = self.min_p.np[idx_mapping_np] - no_min_p = np.all(min_p == 0.0) - min_p = self.min_p.copy_to_uva() if not no_min_p else None - - rep_penalty = self.repetition_penalty.copy_to_uva() - freq_penalty = self.frequency_penalty.copy_to_uva() - pres_penalty = self.presence_penalty.copy_to_uva() - - seeds = self.seeds.copy_to_uva() - - num_logprobs = self.num_logprobs[idx_mapping_np] - max_num_logprobs: int | None = int(np.max(num_logprobs)) - if max_num_logprobs == -1: - max_num_logprobs = None - - return SamplingMetadata( - idx_mapping=idx_mapping, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - repetition_penalty=rep_penalty, - frequency_penalty=freq_penalty, - presence_penalty=pres_penalty, - prompt_bin_mask=self.prompt_bin_mask, - output_bin_counts=self.output_bin_counts, - seeds=seeds, - pos=pos, - max_num_logprobs=max_num_logprobs, - ) - def make_lora_inputs( self, req_ids: list[str], @@ -272,11 +152,3 @@ def make_lora_inputs( class ExtraData: lora_request: LoRARequest | None in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list) - - -def use_penalty(sampling_params: SamplingParams) -> bool: - return ( - sampling_params.repetition_penalty != 1.0 - or sampling_params.frequency_penalty != 0.0 - or sampling_params.presence_penalty != 0.0 - )