Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions nemo_rl/experience/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
GenerationDatumSpec,
GenerationInterface,
GenerationOutputSpec,
GuidedDecodingConfig,
)
from nemo_rl.utils.timer import Timer

Expand All @@ -61,6 +62,7 @@ def generate_responses(
input_lengths: torch.Tensor,
include_logprobs: bool = True,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> tuple[BatchedDataDict[DatumSpec], list[torch.Tensor], dict[str, float | int]]:
"""Generate responses from policy using synchronous generation."""
# Add stop_strings to generation_input_data if present in the batch
Expand All @@ -72,7 +74,9 @@ def generate_responses(

# Always use synchronous generation
generation_outputs = policy_generation.generate(
generation_input_data, greedy=greedy
generation_input_data,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
)

# Extract everything we need from the generation outputs
Expand Down Expand Up @@ -125,6 +129,7 @@ async def generate_responses_async(
input_lengths: torch.Tensor,
include_logprobs: bool = True,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> tuple[BatchedDataDict[DatumSpec], list[torch.Tensor], dict[str, float | int]]:
"""Async version of generate_responses that properly calls generate_async."""
# Add stop_strings to generation_input_data if present in the batch
Expand All @@ -151,7 +156,9 @@ async def generate_responses_async(
tuple[int, BatchedDataDict[GenerationOutputSpec]]
] = []
async for original_idx, single_item_output in policy_generation.generate_async(
generation_input_data, greedy=greedy
generation_input_data,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
):
collected_indexed_outputs.append((original_idx, single_item_output))

Expand Down Expand Up @@ -337,6 +344,7 @@ def run_multi_turn_rollout(
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]:
"""Runs a multi-turn rollout loop, interacting with the environment.

Expand All @@ -348,6 +356,7 @@ def run_multi_turn_rollout(
max_rollout_turns: Maximum number of agent-environment interaction turns.
max_seq_len: Maximum sequence length allowed.
greedy: Whether to use greedy decoding.
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Returns:
Tuple containing:
Expand Down Expand Up @@ -424,6 +433,7 @@ def run_multi_turn_rollout(
tokenizer,
input_lengths=active_input_lengths,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
)

# Record token usage - assistant
Expand Down Expand Up @@ -548,6 +558,7 @@ async def async_generate_response_for_sample_turn(
tokenizer: TokenizerType,
max_seq_len: int,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]]:
"""Generate a response for a single sample's turn using async generation.

Expand All @@ -558,6 +569,7 @@ async def async_generate_response_for_sample_turn(
tokenizer: Tokenizer to use
max_seq_len: Maximum sequence length
greedy: Whether to use greedy decoding
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Returns:
Tuple of (updated_message_log, generated_tokens, input_lengths, generation_metrics)
Expand Down Expand Up @@ -617,6 +629,7 @@ async def run_sample_multi_turn_rollout(
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> tuple[dict, dict[str, Any]]:
"""Run a multi-turn rollout for a single sample.

Expand All @@ -632,6 +645,7 @@ async def run_sample_multi_turn_rollout(
max_seq_len: Maximum sequence length
max_rollout_turns: Maximum number of turns
greedy: Whether to use greedy decoding
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Returns:
Tuple of (final_sample_state, sample_metrics)
Expand Down Expand Up @@ -677,6 +691,7 @@ async def run_sample_multi_turn_rollout(
tokenizer,
max_seq_len,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
)
current_message_log = updated_message_log

Expand Down Expand Up @@ -785,6 +800,7 @@ def run_async_multi_turn_rollout(
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]:
"""Run multi-turn rollouts with sample-level processing.

Expand All @@ -799,6 +815,7 @@ def run_async_multi_turn_rollout(
max_seq_len: Maximum sequence length allowed
max_rollout_turns: Maximum number of agent-environment interaction turns
greedy: Whether to use greedy decoding
guided_decoding_config: Configuration for guided decoding, None to disable guided decoding.

Returns:
Tuple containing:
Expand Down Expand Up @@ -835,6 +852,7 @@ async def run_single_sample_with_error_handling(i, sample_state):
max_seq_len=max_seq_len,
max_rollout_turns=max_rollout_turns,
greedy=greedy,
guided_decoding_config=guided_decoding_config,
)
return result
except Exception as e:
Expand Down
32 changes: 30 additions & 2 deletions nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, NotRequired, TypedDict, Union
from typing import Any, NotRequired, Optional, TypedDict, Union

import ray
import torch
Expand Down Expand Up @@ -115,6 +115,30 @@ class ColocationConfig(TypedDict):
resources: OptionalResourcesConfig


class GuidedDecodingConfig(TypedDict):
"""Configuration for guided decoding.

`mode`: The guided decoding mode, can be one of `json`, `regex`, `choice`, `grammar`, or `json_object`.

For the selected mode, its corresponding field must be provided:

- `json`: the output must be a JSON object matching the provided schema
- `regex`: the output must match the provided regex
- `choice`: the output must be one of the provided choices
- `grammar`: the output must be a valid grammar
- `json_object`: the output must be some JSON object

This class is intentially similar to the GuidedDecodingParams class in vLLM,
however, we do not want to inject that dependency here.
"""

mode: str
json: NotRequired[Union[str, dict]]
regex: NotRequired[str]
choice: NotRequired[list[str]]
grammar: NotRequired[str]


class GenerationConfig(TypedDict):
"""Configuration for generation."""

Expand All @@ -127,6 +151,7 @@ class GenerationConfig(TypedDict):
stop_token_ids: list[int] | None
stop_strings: list[str] | None
colocated: NotRequired[ColocationConfig]
guided_decoding: NotRequired[GuidedDecodingConfig]
# This isn't meant to be passed by the user, but is populated by nemo_rl.models.generation.__init__.configure_generation_config
_pad_token_id: NotRequired[int]

Expand Down Expand Up @@ -224,7 +249,10 @@ def init_collective(

@abstractmethod
def generate(
self, data: BatchedDataDict["GenerationDatumSpec"], greedy: bool
self,
data: BatchedDataDict["GenerationDatumSpec"],
greedy: bool,
guided_decoding_config: Optional[GuidedDecodingConfig],
) -> BatchedDataDict["GenerationOutputSpec"]:
pass

Expand Down
49 changes: 41 additions & 8 deletions nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
import os
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Optional,
Union,
)

if TYPE_CHECKING:
from vllm.sampling_params import GuidedDecodingParams

import numpy as np
import ray
from ray.util.placement_group import PlacementGroup
Expand All @@ -34,6 +38,7 @@
GenerationDatumSpec,
GenerationInterface,
GenerationOutputSpec,
GuidedDecodingConfig,
)
from nemo_rl.models.generation.vllm.config import VllmConfig

Expand Down Expand Up @@ -421,7 +426,10 @@ def init_collective(
return futures

def generate(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
self,
data: BatchedDataDict[GenerationDatumSpec],
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> BatchedDataDict[GenerationOutputSpec]:
"""Generate a batch of data using vLLM."""
assert isinstance(data, BatchedDataDict), (
Expand All @@ -442,7 +450,10 @@ def generate(
in_sharded_axes=["data_parallel"],
replicate_on_axes=None, # just run on tp rank 0
output_is_replicated=None,
common_kwargs={"greedy": greedy},
common_kwargs={
"greedy": greedy,
"guided_decoding_config": guided_decoding_config,
},
)

# Get results from the workers, respecting tied worker groups (only one result per tied worker group)
Expand All @@ -469,7 +480,10 @@ def generate(
return combined

def generate_text(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
self,
data: BatchedDataDict[GenerationDatumSpec],
greedy: bool = False,
guided_decoding_params: Optional["GuidedDecodingParams"] = None,
) -> BatchedDataDict[GenerationOutputSpec]:
"""Generate text responses using vLLM."""
assert isinstance(data, BatchedDataDict), (
Expand All @@ -493,7 +507,10 @@ def generate_text(
in_sharded_axes=["data_parallel"],
replicate_on_axes=None, # just run on tp rank 0
output_is_replicated=None,
common_kwargs={"greedy": greedy},
common_kwargs={
"greedy": greedy,
"guided_decoding_params": guided_decoding_params,
},
)

# Get results from the workers, respecting tied worker groups (only one result per tied worker group)
Expand All @@ -520,6 +537,7 @@ async def _async_generate_base(
method_name: str,
data_validation_fn,
greedy: bool = False,
**kwargs,
) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]:
"""Base async generation method that handles common worker management logic.

Expand Down Expand Up @@ -556,6 +574,7 @@ async def _async_generate_base(
worker_idx=leader_worker_idx,
data=data,
greedy=greedy,
**kwargs,
)

# Increment the round-robin worker group index
Expand Down Expand Up @@ -643,7 +662,10 @@ async def consume_worker_generator(worker_idx, worker_gen):
)

async def generate_text_async(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
self,
data: BatchedDataDict[GenerationDatumSpec],
greedy: bool = False,
guided_decoding_params: Optional["GuidedDecodingParams"] = None,
) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]:
"""Generate text responses asynchronously, yielding results as they are ready.

Expand All @@ -661,12 +683,19 @@ def validate_text_data(data):
return True

async for result in self._async_generate_base(
data, "generate_text_async", validate_text_data, greedy
data,
"generate_text_async",
validate_text_data,
greedy,
guided_decoding_params=guided_decoding_params,
):
yield result

async def generate_async(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
self,
data: BatchedDataDict[GenerationDatumSpec],
greedy: bool = False,
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
) -> AsyncGenerator[tuple[int, BatchedDataDict[GenerationOutputSpec]], None]:
"""Generate responses asynchronously, yielding individual samples as they complete.

Expand All @@ -684,7 +713,11 @@ def validate_generate_data(data):
return True

async for result in self._async_generate_base(
data, "generate_async", validate_generate_data, greedy
data,
"generate_async",
validate_generate_data,
greedy,
guided_decoding_config=guided_decoding_config,
):
yield result

Expand Down
Loading
Loading