Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 58 additions & 22 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,21 +418,35 @@ def __post_init__(self) -> None:

def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
raise VLLMValidationError(
f"n must be an int, but is of type {type(self.n)}",
parameter="n",
value=self.n,
)
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
raise VLLMValidationError(
f"n must be at least 1, got {self.n}.",
parameter="n",
value=self.n,
)
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError(
f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
raise VLLMValidationError(
f"presence_penalty must be in [-2, 2], got {self.presence_penalty}.",
parameter="presence_penalty",
value=self.presence_penalty,
)
if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError(
f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
raise VLLMValidationError(
f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}.",
parameter="frequency_penalty",
value=self.frequency_penalty,
)
if self.repetition_penalty <= 0.0:
raise ValueError(
raise VLLMValidationError(
"repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}."
f"{self.repetition_penalty}.",
parameter="repetition_penalty",
value=self.repetition_penalty,
)
if self.temperature < 0.0:
raise VLLMValidationError(
Expand All @@ -448,29 +462,42 @@ def _verify_args(self) -> None:
)
# quietly accept -1 as disabled, but prefer 0
if self.top_k < -1:
raise ValueError(
f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
raise VLLMValidationError(
f"top_k must be 0 (disable), or at least 1, got {self.top_k}.",
parameter="top_k",
value=self.top_k,
)
if not isinstance(self.top_k, int):
raise TypeError(
f"top_k must be an integer, got {type(self.top_k).__name__}"
raise VLLMValidationError(
f"top_k must be an integer, got {type(self.top_k).__name__}",
parameter="top_k",
value=self.top_k,
)
if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
raise VLLMValidationError(
f"min_p must be in [0, 1], got {self.min_p}.",
parameter="min_p",
value=self.min_p,
)
if self.max_tokens is not None and self.max_tokens < 1:
raise VLLMValidationError(
f"max_tokens must be at least 1, got {self.max_tokens}.",
parameter="max_tokens",
value=self.max_tokens,
)
if self.min_tokens < 0:
raise ValueError(
f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
raise VLLMValidationError(
"min_tokens must be greater than or equal to 0, "
f"got {self.min_tokens}.",
parameter="min_tokens",
value=self.min_tokens,
)
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(
raise VLLMValidationError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}."
f"max_tokens={self.max_tokens}, got {self.min_tokens}.",
parameter="min_tokens",
value=self.min_tokens,
)
if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
raise VLLMValidationError(
Expand All @@ -491,16 +518,25 @@ def _verify_args(self) -> None:
)
assert isinstance(self.stop_token_ids, list)
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
raise ValueError(
f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
raise VLLMValidationError(
"stop_token_ids must contain only integers, "
f"got {self.stop_token_ids}.",
parameter="stop_token_ids",
value=self.stop_token_ids,
)
assert isinstance(self.stop, list)
if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.")
raise VLLMValidationError(
"stop cannot contain an empty string.",
parameter="stop",
value=self.stop,
)
if self.stop and not self.detokenize:
raise ValueError(
raise VLLMValidationError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop."
"Set detokenize=True to use stop.",
parameter="stop",
value=self.stop,
)

def _verify_greedy_sampling(self) -> None:
Expand Down
Loading