From d0f7c8c452f572a11b60123f7a4ca13eba826491 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 14:02:06 -0700 Subject: [PATCH 01/34] Multiturn integrated Signed-off-by: Sahil Jain --- .gitignore | 3 +- examples/configs/grpo_math_1B.yaml | 1 + nemo_reinforcer/algorithms/grpo.py | 167 +---- nemo_reinforcer/data/interfaces.py | 1 + .../distributed/batched_data_dict.py | 36 +- nemo_reinforcer/environments/interfaces.py | 25 +- .../environments/math_environment.py | 31 +- nemo_reinforcer/experience/rollouts.py | 420 ++++++++++++ .../models/generation/interfaces.py | 4 +- nemo_reinforcer/models/generation/vllm.py | 54 +- .../models/policy/fsdp1_policy_worker.py | 16 +- tests/unit/algorithms/test_grpo.py | 46 +- tests/unit/environments/game_interface.py | 62 ++ .../unit/environments/sliding_puzzle_game.py | 242 +++++++ .../environments/test_math_environment.py | 98 +-- tests/unit/experience/test_rollouts.py | 626 ++++++++++++++++++ tests/unit/test_envs.py | 409 ++++++++++++ 17 files changed, 2011 insertions(+), 230 deletions(-) create mode 100644 nemo_reinforcer/experience/rollouts.py create mode 100644 tests/unit/environments/game_interface.py create mode 100644 tests/unit/environments/sliding_puzzle_game.py create mode 100644 tests/unit/experience/test_rollouts.py create mode 100644 tests/unit/test_envs.py diff --git a/.gitignore b/.gitignore index 19f2edc5b8..478990ddc8 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,8 @@ dist/ *.vscode/ # Test -.coverage +coverage.json +.coverage* test_assets/ # Cache diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 4cf474df01..7149965324 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -32,6 +32,7 @@ policy: generation_batch_size: 32 # Only used when generating using HF backend logprob_batch_size: 4 max_total_sequence_length: 512 + max_turns: 999999 precision: "bfloat16" fsdp_offload_enabled: false activation_checkpointing_enabled: false diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 1914b27e98..4116ae11a6 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -24,7 +24,10 @@ from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.algorithms.utils import calculate_baseline_and_std_per_prompt -from nemo_reinforcer.environments.interfaces import EnvironmentInterface +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.data.interfaces import ( DatumSpec, @@ -59,6 +62,7 @@ from nemo_reinforcer.utils.logger import Logger, LoggerConfig from nemo_reinforcer.utils.timer import Timer from nemo_reinforcer.utils.checkpoint import CheckpointManager, CheckpointingConfig +from nemo_reinforcer.experience.rollouts import run_multi_turn_rollout # =============================================================================== @@ -73,6 +77,7 @@ class GRPOConfig(TypedDict): normalize_rewards: bool use_leave_one_out_baseline: bool val_period: int + val_batch_size: int val_at_start: bool checkpoint_dir: str @@ -94,7 +99,7 @@ def _default_grpo_save_state() -> GRPOSaveState: class MasterConfig(TypedDict): policy: PolicyConfig loss_fn: ClippedPGLossConfig - math_env: MathEnvConfig + env_configs: Dict[str, Any] data: DataConfig grpo: GRPOConfig logger: LoggerConfig @@ -283,120 +288,6 @@ def refit_policy_generation( policy.offload_after_refit() -def generate_responses( - policy_generation: GenerationInterface, - generation_input_data: BatchedDataDict[GenerationDatumSpec], - batch: BatchedDataDict[DatumSpec], - tokenizer, - input_lengths: torch.Tensor, - include_logprobs: bool = True, -) -> Tuple[BatchedDataDict[DatumSpec], List[List[int]], Dict[str, float | int]]: - """Generate responses from policy.""" - # Generate responses - generation_outputs = policy_generation.generate(generation_input_data) - - # Extract generated tokens - generated_ids = [] - unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"] - for output_ids, input_length, total_length in zip( - generation_outputs["output_ids"], input_lengths, unpadded_sequence_lengths - ): - generated_ids.append(output_ids[input_length:total_length]) - - generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - - # Append to message log - for i, (text, input_length, total_length) in enumerate( - zip(generated_texts, input_lengths, unpadded_sequence_lengths) - ): - message = { - "role": "assistant", - "content": text, - "token_ids": generation_outputs["output_ids"][i, input_length:total_length], - } - - if include_logprobs and "logprobs" in generation_outputs: - message["generation_logprobs"] = generation_outputs["logprobs"][ - i, input_length:total_length - ] - - batch["message_log"][i].append(message) - - metrics = { - "mean_generation_length": ( - torch.sum(unpadded_sequence_lengths) - torch.sum(input_lengths) - ).item() - / len(unpadded_sequence_lengths), - "max_seqlen": torch.max(unpadded_sequence_lengths).item(), - } - - return batch, generated_ids, metrics - - -def calculate_rewards( - batch: BatchedDataDict[DatumSpec], - task_to_env: Dict[str, EnvironmentInterface], -) -> Tuple[torch.Tensor, List[LLMMessageLogType]]: - """Calculate rewards for generated responses. - - Args: - batch: Batch containing message_log (LLMMessageLogType) with generated responses - task_to_env: Dictionary mapping task names to their corresponding environments - - Returns: - rewards: Tensor of rewards - to_env: Simplified message logs sent to environment (LLMMessageLogType format) - """ - # Extract message logs for environment - to_env = [ - get_keys_from_message_log(batch["message_log"][i], ["role", "content"]) - for i in range(len(batch["message_log"])) - ] - task_names = [batch["task_name"][i] for i in range(len(batch["task_name"]))] - - # Group messages by task type - task_groups = {} - for i, task_name in enumerate(task_names): - if task_name not in task_groups: - task_groups[task_name] = [] - task_groups[task_name].append((i, to_env[i])) - - # Calculate rewards for each task group concurrently - futures = [] - future_to_indices = {} # Map future to its corresponding indices - for task_name, group in task_groups.items(): - if task_name not in task_to_env: - raise ValueError(f"No environment found for task type: {task_name}") - - # Extract indices and messages for this group - indices = [idx for idx, _ in group] - messages = [msg for _, msg in group] - - # Get corresponding environment info - env_info = [batch["extra_env_info"][i] for i in indices] - - # Submit task to environment and store future - future = task_to_env[task_name].step.remote(messages, env_info) - futures.append(future) - future_to_indices[future] = indices - - results = ray.get(futures) - all_rewards = [] - for future, result in zip(futures, results): - indices = future_to_indices[future] - _, _, task_rewards, _ = result - - # Store results with their original indices - for idx, reward in zip(indices, task_rewards): - all_rewards.append((idx, reward)) - - # Sort results by original index to maintain order - all_rewards.sort(key=lambda x: x[0]) - rewards = torch.tensor([reward for _, reward in all_rewards]) - - return rewards, to_env - - # =============================================================================== # Training & Validation # =============================================================================== @@ -463,7 +354,7 @@ def grpo_train( print("▶ Preparing batch...") with timer.time("data_processing"): # Repeat batch items - repeated_batch = batch.repeat_interleave( + repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave( master_config["grpo"]["num_generations_per_prompt"] ) # Convert LLMMessageLogType to FlatMessagesType for generation @@ -472,36 +363,40 @@ def grpo_train( pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) input_ids = batched_flat["token_ids"] - # Create generation-specific input structure - generation_input_data = BatchedDataDict[GenerationDatumSpec]( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - } - ) + # # Create generation-specific input structure + # generation_input_data = BatchedDataDict[GenerationDatumSpec]( + # { + # "input_ids": input_ids, + # "input_lengths": input_lengths, + # } + # ) # Generate responses - this updates the LLMMessageLogType in repeated_batch - print(f"▶ Generating responses for batch of size {len(input_ids)}...") + print(f"▶ Generating responses for batch of size {repeated_batch.size}...") with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation(policy, policy_generation) POLICY_GENERATION_STALE = False else: policy_generation.prepare_for_generation() + with timer.time("generation"): - repeated_batch, _, gen_metrics = generate_responses( - policy_generation, - generation_input_data, - repeated_batch, - tokenizer, - input_lengths, + repeated_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=policy_generation, + initial_batch=repeated_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=master_config["policy"]["max_total_sequence_length"], + max_turns=master_config["policy"]["max_turns"], + greedy=False, ) policy_generation.finish_generation() - # Calculate rewards & advantages based on the updated LLMMessageLogType - print("▶ Calculating rewards...") + # Calculate rewards & advantages + print("▶ Processing rewards...") with timer.time("reward_calculation"): - rewards, _ = calculate_rewards(repeated_batch, task_to_env) + # Extract rewards from final_batch + rewards = repeated_batch["total_reward"] print("▶ Computing advantages...") baseline, std = calculate_baseline_and_std_per_prompt( @@ -665,14 +560,14 @@ def grpo_train( metrics[k] = np.sum(v).item() else: metrics[k] = np.mean(v).item() - metrics.update(gen_metrics) + metrics.update(rollout_metrics) timing_metrics = timer.get_timing_metrics(reduction_op="sum") print(f" • Loss: {metrics['loss']:.4f}") print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") print( - f" • Mean Generation Length: {gen_metrics['mean_generation_length']:.4f}" + f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}" ) print("\n⏱️ Timing:") diff --git a/nemo_reinforcer/data/interfaces.py b/nemo_reinforcer/data/interfaces.py index 6ae1152be7..33a44d555d 100644 --- a/nemo_reinforcer/data/interfaces.py +++ b/nemo_reinforcer/data/interfaces.py @@ -32,6 +32,7 @@ class DatumSpec(TypedDict): loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid) idx: int task_name: Optional[str] = "default" + stop_strings: Optional[List[str]] = None # Optional stop strings for generation __extra__: Any # This allows additional fields of any type diff --git a/nemo_reinforcer/distributed/batched_data_dict.py b/nemo_reinforcer/distributed/batched_data_dict.py index a1711ae8c2..a325dcff49 100644 --- a/nemo_reinforcer/distributed/batched_data_dict.py +++ b/nemo_reinforcer/distributed/batched_data_dict.py @@ -13,7 +13,7 @@ # limitations under the License. from copy import deepcopy from collections import UserDict -from typing import List, Dict, Optional, Iterator, TypeVar, Any, Generic +from typing import List, Dict, Optional, Iterator, TypeVar, Any, Generic, Union from typing_extensions import Self import torch @@ -275,13 +275,37 @@ def size(self) -> int: return len(self.data[key]) return self.data[key].shape[0] - def to(self, device: torch.device) -> "BatchedDataDict": - """Move all tensors in the batch to a specific device.""" - for k in self.data: - if torch.is_tensor(self.data[k]): - self.data[k] = self.data[k].to(device) + def to(self, device: torch.device) -> Self: + """Move tensors in batched dict to device.""" + for k, v in self.data.items(): + if torch.is_tensor(v): + self.data[k] = v.to(device) return self + def select_indices( + self, indices: Union[List[int], torch.Tensor] + ) -> "BatchedDataDict": + """Selects specific rows from the batch based on indices. + + Args: + indices: A list or tensor of integer indices to select. + + Returns: + BatchedDataDict: A new BatchedDataDict containing only the selected rows. + """ + selected_batch = BatchedDataDict() + for k, v in self.data.items(): + if torch.is_tensor(v): + selected_batch[k] = v[indices] + elif isinstance(v, list): + selected_batch[k] = [v[i] for i in indices] + else: + # Handle other potential types if necessary, or raise error + raise TypeError( + f"Unsupported type {type(v)} for index selection in BatchedDataDict" + ) + return selected_batch + def get_dict(self) -> dict: """Get the underlying data dictionary.""" return self.data diff --git a/nemo_reinforcer/environments/interfaces.py b/nemo_reinforcer/environments/interfaces.py index 40986f4f19..881f3467b4 100644 --- a/nemo_reinforcer/environments/interfaces.py +++ b/nemo_reinforcer/environments/interfaces.py @@ -12,27 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, NamedTuple, Optional from torch import Tensor from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.data.interfaces import LLMMessageLogType -EnvironmentReturn = Tuple[List[List[Dict[str, str]]], List[Dict], Tensor, Tensor] + +class EnvironmentReturn(NamedTuple): + """Standard return type for environment step methods.""" + + observations: List[Dict[str, str]] + metadata: List[Optional[dict]] + next_stop_strings: List[Optional[List[str]]] + rewards: Tensor + terminated: Tensor class EnvironmentInterface(abc.ABC): @abc.abstractmethod def step( self, - message_log_batch: List[List[Dict[str, str]]], - metadata: List[Dict], + message_log_batch: List[LLMMessageLogType], + metadata: List[Optional[dict]], *args, **kwargs, ) -> EnvironmentReturn: """Runs a step in the environment. Allows for asynchrony with remote servers, but it's not required (this function is a ray remote). message_log_batch: batch of OpenAI-API-like message logs that represent interactions with the LLM. + Each element is a List[Dict[str, Union[str, torch.Tensor]]]. For example, if this were a Math Environment, then the message log would be [ @@ -48,13 +58,10 @@ def step( {"role": "assistant", "content": "model response"}, ] metadata: batch of whatever the environment needs to keep track of. I.e. - math solutions, code unit tests, or agent states. + math solutions, code unit tests, or agent states. Can be None if episode terminated. Returns: - - List[Dict[str, str]]: An observation/response batch in an OpenAI-API-like message format that is the result of the step. - - List[Dict]: An updated batch of metadata. - - Tensor: A tensor of rewards. - - Tensor: A tensor of done flags. + - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminated flags. """ @abc.abstractmethod diff --git a/nemo_reinforcer/environments/math_environment.py b/nemo_reinforcer/environments/math_environment.py index 65a5cc0e27..b72c031e29 100644 --- a/nemo_reinforcer/environments/math_environment.py +++ b/nemo_reinforcer/environments/math_environment.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from itertools import tee -from typing import Dict, List, Tuple, TypedDict +from typing import Dict, List, Tuple, TypedDict, Optional import ray import torch from math_verify import parse, verify from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict -from nemo_reinforcer.environments.interfaces import EnvironmentInterface +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) from nemo_reinforcer.environments.metrics import ( calculate_pass_rate_per_prompt, ) @@ -29,6 +32,7 @@ class MathEnvConfig(TypedDict): num_workers: int + stop_strings: Optional[List[str]] = None # Default stop strings for this env @ray.remote @@ -66,7 +70,8 @@ class MathEnvironmentMetadata(TypedDict): class MathEnvironment(EnvironmentInterface): DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM - def __init__(self, cfg: Dict): + def __init__(self, cfg: MathEnvConfig): + self.cfg = cfg self.num_workers = cfg["num_workers"] self.workers = [ HFVerifyWorker.options( @@ -84,7 +89,7 @@ def step( self, message_log_batch: List[List[Dict[str, str]]], metadata: List[MathEnvironmentMetadata], - ): + ) -> EnvironmentReturn: """Runs a step in the math environment. Args: @@ -95,6 +100,7 @@ def step( EnvironmentReturn: A tuple containing: - List[Dict[str, str]]: Observations/responses batch - List[Dict]: Updated metadata + - List[str]: Next stop strings for the next turn - Tensor: Rewards tensor - Tensor: Done flags tensor """ @@ -129,7 +135,12 @@ def step( # flatten the results results = [item for sublist in results for item in sublist] observations = [ - {"role": "user", "content": "correct" if result else "incorrect"} + { + "role": "environment", + "content": "Environment: correct" + if result + else "Environment: incorrect", + } for result in results ] @@ -137,7 +148,15 @@ def step( rewards = torch.tensor(results).cpu() done = torch.ones_like(rewards).cpu() - return observations, metadata, rewards, done + next_stop_strings = None + + return EnvironmentReturn( + observations=observations, + metadata=metadata, + next_stop_strings=next_stop_strings, + rewards=rewards, + terminated=done, + ) def global_post_process_and_metrics( self, batch: BatchedDataDict diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py new file mode 100644 index 0000000000..d402bdff4c --- /dev/null +++ b/nemo_reinforcer/experience/rollouts.py @@ -0,0 +1,420 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generate rollouts + +import torch +from typing import List, Tuple, Dict, Optional, Any, NamedTuple +from transformers import AutoTokenizer +import ray + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.data.interfaces import ( + DatumSpec, + LLMMessageLogType, + FlatMessagesType, +) +from nemo_reinforcer.data.llm_message_utils import ( + get_keys_from_message_log, + batched_message_log_to_flat_message, +) +from nemo_reinforcer.models.generation.interfaces import ( + GenerationInterface, + GenerationDatumSpec, +) +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) + + +# Return type for calculate_rewards +class RewardsOutput(NamedTuple): + rewards: torch.Tensor + env_observations: List[Dict[str, str]] + terminateds: torch.Tensor + next_stop_strings: List[Optional[List[str]]] + metadata: List[Optional[Dict[str, Any]]] + + +def generate_responses( + policy_generation: GenerationInterface, + generation_input_data: BatchedDataDict[GenerationDatumSpec], + batch: BatchedDataDict[DatumSpec], + tokenizer: AutoTokenizer, + input_lengths: torch.Tensor, + include_logprobs: bool = True, + greedy: bool = False, +) -> Tuple[BatchedDataDict[DatumSpec], List[torch.Tensor], dict]: + """Generate responses from policy.""" + # Add stop_strings to generation_input_data if present in the batch + if "stop_strings" in batch: + generation_input_data["stop_strings"] = batch["stop_strings"] + else: + # Ensure the key exists even if it's None, matching GenerationDatumSpec + generation_input_data["stop_strings"] = [None] * len(input_lengths) + + # Generate responses + generation_outputs = policy_generation.generate( + generation_input_data, greedy=greedy + ) + + # Extract generated tokens + generated_ids = [] + unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"] + for output_ids, input_length, total_length in zip( + generation_outputs["output_ids"], input_lengths, unpadded_sequence_lengths + ): + generated_ids.append(output_ids[input_length:total_length]) + + generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + # Append to message log + for i, (text, input_length, total_length) in enumerate( + zip(generated_texts, input_lengths, unpadded_sequence_lengths) + ): + message = { + "role": "assistant", + "content": text, + "token_ids": generation_outputs["output_ids"][i, input_length:total_length], + } + + if include_logprobs and "logprobs" in generation_outputs: + message["generation_logprobs"] = generation_outputs["logprobs"][ + i, input_length:total_length + ] + + batch["message_log"][i].append(message) + + metrics = { + "mean_generation_length": ( + torch.sum(unpadded_sequence_lengths) - torch.sum(input_lengths) + ).item() + / len(unpadded_sequence_lengths), + "max_seqlen": torch.max(unpadded_sequence_lengths).item(), + } + + return batch, generated_ids, metrics + + +def calculate_rewards( + batch: BatchedDataDict[DatumSpec], + task_to_env: Dict[str, EnvironmentInterface], +) -> RewardsOutput: + """Calculate rewards for generated responses and get environment feedback. + + Args: + batch: Batch containing message_log (LLMMessageLogType) with generated responses + task_to_env: Dictionary mapping task names to their corresponding environments + + Returns: + Tuple containing: + - rewards: Tensor of rewards for the last turn. + - env_observations: List of observations from the environment for the next turn. + - terminateds: Tensor of booleans indicating if an episode ended naturally. + - next_stop_strings: List of stop strings for the next generation step. + - metadata: List of extracted metadata from the environment. + """ + # Extract message logs for environment (most recent interaction) + to_env = [ + get_keys_from_message_log(batch["message_log"][i], ["role", "content"]) + for i in range(len(batch["message_log"])) + ] + task_names = [batch["task_name"][i] for i in range(len(batch["task_name"]))] + + # Group messages by task type + task_groups = {} + for i, task_name in enumerate(task_names): + if task_name not in task_groups: + task_groups[task_name] = [] + task_groups[task_name].append((i, to_env[i])) + + # Calculate rewards for each task group concurrently + futures = [] + future_to_indices = {} # Map future to its corresponding indices + for task_name, group in task_groups.items(): + if task_name not in task_to_env: + raise ValueError(f"No environment found for task type: {task_name}") + + # Extract indices and messages for this group + indices = [idx for idx, _ in group] + messages = [msg for _, msg in group] + + # Get corresponding environment info + env_info = [batch["extra_env_info"][i] for i in indices] + + # Submit task to environment and store future + future = task_to_env[task_name].step.remote(messages, env_info) + futures.append(future) + future_to_indices[future] = indices + + results = ray.get(futures) + all_rewards = [] + all_env_observations = [] + all_terminateds = [] + all_next_stop_strings = [] + all_metadata = [] # Store extracted metadata + all_indices_order = [] + + for future, result in zip(futures, results): + indices = future_to_indices[future] + # Environment step returns: EnvironmentReturn + env_observations, metadata, next_stop_strings, task_rewards, terminateds = ( + result + ) + + # Store results with their original indices + for i, idx in enumerate(indices): + all_indices_order.append(idx) + all_rewards.append(task_rewards[i]) + all_env_observations.append(env_observations[i]) + all_terminateds.append(terminateds[i]) + all_next_stop_strings.append(next_stop_strings[i]) + all_metadata.append(metadata[i]) + + # Sort results by original index to maintain order + sorted_indices = sorted( + range(len(all_indices_order)), key=lambda k: all_indices_order[k] + ) + rewards = torch.tensor([all_rewards[i] for i in sorted_indices]) + env_observations = [all_env_observations[i] for i in sorted_indices] + terminateds = torch.tensor([all_terminateds[i] for i in sorted_indices]) + next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices] + metadata = [all_metadata[i] for i in sorted_indices] # Sort metadata + + # Ensure tensors are on CPU + rewards = rewards.cpu() + terminateds = terminateds.cpu() + + return RewardsOutput( + rewards=rewards, + env_observations=env_observations, + terminateds=terminateds, + next_stop_strings=next_stop_strings, + metadata=metadata, + ) + + +def run_multi_turn_rollout( + policy_generation: GenerationInterface, + initial_batch: BatchedDataDict[DatumSpec], + tokenizer: AutoTokenizer, + task_to_env: Dict[str, EnvironmentInterface], + max_seq_len: int, + max_turns: int = 999999, + greedy: bool = False, +) -> Tuple[BatchedDataDict[DatumSpec], Dict[str, Any]]: + """Runs a multi-turn rollout loop, interacting with the environment. + + Args: + policy_generation: The generation interface (policy). + initial_batch: The starting batch containing initial message logs. + tokenizer: The tokenizer. + task_to_env: Dictionary mapping task names to environment instances. + max_turns: Maximum number of agent-environment interaction turns. + max_seq_len: Maximum sequence length allowed. + greedy: Whether to use greedy decoding. + + Returns: + Tuple containing: + - BatchedDataDict with the full interaction history and accumulated rewards + - Dictionary of rollout metrics + """ + current_batch = initial_batch.copy() # Work on a copy + batch_size = len(current_batch["message_log"]) + active_indices = torch.arange(batch_size) + turn_rewards = torch.zeros(batch_size, dtype=torch.float32) + total_rewards = torch.zeros(batch_size, dtype=torch.float32) + + # Initialize stop_strings from the initial batch if present + current_stop_strings = current_batch.get("stop_strings", [None] * batch_size) + + # Tracking metrics for each sample + sample_turn_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_token_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_assistant_token_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_env_token_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_terminated = torch.zeros(batch_size, dtype=torch.bool) + sample_truncated = torch.zeros(batch_size, dtype=torch.bool) + sample_max_turns_reached = torch.zeros(batch_size, dtype=torch.bool) + + # Tracking per-turn metrics + total_gen_tokens_per_turn = [] + reward_per_turn = [] + active_samples_per_turn = [] + + for turn in range(max_turns): + if len(active_indices) == 0: + print(f" Turn {turn + 1}/{max_turns}: All samples finished.") + break + + print( + f" Turn {turn + 1}/{max_turns}: Processing {len(active_indices)} active samples..." + ) + + active_samples_per_turn.append(len(active_indices)) + + # Convert LLMMessageLogType to FlatMessagesType for generation + active_batch = current_batch.select_indices(active_indices) + active_stop_strings = [current_stop_strings[i] for i in active_indices.tolist()] + + active_flat_messages: FlatMessagesType + active_flat_messages, active_input_lengths = ( + batched_message_log_to_flat_message( + active_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + ) + + # Extract input_ids and lengths from the flat messages + active_input_ids = active_flat_messages["token_ids"] + + generation_input_data = BatchedDataDict[GenerationDatumSpec]( + { + "input_ids": active_input_ids, + "input_lengths": active_input_lengths, + "stop_strings": active_stop_strings, + } + ) + + # generate_responses updates active_batch["message_log"] in-place + active_batch, generated_ids, gen_metrics = generate_responses( + policy_generation, + generation_input_data, + active_batch, + tokenizer, + active_input_lengths, + greedy=greedy, + ) + print( + f" Generated responses (Avg len: {gen_metrics['mean_generation_length']:.1f})" + ) + + # Record token usage - assistant + for i, global_idx in enumerate(active_indices.tolist()): + sample_assistant_token_counts[global_idx] += len(generated_ids[i]) + sample_token_counts[global_idx] += len(generated_ids[i]) + + # Track total generated tokens this turn + total_gen_tokens_per_turn.append(sum(len(ids) for ids in generated_ids)) + + # Calculate rewards and get environment feedback + env_output: RewardsOutput = calculate_rewards(active_batch, task_to_env) + + turn_rewards[active_indices] = env_output.rewards + total_rewards[active_indices] += turn_rewards[active_indices] + + # Record rewards for this turn + reward_per_turn.append(env_output.rewards.mean().item()) + + print( + f" Calculated rewards (Avg: {turn_rewards[active_indices].mean():.3f})" + ) + + # Update message log for ALL active samples with env observation + # This must happen BEFORE filtering based on done flags + truncation_mask = torch.zeros_like(env_output.terminateds) + for i, global_idx in enumerate(active_indices.tolist()): + env_obs_content = env_output.env_observations[i]["content"] + # Tokenize the raw content from the environment + # TODO @sahilj: handle if we want these subsequent messages to have a chat template + tokenized_obs = tokenizer( + env_obs_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + + # check if new message overflows max_seq_len + if len(tokenized_obs) + active_input_lengths[i] > max_seq_len: + # truncate + tokenized_obs = tokenized_obs[: max_seq_len - active_input_lengths[i]] + truncation_mask[i] = True + # Record truncation + sample_truncated[active_indices[i]] = True + + tokenized_env_obs_message = { + "role": env_output.env_observations[i]["role"], + "content": env_obs_content, + "token_ids": tokenized_obs, + } + current_batch["message_log"][global_idx].append(tokenized_env_obs_message) + + # Record token usage - environment + sample_env_token_counts[global_idx] += len(tokenized_obs) + sample_token_counts[global_idx] += len(tokenized_obs) + + # Increment turn count + sample_turn_counts[global_idx] += 1 + + # Determine done samples and update active set + done = env_output.terminateds | truncation_mask + active_mask = ~done + + # Identify samples that just finished this turn + newly_finished_indices_local = torch.where(done)[0] + newly_finished_indices_global = active_indices[newly_finished_indices_local] + + # Record termination status + for i, idx in enumerate(newly_finished_indices_local.tolist()): + global_idx = active_indices[idx].item() + # Record whether this sample terminated naturally + if env_output.terminateds[idx]: + sample_terminated[global_idx] = True + + print( + f" {len(newly_finished_indices_global)} samples finished this turn." + f" (Terminated: {env_output.terminateds.sum()})" + ) + + # Update active indices for the next iteration + active_indices_local_next = torch.where(active_mask)[0] + active_indices = active_indices[active_indices_local_next] + continuing_indices_global = active_indices # Indices relative to original batch + # Get next stop strings and infos corresponding to the indices that are *continuing* + continuing_next_stops = [ + env_output.next_stop_strings[i] for i in active_indices_local_next.tolist() + ] + # Get metadata corresponding to continuing indices, using the correct field name + continuing_metadata = [ + env_output.metadata[i] for i in active_indices_local_next.tolist() + ] + + for i, global_idx in enumerate(continuing_indices_global.tolist()): + # Update stop strings for the next turn + current_stop_strings[global_idx] = continuing_next_stops[i] + # Update metadata (extra_env_info) using info from environment + if continuing_metadata[i] is not None: + current_batch["extra_env_info"][global_idx] = continuing_metadata[i] + + # Record samples that reached max turns + if len(active_indices) > 0: + sample_max_turns_reached[active_indices] = True + + # Add total rewards to the final batch + current_batch["total_reward"] = total_rewards + + # Calculate aggregate metrics + rollout_metrics = { + # Overall metrics + "total_turns": int(sample_turn_counts.sum().item()), + "avg_turns_per_sample": float(sample_turn_counts.float().mean().item()), + "max_turns_per_sample": int(sample_turn_counts.max().item()), + "natural_termination_rate": float(sample_terminated.float().mean().item()), + "truncation_rate": float(sample_truncated.float().mean().item()), + "max_turns_reached_rate": float(sample_max_turns_reached.float().mean().item()), + # Token usage metrics + "mean_gen_tokens_per_sample": float(sample_token_counts.float().mean().item()), + "mean_env_tokens_per_sample": float( + sample_env_token_counts.float().mean().item() + ), + } + return current_batch, rollout_metrics diff --git a/nemo_reinforcer/models/generation/interfaces.py b/nemo_reinforcer/models/generation/interfaces.py index 468714899f..48ed8554d8 100644 --- a/nemo_reinforcer/models/generation/interfaces.py +++ b/nemo_reinforcer/models/generation/interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, TypedDict, Union, Tuple, List +from typing import Any, TypedDict, Union, Tuple, List, Optional import torch from transformers import AutoTokenizer @@ -139,6 +139,7 @@ class GenerationDatumSpec(TypedDict): - input_ids: Tensor of token IDs representing the input sequences (right padded) - input_lengths: Tensor containing the actual length of each sequence (without padding) + - stop_strings: Optional list of strings to stop generation (per sample) - __extra__: Additional model-specific data fields Example of a batch with 4 entries with different sequence lengths: @@ -163,6 +164,7 @@ class GenerationDatumSpec(TypedDict): input_ids: torch.Tensor input_lengths: torch.Tensor + stop_strings: Optional[List[str]] = None __extra__: Any diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index ada0bf2623..57ecde59ed 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -224,24 +224,25 @@ def generate( - generation_lengths: Lengths of each response - unpadded_sequence_lengths: Lengths of each input + generated sequence """ - # Verify input is right padded - assert isinstance(data, BatchedDataDict), ( - f"data must be a BatchedDataDict, got type: {type(data)}" - ) - assert "input_ids" in data and "input_lengths" in data, ( - f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" - ) - is_right_padded, error_msg = verify_right_padding( - data, pad_value=self.cfg["pad_token_id"] - ) - if not is_right_padded: - warnings.warn( - f"Input to vLLM worker is not properly right-padded: {error_msg}" - ) - - # Convert inputs to vLLM format input_ids = data["input_ids"] input_lengths = data["input_lengths"] + # this function requires all generations have the same stop strings, so we collect all here + batch_stop_strings = data.get("stop_strings", []) + stop_strings = set() + for sample_stop_strings in batch_stop_strings: + if sample_stop_strings: + stop_strings.update(sample_stop_strings) + + # Add default stop strings from config + if self.cfg.get("stop_strings", None): + stop_strings.update(self.cfg["stop_strings"]) + + stop_strings = list(stop_strings) + + # verify inputs have correct padding + verify_right_padding(data, pad_value=self.cfg["pad_token_id"]) + + # Convert inputs to vLLM format batch_size = input_ids.shape[0] # Original input length with padding padded_input_length = input_ids.size(1) @@ -269,7 +270,7 @@ def generate( max_tokens=self.cfg["max_new_tokens"], logprobs=0, # Return logprobs for the generated tokens stop_token_ids=self.cfg["stop_token_ids"], - stop=self.cfg["stop_strings"], + stop=stop_strings, include_stop_str_in_output=True, # returning stop strings like hf ) @@ -359,6 +360,23 @@ def generate_text( BatchedDataDict containing: - texts: List of generated text responses """ + # Extract stop_strings if provided, else use default from config + batch_stop_strings = data.get( + "stop_strings", [self.cfg.get("stop_strings")] * len(data["prompts"]) + ) + + # This function requires all generations have the same stop strings, so we collect all here + stop_strings = set() + for sample_stop_strings in batch_stop_strings: + if sample_stop_strings: + stop_strings.update(sample_stop_strings) + + # Add default stop strings from config + if self.cfg.get("stop_strings", None): + stop_strings.update(self.cfg["stop_strings"]) + + stop_strings = list(stop_strings) if len(stop_strings) > 0 else None + # Read generation parameters from config top_k = self.cfg["top_k"] if self.cfg["top_k"] is not None else -1 sampling_params = self.SamplingParams( @@ -367,7 +385,7 @@ def generate_text( top_k=top_k if not greedy else 1, max_tokens=self.cfg["max_new_tokens"], stop_token_ids=self.cfg["stop_token_ids"], - stop=self.cfg["stop_strings"], + stop=stop_strings, include_stop_str_in_output=True, # returning stop strings like hf ) diff --git a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py index 192d51ce88..37c385d44b 100644 --- a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py +++ b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py @@ -517,6 +517,20 @@ def generate( # Set attention mask for the actual tokens (at the end for left padding) left_padded_attention_mask[i, seq_len - length :] = 1 + # this function requires all generations have the same stop strings, so we collect all here + batch_stop_strings = gen_batch.get("stop_strings", []) + stop_strings = set() + for sample_stop_strings in batch_stop_strings: + if sample_stop_strings: + stop_strings.update(sample_stop_strings) + + # Add default stop strings from config + if gen_cfg.get("stop_strings", None): + stop_strings.update(gen_cfg["stop_strings"]) + + stop_strings = list(stop_strings) if len(stop_strings) > 0 else None + print(f"Stop strings: {stop_strings}") + if isinstance( self.model, torch.distributed.fsdp.FullyShardedDataParallel ): @@ -533,7 +547,7 @@ def generate( top_k=gen_cfg["top_k"], pad_token_id=gen_cfg["pad_token_id"], eos_token_id=gen_cfg["stop_token_ids"], - stop_strings=gen_cfg["stop_strings"], + stop_strings=stop_strings, tokenizer=self.tokenizer, # needs for stop_strings return_dict_in_generate=True, output_scores=True, diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index c6491e02d6..7a492e2b01 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -16,10 +16,13 @@ import ray from typing import Dict, List, Tuple -from nemo_reinforcer.algorithms.grpo import calculate_rewards +from nemo_reinforcer.experience.rollouts import calculate_rewards from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.data.interfaces import DatumSpec, LLMMessageLogType -from nemo_reinforcer.environments.interfaces import EnvironmentInterface +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) @ray.remote(num_cpus=0) @@ -30,9 +33,15 @@ def __init__(self, rewards: List[float]): def step( self, messages: List[LLMMessageLogType], env_info: List[dict] - ) -> Tuple[None, None, List[float], None]: + ) -> EnvironmentReturn: self._calls += 1 - return None, None, self.rewards, None + return ( + [{"role": "environment", "content": "observation"}] * len(messages), + [{}] * len(messages), + [[]] * len(messages), + self.rewards, + [True] * len(messages), + ) def get_calls(self): return self._calls @@ -117,11 +126,17 @@ def test_calculate_rewards_single_task(mock_env): batch = create_mock_batch(2, task_names, message_logs) # Calculate rewards - rewards, to_env = calculate_rewards(batch, task_to_env) + rewards, env_observations, terminateds, next_stop_strings, metadata = ( + calculate_rewards(batch, task_to_env) + ) # Verify results assert torch.allclose(rewards, torch.tensor([1.0, 2.0])) - assert len(to_env) == 2 + assert len(env_observations) == 2 + assert len(terminateds) == 2 + assert len(next_stop_strings) == 2 + assert len(metadata) == 2 + assert torch.allclose(rewards, torch.tensor([1.0, 2.0])) assert ( ray.get(mock_env.get_calls.remote()) == 1 ) # Should only call once for all samples of same task @@ -146,11 +161,17 @@ def test_calculate_rewards_multiple_tasks(mock_envs): batch = create_mock_batch(4, task_names, message_logs) # Calculate rewards - rewards, to_env = calculate_rewards(batch, mock_envs) + rewards, env_observations, terminateds, next_stop_strings, metadata = ( + calculate_rewards(batch, mock_envs) + ) # Verify results assert torch.allclose(rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) - assert len(to_env) == 4 + assert len(env_observations) == 4 + assert len(terminateds) == 4 + assert len(next_stop_strings) == 4 + assert len(metadata) == 4 + assert torch.allclose(rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) assert ( ray.get(mock_envs["math"].get_calls.remote()) == 1 ) # One call for all math samples @@ -167,11 +188,16 @@ def test_calculate_rewards_empty_batch(mock_env): batch = create_mock_batch(0, [], []) # Calculate rewards - rewards, to_env = calculate_rewards(batch, task_to_env) + rewards, env_observations, terminateds, next_stop_strings, metadata = ( + calculate_rewards(batch, task_to_env) + ) # Verify results assert len(rewards) == 0 - assert len(to_env) == 0 + assert len(env_observations) == 0 + assert len(terminateds) == 0 + assert len(next_stop_strings) == 0 + assert len(metadata) == 0 assert ( ray.get(mock_env.get_calls.remote()) == 0 ) # Should not call environment for empty batch diff --git a/tests/unit/environments/game_interface.py b/tests/unit/environments/game_interface.py new file mode 100644 index 0000000000..77c7396a18 --- /dev/null +++ b/tests/unit/environments/game_interface.py @@ -0,0 +1,62 @@ +from typing import Dict, Any, Tuple, List, Optional, Callable + + +class GameInterface: + @staticmethod + def generate(config: Dict[str, Any]) -> Dict[str, Any]: + """ + Generate a new game state based on configuration. + + Args: + config: Game configuration dictionary + + Returns: + A dictionary containing the complete game state + """ + raise NotImplementedError("Each game must implement generate()") + + @staticmethod + def init(game_state: Dict[str, Any]) -> str: + """ + Initialize a game and return welcome messages. + + Args: + game_state: The game state dictionary + + Returns: + String containing welcome message and instructions + """ + raise NotImplementedError("Each game must implement init()") + + @staticmethod + def step( + action: str, game_state: Dict[str, Any] + ) -> Tuple[str, float, bool, Dict[str, Any]]: + """ + Process a game action and update the state. + + Args: + action: String representing the player's action + game_state: Current game state dictionary + + Returns: + Tuple containing: + - Response message + - Reward for this action + - Boolean indicating if game is terminated + - Updated game state + """ + raise NotImplementedError("Each game must implement step()") + + @staticmethod + def render(game_state: Dict[str, Any]) -> str: + """ + Render the current game state as a string. + + Args: + game_state: The game state dictionary + + Returns: + String representation of the game state + """ + raise NotImplementedError("Each game must implement render()") diff --git a/tests/unit/environments/sliding_puzzle_game.py b/tests/unit/environments/sliding_puzzle_game.py new file mode 100644 index 0000000000..584eb6b3fb --- /dev/null +++ b/tests/unit/environments/sliding_puzzle_game.py @@ -0,0 +1,242 @@ +import random +import copy +from typing import List, Tuple, Dict, Any, Optional +from .game_interface import GameInterface + + +class SlidingPuzzleGame(GameInterface): + @staticmethod + def generate(config: Dict[str, Any]) -> Dict[str, Any]: + """Generate a new Sliding Puzzle.""" + size = config.get("size", 4) # Default to 4x4 (15-puzzle) + shuffle_moves = config.get( + "shuffle_moves", 100 + ) # Number of random moves for shuffling + + # Create the solved state + grid = [[(r * size + c + 1) for c in range(size)] for r in range(size)] + # Set the bottom-right corner to 0 (empty space) + grid[size - 1][size - 1] = 0 + + # Save the solution + solution = [row[:] for row in grid] + + # Find the empty space + empty_pos = (size - 1, size - 1) + + # Shuffle the grid with valid moves + for _ in range(shuffle_moves): + # Get possible moves + moves = [] + r, c = empty_pos + for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]: # Right, Down, Left, Up + nr, nc = r + dr, c + dc + if 0 <= nr < size and 0 <= nc < size: + moves.append((nr, nc)) + + # Choose a random move + if moves: + new_r, new_c = random.choice(moves) + # Swap the empty space with the chosen tile + grid[r][c], grid[new_r][new_c] = grid[new_r][new_c], grid[r][c] + empty_pos = (new_r, new_c) + + # Create and return the game state + return { + "size": size, + "grid": grid, + "solution": solution, + "empty_pos": empty_pos, + "commands": { + "slide r c": "Slide tile at row r, column c into the empty space", + "up": "Slide tile below empty space up", + "down": "Slide tile above empty space down", + "left": "Slide tile to the right of empty space left", + "right": "Slide tile to the left of empty space right", + }, + } + + @staticmethod + def init(game_state: Dict[str, Any]) -> str: + """Initialize Sliding Puzzle game and return welcome message.""" + size = game_state["size"] + + return ( + f"\n===== SLIDING PUZZLE =====\n" + f"Arrange the {size}x{size} grid by sliding tiles into the empty space.\n" + f"- The goal is to arrange numbers from 1 to {size * size - 1} in order\n" + f"- Use 'slide r c' to slide a specific tile\n" + f"- Or use 'up', 'down', 'left', 'right' to slide in that direction" + ) + + @staticmethod + def step( + action: str, game_state: Dict[str, Any] + ) -> Tuple[str, float, bool, Dict[str, Any]]: + """Process an action in the Sliding Puzzle game.""" + size = game_state["size"] + grid = game_state["grid"] + empty_r, empty_c = game_state["empty_pos"] + + # Default return values + response = "Unknown command. Type 'help' to see available commands." + reward = -0.05 # Small penalty for invalid actions + is_terminated = False + + # Deep copy game state to avoid modifying the original + new_state = copy.deepcopy(game_state) + + move_made = False + + if action.startswith("slide "): + try: + _, r, c = action.split() + r, c = int(r) - 1, int(c) - 1 + + # Validate input + if not (0 <= r < size and 0 <= c < size): + return ( + f"Invalid position. Row/column must be between 1 and {size}.", + reward, + is_terminated, + new_state, + ) + + # Check if tile is adjacent to empty space + if abs(r - empty_r) + abs(c - empty_c) != 1: + return ( + "Tile must be adjacent to the empty space.", + reward, + is_terminated, + new_state, + ) + + # Slide the tile + new_state["grid"][empty_r][empty_c] = grid[r][c] + new_state["grid"][r][c] = 0 + new_state["empty_pos"] = (r, c) + + move_made = True + response = f"Slid tile {grid[r][c]} into the empty space." + + except ValueError: + return ( + "Invalid input format. Use: slide row col", + reward, + is_terminated, + new_state, + ) + + elif action in ["up", "down", "left", "right"]: + # Convert direction to row/col offset + if action == "up": + r, c = empty_r + 1, empty_c # Tile below moves up + dir_text = "up" + elif action == "down": + r, c = empty_r - 1, empty_c # Tile above moves down + dir_text = "down" + elif action == "left": + r, c = empty_r, empty_c + 1 # Tile to right moves left + dir_text = "left" + elif action == "right": + r, c = empty_r, empty_c - 1 # Tile to left moves right + dir_text = "right" + + # Check if the move is valid + if 0 <= r < size and 0 <= c < size: + # Slide the tile + new_state["grid"][empty_r][empty_c] = grid[r][c] + new_state["grid"][r][c] = 0 + new_state["empty_pos"] = (r, c) + + move_made = True + response = f"Slid tile {grid[r][c]} {dir_text}." + else: + return f"Cannot slide {dir_text}.", reward, is_terminated, new_state + + if move_made: + reward = 0 + + # Check if puzzle is solved + if new_state["grid"] == new_state["solution"]: + response = "Congratulations! You've solved the puzzle!" + reward = 1.0 # Win reward + is_terminated = True + + return response, reward, is_terminated, new_state + + @staticmethod + def render(game_state: Dict[str, Any]) -> str: + """Render the current Sliding Puzzle game state.""" + grid = game_state["grid"] + size = game_state["size"] + + output = ["\n"] + + # Create a visual representation of the grid + max_digits = len(str(size * size - 1)) + + # Top border + output.append(" " + "+" + "-" * (max_digits + 2) * size + "+") + + # Rows + for i, row in enumerate(grid): + row_str = f"{i + 1} |" + for val in row: + if val == 0: + # Empty space + row_str += " " * (max_digits + 2) + else: + # Tile with number + row_str += f" {val:>{max_digits}} " + row_str += "|" + output.append(row_str) + + # Bottom border + output.append(" " + "+" + "-" * (max_digits + 2) * size + "+") + + # Column labels + col_labels = " " + for i in range(size): + col_labels += f"{i + 1:^{max_digits + 2}}" + output.append(col_labels) + + return "\n".join(output) + + +def is_solvable(grid: List[List[int]], size: int) -> bool: + """Check if a sliding puzzle is solvable.""" + # Flatten the grid + flat = [num for row in grid for num in row if num != 0] + + # Count inversions + inversions = 0 + for i in range(len(flat)): + for j in range(i + 1, len(flat)): + if flat[i] > flat[j]: + inversions += 1 + + # Find row of the empty tile (0) from the bottom + empty_row = 0 + for i in range(size - 1, -1, -1): + for j in range(size): + if grid[i][j] == 0: + empty_row = size - i + break + + # For odd-sized grids, the puzzle is solvable if the number of inversions is even + if size % 2 == 1: + return inversions % 2 == 0 + # For even-sized grids, the puzzle is solvable if: + # (inversions odd and empty on even row from bottom) or (inversions even and empty on odd row from bottom) + else: + return (inversions % 2 == 1 and empty_row % 2 == 0) or ( + inversions % 2 == 0 and empty_row % 2 == 1 + ) + + +def play_sliding_puzzle(config=None): + """Wrapper function for backward compatibility.""" + from play_game import play_game + + play_game(SlidingPuzzleGame, config) diff --git a/tests/unit/environments/test_math_environment.py b/tests/unit/environments/test_math_environment.py index 9b2eb4e21c..6fb9897e7b 100644 --- a/tests/unit/environments/test_math_environment.py +++ b/tests/unit/environments/test_math_environment.py @@ -109,74 +109,82 @@ def multiple_assistant_test_data(): def test_math_env_step_basic(math_env, basic_test_data): """Test basic functionality of MathEnvironment step with simple messages.""" - observations, updated_metadata, rewards, done = ray.get( + result = ray.get( math_env.step.remote( basic_test_data["message_log_batch"], basic_test_data["metadata"] ) ) - # Check observations - assert len(observations) == 3, "Should return observations for all 3 messages" - assert all(obs["role"] == "user" for obs in observations), ( - "All observations should be from user" + # Check observations using field access + assert len(result.observations) == 3, ( + "Should return observations for all 3 messages" ) - assert all(obs["content"] == "correct" for obs in observations), ( - "All responses should be correct" + assert all(obs["role"] == "environment" for obs in result.observations), ( + "All observations should be from environment" ) + assert all( + obs["content"] == "Environment: correct" for obs in result.observations + ), "All responses should be correct" # Check metadata - assert len(updated_metadata) == 3, "Should return metadata for all 3 messages" - assert updated_metadata == basic_test_data["metadata"], ( + assert len(result.metadata) == 3, "Should return metadata for all 3 messages" + assert result.metadata == basic_test_data["metadata"], ( "Metadata should be unchanged" ) # Check rewards and done flags - assert rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" - assert all(rewards == 1.0), "All rewards should be 1.0 for correct answers" - assert done.shape == (3,), "Done flags should be a tensor of shape (3,)" - assert all(done == 1.0), "All done flags should be 1.0" + assert result.rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" + assert all(result.rewards == 1.0), "All rewards should be 1.0 for correct answers" + assert result.terminated.shape == (3,), ( + "Terminated flags should be a tensor of shape (3,)" + ) + assert all(result.terminated == 1.0), "All terminated flags should be 1.0" def test_math_env_step_mixed(math_env, mixed_test_data): """Test MathEnvironment step with a mix of correct and incorrect responses.""" - observations, updated_metadata, rewards, done = ray.get( + result = ray.get( math_env.step.remote( mixed_test_data["message_log_batch"], mixed_test_data["metadata"] ) ) # Check observations and rewards - assert len(observations) == 3, "Should return observations for all 3 messages" - assert observations[0]["content"] == "correct", "First response should be correct" - assert observations[1]["content"] == "incorrect", ( + assert len(result.observations) == 3, ( + "Should return observations for all 3 messages" + ) + assert result.observations[0]["content"] == "Environment: correct", ( + "First response should be correct" + ) + assert result.observations[1]["content"] == "Environment: incorrect", ( "Second response should be incorrect" ) - assert observations[2]["content"] == "correct", "Third response should be correct" + assert result.observations[2]["content"] == "Environment: correct", ( + "Third response should be correct" + ) - assert rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" - assert rewards[0] == 1.0, "First reward should be 1.0" - assert rewards[1] == 0.0, "Second reward should be 0.0" - assert rewards[2] == 1.0, "Third reward should be 1.0" + assert result.rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" + assert result.rewards[0] == 1.0, "First reward should be 1.0" + assert result.rewards[1] == 0.0, "Second reward should be 0.0" + assert result.rewards[2] == 1.0, "Third reward should be 1.0" def test_math_env_step_empty(math_env): """Test MathEnvironment step with empty input.""" - observations, updated_metadata, rewards, done = ray.get( - math_env.step.remote([], []) - ) + result = ray.get(math_env.step.remote([], [])) # Check all outputs are empty - assert len(observations) == 0, "Should return empty observations list" - assert len(updated_metadata) == 0, "Should return empty metadata list" - assert rewards.shape == (0,), "Should return empty rewards tensor" - assert done.shape == (0,), "Should return empty done tensor" + assert len(result.observations) == 0, "Should return empty observations list" + assert len(result.metadata) == 0, "Should return empty metadata list" + assert result.rewards.shape == (0,), "Should return empty rewards tensor" + assert result.terminated.shape == (0,), "Should return empty terminated tensor" def test_math_env_step_multiple_assistant_messages( math_env, multiple_assistant_test_data ): """Test MathEnvironment step with multiple assistant messages in a conversation.""" - observations, updated_metadata, rewards, done = ray.get( + result = ray.get( math_env.step.remote( multiple_assistant_test_data["message_log_batch"], multiple_assistant_test_data["metadata"], @@ -184,11 +192,13 @@ def test_math_env_step_multiple_assistant_messages( ) # Check that only the last assistant message is used - assert len(observations) == 2, "Should return observations for both conversations" - assert all(obs["content"] == "correct" for obs in observations), ( - "All responses should be correct" + assert len(result.observations) == 2, ( + "Should return observations for both conversations" ) - assert all(rewards == 1.0), "All rewards should be 1.0" + assert all( + obs["content"] == "Environment: correct" for obs in result.observations + ), "All responses should be correct" + assert all(result.rewards == 1.0), "All rewards should be 1.0" @pytest.mark.parametrize("batch_size", [1, 2, 10, 25, 101]) @@ -202,16 +212,20 @@ def test_math_env_various_batches(math_env, batch_size): ] * batch_size metadata = [{"ground_truth": "3.33333333"}] * batch_size - observations, updated_metadata, rewards, done = ray.get( - math_env.step.remote(message_log_batch, metadata) - ) + result = ray.get(math_env.step.remote(message_log_batch, metadata)) # Check outputs - assert len(observations) == batch_size, ( + assert len(result.observations) == batch_size, ( f"Should return observations for all {batch_size} messages" ) - assert all(obs["content"] == "correct" for obs in observations), ( - "All responses should be correct" + assert all( + obs["content"] == "Environment: correct" for obs in result.observations + ), "All responses should be correct" + assert result.rewards.shape == (batch_size,), ( + "Rewards should be a tensor of shape (batch_size,)" + ) + assert all(result.rewards == 1.0), "All rewards should be 1.0" + assert result.terminated.shape == (batch_size,), ( + "Terminated flags should be a tensor of shape (batch_size,)" ) - assert all(rewards == 1.0), "All rewards should be 1.0" - assert all(done == 1.0), "All done flags should be 1.0" + assert all(result.terminated == 1.0), "All terminated flags should be 1.0" diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py new file mode 100644 index 0000000000..412654fff8 --- /dev/null +++ b/tests/unit/experience/test_rollouts.py @@ -0,0 +1,626 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import ray +import torch +from typing import Dict, List, Tuple, Optional, TypedDict +from copy import deepcopy +import gc + +from transformers import AutoTokenizer, GPT2TokenizerFast + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.data.interfaces import DatumSpec, LLMMessageLogType +from nemo_reinforcer.environments.interfaces import EnvironmentInterface +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.models.generation.interfaces import configure_generation_config +from nemo_reinforcer.experience.rollouts import run_multi_turn_rollout +from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES + +# Import the test environment definitions +from tests.unit.test_envs import ( + MultiStepCalculatorEnv, + _MultiStepCalculatorLogic, + MultiStepCalcMetadata, + SlidingPuzzleEnv, + SlidingPuzzleMetadata, +) + +# Import the game logic for generating initial state from its new location +from tests.unit.environments.sliding_puzzle_game import SlidingPuzzleGame + +from nemo_reinforcer.models.generation.vllm import VllmConfig, VllmGeneration + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="function") +def rollout_tokenizer(): + """Loads the tokenizer for the tests.""" + print(f"Loading tokenizer: {MODEL_NAME}") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + print( + f"Tokenizer loaded. Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})" + ) + return tokenizer + + +# Separate fixture for cluster setup and teardown +@pytest.fixture(scope="function") +def rollout_cluster(): + cluster_instance = None + cluster_name = f"test-rollout-cluster-{id(cluster_instance)}" # Unique name + print(f"\nCreating virtual cluster '{cluster_name}'...") + try: + # Use 1 GPU for simplicity + cluster_instance = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[1], + use_gpus=True, + num_gpus_per_node=1, + max_colocated_worker_groups=2, # Allow policy and env + ) + yield cluster_instance + finally: + print(f"\nCleaning up cluster '{cluster_name}'...") + if cluster_instance: + cluster_instance.shutdown() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print(f"Cluster '{cluster_name}' cleanup finished.") + + +# Fixture for the multi-step calculator environment actor +@pytest.fixture(scope="function") +def multi_step_calculator_environment(rollout_cluster): + env_actor = None + print("Creating MultiStepCalculatorEnv actor...") + try: + env_actor = MultiStepCalculatorEnv.remote() + task_to_env = {"multi_step_calculator_game": env_actor} + yield task_to_env, env_actor + finally: + print("Cleaning up multi_step_calculator_environment...") + if env_actor: + ray.kill(env_actor) + print("multi_step_calculator_environment cleanup finished.") + + +# Fixture for the multi-step calculator initial batch data +@pytest.fixture(scope="function") +def initial_multi_step_calculator_batch(rollout_tokenizer): + print("Creating initial multi-step calculator test batch...") + batch_size = 1 # Simpler to debug with one sample + problem = "(5 + 3) * 2" # Example problem + expected_answer = 16.0 + max_steps = 5 # Allow a few steps + tool_instructions = ( + "You have a calculator tool. To use it, respond with:\n" + "'[operand1, operand2, operation_name]'\n" + "The valid 'operation_name' values are exactly: 'sum', 'diff', 'prod', 'div'.\n" + "Example: [5, 3, sum]\n" + "You will receive the result of your calculation as ...\n" + "Use this result to make the next calculation if needed.\n" + "IMPORTANT: Only perform one calculation step (one tool call) before waiting for a result and making a new tool call.\n" + "IMPORTANT: Do not perform any other calculations or operations aside from the tool call and result. Doing so will result in failure.\n" + "To give the final answer, just output the number. numbers inside of don't count, so output just the final number yourself outside of this.\n" + "Example full output: [2, 4, sum]\n6.0\n[6, 6, diff]\n0.0 0\n(note how you have to output the final 0 outside of the tags)" + "------\n" + f"Solve: {problem}" + ) + batch_message_logs = [] + batch_extra_env_info = [] + batch_loss_multipliers = [] + batch_indices = [] + batch_task_names = [] + + for i in range(batch_size): + # Apply chat template to the initial prompt + initial_prompt_content = rollout_tokenizer.apply_chat_template( + [{"role": "user", "content": tool_instructions}], + tokenize=False, + add_system_prompt=False, + add_generation_prompt=True, + add_special_tokens=False, + ) + tokenized_prompt = rollout_tokenizer( + initial_prompt_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + message_log = [ + { + "role": "user", + "content": initial_prompt_content, + "token_ids": tokenized_prompt, + } + ] + metadata = MultiStepCalcMetadata( + problem=problem, + expected_final_answer=expected_answer, + max_steps=max_steps, + current_step=0, + ) + + batch_message_logs.append(message_log) + batch_extra_env_info.append(metadata) + batch_loss_multipliers.append(1.0) + batch_indices.append(i) + batch_task_names.append("multi_step_calculator_game") + + initial_batch_dict = { + "message_log": batch_message_logs, + "extra_env_info": batch_extra_env_info, + "loss_multiplier": batch_loss_multipliers, + "idx": batch_indices, + "task_name": batch_task_names, + "stop_strings": [[""]] * batch_size, + } + return BatchedDataDict(initial_batch_dict) + + +# Keep the base config separate +base_hf_test_config: PolicyConfig = { + "policy_type": "hf", + "model_name": MODEL_NAME, + "tokenizer_name": None, + "model_path": None, + "num_workers": 1, + "train_global_batch_size": 2, + "train_micro_batch_size": 1, + "logprob_batch_size": 2, + "generation_batch_size": 1, # Smaller for simpler testing + "learning_rate": 5e-6, + "precision": "float32", + "activation_checkpointing_enabled": False, + "fsdp_offload_enabled": False, + "generation": { + "backend": "hf", + "max_new_tokens": 50, # Increased for tool call format + "temperature": 0.01, + "top_p": 1.0, + "top_k": None, + "stop_token_ids": None, + "stop_strings": None, + }, + "optimizer": { + "name": "torch.optim.AdamW", + "kwargs": { + "lr": 5e-6, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-8, + }, + }, + "dtensor_cfg": {"enabled": False}, +} + +base_vllm_test_config: VllmConfig = { + "backend": "vllm", + "model_name": MODEL_NAME, + "tokenizer_name": None, + "dtype": "bfloat16", + "gpu_memory_utilization": 0.6, + "max_new_tokens": 50, # Increased for tool call format + "temperature": 0.01, # Near-greedy + "top_p": 1.0, + "top_k": None, + "stop_token_ids": None, + "stop_strings": None, + "vllm_cfg": { + "tensor_parallel_size": 1, + "max_model_len": 2048, + "disable_log_stats": True, + "disable_log_requests": True, + "gpu_memory_utilization": 0.6, + }, +} + + +@pytest.fixture(scope="function") +def multi_step_setup_hf( + rollout_cluster, + rollout_tokenizer, + multi_step_calculator_environment, + initial_multi_step_calculator_batch, +): + """Sets up components for multi-step calculator tests using HfPolicy.""" + policy = None + task_to_env, _ = multi_step_calculator_environment + print("Creating HfPolicy for Multi-Step Calculator Test...") + try: + config = deepcopy(base_hf_test_config) + config["tokenizer_name"] = rollout_tokenizer.name_or_path + if "gpt2" in rollout_tokenizer.name_or_path.lower(): + config["model_name"] = "gpt2" + config["generation"] = configure_generation_config( + config["generation"], rollout_tokenizer + ) + config["generation"]["stop_strings"] = None + policy = HfPolicy( + cluster=rollout_cluster, + config=config, + tokenizer=rollout_tokenizer, + init_reference_model=False, + init_optimizer=False, + ) + yield ( + policy, + rollout_tokenizer, + task_to_env, + initial_multi_step_calculator_batch, + rollout_cluster, + ) + finally: + print("Cleaning up HfPolicy (Multi-Step Calc Test)...") + if policy: + policy.shutdown() + print("HfPolicy cleanup finished (Multi-Step Calc Test).") + + +@pytest.fixture(scope="function") +def multi_step_setup_vllm( + rollout_cluster, + rollout_tokenizer, + multi_step_calculator_environment, + initial_multi_step_calculator_batch, +): + """Sets up components for multi-step calculator tests using VllmGeneration.""" + vllm_generation = None + task_to_env, _ = multi_step_calculator_environment + is_eval = True + print("Creating VllmGeneration for Multi-Step Calculator Test...") + try: + vllm_config = deepcopy(base_vllm_test_config) + vllm_config["tokenizer_name"] = rollout_tokenizer.name_or_path + if "gpt2" in rollout_tokenizer.name_or_path.lower(): + vllm_config["model_name"] = "gpt2" + vllm_config = configure_generation_config( + vllm_config, rollout_tokenizer, is_eval=is_eval + ) + vllm_generation = VllmGeneration(rollout_cluster, vllm_config) + vllm_generation.finish_generation() + yield ( + vllm_generation, + rollout_tokenizer, + task_to_env, + initial_multi_step_calculator_batch, + rollout_cluster, + ) + finally: + print("Cleaning up VllmGeneration (Multi-Step Calc Test)...") + if vllm_generation: + vllm_generation.shutdown() + print("VllmGeneration cleanup finished (Multi-Step Calc Test).") + + +def test_run_multi_step_calculator_hf(multi_step_setup_hf): + """Tests multi-step calculator rollout with HfPolicy.""" + policy, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( + multi_step_setup_hf + ) + max_turns = ( + initial_batch["extra_env_info"][0]["max_steps"] + 1 + ) # Allow max steps + final answer + max_seq_len = 1024 # Increased for potentially longer interaction + + print("\nRunning multi-step calculator rollout (HF)...") + policy.prepare_for_generation() + final_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=policy, + initial_batch=initial_batch, + tokenizer=rollout_tokenizer, + task_to_env=task_to_env, + max_seq_len=max_seq_len, + max_turns=max_turns, + ) + policy.finish_generation() + print("Multi-step calculator rollout complete (HF).") + + # --- Assertions --- + assert isinstance(final_batch, BatchedDataDict) + assert "message_log" in final_batch + assert "total_reward" in final_batch + assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) + + sample_log = final_batch["message_log"][0] + expected_final_answer = initial_batch["extra_env_info"][0]["expected_final_answer"] + print("\nSample Interaction Log (Multi-Step Calculator - HF):") + tool_call_count = 0 + final_answer_msg = None + for i, msg in enumerate(sample_log): + print(f" {i}: Role={msg['role']}, Content='{msg['content']}'") + if msg["role"] == "assistant": + if msg["content"].strip().endswith(""): + tool_call_count += 1 + else: + final_answer_msg = msg["content"].strip() + + assert tool_call_count >= 1, "Expected at least one tool call" + assert final_answer_msg is not None, ( + "Expected a final answer message from assistant" + ) + + # Check final answer correctness (allowing for different final answer formats) + final_answer_logic = _MultiStepCalculatorLogic() + extracted_final_answer = final_answer_logic._is_final_answer(final_answer_msg) + assert extracted_final_answer is not None, ( + f"Could not parse final answer from: {final_answer_msg}" + ) + assert abs(extracted_final_answer - expected_final_answer) < 1e-6, ( + f"Final answer incorrect. Expected {expected_final_answer}, Got {extracted_final_answer}" + ) + + # Check total reward (should be 1.0 if correct) + assert torch.all(final_batch["total_reward"] == 1.0), ( + f"Expected total reward 1.0, got {final_batch['total_reward']}" + ) + + print("\nMulti-Step Calculator HF Test assertions passed.") + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 1, + reason="VLLM test requires at least 1 GPU", +) +def test_run_multi_step_calculator_vllm(multi_step_setup_vllm): + """Tests multi-step calculator rollout with VllmGeneration.""" + vllm_generation, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( + multi_step_setup_vllm + ) + max_turns = initial_batch["extra_env_info"][0]["max_steps"] + 1 + max_seq_len = 1024 + + print("\nRunning multi-step calculator rollout (VLLM)...") + vllm_generation.prepare_for_generation() + final_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=vllm_generation, + initial_batch=initial_batch, + tokenizer=rollout_tokenizer, + task_to_env=task_to_env, + max_seq_len=max_seq_len, + max_turns=max_turns, + ) + vllm_generation.finish_generation() + print("Multi-step calculator rollout complete (VLLM).") + + # --- Assertions --- + assert isinstance(final_batch, BatchedDataDict) + assert "message_log" in final_batch + assert "total_reward" in final_batch + assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) + + sample_log = final_batch["message_log"][0] + expected_final_answer = initial_batch["extra_env_info"][0]["expected_final_answer"] + print("\nSample Interaction Log (Multi-Step Calculator - VLLM):") + tool_call_count = 0 + final_answer_msg = None + for i, msg in enumerate(sample_log): + print(f" {i}: Role={msg['role']}, Content='{msg['content']}'") + if msg["role"] == "assistant": + if msg["content"].strip().endswith(""): + tool_call_count += 1 + else: + final_answer_msg = msg["content"].strip() + + assert tool_call_count >= 1, "Expected at least one tool call" + assert final_answer_msg is not None, ( + "Expected a final answer message from assistant" + ) + + final_answer_logic = _MultiStepCalculatorLogic() + extracted_final_answer = final_answer_logic._is_final_answer(final_answer_msg) + assert extracted_final_answer is not None, ( + f"Could not parse final answer from: {final_answer_msg}" + ) + assert abs(extracted_final_answer - expected_final_answer) < 1e-6, ( + f"Final answer incorrect. Expected {expected_final_answer}, Got {extracted_final_answer}" + ) + + assert torch.all(final_batch["total_reward"] == 1.0), ( + f"Expected total reward 1.0, got {final_batch['total_reward']}" + ) + + print("\nMulti-Step Calculator VLLM Test assertions passed.") + + +# --- Fixture for Sliding Puzzle Environment --- +@pytest.fixture(scope="function") +def sliding_puzzle_environment(rollout_cluster): + env_actor = None + print("Creating SlidingPuzzleEnv actor...") + try: + # Pass game config if needed, e.g., {"game_config": {"size": 3}} + env_actor = SlidingPuzzleEnv.remote() + task_to_env = {"sliding_puzzle_game": env_actor} + yield task_to_env, env_actor + finally: + print("Cleaning up sliding_puzzle_environment...") + if env_actor: + env_actor.shutdown.remote() + ray.kill(env_actor) + print("sliding_puzzle_environment cleanup finished.") + + +# --- Fixture for Sliding Puzzle Initial Batch --- +@pytest.fixture(scope="function") +def initial_sliding_puzzle_batch(rollout_tokenizer): + print("Creating initial sliding puzzle test batch...") + batch_size = 1 + game_config = { + "size": 2, + "shuffle_moves": 1, + } + max_moves = 25 # Set a limit for the test + + # Generate initial game state + initial_game_state = SlidingPuzzleGame.generate(game_config) + initial_render = SlidingPuzzleGame.render(initial_game_state) + welcome_message = SlidingPuzzleGame.init(initial_game_state) + + prompt_instructions = ( + f"{welcome_message}\n\n" + f"Current Board State:\n{initial_render}\n\n" + f"Reach the goal state where numbers are ordered 1 through {game_config['size'] ** 2 - 1} " + f"with the empty space (0) at the bottom right.\n" + f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n" + f"After thinking, output your chosen action on a new line starting with 'Action:' like this:\nAction: your_action" + f"Think step-by-step before acting. \n" + ) + + batch_message_logs = [] + batch_extra_env_info = [] + batch_loss_multipliers = [] + batch_indices = [] + batch_task_names = [] + + for i in range(batch_size): + # Apply chat template to the initial prompt + initial_prompt_content = rollout_tokenizer.apply_chat_template( + [{"role": "user", "content": prompt_instructions}], + tokenize=False, + add_system_prompt=True, # Include system prompt for Qwen + add_generation_prompt=True, + add_special_tokens=False, + ).strip() + tokenized_prompt = rollout_tokenizer( + initial_prompt_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + message_log = [ + { + "role": "user", + "content": initial_prompt_content, + "token_ids": tokenized_prompt, + } + ] + + metadata = SlidingPuzzleMetadata( + game_state=initial_game_state, num_moves=0, max_moves=max_moves + ) + + batch_message_logs.append(message_log) + batch_extra_env_info.append(metadata) + batch_loss_multipliers.append(1.0) + batch_indices.append(i) + batch_task_names.append("sliding_puzzle_game") + + initial_batch_dict = { + "message_log": batch_message_logs, + "extra_env_info": batch_extra_env_info, + "loss_multiplier": batch_loss_multipliers, + "idx": batch_indices, + "task_name": batch_task_names, + # No stop_strings needed initially, env provides + } + return BatchedDataDict(initial_batch_dict) + + +@pytest.fixture(scope="function") +def sliding_puzzle_setup_vllm( + rollout_cluster, + rollout_tokenizer, + sliding_puzzle_environment, + initial_sliding_puzzle_batch, +): + """Sets up components for sliding puzzle tests using VllmGeneration.""" + vllm_generation = None + task_to_env, _ = sliding_puzzle_environment + is_eval = True + print("Creating VllmGeneration for Sliding Puzzle Test...") + try: + vllm_config = deepcopy(base_vllm_test_config) + # Qwen model name is already in base config + vllm_config["tokenizer_name"] = rollout_tokenizer.name_or_path + vllm_config = configure_generation_config( + vllm_config, rollout_tokenizer, is_eval=is_eval + ) + # Ensure max_new_tokens is sufficient + vllm_config["max_new_tokens"] = 500 + vllm_generation = VllmGeneration(rollout_cluster, vllm_config) + vllm_generation.finish_generation() + yield ( + vllm_generation, + rollout_tokenizer, + task_to_env, + initial_sliding_puzzle_batch, + rollout_cluster, + ) + finally: + print("Cleaning up VllmGeneration (Sliding Puzzle Test)...") + if vllm_generation: + vllm_generation.shutdown() + print("VllmGeneration cleanup finished (Sliding Puzzle Test).") + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 1, + reason="VLLM test requires at least 1 GPU", +) +def test_run_sliding_puzzle_vllm(sliding_puzzle_setup_vllm): + """Tests sliding puzzle rollout with VllmGeneration.""" + vllm_generation, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( + sliding_puzzle_setup_vllm + ) + max_moves = initial_batch["extra_env_info"][0]["max_moves"] + max_turns = max_moves + 1 + max_seq_len = 2048 + + print("\nRunning sliding puzzle rollout (VLLM)...") + vllm_generation.prepare_for_generation() + final_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=vllm_generation, + initial_batch=initial_batch, + tokenizer=rollout_tokenizer, + task_to_env=task_to_env, + max_turns=max_turns, + max_seq_len=max_seq_len, + greedy=True, + ) + print(rollout_metrics) + vllm_generation.finish_generation() + print("Sliding puzzle rollout complete (VLLM).") + + # --- Assertions --- + assert isinstance(final_batch, BatchedDataDict) + assert "message_log" in final_batch + assert "total_reward" in final_batch + assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) + + sample_log = final_batch["message_log"][0] + print("\nSample Interaction Log (Sliding Puzzle - VLLM):") + action_tag_count = 0 + for i, msg in enumerate(sample_log): + print(f" {i}: Role={msg['role']}, Content starts with: '{msg['content']}'") + if msg["role"] == "assistant" and "action:" in msg["content"].lower(): + action_tag_count += 1 + + assert action_tag_count > 0, ( + "Expected at least one assistant message with 'Action:' prefix" + ) + + print(f"Final Total Reward: {final_batch['total_reward'][0].item()}") + assert final_batch["total_reward"][0] > 0.0, ( + f"Expected final reward to be greater than 0.0 (solved), but got {final_batch['total_reward'][0]}" + ) + + last_env_message = sample_log[-1]["content"] + assert "congratulations" in last_env_message.lower(), ( + "Last message should indicate puzzle solved" + ) + + print("\nSliding Puzzle VLLM Test assertions passed.") diff --git a/tests/unit/test_envs.py b/tests/unit/test_envs.py new file mode 100644 index 0000000000..ccf72c1d0c --- /dev/null +++ b/tests/unit/test_envs.py @@ -0,0 +1,409 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ray +import torch +from typing import Dict, List, Tuple, Optional, TypedDict, Literal, Any + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.data.interfaces import LLMMessageLogType +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) +from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES +from .environments.sliding_puzzle_game import SlidingPuzzleGame + + +class MultiStepCalcMetadata(TypedDict): + problem: str + expected_final_answer: float + max_steps: int + current_step: int + + +class SlidingPuzzleMetadata(TypedDict): + game_state: Dict[str, Any] # Stores the dict returned by SlidingPuzzleGame methods + num_moves: int + max_moves: int + + +class _MultiStepCalculatorLogic: + def __init__(self): + pass + + def _parse_tool_call(self, text: str) -> Optional[Tuple[float, float, str]]: + """Parses '[opA, opB, operation]'.""" + # Use a more distinct tool call suffix + tool_call_suffix = "" + if not text.strip().endswith(tool_call_suffix): + return None + + content = text.strip()[: -len(tool_call_suffix)].strip() + if not (content.startswith("[") and content.endswith("]")): + return None + parts = content[1:-1].split(",") + if len(parts) != 3: + return None + try: + op_a = float(parts[0].strip()) + op_b = float(parts[1].strip()) + operation = parts[2].strip().lower() + return op_a, op_b, operation + except ValueError: + return None + + def _calculate(self, op_a: float, op_b: float, operation: str) -> Optional[float]: + """Performs the calculation.""" + # (Reusing the calculation logic) + if operation == "sum": + return op_a + op_b + elif operation == "diff": + return op_a - op_b + elif operation == "prod": + return op_a * op_b + elif operation == "div": + if abs(op_b) < 1e-6: + return None # Division by zero error + return op_a / op_b + else: + return None # Unknown operation + + def _is_final_answer(self, text: str) -> Optional[float]: + """Checks if the text is just a final numerical answer.""" + try: + # Allow potential formatting like 16.0 + # or just the number itself. + processed_text = text.strip() + if processed_text.startswith("") and processed_text.endswith( + "" + ): + processed_text = processed_text[ + len("") : -len("") + ] + + return float(processed_text) + except ValueError: + return None + + def process_turn( + self, + message_log: LLMMessageLogType, + metadata: MultiStepCalcMetadata, + ) -> Tuple[ + Dict[str, str], + float, + bool, + Optional[List[str]], + Optional[MultiStepCalcMetadata], + ]: + """Processes a single turn for the multi-step calculator task.""" + last_assistant_msg = "" + if message_log and message_log[-1]["role"] == "assistant": + last_assistant_msg = message_log[-1]["content"].strip() + + current_step = metadata["current_step"] + max_steps = metadata["max_steps"] + expected_final_answer = metadata["expected_final_answer"] + + turn_reward = 0.0 + is_terminated = False + next_stop_strings = [ + "" + ] # Let model generate tool call or final answer freely + next_metadata = metadata.copy() + next_observation_content = "" + + # Check if max steps reached + if current_step >= max_steps: + is_terminated = True + next_observation_content = "Maximum steps reached." + next_metadata = None + return ( + {"role": "environment", "content": next_observation_content}, + 0.0, + is_terminated, + None, + next_metadata, + ) + + # Check for final answer first + final_answer = self._is_final_answer(last_assistant_msg) + if final_answer is not None: + is_terminated = True + next_metadata = None # End of episode + if abs(final_answer - expected_final_answer) < 1e-6: + turn_reward = 1.0 # Correct final answer + next_observation_content = ( + f"Correct! The final answer is {final_answer:.2f}." + ) + else: + turn_reward = 0.0 # Incorrect final answer + next_observation_content = f"Incorrect final answer. Expected {expected_final_answer:.2f}, got {final_answer:.2f}." + else: + # Check for tool call + parsed_call = self._parse_tool_call(last_assistant_msg) + if parsed_call: + req_op_a, req_op_b, req_op = parsed_call + result = self._calculate(req_op_a, req_op_b, req_op) + if result is not None: + # Tool call success, provide result + next_observation_content = f"{result:.5f}" + next_metadata["current_step"] += 1 + is_terminated = False + else: # Calculation failed + is_terminated = True + next_observation_content = "Calculation failed." + next_metadata = None + else: # No final answer and no valid tool call + is_terminated = True + next_observation_content = ( + "Invalid response. Expected tool call or final answer." + ) + next_metadata = None + + next_observation = {"role": "environment", "content": next_observation_content} + return ( + next_observation, + turn_reward, + is_terminated, + next_stop_strings, + next_metadata, + ) + + +class _SlidingPuzzleLogic: + def __init__(self): + pass # No initialization needed as game methods are static + + def _parse_action(self, text: str) -> Optional[str]: + """Parses the action from 'Action: ...' prefix.""" + prefix = "Action:" + # Find the prefix, case-insensitive, and potentially after some thought process + text_lower = text.lower() + prefix_lower = prefix.lower() + start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence + + if start_idx != -1: + # Return the part after the prefix + action_part = text[start_idx + len(prefix) :].strip() + # Take only the first line if multiple lines were generated after prefix + return action_part.split("\n")[0].strip() + return None + + def process_turn( + self, + message_log: LLMMessageLogType, + metadata: SlidingPuzzleMetadata, + ) -> Tuple[ + Dict[str, str], + float, + bool, + Optional[List[str]], + Optional[SlidingPuzzleMetadata], + ]: + """Processes a single turn for the sliding puzzle task.""" + game_state = metadata["game_state"] + current_moves = metadata["num_moves"] + max_moves = metadata["max_moves"] + + turn_reward = 0.0 + is_terminated = False + next_stop_strings = None # Let model finish its thought and action naturally + next_metadata = metadata.copy() + next_observation_content = "" + + # Check if max moves reached + if current_moves >= max_moves: + is_terminated = True + next_observation_content = ( + f"Maximum moves ({max_moves}) reached." + ) + next_metadata = None + return ( + {"role": "environment", "content": next_observation_content}, + 0.0, + is_terminated, + None, + next_metadata, + ) + + # Get last assistant message and parse action + last_assistant_msg_content = "" + if message_log and message_log[-1]["role"] == "assistant": + last_assistant_msg_content = message_log[-1]["content"].strip() + + parsed_action = self._parse_action(last_assistant_msg_content) + + if parsed_action is None: + # Handle cases where parsing failed or it wasn't assistant's turn properly + # is_terminated = True # Penalize for bad format + next_observation_content = ( + "\nInvalid response format. Try 'Action: your_move'." + ) + next_metadata = None + else: + # Execute the game step + step_response, reward, game_over, next_game_state = SlidingPuzzleGame.step( + parsed_action, game_state + ) + + turn_reward = reward + is_terminated = game_over + next_metadata["game_state"] = next_game_state + next_metadata["num_moves"] = current_moves + 1 + + # Combine rendered board and step response for the next observation + rendered_board = SlidingPuzzleGame.render(next_game_state) + next_observation_content = f"{rendered_board}\n\n{step_response}" + + if is_terminated: + next_metadata = None # Clear metadata on termination + # next_stop_strings remains None + + return ( + {"role": "environment", "content": next_observation_content + "\n"}, + turn_reward, + is_terminated, + next_stop_strings, + next_metadata, + ) + + +@ray.remote +class MultiStepCalculatorEnv(EnvironmentInterface): + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM + """Multi-step calculator environment (Ray Actor).""" + + def __init__(self, cfg: Optional[Dict] = None): + self.logic = _MultiStepCalculatorLogic() + + def step( + self, + message_log_batch: List[LLMMessageLogType], + metadata_batch: List[MultiStepCalcMetadata], + ) -> EnvironmentReturn: + """Processes a batch of interactions using the calculator logic.""" + futures = [ + self.logic.process_turn(log, meta) + for log, meta in zip(message_log_batch, metadata_batch) + ] + results = futures + + # Unpack results and format according to EnvironmentReturn tuple + observations = [] + rewards = [] + terminateds = [] + all_stop_strings = [] # List of Lists or Nones + all_next_metadata = [] + + for obs, rew, term, stops, meta in results: + observations.append(obs) # obs is already Dict[str, str] + rewards.append(rew) + terminateds.append(term) + all_stop_strings.append(stops) + all_next_metadata.append(meta) + + # Convert to tensors where needed + rewards_tensor = torch.tensor(rewards, dtype=torch.float32) + # Done flag combines termination and truncation (truncation not used here) + done_tensor = torch.tensor(terminateds, dtype=torch.bool) + + # Return tuple matching EnvironmentReturn NamedTuple + return EnvironmentReturn( + observations=observations, + metadata=all_next_metadata, + next_stop_strings=all_stop_strings, + rewards=rewards_tensor, + terminated=done_tensor, + ) + + def shutdown(self): + pass + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + # Example: could calculate success rate based on final reward + final_rewards = batch.get( + "total_reward", torch.tensor([0.0] * len(batch["idx"])) + ) + success_rate = final_rewards.mean().item() if len(final_rewards) > 0 else 0.0 + return batch, {"success_rate": success_rate} + + +@ray.remote +class SlidingPuzzleEnv(EnvironmentInterface): + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM + """Sliding Puzzle environment (Ray Actor).""" + + def __init__(self, cfg: Optional[Dict] = None): + # cfg could contain game generation config like {'size': 3, 'shuffle_moves': 50} + self.game_config = cfg.get("game_config", {}) if cfg else {} + self.logic = _SlidingPuzzleLogic() + + def step( + self, + message_log_batch: List[LLMMessageLogType], + metadata_batch: List[SlidingPuzzleMetadata], + ) -> EnvironmentReturn: + """Processes a batch of sliding puzzle interactions.""" + # Since logic is synchronous, process sequentially (can parallelize if logic becomes heavy) + results = [ + self.logic.process_turn(log, meta) + for log, meta in zip(message_log_batch, metadata_batch) + ] + + # Unpack results and format according to EnvironmentReturn NamedTuple + observations = [] + rewards = [] + terminateds = [] + all_stop_strings = [] + all_next_metadata = [] + + for obs, rew, term, stops, meta in results: + observations.append(obs) + rewards.append(rew) + terminateds.append(term) + all_stop_strings.append(stops) + all_next_metadata.append(meta) + + rewards_tensor = torch.tensor(rewards, dtype=torch.float32) + terminated_tensor = torch.tensor(terminateds, dtype=torch.bool) + + return EnvironmentReturn( + observations=observations, + metadata=all_next_metadata, + next_stop_strings=all_stop_strings, + rewards=rewards_tensor, + terminated=terminated_tensor, + ) + + def shutdown(self): + pass + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + # Calculate success rate based on final reward == 1.0 + final_rewards = batch.get( + "total_reward", torch.tensor([0.0] * len(batch["idx"])) + ) + success_rate = ( + (final_rewards == 1.0).float().mean().item() + if len(final_rewards) > 0 + else 0.0 + ) + # Could also calculate average number of moves for successful episodes, etc. + return batch, {"sliding_puzzle_success_rate": success_rate} From bbbbe88b9966f4d98dd1a1f901796d7d265b5eb8 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 14:02:34 -0700 Subject: [PATCH 02/34] Removed redundant imports Signed-off-by: Sahil Jain --- tests/unit/experience/test_rollouts.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index 412654fff8..f9bfdbb2ad 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -15,21 +15,17 @@ import pytest import ray import torch -from typing import Dict, List, Tuple, Optional, TypedDict from copy import deepcopy import gc -from transformers import AutoTokenizer, GPT2TokenizerFast +from transformers import AutoTokenizer from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster -from nemo_reinforcer.data.interfaces import DatumSpec, LLMMessageLogType -from nemo_reinforcer.environments.interfaces import EnvironmentInterface from nemo_reinforcer.models.policy import PolicyConfig from nemo_reinforcer.models.policy.hf_policy import HfPolicy from nemo_reinforcer.models.generation.interfaces import configure_generation_config from nemo_reinforcer.experience.rollouts import run_multi_turn_rollout -from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES # Import the test environment definitions from tests.unit.test_envs import ( From 811067b589591cd00aa037c60db61ce938e060b8 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 14:46:04 -0700 Subject: [PATCH 03/34] Fixed Math env on mutliturn Signed-off-by: Sahil Jain --- nemo_reinforcer/environments/math_environment.py | 2 +- nemo_reinforcer/experience/rollouts.py | 7 +++++-- tests/unit/experience/test_rollouts.py | 5 +++++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/nemo_reinforcer/environments/math_environment.py b/nemo_reinforcer/environments/math_environment.py index b72c031e29..efbfcc01bc 100644 --- a/nemo_reinforcer/environments/math_environment.py +++ b/nemo_reinforcer/environments/math_environment.py @@ -148,7 +148,7 @@ def step( rewards = torch.tensor(results).cpu() done = torch.ones_like(rewards).cpu() - next_stop_strings = None + next_stop_strings = [[None]] * len(message_log_batch) return EnvironmentReturn( observations=observations, diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py index d402bdff4c..14c6ef5a1e 100644 --- a/nemo_reinforcer/experience/rollouts.py +++ b/nemo_reinforcer/experience/rollouts.py @@ -173,6 +173,8 @@ def calculate_rewards( env_observations, metadata, next_stop_strings, task_rewards, terminateds = ( result ) + if next_stop_strings is None: + next_stop_strings = [[None]] * len(task_rewards) # Store results with their original indices for i, idx in enumerate(indices): @@ -324,7 +326,7 @@ def run_multi_turn_rollout( # Update message log for ALL active samples with env observation # This must happen BEFORE filtering based on done flags - truncation_mask = torch.zeros_like(env_output.terminateds) + truncation_mask = torch.zeros_like(env_output.terminateds, dtype=torch.bool) for i, global_idx in enumerate(active_indices.tolist()): env_obs_content = env_output.env_observations[i]["content"] # Tokenize the raw content from the environment @@ -356,7 +358,8 @@ def run_multi_turn_rollout( sample_turn_counts[global_idx] += 1 # Determine done samples and update active set - done = env_output.terminateds | truncation_mask + terminateds = env_output.terminateds.bool() + done = terminateds | truncation_mask active_mask = ~done # Identify samples that just finished this turn diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index f9bfdbb2ad..04639a4275 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -302,6 +302,11 @@ def multi_step_setup_vllm( print("Cleaning up VllmGeneration (Multi-Step Calc Test)...") if vllm_generation: vllm_generation.shutdown() + # Force garbage collection to help release resources + import gc + + gc.collect() + torch.cuda.empty_cache() print("VllmGeneration cleanup finished (Multi-Step Calc Test).") From d1c59b2c040e7b344017c7ca946b5dd23a26f1d7 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 15:44:10 -0700 Subject: [PATCH 04/34] Fixed nondetermistic multiturn error bug Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/grpo.py | 19 ++++++++++--------- nemo_reinforcer/data/llm_message_utils.py | 7 +++++-- nemo_reinforcer/experience/rollouts.py | 17 ----------------- tests/unit/environments/game_interface.py | 14 ++++++++++++++ .../unit/environments/sliding_puzzle_game.py | 14 ++++++++++++++ 5 files changed, 43 insertions(+), 28 deletions(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 4116ae11a6..9df248a4ba 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -637,23 +637,24 @@ def validate( ) # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs) - val_batch, generated_ids, gen_metrics = generate_responses( + val_batch, gen_metrics = run_multi_turn_rollout( policy_generation, generation_input_data, - val_batch, tokenizer, - input_lengths, - include_logprobs=False, + val_task_to_env, + max_seq_len=master_config["policy"]["max_total_sequence_length"], + max_turns=master_config["policy"]["max_turns"], + greedy=False, ) - - # Calculate rewards based on the updated LLMMessageLogType - with timer.time("reward_calculation"): - rewards, to_env = calculate_rewards(val_batch, val_task_to_env) + rewards = val_batch["total_reward"] total_rewards.extend(rewards.tolist()) - total_lengths.extend([len(ids) for ids in generated_ids]) + total_lengths.extend(gen_metrics["mean_generation_length"].tolist()) # Collect message logs for later display + to_env = get_keys_from_message_log( + val_batch["message_log"], ["role", "content"] + ) all_message_logs.extend(to_env) # Calculate validation metrics diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index 362183d978..db908f7bc9 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -289,8 +289,11 @@ def batched_message_log_to_flat_message( # Create input_lengths tensor input_lengths = [] for seq in sequenced_lists: - seq_len = next( - (v.size(0) for v in seq.values() if isinstance(v, torch.Tensor)), 0 + # Find the maximum length among all tensors in the dictionary, default to 0 if none exist + # Use maximum here since there may be keys that aren't populated for all messages yet. + # For example, logprobs don't get populated for non-generated tokens until post-processing. + seq_len = max( + (v.size(0) for v in seq.values() if isinstance(v, torch.Tensor)), default=0 ) input_lengths.append(seq_len) input_lengths_tensor = torch.tensor(input_lengths, dtype=torch.int32) diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py index 14c6ef5a1e..a6ec956a0c 100644 --- a/nemo_reinforcer/experience/rollouts.py +++ b/nemo_reinforcer/experience/rollouts.py @@ -258,13 +258,8 @@ def run_multi_turn_rollout( for turn in range(max_turns): if len(active_indices) == 0: - print(f" Turn {turn + 1}/{max_turns}: All samples finished.") break - print( - f" Turn {turn + 1}/{max_turns}: Processing {len(active_indices)} active samples..." - ) - active_samples_per_turn.append(len(active_indices)) # Convert LLMMessageLogType to FlatMessagesType for generation @@ -299,9 +294,6 @@ def run_multi_turn_rollout( active_input_lengths, greedy=greedy, ) - print( - f" Generated responses (Avg len: {gen_metrics['mean_generation_length']:.1f})" - ) # Record token usage - assistant for i, global_idx in enumerate(active_indices.tolist()): @@ -320,10 +312,6 @@ def run_multi_turn_rollout( # Record rewards for this turn reward_per_turn.append(env_output.rewards.mean().item()) - print( - f" Calculated rewards (Avg: {turn_rewards[active_indices].mean():.3f})" - ) - # Update message log for ALL active samples with env observation # This must happen BEFORE filtering based on done flags truncation_mask = torch.zeros_like(env_output.terminateds, dtype=torch.bool) @@ -373,11 +361,6 @@ def run_multi_turn_rollout( if env_output.terminateds[idx]: sample_terminated[global_idx] = True - print( - f" {len(newly_finished_indices_global)} samples finished this turn." - f" (Terminated: {env_output.terminateds.sum()})" - ) - # Update active indices for the next iteration active_indices_local_next = torch.where(active_mask)[0] active_indices = active_indices[active_indices_local_next] diff --git a/tests/unit/environments/game_interface.py b/tests/unit/environments/game_interface.py index 77c7396a18..2f0237ed23 100644 --- a/tests/unit/environments/game_interface.py +++ b/tests/unit/environments/game_interface.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Dict, Any, Tuple, List, Optional, Callable diff --git a/tests/unit/environments/sliding_puzzle_game.py b/tests/unit/environments/sliding_puzzle_game.py index 584eb6b3fb..664e4c312b 100644 --- a/tests/unit/environments/sliding_puzzle_game.py +++ b/tests/unit/environments/sliding_puzzle_game.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import random import copy from typing import List, Tuple, Dict, Any, Optional From 95c3218d7952274be8517d98d22e1b05473f0706 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 17:24:32 -0700 Subject: [PATCH 05/34] Fixed validation error Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/grpo.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 9df248a4ba..09d4249675 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -621,25 +621,10 @@ def validate( if batch_idx >= max_batches: break - # Convert LLMMessageLogType to FlatMessagesType for generation - batched_flat, input_lengths = batched_message_log_to_flat_message( - val_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - ) - # Extract input IDs - input_ids = batched_flat["token_ids"] - # Create generation-specific input structure - generation_input_data = BatchedDataDict( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - } - ) - # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs) val_batch, gen_metrics = run_multi_turn_rollout( policy_generation, - generation_input_data, + val_batch, tokenizer, val_task_to_env, max_seq_len=master_config["policy"]["max_total_sequence_length"], From ceb6035bcc07b45ad966765a7da16e59cff0a913 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 17:30:17 -0700 Subject: [PATCH 06/34] Fixed validation error Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/grpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 09d4249675..ef1e5b2d66 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -634,7 +634,7 @@ def validate( rewards = val_batch["total_reward"] total_rewards.extend(rewards.tolist()) - total_lengths.extend(gen_metrics["mean_generation_length"].tolist()) + total_lengths.extend(gen_metrics["mean_gen_tokens_per_sample"].tolist()) # Collect message logs for later display to_env = get_keys_from_message_log( From 14a34173f36de1a9e162d4d41f9855a80e6e664f Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 17:34:59 -0700 Subject: [PATCH 07/34] Fixed validation error Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/grpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index ef1e5b2d66..1c0d8d585c 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -634,7 +634,7 @@ def validate( rewards = val_batch["total_reward"] total_rewards.extend(rewards.tolist()) - total_lengths.extend(gen_metrics["mean_gen_tokens_per_sample"].tolist()) + total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) # Collect message logs for later display to_env = get_keys_from_message_log( From d112702a2a589119138ecf6bb9790fe16de54ec8 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 17:55:23 -0700 Subject: [PATCH 08/34] <1 lp error ?? Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/loss_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 0d2e61a9ac..0286bcd8d0 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -98,6 +98,9 @@ def __call__( lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now) mult_prob_error = masked_mean(torch.exp(lp_error), mask).item() + if mult_prob_error == 0.0: + # this sometimes gets 0 (everything masked/invalid). Doing this to avoid screwing up stats too much + mult_prob_error = 1.0 next_token_logits = next_token_logits.to(torch.float32) From f370c4a56e2c640d7593579aab665d80f03a0785 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 18:10:20 -0700 Subject: [PATCH 09/34] debugging Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/loss_functions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 0286bcd8d0..ec9ad84c52 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -100,6 +100,10 @@ def __call__( mult_prob_error = masked_mean(torch.exp(lp_error), mask).item() if mult_prob_error == 0.0: # this sometimes gets 0 (everything masked/invalid). Doing this to avoid screwing up stats too much + print("mult_prob_error is 0") + print("mask sum", mask.sum()) + print(f"token_mask: {token_mask}") + print(f"sample_mask: {sample_mask}") mult_prob_error = 1.0 next_token_logits = next_token_logits.to(torch.float32) From f3a5001c66f11106c930b21a48862d9d89c0a897 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 17 Apr 2025 21:25:09 -0700 Subject: [PATCH 10/34] remove debugging Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/loss_functions.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index ec9ad84c52..0286bcd8d0 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -100,10 +100,6 @@ def __call__( mult_prob_error = masked_mean(torch.exp(lp_error), mask).item() if mult_prob_error == 0.0: # this sometimes gets 0 (everything masked/invalid). Doing this to avoid screwing up stats too much - print("mult_prob_error is 0") - print("mask sum", mask.sum()) - print(f"token_mask: {token_mask}") - print(f"sample_mask: {sample_mask}") mult_prob_error = 1.0 next_token_logits = next_token_logits.to(torch.float32) From 5c90d641e56094438caff8a7a2310cb6b2ca1542 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Fri, 18 Apr 2025 14:40:05 -0700 Subject: [PATCH 11/34] cleanup Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/grpo.py | 7 -- nemo_reinforcer/environments/interfaces.py | 20 +++++- .../environments/math_environment.py | 2 +- nemo_reinforcer/experience/rollouts.py | 68 ++++++------------- tests/unit/algorithms/test_grpo.py | 6 +- .../environments/test_math_environment.py | 10 +-- tests/unit/test_envs.py | 4 +- 7 files changed, 48 insertions(+), 69 deletions(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 1c0d8d585c..1b829fef64 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -363,13 +363,6 @@ def grpo_train( pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) input_ids = batched_flat["token_ids"] - # # Create generation-specific input structure - # generation_input_data = BatchedDataDict[GenerationDatumSpec]( - # { - # "input_ids": input_ids, - # "input_lengths": input_lengths, - # } - # ) # Generate responses - this updates the LLMMessageLogType in repeated_batch print(f"▶ Generating responses for batch of size {repeated_batch.size}...") diff --git a/nemo_reinforcer/environments/interfaces.py b/nemo_reinforcer/environments/interfaces.py index 881f3467b4..46f42bc24f 100644 --- a/nemo_reinforcer/environments/interfaces.py +++ b/nemo_reinforcer/environments/interfaces.py @@ -21,13 +21,27 @@ class EnvironmentReturn(NamedTuple): - """Standard return type for environment step methods.""" + """Standard batched return type for environment step methods. + + **All elements are batched.** + observations: New observation from the environment. + It's a (batched) 'message' type, which is a dict + with keys 'role' and 'content'. + metadata: Updated metadata from the environment. + next_stop_strings: The stop strings for the next turn. + If your environment is a game or similar, + you may want to return a list of stop strings + that are valid actions for the next turn or + similar. This field lets you control this per turn. + rewards: the rewards for this turn. + terminateds: whether the episode ended this turn. + """ observations: List[Dict[str, str]] metadata: List[Optional[dict]] next_stop_strings: List[Optional[List[str]]] rewards: Tensor - terminated: Tensor + terminateds: Tensor class EnvironmentInterface(abc.ABC): @@ -61,7 +75,7 @@ def step( math solutions, code unit tests, or agent states. Can be None if episode terminated. Returns: - - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminated flags. + - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags. """ @abc.abstractmethod diff --git a/nemo_reinforcer/environments/math_environment.py b/nemo_reinforcer/environments/math_environment.py index efbfcc01bc..81f41eb3da 100644 --- a/nemo_reinforcer/environments/math_environment.py +++ b/nemo_reinforcer/environments/math_environment.py @@ -155,7 +155,7 @@ def step( metadata=metadata, next_stop_strings=next_stop_strings, rewards=rewards, - terminated=done, + terminateds=done, ) def global_post_process_and_metrics( diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py index a6ec956a0c..46a12ba641 100644 --- a/nemo_reinforcer/experience/rollouts.py +++ b/nemo_reinforcer/experience/rollouts.py @@ -39,15 +39,6 @@ ) -# Return type for calculate_rewards -class RewardsOutput(NamedTuple): - rewards: torch.Tensor - env_observations: List[Dict[str, str]] - terminateds: torch.Tensor - next_stop_strings: List[Optional[List[str]]] - metadata: List[Optional[Dict[str, Any]]] - - def generate_responses( policy_generation: GenerationInterface, generation_input_data: BatchedDataDict[GenerationDatumSpec], @@ -111,7 +102,7 @@ def generate_responses( def calculate_rewards( batch: BatchedDataDict[DatumSpec], task_to_env: Dict[str, EnvironmentInterface], -) -> RewardsOutput: +) -> EnvironmentReturn: """Calculate rewards for generated responses and get environment feedback. Args: @@ -119,19 +110,19 @@ def calculate_rewards( task_to_env: Dictionary mapping task names to their corresponding environments Returns: - Tuple containing: + EnvironmentReturn namedtuple containing: + - observations: List of observations from the environment for the next turn. + - metadata: List of extracted metadata from the environment. + - next_stop_strings: List of stop strings for the next generation step. - rewards: Tensor of rewards for the last turn. - - env_observations: List of observations from the environment for the next turn. - terminateds: Tensor of booleans indicating if an episode ended naturally. - - next_stop_strings: List of stop strings for the next generation step. - - metadata: List of extracted metadata from the environment. """ # Extract message logs for environment (most recent interaction) to_env = [ get_keys_from_message_log(batch["message_log"][i], ["role", "content"]) for i in range(len(batch["message_log"])) ] - task_names = [batch["task_name"][i] for i in range(len(batch["task_name"]))] + task_names = batch["task_name"] # Group messages by task type task_groups = {} @@ -195,16 +186,15 @@ def calculate_rewards( next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices] metadata = [all_metadata[i] for i in sorted_indices] # Sort metadata - # Ensure tensors are on CPU - rewards = rewards.cpu() - terminateds = terminateds.cpu() + rewards = rewards + terminateds = terminateds - return RewardsOutput( + return EnvironmentReturn( + observations=env_observations, + metadata=metadata, + next_stop_strings=next_stop_strings, rewards=rewards, - env_observations=env_observations, terminateds=terminateds, - next_stop_strings=next_stop_strings, - metadata=metadata, ) @@ -236,7 +226,6 @@ def run_multi_turn_rollout( current_batch = initial_batch.copy() # Work on a copy batch_size = len(current_batch["message_log"]) active_indices = torch.arange(batch_size) - turn_rewards = torch.zeros(batch_size, dtype=torch.float32) total_rewards = torch.zeros(batch_size, dtype=torch.float32) # Initialize stop_strings from the initial batch if present @@ -253,7 +242,6 @@ def run_multi_turn_rollout( # Tracking per-turn metrics total_gen_tokens_per_turn = [] - reward_per_turn = [] active_samples_per_turn = [] for turn in range(max_turns): @@ -304,19 +292,15 @@ def run_multi_turn_rollout( total_gen_tokens_per_turn.append(sum(len(ids) for ids in generated_ids)) # Calculate rewards and get environment feedback - env_output: RewardsOutput = calculate_rewards(active_batch, task_to_env) + env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env) - turn_rewards[active_indices] = env_output.rewards - total_rewards[active_indices] += turn_rewards[active_indices] - - # Record rewards for this turn - reward_per_turn.append(env_output.rewards.mean().item()) + total_rewards[active_indices] += env_output.rewards # Update message log for ALL active samples with env observation # This must happen BEFORE filtering based on done flags truncation_mask = torch.zeros_like(env_output.terminateds, dtype=torch.bool) for i, global_idx in enumerate(active_indices.tolist()): - env_obs_content = env_output.env_observations[i]["content"] + env_obs_content = env_output.observations[i]["content"] # Tokenize the raw content from the environment # TODO @sahilj: handle if we want these subsequent messages to have a chat template tokenized_obs = tokenizer( @@ -332,7 +316,7 @@ def run_multi_turn_rollout( sample_truncated[active_indices[i]] = True tokenized_env_obs_message = { - "role": env_output.env_observations[i]["role"], + "role": env_output.observations[i]["role"], "content": env_obs_content, "token_ids": tokenized_obs, } @@ -347,22 +331,11 @@ def run_multi_turn_rollout( # Determine done samples and update active set terminateds = env_output.terminateds.bool() - done = terminateds | truncation_mask - active_mask = ~done - - # Identify samples that just finished this turn - newly_finished_indices_local = torch.where(done)[0] - newly_finished_indices_global = active_indices[newly_finished_indices_local] - - # Record termination status - for i, idx in enumerate(newly_finished_indices_local.tolist()): - global_idx = active_indices[idx].item() - # Record whether this sample terminated naturally - if env_output.terminateds[idx]: - sample_terminated[global_idx] = True + done = truncation_mask | terminateds + sample_terminated[active_indices] |= done # Update active indices for the next iteration - active_indices_local_next = torch.where(active_mask)[0] + active_indices_local_next = torch.where(~done)[0] active_indices = active_indices[active_indices_local_next] continuing_indices_global = active_indices # Indices relative to original batch # Get next stop strings and infos corresponding to the indices that are *continuing* @@ -382,8 +355,7 @@ def run_multi_turn_rollout( current_batch["extra_env_info"][global_idx] = continuing_metadata[i] # Record samples that reached max turns - if len(active_indices) > 0: - sample_max_turns_reached[active_indices] = True + sample_max_turns_reached[active_indices] = True # Add total rewards to the final batch current_batch["total_reward"] = total_rewards diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 7a492e2b01..da1a21244f 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -126,7 +126,7 @@ def test_calculate_rewards_single_task(mock_env): batch = create_mock_batch(2, task_names, message_logs) # Calculate rewards - rewards, env_observations, terminateds, next_stop_strings, metadata = ( + env_observations, metadata, next_stop_strings, rewards, terminateds = ( calculate_rewards(batch, task_to_env) ) @@ -161,7 +161,7 @@ def test_calculate_rewards_multiple_tasks(mock_envs): batch = create_mock_batch(4, task_names, message_logs) # Calculate rewards - rewards, env_observations, terminateds, next_stop_strings, metadata = ( + env_observations, metadata, next_stop_strings, rewards, terminateds = ( calculate_rewards(batch, mock_envs) ) @@ -188,7 +188,7 @@ def test_calculate_rewards_empty_batch(mock_env): batch = create_mock_batch(0, [], []) # Calculate rewards - rewards, env_observations, terminateds, next_stop_strings, metadata = ( + env_observations, metadata, next_stop_strings, rewards, terminateds = ( calculate_rewards(batch, task_to_env) ) diff --git a/tests/unit/environments/test_math_environment.py b/tests/unit/environments/test_math_environment.py index 6fb9897e7b..c26035ce15 100644 --- a/tests/unit/environments/test_math_environment.py +++ b/tests/unit/environments/test_math_environment.py @@ -135,10 +135,10 @@ def test_math_env_step_basic(math_env, basic_test_data): # Check rewards and done flags assert result.rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" assert all(result.rewards == 1.0), "All rewards should be 1.0 for correct answers" - assert result.terminated.shape == (3,), ( + assert result.terminateds.shape == (3,), ( "Terminated flags should be a tensor of shape (3,)" ) - assert all(result.terminated == 1.0), "All terminated flags should be 1.0" + assert all(result.terminateds == 1.0), "All terminated flags should be 1.0" def test_math_env_step_mixed(math_env, mixed_test_data): @@ -177,7 +177,7 @@ def test_math_env_step_empty(math_env): assert len(result.observations) == 0, "Should return empty observations list" assert len(result.metadata) == 0, "Should return empty metadata list" assert result.rewards.shape == (0,), "Should return empty rewards tensor" - assert result.terminated.shape == (0,), "Should return empty terminated tensor" + assert result.terminateds.shape == (0,), "Should return empty terminateds tensor" def test_math_env_step_multiple_assistant_messages( @@ -225,7 +225,7 @@ def test_math_env_various_batches(math_env, batch_size): "Rewards should be a tensor of shape (batch_size,)" ) assert all(result.rewards == 1.0), "All rewards should be 1.0" - assert result.terminated.shape == (batch_size,), ( + assert result.terminateds.shape == (batch_size,), ( "Terminated flags should be a tensor of shape (batch_size,)" ) - assert all(result.terminated == 1.0), "All terminated flags should be 1.0" + assert all(result.terminateds == 1.0), "All terminated flags should be 1.0" diff --git a/tests/unit/test_envs.py b/tests/unit/test_envs.py index ccf72c1d0c..242dd1fb62 100644 --- a/tests/unit/test_envs.py +++ b/tests/unit/test_envs.py @@ -326,7 +326,7 @@ def step( metadata=all_next_metadata, next_stop_strings=all_stop_strings, rewards=rewards_tensor, - terminated=done_tensor, + terminateds=done_tensor, ) def shutdown(self): @@ -387,7 +387,7 @@ def step( metadata=all_next_metadata, next_stop_strings=all_stop_strings, rewards=rewards_tensor, - terminated=terminated_tensor, + terminateds=terminated_tensor, ) def shutdown(self): From 7a803f090f3fc9fe31b3570ebeaf4cb045eb6cb0 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Fri, 18 Apr 2025 16:25:37 -0700 Subject: [PATCH 12/34] Fix multiturn multigpu bugs Signed-off-by: Sahil Jain --- .../distributed/batched_data_dict.py | 29 +++++++++++++++---- nemo_reinforcer/experience/rollouts.py | 5 +++- nemo_reinforcer/models/generation/vllm.py | 16 ++++++++-- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/nemo_reinforcer/distributed/batched_data_dict.py b/nemo_reinforcer/distributed/batched_data_dict.py index a325dcff49..994781d4b2 100644 --- a/nemo_reinforcer/distributed/batched_data_dict.py +++ b/nemo_reinforcer/distributed/batched_data_dict.py @@ -137,7 +137,10 @@ def chunk(self, rank: int, chunks: int) -> "SlicedDataDict": return chunked_batch def shard_by_batch_size( - self, shards: int, batch_size: Optional[int] = None + self, + shards: int, + batch_size: Optional[int] = None, + allow_uneven_shards: bool = False, ) -> List["SlicedDataDict"]: """Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position. @@ -150,10 +153,17 @@ def shard_by_batch_size( Args: shards (int): The number of shards to divide each batch_size chunk into. batch_size (int): The size of each initial chunk. + allow_uneven_shards (bool): Whether to allow shards to be unevenly sized. + If True, the last shard may be smaller than the others. Returns: List[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards. """ + if allow_uneven_shards: + assert batch_size is None, ( + "batch_size must be None if allow_uneven_shards is True" + ) + # Get the total batch size batch_sizes = set() for val in self.data.values(): @@ -173,13 +183,18 @@ def shard_by_batch_size( assert total_batch_size % batch_size == 0, ( f"Total batch size ({total_batch_size}) is not a multiple of batch_size ({batch_size})" ) - assert batch_size % shards == 0, ( - f"Batch size ({batch_size}) is not a multiple of shards ({shards})" - ) + if not allow_uneven_shards: + assert batch_size % shards == 0, ( + f"Batch size ({batch_size}) is not a multiple of shards ({shards})" + ) num_chunks = total_batch_size // batch_size - shard_size = batch_size // shards - # Create one BatchedDataDict per shard position + # Calculate shard size, rounding up if not evenly divisible + shard_size = ( + (batch_size + shards - 1) // shards + if allow_uneven_shards + else batch_size // shards + ) aggregated_shards = [SlicedDataDict() for _ in range(shards)] # Group data by shard position across all chunks @@ -189,6 +204,8 @@ def shard_by_batch_size( chunk_start = chunk_idx * batch_size shard_start = chunk_start + shard_idx * shard_size shard_end = chunk_start + (shard_idx + 1) * shard_size + if allow_uneven_shards: + shard_end = min(shard_end, total_batch_size) indices = torch.arange(shard_start, shard_end) for k in self.data: diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py index 46a12ba641..891c5dae16 100644 --- a/nemo_reinforcer/experience/rollouts.py +++ b/nemo_reinforcer/experience/rollouts.py @@ -308,7 +308,10 @@ def run_multi_turn_rollout( )["input_ids"][0] # check if new message overflows max_seq_len - if len(tokenized_obs) + active_input_lengths[i] > max_seq_len: + if ( + len(tokenized_obs) + len(generated_ids[i]) + active_input_lengths[i] + >= max_seq_len + ): # truncate tokenized_obs = tokenized_obs[: max_seq_len - active_input_lengths[i]] truncation_mask[i] = True diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index 57ecde59ed..03851f2de7 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -224,6 +224,18 @@ def generate( - generation_lengths: Lengths of each response - unpadded_sequence_lengths: Lengths of each input + generated sequence """ + # Handle empty input case + if len(data["input_ids"]) == 0: + # Return empty BatchedDataDict with all required fields + return BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": torch.zeros((0, 0), dtype=torch.long), + "logprobs": torch.zeros((0, 0), dtype=torch.float), + "generation_lengths": torch.zeros(0, dtype=torch.long), + "unpadded_sequence_lengths": torch.zeros(0, dtype=torch.long), + } + ) + input_ids = data["input_ids"] input_lengths = data["input_lengths"] # this function requires all generations have the same stop strings, so we collect all here @@ -535,10 +547,8 @@ def generate( "input_ids and input_lengths are required in data for vLLM generation" ) - batch_size = data["input_ids"].shape[0] - # Shard the data across the tied worker groups - sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=batch_size) + sharded_data = data.shard_by_batch_size(self.dp_size, allow_uneven_shards=True) future_bundle = self.worker_group.run_all_workers_multiple_data( "generate", sharded_data, From c9d1298378df586260b50a69da4d51d9dda3f981 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Fri, 18 Apr 2025 16:26:25 -0700 Subject: [PATCH 13/34] adding sliding puzzle trianing scripts Signed-off-by: Sahil Jain --- examples/configs/grpo_sliding_puzzle.yaml | 107 ++++++ examples/run_grpo_sliding_puzzle.py | 378 ++++++++++++++++++++++ 2 files changed, 485 insertions(+) create mode 100644 examples/configs/grpo_sliding_puzzle.yaml create mode 100644 examples/run_grpo_sliding_puzzle.py diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml new file mode 100644 index 0000000000..42682d92c9 --- /dev/null +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -0,0 +1,107 @@ +defaults: "grpo_math_1B.yaml" + +# Environment setup: Map task names to their configurations +env: + sliding_puzzle_game: + env_class: "tests.unit.test_envs.SlidingPuzzleEnv" # Path to the environment actor class + # Configuration passed to the SlidingPuzzleEnv constructor + cfg: + # Game generation parameters + game_config: + size: 3 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) + shuffle_moves: 5 # Number of random moves to shuffle the solved state + # Gameplay parameters + max_moves: 50 # Maximum moves allowed per episode + +grpo: + max_num_steps: 10000 + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo-sliding-puzzle" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + +policy: + model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 512 + train_micro_batch_size: 4 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: 4 + max_total_sequence_length: 1024 + max_turns: 50 + precision: "bfloat16" + fsdp_offload_enabled: false + activation_checkpointing_enabled: false + + dtensor_cfg: + enabled: false + cpu_offload: False + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "grpo-dev-sliding_puzzle" + tensorboard: {} + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 1 + num_nodes: 1 + diff --git a/examples/run_grpo_sliding_puzzle.py b/examples/run_grpo_sliding_puzzle.py new file mode 100644 index 0000000000..b6a549b354 --- /dev/null +++ b/examples/run_grpo_sliding_puzzle.py @@ -0,0 +1,378 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +import itertools # For infinite counter +from collections import defaultdict +from typing import Any, Dict, Tuple, List, Iterator # Added Iterator + +import torch # Added torch import +from omegaconf import OmegaConf +from transformers import AutoTokenizer + +# === MODIFIED: Use IterableDataset === +from torch.utils.data import Dataset, IterableDataset # Import IterableDataset + +# === Core Imports (Keep from math example) === +from nemo_reinforcer.algorithms.grpo import ( + MasterConfig, + grpo_train, + setup, +) # CRITICAL: Keep imported setup +from nemo_reinforcer.algorithms.utils import get_tokenizer + +# from nemo_reinforcer.data import DataConfig # Keep if setup needs it, maybe remove later +# from nemo_reinforcer.data.interfaces import TaskDataSpec # Remove later if not needed by setup_puzzle_data +from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.models.generation.interfaces import configure_generation_config +from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides +from nemo_reinforcer.utils.logger import get_next_experiment_dir + +from tests.unit.environments.sliding_puzzle_game import SlidingPuzzleGame +from tests.unit.test_envs import SlidingPuzzleEnv, SlidingPuzzleMetadata +from nemo_reinforcer.data.interfaces import LLMMessageLogType, DatumSpec + + +def generate_puzzle_datum( + tokenizer, + game_config: Dict, + max_moves: int, + task_name: str, + idx: int, + policy_model_name: str, +) -> DatumSpec: + """Generates a single sliding puzzle datum (prompt and metadata).""" + # (Content copied from previous correct version) + initial_game_state = SlidingPuzzleGame.generate(game_config) + initial_render = SlidingPuzzleGame.render(initial_game_state) + welcome_message = SlidingPuzzleGame.init(initial_game_state) + puzzle_size = game_config.get("size", 3) + prompt_instructions = ( + f"{welcome_message}\n\n" + f"Current Board State:\n{initial_render}\n\n" + f"Reach the goal state where numbers are ordered 1 through {puzzle_size**2 - 1} " + f"with the empty space (0) at the bottom right.\n" + f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n" + f"After thinking, output your chosen action on a new line starting with 'Action:' like this:\nAction: your_action" + f"\nThink step-by-step before acting.\n" + ) + add_system_prompt = "chat" in policy_model_name.lower() + initial_prompt_content = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt_instructions}], + tokenize=False, + add_system_prompt=add_system_prompt, + add_generation_prompt=True, + add_special_tokens=False, + ).strip() + tokenized_prompt = tokenizer( + initial_prompt_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + message_log: LLMMessageLogType = [ + { + "role": "user", + "content": initial_prompt_content, + "token_ids": tokenized_prompt, + } + ] + metadata = SlidingPuzzleMetadata( + game_state=initial_game_state, num_moves=0, max_moves=max_moves + ) + datum: DatumSpec = { + "message_log": message_log, + "length": len(tokenized_prompt), + "extra_env_info": metadata, + "loss_multiplier": 1.0, + "idx": idx, + "task_name": task_name, + } + return datum + + +# === MODIFIED: Replace PreGeneratedPuzzleDataset with IterablePuzzleDataset === +class IterablePuzzleDataset(IterableDataset): + """An IterableDataset that generates sliding puzzle data indefinitely.""" + + # === MODIFIED: Removed dataset_size, generates indefinitely === + def __init__( + self, tokenizer, game_config, max_moves, task_name, policy_model_name, length + ): + super().__init__() + self.tokenizer = tokenizer + self.game_config = game_config + self.max_moves = max_moves + self.task_name = task_name + self.policy_model_name = policy_model_name + self.length = length + + def __iter__(self) -> Iterator[DatumSpec]: + print( + f"Starting new iteration of IterablePuzzleDataset (indefinite generation)." + ) + # Use itertools.count for an infinite index generator + for i in itertools.count(): + yield generate_puzzle_datum( + tokenizer=self.tokenizer, + game_config=self.game_config, + max_moves=self.max_moves, + task_name=self.task_name, + idx=i, + policy_model_name=self.policy_model_name, + ) + # This print message will never be reached in normal operation + # print(f"Finished iteration of IterablePuzzleDataset.") + + def __len__(self): + return self.length + + +# === MODIFIED: setup_puzzle_data now returns IterablePuzzleDataset === +def setup_puzzle_data( + tokenizer: AutoTokenizer, + # === MODIFIED: Accept `env_cfg` instead of `env_configs` === + env_cfg: Dict[str, Any], + policy_cfg: Dict[str, Any], + task_name: str, + length: int, +) -> Tuple[IterableDataset, IterableDataset | None, Dict, Dict]: + """Sets up the iterable data generator and env map for the sliding puzzle task.""" + print("Setting up Sliding Puzzle iterable data and environment...") + # === MODIFIED: Access env config directly via task_name === + env_config = env_cfg[task_name] + + # --- Instantiate Environment Actor --- # + print(f"Instantiating environment actor for task '{task_name}'...") + module_path, class_name = env_config["env_class"].rsplit(".", 1) + try: + EnvClass = getattr(__import__(module_path, fromlist=[class_name]), class_name) + except ImportError as e: + print( + f"ERROR: Could not import environment class {env_config['env_class']}. Ensure it's in PYTHONPATH." + ) + raise e + env_actor = EnvClass.options(num_gpus=0).remote(cfg=dict(env_config["cfg"])) + task_to_env = {task_name: env_actor} + print(f"Environment actor '{task_name}' created.") + + # --- Instantiate Iterable Dataset --- # + print(f"Creating IterablePuzzleDataset...") + training_dataset = IterablePuzzleDataset( + tokenizer=tokenizer, + game_config=dict(env_config["cfg"]["game_config"]), + max_moves=env_config["cfg"]["max_moves"], + task_name=task_name, + policy_model_name=policy_cfg.get("model_name", ""), + length=length, + ) + print("Iterable training dataset created.") + + validation_dataset = IterablePuzzleDataset( + tokenizer=tokenizer, + game_config=dict(env_config["cfg"]["game_config"]), + max_moves=env_config["cfg"]["max_moves"], + task_name=task_name, + policy_model_name=policy_cfg.get("model_name", ""), + length=256, + ) + val_task_to_env = task_to_env + + return training_dataset, validation_dataset, task_to_env, val_task_to_env + + +# === Argparse function (Keep as is) === +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + args, overrides = parser.parse_known_args() + return args, overrides + + +# === Main function (Follow math structure exactly) === +def main(): + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + # Default config path + if not args.config: + # --- MODIFIED: Default config path --- + default_config_path = os.path.join( + os.path.dirname(__file__), "configs", "grpo_sliding_puzzle.yaml" + ) + if not os.path.exists(default_config_path): + raise FileNotFoundError( + f"Default config file not found at {default_config_path}." + ) + args.config = default_config_path + print(f"No config provided, using default: {args.config}") + + # Load base config + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + # Apply overrides + if overrides: + print(f"Applying overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) # Returns OmegaConf object + print("Applied CLI overrides.") + else: + # Ensure config is OmegaConf object even without overrides for consistency + config = OmegaConf.create(config) + + # Convert final config to dictionary for local use AFTER overrides + # Use resolve=True to handle interpolations if any remain + final_config_obj = config # Keep as OmegaConf object for setup/utils + final_config_dict = OmegaConf.to_container(config, resolve=True) + print("----- Final Configuration ----- ") + pprint.pprint(final_config_dict) + print("--------------------------------- ") + + # Configure logging directory + # Use dictionary access here + logger_cfg = final_config_dict.get("logger", {}) + if "log_dir" in logger_cfg: + try: + log_dir = get_next_experiment_dir(logger_cfg["log_dir"]) + # Update dictionary for consistency, though setup might use OmegaConf obj + final_config_dict["logger"]["log_dir"] = log_dir + # Also update OmegaConf object if setup relies on it + if isinstance(final_config_obj, OmegaConf): + OmegaConf.update( + final_config_obj, "logger.log_dir", log_dir, merge=True + ) + print(f"Logging directory set to: {log_dir}") + os.makedirs(log_dir, exist_ok=True) + except Exception as e: + print(f"WARNING: Could not configure logging directory: {e}") + else: + print( + "WARNING: 'logger.log_dir' not found in config, using default logging behavior." + ) + + # Configure checkpointing directory + # Use dictionary access here + checkpoint_cfg = final_config_dict.get("checkpointing", {}) + if checkpoint_cfg.get("enabled"): + if "checkpoint_dir" in checkpoint_cfg: + print( + f"Checkpointing enabled. Directory: {checkpoint_cfg['checkpoint_dir']}" + ) + os.makedirs(checkpoint_cfg["checkpoint_dir"], exist_ok=True) + else: + print( + "WARNING: Checkpointing enabled but 'checkpointing.checkpoint_dir' not specified." + ) + + # Initialize Ray first + # Pass the dictionary config to init_ray + init_ray() + + # Setup tokenizer + # === MODIFIED: Access tokenizer config from new structure === + policy_cfg = final_config_dict["policy"] + tokenizer_cfg = policy_cfg.get( + "tokenizer", policy_cfg + ) # Use policy dict if 'tokenizer' key absent + tokenizer = get_tokenizer(tokenizer_cfg) + print("Tokenizer loaded.") + + # Configure generation config + # === MODIFIED: Access generation config from new structure === + if "generation" in policy_cfg: + policy_cfg["generation"] = configure_generation_config( + policy_cfg["generation"], tokenizer + ) + # Update the main config dict/obj if needed by setup + final_config_dict["policy"]["generation"] = policy_cfg["generation"] + if isinstance(final_config_obj, OmegaConf): + OmegaConf.update( + final_config_obj, + "policy.generation", + policy_cfg["generation"], + merge=True, + ) + print("Generation config configured.") + else: + print("WARNING: Policy generation config not found.") + + # Setup data & env map + ds_length = ( + config["grpo"]["num_prompts_per_step"] + * config["grpo"]["num_generations_per_prompt"] + * config["grpo"]["max_num_steps"] + ) + dataset, val_dataset, task_to_env, val_task_to_env = setup_puzzle_data( + tokenizer=tokenizer, + env_cfg=final_config_dict["env"], # Pass 'env' section + policy_cfg=policy_cfg, + task_name="sliding_puzzle_game", + length=ds_length, + ) + + # Call the IMPORTED setup function + print("Running main setup...") + # Pass the dictionary config + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, # Instantiated logger object + checkpointer, # Instantiated checkpointer object + grpo_state, # Initial state for training + master_config, # Processed MasterConfig object + # Pass final_config_dict (plain dict) to setup + ) = setup(final_config_dict, tokenizer, dataset, val_dataset) + print("Main setup complete.") + + # Call grpo_train with the components returned by setup + print("Starting GRPO training...") + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + print("GRPO training finished.") + + # Final logging message + output_dir = None + if logger is not None and hasattr(logger, "log_dir") and logger.log_dir: + output_dir = logger.log_dir + elif "logger" in final_config_dict and "log_dir" in final_config_dict["logger"]: + output_dir = final_config_dict["logger"]["log_dir"] + if not output_dir: + output_dir = final_config_dict.get( + "output_dir", "./grpo_sliding_puzzle_outputs/unknown_run" + ) + print(f"Checkpoints and logs should be in: {output_dir}") + print("Script finished successfully.") + + +if __name__ == "__main__": + main() From b9d936f1a179d7d56850009cd7ae151888bd3dd9 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Fri, 18 Apr 2025 17:02:40 -0700 Subject: [PATCH 14/34] fix many GPU bug Signed-off-by: Sahil Jain --- nemo_reinforcer/distributed/batched_data_dict.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo_reinforcer/distributed/batched_data_dict.py b/nemo_reinforcer/distributed/batched_data_dict.py index 994781d4b2..966018698e 100644 --- a/nemo_reinforcer/distributed/batched_data_dict.py +++ b/nemo_reinforcer/distributed/batched_data_dict.py @@ -205,6 +205,9 @@ def shard_by_batch_size( shard_start = chunk_start + shard_idx * shard_size shard_end = chunk_start + (shard_idx + 1) * shard_size if allow_uneven_shards: + # Cap the end index at the total batch size for the last shard + # or if shard_end calculation goes beyond total_batch_size + shard_start = min(shard_start, total_batch_size) shard_end = min(shard_end, total_batch_size) indices = torch.arange(shard_start, shard_end) From 8970cd037d6935a389633006be4c84a00b5366a7 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Fri, 18 Apr 2025 18:19:22 -0700 Subject: [PATCH 15/34] :wUpdated sliding defaults Signed-off-by: Sahil Jain --- examples/configs/grpo_sliding_puzzle.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index 42682d92c9..e7ee14cf92 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -8,10 +8,10 @@ env: cfg: # Game generation parameters game_config: - size: 3 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) - shuffle_moves: 5 # Number of random moves to shuffle the solved state + size: 4 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) + shuffle_moves: 15 # Number of random moves to shuffle the solved state # Gameplay parameters - max_moves: 50 # Maximum moves allowed per episode + max_moves: 70 # Maximum moves allowed per episode grpo: max_num_steps: 10000 @@ -32,8 +32,8 @@ policy: train_micro_batch_size: 4 generation_batch_size: 32 # Only used when generating using HF backend logprob_batch_size: 4 - max_total_sequence_length: 1024 - max_turns: 50 + max_total_sequence_length: 2048 + max_turns: 70 precision: "bfloat16" fsdp_offload_enabled: false activation_checkpointing_enabled: false From df0b09f420c3fde8bcd0d110ffb3eab8f60a6e81 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 21 Apr 2025 13:21:13 -0700 Subject: [PATCH 16/34] Bugfixes to multiturn Signed-off-by: Sahil Jain --- examples/configs/grpo_sliding_puzzle.yaml | 12 ++++----- examples/run_grpo_sliding_puzzle.py | 21 ++++++++++----- nemo_reinforcer/data/datasets.py | 4 +++ nemo_reinforcer/experience/rollouts.py | 1 + nemo_reinforcer/models/generation/vllm.py | 1 + tests/unit/test_envs.py | 33 ++++++++++++++++------- 6 files changed, 50 insertions(+), 22 deletions(-) diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index e7ee14cf92..d7c5cac8a3 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -8,7 +8,7 @@ env: cfg: # Game generation parameters game_config: - size: 4 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) + size: 5 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) shuffle_moves: 15 # Number of random moves to shuffle the solved state # Gameplay parameters max_moves: 70 # Maximum moves allowed per episode @@ -29,17 +29,17 @@ policy: tokenizer: name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 - train_micro_batch_size: 4 + train_micro_batch_size: 1 generation_batch_size: 32 # Only used when generating using HF backend logprob_batch_size: 4 - max_total_sequence_length: 2048 + max_total_sequence_length: 3072 max_turns: 70 precision: "bfloat16" fsdp_offload_enabled: false activation_checkpointing_enabled: false dtensor_cfg: - enabled: false + enabled: true cpu_offload: False sequence_parallel: false activation_checkpointing: false @@ -78,8 +78,8 @@ policy: backend: "vllm" max_new_tokens: ${policy.max_total_sequence_length} temperature: 1.0 - top_p: 1.0 - top_k: null + top_p: 0.999 + top_k: 10000 stop_token_ids: null stop_strings: null vllm_cfg: diff --git a/examples/run_grpo_sliding_puzzle.py b/examples/run_grpo_sliding_puzzle.py index b6a549b354..200072ae99 100644 --- a/examples/run_grpo_sliding_puzzle.py +++ b/examples/run_grpo_sliding_puzzle.py @@ -18,6 +18,7 @@ import itertools # For infinite counter from collections import defaultdict from typing import Any, Dict, Tuple, List, Iterator # Added Iterator +import random import torch # Added torch import from omegaconf import OmegaConf @@ -56,6 +57,16 @@ def generate_puzzle_datum( ) -> DatumSpec: """Generates a single sliding puzzle datum (prompt and metadata).""" # (Content copied from previous correct version) + def generate_random_config(max_config: Dict[str, Any]) -> Dict[str, Any]: + """Generate a random config for the sliding puzzle game.""" + shuffle_moves = random.randint(1, max_config.get("shuffle_moves")) + if shuffle_moves % 2 == 0: + shuffle_moves += 1 + return { + "size": random.randint(2, max_config.get("size", 3)), + "shuffle_moves": shuffle_moves, + } + game_config = generate_random_config(game_config) initial_game_state = SlidingPuzzleGame.generate(game_config) initial_render = SlidingPuzzleGame.render(initial_game_state) welcome_message = SlidingPuzzleGame.init(initial_game_state) @@ -66,8 +77,9 @@ def generate_puzzle_datum( f"Reach the goal state where numbers are ordered 1 through {puzzle_size**2 - 1} " f"with the empty space (0) at the bottom right.\n" f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n" - f"After thinking, output your chosen action on a new line starting with 'Action:' like this:\nAction: your_action" - f"\nThink step-by-step before acting.\n" + f"After thinking, output your chosen action on a new line starting with '' like this:\nyour_action" + f"\nIf you just want to see the board, output view" + f"\nThink carefully step-by-step before acting.\n" ) add_system_prompt = "chat" in policy_model_name.lower() initial_prompt_content = tokenizer.apply_chat_template( @@ -97,15 +109,14 @@ def generate_puzzle_datum( "loss_multiplier": 1.0, "idx": idx, "task_name": task_name, + "stop_strings": [""], } return datum -# === MODIFIED: Replace PreGeneratedPuzzleDataset with IterablePuzzleDataset === class IterablePuzzleDataset(IterableDataset): """An IterableDataset that generates sliding puzzle data indefinitely.""" - # === MODIFIED: Removed dataset_size, generates indefinitely === def __init__( self, tokenizer, game_config, max_moves, task_name, policy_model_name, length ): @@ -131,8 +142,6 @@ def __iter__(self) -> Iterator[DatumSpec]: idx=i, policy_model_name=self.policy_model_name, ) - # This print message will never be reached in normal operation - # print(f"Finished iteration of IterablePuzzleDataset.") def __len__(self): return self.length diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 8d8ca78371..822ec9c370 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -124,6 +124,9 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: idx = [datum_spec["idx"] for datum_spec in data_batch] batch_max_length = torch.ones_like(length) * length.max() + # Extract stop_strings if present + stop_strings = [datum.get("stop_strings", [None]) for datum in data_batch] + output = BatchedDataDict( message_log=message_log, length=length, @@ -132,6 +135,7 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: task_name=task_names, idx=idx, batch_max_length=batch_max_length, + stop_strings=stop_strings, ) return output diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py index 891c5dae16..e671575a0c 100644 --- a/nemo_reinforcer/experience/rollouts.py +++ b/nemo_reinforcer/experience/rollouts.py @@ -230,6 +230,7 @@ def run_multi_turn_rollout( # Initialize stop_strings from the initial batch if present current_stop_strings = current_batch.get("stop_strings", [None] * batch_size) + # print(f"current_stop_strings: {current_stop_strings}") # Keep commented out # Tracking metrics for each sample sample_turn_counts = torch.zeros(batch_size, dtype=torch.int32) diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index 03851f2de7..ab9b1767a7 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -250,6 +250,7 @@ def generate( stop_strings.update(self.cfg["stop_strings"]) stop_strings = list(stop_strings) + print(f"stop_strings: {stop_strings}") # verify inputs have correct padding verify_right_padding(data, pad_value=self.cfg["pad_token_id"]) diff --git a/tests/unit/test_envs.py b/tests/unit/test_envs.py index 242dd1fb62..e66149b43d 100644 --- a/tests/unit/test_envs.py +++ b/tests/unit/test_envs.py @@ -188,18 +188,23 @@ def __init__(self): pass # No initialization needed as game methods are static def _parse_action(self, text: str) -> Optional[str]: - """Parses the action from 'Action: ...' prefix.""" - prefix = "Action:" + """Parses the action from ''""" + prefix = "" + suffix = "" # Find the prefix, case-insensitive, and potentially after some thought process text_lower = text.lower() prefix_lower = prefix.lower() + suffix_lower = suffix.lower() + start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence - + if start_idx != -1: - # Return the part after the prefix - action_part = text[start_idx + len(prefix) :].strip() - # Take only the first line if multiple lines were generated after prefix - return action_part.split("\n")[0].strip() + # Find the end tag after the start tag + end_idx = text_lower.find(suffix_lower, start_idx + len(prefix_lower)) + if end_idx != -1: + # Extract content between tags + action_content = text[start_idx + len(prefix):end_idx].strip() + return action_content return None def process_turn( @@ -220,7 +225,7 @@ def process_turn( turn_reward = 0.0 is_terminated = False - next_stop_strings = None # Let model finish its thought and action naturally + next_stop_strings = [""] next_metadata = metadata.copy() next_observation_content = "" @@ -249,10 +254,16 @@ def process_turn( if parsed_action is None: # Handle cases where parsing failed or it wasn't assistant's turn properly # is_terminated = True # Penalize for bad format + rendered_board = SlidingPuzzleGame.render(game_state) next_observation_content = ( - "\nInvalid response format. Try 'Action: your_move'." + f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action" ) next_metadata = None + elif parsed_action == "view": + rendered_board = SlidingPuzzleGame.render(game_state) + next_observation_content = ( + f"\n{rendered_board}\n\nViewing the board. No move made." + ) else: # Execute the game step step_response, reward, game_over, next_game_state = SlidingPuzzleGame.step( @@ -266,7 +277,9 @@ def process_turn( # Combine rendered board and step response for the next observation rendered_board = SlidingPuzzleGame.render(next_game_state) - next_observation_content = f"{rendered_board}\n\n{step_response}" + # next_observation_content = f"\n{rendered_board}\n\n{step_response}" + next_observation_content = f"\n{step_response}\n" + # next_observation_content = f"\n{step_response}" if is_terminated: next_metadata = None # Clear metadata on termination From e075cc4c33bf60dfa31fe19ae52c14d683d63175 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 21 Apr 2025 13:52:32 -0700 Subject: [PATCH 17/34] Fixed sliding puzzle test Signed-off-by: Sahil Jain --- tests/unit/experience/test_rollouts.py | 41 +++++++++++++------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index 04639a4275..4173e7bffd 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -467,7 +467,7 @@ def initial_sliding_puzzle_batch(rollout_tokenizer): "size": 2, "shuffle_moves": 1, } - max_moves = 25 # Set a limit for the test + max_moves = 10 # Set a limit for the test # Generate initial game state initial_game_state = SlidingPuzzleGame.generate(game_config) @@ -479,9 +479,10 @@ def initial_sliding_puzzle_batch(rollout_tokenizer): f"Current Board State:\n{initial_render}\n\n" f"Reach the goal state where numbers are ordered 1 through {game_config['size'] ** 2 - 1} " f"with the empty space (0) at the bottom right.\n" - f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n" - f"After thinking, output your chosen action on a new line starting with 'Action:' like this:\nAction: your_action" - f"Think step-by-step before acting. \n" + f"Valid actions: 'up', 'down', 'left', 'right'\n" + f"After thinking, output your chosen action on a new line starting with '' like this:\nyour_action" + f"\nIf you just want to see the board, output view" + f"\nThink carefully step-by-step before acting. If you get a 'cannot slide' error, try something different\n" ) batch_message_logs = [] @@ -526,7 +527,7 @@ def initial_sliding_puzzle_batch(rollout_tokenizer): "loss_multiplier": batch_loss_multipliers, "idx": batch_indices, "task_name": batch_task_names, - # No stop_strings needed initially, env provides + "stop_strings": [""], } return BatchedDataDict(initial_batch_dict) @@ -603,25 +604,23 @@ def test_run_sliding_puzzle_vllm(sliding_puzzle_setup_vllm): assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) sample_log = final_batch["message_log"][0] - print("\nSample Interaction Log (Sliding Puzzle - VLLM):") + print(f"Final Total Reward: {final_batch['total_reward'][0].item()}") + + # Count the number of tags and environment messages action_tag_count = 0 - for i, msg in enumerate(sample_log): - print(f" {i}: Role={msg['role']}, Content starts with: '{msg['content']}'") - if msg["role"] == "assistant" and "action:" in msg["content"].lower(): - action_tag_count += 1 + environment_message_count = 0 - assert action_tag_count > 0, ( - "Expected at least one assistant message with 'Action:' prefix" - ) + for msg in sample_log: + if msg["role"] == "assistant" and "" in msg["content"]: + action_tag_count += 1 + elif msg["role"] == "environment": + environment_message_count += 1 - print(f"Final Total Reward: {final_batch['total_reward'][0].item()}") - assert final_batch["total_reward"][0] > 0.0, ( - f"Expected final reward to be greater than 0.0 (solved), but got {final_batch['total_reward'][0]}" - ) + print(f"Found {action_tag_count} messages with tags") + print(f"Found {environment_message_count} environment messages") - last_env_message = sample_log[-1]["content"] - assert "congratulations" in last_env_message.lower(), ( - "Last message should indicate puzzle solved" - ) + # Assert that we have multiple action tags and environment messages + assert action_tag_count > 3, "Expected at least one message with tag" + assert environment_message_count > 3, "Expected at least one environment message" print("\nSliding Puzzle VLLM Test assertions passed.") From d734f35bd63f89353c7f2eae8f4a6111a44931a9 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 21 Apr 2025 13:57:22 -0700 Subject: [PATCH 18/34] Removed WIP example Signed-off-by: Sahil Jain --- examples/configs/grpo_sliding_puzzle.yaml | 107 ------ examples/run_grpo_sliding_puzzle.py | 387 ---------------------- 2 files changed, 494 deletions(-) delete mode 100644 examples/configs/grpo_sliding_puzzle.yaml delete mode 100644 examples/run_grpo_sliding_puzzle.py diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml deleted file mode 100644 index d7c5cac8a3..0000000000 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ /dev/null @@ -1,107 +0,0 @@ -defaults: "grpo_math_1B.yaml" - -# Environment setup: Map task names to their configurations -env: - sliding_puzzle_game: - env_class: "tests.unit.test_envs.SlidingPuzzleEnv" # Path to the environment actor class - # Configuration passed to the SlidingPuzzleEnv constructor - cfg: - # Game generation parameters - game_config: - size: 5 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) - shuffle_moves: 15 # Number of random moves to shuffle the solved state - # Gameplay parameters - max_moves: 70 # Maximum moves allowed per episode - -grpo: - max_num_steps: 10000 - -checkpointing: - enabled: true - checkpoint_dir: "results/grpo-sliding-puzzle" - metric_name: "val_reward" - higher_is_better: true - keep_top_k: 3 - save_period: 10 - -policy: - model_name: "meta-llama/Llama-3.2-1B-Instruct" - tokenizer: - name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default - train_global_batch_size: 512 - train_micro_batch_size: 1 - generation_batch_size: 32 # Only used when generating using HF backend - logprob_batch_size: 4 - max_total_sequence_length: 3072 - max_turns: 70 - precision: "bfloat16" - fsdp_offload_enabled: false - activation_checkpointing_enabled: false - - dtensor_cfg: - enabled: true - cpu_offload: False - sequence_parallel: false - activation_checkpointing: false - tensor_parallel_size: 1 - - # makes the training sequence length divisible by the tensor parallel size - # this is useful for sequence parallel training - make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} - max_grad_norm: 1.0 - - optimizer: - name: "torch.optim.AdamW" - kwargs: - lr: 5.0e-6 - weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1e-8 - # when using Dtensor, we need to set foreach - # and fused to False - foreach: False - fused: False - - scheduler: - - name: "torch.optim.lr_scheduler.LinearLR" - kwargs: - start_factor: 0.1 - end_factor: 1.0 - total_iters: 50 - - name: "torch.optim.lr_scheduler.ConstantLR" - kwargs: - factor: 1.0 - total_iters: 10000000000 - - milestones: [50] - - generation: - backend: "vllm" - max_new_tokens: ${policy.max_total_sequence_length} - temperature: 1.0 - top_p: 0.999 - top_k: 10000 - stop_token_ids: null - stop_strings: null - vllm_cfg: - tensor_parallel_size: 1 - gpu_memory_utilization: 0.6 - max_model_len: ${policy.max_total_sequence_length} - -logger: - log_dir: "logs" # Base directory for all logs - num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal - wandb_enabled: false - tensorboard_enabled: false - monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard - wandb: - project: "grpo-dev" - name: "grpo-dev-sliding_puzzle" - tensorboard: {} - gpu_monitoring: - collection_interval: 10 # How often to collect GPU usage metrics (in seconds) - flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) - -cluster: - gpus_per_node: 1 - num_nodes: 1 - diff --git a/examples/run_grpo_sliding_puzzle.py b/examples/run_grpo_sliding_puzzle.py deleted file mode 100644 index 200072ae99..0000000000 --- a/examples/run_grpo_sliding_puzzle.py +++ /dev/null @@ -1,387 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import pprint -import itertools # For infinite counter -from collections import defaultdict -from typing import Any, Dict, Tuple, List, Iterator # Added Iterator -import random - -import torch # Added torch import -from omegaconf import OmegaConf -from transformers import AutoTokenizer - -# === MODIFIED: Use IterableDataset === -from torch.utils.data import Dataset, IterableDataset # Import IterableDataset - -# === Core Imports (Keep from math example) === -from nemo_reinforcer.algorithms.grpo import ( - MasterConfig, - grpo_train, - setup, -) # CRITICAL: Keep imported setup -from nemo_reinforcer.algorithms.utils import get_tokenizer - -# from nemo_reinforcer.data import DataConfig # Keep if setup needs it, maybe remove later -# from nemo_reinforcer.data.interfaces import TaskDataSpec # Remove later if not needed by setup_puzzle_data -from nemo_reinforcer.distributed.virtual_cluster import init_ray -from nemo_reinforcer.models.generation.interfaces import configure_generation_config -from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides -from nemo_reinforcer.utils.logger import get_next_experiment_dir - -from tests.unit.environments.sliding_puzzle_game import SlidingPuzzleGame -from tests.unit.test_envs import SlidingPuzzleEnv, SlidingPuzzleMetadata -from nemo_reinforcer.data.interfaces import LLMMessageLogType, DatumSpec - - -def generate_puzzle_datum( - tokenizer, - game_config: Dict, - max_moves: int, - task_name: str, - idx: int, - policy_model_name: str, -) -> DatumSpec: - """Generates a single sliding puzzle datum (prompt and metadata).""" - # (Content copied from previous correct version) - def generate_random_config(max_config: Dict[str, Any]) -> Dict[str, Any]: - """Generate a random config for the sliding puzzle game.""" - shuffle_moves = random.randint(1, max_config.get("shuffle_moves")) - if shuffle_moves % 2 == 0: - shuffle_moves += 1 - return { - "size": random.randint(2, max_config.get("size", 3)), - "shuffle_moves": shuffle_moves, - } - game_config = generate_random_config(game_config) - initial_game_state = SlidingPuzzleGame.generate(game_config) - initial_render = SlidingPuzzleGame.render(initial_game_state) - welcome_message = SlidingPuzzleGame.init(initial_game_state) - puzzle_size = game_config.get("size", 3) - prompt_instructions = ( - f"{welcome_message}\n\n" - f"Current Board State:\n{initial_render}\n\n" - f"Reach the goal state where numbers are ordered 1 through {puzzle_size**2 - 1} " - f"with the empty space (0) at the bottom right.\n" - f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n" - f"After thinking, output your chosen action on a new line starting with '' like this:\nyour_action" - f"\nIf you just want to see the board, output view" - f"\nThink carefully step-by-step before acting.\n" - ) - add_system_prompt = "chat" in policy_model_name.lower() - initial_prompt_content = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt_instructions}], - tokenize=False, - add_system_prompt=add_system_prompt, - add_generation_prompt=True, - add_special_tokens=False, - ).strip() - tokenized_prompt = tokenizer( - initial_prompt_content, return_tensors="pt", add_special_tokens=False - )["input_ids"][0] - message_log: LLMMessageLogType = [ - { - "role": "user", - "content": initial_prompt_content, - "token_ids": tokenized_prompt, - } - ] - metadata = SlidingPuzzleMetadata( - game_state=initial_game_state, num_moves=0, max_moves=max_moves - ) - datum: DatumSpec = { - "message_log": message_log, - "length": len(tokenized_prompt), - "extra_env_info": metadata, - "loss_multiplier": 1.0, - "idx": idx, - "task_name": task_name, - "stop_strings": [""], - } - return datum - - -class IterablePuzzleDataset(IterableDataset): - """An IterableDataset that generates sliding puzzle data indefinitely.""" - - def __init__( - self, tokenizer, game_config, max_moves, task_name, policy_model_name, length - ): - super().__init__() - self.tokenizer = tokenizer - self.game_config = game_config - self.max_moves = max_moves - self.task_name = task_name - self.policy_model_name = policy_model_name - self.length = length - - def __iter__(self) -> Iterator[DatumSpec]: - print( - f"Starting new iteration of IterablePuzzleDataset (indefinite generation)." - ) - # Use itertools.count for an infinite index generator - for i in itertools.count(): - yield generate_puzzle_datum( - tokenizer=self.tokenizer, - game_config=self.game_config, - max_moves=self.max_moves, - task_name=self.task_name, - idx=i, - policy_model_name=self.policy_model_name, - ) - - def __len__(self): - return self.length - - -# === MODIFIED: setup_puzzle_data now returns IterablePuzzleDataset === -def setup_puzzle_data( - tokenizer: AutoTokenizer, - # === MODIFIED: Accept `env_cfg` instead of `env_configs` === - env_cfg: Dict[str, Any], - policy_cfg: Dict[str, Any], - task_name: str, - length: int, -) -> Tuple[IterableDataset, IterableDataset | None, Dict, Dict]: - """Sets up the iterable data generator and env map for the sliding puzzle task.""" - print("Setting up Sliding Puzzle iterable data and environment...") - # === MODIFIED: Access env config directly via task_name === - env_config = env_cfg[task_name] - - # --- Instantiate Environment Actor --- # - print(f"Instantiating environment actor for task '{task_name}'...") - module_path, class_name = env_config["env_class"].rsplit(".", 1) - try: - EnvClass = getattr(__import__(module_path, fromlist=[class_name]), class_name) - except ImportError as e: - print( - f"ERROR: Could not import environment class {env_config['env_class']}. Ensure it's in PYTHONPATH." - ) - raise e - env_actor = EnvClass.options(num_gpus=0).remote(cfg=dict(env_config["cfg"])) - task_to_env = {task_name: env_actor} - print(f"Environment actor '{task_name}' created.") - - # --- Instantiate Iterable Dataset --- # - print(f"Creating IterablePuzzleDataset...") - training_dataset = IterablePuzzleDataset( - tokenizer=tokenizer, - game_config=dict(env_config["cfg"]["game_config"]), - max_moves=env_config["cfg"]["max_moves"], - task_name=task_name, - policy_model_name=policy_cfg.get("model_name", ""), - length=length, - ) - print("Iterable training dataset created.") - - validation_dataset = IterablePuzzleDataset( - tokenizer=tokenizer, - game_config=dict(env_config["cfg"]["game_config"]), - max_moves=env_config["cfg"]["max_moves"], - task_name=task_name, - policy_model_name=policy_cfg.get("model_name", ""), - length=256, - ) - val_task_to_env = task_to_env - - return training_dataset, validation_dataset, task_to_env, val_task_to_env - - -# === Argparse function (Keep as is) === -def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="Run GRPO training with configuration") - parser.add_argument( - "--config", type=str, default=None, help="Path to YAML config file" - ) - args, overrides = parser.parse_known_args() - return args, overrides - - -# === Main function (Follow math structure exactly) === -def main(): - """Main entry point.""" - # Parse arguments - args, overrides = parse_args() - - # Default config path - if not args.config: - # --- MODIFIED: Default config path --- - default_config_path = os.path.join( - os.path.dirname(__file__), "configs", "grpo_sliding_puzzle.yaml" - ) - if not os.path.exists(default_config_path): - raise FileNotFoundError( - f"Default config file not found at {default_config_path}." - ) - args.config = default_config_path - print(f"No config provided, using default: {args.config}") - - # Load base config - config = load_config(args.config) - print(f"Loaded configuration from: {args.config}") - - # Apply overrides - if overrides: - print(f"Applying overrides: {overrides}") - config = parse_hydra_overrides(config, overrides) # Returns OmegaConf object - print("Applied CLI overrides.") - else: - # Ensure config is OmegaConf object even without overrides for consistency - config = OmegaConf.create(config) - - # Convert final config to dictionary for local use AFTER overrides - # Use resolve=True to handle interpolations if any remain - final_config_obj = config # Keep as OmegaConf object for setup/utils - final_config_dict = OmegaConf.to_container(config, resolve=True) - print("----- Final Configuration ----- ") - pprint.pprint(final_config_dict) - print("--------------------------------- ") - - # Configure logging directory - # Use dictionary access here - logger_cfg = final_config_dict.get("logger", {}) - if "log_dir" in logger_cfg: - try: - log_dir = get_next_experiment_dir(logger_cfg["log_dir"]) - # Update dictionary for consistency, though setup might use OmegaConf obj - final_config_dict["logger"]["log_dir"] = log_dir - # Also update OmegaConf object if setup relies on it - if isinstance(final_config_obj, OmegaConf): - OmegaConf.update( - final_config_obj, "logger.log_dir", log_dir, merge=True - ) - print(f"Logging directory set to: {log_dir}") - os.makedirs(log_dir, exist_ok=True) - except Exception as e: - print(f"WARNING: Could not configure logging directory: {e}") - else: - print( - "WARNING: 'logger.log_dir' not found in config, using default logging behavior." - ) - - # Configure checkpointing directory - # Use dictionary access here - checkpoint_cfg = final_config_dict.get("checkpointing", {}) - if checkpoint_cfg.get("enabled"): - if "checkpoint_dir" in checkpoint_cfg: - print( - f"Checkpointing enabled. Directory: {checkpoint_cfg['checkpoint_dir']}" - ) - os.makedirs(checkpoint_cfg["checkpoint_dir"], exist_ok=True) - else: - print( - "WARNING: Checkpointing enabled but 'checkpointing.checkpoint_dir' not specified." - ) - - # Initialize Ray first - # Pass the dictionary config to init_ray - init_ray() - - # Setup tokenizer - # === MODIFIED: Access tokenizer config from new structure === - policy_cfg = final_config_dict["policy"] - tokenizer_cfg = policy_cfg.get( - "tokenizer", policy_cfg - ) # Use policy dict if 'tokenizer' key absent - tokenizer = get_tokenizer(tokenizer_cfg) - print("Tokenizer loaded.") - - # Configure generation config - # === MODIFIED: Access generation config from new structure === - if "generation" in policy_cfg: - policy_cfg["generation"] = configure_generation_config( - policy_cfg["generation"], tokenizer - ) - # Update the main config dict/obj if needed by setup - final_config_dict["policy"]["generation"] = policy_cfg["generation"] - if isinstance(final_config_obj, OmegaConf): - OmegaConf.update( - final_config_obj, - "policy.generation", - policy_cfg["generation"], - merge=True, - ) - print("Generation config configured.") - else: - print("WARNING: Policy generation config not found.") - - # Setup data & env map - ds_length = ( - config["grpo"]["num_prompts_per_step"] - * config["grpo"]["num_generations_per_prompt"] - * config["grpo"]["max_num_steps"] - ) - dataset, val_dataset, task_to_env, val_task_to_env = setup_puzzle_data( - tokenizer=tokenizer, - env_cfg=final_config_dict["env"], # Pass 'env' section - policy_cfg=policy_cfg, - task_name="sliding_puzzle_game", - length=ds_length, - ) - - # Call the IMPORTED setup function - print("Running main setup...") - # Pass the dictionary config - ( - policy, - policy_generation, - cluster, - dataloader, - val_dataloader, - loss_fn, - logger, # Instantiated logger object - checkpointer, # Instantiated checkpointer object - grpo_state, # Initial state for training - master_config, # Processed MasterConfig object - # Pass final_config_dict (plain dict) to setup - ) = setup(final_config_dict, tokenizer, dataset, val_dataset) - print("Main setup complete.") - - # Call grpo_train with the components returned by setup - print("Starting GRPO training...") - grpo_train( - policy, - policy_generation, - dataloader, - val_dataloader, - tokenizer, - loss_fn, - task_to_env, - val_task_to_env, - logger, - checkpointer, - grpo_state, - master_config, - ) - print("GRPO training finished.") - - # Final logging message - output_dir = None - if logger is not None and hasattr(logger, "log_dir") and logger.log_dir: - output_dir = logger.log_dir - elif "logger" in final_config_dict and "log_dir" in final_config_dict["logger"]: - output_dir = final_config_dict["logger"]["log_dir"] - if not output_dir: - output_dir = final_config_dict.get( - "output_dir", "./grpo_sliding_puzzle_outputs/unknown_run" - ) - print(f"Checkpoints and logs should be in: {output_dir}") - print("Script finished successfully.") - - -if __name__ == "__main__": - main() From 694723d850b6eac787a9eb40ca4348efa99c5a44 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 21 Apr 2025 14:37:06 -0700 Subject: [PATCH 19/34] Cleanup Signed-off-by: Sahil Jain --- examples/configs/grpo_math_1B.yaml | 2 +- nemo_reinforcer/algorithms/grpo.py | 6 +++--- nemo_reinforcer/experience/rollouts.py | 15 ++++++--------- tests/unit/experience/test_rollouts.py | 18 +++++++++--------- 4 files changed, 19 insertions(+), 22 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 7149965324..d13325f2e5 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -32,7 +32,7 @@ policy: generation_batch_size: 32 # Only used when generating using HF backend logprob_batch_size: 4 max_total_sequence_length: 512 - max_turns: 999999 + max_rollout_turns: 1 precision: "bfloat16" fsdp_offload_enabled: false activation_checkpointing_enabled: false diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 1b829fef64..670c815e01 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -376,11 +376,11 @@ def grpo_train( with timer.time("generation"): repeated_batch, rollout_metrics = run_multi_turn_rollout( policy_generation=policy_generation, - initial_batch=repeated_batch, + input_batch=repeated_batch, tokenizer=tokenizer, task_to_env=task_to_env, max_seq_len=master_config["policy"]["max_total_sequence_length"], - max_turns=master_config["policy"]["max_turns"], + max_rollout_turns=master_config["policy"]["max_rollout_turns"], greedy=False, ) policy_generation.finish_generation() @@ -621,7 +621,7 @@ def validate( tokenizer, val_task_to_env, max_seq_len=master_config["policy"]["max_total_sequence_length"], - max_turns=master_config["policy"]["max_turns"], + max_rollout_turns=master_config["policy"]["max_rollout_turns"], greedy=False, ) rewards = val_batch["total_reward"] diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py index e671575a0c..b1633a5add 100644 --- a/nemo_reinforcer/experience/rollouts.py +++ b/nemo_reinforcer/experience/rollouts.py @@ -186,9 +186,6 @@ def calculate_rewards( next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices] metadata = [all_metadata[i] for i in sorted_indices] # Sort metadata - rewards = rewards - terminateds = terminateds - return EnvironmentReturn( observations=env_observations, metadata=metadata, @@ -200,21 +197,21 @@ def calculate_rewards( def run_multi_turn_rollout( policy_generation: GenerationInterface, - initial_batch: BatchedDataDict[DatumSpec], + input_batch: BatchedDataDict[DatumSpec], tokenizer: AutoTokenizer, task_to_env: Dict[str, EnvironmentInterface], max_seq_len: int, - max_turns: int = 999999, + max_rollout_turns: int = 999999, greedy: bool = False, ) -> Tuple[BatchedDataDict[DatumSpec], Dict[str, Any]]: """Runs a multi-turn rollout loop, interacting with the environment. Args: policy_generation: The generation interface (policy). - initial_batch: The starting batch containing initial message logs. + input_batch: The starting batch containing initial message logs. tokenizer: The tokenizer. task_to_env: Dictionary mapping task names to environment instances. - max_turns: Maximum number of agent-environment interaction turns. + max_rollout_turns: Maximum number of agent-environment interaction turns. max_seq_len: Maximum sequence length allowed. greedy: Whether to use greedy decoding. @@ -223,7 +220,7 @@ def run_multi_turn_rollout( - BatchedDataDict with the full interaction history and accumulated rewards - Dictionary of rollout metrics """ - current_batch = initial_batch.copy() # Work on a copy + current_batch = input_batch.copy() # Work on a copy batch_size = len(current_batch["message_log"]) active_indices = torch.arange(batch_size) total_rewards = torch.zeros(batch_size, dtype=torch.float32) @@ -245,7 +242,7 @@ def run_multi_turn_rollout( total_gen_tokens_per_turn = [] active_samples_per_turn = [] - for turn in range(max_turns): + for turn in range(max_rollout_turns): if len(active_indices) == 0: break diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index 4173e7bffd..daeecb2bc6 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -315,7 +315,7 @@ def test_run_multi_step_calculator_hf(multi_step_setup_hf): policy, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( multi_step_setup_hf ) - max_turns = ( + max_rollout_turns = ( initial_batch["extra_env_info"][0]["max_steps"] + 1 ) # Allow max steps + final answer max_seq_len = 1024 # Increased for potentially longer interaction @@ -324,11 +324,11 @@ def test_run_multi_step_calculator_hf(multi_step_setup_hf): policy.prepare_for_generation() final_batch, rollout_metrics = run_multi_turn_rollout( policy_generation=policy, - initial_batch=initial_batch, + input_batch=initial_batch, tokenizer=rollout_tokenizer, task_to_env=task_to_env, max_seq_len=max_seq_len, - max_turns=max_turns, + max_rollout_turns=max_rollout_turns, ) policy.finish_generation() print("Multi-step calculator rollout complete (HF).") @@ -384,18 +384,18 @@ def test_run_multi_step_calculator_vllm(multi_step_setup_vllm): vllm_generation, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( multi_step_setup_vllm ) - max_turns = initial_batch["extra_env_info"][0]["max_steps"] + 1 + max_rollout_turns = initial_batch["extra_env_info"][0]["max_steps"] + 1 max_seq_len = 1024 print("\nRunning multi-step calculator rollout (VLLM)...") vllm_generation.prepare_for_generation() final_batch, rollout_metrics = run_multi_turn_rollout( policy_generation=vllm_generation, - initial_batch=initial_batch, + input_batch=initial_batch, tokenizer=rollout_tokenizer, task_to_env=task_to_env, max_seq_len=max_seq_len, - max_turns=max_turns, + max_rollout_turns=max_rollout_turns, ) vllm_generation.finish_generation() print("Multi-step calculator rollout complete (VLLM).") @@ -579,17 +579,17 @@ def test_run_sliding_puzzle_vllm(sliding_puzzle_setup_vllm): sliding_puzzle_setup_vllm ) max_moves = initial_batch["extra_env_info"][0]["max_moves"] - max_turns = max_moves + 1 + max_rollout_turns = max_moves + 1 max_seq_len = 2048 print("\nRunning sliding puzzle rollout (VLLM)...") vllm_generation.prepare_for_generation() final_batch, rollout_metrics = run_multi_turn_rollout( policy_generation=vllm_generation, - initial_batch=initial_batch, + input_batch=initial_batch, tokenizer=rollout_tokenizer, task_to_env=task_to_env, - max_turns=max_turns, + max_rollout_turns=max_rollout_turns, max_seq_len=max_seq_len, greedy=True, ) From b680aa9d40043db8aabb3432714337b13e51eb8b Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 21 Apr 2025 14:42:47 -0700 Subject: [PATCH 20/34] Cleanup Signed-off-by: Sahil Jain --- examples/configs/grpo_math_1B.yaml | 2 +- nemo_reinforcer/algorithms/grpo.py | 4 ++-- nemo_reinforcer/experience/rollouts.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index d13325f2e5..747e98abdb 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -2,6 +2,7 @@ grpo: num_prompts_per_step: 32 num_generations_per_prompt: 16 + max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question) max_num_steps: 1000000 normalize_rewards: true use_leave_one_out_baseline: true @@ -32,7 +33,6 @@ policy: generation_batch_size: 32 # Only used when generating using HF backend logprob_batch_size: 4 max_total_sequence_length: 512 - max_rollout_turns: 1 precision: "bfloat16" fsdp_offload_enabled: false activation_checkpointing_enabled: false diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 670c815e01..155b49a8df 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -380,7 +380,7 @@ def grpo_train( tokenizer=tokenizer, task_to_env=task_to_env, max_seq_len=master_config["policy"]["max_total_sequence_length"], - max_rollout_turns=master_config["policy"]["max_rollout_turns"], + max_rollout_turns=master_config["grpo"]["max_rollout_turns"], greedy=False, ) policy_generation.finish_generation() @@ -621,7 +621,7 @@ def validate( tokenizer, val_task_to_env, max_seq_len=master_config["policy"]["max_total_sequence_length"], - max_rollout_turns=master_config["policy"]["max_rollout_turns"], + max_rollout_turns=master_config["grpo"]["max_rollout_turns"], greedy=False, ) rewards = val_batch["total_reward"] diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py index b1633a5add..87bedf5a02 100644 --- a/nemo_reinforcer/experience/rollouts.py +++ b/nemo_reinforcer/experience/rollouts.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Generate rollouts +# Generate rollouts for arbitrary environments +# Supports multi-turn rollouts and many simultaneous environments (E.g. you can train on math, code, multi-turn games and more at once) import torch from typing import List, Tuple, Dict, Optional, Any, NamedTuple From eea7cd7548242226d5de922c848b5a9d543f4b61 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 21 Apr 2025 14:58:31 -0700 Subject: [PATCH 21/34] math fix Signed-off-by: Sahil Jain --- nemo_reinforcer/data/datasets.py | 2 +- nemo_reinforcer/environments/math_environment.py | 2 +- nemo_reinforcer/experience/rollouts.py | 2 +- nemo_reinforcer/models/generation/vllm.py | 1 - nemo_reinforcer/models/policy/fsdp1_policy_worker.py | 1 - 5 files changed, 3 insertions(+), 5 deletions(-) diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 822ec9c370..f872b4e9f5 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -125,7 +125,7 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: batch_max_length = torch.ones_like(length) * length.max() # Extract stop_strings if present - stop_strings = [datum.get("stop_strings", [None]) for datum in data_batch] + stop_strings = [datum.get("stop_strings", None) for datum in data_batch] output = BatchedDataDict( message_log=message_log, diff --git a/nemo_reinforcer/environments/math_environment.py b/nemo_reinforcer/environments/math_environment.py index 81f41eb3da..cc61fdcb2c 100644 --- a/nemo_reinforcer/environments/math_environment.py +++ b/nemo_reinforcer/environments/math_environment.py @@ -148,7 +148,7 @@ def step( rewards = torch.tensor(results).cpu() done = torch.ones_like(rewards).cpu() - next_stop_strings = [[None]] * len(message_log_batch) + next_stop_strings = [None] * len(message_log_batch) return EnvironmentReturn( observations=observations, diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py index 87bedf5a02..6e43c6a62a 100644 --- a/nemo_reinforcer/experience/rollouts.py +++ b/nemo_reinforcer/experience/rollouts.py @@ -166,7 +166,7 @@ def calculate_rewards( result ) if next_stop_strings is None: - next_stop_strings = [[None]] * len(task_rewards) + next_stop_strings = [None] * len(task_rewards) # Store results with their original indices for i, idx in enumerate(indices): diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index ab9b1767a7..03851f2de7 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -250,7 +250,6 @@ def generate( stop_strings.update(self.cfg["stop_strings"]) stop_strings = list(stop_strings) - print(f"stop_strings: {stop_strings}") # verify inputs have correct padding verify_right_padding(data, pad_value=self.cfg["pad_token_id"]) diff --git a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py index 37c385d44b..3ff7faed70 100644 --- a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py +++ b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py @@ -529,7 +529,6 @@ def generate( stop_strings.update(gen_cfg["stop_strings"]) stop_strings = list(stop_strings) if len(stop_strings) > 0 else None - print(f"Stop strings: {stop_strings}") if isinstance( self.model, torch.distributed.fsdp.FullyShardedDataParallel From 44666cec4379f1a917ba535767a63e15b4742b57 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 21 Apr 2025 16:16:43 -0700 Subject: [PATCH 22/34] added doctests to batched_data_dict Signed-off-by: Sahil Jain --- .../distributed/batched_data_dict.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/nemo_reinforcer/distributed/batched_data_dict.py b/nemo_reinforcer/distributed/batched_data_dict.py index 966018698e..8e00e3b13b 100644 --- a/nemo_reinforcer/distributed/batched_data_dict.py +++ b/nemo_reinforcer/distributed/batched_data_dict.py @@ -158,6 +158,36 @@ def shard_by_batch_size( Returns: List[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards. + + Examples: + ```{doctest} + >>> from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + >>> # Create a batch of two message logs with different lengths + >>> batch = BatchedDataDict({ + ... 'problem_id': [0, 0, 1, 1, 2, 2, 3, 3], + ... 'arbitrary_data': [1, 2, 3, 4, 5, 6, 7, 8] + ... }) + >>> shards = batch.shard_by_batch_size(shards=2) + >>> shards + [{'problem_id': [0, 0, 1, 1], 'arbitrary_data': [1, 2, 3, 4]}, {'problem_id': [2, 2, 3, 3], 'arbitrary_data': [5, 6, 7, 8]}] + >>> # Now say that I'm training with a GBS of 4 and I want to take gradients steps on problems 0 and 1 before 2 and 3 (problems are repeated because GRPO) + >>> # In the current case, problems 0 and 2 will be trained on first since they're the first elements in each DP rank's batch. + >>> # So, we'll use the batch_size argument to split the batch into chunks of size 4 first. + >>> shards = batch.shard_by_batch_size(shards=2, batch_size=4) + >>> shards + [{'problem_id': [0, 0, 2, 2], 'arbitrary_data': [1, 2, 5, 6]}, {'problem_id': [1, 1, 3, 3], 'arbitrary_data': [3, 4, 7, 8]}] + >>> # Now, the ranks have 0 and 1 first so when they split their batches into microbatches (of size 2 since GBS=4 and DP=2), they'll train on 0 and 1 first. + >>> # Another way to use this function is with the 'allow_uneven_shards' flag, which allows the last shard to be smaller than the others when necessary. + >>> # This is necessary in multi-turn rollouts when some sequences terminate early, leaving unclean batch sizes. + >>> batch = BatchedDataDict({ + ... 'problem_id': [0, 1, 2, 3, 4], + ... 'arbitrary_data': [10, 11, 12, 13, 14] + ... }) + >>> shards = batch.shard_by_batch_size(shards=2, allow_uneven_shards=True) + >>> shards + [{'problem_id': [0, 1, 2], 'arbitrary_data': [10, 11, 12]}, {'problem_id': [3, 4], 'arbitrary_data': [13, 14]}] + >>> # This is incompatible with the batch_size argument + ``` """ if allow_uneven_shards: assert batch_size is None, ( From 32eefb3982d55f7ac0d9ffc8975cef93bd8b26c7 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 21 Apr 2025 16:31:06 -0700 Subject: [PATCH 23/34] lint Signed-off-by: Sahil Jain --- tests/unit/test_envs.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_envs.py b/tests/unit/test_envs.py index e66149b43d..5d410f5895 100644 --- a/tests/unit/test_envs.py +++ b/tests/unit/test_envs.py @@ -195,15 +195,15 @@ def _parse_action(self, text: str) -> Optional[str]: text_lower = text.lower() prefix_lower = prefix.lower() suffix_lower = suffix.lower() - + start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence - + if start_idx != -1: # Find the end tag after the start tag end_idx = text_lower.find(suffix_lower, start_idx + len(prefix_lower)) if end_idx != -1: # Extract content between tags - action_content = text[start_idx + len(prefix):end_idx].strip() + action_content = text[start_idx + len(prefix) : end_idx].strip() return action_content return None @@ -255,15 +255,11 @@ def process_turn( # Handle cases where parsing failed or it wasn't assistant's turn properly # is_terminated = True # Penalize for bad format rendered_board = SlidingPuzzleGame.render(game_state) - next_observation_content = ( - f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action" - ) + next_observation_content = f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action" next_metadata = None elif parsed_action == "view": rendered_board = SlidingPuzzleGame.render(game_state) - next_observation_content = ( - f"\n{rendered_board}\n\nViewing the board. No move made." - ) + next_observation_content = f"\n{rendered_board}\n\nViewing the board. No move made." else: # Execute the game step step_response, reward, game_over, next_game_state = SlidingPuzzleGame.step( From f458fb0db2ec8ad173211b94f9fe55c41d4083d3 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 21 Apr 2025 23:35:47 -0700 Subject: [PATCH 24/34] unit tests Signed-off-by: Sahil Jain --- nemo_reinforcer/experience/rollouts.py | 1 - .../models/generation/test_vllm_generation.py | 30 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py index 6e43c6a62a..f41661efaf 100644 --- a/nemo_reinforcer/experience/rollouts.py +++ b/nemo_reinforcer/experience/rollouts.py @@ -228,7 +228,6 @@ def run_multi_turn_rollout( # Initialize stop_strings from the initial batch if present current_stop_strings = current_batch.get("stop_strings", [None] * batch_size) - # print(f"current_stop_strings: {current_stop_strings}") # Keep commented out # Tracking metrics for each sample sample_turn_counts = torch.zeros(batch_size, dtype=torch.int32) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 04e0cd5969..3f361394a9 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -814,3 +814,33 @@ def test_vllm_generation_with_stop( vllm_generation.shutdown() if not is_eval: hf_policy.shutdown() + + +def test_vllm_non_divisible_batch_handling(policy): + """Test that VLLM generation handles non divisible input batches correctly.""" + # This test runs on 2 GPUs but has a batch size of 1. The first GPU will run a batch + # and the second will run a batch of size 0. + + # Create and run with non divisible batch + empty_batch = BatchedDataDict( + { + "input_ids": torch.zeros((1, 1), dtype=torch.long), + "input_lengths": torch.ones(1, dtype=torch.long), + } + ) + + outputs = policy.generate(empty_batch) + + # Verify output structure and dimensions + required_keys = [ + "output_ids", + "logprobs", + "generation_lengths", + "unpadded_sequence_lengths", + ] + assert all(key in outputs for key in required_keys), ( + "Missing required output fields" + ) + assert all(outputs[key].shape[0] == 1 for key in required_keys), ( + "Output tensors should have a batch dimension of 1" + ) From 6be19086f4fdd9a128fc9922acb307a9e532def7 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Tue, 22 Apr 2025 02:34:45 -0700 Subject: [PATCH 25/34] added e2e sliding puzzle game example Signed-off-by: Sahil Jain --- examples/configs/grpo_sliding_puzzle.yaml | 111 +++++++ examples/run_grpo_sliding_puzzle.py | 278 ++++++++++++++++++ .../environments/games/sliding_puzzle.py | 224 +++++++++++--- tests/unit/environments/game_interface.py | 76 ----- tests/unit/experience/test_rollouts.py | 17 +- tests/unit/test_envs.py | 180 ------------ 6 files changed, 587 insertions(+), 299 deletions(-) create mode 100644 examples/configs/grpo_sliding_puzzle.yaml create mode 100644 examples/run_grpo_sliding_puzzle.py rename tests/unit/environments/sliding_puzzle_game.py => nemo_reinforcer/environments/games/sliding_puzzle.py (51%) delete mode 100644 tests/unit/environments/game_interface.py diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml new file mode 100644 index 0000000000..b218a5d060 --- /dev/null +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -0,0 +1,111 @@ +defaults: "grpo_math_1B.yaml" + +# Environment setup: Map task names to their configurations +env: + sliding_puzzle_game: + env_class: "tests.unit.test_envs.SlidingPuzzleEnv" # Path to the environment actor class + # Configuration passed to the SlidingPuzzleEnv constructor + cfg: + # Game generation parameters + game_config: + size: 5 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) + shuffle_moves: 15 # Number of random moves to shuffle the solved state + # Gameplay parameters + max_moves: 70 # Maximum moves allowed per episode + +grpo: + max_num_steps: 10000 + max_rollout_turns: 50 + +data: + add_system_prompt: false + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo-sliding-puzzle" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + +policy: + model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 512 + train_micro_batch_size: 1 + generation_batch_size: 32 # Only used when generating using HF backend + logprob_batch_size: 4 + max_total_sequence_length: 3072 + max_turns: 70 + precision: "bfloat16" + fsdp_offload_enabled: false + activation_checkpointing_enabled: false + + dtensor_cfg: + enabled: true + cpu_offload: False + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 0.999 + top_k: 10000 + stop_token_ids: null + stop_strings: null + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "grpo-dev-sliding_puzzle" + tensorboard: {} + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 1 + num_nodes: 1 + diff --git a/examples/run_grpo_sliding_puzzle.py b/examples/run_grpo_sliding_puzzle.py new file mode 100644 index 0000000000..abd468881f --- /dev/null +++ b/examples/run_grpo_sliding_puzzle.py @@ -0,0 +1,278 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +import itertools +from typing import Any, Dict, Tuple, Iterator +import random + +from omegaconf import OmegaConf +from transformers import AutoTokenizer + +from torch.utils.data import IterableDataset + +from nemo_reinforcer.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_reinforcer.algorithms.utils import get_tokenizer + +from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.models.generation.interfaces import configure_generation_config +from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides +from nemo_reinforcer.utils.logger import get_next_experiment_dir + +from nemo_reinforcer.environments.games.sliding_puzzle import ( + SlidingPuzzleGameLogic, + SlidingPuzzleEnv, + SlidingPuzzleConfig, + SlidingPuzzleMetadata, +) +from nemo_reinforcer.data.interfaces import LLMMessageLogType, DatumSpec + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + args, overrides = parser.parse_known_args() + return args, overrides + + +def generate_puzzle_datum( + tokenizer, + game_config: SlidingPuzzleConfig, + max_moves: int, + task_name: str, + idx: int, + add_system_prompt: bool, +) -> DatumSpec: + """Generates a single sliding puzzle datum (prompt and metadata).""" + + def generate_random_config(max_config: Dict[str, Any]) -> Dict[str, Any]: + """Generate a random config for the sliding puzzle game.""" + shuffle_moves = random.randint(1, max_config.get("shuffle_moves")) + if shuffle_moves % 2 == 0: + shuffle_moves += 1 + return { + "size": random.randint(2, max_config.get("size", 3)), + "shuffle_moves": shuffle_moves, + } + + game_config = generate_random_config(game_config) + initial_game_state = SlidingPuzzleGameLogic.generate(game_config) + initial_render = SlidingPuzzleGameLogic.render(initial_game_state) + welcome_message = SlidingPuzzleGameLogic.init(initial_game_state) + puzzle_size = game_config.get("size", 3) + prompt_instructions = ( + f"{welcome_message}\n\n" + f"Current Board State:\n{initial_render}\n\n" + f"Reach the goal state where numbers are ordered 1 through {puzzle_size**2 - 1} " + f"with the empty space (0) at the bottom right.\n" + f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n" + f"After thinking, output your chosen action on a new line starting with '' like this:\nyour_action" + f"\nIf you just want to see the board, output view" + f"\nThink carefully step-by-step before acting.\n" + ) + initial_prompt_content = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt_instructions}], + tokenize=False, + add_system_prompt=add_system_prompt, + add_generation_prompt=True, + add_special_tokens=False, + ).strip() + tokenized_prompt = tokenizer( + initial_prompt_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + message_log: LLMMessageLogType = [ + { + "role": "user", + "content": initial_prompt_content, + "token_ids": tokenized_prompt, + } + ] + metadata = SlidingPuzzleMetadata( + game_state=initial_game_state, num_moves=0, max_moves=max_moves + ) + datum: DatumSpec = { + "message_log": message_log, + "length": len(tokenized_prompt), + "extra_env_info": metadata, + "loss_multiplier": 1.0, + "idx": idx, + "task_name": task_name, + "stop_strings": [""], + } + return datum + + +class IterablePuzzleDataset(IterableDataset): + """An IterableDataset that generates sliding puzzle data indefinitely.""" + + def __init__( + self, tokenizer, game_config, max_moves, task_name, add_system_prompt, length + ): + super().__init__() + self.tokenizer = tokenizer + self.game_config = game_config + self.max_moves = max_moves + self.task_name = task_name + self.add_system_prompt = add_system_prompt + self.length = length + + def __iter__(self) -> Iterator[DatumSpec]: + print(f"Starting IterablePuzzleDataset (indefinite generation).") + # Use itertools.count for an infinite index generator + for i in itertools.count(): + yield generate_puzzle_datum( + tokenizer=self.tokenizer, + game_config=self.game_config, + max_moves=self.max_moves, + task_name=self.task_name, + idx=i, + add_system_prompt=self.add_system_prompt, + ) + + def __len__(self): + return self.length + + +def setup_puzzle_data( + tokenizer: AutoTokenizer, + env_cfg: Dict[str, Any], + task_name: str, + length: int, + val_length: int, + add_system_prompt: bool, +) -> Tuple[IterableDataset, IterableDataset | None, Dict, Dict]: + """Sets up the iterable data generator and env map for the sliding puzzle task.""" + print("Setting up Sliding Puzzle iterable data and environment...") + env_config = env_cfg[task_name] + + print(f"Instantiating environment for task '{task_name}'...") + env = SlidingPuzzleEnv.options(num_gpus=0).remote(cfg=dict(env_config["cfg"])) + task_to_env = {task_name: env} + print(f"Environment '{task_name}' created.") + + print(f"Creating Sliding Puzzle dataset...") + training_dataset = IterablePuzzleDataset( + tokenizer=tokenizer, + game_config=dict(env_config["cfg"]["game_config"]), + max_moves=env_config["cfg"]["max_moves"], + task_name=task_name, + add_system_prompt=add_system_prompt, + length=length, + ) + print("Sliding Puzzle dataset created.") + + validation_dataset = IterablePuzzleDataset( + tokenizer=tokenizer, + game_config=dict(env_config["cfg"]["game_config"]), + max_moves=env_config["cfg"]["max_moves"], + task_name=task_name, + add_system_prompt=add_system_prompt, + length=val_length, + ) + val_task_to_env = task_to_env + + return training_dataset, validation_dataset, task_to_env, val_task_to_env + + +def main(): + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "grpo_sliding_puzzle.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"📊 Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + # setup tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # setup data & env map + ds_length = ( + config["grpo"]["num_prompts_per_step"] + * config["grpo"]["num_generations_per_prompt"] + * config["grpo"]["max_num_steps"] + ) + dataset, val_dataset, task_to_env, val_task_to_env = setup_puzzle_data( + tokenizer=tokenizer, + env_cfg=config["env"], + task_name="sliding_puzzle_game", + length=ds_length, + val_length=config["grpo"]["max_val_samples"], + add_system_prompt=config["data"]["add_system_prompt"], + ) + + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/environments/sliding_puzzle_game.py b/nemo_reinforcer/environments/games/sliding_puzzle.py similarity index 51% rename from tests/unit/environments/sliding_puzzle_game.py rename to nemo_reinforcer/environments/games/sliding_puzzle.py index 664e4c312b..7ee66e985c 100644 --- a/tests/unit/environments/sliding_puzzle_game.py +++ b/nemo_reinforcer/environments/games/sliding_puzzle.py @@ -12,13 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ray +import torch +from typing import Dict, List, Tuple, Optional, TypedDict, Any import random import copy -from typing import List, Tuple, Dict, Any, Optional -from .game_interface import GameInterface +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.data.interfaces import LLMMessageLogType +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) +from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES -class SlidingPuzzleGame(GameInterface): + +class SlidingPuzzleConfig(TypedDict): + size: int + shuffle_moves: int + + +class SlidingPuzzleMetadata(TypedDict): + game_state: Dict[str, Any] # Stores the dict returned by SlidingPuzzleGame methods + num_moves: int + max_moves: int + + +class SlidingPuzzleGameLogic: @staticmethod def generate(config: Dict[str, Any]) -> Dict[str, Any]: """Generate a new Sliding Puzzle.""" @@ -94,7 +114,7 @@ def step( # Default return values response = "Unknown command. Type 'help' to see available commands." - reward = -0.05 # Small penalty for invalid actions + reward = 0.0 # No penalty for invalid actions is_terminated = False # Deep copy game state to avoid modifying the original @@ -218,39 +238,171 @@ def render(game_state: Dict[str, Any]) -> str: return "\n".join(output) -def is_solvable(grid: List[List[int]], size: int) -> bool: - """Check if a sliding puzzle is solvable.""" - # Flatten the grid - flat = [num for row in grid for num in row if num != 0] - - # Count inversions - inversions = 0 - for i in range(len(flat)): - for j in range(i + 1, len(flat)): - if flat[i] > flat[j]: - inversions += 1 - - # Find row of the empty tile (0) from the bottom - empty_row = 0 - for i in range(size - 1, -1, -1): - for j in range(size): - if grid[i][j] == 0: - empty_row = size - i - break - - # For odd-sized grids, the puzzle is solvable if the number of inversions is even - if size % 2 == 1: - return inversions % 2 == 0 - # For even-sized grids, the puzzle is solvable if: - # (inversions odd and empty on even row from bottom) or (inversions even and empty on odd row from bottom) - else: - return (inversions % 2 == 1 and empty_row % 2 == 0) or ( - inversions % 2 == 0 and empty_row % 2 == 1 +class SlidingPuzzleRunner: + def __init__(self): + pass # No initialization needed as game methods are static + + def _parse_action(self, text: str) -> Optional[str]: + """Parses the action from ''.""" + prefix = "" + suffix = "" + # Find the prefix, case-insensitive, and potentially after some thought process + text_lower = text.lower() + prefix_lower = prefix.lower() + suffix_lower = suffix.lower() + + start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence + + if start_idx != -1: + # Find the end tag after the start tag + end_idx = text_lower.find(suffix_lower, start_idx + len(prefix_lower)) + if end_idx != -1: + # Extract content between tags + action_content = text[start_idx + len(prefix) : end_idx].strip() + return action_content + return None + + def process_turn( + self, + message_log: LLMMessageLogType, + metadata: SlidingPuzzleMetadata, + ) -> Tuple[ + Dict[str, str], + float, + bool, + Optional[List[str]], + Optional[SlidingPuzzleMetadata], + ]: + """Processes a single turn for the sliding puzzle task.""" + game_state = metadata["game_state"] + current_moves = metadata["num_moves"] + max_moves = metadata["max_moves"] + + turn_reward = 0.0 + is_terminated = False + next_stop_strings = [""] + next_metadata = metadata.copy() + next_observation_content = "" + + # Check if max moves reached + if current_moves >= max_moves: + is_terminated = True + next_observation_content = ( + f"Maximum moves ({max_moves}) reached." + ) + next_metadata = None + return ( + {"role": "environment", "content": next_observation_content}, + 0.0, + is_terminated, + None, + next_metadata, + ) + + # Get last assistant message and parse action + last_assistant_msg_content = "" + if message_log and message_log[-1]["role"] == "assistant": + last_assistant_msg_content = message_log[-1]["content"].strip() + + parsed_action = self._parse_action(last_assistant_msg_content) + + if parsed_action is None: + # Handle cases where parsing failed or it wasn't assistant's turn properly + # is_terminated = True # Penalize for bad format + rendered_board = SlidingPuzzleGameLogic.render(game_state) + next_observation_content = f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action" + next_metadata = None + elif parsed_action == "view": + rendered_board = SlidingPuzzleGameLogic.render(game_state) + next_observation_content = f"\n{rendered_board}\n\nViewing the board. No move made." + else: + # Execute the game step + step_response, reward, game_over, next_game_state = ( + SlidingPuzzleGameLogic.step(parsed_action, game_state) + ) + + turn_reward = reward + is_terminated = game_over + next_metadata["game_state"] = next_game_state + next_metadata["num_moves"] = current_moves + 1 + + # Combine rendered board and step response for the next observation + rendered_board = SlidingPuzzleGameLogic.render(next_game_state) + next_observation_content = f"\n{step_response}\n" + + if is_terminated: + next_metadata = None # Clear metadata on termination + + return ( + {"role": "environment", "content": next_observation_content + "\n"}, + turn_reward, + is_terminated, + next_stop_strings, + next_metadata, ) -def play_sliding_puzzle(config=None): - """Wrapper function for backward compatibility.""" - from play_game import play_game +@ray.remote +class SlidingPuzzleEnv(EnvironmentInterface): + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM + """Sliding Puzzle environment (Ray Actor).""" - play_game(SlidingPuzzleGame, config) + def __init__(self, cfg: Optional[SlidingPuzzleConfig] = None): + # cfg could contain game generation config like {'size': 3, 'shuffle_moves': 50} + self.game_config = cfg.get("game_config", {}) if cfg else {} + self.runner = SlidingPuzzleRunner() + + def step( + self, + message_log_batch: List[LLMMessageLogType], + metadata_batch: List[SlidingPuzzleMetadata], + ) -> EnvironmentReturn: + """Processes a batch of sliding puzzle interactions.""" + # Since logic is synchronous, process sequentially (can parallelize if logic becomes heavy) + results = [ + self.runner.process_turn(log, meta) + for log, meta in zip(message_log_batch, metadata_batch) + ] + + # Unpack results and format according to EnvironmentReturn NamedTuple + observations = [] + rewards = [] + terminateds = [] + all_stop_strings = [] + all_next_metadata = [] + + for obs, rew, term, stops, meta in results: + observations.append(obs) + rewards.append(rew) + terminateds.append(term) + all_stop_strings.append(stops) + all_next_metadata.append(meta) + + rewards_tensor = torch.tensor(rewards, dtype=torch.float32) + terminated_tensor = torch.tensor(terminateds, dtype=torch.bool) + + return EnvironmentReturn( + observations=observations, + metadata=all_next_metadata, + next_stop_strings=all_stop_strings, + rewards=rewards_tensor, + terminateds=terminated_tensor, + ) + + def shutdown(self): + pass + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + # Calculate success rate based on final reward == 1.0 + final_rewards = batch.get( + "total_reward", torch.tensor([0.0] * len(batch["idx"])) + ) + success_rate = ( + (final_rewards == 1.0).float().mean().item() + if len(final_rewards) > 0 + else 0.0 + ) + # Could also calculate average number of moves for successful episodes, etc. + return batch, {"sliding_puzzle_success_rate": success_rate} diff --git a/tests/unit/environments/game_interface.py b/tests/unit/environments/game_interface.py deleted file mode 100644 index 2f0237ed23..0000000000 --- a/tests/unit/environments/game_interface.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Any, Tuple, List, Optional, Callable - - -class GameInterface: - @staticmethod - def generate(config: Dict[str, Any]) -> Dict[str, Any]: - """ - Generate a new game state based on configuration. - - Args: - config: Game configuration dictionary - - Returns: - A dictionary containing the complete game state - """ - raise NotImplementedError("Each game must implement generate()") - - @staticmethod - def init(game_state: Dict[str, Any]) -> str: - """ - Initialize a game and return welcome messages. - - Args: - game_state: The game state dictionary - - Returns: - String containing welcome message and instructions - """ - raise NotImplementedError("Each game must implement init()") - - @staticmethod - def step( - action: str, game_state: Dict[str, Any] - ) -> Tuple[str, float, bool, Dict[str, Any]]: - """ - Process a game action and update the state. - - Args: - action: String representing the player's action - game_state: Current game state dictionary - - Returns: - Tuple containing: - - Response message - - Reward for this action - - Boolean indicating if game is terminated - - Updated game state - """ - raise NotImplementedError("Each game must implement step()") - - @staticmethod - def render(game_state: Dict[str, Any]) -> str: - """ - Render the current game state as a string. - - Args: - game_state: The game state dictionary - - Returns: - String representation of the game state - """ - raise NotImplementedError("Each game must implement render()") diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index daeecb2bc6..be3be2b4a6 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -32,12 +32,15 @@ MultiStepCalculatorEnv, _MultiStepCalculatorLogic, MultiStepCalcMetadata, - SlidingPuzzleEnv, - SlidingPuzzleMetadata, ) # Import the game logic for generating initial state from its new location -from tests.unit.environments.sliding_puzzle_game import SlidingPuzzleGame +from nemo_reinforcer.environments.games.sliding_puzzle import ( + SlidingPuzzleGameLogic, + SlidingPuzzleEnv, + SlidingPuzzleConfig, + SlidingPuzzleMetadata, +) from nemo_reinforcer.models.generation.vllm import VllmConfig, VllmGeneration @@ -463,16 +466,16 @@ def sliding_puzzle_environment(rollout_cluster): def initial_sliding_puzzle_batch(rollout_tokenizer): print("Creating initial sliding puzzle test batch...") batch_size = 1 - game_config = { + game_config: SlidingPuzzleConfig = { "size": 2, "shuffle_moves": 1, } max_moves = 10 # Set a limit for the test # Generate initial game state - initial_game_state = SlidingPuzzleGame.generate(game_config) - initial_render = SlidingPuzzleGame.render(initial_game_state) - welcome_message = SlidingPuzzleGame.init(initial_game_state) + initial_game_state = SlidingPuzzleGameLogic.generate(game_config) + initial_render = SlidingPuzzleGameLogic.render(initial_game_state) + welcome_message = SlidingPuzzleGameLogic.init(initial_game_state) prompt_instructions = ( f"{welcome_message}\n\n" diff --git a/tests/unit/test_envs.py b/tests/unit/test_envs.py index 5d410f5895..5fe62f135b 100644 --- a/tests/unit/test_envs.py +++ b/tests/unit/test_envs.py @@ -23,7 +23,6 @@ EnvironmentReturn, ) from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES -from .environments.sliding_puzzle_game import SlidingPuzzleGame class MultiStepCalcMetadata(TypedDict): @@ -33,12 +32,6 @@ class MultiStepCalcMetadata(TypedDict): current_step: int -class SlidingPuzzleMetadata(TypedDict): - game_state: Dict[str, Any] # Stores the dict returned by SlidingPuzzleGame methods - num_moves: int - max_moves: int - - class _MultiStepCalculatorLogic: def __init__(self): pass @@ -183,113 +176,6 @@ def process_turn( ) -class _SlidingPuzzleLogic: - def __init__(self): - pass # No initialization needed as game methods are static - - def _parse_action(self, text: str) -> Optional[str]: - """Parses the action from ''""" - prefix = "" - suffix = "" - # Find the prefix, case-insensitive, and potentially after some thought process - text_lower = text.lower() - prefix_lower = prefix.lower() - suffix_lower = suffix.lower() - - start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence - - if start_idx != -1: - # Find the end tag after the start tag - end_idx = text_lower.find(suffix_lower, start_idx + len(prefix_lower)) - if end_idx != -1: - # Extract content between tags - action_content = text[start_idx + len(prefix) : end_idx].strip() - return action_content - return None - - def process_turn( - self, - message_log: LLMMessageLogType, - metadata: SlidingPuzzleMetadata, - ) -> Tuple[ - Dict[str, str], - float, - bool, - Optional[List[str]], - Optional[SlidingPuzzleMetadata], - ]: - """Processes a single turn for the sliding puzzle task.""" - game_state = metadata["game_state"] - current_moves = metadata["num_moves"] - max_moves = metadata["max_moves"] - - turn_reward = 0.0 - is_terminated = False - next_stop_strings = [""] - next_metadata = metadata.copy() - next_observation_content = "" - - # Check if max moves reached - if current_moves >= max_moves: - is_terminated = True - next_observation_content = ( - f"Maximum moves ({max_moves}) reached." - ) - next_metadata = None - return ( - {"role": "environment", "content": next_observation_content}, - 0.0, - is_terminated, - None, - next_metadata, - ) - - # Get last assistant message and parse action - last_assistant_msg_content = "" - if message_log and message_log[-1]["role"] == "assistant": - last_assistant_msg_content = message_log[-1]["content"].strip() - - parsed_action = self._parse_action(last_assistant_msg_content) - - if parsed_action is None: - # Handle cases where parsing failed or it wasn't assistant's turn properly - # is_terminated = True # Penalize for bad format - rendered_board = SlidingPuzzleGame.render(game_state) - next_observation_content = f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action" - next_metadata = None - elif parsed_action == "view": - rendered_board = SlidingPuzzleGame.render(game_state) - next_observation_content = f"\n{rendered_board}\n\nViewing the board. No move made." - else: - # Execute the game step - step_response, reward, game_over, next_game_state = SlidingPuzzleGame.step( - parsed_action, game_state - ) - - turn_reward = reward - is_terminated = game_over - next_metadata["game_state"] = next_game_state - next_metadata["num_moves"] = current_moves + 1 - - # Combine rendered board and step response for the next observation - rendered_board = SlidingPuzzleGame.render(next_game_state) - # next_observation_content = f"\n{rendered_board}\n\n{step_response}" - next_observation_content = f"\n{step_response}\n" - # next_observation_content = f"\n{step_response}" - - if is_terminated: - next_metadata = None # Clear metadata on termination - # next_stop_strings remains None - - return ( - {"role": "environment", "content": next_observation_content + "\n"}, - turn_reward, - is_terminated, - next_stop_strings, - next_metadata, - ) - - @ray.remote class MultiStepCalculatorEnv(EnvironmentInterface): DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM @@ -350,69 +236,3 @@ def global_post_process_and_metrics( ) success_rate = final_rewards.mean().item() if len(final_rewards) > 0 else 0.0 return batch, {"success_rate": success_rate} - - -@ray.remote -class SlidingPuzzleEnv(EnvironmentInterface): - DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM - """Sliding Puzzle environment (Ray Actor).""" - - def __init__(self, cfg: Optional[Dict] = None): - # cfg could contain game generation config like {'size': 3, 'shuffle_moves': 50} - self.game_config = cfg.get("game_config", {}) if cfg else {} - self.logic = _SlidingPuzzleLogic() - - def step( - self, - message_log_batch: List[LLMMessageLogType], - metadata_batch: List[SlidingPuzzleMetadata], - ) -> EnvironmentReturn: - """Processes a batch of sliding puzzle interactions.""" - # Since logic is synchronous, process sequentially (can parallelize if logic becomes heavy) - results = [ - self.logic.process_turn(log, meta) - for log, meta in zip(message_log_batch, metadata_batch) - ] - - # Unpack results and format according to EnvironmentReturn NamedTuple - observations = [] - rewards = [] - terminateds = [] - all_stop_strings = [] - all_next_metadata = [] - - for obs, rew, term, stops, meta in results: - observations.append(obs) - rewards.append(rew) - terminateds.append(term) - all_stop_strings.append(stops) - all_next_metadata.append(meta) - - rewards_tensor = torch.tensor(rewards, dtype=torch.float32) - terminated_tensor = torch.tensor(terminateds, dtype=torch.bool) - - return EnvironmentReturn( - observations=observations, - metadata=all_next_metadata, - next_stop_strings=all_stop_strings, - rewards=rewards_tensor, - terminateds=terminated_tensor, - ) - - def shutdown(self): - pass - - def global_post_process_and_metrics( - self, batch: BatchedDataDict - ) -> Tuple[BatchedDataDict, dict]: - # Calculate success rate based on final reward == 1.0 - final_rewards = batch.get( - "total_reward", torch.tensor([0.0] * len(batch["idx"])) - ) - success_rate = ( - (final_rewards == 1.0).float().mean().item() - if len(final_rewards) > 0 - else 0.0 - ) - # Could also calculate average number of moves for successful episodes, etc. - return batch, {"sliding_puzzle_success_rate": success_rate} From a0f3ecb175eba9d0f461bdef04411c3cc1109167 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Tue, 22 Apr 2025 02:53:37 -0700 Subject: [PATCH 26/34] config update Signed-off-by: Sahil Jain --- examples/configs/grpo_sliding_puzzle.yaml | 46 +++++++++++++---------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index b218a5d060..fa5f2caeba 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -1,24 +1,20 @@ -defaults: "grpo_math_1B.yaml" - -# Environment setup: Map task names to their configurations -env: - sliding_puzzle_game: - env_class: "tests.unit.test_envs.SlidingPuzzleEnv" # Path to the environment actor class - # Configuration passed to the SlidingPuzzleEnv constructor - cfg: - # Game generation parameters - game_config: - size: 5 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) - shuffle_moves: 15 # Number of random moves to shuffle the solved state - # Gameplay parameters - max_moves: 70 # Maximum moves allowed per episode - +# GRPO Algorithm Configuration grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_rollout_turns: 50 # Maximum turns allowed per rollout max_num_steps: 10000 - max_rollout_turns: 50 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + max_val_samples: 256 + val_batch_size: 256 -data: - add_system_prompt: false +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_eps_min: 0.2 + ratio_eps_max: 0.2 checkpointing: enabled: true @@ -29,7 +25,7 @@ checkpointing: save_period: 10 policy: - model_name: "meta-llama/Llama-3.2-1B-Instruct" + model_name: "Qwen/Qwen2.5-1.5B-Instruct" tokenizer: name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 @@ -37,7 +33,6 @@ policy: generation_batch_size: 32 # Only used when generating using HF backend logprob_batch_size: 4 max_total_sequence_length: 3072 - max_turns: 70 precision: "bfloat16" fsdp_offload_enabled: false activation_checkpointing_enabled: false @@ -91,6 +86,17 @@ policy: gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} +data: + add_system_prompt: false + +env: + sliding_puzzle_game: + cfg: + game_config: + size: 5 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) + shuffle_moves: 15 # Number of random moves to shuffle the solved state + max_moves: 50 # Maximum moves allowed per episode + logger: log_dir: "logs" # Base directory for all logs num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal From 1e3ce54a19fd964e8dd837c8712bc0b9c3564247 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Tue, 22 Apr 2025 14:11:18 -0700 Subject: [PATCH 27/34] removed vestigial Signed-off-by: Sahil Jain --- tests/unit/environments/game_interface.py | 76 ------ .../unit/environments/sliding_puzzle_game.py | 256 ------------------ 2 files changed, 332 deletions(-) delete mode 100644 tests/unit/environments/game_interface.py delete mode 100644 tests/unit/environments/sliding_puzzle_game.py diff --git a/tests/unit/environments/game_interface.py b/tests/unit/environments/game_interface.py deleted file mode 100644 index 2f0237ed23..0000000000 --- a/tests/unit/environments/game_interface.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Any, Tuple, List, Optional, Callable - - -class GameInterface: - @staticmethod - def generate(config: Dict[str, Any]) -> Dict[str, Any]: - """ - Generate a new game state based on configuration. - - Args: - config: Game configuration dictionary - - Returns: - A dictionary containing the complete game state - """ - raise NotImplementedError("Each game must implement generate()") - - @staticmethod - def init(game_state: Dict[str, Any]) -> str: - """ - Initialize a game and return welcome messages. - - Args: - game_state: The game state dictionary - - Returns: - String containing welcome message and instructions - """ - raise NotImplementedError("Each game must implement init()") - - @staticmethod - def step( - action: str, game_state: Dict[str, Any] - ) -> Tuple[str, float, bool, Dict[str, Any]]: - """ - Process a game action and update the state. - - Args: - action: String representing the player's action - game_state: Current game state dictionary - - Returns: - Tuple containing: - - Response message - - Reward for this action - - Boolean indicating if game is terminated - - Updated game state - """ - raise NotImplementedError("Each game must implement step()") - - @staticmethod - def render(game_state: Dict[str, Any]) -> str: - """ - Render the current game state as a string. - - Args: - game_state: The game state dictionary - - Returns: - String representation of the game state - """ - raise NotImplementedError("Each game must implement render()") diff --git a/tests/unit/environments/sliding_puzzle_game.py b/tests/unit/environments/sliding_puzzle_game.py deleted file mode 100644 index 664e4c312b..0000000000 --- a/tests/unit/environments/sliding_puzzle_game.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -import copy -from typing import List, Tuple, Dict, Any, Optional -from .game_interface import GameInterface - - -class SlidingPuzzleGame(GameInterface): - @staticmethod - def generate(config: Dict[str, Any]) -> Dict[str, Any]: - """Generate a new Sliding Puzzle.""" - size = config.get("size", 4) # Default to 4x4 (15-puzzle) - shuffle_moves = config.get( - "shuffle_moves", 100 - ) # Number of random moves for shuffling - - # Create the solved state - grid = [[(r * size + c + 1) for c in range(size)] for r in range(size)] - # Set the bottom-right corner to 0 (empty space) - grid[size - 1][size - 1] = 0 - - # Save the solution - solution = [row[:] for row in grid] - - # Find the empty space - empty_pos = (size - 1, size - 1) - - # Shuffle the grid with valid moves - for _ in range(shuffle_moves): - # Get possible moves - moves = [] - r, c = empty_pos - for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]: # Right, Down, Left, Up - nr, nc = r + dr, c + dc - if 0 <= nr < size and 0 <= nc < size: - moves.append((nr, nc)) - - # Choose a random move - if moves: - new_r, new_c = random.choice(moves) - # Swap the empty space with the chosen tile - grid[r][c], grid[new_r][new_c] = grid[new_r][new_c], grid[r][c] - empty_pos = (new_r, new_c) - - # Create and return the game state - return { - "size": size, - "grid": grid, - "solution": solution, - "empty_pos": empty_pos, - "commands": { - "slide r c": "Slide tile at row r, column c into the empty space", - "up": "Slide tile below empty space up", - "down": "Slide tile above empty space down", - "left": "Slide tile to the right of empty space left", - "right": "Slide tile to the left of empty space right", - }, - } - - @staticmethod - def init(game_state: Dict[str, Any]) -> str: - """Initialize Sliding Puzzle game and return welcome message.""" - size = game_state["size"] - - return ( - f"\n===== SLIDING PUZZLE =====\n" - f"Arrange the {size}x{size} grid by sliding tiles into the empty space.\n" - f"- The goal is to arrange numbers from 1 to {size * size - 1} in order\n" - f"- Use 'slide r c' to slide a specific tile\n" - f"- Or use 'up', 'down', 'left', 'right' to slide in that direction" - ) - - @staticmethod - def step( - action: str, game_state: Dict[str, Any] - ) -> Tuple[str, float, bool, Dict[str, Any]]: - """Process an action in the Sliding Puzzle game.""" - size = game_state["size"] - grid = game_state["grid"] - empty_r, empty_c = game_state["empty_pos"] - - # Default return values - response = "Unknown command. Type 'help' to see available commands." - reward = -0.05 # Small penalty for invalid actions - is_terminated = False - - # Deep copy game state to avoid modifying the original - new_state = copy.deepcopy(game_state) - - move_made = False - - if action.startswith("slide "): - try: - _, r, c = action.split() - r, c = int(r) - 1, int(c) - 1 - - # Validate input - if not (0 <= r < size and 0 <= c < size): - return ( - f"Invalid position. Row/column must be between 1 and {size}.", - reward, - is_terminated, - new_state, - ) - - # Check if tile is adjacent to empty space - if abs(r - empty_r) + abs(c - empty_c) != 1: - return ( - "Tile must be adjacent to the empty space.", - reward, - is_terminated, - new_state, - ) - - # Slide the tile - new_state["grid"][empty_r][empty_c] = grid[r][c] - new_state["grid"][r][c] = 0 - new_state["empty_pos"] = (r, c) - - move_made = True - response = f"Slid tile {grid[r][c]} into the empty space." - - except ValueError: - return ( - "Invalid input format. Use: slide row col", - reward, - is_terminated, - new_state, - ) - - elif action in ["up", "down", "left", "right"]: - # Convert direction to row/col offset - if action == "up": - r, c = empty_r + 1, empty_c # Tile below moves up - dir_text = "up" - elif action == "down": - r, c = empty_r - 1, empty_c # Tile above moves down - dir_text = "down" - elif action == "left": - r, c = empty_r, empty_c + 1 # Tile to right moves left - dir_text = "left" - elif action == "right": - r, c = empty_r, empty_c - 1 # Tile to left moves right - dir_text = "right" - - # Check if the move is valid - if 0 <= r < size and 0 <= c < size: - # Slide the tile - new_state["grid"][empty_r][empty_c] = grid[r][c] - new_state["grid"][r][c] = 0 - new_state["empty_pos"] = (r, c) - - move_made = True - response = f"Slid tile {grid[r][c]} {dir_text}." - else: - return f"Cannot slide {dir_text}.", reward, is_terminated, new_state - - if move_made: - reward = 0 - - # Check if puzzle is solved - if new_state["grid"] == new_state["solution"]: - response = "Congratulations! You've solved the puzzle!" - reward = 1.0 # Win reward - is_terminated = True - - return response, reward, is_terminated, new_state - - @staticmethod - def render(game_state: Dict[str, Any]) -> str: - """Render the current Sliding Puzzle game state.""" - grid = game_state["grid"] - size = game_state["size"] - - output = ["\n"] - - # Create a visual representation of the grid - max_digits = len(str(size * size - 1)) - - # Top border - output.append(" " + "+" + "-" * (max_digits + 2) * size + "+") - - # Rows - for i, row in enumerate(grid): - row_str = f"{i + 1} |" - for val in row: - if val == 0: - # Empty space - row_str += " " * (max_digits + 2) - else: - # Tile with number - row_str += f" {val:>{max_digits}} " - row_str += "|" - output.append(row_str) - - # Bottom border - output.append(" " + "+" + "-" * (max_digits + 2) * size + "+") - - # Column labels - col_labels = " " - for i in range(size): - col_labels += f"{i + 1:^{max_digits + 2}}" - output.append(col_labels) - - return "\n".join(output) - - -def is_solvable(grid: List[List[int]], size: int) -> bool: - """Check if a sliding puzzle is solvable.""" - # Flatten the grid - flat = [num for row in grid for num in row if num != 0] - - # Count inversions - inversions = 0 - for i in range(len(flat)): - for j in range(i + 1, len(flat)): - if flat[i] > flat[j]: - inversions += 1 - - # Find row of the empty tile (0) from the bottom - empty_row = 0 - for i in range(size - 1, -1, -1): - for j in range(size): - if grid[i][j] == 0: - empty_row = size - i - break - - # For odd-sized grids, the puzzle is solvable if the number of inversions is even - if size % 2 == 1: - return inversions % 2 == 0 - # For even-sized grids, the puzzle is solvable if: - # (inversions odd and empty on even row from bottom) or (inversions even and empty on odd row from bottom) - else: - return (inversions % 2 == 1 and empty_row % 2 == 0) or ( - inversions % 2 == 0 and empty_row % 2 == 1 - ) - - -def play_sliding_puzzle(config=None): - """Wrapper function for backward compatibility.""" - from play_game import play_game - - play_game(SlidingPuzzleGame, config) From fe73757c7a24d7cf3418c378bca7a69066fad34c Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Tue, 22 Apr 2025 17:18:27 -0700 Subject: [PATCH 28/34] Cleanup Signed-off-by: Sahil Jain --- examples/configs/grpo_sliding_puzzle.yaml | 66 ++----------------- .../environments/games/sliding_puzzle.py | 10 +-- 2 files changed, 7 insertions(+), 69 deletions(-) diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index fa5f2caeba..d981b774b0 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -1,20 +1,11 @@ # GRPO Algorithm Configuration +defaults: "grpo_math_1B.yaml" + grpo: num_prompts_per_step: 32 num_generations_per_prompt: 16 max_rollout_turns: 50 # Maximum turns allowed per rollout max_num_steps: 10000 - normalize_rewards: true - use_leave_one_out_baseline: true - val_period: 10 - val_at_start: false - max_val_samples: 256 - val_batch_size: 256 - -loss_fn: - reference_policy_kl_penalty: 0.01 - ratio_eps_min: 0.2 - ratio_eps_max: 0.2 checkpointing: enabled: true @@ -25,58 +16,14 @@ checkpointing: save_period: 10 policy: - model_name: "Qwen/Qwen2.5-1.5B-Instruct" - tokenizer: - name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default - train_global_batch_size: 512 - train_micro_batch_size: 1 - generation_batch_size: 32 # Only used when generating using HF backend - logprob_batch_size: 4 max_total_sequence_length: 3072 - precision: "bfloat16" - fsdp_offload_enabled: false - activation_checkpointing_enabled: false - - dtensor_cfg: - enabled: true - cpu_offload: False - sequence_parallel: false - activation_checkpointing: false - tensor_parallel_size: 1 - - # makes the training sequence length divisible by the tensor parallel size - # this is useful for sequence parallel training - make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} - max_grad_norm: 1.0 - - optimizer: - name: "torch.optim.AdamW" - kwargs: - lr: 5.0e-6 - weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1e-8 - # when using Dtensor, we need to set foreach - # and fused to False - foreach: False - fused: False - - scheduler: - - name: "torch.optim.lr_scheduler.LinearLR" - kwargs: - start_factor: 0.1 - end_factor: 1.0 - total_iters: 50 - - name: "torch.optim.lr_scheduler.ConstantLR" - kwargs: - factor: 1.0 - total_iters: 10000000000 - - milestones: [50] generation: backend: "vllm" max_new_tokens: ${policy.max_total_sequence_length} temperature: 1.0 + # Setting top_p/top_k to 0.999/10000 to strip out Qwen's special/illegal tokens + # https://github.com/NVIDIA/reinforcer/issues/237 top_p: 0.999 top_k: 10000 stop_token_ids: null @@ -110,8 +57,3 @@ logger: gpu_monitoring: collection_interval: 10 # How often to collect GPU usage metrics (in seconds) flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) - -cluster: - gpus_per_node: 1 - num_nodes: 1 - diff --git a/nemo_reinforcer/environments/games/sliding_puzzle.py b/nemo_reinforcer/environments/games/sliding_puzzle.py index 7ee66e985c..0bb595bc0c 100644 --- a/nemo_reinforcer/environments/games/sliding_puzzle.py +++ b/nemo_reinforcer/environments/games/sliding_puzzle.py @@ -82,11 +82,11 @@ def generate(config: Dict[str, Any]) -> Dict[str, Any]: "solution": solution, "empty_pos": empty_pos, "commands": { - "slide r c": "Slide tile at row r, column c into the empty space", "up": "Slide tile below empty space up", "down": "Slide tile above empty space down", "left": "Slide tile to the right of empty space left", "right": "Slide tile to the left of empty space right", + "view": "View the current state of the board", }, } @@ -99,8 +99,8 @@ def init(game_state: Dict[str, Any]) -> str: f"\n===== SLIDING PUZZLE =====\n" f"Arrange the {size}x{size} grid by sliding tiles into the empty space.\n" f"- The goal is to arrange numbers from 1 to {size * size - 1} in order\n" - f"- Use 'slide r c' to slide a specific tile\n" - f"- Or use 'up', 'down', 'left', 'right' to slide in that direction" + f"- Use 'up', 'down', 'left', 'right' to slide in that direction\n" + f"- Use 'view' to see the current state of the board" ) @staticmethod @@ -307,8 +307,6 @@ def process_turn( parsed_action = self._parse_action(last_assistant_msg_content) if parsed_action is None: - # Handle cases where parsing failed or it wasn't assistant's turn properly - # is_terminated = True # Penalize for bad format rendered_board = SlidingPuzzleGameLogic.render(game_state) next_observation_content = f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action" next_metadata = None @@ -326,8 +324,6 @@ def process_turn( next_metadata["game_state"] = next_game_state next_metadata["num_moves"] = current_moves + 1 - # Combine rendered board and step response for the next observation - rendered_board = SlidingPuzzleGameLogic.render(next_game_state) next_observation_content = f"\n{step_response}\n" if is_terminated: From 67b37431cbe3d8e8e826067bfdaaf5645edfe383 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Tue, 22 Apr 2025 17:26:04 -0700 Subject: [PATCH 29/34] Update README Signed-off-by: Sahil Jain --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index eef9fc1b39..04786d3199 100644 --- a/README.md +++ b/README.md @@ -34,9 +34,9 @@ What you can expect: - ✅ **HuggingFace Integration** - Works with 1-8B models (Qwen1.5, Llama) - ✅ **Distributed Training** - FSDP support and Ray-based infrastructure - ✅ **Environment Support** - Support for multi-environment training. -- ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning) +- ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization), SFT (Supervised Fine-Tuning), and DPO (Direct Preference Optimization) +- ✅ **Multi-Turn RL** - multi-turn generation and training for RL with tool use, games, etc. - ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state) -- ✅ **DPO Algorithm** - Direct Preference Optimization for alignment - 🔜 **Larger Model Support** - Native PyTorch support for models up to 70B parameters - 🔜 **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training - 🔜 **Environment Isolation** - Dependency isolation between components @@ -117,6 +117,12 @@ sbatch \ ray.sub ``` +We also support multi-turn generation and training (tool use, games, etc.). +Reference example for training to play a Sliding Puzzle Game: +```sh +uv run python examples/run_grpo_sliding_puzzle.py +``` + ### SFT We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). From 1bbd5d0baa81f65475d5356abebdbb807a7af1a9 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Wed, 23 Apr 2025 17:47:43 -0700 Subject: [PATCH 30/34] Added functional sliding puzzle test Signed-off-by: Sahil Jain --- .github/workflows/cicd-main.yml | 1 + tests/functional/grpo_multiturn.sh | 36 ++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100755 tests/functional/grpo_multiturn.sh diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 54d81f46ee..8b5bffc328 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -170,6 +170,7 @@ jobs: if [[ "${{ needs.pre-flight.outputs.test_level }}" =~ ^(L1|L2)$ ]]; then uv run --no-sync bash ./tests/functional/sft.sh uv run --no-sync bash ./tests/functional/grpo.sh + uv run --no-sync bash ./tests/functional/grpo_multiturn.sh uv run --no-sync bash ./tests/functional/dpo.sh else echo Skipping functional tests for level ${{ needs.pre-flight.outputs.test_level }} diff --git a/tests/functional/grpo_multiturn.sh b/tests/functional/grpo_multiturn.sh new file mode 100755 index 0000000000..b61bb0ff65 --- /dev/null +++ b/tests/functional/grpo_multiturn.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetchs metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +LOG_DIR=$SCRIPT_DIR/$(basename $0 .sh)-logs +JSON_METRICS=$LOG_DIR/$(basename $0 .sh).json +RUN_LOG=$LOG_DIR/$(basename $0 .sh).log +export UV_CACHE_DIR=${UV_CACHE_DIR:-$PROJECT_ROOT/uv_cache} +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $LOG_DIR +mkdir -p $LOG_DIR + +cd $PROJECT_ROOT +python -u $PROJECT_ROOT/examples/run_grpo_sliding_puzzle.py \ + cluster.gpus_per_node=2 \ + grpo.max_rollout_turns=10 \ + grpo.max_num_steps=3 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +cd $SCRIPT_DIR +python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +python check_metrics.py $JSON_METRICS \ + 'max(data["train/token_mult_prob_error"]) < 1.1' \ + From f3634315262e14ba65fb3dec31540f56a5b28a0b Mon Sep 17 00:00:00 2001 From: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Date: Wed, 23 Apr 2025 19:11:26 -0700 Subject: [PATCH 31/34] Update README.md Signed-off-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0ba6d1c745..bfd0ccb668 100644 --- a/README.md +++ b/README.md @@ -37,9 +37,10 @@ What you can expect: - ✅ **Environment Support** - Support for multi-environment training. - ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization), SFT (Supervised Fine-Tuning), and DPO (Direct Preference Optimization) - ✅ **Multi-Turn RL** - multi-turn generation and training for RL with tool use, games, etc. -- ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state) - ✅ **Large Model Support** - Native PyTorch support for models up to 32B parameters -- ✅ **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training +- ✅ **Advanced Parallelism** - FSDP2, TP, and SP for efficient training +- ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state) +- ✅ **Environment Isolation** - Dependency isolation between components - 🔜 **(Even) Larger Model Support** - Native PyTorch & Megatron - 🔜 **Improved Native Performance** - Improve training time for Native Pytorch Models From 8db7240729f402859eb57a8fd9307e49c2ee8620 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 24 Apr 2025 00:31:58 -0700 Subject: [PATCH 32/34] reduced seqlen for small functional test machine Signed-off-by: Sahil Jain --- tests/functional/grpo_multiturn.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/functional/grpo_multiturn.sh b/tests/functional/grpo_multiturn.sh index b61bb0ff65..2f260bfe07 100755 --- a/tests/functional/grpo_multiturn.sh +++ b/tests/functional/grpo_multiturn.sh @@ -21,6 +21,8 @@ python -u $PROJECT_ROOT/examples/run_grpo_sliding_puzzle.py \ cluster.gpus_per_node=2 \ grpo.max_rollout_turns=10 \ grpo.max_num_steps=3 \ + policy.max_total_sequence_length=1024 \ + policy.train_micro_batch_size=1 \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \ From 703d12ec4d29f5f5804184d5b9f54097f0caa9e5 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 24 Apr 2025 12:46:03 -0700 Subject: [PATCH 33/34] updated functional Signed-off-by: Sahil Jain --- tests/functional/grpo_multiturn.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/functional/grpo_multiturn.sh b/tests/functional/grpo_multiturn.sh index 2f260bfe07..ff9befcdd7 100755 --- a/tests/functional/grpo_multiturn.sh +++ b/tests/functional/grpo_multiturn.sh @@ -23,6 +23,8 @@ python -u $PROJECT_ROOT/examples/run_grpo_sliding_puzzle.py \ grpo.max_num_steps=3 \ policy.max_total_sequence_length=1024 \ policy.train_micro_batch_size=1 \ + policy.generation.top_p=0.99 \ + policy.generation.top_k=8000 \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \ From c737f108d7ae463b1a2977a7b8e415219fca0c13 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 24 Apr 2025 17:53:31 -0700 Subject: [PATCH 34/34] Qwen instruct Signed-off-by: Sahil Jain --- examples/configs/grpo_sliding_puzzle.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index d981b774b0..27ee2cae46 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -16,6 +16,7 @@ checkpointing: save_period: 10 policy: + model_name: "Qwen/Qwen2.5-1.5B-Instruct" max_total_sequence_length: 3072 generation: