diff --git a/trl/experimental/openenv/utils.py b/trl/experimental/openenv/utils.py index 6d2cf850d65..70f85107be9 100644 --- a/trl/experimental/openenv/utils.py +++ b/trl/experimental/openenv/utils.py @@ -23,7 +23,7 @@ if is_vllm_available(): from vllm import SamplingParams - from vllm.sampling_params import GuidedDecodingParams + from vllm.sampling_params import StructuredOutputsParams def _build_colocate_sampling_params( @@ -33,9 +33,9 @@ def _build_colocate_sampling_params( logprobs: bool = True, ) -> SamplingParams: if trainer.structured_outputs_regex: - guided_decoding = GuidedDecodingParams(regex=trainer.structured_outputs_regex) + structured_outputs = StructuredOutputsParams(regex=trainer.structured_outputs_regex) else: - guided_decoding = None + structured_outputs = None generation_kwargs: dict[str, Any] = { "n": 1, @@ -43,7 +43,7 @@ def _build_colocate_sampling_params( "top_k": trainer.top_k, "min_p": 0.0 if trainer.min_p is None else trainer.min_p, "max_tokens": trainer.max_completion_length, - "guided_decoding": guided_decoding, + "structured_outputs": structured_outputs, } if trainer.repetition_penalty is not None: generation_kwargs["repetition_penalty"] = trainer.repetition_penalty