Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d0f7c8c
Multiturn integrated
SahilJain314 Apr 17, 2025
bbbbe88
Removed redundant imports
SahilJain314 Apr 17, 2025
811067b
Fixed Math env on mutliturn
SahilJain314 Apr 17, 2025
d1c59b2
Fixed nondetermistic multiturn error bug
SahilJain314 Apr 17, 2025
95c3218
Fixed validation error
SahilJain314 Apr 18, 2025
ceb6035
Fixed validation error
SahilJain314 Apr 18, 2025
14a3417
Fixed validation error
SahilJain314 Apr 18, 2025
d112702
<1 lp error ??
SahilJain314 Apr 18, 2025
f370c4a
debugging
SahilJain314 Apr 18, 2025
f3a5001
remove debugging
SahilJain314 Apr 18, 2025
5c90d64
cleanup
SahilJain314 Apr 18, 2025
7a803f0
Fix multiturn multigpu bugs
SahilJain314 Apr 18, 2025
c9d1298
adding sliding puzzle trianing scripts
SahilJain314 Apr 18, 2025
b9d936f
fix many GPU bug
SahilJain314 Apr 19, 2025
8970cd0
:wUpdated sliding defaults
SahilJain314 Apr 19, 2025
df0b09f
Bugfixes to multiturn
SahilJain314 Apr 21, 2025
e075cc4
Fixed sliding puzzle test
SahilJain314 Apr 21, 2025
d734f35
Removed WIP example
SahilJain314 Apr 21, 2025
694723d
Cleanup
SahilJain314 Apr 21, 2025
b680aa9
Cleanup
SahilJain314 Apr 21, 2025
eea7cd7
math fix
SahilJain314 Apr 21, 2025
44666ce
added doctests to batched_data_dict
SahilJain314 Apr 21, 2025
32eefb3
lint
SahilJain314 Apr 21, 2025
f458fb0
unit tests
SahilJain314 Apr 22, 2025
6be1908
added e2e sliding puzzle game example
SahilJain314 Apr 22, 2025
a0f3ecb
config update
SahilJain314 Apr 22, 2025
be37464
Merged main
SahilJain314 Apr 22, 2025
1e3ce54
removed vestigial
SahilJain314 Apr 22, 2025
fe73757
Cleanup
SahilJain314 Apr 23, 2025
67b3743
Update README
SahilJain314 Apr 23, 2025
1bbd5d0
Added functional sliding puzzle test
SahilJain314 Apr 24, 2025
b1809ae
Merge branch 'main' into sahilj/multiturn_example
SahilJain314 Apr 24, 2025
14cd248
Merge branch 'main' into sahilj/multiturn_example
SahilJain314 Apr 24, 2025
f363431
Update README.md
SahilJain314 Apr 24, 2025
8db7240
reduced seqlen for small functional test machine
SahilJain314 Apr 24, 2025
677c40a
Merge branch 'sahilj/multiturn_example' of github.com:NVIDIA/reinforc…
SahilJain314 Apr 24, 2025
703d12e
updated functional
SahilJain314 Apr 24, 2025
c737f10
Qwen instruct
SahilJain314 Apr 25, 2025
a0e0cf6
Merge branch 'main' into sahilj/multiturn_example
SahilJain314 Apr 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ jobs:
if [[ "${{ needs.pre-flight.outputs.test_level }}" =~ ^(L1|L2)$ ]]; then
uv run --no-sync bash ./tests/functional/sft.sh
uv run --no-sync bash ./tests/functional/grpo.sh
uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
uv run --no-sync bash ./tests/functional/dpo.sh
else
echo Skipping functional tests for level ${{ needs.pre-flight.outputs.test_level }}
Expand Down
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ What you can expect:
✅ _Available now_ | 🔜 _Coming in v0.3_

- ✅ **Fast Generation** - vLLM backend for optimized inference
- ✅ **HuggingFace Integration** - Works with 1-8B models (Qwen1.5, Llama)
- ✅ **HuggingFace Integration** - Works with 1-32B models (Qwen2.5, Llama)
- ✅ **Distributed Training** - FSDP support and Ray-based infrastructure
- ✅ **Environment Support** - Support for multi-environment training.
- ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning)
- ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization), SFT (Supervised Fine-Tuning), and DPO (Direct Preference Optimization)
- ✅ **Multi-Turn RL** - multi-turn generation and training for RL with tool use, games, etc.
- ✅ **Large Model Support** - Native PyTorch support for models up to 32B parameters
- ✅ **Advanced Parallelism** - FSDP2, TP, and SP for efficient training
- ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state)

- ✅ **DPO Algorithm** - Direct Preference Optimization for alignment
- ✅ **Larger Model Support** - Native PyTorch support for models up to 32B parameters
- ✅ **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training
- ✅ **Environment Isolation** - Dependency isolation between components

