diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index b8b378542c..3f1e59c704 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -47,6 +47,7 @@ GenerationDatumSpec, GenerationInterface, GenerationOutputSpec, + GuidedDecodingConfig, ) from nemo_rl.utils.timer import Timer @@ -61,6 +62,7 @@ def generate_responses( input_lengths: torch.Tensor, include_logprobs: bool = True, greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> tuple[BatchedDataDict[DatumSpec], list[torch.Tensor], dict[str, float | int]]: """Generate responses from policy using synchronous generation.""" # Add stop_strings to generation_input_data if present in the batch @@ -72,7 +74,9 @@ def generate_responses( # Always use synchronous generation generation_outputs = policy_generation.generate( - generation_input_data, greedy=greedy + generation_input_data, + greedy=greedy, + guided_decoding_config=guided_decoding_config, ) # Extract everything we need from the generation outputs @@ -125,6 +129,7 @@ async def generate_responses_async( input_lengths: torch.Tensor, include_logprobs: bool = True, greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> tuple[BatchedDataDict[DatumSpec], list[torch.Tensor], dict[str, float | int]]: """Async version of generate_responses that properly calls generate_async.""" # Add stop_strings to generation_input_data if present in the batch @@ -151,7 +156,9 @@ async def generate_responses_async( tuple[int, BatchedDataDict[GenerationOutputSpec]] ] = [] async for original_idx, single_item_output in policy_generation.generate_async( - generation_input_data, greedy=greedy + generation_input_data, + greedy=greedy, + guided_decoding_config=guided_decoding_config, ): collected_indexed_outputs.append((original_idx, single_item_output)) @@ -337,6 +344,7 @@ def run_multi_turn_rollout( max_seq_len: int, max_rollout_turns: int = 999999, greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]: """Runs a multi-turn rollout loop, interacting with the environment. @@ -348,6 +356,7 @@ def run_multi_turn_rollout( max_rollout_turns: Maximum number of agent-environment interaction turns. max_seq_len: Maximum sequence length allowed. greedy: Whether to use greedy decoding. + guided_decoding_config: Configuration for guided decoding, None to disable guided decoding. Returns: Tuple containing: @@ -424,6 +433,7 @@ def run_multi_turn_rollout( tokenizer, input_lengths=active_input_lengths, greedy=greedy, + guided_decoding_config=guided_decoding_config, ) # Record token usage - assistant @@ -548,6 +558,7 @@ async def async_generate_response_for_sample_turn( tokenizer: TokenizerType, max_seq_len: int, greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]]: """Generate a response for a single sample's turn using async generation. @@ -558,6 +569,7 @@ async def async_generate_response_for_sample_turn( tokenizer: Tokenizer to use max_seq_len: Maximum sequence length greedy: Whether to use greedy decoding + guided_decoding_config: Configuration for guided decoding, None to disable guided decoding. Returns: Tuple of (updated_message_log, generated_tokens, input_lengths, generation_metrics) @@ -617,6 +629,7 @@ async def run_sample_multi_turn_rollout( max_seq_len: int, max_rollout_turns: int = 999999, greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> tuple[dict, dict[str, Any]]: """Run a multi-turn rollout for a single sample. @@ -632,6 +645,7 @@ async def run_sample_multi_turn_rollout( max_seq_len: Maximum sequence length max_rollout_turns: Maximum number of turns greedy: Whether to use greedy decoding + guided_decoding_config: Configuration for guided decoding, None to disable guided decoding. Returns: Tuple of (final_sample_state, sample_metrics) @@ -677,6 +691,7 @@ async def run_sample_multi_turn_rollout( tokenizer, max_seq_len, greedy=greedy, + guided_decoding_config=guided_decoding_config, ) current_message_log = updated_message_log @@ -785,6 +800,7 @@ def run_async_multi_turn_rollout( max_seq_len: int, max_rollout_turns: int = 999999, greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]: """Run multi-turn rollouts with sample-level processing. @@ -799,6 +815,7 @@ def run_async_multi_turn_rollout( max_seq_len: Maximum sequence length allowed max_rollout_turns: Maximum number of agent-environment interaction turns greedy: Whether to use greedy decoding + guided_decoding_config: Configuration for guided decoding, None to disable guided decoding. Returns: Tuple containing: @@ -835,6 +852,7 @@ async def run_single_sample_with_error_handling(i, sample_state): max_seq_len=max_seq_len, max_rollout_turns=max_rollout_turns, greedy=greedy, + guided_decoding_config=guided_decoding_config, ) return result except Exception as e: diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index f7f58b383f..136d80188c 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, NotRequired, TypedDict, Union +from typing import Any, NotRequired, Optional, TypedDict, Union import ray import torch @@ -115,6 +115,30 @@ class ColocationConfig(TypedDict): resources: OptionalResourcesConfig +class GuidedDecodingConfig(TypedDict): + """Configuration for guided decoding. + + `mode`: The guided decoding mode, can be one of `json`, `regex`, `choice`, `grammar`, or `json_object`. + + For the selected mode, its corresponding field must be provided: + + - `json`: the output must be a JSON object matching the provided schema + - `regex`: the output must match the provided regex + - `choice`: the output must be one of the provided choices + - `grammar`: the output must be a valid grammar + - `json_object`: the output must be some JSON object + + This class is intentially similar to the GuidedDecodingParams class in vLLM, + however, we do not want to inject that dependency here. + """ + + mode: str + json: NotRequired[Union[str, dict]] + regex: NotRequired[str] + choice: NotRequired[list[str]] + grammar: NotRequired[str] + + class GenerationConfig(TypedDict): """Configuration for generation.""" @@ -127,6 +151,7 @@ class GenerationConfig(TypedDict): stop_token_ids: list[int] | None stop_strings: list[str] | None colocated: NotRequired[ColocationConfig] + guided_decoding: NotRequired[GuidedDecodingConfig] # This isn't meant to be passed by the user, but is populated by nemo_rl.models.generation.__init__.configure_generation_config _pad_token_id: NotRequired[int] @@ -224,7 +249,10 @@ def init_collective( @abstractmethod def generate( - self, data: BatchedDataDict["GenerationDatumSpec"], greedy: bool + self, + data: BatchedDataDict["GenerationDatumSpec"], + greedy: bool, + guided_decoding_config: Optional[GuidedDecodingConfig], ) -> BatchedDataDict["GenerationOutputSpec"]: pass diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 5dcc7eaf2e..ddd65aad12 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -16,12 +16,16 @@ import os from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, AsyncGenerator, Optional, Union, ) +if TYPE_CHECKING: + from vllm.sampling_params import GuidedDecodingParams + import numpy as np import ray from ray.util.placement_group import PlacementGroup @@ -34,6 +38,7 @@ GenerationDatumSpec, GenerationInterface, GenerationOutputSpec, + GuidedDecodingConfig, ) from nemo_rl.models.generation.vllm.config import VllmConfig @@ -421,7 +426,10 @@ def init_collective( return futures def generate( - self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + self, + data: BatchedDataDict[GenerationDatumSpec], + greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> BatchedDataDict[GenerationOutputSpec]: """Generate a batch of data using vLLM.""" assert isinstance(data, BatchedDataDict), ( @@ -442,7 +450,10 @@ def generate( in_sharded_axes=["data_parallel"], replicate_on_axes=None, # just run on tp rank 0 output_is_replicated=None, - common_kwargs={"greedy": greedy}, + common_kwargs={ + "greedy": greedy, + "guided_decoding_config": guided_decoding_config, + }, ) # Get results from the workers, respecting tied worker groups (only one result per tied worker group) @@ -469,7 +480,10 @@ def generate( return combined def generate_text( - self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + self, + data: BatchedDataDict[GenerationDatumSpec], + greedy: bool = False, + guided_decoding_params: Optional["GuidedDecodingParams"] = None, ) -> BatchedDataDict[GenerationOutputSpec]: """Generate text responses using vLLM.""" assert isinstance(data, BatchedDataDict), ( @@ -493,7 +507,10 @@ def generate_text( in_sharded_axes=["data_parallel"], replicate_on_axes=None, # just run on tp rank 0 output_is_replicated=None, - common_kwargs={"greedy": greedy}, + common_kwargs={ + "greedy": greedy, + "guided_decoding_params": guided_decoding_params, + }, ) # Get results from the workers, respecting tied worker groups (only one result per tied worker group) @@ -520,6 +537,7 @@ async def _async_generate_base( method_name: str, data_validation_fn, greedy: bool = False, + **kwargs, ) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]: """Base async generation method that handles common worker management logic. @@ -556,6 +574,7 @@ async def _async_generate_base( worker_idx=leader_worker_idx, data=data, greedy=greedy, + **kwargs, ) # Increment the round-robin worker group index @@ -643,7 +662,10 @@ async def consume_worker_generator(worker_idx, worker_gen): ) async def generate_text_async( - self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + self, + data: BatchedDataDict[GenerationDatumSpec], + greedy: bool = False, + guided_decoding_params: Optional["GuidedDecodingParams"] = None, ) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]: """Generate text responses asynchronously, yielding results as they are ready. @@ -661,12 +683,19 @@ def validate_text_data(data): return True async for result in self._async_generate_base( - data, "generate_text_async", validate_text_data, greedy + data, + "generate_text_async", + validate_text_data, + greedy, + guided_decoding_params=guided_decoding_params, ): yield result async def generate_async( - self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + self, + data: BatchedDataDict[GenerationDatumSpec], + greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]: """Generate responses asynchronously, yielding individual samples as they complete. @@ -684,7 +713,11 @@ def validate_generate_data(data): return True async for result in self._async_generate_base( - data, "generate_async", validate_generate_data, greedy + data, + "generate_async", + validate_generate_data, + greedy, + guided_decoding_config=guided_decoding_config, ): yield result diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index a97d68e669..3ddd694585 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -16,7 +16,10 @@ import gc import os import sys -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast + +if TYPE_CHECKING: + from vllm.sampling_params import GuidedDecodingParams import ray import torch @@ -26,6 +29,7 @@ from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, GenerationOutputSpec, + GuidedDecodingConfig, verify_right_padding, ) from nemo_rl.models.generation.vllm.config import VllmConfig @@ -339,6 +343,31 @@ def is_alive(self): """Check if the worker is alive.""" return True + def _get_vllm_guided_decoding_params( + self, guided_decoding_config: Optional[GuidedDecodingConfig] + ) -> Optional["GuidedDecodingParams"]: + """Get the guided decoding parameters for vLLM.""" + from vllm.sampling_params import GuidedDecodingParams + + if guided_decoding_config is None: + return None + + match guided_decoding_config["mode"]: + case "json": + return GuidedDecodingParams(json=guided_decoding_config["json"]) + case "regex": + return GuidedDecodingParams(regex=guided_decoding_config["regex"]) + case "choice": + return GuidedDecodingParams(choice=guided_decoding_config["choice"]) + case "grammar": + return GuidedDecodingParams(grammar=guided_decoding_config["grammar"]) + case "json_object": + return GuidedDecodingParams(json_object=True) + case _: + raise ValueError( + f"Unsupported guided decoding mode: {guided_decoding_config['mode']}" + ) + def _merge_stop_strings(self, batch_stop_strings): stop_set: set[str] = set() @@ -358,6 +387,7 @@ def _build_sampling_params( greedy: bool, stop_strings, max_new_tokens: Optional[int] = None, + guided_decoding_params: Optional["GuidedDecodingParams"] = None, ): top_k_cfg = self.cfg["top_k"] top_k_val = 1 if greedy else (top_k_cfg if top_k_cfg is not None else -1) @@ -377,6 +407,7 @@ def _build_sampling_params( stop_token_ids=self.cfg["stop_token_ids"], stop=stop_strings, include_stop_str_in_output=True, + guided_decoding=guided_decoding_params, ) def start_gpu_profiling(self) -> None: @@ -425,13 +456,17 @@ def init_collective( @wrap_with_nvtx_name("vllm_genertion_worker/generate") def generate( - self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + self, + data: BatchedDataDict[GenerationDatumSpec], + greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> BatchedDataDict[GenerationOutputSpec]: """Generate a batch of data using vLLM generation. Args: data: BatchedDataDict containing input_ids and input_lengths tensors greedy: Whether to use greedy decoding instead of sampling + guided_decoding_config: Configuration for guided decoding, None to disable guided decoding. Returns: BatchedDataDict conforming to GenerationOutputSpec: @@ -459,6 +494,9 @@ def generate( sampling_params = self._build_sampling_params( greedy=greedy, stop_strings=stop_strings, + guided_decoding_params=self._get_vllm_guided_decoding_params( + guided_decoding_config + ), ) # verify inputs have correct padding @@ -552,13 +590,17 @@ def generate( @wrap_with_nvtx_name("vllm_genertion_worker/generate_text") def generate_text( - self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + self, + data: BatchedDataDict[GenerationDatumSpec], + greedy: bool = False, + guided_decoding_params: Optional["GuidedDecodingParams"] = None, ) -> BatchedDataDict[GenerationOutputSpec]: """Generate text responses using vLLM generation. Args: data: BatchedDataDict containing prompts with text strings greedy: Whether to use greedy decoding instead of sampling + guided_decoding_params: Guided decoding parameters for vLLM, None to disable guided decoding. Returns: BatchedDataDict containing: @@ -597,6 +639,7 @@ def generate_text( stop_token_ids=self.cfg["stop_token_ids"], stop=stop_strings, include_stop_str_in_output=True, # returning stop strings like hf + guided_decoding=guided_decoding_params, ) # Generate outputs diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index d4e8161b44..fc0c005b4f 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -16,7 +16,10 @@ import gc import threading import uuid -from typing import Any, AsyncGenerator, Optional, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, cast + +if TYPE_CHECKING: + from vllm.sampling_params import GuidedDecodingParams import ray import torch @@ -29,6 +32,7 @@ from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, GenerationOutputSpec, + GuidedDecodingConfig, verify_right_padding, ) from nemo_rl.models.generation.vllm.utils import format_prompt_for_vllm_generation @@ -497,12 +501,14 @@ async def generate_async( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]: """Generate a batch of data using vLLM's AsyncLLMEngine, yielding results as they are ready. Args: data: BatchedDataDict with input_ids and input_lengths greedy: Whether to use greedy decoding instead of sampling + guided_decoding_config: Configuration for guided decoding, None to disable guided decoding. Yields: Tuple of (original_index, BatchedDataDict conforming to GenerationOutputSpec for the single sequence) @@ -594,6 +600,9 @@ async def process_single_sample(sample_idx): greedy=greedy, stop_strings=final_stop_strings_for_sample, max_new_tokens=allowed_new_tokens, + guided_decoding_params=self._get_vllm_guided_decoding_params( + guided_decoding_config + ), ) request_id = str(uuid.uuid4()) @@ -714,13 +723,17 @@ async def process_single_sample(sample_idx): raise e async def generate_text_async( - self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + self, + data: BatchedDataDict[GenerationDatumSpec], + greedy: bool = False, + guided_decoding_params: Optional["GuidedDecodingParams"] = None, ) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]: """Generate text responses asynchronously, yielding results as they are ready. Args: data: BatchedDataDict containing prompts with text strings greedy: Whether to use greedy decoding instead of sampling + guided_decoding_params: Guided decoding parameters for vLLM, None to disable guided decoding. Yields: Tuple of (original_index, BatchedDataDict containing single text response) @@ -767,6 +780,7 @@ async def process_single_prompt(prompt_idx): stop_token_ids=self.cfg["stop_token_ids"], stop=final_stop_strings, include_stop_str_in_output=True, # returning stop strings like hf + guided_decoding=guided_decoding_params, ) request_id = str(uuid.uuid4()) diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index c1fde9bcf5..7e7e430b57 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -37,6 +37,7 @@ GenerationDatumSpec, GenerationInterface, GenerationOutputSpec, + GuidedDecodingConfig, ) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( @@ -565,9 +566,17 @@ def train( return aggregated_results def generate( - self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + self, + data: BatchedDataDict[GenerationDatumSpec], + greedy: bool = False, + guided_decoding_config: Optional[GuidedDecodingConfig] = None, ) -> BatchedDataDict[GenerationOutputSpec]: """Generate a batch of data using the policy.""" + # Guided decoding is currently only supported for vLLM backend + assert guided_decoding_config is None, ( + "Guided decoding is currently only supported for vLLM backend" + ) + # Verify input data is right-padded assert isinstance(data, BatchedDataDict), ( f"data must be a BatchedDataDict, got type: {type(data)}" diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 79e4112485..27d4200b6a 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -14,6 +14,7 @@ import json import os +import re from copy import deepcopy from pathlib import Path @@ -30,6 +31,7 @@ from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, + GuidedDecodingConfig, ) from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.generation.vllm.vllm_worker_async import ( @@ -2320,3 +2322,123 @@ def test_vllm_megatron_weight_update_with_packing(cluster, test_input_data): megatron_policy.shutdown() if vllm_generation: vllm_generation.shutdown() + + +def test_vllm_guided_decoding(cluster, tokenizer): + """Test vLLM generation with different guided decoding modes.""" + + try: + # construct VllmGeneration policy with more max tokens + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config["max_new_tokens"] = 16 + vllm_config["vllm_cfg"]["async_engine"] = False + vllm_config = configure_generation_config(vllm_config, tokenizer) + vllm_policy = VllmGeneration(cluster, vllm_config) + # vllm_policy.finish_generation() + + def get_test_input_data(prompt: str): + tokenized_data = tokenizer( + [prompt], + padding=True, + truncation=False, + return_tensors="pt", + padding_side="right", + ) + + input_lengths = tokenized_data["attention_mask"].sum(dim=1).to(torch.int32) + test_input_data = BatchedDataDict( + { + "input_ids": tokenized_data["input_ids"], + "input_lengths": input_lengths, + } + ) + return test_input_data + + # Test 1: Regex guided decoding + print("Testing regex guided decoding...") + regex_guided_config: GuidedDecodingConfig = { + "mode": "regex", + "regex": r"\d{3}-\d{3}-\d{4}", # Phone number pattern + } + + prompt1 = "Give me a phone number: " + input_data = get_test_input_data(prompt1) + regex_outputs = vllm_policy.generate( + input_data, greedy=True, guided_decoding_config=regex_guided_config + ) + + # Validate regex outputs + assert "output_ids" in regex_outputs, ( + "output_ids not found in regex guided generation output" + ) + assert regex_outputs["output_ids"].shape[0] == len(input_data["input_ids"]), ( + "Wrong batch size in regex guided output" + ) + + regex_generated_texts = tokenizer.batch_decode( + regex_outputs["output_ids"], skip_special_tokens=True + ) + output_only = regex_generated_texts[0].split(prompt1)[1] + assert re.match(regex_guided_config["regex"], output_only), ( + "Output should match the regex pattern" + ) + + # Validate log probabilities for regex guided tokens (should be logprob=0, probability=1) + assert "logprobs" in regex_outputs, ( + "logprobs not found in regex guided generation output" + ) + assert "generation_lengths" in regex_outputs, ( + "generation_lengths not found in regex guided generation output" + ) + + logprobs = regex_outputs["logprobs"] + generation_lengths = regex_outputs["generation_lengths"] + input_lengths = input_data["input_lengths"] + + # For regex \d{3}-\d{3}-\d{4}, hyphens at positions 3,7 and end token should have logprob ≈ 0 + input_len = input_lengths[0].item() + generated_logprobs = logprobs[ + 0, input_len : input_len + generation_lengths[0].item() + ] + + # Check hyphen positions (3, 7) and last token have logprob ≈ 0 + constrained_indices = [3, 7, -1] # hyphens and end token + assert all( + abs(generated_logprobs[i].item()) < 1e-3 for i in constrained_indices + ), "Regex constrained tokens should have logprob ≈ 0" + + # Test 2: Choice guided decoding + print("Testing choice guided decoding...") + choices = ["yes", "no", "maybe"] + + choice_guided_config: GuidedDecodingConfig = { + "mode": "choice", + "choice": choices, + } + + prompt2 = "Should I go to the gym today? (yes/no/maybe): " + input_data = get_test_input_data(prompt2) + choice_outputs = vllm_policy.generate( + input_data, greedy=True, guided_decoding_config=choice_guided_config + ) + + # Validate choice outputs + assert "output_ids" in choice_outputs, ( + "output_ids not found in choice guided generation output" + ) + assert choice_outputs["output_ids"].shape[0] == len(input_data["input_ids"]), ( + "Wrong batch size in choice guided output" + ) + + choice_generated_texts = tokenizer.batch_decode( + choice_outputs["output_ids"], skip_special_tokens=True + ) + output_only = choice_generated_texts[0].split(prompt2)[1] + assert output_only in choices, "Output should be one of the choices" + + finally: + vllm_policy.shutdown() + import gc + + gc.collect() + torch.cuda.empty_cache()