Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion tests/v1/entrypoints/llm/test_struct_output_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import json
from dataclasses import fields
from enum import Enum
from typing import TYPE_CHECKING, Any

Expand All @@ -21,7 +22,8 @@
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.sampling_params import (GuidedDecodingParams, SamplingParams,
StructuredOutputsParams)

if TYPE_CHECKING:
from vllm.config import TokenizerMode
Expand Down Expand Up @@ -89,6 +91,26 @@ def _load_json(s: str, backend: str) -> str:
return json.loads(s)


def test_guided_decoding_deprecated():
with pytest.warns(DeprecationWarning,
match="GuidedDecodingParams is deprecated.*"):
guided_decoding = GuidedDecodingParams(json_object=True)

structured_outputs = StructuredOutputsParams(json_object=True)
assert fields(guided_decoding) == fields(structured_outputs)

with pytest.warns(DeprecationWarning,
match="guided_decoding is deprecated.*"):
sp1 = SamplingParams(guided_decoding=guided_decoding)

with pytest.warns(DeprecationWarning,
match="guided_decoding is deprecated.*"):
sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding)

assert sp1 == sp2
assert sp1.structured_outputs == guided_decoding


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
"model_name, backend, tokenizer_mode, speculative_config",
Expand Down
36 changes: 36 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sampling parameters for text generation."""
import copy
import warnings
from dataclasses import field
from enum import Enum, IntEnum
from functools import cached_property
Expand Down Expand Up @@ -59,6 +60,19 @@ def __post_init__(self):
f"but multiple are specified: {self.__dict__}")


@dataclass
class GuidedDecodingParams(StructuredOutputsParams):

def __post_init__(self):
warnings.warn(
"GuidedDecodingParams is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"StructuredOutputsParams instead.",
DeprecationWarning,
stacklevel=2)
return super().__post_init__()


class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
Expand Down Expand Up @@ -179,6 +193,8 @@ class SamplingParams(
# Fields used to construct logits processors
structured_outputs: Optional[StructuredOutputsParams] = None
"""Parameters for configuring structured outputs."""
guided_decoding: Optional[GuidedDecodingParams] = None
"""Deprecated alias for structured_outputs."""
logit_bias: Optional[dict[int, float]] = None
"""If provided, the engine will construct a logits processor that applies
these logit biases."""
Expand Down Expand Up @@ -227,6 +243,7 @@ def from_optional(
ge=-1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: Optional[StructuredOutputsParams] = None,
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
allowed_token_ids: Optional[list[int]] = None,
extra_args: Optional[dict[str, Any]] = None,
Expand All @@ -238,6 +255,15 @@ def from_optional(
int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items()
}
if guided_decoding is not None:
warnings.warn(
"guided_decoding is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"structured_outputs instead.",
DeprecationWarning,
stacklevel=2)
structured_outputs = guided_decoding
guided_decoding = None

return SamplingParams(
n=1 if n is None else n,
Expand Down Expand Up @@ -334,6 +360,16 @@ def __post_init__(self) -> None:
# eos_token_id is added to this by the engine
self._all_stop_token_ids.update(self.stop_token_ids)

if self.guided_decoding is not None:
warnings.warn(
"guided_decoding is deprecated. This will be removed in "
"v0.12.0 or v1.0.0, which ever is soonest. Please use "
"structured_outputs instead.",
DeprecationWarning,
stacklevel=2)
self.structured_outputs = self.guided_decoding
self.guided_decoding = None

def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of "
Expand Down
Loading