diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index e3bee5e7f6..876c43021a 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -170,6 +170,7 @@ jobs: if [[ "${{ needs.pre-flight.outputs.test_level }}" =~ ^(L1|L2)$ ]]; then uv run --no-sync bash ./tests/functional/sft.sh uv run --no-sync bash ./tests/functional/grpo.sh + uv run --no-sync bash ./tests/functional/grpo_multiturn.sh uv run --no-sync bash ./tests/functional/dpo.sh else echo Skipping functional tests for level ${{ needs.pre-flight.outputs.test_level }} diff --git a/README.md b/README.md index 3381fef9f7..bfd0ccb668 100644 --- a/README.md +++ b/README.md @@ -32,17 +32,17 @@ What you can expect: ✅ _Available now_ | 🔜 _Coming in v0.3_ - ✅ **Fast Generation** - vLLM backend for optimized inference -- ✅ **HuggingFace Integration** - Works with 1-8B models (Qwen1.5, Llama) +- ✅ **HuggingFace Integration** - Works with 1-32B models (Qwen2.5, Llama) - ✅ **Distributed Training** - FSDP support and Ray-based infrastructure - ✅ **Environment Support** - Support for multi-environment training. -- ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning) +- ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization), SFT (Supervised Fine-Tuning), and DPO (Direct Preference Optimization) +- ✅ **Multi-Turn RL** - multi-turn generation and training for RL with tool use, games, etc. +- ✅ **Large Model Support** - Native PyTorch support for models up to 32B parameters +- ✅ **Advanced Parallelism** - FSDP2, TP, and SP for efficient training - ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state) - -- ✅ **DPO Algorithm** - Direct Preference Optimization for alignment -- ✅ **Larger Model Support** - Native PyTorch support for models up to 32B parameters -- ✅ **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training - ✅ **Environment Isolation** - Dependency isolation between components +- 🔜 **(Even) Larger Model Support** - Native PyTorch & Megatron - 🔜 **Improved Native Performance** - Improve training time for Native Pytorch Models - 🔜 **Megatron Policy** - Support advanced parallelism in training with Megatron Core - 🔜 **Megatron Inference** - Support Megatron Inference for day-0 support for new megatron models @@ -145,6 +145,12 @@ sbatch \ ray.sub ``` +We also support multi-turn generation and training (tool use, games, etc.). +Reference example for training to play a Sliding Puzzle Game: +```sh +uv run python examples/run_grpo_sliding_puzzle.py +``` + ### SFT We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml new file mode 100644 index 0000000000..27ee2cae46 --- /dev/null +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -0,0 +1,60 @@ +# GRPO Algorithm Configuration +defaults: "grpo_math_1B.yaml" + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_rollout_turns: 50 # Maximum turns allowed per rollout + max_num_steps: 10000 + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo-sliding-puzzle" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + +policy: + model_name: "Qwen/Qwen2.5-1.5B-Instruct" + max_total_sequence_length: 3072 + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + # Setting top_p/top_k to 0.999/10000 to strip out Qwen's special/illegal tokens + # https://github.com/NVIDIA/reinforcer/issues/237 + top_p: 0.999 + top_k: 10000 + stop_token_ids: null + stop_strings: null + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + +data: + add_system_prompt: false + +env: + sliding_puzzle_game: + cfg: + game_config: + size: 5 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) + shuffle_moves: 15 # Number of random moves to shuffle the solved state + max_moves: 50 # Maximum moves allowed per episode + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "grpo-dev-sliding_puzzle" + tensorboard: {} + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) diff --git a/examples/run_grpo_sliding_puzzle.py b/examples/run_grpo_sliding_puzzle.py new file mode 100644 index 0000000000..abd468881f --- /dev/null +++ b/examples/run_grpo_sliding_puzzle.py @@ -0,0 +1,278 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +import itertools +from typing import Any, Dict, Tuple, Iterator +import random + +from omegaconf import OmegaConf +from transformers import AutoTokenizer + +from torch.utils.data import IterableDataset + +from nemo_reinforcer.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_reinforcer.algorithms.utils import get_tokenizer + +from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.models.generation.interfaces import configure_generation_config +from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides +from nemo_reinforcer.utils.logger import get_next_experiment_dir + +from nemo_reinforcer.environments.games.sliding_puzzle import ( + SlidingPuzzleGameLogic, + SlidingPuzzleEnv, + SlidingPuzzleConfig, + SlidingPuzzleMetadata, +) +from nemo_reinforcer.data.interfaces import LLMMessageLogType, DatumSpec + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + args, overrides = parser.parse_known_args() + return args, overrides + + +def generate_puzzle_datum( + tokenizer, + game_config: SlidingPuzzleConfig, + max_moves: int, + task_name: str, + idx: int, + add_system_prompt: bool, +) -> DatumSpec: + """Generates a single sliding puzzle datum (prompt and metadata).""" + + def generate_random_config(max_config: Dict[str, Any]) -> Dict[str, Any]: + """Generate a random config for the sliding puzzle game.""" + shuffle_moves = random.randint(1, max_config.get("shuffle_moves")) + if shuffle_moves % 2 == 0: + shuffle_moves += 1 + return { + "size": random.randint(2, max_config.get("size", 3)), + "shuffle_moves": shuffle_moves, + } + + game_config = generate_random_config(game_config) + initial_game_state = SlidingPuzzleGameLogic.generate(game_config) + initial_render = SlidingPuzzleGameLogic.render(initial_game_state) + welcome_message = SlidingPuzzleGameLogic.init(initial_game_state) + puzzle_size = game_config.get("size", 3) + prompt_instructions = ( + f"{welcome_message}\n\n" + f"Current Board State:\n{initial_render}\n\n" + f"Reach the goal state where numbers are ordered 1 through {puzzle_size**2 - 1} " + f"with the empty space (0) at the bottom right.\n" + f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n" + f"After thinking, output your chosen action on a new line starting with '' like this:\nyour_action" + f"\nIf you just want to see the board, output view" + f"\nThink carefully step-by-step before acting.\n" + ) + initial_prompt_content = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt_instructions}], + tokenize=False, + add_system_prompt=add_system_prompt, + add_generation_prompt=True, + add_special_tokens=False, + ).strip() + tokenized_prompt = tokenizer( + initial_prompt_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + message_log: LLMMessageLogType = [ + { + "role": "user", + "content": initial_prompt_content, + "token_ids": tokenized_prompt, + } + ] + metadata = SlidingPuzzleMetadata( + game_state=initial_game_state, num_moves=0, max_moves=max_moves + ) + datum: DatumSpec = { + "message_log": message_log, + "length": len(tokenized_prompt), + "extra_env_info": metadata, + "loss_multiplier": 1.0, + "idx": idx, + "task_name": task_name, + "stop_strings": [""], + } + return datum + + +class IterablePuzzleDataset(IterableDataset): + """An IterableDataset that generates sliding puzzle data indefinitely.""" + + def __init__( + self, tokenizer, game_config, max_moves, task_name, add_system_prompt, length + ): + super().__init__() + self.tokenizer = tokenizer + self.game_config = game_config + self.max_moves = max_moves + self.task_name = task_name + self.add_system_prompt = add_system_prompt + self.length = length + + def __iter__(self) -> Iterator[DatumSpec]: + print(f"Starting IterablePuzzleDataset (indefinite generation).") + # Use itertools.count for an infinite index generator + for i in itertools.count(): + yield generate_puzzle_datum( + tokenizer=self.tokenizer, + game_config=self.game_config, + max_moves=self.max_moves, + task_name=self.task_name, + idx=i, + add_system_prompt=self.add_system_prompt, + ) + + def __len__(self): + return self.length + + +def setup_puzzle_data( + tokenizer: AutoTokenizer, + env_cfg: Dict[str, Any], + task_name: str, + length: int, + val_length: int, + add_system_prompt: bool, +) -> Tuple[IterableDataset, IterableDataset | None, Dict, Dict]: + """Sets up the iterable data generator and env map for the sliding puzzle task.""" + print("Setting up Sliding Puzzle iterable data and environment...") + env_config = env_cfg[task_name] + + print(f"Instantiating environment for task '{task_name}'...") + env = SlidingPuzzleEnv.options(num_gpus=0).remote(cfg=dict(env_config["cfg"])) + task_to_env = {task_name: env} + print(f"Environment '{task_name}' created.") + + print(f"Creating Sliding Puzzle dataset...") + training_dataset = IterablePuzzleDataset( + tokenizer=tokenizer, + game_config=dict(env_config["cfg"]["game_config"]), + max_moves=env_config["cfg"]["max_moves"], + task_name=task_name, + add_system_prompt=add_system_prompt, + length=length, + ) + print("Sliding Puzzle dataset created.") + + validation_dataset = IterablePuzzleDataset( + tokenizer=tokenizer, + game_config=dict(env_config["cfg"]["game_config"]), + max_moves=env_config["cfg"]["max_moves"], + task_name=task_name, + add_system_prompt=add_system_prompt, + length=val_length, + ) + val_task_to_env = task_to_env + + return training_dataset, validation_dataset, task_to_env, val_task_to_env + + +def main(): + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "grpo_sliding_puzzle.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"📊 Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + # setup tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # setup data & env map + ds_length = ( + config["grpo"]["num_prompts_per_step"] + * config["grpo"]["num_generations_per_prompt"] + * config["grpo"]["max_num_steps"] + ) + dataset, val_dataset, task_to_env, val_task_to_env = setup_puzzle_data( + tokenizer=tokenizer, + env_cfg=config["env"], + task_name="sliding_puzzle_game", + length=ds_length, + val_length=config["grpo"]["max_val_samples"], + add_system_prompt=config["data"]["add_system_prompt"], + ) + + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/environments/sliding_puzzle_game.py b/nemo_reinforcer/environments/games/sliding_puzzle.py similarity index 51% rename from tests/unit/environments/sliding_puzzle_game.py rename to nemo_reinforcer/environments/games/sliding_puzzle.py index 664e4c312b..0bb595bc0c 100644 --- a/tests/unit/environments/sliding_puzzle_game.py +++ b/nemo_reinforcer/environments/games/sliding_puzzle.py @@ -12,13 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ray +import torch +from typing import Dict, List, Tuple, Optional, TypedDict, Any import random import copy -from typing import List, Tuple, Dict, Any, Optional -from .game_interface import GameInterface +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.data.interfaces import LLMMessageLogType +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) +from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES -class SlidingPuzzleGame(GameInterface): + +class SlidingPuzzleConfig(TypedDict): + size: int + shuffle_moves: int + + +class SlidingPuzzleMetadata(TypedDict): + game_state: Dict[str, Any] # Stores the dict returned by SlidingPuzzleGame methods + num_moves: int + max_moves: int + + +class SlidingPuzzleGameLogic: @staticmethod def generate(config: Dict[str, Any]) -> Dict[str, Any]: """Generate a new Sliding Puzzle.""" @@ -62,11 +82,11 @@ def generate(config: Dict[str, Any]) -> Dict[str, Any]: "solution": solution, "empty_pos": empty_pos, "commands": { - "slide r c": "Slide tile at row r, column c into the empty space", "up": "Slide tile below empty space up", "down": "Slide tile above empty space down", "left": "Slide tile to the right of empty space left", "right": "Slide tile to the left of empty space right", + "view": "View the current state of the board", }, } @@ -79,8 +99,8 @@ def init(game_state: Dict[str, Any]) -> str: f"\n===== SLIDING PUZZLE =====\n" f"Arrange the {size}x{size} grid by sliding tiles into the empty space.\n" f"- The goal is to arrange numbers from 1 to {size * size - 1} in order\n" - f"- Use 'slide r c' to slide a specific tile\n" - f"- Or use 'up', 'down', 'left', 'right' to slide in that direction" + f"- Use 'up', 'down', 'left', 'right' to slide in that direction\n" + f"- Use 'view' to see the current state of the board" ) @staticmethod @@ -94,7 +114,7 @@ def step( # Default return values response = "Unknown command. Type 'help' to see available commands." - reward = -0.05 # Small penalty for invalid actions + reward = 0.0 # No penalty for invalid actions is_terminated = False # Deep copy game state to avoid modifying the original @@ -218,39 +238,167 @@ def render(game_state: Dict[str, Any]) -> str: return "\n".join(output) -def is_solvable(grid: List[List[int]], size: int) -> bool: - """Check if a sliding puzzle is solvable.""" - # Flatten the grid - flat = [num for row in grid for num in row if num != 0] - - # Count inversions - inversions = 0 - for i in range(len(flat)): - for j in range(i + 1, len(flat)): - if flat[i] > flat[j]: - inversions += 1 - - # Find row of the empty tile (0) from the bottom - empty_row = 0 - for i in range(size - 1, -1, -1): - for j in range(size): - if grid[i][j] == 0: - empty_row = size - i - break - - # For odd-sized grids, the puzzle is solvable if the number of inversions is even - if size % 2 == 1: - return inversions % 2 == 0 - # For even-sized grids, the puzzle is solvable if: - # (inversions odd and empty on even row from bottom) or (inversions even and empty on odd row from bottom) - else: - return (inversions % 2 == 1 and empty_row % 2 == 0) or ( - inversions % 2 == 0 and empty_row % 2 == 1 +class SlidingPuzzleRunner: + def __init__(self): + pass # No initialization needed as game methods are static + + def _parse_action(self, text: str) -> Optional[str]: + """Parses the action from ''.""" + prefix = "" + suffix = "" + # Find the prefix, case-insensitive, and potentially after some thought process + text_lower = text.lower() + prefix_lower = prefix.lower() + suffix_lower = suffix.lower() + + start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence + + if start_idx != -1: + # Find the end tag after the start tag + end_idx = text_lower.find(suffix_lower, start_idx + len(prefix_lower)) + if end_idx != -1: + # Extract content between tags + action_content = text[start_idx + len(prefix) : end_idx].strip() + return action_content + return None + + def process_turn( + self, + message_log: LLMMessageLogType, + metadata: SlidingPuzzleMetadata, + ) -> Tuple[ + Dict[str, str], + float, + bool, + Optional[List[str]], + Optional[SlidingPuzzleMetadata], + ]: + """Processes a single turn for the sliding puzzle task.""" + game_state = metadata["game_state"] + current_moves = metadata["num_moves"] + max_moves = metadata["max_moves"] + + turn_reward = 0.0 + is_terminated = False + next_stop_strings = [""] + next_metadata = metadata.copy() + next_observation_content = "" + + # Check if max moves reached + if current_moves >= max_moves: + is_terminated = True + next_observation_content = ( + f"Maximum moves ({max_moves}) reached." + ) + next_metadata = None + return ( + {"role": "environment", "content": next_observation_content}, + 0.0, + is_terminated, + None, + next_metadata, + ) + + # Get last assistant message and parse action + last_assistant_msg_content = "" + if message_log and message_log[-1]["role"] == "assistant": + last_assistant_msg_content = message_log[-1]["content"].strip() + + parsed_action = self._parse_action(last_assistant_msg_content) + + if parsed_action is None: + rendered_board = SlidingPuzzleGameLogic.render(game_state) + next_observation_content = f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action" + next_metadata = None + elif parsed_action == "view": + rendered_board = SlidingPuzzleGameLogic.render(game_state) + next_observation_content = f"\n{rendered_board}\n\nViewing the board. No move made." + else: + # Execute the game step + step_response, reward, game_over, next_game_state = ( + SlidingPuzzleGameLogic.step(parsed_action, game_state) + ) + + turn_reward = reward + is_terminated = game_over + next_metadata["game_state"] = next_game_state + next_metadata["num_moves"] = current_moves + 1 + + next_observation_content = f"\n{step_response}\n" + + if is_terminated: + next_metadata = None # Clear metadata on termination + + return ( + {"role": "environment", "content": next_observation_content + "\n"}, + turn_reward, + is_terminated, + next_stop_strings, + next_metadata, ) -def play_sliding_puzzle(config=None): - """Wrapper function for backward compatibility.""" - from play_game import play_game +@ray.remote +class SlidingPuzzleEnv(EnvironmentInterface): + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM + """Sliding Puzzle environment (Ray Actor).""" - play_game(SlidingPuzzleGame, config) + def __init__(self, cfg: Optional[SlidingPuzzleConfig] = None): + # cfg could contain game generation config like {'size': 3, 'shuffle_moves': 50} + self.game_config = cfg.get("game_config", {}) if cfg else {} + self.runner = SlidingPuzzleRunner() + + def step( + self, + message_log_batch: List[LLMMessageLogType], + metadata_batch: List[SlidingPuzzleMetadata], + ) -> EnvironmentReturn: + """Processes a batch of sliding puzzle interactions.""" + # Since logic is synchronous, process sequentially (can parallelize if logic becomes heavy) + results = [ + self.runner.process_turn(log, meta) + for log, meta in zip(message_log_batch, metadata_batch) + ] + + # Unpack results and format according to EnvironmentReturn NamedTuple + observations = [] + rewards = [] + terminateds = [] + all_stop_strings = [] + all_next_metadata = [] + + for obs, rew, term, stops, meta in results: + observations.append(obs) + rewards.append(rew) + terminateds.append(term) + all_stop_strings.append(stops) + all_next_metadata.append(meta) + + rewards_tensor = torch.tensor(rewards, dtype=torch.float32) + terminated_tensor = torch.tensor(terminateds, dtype=torch.bool) + + return EnvironmentReturn( + observations=observations, + metadata=all_next_metadata, + next_stop_strings=all_stop_strings, + rewards=rewards_tensor, + terminateds=terminated_tensor, + ) + + def shutdown(self): + pass + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + # Calculate success rate based on final reward == 1.0 + final_rewards = batch.get( + "total_reward", torch.tensor([0.0] * len(batch["idx"])) + ) + success_rate = ( + (final_rewards == 1.0).float().mean().item() + if len(final_rewards) > 0 + else 0.0 + ) + # Could also calculate average number of moves for successful episodes, etc. + return batch, {"sliding_puzzle_success_rate": success_rate} diff --git a/tests/functional/grpo_multiturn.sh b/tests/functional/grpo_multiturn.sh new file mode 100755 index 0000000000..ff9befcdd7 --- /dev/null +++ b/tests/functional/grpo_multiturn.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetchs metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +LOG_DIR=$SCRIPT_DIR/$(basename $0 .sh)-logs +JSON_METRICS=$LOG_DIR/$(basename $0 .sh).json +RUN_LOG=$LOG_DIR/$(basename $0 .sh).log +export UV_CACHE_DIR=${UV_CACHE_DIR:-$PROJECT_ROOT/uv_cache} +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $LOG_DIR +mkdir -p $LOG_DIR + +cd $PROJECT_ROOT +python -u $PROJECT_ROOT/examples/run_grpo_sliding_puzzle.py \ + cluster.gpus_per_node=2 \ + grpo.max_rollout_turns=10 \ + grpo.max_num_steps=3 \ + policy.max_total_sequence_length=1024 \ + policy.train_micro_batch_size=1 \ + policy.generation.top_p=0.99 \ + policy.generation.top_k=8000 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +cd $SCRIPT_DIR +python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +python check_metrics.py $JSON_METRICS \ + 'max(data["train/token_mult_prob_error"]) < 1.1' \ + diff --git a/tests/unit/environments/game_interface.py b/tests/unit/environments/game_interface.py deleted file mode 100644 index 2f0237ed23..0000000000 --- a/tests/unit/environments/game_interface.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Any, Tuple, List, Optional, Callable - - -class GameInterface: - @staticmethod - def generate(config: Dict[str, Any]) -> Dict[str, Any]: - """ - Generate a new game state based on configuration. - - Args: - config: Game configuration dictionary - - Returns: - A dictionary containing the complete game state - """ - raise NotImplementedError("Each game must implement generate()") - - @staticmethod - def init(game_state: Dict[str, Any]) -> str: - """ - Initialize a game and return welcome messages. - - Args: - game_state: The game state dictionary - - Returns: - String containing welcome message and instructions - """ - raise NotImplementedError("Each game must implement init()") - - @staticmethod - def step( - action: str, game_state: Dict[str, Any] - ) -> Tuple[str, float, bool, Dict[str, Any]]: - """ - Process a game action and update the state. - - Args: - action: String representing the player's action - game_state: Current game state dictionary - - Returns: - Tuple containing: - - Response message - - Reward for this action - - Boolean indicating if game is terminated - - Updated game state - """ - raise NotImplementedError("Each game must implement step()") - - @staticmethod - def render(game_state: Dict[str, Any]) -> str: - """ - Render the current game state as a string. - - Args: - game_state: The game state dictionary - - Returns: - String representation of the game state - """ - raise NotImplementedError("Each game must implement render()") diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index daeecb2bc6..a1e72fa6a7 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -32,13 +32,15 @@ MultiStepCalculatorEnv, _MultiStepCalculatorLogic, MultiStepCalcMetadata, +) + +from nemo_reinforcer.environments.games.sliding_puzzle import ( + SlidingPuzzleGameLogic, SlidingPuzzleEnv, + SlidingPuzzleConfig, SlidingPuzzleMetadata, ) -# Import the game logic for generating initial state from its new location -from tests.unit.environments.sliding_puzzle_game import SlidingPuzzleGame - from nemo_reinforcer.models.generation.vllm import VllmConfig, VllmGeneration MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" @@ -463,16 +465,16 @@ def sliding_puzzle_environment(rollout_cluster): def initial_sliding_puzzle_batch(rollout_tokenizer): print("Creating initial sliding puzzle test batch...") batch_size = 1 - game_config = { + game_config: SlidingPuzzleConfig = { "size": 2, "shuffle_moves": 1, } max_moves = 10 # Set a limit for the test # Generate initial game state - initial_game_state = SlidingPuzzleGame.generate(game_config) - initial_render = SlidingPuzzleGame.render(initial_game_state) - welcome_message = SlidingPuzzleGame.init(initial_game_state) + initial_game_state = SlidingPuzzleGameLogic.generate(game_config) + initial_render = SlidingPuzzleGameLogic.render(initial_game_state) + welcome_message = SlidingPuzzleGameLogic.init(initial_game_state) prompt_instructions = ( f"{welcome_message}\n\n" diff --git a/tests/unit/test_envs.py b/tests/unit/test_envs.py index 5d410f5895..5fe62f135b 100644 --- a/tests/unit/test_envs.py +++ b/tests/unit/test_envs.py @@ -23,7 +23,6 @@ EnvironmentReturn, ) from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES -from .environments.sliding_puzzle_game import SlidingPuzzleGame class MultiStepCalcMetadata(TypedDict): @@ -33,12 +32,6 @@ class MultiStepCalcMetadata(TypedDict): current_step: int -class SlidingPuzzleMetadata(TypedDict): - game_state: Dict[str, Any] # Stores the dict returned by SlidingPuzzleGame methods - num_moves: int - max_moves: int - - class _MultiStepCalculatorLogic: def __init__(self): pass @@ -183,113 +176,6 @@ def process_turn( ) -class _SlidingPuzzleLogic: - def __init__(self): - pass # No initialization needed as game methods are static - - def _parse_action(self, text: str) -> Optional[str]: - """Parses the action from ''""" - prefix = "" - suffix = "" - # Find the prefix, case-insensitive, and potentially after some thought process - text_lower = text.lower() - prefix_lower = prefix.lower() - suffix_lower = suffix.lower() - - start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence - - if start_idx != -1: - # Find the end tag after the start tag - end_idx = text_lower.find(suffix_lower, start_idx + len(prefix_lower)) - if end_idx != -1: - # Extract content between tags - action_content = text[start_idx + len(prefix) : end_idx].strip() - return action_content - return None - - def process_turn( - self, - message_log: LLMMessageLogType, - metadata: SlidingPuzzleMetadata, - ) -> Tuple[ - Dict[str, str], - float, - bool, - Optional[List[str]], - Optional[SlidingPuzzleMetadata], - ]: - """Processes a single turn for the sliding puzzle task.""" - game_state = metadata["game_state"] - current_moves = metadata["num_moves"] - max_moves = metadata["max_moves"] - - turn_reward = 0.0 - is_terminated = False - next_stop_strings = [""] - next_metadata = metadata.copy() - next_observation_content = "" - - # Check if max moves reached - if current_moves >= max_moves: - is_terminated = True - next_observation_content = ( - f"Maximum moves ({max_moves}) reached." - ) - next_metadata = None - return ( - {"role": "environment", "content": next_observation_content}, - 0.0, - is_terminated, - None, - next_metadata, - ) - - # Get last assistant message and parse action - last_assistant_msg_content = "" - if message_log and message_log[-1]["role"] == "assistant": - last_assistant_msg_content = message_log[-1]["content"].strip() - - parsed_action = self._parse_action(last_assistant_msg_content) - - if parsed_action is None: - # Handle cases where parsing failed or it wasn't assistant's turn properly - # is_terminated = True # Penalize for bad format - rendered_board = SlidingPuzzleGame.render(game_state) - next_observation_content = f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action" - next_metadata = None - elif parsed_action == "view": - rendered_board = SlidingPuzzleGame.render(game_state) - next_observation_content = f"\n{rendered_board}\n\nViewing the board. No move made." - else: - # Execute the game step - step_response, reward, game_over, next_game_state = SlidingPuzzleGame.step( - parsed_action, game_state - ) - - turn_reward = reward - is_terminated = game_over - next_metadata["game_state"] = next_game_state - next_metadata["num_moves"] = current_moves + 1 - - # Combine rendered board and step response for the next observation - rendered_board = SlidingPuzzleGame.render(next_game_state) - # next_observation_content = f"\n{rendered_board}\n\n{step_response}" - next_observation_content = f"\n{step_response}\n" - # next_observation_content = f"\n{step_response}" - - if is_terminated: - next_metadata = None # Clear metadata on termination - # next_stop_strings remains None - - return ( - {"role": "environment", "content": next_observation_content + "\n"}, - turn_reward, - is_terminated, - next_stop_strings, - next_metadata, - ) - - @ray.remote class MultiStepCalculatorEnv(EnvironmentInterface): DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM @@ -350,69 +236,3 @@ def global_post_process_and_metrics( ) success_rate = final_rewards.mean().item() if len(final_rewards) > 0 else 0.0 return batch, {"success_rate": success_rate} - - -@ray.remote -class SlidingPuzzleEnv(EnvironmentInterface): - DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM - """Sliding Puzzle environment (Ray Actor).""" - - def __init__(self, cfg: Optional[Dict] = None): - # cfg could contain game generation config like {'size': 3, 'shuffle_moves': 50} - self.game_config = cfg.get("game_config", {}) if cfg else {} - self.logic = _SlidingPuzzleLogic() - - def step( - self, - message_log_batch: List[LLMMessageLogType], - metadata_batch: List[SlidingPuzzleMetadata], - ) -> EnvironmentReturn: - """Processes a batch of sliding puzzle interactions.""" - # Since logic is synchronous, process sequentially (can parallelize if logic becomes heavy) - results = [ - self.logic.process_turn(log, meta) - for log, meta in zip(message_log_batch, metadata_batch) - ] - - # Unpack results and format according to EnvironmentReturn NamedTuple - observations = [] - rewards = [] - terminateds = [] - all_stop_strings = [] - all_next_metadata = [] - - for obs, rew, term, stops, meta in results: - observations.append(obs) - rewards.append(rew) - terminateds.append(term) - all_stop_strings.append(stops) - all_next_metadata.append(meta) - - rewards_tensor = torch.tensor(rewards, dtype=torch.float32) - terminated_tensor = torch.tensor(terminateds, dtype=torch.bool) - - return EnvironmentReturn( - observations=observations, - metadata=all_next_metadata, - next_stop_strings=all_stop_strings, - rewards=rewards_tensor, - terminateds=terminated_tensor, - ) - - def shutdown(self): - pass - - def global_post_process_and_metrics( - self, batch: BatchedDataDict - ) -> Tuple[BatchedDataDict, dict]: - # Calculate success rate based on final reward == 1.0 - final_rewards = batch.get( - "total_reward", torch.tensor([0.0] * len(batch["idx"])) - ) - success_rate = ( - (final_rewards == 1.0).float().mean().item() - if len(final_rewards) > 0 - else 0.0 - ) - # Could also calculate average number of moves for successful episodes, etc. - return batch, {"sliding_puzzle_success_rate": success_rate}