Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions nemo_rl/experience/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import asyncio
import copy
from typing import Any
from typing import Any, Optional

import ray
import torch
Expand All @@ -41,6 +41,7 @@
GenerationDatumSpec,
GenerationInterface,
GenerationOutputSpec,
GuidedDecodingConfig,
)

TokenizerType = PreTrainedTokenizerBase
Expand All @@ -54,6 +55,7 @@ def generate_responses(
input_lengths: torch.Tensor,
include_logprobs: bool = True,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> tuple[BatchedDataDict[DatumSpec], list[torch.Tensor], dict[str, float | int]]:
"""Generate responses from policy using synchronous generation."""
# Add stop_strings to generation_input_data if present in the batch
Expand All @@ -65,7 +67,9 @@ def generate_responses(

# Always use synchronous generation
generation_outputs = policy_generation.generate(
generation_input_data, greedy=greedy
generation_input_data,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
)

# Extract everything we need from the generation outputs
Expand Down Expand Up @@ -118,6 +122,7 @@ async def generate_responses_async(
input_lengths: torch.Tensor,
include_logprobs: bool = True,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> tuple[BatchedDataDict[DatumSpec], list[torch.Tensor], dict[str, float | int]]:
"""Async version of generate_responses that properly calls generate_async."""
# Add stop_strings to generation_input_data if present in the batch
Expand All @@ -144,7 +149,9 @@ async def generate_responses_async(
tuple[int, BatchedDataDict[GenerationOutputSpec]]
] = []
async for original_idx, single_item_output in policy_generation.generate_async(
generation_input_data, greedy=greedy
generation_input_data,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
):
collected_indexed_outputs.append((original_idx, single_item_output))

Expand Down Expand Up @@ -321,6 +328,7 @@ def run_multi_turn_rollout(
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
Comment thread
ybgao-nvidia marked this conversation as resolved.
) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]:
"""Runs a multi-turn rollout loop, interacting with the environment.

Expand All @@ -332,6 +340,7 @@ def run_multi_turn_rollout(
max_rollout_turns: Maximum number of agent-environment interaction turns.
max_seq_len: Maximum sequence length allowed.
greedy: Whether to use greedy decoding.
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Returns:
Tuple containing:
Expand Down Expand Up @@ -396,6 +405,7 @@ def run_multi_turn_rollout(
tokenizer,
input_lengths=active_input_lengths,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
)

# Record token usage - assistant
Expand Down Expand Up @@ -516,6 +526,7 @@ async def async_generate_response_for_sample_turn(
tokenizer: TokenizerType,
max_seq_len: int,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
Comment thread
ybgao-nvidia marked this conversation as resolved.
) -> tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]]:
"""Generate a response for a single sample's turn using async generation.

Expand All @@ -526,6 +537,7 @@ async def async_generate_response_for_sample_turn(
tokenizer: Tokenizer to use
max_seq_len: Maximum sequence length
greedy: Whether to use greedy decoding
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Returns:
Tuple of (updated_message_log, generated_tokens, input_lengths, generation_metrics)
Expand Down Expand Up @@ -585,6 +597,7 @@ async def run_sample_multi_turn_rollout(
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
Comment thread
ybgao-nvidia marked this conversation as resolved.
) -> tuple[dict, dict[str, Any]]:
"""Run a multi-turn rollout for a single sample.

Expand All @@ -600,6 +613,7 @@ async def run_sample_multi_turn_rollout(
max_seq_len: Maximum sequence length
max_rollout_turns: Maximum number of turns
greedy: Whether to use greedy decoding
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Returns:
Tuple of (final_sample_state, sample_metrics)
Expand Down Expand Up @@ -643,6 +657,7 @@ async def run_sample_multi_turn_rollout(
tokenizer,
max_seq_len,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
)
current_message_log = updated_message_log

Expand Down Expand Up @@ -743,6 +758,7 @@ def run_async_multi_turn_rollout(
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
Comment thread
ybgao-nvidia marked this conversation as resolved.
) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]:
"""Run multi-turn rollouts with sample-level processing.

Expand All @@ -757,6 +773,7 @@ def run_async_multi_turn_rollout(
max_seq_len: Maximum sequence length allowed
max_rollout_turns: Maximum number of agent-environment interaction turns
greedy: Whether to use greedy decoding
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Returns:
Tuple containing:
Expand Down Expand Up @@ -793,6 +810,7 @@ async def run_single_sample_with_error_handling(i, sample_state):
max_seq_len=max_seq_len,
max_rollout_turns=max_rollout_turns,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
)
return result
except Exception as e:
Expand Down
32 changes: 30 additions & 2 deletions nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, NotRequired, TypedDict, Union
from typing import Any, NotRequired, Optional, TypedDict, Union

import ray
import torch
Expand Down Expand Up @@ -109,6 +109,30 @@ class ColocationConfig(TypedDict):
resources: NotRequired[ResourcesConfig]


class GuidedDecodingConfig(TypedDict):
Comment thread
ybgao-nvidia marked this conversation as resolved.
"""Configuration for guided decoding.

`mode`: The guided decoding mode, can be one of `json`, `regex`, `choice`, `grammar`, or `json_object`.

For the selected mode, its corresponding field must be provided:

- `json`: the output must be a JSON object matching the provided schema
- `regex`: the output must match the provided regex
- `choice`: the output must be one of the provided choices
- `grammar`: the output must be a valid grammar
- `json_object`: the output must be some JSON object

This class is intentially similar to the GuidedDecodingParams class in vLLM,
however, we do not want to inject that dependency here.
"""

mode: str
json: NotRequired[Union[str, dict]]
regex: NotRequired[str]
choice: NotRequired[list[str]]
grammar: NotRequired[str]


class GenerationConfig(TypedDict):
"""Configuration for generation."""

Expand All @@ -122,6 +146,7 @@ class GenerationConfig(TypedDict):
stop_strings: NotRequired[list[str]]
pad_token_id: NotRequired[int]
colocated: NotRequired[ColocationConfig]
guided_decoding: NotRequired[GuidedDecodingConfig]


class GenerationDatumSpec(TypedDict):
Expand Down Expand Up @@ -217,7 +242,10 @@ def init_collective(

@abstractmethod
def generate(
self, data: BatchedDataDict["GenerationDatumSpec"], greedy: bool
self,
data: BatchedDataDict["GenerationDatumSpec"],
greedy: bool,
guided_decoding_config: Optional[GuidedDecodingConfig],
) -> BatchedDataDict["GenerationOutputSpec"]:
pass

Expand Down
61 changes: 57 additions & 4 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import uuid
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
NotRequired,
Expand All @@ -29,6 +30,9 @@
cast,
)

if TYPE_CHECKING:
from vllm.sampling_params import GuidedDecodingParams

import numpy as np
import ray
import torch
Expand All @@ -49,6 +53,7 @@
GenerationDatumSpec,
GenerationInterface,
GenerationOutputSpec,
GuidedDecodingConfig,
verify_right_padding,
)
from nemo_rl.models.huggingface.common import ModelFlag
Expand Down Expand Up @@ -84,6 +89,31 @@ def __repr__(self) -> str:
"""
return f"{self.__class__.__name__}"

def _get_vllm_guided_decoding_params(
self, guided_decoding_config: Optional[GuidedDecodingConfig]
) -> Optional["GuidedDecodingParams"]:
"""Get the guided decoding parameters for vLLM."""
from vllm.sampling_params import GuidedDecodingParams

if guided_decoding_config is None:
return None

match guided_decoding_config["mode"]:
case "json":
return GuidedDecodingParams(json=guided_decoding_config["json"])
case "regex":
return GuidedDecodingParams(regex=guided_decoding_config["regex"])
case "choice":
return GuidedDecodingParams(choice=guided_decoding_config["choice"])
case "grammar":
return GuidedDecodingParams(grammar=guided_decoding_config["grammar"])
case "json_object":
return GuidedDecodingParams(json_object=True)
case _:
raise ValueError(
f"Unsupported guided decoding mode: {guided_decoding_config['mode']}"
)

@staticmethod
def configure_worker(
num_gpus: int | float, bundle_indices: Optional[tuple[int, list[int]]] = None
Expand Down Expand Up @@ -415,6 +445,7 @@ def _build_sampling_params(
greedy: bool,
stop_strings,
max_new_tokens: Optional[int] = None,
guided_decoding_params: Optional["GuidedDecodingParams"] = None, # noqa: F821
Comment thread
ybgao-nvidia marked this conversation as resolved.
Outdated
):
top_k_cfg = self.cfg["top_k"]
top_k_val = 1 if greedy else (top_k_cfg if top_k_cfg is not None else -1)
Expand All @@ -434,16 +465,21 @@ def _build_sampling_params(
stop_token_ids=self.cfg["stop_token_ids"],
stop=stop_strings,
include_stop_str_in_output=True,
guided_decoding=guided_decoding_params,
)

def generate(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
self,
data: BatchedDataDict[GenerationDatumSpec],
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
Comment thread
ybgao-nvidia marked this conversation as resolved.
Outdated
) -> BatchedDataDict[GenerationOutputSpec]:
"""Generate a batch of data using vLLM generation.

Args:
data: BatchedDataDict containing input_ids and input_lengths tensors
greedy: Whether to use greedy decoding instead of sampling
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Returns:
BatchedDataDict conforming to GenerationOutputSpec:
Expand Down Expand Up @@ -471,6 +507,9 @@ def generate(
sampling_params = self._build_sampling_params(
greedy=greedy,
stop_strings=stop_strings,
guided_decoding_params=self._get_vllm_guided_decoding_params(
guided_decoding_config
),
)

# verify inputs have correct padding
Expand Down Expand Up @@ -578,12 +617,14 @@ async def generate_async(
self,
data: BatchedDataDict[GenerationDatumSpec],
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
Comment thread
ybgao-nvidia marked this conversation as resolved.
Outdated
) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]:
"""Generate a batch of data using vLLM's AsyncLLMEngine, yielding results as they are ready.

Args:
data: BatchedDataDict with input_ids and input_lengths
greedy: Whether to use greedy decoding instead of sampling
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Yields:
Tuple of (original_index, BatchedDataDict conforming to GenerationOutputSpec for the single sequence)
Expand Down Expand Up @@ -680,6 +721,9 @@ async def process_single_sample(sample_idx):
greedy=greedy,
stop_strings=final_stop_strings_for_sample,
max_new_tokens=allowed_new_tokens,
guided_decoding=self._get_vllm_guided_decoding_params(
guided_decoding_config
),
)

request_id = str(uuid.uuid4())
Expand Down Expand Up @@ -800,7 +844,10 @@ async def process_single_sample(sample_idx):
raise e

def generate_text(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
self,
data: BatchedDataDict[GenerationDatumSpec],
greedy: bool = False,
guided_decoding_params: Optional["GuidedDecodingParams"] = None, # noqa: F821
Comment thread
ybgao-nvidia marked this conversation as resolved.
Outdated
Comment thread
ybgao-nvidia marked this conversation as resolved.
Outdated
) -> BatchedDataDict[GenerationOutputSpec]:
"""Generate text responses using vLLM generation.

Expand Down Expand Up @@ -1599,7 +1646,10 @@ def init_collective(
return futures

def generate(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
self,
data: BatchedDataDict[GenerationDatumSpec],
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> BatchedDataDict[GenerationOutputSpec]:
"""Generate a batch of data using vLLM."""
assert isinstance(data, BatchedDataDict), (
Expand All @@ -1620,7 +1670,10 @@ def generate(
in_sharded_axes=["data_parallel"],
replicate_on_axes=None, # just run on tp rank 0
output_is_replicated=None,
common_kwargs={"greedy": greedy},
common_kwargs={
"greedy": greedy,
"guided_decoding_config": guided_decoding_config,
},
)

# Get results from the workers, respecting tied worker groups (only one result per tied worker group)
Expand Down
11 changes: 10 additions & 1 deletion nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
GenerationDatumSpec,
GenerationInterface,
GenerationOutputSpec,
GuidedDecodingConfig,
)
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.models.policy.interfaces import (
Expand Down Expand Up @@ -419,9 +420,17 @@ def train(
return aggregated_results

def generate(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
self,
data: BatchedDataDict[GenerationDatumSpec],
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> BatchedDataDict[GenerationOutputSpec]:
"""Generate a batch of data using the policy."""
# Guided decoding is currently only supported for vLLM backend
assert guided_decoding_config is None, (
"Guided decoding is currently only supported for vLLM backend"
)

# Verify input data is right-padded
assert isinstance(data, BatchedDataDict), (
f"data must be a BatchedDataDict, got type: {type(data)}"
Expand Down
Loading