diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 31e865f42ff3..5582d1f04c00 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -7,7 +7,7 @@ from transformers import GenerationConfig, GenerationMixin from typing import Optional -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, _apply_quadratic_sampling from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner @@ -397,3 +397,56 @@ def mock_sample(probs, logprobs, sampling_metadata): assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) del model_runner + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_apply_quadratic_sampling(seed: int, device: str): + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + logits = torch.randn((batch_size, 10), device=device) + smoothing_factors = torch.rand((batch_size, ), device=device) + smoothing_curves = torch.rand((batch_size, ), device=device) + + transformed_logits = _apply_quadratic_sampling(logits, smoothing_factors, + smoothing_curves) + + assert transformed_logits.shape == logits.shape + assert torch.all(transformed_logits <= logits) + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_apply_quadratic_sampling_with_inf(seed: int, device: str): + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + logits = torch.randn((batch_size, 10), device=device) + logits[:, 0] = float('-inf') + smoothing_factors = torch.rand((batch_size, ), device=device) + smoothing_curves = torch.rand((batch_size, ), device=device) + + transformed_logits = _apply_quadratic_sampling(logits, smoothing_factors, + smoothing_curves) + + assert transformed_logits.shape == logits.shape + assert torch.all(transformed_logits[:, 1:] <= logits[:, 1:]) + assert torch.all(transformed_logits[:, 0] == logits[:, 0]) + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_apply_quadratic_sampling_with_zero_smoothing(seed: int, device: str): + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + logits = torch.randn((batch_size, 10), device=device) + smoothing_factors = torch.zeros((batch_size, ), device=device) + smoothing_curves = torch.zeros((batch_size, ), device=device) + + transformed_logits = _apply_quadratic_sampling(logits, smoothing_factors, + smoothing_curves) + + assert transformed_logits.shape == logits.shape + assert torch.all(transformed_logits == logits) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 26499b8d7a66..cb01303e04c3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -84,6 +84,8 @@ class ChatCompletionRequest(BaseModel): echo: Optional[bool] = False repetition_penalty: Optional[float] = 1.0 min_p: Optional[float] = 0.0 + smoothing_factor: Optional[float] = 0.0 + smoothing_curve: Optional[float] = 1.0 include_stop_str_in_output: Optional[bool] = False length_penalty: Optional[float] = 1.0 guided_json: Optional[Union[str, dict, BaseModel]] = None @@ -116,6 +118,8 @@ def logit_bias_logits_processor( temperature=self.temperature, top_p=self.top_p, min_p=self.min_p, + smoothing_factor=self.smoothing_factor, + smoothing_curve=self.smoothing_curve, seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, @@ -178,6 +182,8 @@ class CompletionRequest(BaseModel): spaces_between_special_tokens: Optional[bool] = True repetition_penalty: Optional[float] = 1.0 min_p: Optional[float] = 0.0 + smoothing_factor: Optional[float] = 0.0 + smoothing_curve: Optional[float] = 1.0 include_stop_str_in_output: Optional[bool] = False length_penalty: Optional[float] = 1.0 guided_json: Optional[Union[str, dict, BaseModel]] = None @@ -211,6 +217,8 @@ def logit_bias_logits_processor( top_p=self.top_p, top_k=self.top_k, min_p=self.min_p, + smoothing_factor=self.smoothing_factor, + smoothing_curve=self.smoothing_curve, seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 71655b216fb3..a0ebb6677b32 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -81,8 +81,8 @@ def forward( logits = _apply_logits_processors(logits, sampling_metadata) # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( + (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p, + do_quadratic) = SamplingTensors.from_sampling_metadata( sampling_metadata, vocab_size, logits.device, logits.dtype) # Apply presence and frequency penalties. @@ -104,6 +104,11 @@ def forward( if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) + if do_quadratic: + logits = _apply_quadratic_sampling( + logits, sampling_tensors.smoothing_factors, + sampling_tensors.smoothing_curves) + # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) @@ -242,6 +247,42 @@ def _apply_min_p( return logits +# torch.jit will fuse pointwise ops for better performance +@torch.jit.script +def _apply_quadratic_sampling( + logits: torch.Tensor, + smoothing_factors: torch.Tensor, + smoothing_curves: torch.Tensor, +) -> torch.Tensor: + """ + Applies quadratic and cubic transformation to the logits based + on the provided smoothing factors and curves. The transformation + is centered around the maximum logit value in the batch. + + Credits: @kalomaze + Adapted from + https://github.com/PygmalionAI/aphrodite-engine/blob/13d850334e2ad2cb00aba251bf91f8d20f495d98/aphrodite/modeling/layers/sampler.py#L435-L476 + """ + max_logits = logits.max(dim=-1, keepdim=True).values + diff = logits - max_logits + smoothing_factors.unsqueeze_(dim=1) + smoothing_curves.unsqueeze_(dim=1) + + k = (3 - smoothing_curves) / 2 + s = (smoothing_curves - 1) / 2 + + mask = smoothing_factors > 0 + mask = mask.expand(logits.shape[0], logits.shape[1]) + + # only transform logits when they're not -inf, otherwise + # fails at smoothing_curves==3 + transformed_logits = torch.where( + (logits != float('-inf')) & mask, -(k * smoothing_factors * diff**2) + + (s * smoothing_factors * diff**3) + max_logits, logits) + + return transformed_logits + + def _greedy_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], samples: torch.Tensor, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 7deb80801856..2cc3b00f5ef4 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -64,6 +64,8 @@ class SamplingTensors: top_ps: torch.Tensor top_ks: torch.Tensor min_ps: torch.Tensor + smoothing_factors: torch.Tensor + smoothing_curves: torch.Tensor presence_penalties: torch.Tensor frequency_penalties: torch.Tensor repetition_penalties: torch.Tensor @@ -81,12 +83,15 @@ def from_sampling_metadata( temperatures: List[float] = [] top_ps: List[float] = [] min_ps: List[float] = [] + smoothing_factors: List[float] = [] + smoothing_curves: List[float] = [] presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] do_penalties = False do_top_p_top_k = False do_min_p = False + do_quadratic = False for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature @@ -95,6 +100,8 @@ def from_sampling_metadata( r = sampling_params.repetition_penalty top_p = sampling_params.top_p min_p = sampling_params.min_p + smoothing_factor = sampling_params.smoothing_factor + smoothing_curve = sampling_params.smoothing_curve # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) top_k = vocab_size if top_k == -1 else top_k @@ -108,6 +115,9 @@ def from_sampling_metadata( do_top_p_top_k = True if not do_min_p and min_p > _SAMPLING_EPS: do_min_p = True + if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS + or smoothing_curve > 1.0): + do_quadratic = True if not do_penalties and (abs(p) >= _SAMPLING_EPS or abs(f) >= _SAMPLING_EPS or abs(r - 1.0) >= _SAMPLING_EPS): @@ -120,6 +130,8 @@ def from_sampling_metadata( top_ps += [top_p] * (prompt_len - 1) top_ks += [top_k] * (prompt_len - 1) min_ps += [min_p] * (prompt_len - 1) + smoothing_factors += [smoothing_factor] * (prompt_len - 1) + smoothing_curves += [smoothing_curve] * (prompt_len - 1) presence_penalties += [0] * (prompt_len - 1) frequency_penalties += [0] * (prompt_len - 1) repetition_penalties += [1] * (prompt_len - 1) @@ -133,19 +145,25 @@ def from_sampling_metadata( top_ps += [top_p] * len(seq_ids) top_ks += [top_k] * len(seq_ids) min_ps += [min_p] * len(seq_ids) + smoothing_factors += [smoothing_factor] * len(seq_ids) + smoothing_curves += [smoothing_curve] * len(seq_ids) presence_penalties += [p] * len(seq_ids) frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) sampling_tensors = SamplingTensors.from_lists( - temperatures, top_ps, top_ks, min_ps, presence_penalties, - frequency_penalties, repetition_penalties, prompt_tokens, - output_tokens, vocab_size, device, dtype) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) + temperatures, top_ps, top_ks, min_ps, smoothing_factors, + smoothing_curves, presence_penalties, frequency_penalties, + repetition_penalties, prompt_tokens, output_tokens, vocab_size, + device, dtype) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p, + do_quadratic) @classmethod def from_lists(cls, temperatures: List[float], top_ps: List[float], top_ks: List[int], min_ps: List[float], + smoothing_factors: List[float], + smoothing_curves: List[float], presence_penalties: List[float], frequency_penalties: List[float], repetition_penalties: List[float], @@ -185,6 +203,14 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=dtype, pin_memory=pin_memory, ) + smoothing_factors_t = torch.tensor(smoothing_factors, + device="cpu", + dtype=dtype, + pin_memory=pin_memory) + smoothing_curves_t = torch.tensor(smoothing_curves, + device="cpu", + dtype=dtype, + pin_memory=pin_memory) presence_penalties_t = torch.tensor( presence_penalties, device="cpu", @@ -228,6 +254,10 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], top_ps=top_ps_t.to(device=device, non_blocking=True), top_ks=top_ks_t.to(device=device, non_blocking=True), min_ps=min_ps_t.to(device=device, non_blocking=True), + smoothing_factors=smoothing_factors_t.to(device=device, + non_blocking=True), + smoothing_curves=smoothing_curves_t.to(device=device, + non_blocking=True), presence_penalties=presence_penalties_t.to(device=device, non_blocking=True), frequency_penalties=frequency_penalties_t.to(device=device, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 8103f3c2b24b..eb235d6e7373 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -58,6 +58,11 @@ class SamplingParams: min_p: Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. + smoothing_factor: Smoothing factor (float) to use in Quadratic + Sampling. Applies a quadratic transformation to the logits. + Must be in [0, inf). Set to 0 to disable. + smoothing_curve: Smoothing curve (float) to use for Cubic sampling. + Must be in [0, inf). Set to 1.0 to disable. seed: Random seed to use for the generation. use_beam_search: Whether to use beam search instead of sampling. length_penalty: Float that penalizes sequences based on their length. @@ -104,6 +109,8 @@ def __init__( top_p: float = 1.0, top_k: int = -1, min_p: float = 0.0, + smoothing_factor: float = 0.0, + smoothing_curve: float = 1.0, seed: Optional[int] = None, use_beam_search: bool = False, length_penalty: float = 1.0, @@ -128,6 +135,8 @@ def __init__( self.top_p = top_p self.top_k = top_k self.min_p = min_p + self.smoothing_factor = smoothing_factor + self.smoothing_curve = smoothing_curve self.seed = seed self.use_beam_search = use_beam_search self.length_penalty = length_penalty @@ -188,6 +197,12 @@ def _verify_args(self) -> None: if not 0.0 <= self.min_p <= 1.0: raise ValueError("min_p must be in [0, 1], got " f"{self.min_p}.") + if not self.smoothing_factor >= 0: + raise ValueError(f"smoothing_factor must be non negative, got " + f"{self.smoothing_factor}.") + if not self.smoothing_curve >= 1.0: + raise ValueError(f"smoothing_curve must be greater than 1, got " + f"{self.smoothing_curve}.") if self.max_tokens is not None and self.max_tokens < 1: raise ValueError( f"max_tokens must be at least 1, got {self.max_tokens}.")