- 🔜 **(Even) Larger Model Support** - Native PyTorch & Megatron
- 🔜 **Improved Native Performance** - Improve training time for Native Pytorch Models
- 🔜 **Megatron Policy** - Support advanced parallelism in training with Megatron Core
- 🔜 **Megatron Inference** - Support Megatron Inference for day-0 support for new megatron models
Expand Down Expand Up @@ -145,6 +145,12 @@ sbatch \
ray.sub
```

We also support multi-turn generation and training (tool use, games, etc.).
Reference example for training to play a Sliding Puzzle Game:
```sh
uv run python examples/run_grpo_sliding_puzzle.py
```

### SFT

We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/).
Expand Down
60 changes: 60 additions & 0 deletions examples/configs/grpo_sliding_puzzle.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# GRPO Algorithm Configuration
defaults: "grpo_math_1B.yaml"

grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 16
max_rollout_turns: 50 # Maximum turns allowed per rollout
max_num_steps: 10000

checkpointing:
enabled: true
checkpoint_dir: "results/grpo-sliding-puzzle"
metric_name: "val_reward"
higher_is_better: true
keep_top_k: 3
save_period: 10

policy:
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
max_total_sequence_length: 3072

generation:
backend: "vllm"
max_new_tokens: ${policy.max_total_sequence_length}
temperature: 1.0
# Setting top_p/top_k to 0.999/10000 to strip out Qwen's special/illegal tokens
# https://github.com/NVIDIA/reinforcer/issues/237
top_p: 0.999
top_k: 10000
stop_token_ids: null
stop_strings: null
vllm_cfg:
tensor_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}

data:
add_system_prompt: false

env:
sliding_puzzle_game:
cfg:
game_config:
size: 5 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3)
shuffle_moves: 15 # Number of random moves to shuffle the solved state
max_moves: 50 # Maximum moves allowed per episode

logger:
log_dir: "logs" # Base directory for all logs
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb_enabled: false
tensorboard_enabled: false
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "grpo-dev"
name: "grpo-dev-sliding_puzzle"
tensorboard: {}
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)
278 changes: 278 additions & 0 deletions examples/run_grpo_sliding_puzzle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import pprint
import itertools
from typing import Any, Dict, Tuple, Iterator
import random

from omegaconf import OmegaConf
from transformers import AutoTokenizer

from torch.utils.data import IterableDataset

from nemo_reinforcer.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_reinforcer.algorithms.utils import get_tokenizer

from nemo_reinforcer.distributed.virtual_cluster import init_ray
from nemo_reinforcer.models.generation.interfaces import configure_generation_config
from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides
from nemo_reinforcer.utils.logger import get_next_experiment_dir

from nemo_reinforcer.environments.games.sliding_puzzle import (
SlidingPuzzleGameLogic,
SlidingPuzzleEnv,
SlidingPuzzleConfig,
SlidingPuzzleMetadata,
)
from nemo_reinforcer.data.interfaces import LLMMessageLogType, DatumSpec


def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Run GRPO training with configuration")
parser.add_argument(
"--config", type=str, default=None, help="Path to YAML config file"
)
args, overrides = parser.parse_known_args()
return args, overrides


def generate_puzzle_datum(
tokenizer,
game_config: SlidingPuzzleConfig,
max_moves: int,
task_name: str,
idx: int,
add_system_prompt: bool,
) -> DatumSpec:
"""Generates a single sliding puzzle datum (prompt and metadata)."""

def generate_random_config(max_config: Dict[str, Any]) -> Dict[str, Any]:
"""Generate a random config for the sliding puzzle game."""
shuffle_moves = random.randint(1, max_config.get("shuffle_moves"))
if shuffle_moves % 2 == 0:
shuffle_moves += 1
return {
"size": random.randint(2, max_config.get("size", 3)),
"shuffle_moves": shuffle_moves,
}

game_config = generate_random_config(game_config)
initial_game_state = SlidingPuzzleGameLogic.generate(game_config)
initial_render = SlidingPuzzleGameLogic.render(initial_game_state)
welcome_message = SlidingPuzzleGameLogic.init(initial_game_state)
puzzle_size = game_config.get("size", 3)
prompt_instructions = (
f"{welcome_message}\n\n"
f"Current Board State:\n{initial_render}\n\n"
f"Reach the goal state where numbers are ordered 1 through {puzzle_size**2 - 1} "
f"with the empty space (0) at the bottom right.\n"
f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n"
f"After thinking, output your chosen action on a new line starting with '<action></action>' like this:\n<action>your_action</action>"
f"\nIf you just want to see the board, output <action>view</action>"
f"\nThink carefully step-by-step before acting.\n"
)
initial_prompt_content = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_instructions}],
tokenize=False,
add_system_prompt=add_system_prompt,
add_generation_prompt=True,
add_special_tokens=False,
).strip()
tokenized_prompt = tokenizer(
initial_prompt_content, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]
message_log: LLMMessageLogType = [
{
"role": "user",
"content": initial_prompt_content,
"token_ids": tokenized_prompt,
}
]
metadata = SlidingPuzzleMetadata(
game_state=initial_game_state, num_moves=0, max_moves=max_moves
)
datum: DatumSpec = {
"message_log": message_log,
"length": len(tokenized_prompt),
"extra_env_info": metadata,
"loss_multiplier": 1.0,
"idx": idx,
"task_name": task_name,
"stop_strings": ["</action>"],
}
return datum


class IterablePuzzleDataset(IterableDataset):
"""An IterableDataset that generates sliding puzzle data indefinitely."""

def __init__(
self, tokenizer, game_config, max_moves, task_name, add_system_prompt, length
):
super().__init__()
self.tokenizer = tokenizer
self.game_config = game_config
self.max_moves = max_moves
self.task_name = task_name
self.add_system_prompt = add_system_prompt
self.length = length

def __iter__(self) -> Iterator[DatumSpec]:
print(f"Starting IterablePuzzleDataset (indefinite generation).")
# Use itertools.count for an infinite index generator
for i in itertools.count():
yield generate_puzzle_datum(
tokenizer=self.tokenizer,
game_config=self.game_config,
max_moves=self.max_moves,
task_name=self.task_name,
idx=i,
add_system_prompt=self.add_system_prompt,
)

def __len__(self):
return self.length


def setup_puzzle_data(
tokenizer: AutoTokenizer,
env_cfg: Dict[str, Any],
task_name: str,
length: int,
val_length: int,
add_system_prompt: bool,
) -> Tuple[IterableDataset, IterableDataset | None, Dict, Dict]:
"""Sets up the iterable data generator and env map for the sliding puzzle task."""
print("Setting up Sliding Puzzle iterable data and environment...")
env_config = env_cfg[task_name]

print(f"Instantiating environment for task '{task_name}'...")
env = SlidingPuzzleEnv.options(num_gpus=0).remote(cfg=dict(env_config["cfg"]))
task_to_env = {task_name: env}
print(f"Environment '{task_name}' created.")

print(f"Creating Sliding Puzzle dataset...")
training_dataset = IterablePuzzleDataset(
tokenizer=tokenizer,
game_config=dict(env_config["cfg"]["game_config"]),
max_moves=env_config["cfg"]["max_moves"],
task_name=task_name,
add_system_prompt=add_system_prompt,
length=length,
)
print("Sliding Puzzle dataset created.")

validation_dataset = IterablePuzzleDataset(
tokenizer=tokenizer,
game_config=dict(env_config["cfg"]["game_config"]),
max_moves=env_config["cfg"]["max_moves"],
task_name=task_name,
add_system_prompt=add_system_prompt,
length=val_length,
)
val_task_to_env = task_to_env

return training_dataset, validation_dataset, task_to_env, val_task_to_env


def main():
"""Main entry point."""
# Parse arguments
args, overrides = parse_args()

if not args.config:
args.config = os.path.join(
os.path.dirname(__file__), "configs", "grpo_sliding_puzzle.yaml"
)

config = load_config(args.config)
print(f"Loaded configuration from: {args.config}")

if overrides:
print(f"Overrides: {overrides}")
config = parse_hydra_overrides(config, overrides)

config: MasterConfig = OmegaConf.to_container(config, resolve=True)
print("Applied CLI overrides")

# Print config
print("Final config:")
pprint.pprint(config)

# Get the next experiment directory with incremented ID
config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
print(f"📊 Using log directory: {config['logger']['log_dir']}")
if config["checkpointing"]["enabled"]:
print(
f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
)

init_ray()

# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["tokenizer"])
config["policy"]["generation"] = configure_generation_config(
config["policy"]["generation"], tokenizer
)

# setup data & env map
ds_length = (
config["grpo"]["num_prompts_per_step"]
* config["grpo"]["num_generations_per_prompt"]
* config["grpo"]["max_num_steps"]
)
dataset, val_dataset, task_to_env, val_task_to_env = setup_puzzle_data(
tokenizer=tokenizer,
env_cfg=config["env"],
task_name="sliding_puzzle_game",
length=ds_length,
val_length=config["grpo"]["max_val_samples"],
add_system_prompt=config["data"]["add_system_prompt"],
)

(
policy,
policy_generation,
cluster,
dataloader,
val_dataloader,
loss_fn,
logger,
checkpointer,
grpo_state,
master_config,
) = setup(config, tokenizer, dataset, val_dataset)

grpo_train(
policy,
policy_generation,
dataloader,
val_dataloader,
tokenizer,
loss_fn,
task_to_env,
val_task_to_env,
logger,
checkpointer,
grpo_state,
master_config,
)


if __name__ == "__main__":
main()
Loading