Skip to content
55 changes: 54 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from transformers import GenerationConfig, GenerationMixin
from typing import Optional

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import Sampler, _apply_quadratic_sampling
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
Expand Down Expand Up @@ -397,3 +397,56 @@ def mock_sample(probs, logprobs, sampling_metadata):
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_apply_quadratic_sampling(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
logits = torch.randn((batch_size, 10), device=device)
smoothing_factors = torch.rand((batch_size, ), device=device)
smoothing_curves = torch.rand((batch_size, ), device=device)

transformed_logits = _apply_quadratic_sampling(logits, smoothing_factors,
smoothing_curves)

assert transformed_logits.shape == logits.shape
assert torch.all(transformed_logits <= logits)


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_apply_quadratic_sampling_with_inf(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
logits = torch.randn((batch_size, 10), device=device)
logits[:, 0] = float('-inf')
smoothing_factors = torch.rand((batch_size, ), device=device)
smoothing_curves = torch.rand((batch_size, ), device=device)

transformed_logits = _apply_quadratic_sampling(logits, smoothing_factors,
smoothing_curves)

assert transformed_logits.shape == logits.shape
assert torch.all(transformed_logits[:, 1:] <= logits[:, 1:])
assert torch.all(transformed_logits[:, 0] == logits[:, 0])


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_apply_quadratic_sampling_with_zero_smoothing(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
logits = torch.randn((batch_size, 10), device=device)
smoothing_factors = torch.zeros((batch_size, ), device=device)
smoothing_curves = torch.zeros((batch_size, ), device=device)

transformed_logits = _apply_quadratic_sampling(logits, smoothing_factors,
smoothing_curves)

assert transformed_logits.shape == logits.shape
assert torch.all(transformed_logits == logits)
8 changes: 8 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class ChatCompletionRequest(BaseModel):
echo: Optional[bool] = False
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
smoothing_factor: Optional[float] = 0.0
smoothing_curve: Optional[float] = 1.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
Expand Down Expand Up @@ -116,6 +118,8 @@ def logit_bias_logits_processor(
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
smoothing_factor=self.smoothing_factor,
smoothing_curve=self.smoothing_curve,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
Expand Down Expand Up @@ -178,6 +182,8 @@ class CompletionRequest(BaseModel):
spaces_between_special_tokens: Optional[bool] = True
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
smoothing_factor: Optional[float] = 0.0
smoothing_curve: Optional[float] = 1.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
Expand Down Expand Up @@ -211,6 +217,8 @@ def logit_bias_logits_processor(
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
smoothing_factor=self.smoothing_factor,
smoothing_curve=self.smoothing_curve,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
Expand Down
45 changes: 43 additions & 2 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def forward(
logits = _apply_logits_processors(logits, sampling_metadata)

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
(sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
do_quadratic) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)

# Apply presence and frequency penalties.
Expand All @@ -104,6 +104,11 @@ def forward(
if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)

if do_quadratic:
logits = _apply_quadratic_sampling(
logits, sampling_tensors.smoothing_factors,
sampling_tensors.smoothing_curves)

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Expand Down Expand Up @@ -242,6 +247,42 @@ def _apply_min_p(
return logits


# torch.jit will fuse pointwise ops for better performance
@torch.jit.script
def _apply_quadratic_sampling(
Comment thread
Yard1 marked this conversation as resolved.
logits: torch.Tensor,
smoothing_factors: torch.Tensor,
smoothing_curves: torch.Tensor,
) -> torch.Tensor:
"""
Applies quadratic and cubic transformation to the logits based
on the provided smoothing factors and curves. The transformation
is centered around the maximum logit value in the batch.

Credits: @kalomaze
Adapted from
https://github.com/PygmalionAI/aphrodite-engine/blob/13d850334e2ad2cb00aba251bf91f8d20f495d98/aphrodite/modeling/layers/sampler.py#L435-L476
"""
max_logits = logits.max(dim=-1, keepdim=True).values
diff = logits - max_logits
smoothing_factors.unsqueeze_(dim=1)
smoothing_curves.unsqueeze_(dim=1)

k = (3 - smoothing_curves) / 2
s = (smoothing_curves - 1) / 2

mask = smoothing_factors > 0
mask = mask.expand(logits.shape[0], logits.shape[1])

# only transform logits when they're not -inf, otherwise
# fails at smoothing_curves==3
transformed_logits = torch.where(
(logits != float('-inf')) & mask, -(k * smoothing_factors * diff**2) +
Comment thread
Yard1 marked this conversation as resolved.
(s * smoothing_factors * diff**3) + max_logits, logits)

return transformed_logits


def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
samples: torch.Tensor,
Expand Down
38 changes: 34 additions & 4 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class SamplingTensors:
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
smoothing_factors: torch.Tensor
smoothing_curves: torch.Tensor
presence_penalties: torch.Tensor
frequency_penalties: torch.Tensor
repetition_penalties: torch.Tensor
Expand All @@ -81,12 +83,15 @@ def from_sampling_metadata(
temperatures: List[float] = []
top_ps: List[float] = []
min_ps: List[float] = []
smoothing_factors: List[float] = []
smoothing_curves: List[float] = []
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
do_penalties = False
do_top_p_top_k = False
do_min_p = False
do_quadratic = False
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
Expand All @@ -95,6 +100,8 @@ def from_sampling_metadata(
r = sampling_params.repetition_penalty
top_p = sampling_params.top_p
min_p = sampling_params.min_p
smoothing_factor = sampling_params.smoothing_factor
smoothing_curve = sampling_params.smoothing_curve
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
top_k = vocab_size if top_k == -1 else top_k
Expand All @@ -108,6 +115,9 @@ def from_sampling_metadata(
do_top_p_top_k = True
if not do_min_p and min_p > _SAMPLING_EPS:
do_min_p = True
if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
or smoothing_curve > 1.0):
do_quadratic = True
if not do_penalties and (abs(p) >= _SAMPLING_EPS
or abs(f) >= _SAMPLING_EPS
or abs(r - 1.0) >= _SAMPLING_EPS):
Expand All @@ -120,6 +130,8 @@ def from_sampling_metadata(
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
smoothing_factors += [smoothing_factor] * (prompt_len - 1)
smoothing_curves += [smoothing_curve] * (prompt_len - 1)
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
Expand All @@ -133,19 +145,25 @@ def from_sampling_metadata(
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
smoothing_factors += [smoothing_factor] * len(seq_ids)
smoothing_curves += [smoothing_curve] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)

sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
frequency_penalties, repetition_penalties, prompt_tokens,
output_tokens, vocab_size, device, dtype)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
temperatures, top_ps, top_ks, min_ps, smoothing_factors,
smoothing_curves, presence_penalties, frequency_penalties,
repetition_penalties, prompt_tokens, output_tokens, vocab_size,
device, dtype)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
do_quadratic)

@classmethod
def from_lists(cls, temperatures: List[float], top_ps: List[float],
top_ks: List[int], min_ps: List[float],
smoothing_factors: List[float],
smoothing_curves: List[float],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
Expand Down Expand Up @@ -185,6 +203,14 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
dtype=dtype,
pin_memory=pin_memory,
)
smoothing_factors_t = torch.tensor(smoothing_factors,
device="cpu",
dtype=dtype,
pin_memory=pin_memory)
smoothing_curves_t = torch.tensor(smoothing_curves,
device="cpu",
dtype=dtype,
pin_memory=pin_memory)
presence_penalties_t = torch.tensor(
presence_penalties,
device="cpu",
Expand Down Expand Up @@ -228,6 +254,10 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
top_ps=top_ps_t.to(device=device, non_blocking=True),
top_ks=top_ks_t.to(device=device, non_blocking=True),
min_ps=min_ps_t.to(device=device, non_blocking=True),
smoothing_factors=smoothing_factors_t.to(device=device,
non_blocking=True),
smoothing_curves=smoothing_curves_t.to(device=device,
non_blocking=True),
presence_penalties=presence_penalties_t.to(device=device,
non_blocking=True),
frequency_penalties=frequency_penalties_t.to(device=device,
Expand Down
15 changes: 15 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
smoothing_factor: Smoothing factor (float) to use in Quadratic
Sampling. Applies a quadratic transformation to the logits.
Must be in [0, inf). Set to 0 to disable.
smoothing_curve: Smoothing curve (float) to use for Cubic sampling.
Must be in [0, inf). Set to 1.0 to disable.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Expand Down Expand Up @@ -104,6 +109,8 @@ def __init__(
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
smoothing_factor: float = 0.0,
smoothing_curve: float = 1.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
Expand All @@ -128,6 +135,8 @@ def __init__(
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.smoothing_factor = smoothing_factor
self.smoothing_curve = smoothing_curve
self.seed = seed
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
Expand Down Expand Up @@ -188,6 +197,12 @@ def _verify_args(self) -> None:
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0, 1], got "
f"{self.min_p}.")
if not self.smoothing_factor >= 0:
raise ValueError(f"smoothing_factor must be non negative, got "
f"{self.smoothing_factor}.")
if not self.smoothing_curve >= 1.0:
raise ValueError(f"smoothing_curve must be greater than 1, got "
f"{self.smoothing_curve}.")
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
Expand Down