-
Notifications
You must be signed in to change notification settings - Fork 306
feat: E2E multi-turn RL example with a sliding puzzle game #242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
d0f7c8c
Multiturn integrated
SahilJain314 bbbbe88
Removed redundant imports
SahilJain314 811067b
Fixed Math env on mutliturn
SahilJain314 d1c59b2
Fixed nondetermistic multiturn error bug
SahilJain314 95c3218
Fixed validation error
SahilJain314 ceb6035
Fixed validation error
SahilJain314 14a3417
Fixed validation error
SahilJain314 d112702
<1 lp error ??
SahilJain314 f370c4a
debugging
SahilJain314 f3a5001
remove debugging
SahilJain314 5c90d64
cleanup
SahilJain314 7a803f0
Fix multiturn multigpu bugs
SahilJain314 c9d1298
adding sliding puzzle trianing scripts
SahilJain314 b9d936f
fix many GPU bug
SahilJain314 8970cd0
:wUpdated sliding defaults
SahilJain314 df0b09f
Bugfixes to multiturn
SahilJain314 e075cc4
Fixed sliding puzzle test
SahilJain314 d734f35
Removed WIP example
SahilJain314 694723d
Cleanup
SahilJain314 b680aa9
Cleanup
SahilJain314 eea7cd7
math fix
SahilJain314 44666ce
added doctests to batched_data_dict
SahilJain314 32eefb3
lint
SahilJain314 f458fb0
unit tests
SahilJain314 6be1908
added e2e sliding puzzle game example
SahilJain314 a0f3ecb
config update
SahilJain314 be37464
Merged main
SahilJain314 1e3ce54
removed vestigial
SahilJain314 fe73757
Cleanup
SahilJain314 67b3743
Update README
SahilJain314 1bbd5d0
Added functional sliding puzzle test
SahilJain314 b1809ae
Merge branch 'main' into sahilj/multiturn_example
SahilJain314 14cd248
Merge branch 'main' into sahilj/multiturn_example
SahilJain314 f363431
Update README.md
SahilJain314 8db7240
reduced seqlen for small functional test machine
SahilJain314 677c40a
Merge branch 'sahilj/multiturn_example' of github.com:NVIDIA/reinforc…
SahilJain314 703d12e
updated functional
SahilJain314 c737f10
Qwen instruct
SahilJain314 a0e0cf6
Merge branch 'main' into sahilj/multiturn_example
SahilJain314 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # GRPO Algorithm Configuration | ||
| defaults: "grpo_math_1B.yaml" | ||
|
|
||
| grpo: | ||
| num_prompts_per_step: 32 | ||
| num_generations_per_prompt: 16 | ||
| max_rollout_turns: 50 # Maximum turns allowed per rollout | ||
| max_num_steps: 10000 | ||
|
|
||
| checkpointing: | ||
| enabled: true | ||
| checkpoint_dir: "results/grpo-sliding-puzzle" | ||
| metric_name: "val_reward" | ||
| higher_is_better: true | ||
| keep_top_k: 3 | ||
| save_period: 10 | ||
|
|
||
| policy: | ||
| model_name: "Qwen/Qwen2.5-1.5B-Instruct" | ||
| max_total_sequence_length: 3072 | ||
|
|
||
| generation: | ||
| backend: "vllm" | ||
| max_new_tokens: ${policy.max_total_sequence_length} | ||
| temperature: 1.0 | ||
| # Setting top_p/top_k to 0.999/10000 to strip out Qwen's special/illegal tokens | ||
| # https://github.com/NVIDIA/reinforcer/issues/237 | ||
| top_p: 0.999 | ||
| top_k: 10000 | ||
SahilJain314 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| stop_token_ids: null | ||
| stop_strings: null | ||
| vllm_cfg: | ||
| tensor_parallel_size: 1 | ||
| gpu_memory_utilization: 0.6 | ||
| max_model_len: ${policy.max_total_sequence_length} | ||
|
|
||
| data: | ||
| add_system_prompt: false | ||
|
|
||
| env: | ||
| sliding_puzzle_game: | ||
| cfg: | ||
| game_config: | ||
| size: 5 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3) | ||
| shuffle_moves: 15 # Number of random moves to shuffle the solved state | ||
| max_moves: 50 # Maximum moves allowed per episode | ||
|
|
||
| logger: | ||
| log_dir: "logs" # Base directory for all logs | ||
| num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal | ||
| wandb_enabled: false | ||
| tensorboard_enabled: false | ||
| monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard | ||
| wandb: | ||
| project: "grpo-dev" | ||
| name: "grpo-dev-sliding_puzzle" | ||
| tensorboard: {} | ||
| gpu_monitoring: | ||
| collection_interval: 10 # How often to collect GPU usage metrics (in seconds) | ||
| flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,278 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import argparse | ||
| import os | ||
| import pprint | ||
| import itertools | ||
| from typing import Any, Dict, Tuple, Iterator | ||
| import random | ||
|
|
||
| from omegaconf import OmegaConf | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from torch.utils.data import IterableDataset | ||
|
|
||
| from nemo_reinforcer.algorithms.grpo import MasterConfig, grpo_train, setup | ||
| from nemo_reinforcer.algorithms.utils import get_tokenizer | ||
|
|
||
| from nemo_reinforcer.distributed.virtual_cluster import init_ray | ||
| from nemo_reinforcer.models.generation.interfaces import configure_generation_config | ||
| from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides | ||
| from nemo_reinforcer.utils.logger import get_next_experiment_dir | ||
|
|
||
| from nemo_reinforcer.environments.games.sliding_puzzle import ( | ||
| SlidingPuzzleGameLogic, | ||
| SlidingPuzzleEnv, | ||
| SlidingPuzzleConfig, | ||
| SlidingPuzzleMetadata, | ||
| ) | ||
| from nemo_reinforcer.data.interfaces import LLMMessageLogType, DatumSpec | ||
|
|
||
|
|
||
| def parse_args(): | ||
| """Parse command line arguments.""" | ||
| parser = argparse.ArgumentParser(description="Run GRPO training with configuration") | ||
| parser.add_argument( | ||
| "--config", type=str, default=None, help="Path to YAML config file" | ||
| ) | ||
| args, overrides = parser.parse_known_args() | ||
| return args, overrides | ||
|
|
||
|
|
||
| def generate_puzzle_datum( | ||
| tokenizer, | ||
| game_config: SlidingPuzzleConfig, | ||
| max_moves: int, | ||
| task_name: str, | ||
| idx: int, | ||
| add_system_prompt: bool, | ||
| ) -> DatumSpec: | ||
| """Generates a single sliding puzzle datum (prompt and metadata).""" | ||
|
|
||
| def generate_random_config(max_config: Dict[str, Any]) -> Dict[str, Any]: | ||
| """Generate a random config for the sliding puzzle game.""" | ||
| shuffle_moves = random.randint(1, max_config.get("shuffle_moves")) | ||
| if shuffle_moves % 2 == 0: | ||
| shuffle_moves += 1 | ||
| return { | ||
| "size": random.randint(2, max_config.get("size", 3)), | ||
| "shuffle_moves": shuffle_moves, | ||
| } | ||
|
|
||
| game_config = generate_random_config(game_config) | ||
| initial_game_state = SlidingPuzzleGameLogic.generate(game_config) | ||
| initial_render = SlidingPuzzleGameLogic.render(initial_game_state) | ||
| welcome_message = SlidingPuzzleGameLogic.init(initial_game_state) | ||
| puzzle_size = game_config.get("size", 3) | ||
| prompt_instructions = ( | ||
| f"{welcome_message}\n\n" | ||
| f"Current Board State:\n{initial_render}\n\n" | ||
| f"Reach the goal state where numbers are ordered 1 through {puzzle_size**2 - 1} " | ||
| f"with the empty space (0) at the bottom right.\n" | ||
| f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n" | ||
| f"After thinking, output your chosen action on a new line starting with '<action></action>' like this:\n<action>your_action</action>" | ||
| f"\nIf you just want to see the board, output <action>view</action>" | ||
| f"\nThink carefully step-by-step before acting.\n" | ||
| ) | ||
| initial_prompt_content = tokenizer.apply_chat_template( | ||
| [{"role": "user", "content": prompt_instructions}], | ||
| tokenize=False, | ||
| add_system_prompt=add_system_prompt, | ||
| add_generation_prompt=True, | ||
| add_special_tokens=False, | ||
| ).strip() | ||
| tokenized_prompt = tokenizer( | ||
| initial_prompt_content, return_tensors="pt", add_special_tokens=False | ||
| )["input_ids"][0] | ||
| message_log: LLMMessageLogType = [ | ||
| { | ||
| "role": "user", | ||
| "content": initial_prompt_content, | ||
| "token_ids": tokenized_prompt, | ||
| } | ||
| ] | ||
| metadata = SlidingPuzzleMetadata( | ||
| game_state=initial_game_state, num_moves=0, max_moves=max_moves | ||
| ) | ||
| datum: DatumSpec = { | ||
| "message_log": message_log, | ||
| "length": len(tokenized_prompt), | ||
| "extra_env_info": metadata, | ||
| "loss_multiplier": 1.0, | ||
| "idx": idx, | ||
| "task_name": task_name, | ||
| "stop_strings": ["</action>"], | ||
| } | ||
| return datum | ||
|
|
||
|
|
||
| class IterablePuzzleDataset(IterableDataset): | ||
| """An IterableDataset that generates sliding puzzle data indefinitely.""" | ||
|
|
||
| def __init__( | ||
| self, tokenizer, game_config, max_moves, task_name, add_system_prompt, length | ||
| ): | ||
| super().__init__() | ||
| self.tokenizer = tokenizer | ||
| self.game_config = game_config | ||
| self.max_moves = max_moves | ||
| self.task_name = task_name | ||
| self.add_system_prompt = add_system_prompt | ||
| self.length = length | ||
|
|
||
| def __iter__(self) -> Iterator[DatumSpec]: | ||
| print(f"Starting IterablePuzzleDataset (indefinite generation).") | ||
| # Use itertools.count for an infinite index generator | ||
| for i in itertools.count(): | ||
| yield generate_puzzle_datum( | ||
| tokenizer=self.tokenizer, | ||
| game_config=self.game_config, | ||
| max_moves=self.max_moves, | ||
| task_name=self.task_name, | ||
| idx=i, | ||
| add_system_prompt=self.add_system_prompt, | ||
| ) | ||
|
|
||
| def __len__(self): | ||
| return self.length | ||
|
|
||
|
|
||
| def setup_puzzle_data( | ||
| tokenizer: AutoTokenizer, | ||
| env_cfg: Dict[str, Any], | ||
| task_name: str, | ||
| length: int, | ||
| val_length: int, | ||
| add_system_prompt: bool, | ||
| ) -> Tuple[IterableDataset, IterableDataset | None, Dict, Dict]: | ||
| """Sets up the iterable data generator and env map for the sliding puzzle task.""" | ||
| print("Setting up Sliding Puzzle iterable data and environment...") | ||
| env_config = env_cfg[task_name] | ||
|
|
||
| print(f"Instantiating environment for task '{task_name}'...") | ||
| env = SlidingPuzzleEnv.options(num_gpus=0).remote(cfg=dict(env_config["cfg"])) | ||
| task_to_env = {task_name: env} | ||
| print(f"Environment '{task_name}' created.") | ||
|
|
||
| print(f"Creating Sliding Puzzle dataset...") | ||
| training_dataset = IterablePuzzleDataset( | ||
| tokenizer=tokenizer, | ||
| game_config=dict(env_config["cfg"]["game_config"]), | ||
| max_moves=env_config["cfg"]["max_moves"], | ||
| task_name=task_name, | ||
| add_system_prompt=add_system_prompt, | ||
| length=length, | ||
| ) | ||
| print("Sliding Puzzle dataset created.") | ||
|
|
||
| validation_dataset = IterablePuzzleDataset( | ||
| tokenizer=tokenizer, | ||
| game_config=dict(env_config["cfg"]["game_config"]), | ||
| max_moves=env_config["cfg"]["max_moves"], | ||
| task_name=task_name, | ||
| add_system_prompt=add_system_prompt, | ||
| length=val_length, | ||
| ) | ||
| val_task_to_env = task_to_env | ||
|
|
||
| return training_dataset, validation_dataset, task_to_env, val_task_to_env | ||
|
|
||
|
|
||
| def main(): | ||
| """Main entry point.""" | ||
| # Parse arguments | ||
| args, overrides = parse_args() | ||
|
|
||
| if not args.config: | ||
| args.config = os.path.join( | ||
| os.path.dirname(__file__), "configs", "grpo_sliding_puzzle.yaml" | ||
| ) | ||
|
|
||
| config = load_config(args.config) | ||
| print(f"Loaded configuration from: {args.config}") | ||
|
|
||
| if overrides: | ||
| print(f"Overrides: {overrides}") | ||
| config = parse_hydra_overrides(config, overrides) | ||
|
|
||
| config: MasterConfig = OmegaConf.to_container(config, resolve=True) | ||
| print("Applied CLI overrides") | ||
|
|
||
| # Print config | ||
| print("Final config:") | ||
| pprint.pprint(config) | ||
|
|
||
| # Get the next experiment directory with incremented ID | ||
| config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) | ||
| print(f"📊 Using log directory: {config['logger']['log_dir']}") | ||
| if config["checkpointing"]["enabled"]: | ||
| print( | ||
| f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" | ||
| ) | ||
|
|
||
| init_ray() | ||
|
|
||
| # setup tokenizer | ||
| tokenizer = get_tokenizer(config["policy"]["tokenizer"]) | ||
| config["policy"]["generation"] = configure_generation_config( | ||
| config["policy"]["generation"], tokenizer | ||
| ) | ||
|
|
||
| # setup data & env map | ||
| ds_length = ( | ||
| config["grpo"]["num_prompts_per_step"] | ||
| * config["grpo"]["num_generations_per_prompt"] | ||
| * config["grpo"]["max_num_steps"] | ||
| ) | ||
| dataset, val_dataset, task_to_env, val_task_to_env = setup_puzzle_data( | ||
| tokenizer=tokenizer, | ||
| env_cfg=config["env"], | ||
| task_name="sliding_puzzle_game", | ||
| length=ds_length, | ||
| val_length=config["grpo"]["max_val_samples"], | ||
| add_system_prompt=config["data"]["add_system_prompt"], | ||
| ) | ||
|
|
||
| ( | ||
| policy, | ||
| policy_generation, | ||
| cluster, | ||
| dataloader, | ||
| val_dataloader, | ||
| loss_fn, | ||
| logger, | ||
| checkpointer, | ||
| grpo_state, | ||
| master_config, | ||
| ) = setup(config, tokenizer, dataset, val_dataset) | ||
|
|
||
| grpo_train( | ||
| policy, | ||
| policy_generation, | ||
| dataloader, | ||
| val_dataloader, | ||
| tokenizer, | ||
| loss_fn, | ||
| task_to_env, | ||
| val_task_to_env, | ||
| logger, | ||
| checkpointer, | ||
| grpo_state, | ||
| master_config, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.