diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml
index e3bee5e7f6..876c43021a 100644
--- a/.github/workflows/cicd-main.yml
+++ b/.github/workflows/cicd-main.yml
@@ -170,6 +170,7 @@ jobs:
if [[ "${{ needs.pre-flight.outputs.test_level }}" =~ ^(L1|L2)$ ]]; then
uv run --no-sync bash ./tests/functional/sft.sh
uv run --no-sync bash ./tests/functional/grpo.sh
+ uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
uv run --no-sync bash ./tests/functional/dpo.sh
else
echo Skipping functional tests for level ${{ needs.pre-flight.outputs.test_level }}
diff --git a/README.md b/README.md
index 3381fef9f7..bfd0ccb668 100644
--- a/README.md
+++ b/README.md
@@ -32,17 +32,17 @@ What you can expect:
✅ _Available now_ | 🔜 _Coming in v0.3_
- ✅ **Fast Generation** - vLLM backend for optimized inference
-- ✅ **HuggingFace Integration** - Works with 1-8B models (Qwen1.5, Llama)
+- ✅ **HuggingFace Integration** - Works with 1-32B models (Qwen2.5, Llama)
- ✅ **Distributed Training** - FSDP support and Ray-based infrastructure
- ✅ **Environment Support** - Support for multi-environment training.
-- ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning)
+- ✅ **Learning Algorithms** - GRPO (Group Relative Policy Optimization), SFT (Supervised Fine-Tuning), and DPO (Direct Preference Optimization)
+- ✅ **Multi-Turn RL** - multi-turn generation and training for RL with tool use, games, etc.
+- ✅ **Large Model Support** - Native PyTorch support for models up to 32B parameters
+- ✅ **Advanced Parallelism** - FSDP2, TP, and SP for efficient training
- ✅ **Worker Isolation** - Process isolation between RL Actors (no worries about global state)
-
-- ✅ **DPO Algorithm** - Direct Preference Optimization for alignment
-- ✅ **Larger Model Support** - Native PyTorch support for models up to 32B parameters
-- ✅ **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training
- ✅ **Environment Isolation** - Dependency isolation between components
+- 🔜 **(Even) Larger Model Support** - Native PyTorch & Megatron
- 🔜 **Improved Native Performance** - Improve training time for Native Pytorch Models
- 🔜 **Megatron Policy** - Support advanced parallelism in training with Megatron Core
- 🔜 **Megatron Inference** - Support Megatron Inference for day-0 support for new megatron models
@@ -145,6 +145,12 @@ sbatch \
ray.sub
```
+We also support multi-turn generation and training (tool use, games, etc.).
+Reference example for training to play a Sliding Puzzle Game:
+```sh
+uv run python examples/run_grpo_sliding_puzzle.py
+```
+
### SFT
We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/).
diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml
new file mode 100644
index 0000000000..27ee2cae46
--- /dev/null
+++ b/examples/configs/grpo_sliding_puzzle.yaml
@@ -0,0 +1,60 @@
+# GRPO Algorithm Configuration
+defaults: "grpo_math_1B.yaml"
+
+grpo:
+ num_prompts_per_step: 32
+ num_generations_per_prompt: 16
+ max_rollout_turns: 50 # Maximum turns allowed per rollout
+ max_num_steps: 10000
+
+checkpointing:
+ enabled: true
+ checkpoint_dir: "results/grpo-sliding-puzzle"
+ metric_name: "val_reward"
+ higher_is_better: true
+ keep_top_k: 3
+ save_period: 10
+
+policy:
+ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
+ max_total_sequence_length: 3072
+
+ generation:
+ backend: "vllm"
+ max_new_tokens: ${policy.max_total_sequence_length}
+ temperature: 1.0
+ # Setting top_p/top_k to 0.999/10000 to strip out Qwen's special/illegal tokens
+ # https://github.com/NVIDIA/reinforcer/issues/237
+ top_p: 0.999
+ top_k: 10000
+ stop_token_ids: null
+ stop_strings: null
+ vllm_cfg:
+ tensor_parallel_size: 1
+ gpu_memory_utilization: 0.6
+ max_model_len: ${policy.max_total_sequence_length}
+
+data:
+ add_system_prompt: false
+
+env:
+ sliding_puzzle_game:
+ cfg:
+ game_config:
+ size: 5 # Size of the puzzle (e.g., 2 for 2x2, 3 for 3x3)
+ shuffle_moves: 15 # Number of random moves to shuffle the solved state
+ max_moves: 50 # Maximum moves allowed per episode
+
+logger:
+ log_dir: "logs" # Base directory for all logs
+ num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
+ wandb_enabled: false
+ tensorboard_enabled: false
+ monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
+ wandb:
+ project: "grpo-dev"
+ name: "grpo-dev-sliding_puzzle"
+ tensorboard: {}
+ gpu_monitoring:
+ collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
+ flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)
diff --git a/examples/run_grpo_sliding_puzzle.py b/examples/run_grpo_sliding_puzzle.py
new file mode 100644
index 0000000000..abd468881f
--- /dev/null
+++ b/examples/run_grpo_sliding_puzzle.py
@@ -0,0 +1,278 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import pprint
+import itertools
+from typing import Any, Dict, Tuple, Iterator
+import random
+
+from omegaconf import OmegaConf
+from transformers import AutoTokenizer
+
+from torch.utils.data import IterableDataset
+
+from nemo_reinforcer.algorithms.grpo import MasterConfig, grpo_train, setup
+from nemo_reinforcer.algorithms.utils import get_tokenizer
+
+from nemo_reinforcer.distributed.virtual_cluster import init_ray
+from nemo_reinforcer.models.generation.interfaces import configure_generation_config
+from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides
+from nemo_reinforcer.utils.logger import get_next_experiment_dir
+
+from nemo_reinforcer.environments.games.sliding_puzzle import (
+ SlidingPuzzleGameLogic,
+ SlidingPuzzleEnv,
+ SlidingPuzzleConfig,
+ SlidingPuzzleMetadata,
+)
+from nemo_reinforcer.data.interfaces import LLMMessageLogType, DatumSpec
+
+
+def parse_args():
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(description="Run GRPO training with configuration")
+ parser.add_argument(
+ "--config", type=str, default=None, help="Path to YAML config file"
+ )
+ args, overrides = parser.parse_known_args()
+ return args, overrides
+
+
+def generate_puzzle_datum(
+ tokenizer,
+ game_config: SlidingPuzzleConfig,
+ max_moves: int,
+ task_name: str,
+ idx: int,
+ add_system_prompt: bool,
+) -> DatumSpec:
+ """Generates a single sliding puzzle datum (prompt and metadata)."""
+
+ def generate_random_config(max_config: Dict[str, Any]) -> Dict[str, Any]:
+ """Generate a random config for the sliding puzzle game."""
+ shuffle_moves = random.randint(1, max_config.get("shuffle_moves"))
+ if shuffle_moves % 2 == 0:
+ shuffle_moves += 1
+ return {
+ "size": random.randint(2, max_config.get("size", 3)),
+ "shuffle_moves": shuffle_moves,
+ }
+
+ game_config = generate_random_config(game_config)
+ initial_game_state = SlidingPuzzleGameLogic.generate(game_config)
+ initial_render = SlidingPuzzleGameLogic.render(initial_game_state)
+ welcome_message = SlidingPuzzleGameLogic.init(initial_game_state)
+ puzzle_size = game_config.get("size", 3)
+ prompt_instructions = (
+ f"{welcome_message}\n\n"
+ f"Current Board State:\n{initial_render}\n\n"
+ f"Reach the goal state where numbers are ordered 1 through {puzzle_size**2 - 1} "
+ f"with the empty space (0) at the bottom right.\n"
+ f"Valid actions: 'up', 'down', 'left', 'right', or 'slide row col' (e.g., 'slide 1 2').\n"
+ f"After thinking, output your chosen action on a new line starting with '' like this:\nyour_action"
+ f"\nIf you just want to see the board, output view"
+ f"\nThink carefully step-by-step before acting.\n"
+ )
+ initial_prompt_content = tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt_instructions}],
+ tokenize=False,
+ add_system_prompt=add_system_prompt,
+ add_generation_prompt=True,
+ add_special_tokens=False,
+ ).strip()
+ tokenized_prompt = tokenizer(
+ initial_prompt_content, return_tensors="pt", add_special_tokens=False
+ )["input_ids"][0]
+ message_log: LLMMessageLogType = [
+ {
+ "role": "user",
+ "content": initial_prompt_content,
+ "token_ids": tokenized_prompt,
+ }
+ ]
+ metadata = SlidingPuzzleMetadata(
+ game_state=initial_game_state, num_moves=0, max_moves=max_moves
+ )
+ datum: DatumSpec = {
+ "message_log": message_log,
+ "length": len(tokenized_prompt),
+ "extra_env_info": metadata,
+ "loss_multiplier": 1.0,
+ "idx": idx,
+ "task_name": task_name,
+ "stop_strings": [""],
+ }
+ return datum
+
+
+class IterablePuzzleDataset(IterableDataset):
+ """An IterableDataset that generates sliding puzzle data indefinitely."""
+
+ def __init__(
+ self, tokenizer, game_config, max_moves, task_name, add_system_prompt, length
+ ):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.game_config = game_config
+ self.max_moves = max_moves
+ self.task_name = task_name
+ self.add_system_prompt = add_system_prompt
+ self.length = length
+
+ def __iter__(self) -> Iterator[DatumSpec]:
+ print(f"Starting IterablePuzzleDataset (indefinite generation).")
+ # Use itertools.count for an infinite index generator
+ for i in itertools.count():
+ yield generate_puzzle_datum(
+ tokenizer=self.tokenizer,
+ game_config=self.game_config,
+ max_moves=self.max_moves,
+ task_name=self.task_name,
+ idx=i,
+ add_system_prompt=self.add_system_prompt,
+ )
+
+ def __len__(self):
+ return self.length
+
+
+def setup_puzzle_data(
+ tokenizer: AutoTokenizer,
+ env_cfg: Dict[str, Any],
+ task_name: str,
+ length: int,
+ val_length: int,
+ add_system_prompt: bool,
+) -> Tuple[IterableDataset, IterableDataset | None, Dict, Dict]:
+ """Sets up the iterable data generator and env map for the sliding puzzle task."""
+ print("Setting up Sliding Puzzle iterable data and environment...")
+ env_config = env_cfg[task_name]
+
+ print(f"Instantiating environment for task '{task_name}'...")
+ env = SlidingPuzzleEnv.options(num_gpus=0).remote(cfg=dict(env_config["cfg"]))
+ task_to_env = {task_name: env}
+ print(f"Environment '{task_name}' created.")
+
+ print(f"Creating Sliding Puzzle dataset...")
+ training_dataset = IterablePuzzleDataset(
+ tokenizer=tokenizer,
+ game_config=dict(env_config["cfg"]["game_config"]),
+ max_moves=env_config["cfg"]["max_moves"],
+ task_name=task_name,
+ add_system_prompt=add_system_prompt,
+ length=length,
+ )
+ print("Sliding Puzzle dataset created.")
+
+ validation_dataset = IterablePuzzleDataset(
+ tokenizer=tokenizer,
+ game_config=dict(env_config["cfg"]["game_config"]),
+ max_moves=env_config["cfg"]["max_moves"],
+ task_name=task_name,
+ add_system_prompt=add_system_prompt,
+ length=val_length,
+ )
+ val_task_to_env = task_to_env
+
+ return training_dataset, validation_dataset, task_to_env, val_task_to_env
+
+
+def main():
+ """Main entry point."""
+ # Parse arguments
+ args, overrides = parse_args()
+
+ if not args.config:
+ args.config = os.path.join(
+ os.path.dirname(__file__), "configs", "grpo_sliding_puzzle.yaml"
+ )
+
+ config = load_config(args.config)
+ print(f"Loaded configuration from: {args.config}")
+
+ if overrides:
+ print(f"Overrides: {overrides}")
+ config = parse_hydra_overrides(config, overrides)
+
+ config: MasterConfig = OmegaConf.to_container(config, resolve=True)
+ print("Applied CLI overrides")
+
+ # Print config
+ print("Final config:")
+ pprint.pprint(config)
+
+ # Get the next experiment directory with incremented ID
+ config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
+ print(f"📊 Using log directory: {config['logger']['log_dir']}")
+ if config["checkpointing"]["enabled"]:
+ print(
+ f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
+ )
+
+ init_ray()
+
+ # setup tokenizer
+ tokenizer = get_tokenizer(config["policy"]["tokenizer"])
+ config["policy"]["generation"] = configure_generation_config(
+ config["policy"]["generation"], tokenizer
+ )
+
+ # setup data & env map
+ ds_length = (
+ config["grpo"]["num_prompts_per_step"]
+ * config["grpo"]["num_generations_per_prompt"]
+ * config["grpo"]["max_num_steps"]
+ )
+ dataset, val_dataset, task_to_env, val_task_to_env = setup_puzzle_data(
+ tokenizer=tokenizer,
+ env_cfg=config["env"],
+ task_name="sliding_puzzle_game",
+ length=ds_length,
+ val_length=config["grpo"]["max_val_samples"],
+ add_system_prompt=config["data"]["add_system_prompt"],
+ )
+
+ (
+ policy,
+ policy_generation,
+ cluster,
+ dataloader,
+ val_dataloader,
+ loss_fn,
+ logger,
+ checkpointer,
+ grpo_state,
+ master_config,
+ ) = setup(config, tokenizer, dataset, val_dataset)
+
+ grpo_train(
+ policy,
+ policy_generation,
+ dataloader,
+ val_dataloader,
+ tokenizer,
+ loss_fn,
+ task_to_env,
+ val_task_to_env,
+ logger,
+ checkpointer,
+ grpo_state,
+ master_config,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/unit/environments/sliding_puzzle_game.py b/nemo_reinforcer/environments/games/sliding_puzzle.py
similarity index 51%
rename from tests/unit/environments/sliding_puzzle_game.py
rename to nemo_reinforcer/environments/games/sliding_puzzle.py
index 664e4c312b..0bb595bc0c 100644
--- a/tests/unit/environments/sliding_puzzle_game.py
+++ b/nemo_reinforcer/environments/games/sliding_puzzle.py
@@ -12,13 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import ray
+import torch
+from typing import Dict, List, Tuple, Optional, TypedDict, Any
import random
import copy
-from typing import List, Tuple, Dict, Any, Optional
-from .game_interface import GameInterface
+from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
+from nemo_reinforcer.data.interfaces import LLMMessageLogType
+from nemo_reinforcer.environments.interfaces import (
+ EnvironmentInterface,
+ EnvironmentReturn,
+)
+from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES
-class SlidingPuzzleGame(GameInterface):
+
+class SlidingPuzzleConfig(TypedDict):
+ size: int
+ shuffle_moves: int
+
+
+class SlidingPuzzleMetadata(TypedDict):
+ game_state: Dict[str, Any] # Stores the dict returned by SlidingPuzzleGame methods
+ num_moves: int
+ max_moves: int
+
+
+class SlidingPuzzleGameLogic:
@staticmethod
def generate(config: Dict[str, Any]) -> Dict[str, Any]:
"""Generate a new Sliding Puzzle."""
@@ -62,11 +82,11 @@ def generate(config: Dict[str, Any]) -> Dict[str, Any]:
"solution": solution,
"empty_pos": empty_pos,
"commands": {
- "slide r c": "Slide tile at row r, column c into the empty space",
"up": "Slide tile below empty space up",
"down": "Slide tile above empty space down",
"left": "Slide tile to the right of empty space left",
"right": "Slide tile to the left of empty space right",
+ "view": "View the current state of the board",
},
}
@@ -79,8 +99,8 @@ def init(game_state: Dict[str, Any]) -> str:
f"\n===== SLIDING PUZZLE =====\n"
f"Arrange the {size}x{size} grid by sliding tiles into the empty space.\n"
f"- The goal is to arrange numbers from 1 to {size * size - 1} in order\n"
- f"- Use 'slide r c' to slide a specific tile\n"
- f"- Or use 'up', 'down', 'left', 'right' to slide in that direction"
+ f"- Use 'up', 'down', 'left', 'right' to slide in that direction\n"
+ f"- Use 'view' to see the current state of the board"
)
@staticmethod
@@ -94,7 +114,7 @@ def step(
# Default return values
response = "Unknown command. Type 'help' to see available commands."
- reward = -0.05 # Small penalty for invalid actions
+ reward = 0.0 # No penalty for invalid actions
is_terminated = False
# Deep copy game state to avoid modifying the original
@@ -218,39 +238,167 @@ def render(game_state: Dict[str, Any]) -> str:
return "\n".join(output)
-def is_solvable(grid: List[List[int]], size: int) -> bool:
- """Check if a sliding puzzle is solvable."""
- # Flatten the grid
- flat = [num for row in grid for num in row if num != 0]
-
- # Count inversions
- inversions = 0
- for i in range(len(flat)):
- for j in range(i + 1, len(flat)):
- if flat[i] > flat[j]:
- inversions += 1
-
- # Find row of the empty tile (0) from the bottom
- empty_row = 0
- for i in range(size - 1, -1, -1):
- for j in range(size):
- if grid[i][j] == 0:
- empty_row = size - i
- break
-
- # For odd-sized grids, the puzzle is solvable if the number of inversions is even
- if size % 2 == 1:
- return inversions % 2 == 0
- # For even-sized grids, the puzzle is solvable if:
- # (inversions odd and empty on even row from bottom) or (inversions even and empty on odd row from bottom)
- else:
- return (inversions % 2 == 1 and empty_row % 2 == 0) or (
- inversions % 2 == 0 and empty_row % 2 == 1
+class SlidingPuzzleRunner:
+ def __init__(self):
+ pass # No initialization needed as game methods are static
+
+ def _parse_action(self, text: str) -> Optional[str]:
+ """Parses the action from ''."""
+ prefix = ""
+ suffix = ""
+ # Find the prefix, case-insensitive, and potentially after some thought process
+ text_lower = text.lower()
+ prefix_lower = prefix.lower()
+ suffix_lower = suffix.lower()
+
+ start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence
+
+ if start_idx != -1:
+ # Find the end tag after the start tag
+ end_idx = text_lower.find(suffix_lower, start_idx + len(prefix_lower))
+ if end_idx != -1:
+ # Extract content between tags
+ action_content = text[start_idx + len(prefix) : end_idx].strip()
+ return action_content
+ return None
+
+ def process_turn(
+ self,
+ message_log: LLMMessageLogType,
+ metadata: SlidingPuzzleMetadata,
+ ) -> Tuple[
+ Dict[str, str],
+ float,
+ bool,
+ Optional[List[str]],
+ Optional[SlidingPuzzleMetadata],
+ ]:
+ """Processes a single turn for the sliding puzzle task."""
+ game_state = metadata["game_state"]
+ current_moves = metadata["num_moves"]
+ max_moves = metadata["max_moves"]
+
+ turn_reward = 0.0
+ is_terminated = False
+ next_stop_strings = [""]
+ next_metadata = metadata.copy()
+ next_observation_content = ""
+
+ # Check if max moves reached
+ if current_moves >= max_moves:
+ is_terminated = True
+ next_observation_content = (
+ f"Maximum moves ({max_moves}) reached."
+ )
+ next_metadata = None
+ return (
+ {"role": "environment", "content": next_observation_content},
+ 0.0,
+ is_terminated,
+ None,
+ next_metadata,
+ )
+
+ # Get last assistant message and parse action
+ last_assistant_msg_content = ""
+ if message_log and message_log[-1]["role"] == "assistant":
+ last_assistant_msg_content = message_log[-1]["content"].strip()
+
+ parsed_action = self._parse_action(last_assistant_msg_content)
+
+ if parsed_action is None:
+ rendered_board = SlidingPuzzleGameLogic.render(game_state)
+ next_observation_content = f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action"
+ next_metadata = None
+ elif parsed_action == "view":
+ rendered_board = SlidingPuzzleGameLogic.render(game_state)
+ next_observation_content = f"\n{rendered_board}\n\nViewing the board. No move made."
+ else:
+ # Execute the game step
+ step_response, reward, game_over, next_game_state = (
+ SlidingPuzzleGameLogic.step(parsed_action, game_state)
+ )
+
+ turn_reward = reward
+ is_terminated = game_over
+ next_metadata["game_state"] = next_game_state
+ next_metadata["num_moves"] = current_moves + 1
+
+ next_observation_content = f"\n{step_response}\n"
+
+ if is_terminated:
+ next_metadata = None # Clear metadata on termination
+
+ return (
+ {"role": "environment", "content": next_observation_content + "\n"},
+ turn_reward,
+ is_terminated,
+ next_stop_strings,
+ next_metadata,
)
-def play_sliding_puzzle(config=None):
- """Wrapper function for backward compatibility."""
- from play_game import play_game
+@ray.remote
+class SlidingPuzzleEnv(EnvironmentInterface):
+ DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM
+ """Sliding Puzzle environment (Ray Actor)."""
- play_game(SlidingPuzzleGame, config)
+ def __init__(self, cfg: Optional[SlidingPuzzleConfig] = None):
+ # cfg could contain game generation config like {'size': 3, 'shuffle_moves': 50}
+ self.game_config = cfg.get("game_config", {}) if cfg else {}
+ self.runner = SlidingPuzzleRunner()
+
+ def step(
+ self,
+ message_log_batch: List[LLMMessageLogType],
+ metadata_batch: List[SlidingPuzzleMetadata],
+ ) -> EnvironmentReturn:
+ """Processes a batch of sliding puzzle interactions."""
+ # Since logic is synchronous, process sequentially (can parallelize if logic becomes heavy)
+ results = [
+ self.runner.process_turn(log, meta)
+ for log, meta in zip(message_log_batch, metadata_batch)
+ ]
+
+ # Unpack results and format according to EnvironmentReturn NamedTuple
+ observations = []
+ rewards = []
+ terminateds = []
+ all_stop_strings = []
+ all_next_metadata = []
+
+ for obs, rew, term, stops, meta in results:
+ observations.append(obs)
+ rewards.append(rew)
+ terminateds.append(term)
+ all_stop_strings.append(stops)
+ all_next_metadata.append(meta)
+
+ rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
+ terminated_tensor = torch.tensor(terminateds, dtype=torch.bool)
+
+ return EnvironmentReturn(
+ observations=observations,
+ metadata=all_next_metadata,
+ next_stop_strings=all_stop_strings,
+ rewards=rewards_tensor,
+ terminateds=terminated_tensor,
+ )
+
+ def shutdown(self):
+ pass
+
+ def global_post_process_and_metrics(
+ self, batch: BatchedDataDict
+ ) -> Tuple[BatchedDataDict, dict]:
+ # Calculate success rate based on final reward == 1.0
+ final_rewards = batch.get(
+ "total_reward", torch.tensor([0.0] * len(batch["idx"]))
+ )
+ success_rate = (
+ (final_rewards == 1.0).float().mean().item()
+ if len(final_rewards) > 0
+ else 0.0
+ )
+ # Could also calculate average number of moves for successful episodes, etc.
+ return batch, {"sliding_puzzle_success_rate": success_rate}
diff --git a/tests/functional/grpo_multiturn.sh b/tests/functional/grpo_multiturn.sh
new file mode 100755
index 0000000000..ff9befcdd7
--- /dev/null
+++ b/tests/functional/grpo_multiturn.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
+PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
+# Mark the current repo as safe, since wandb fetchs metadata about the repo
+git config --global --add safe.directory $PROJECT_ROOT
+
+set -eou pipefail
+
+LOG_DIR=$SCRIPT_DIR/$(basename $0 .sh)-logs
+JSON_METRICS=$LOG_DIR/$(basename $0 .sh).json
+RUN_LOG=$LOG_DIR/$(basename $0 .sh).log
+export UV_CACHE_DIR=${UV_CACHE_DIR:-$PROJECT_ROOT/uv_cache}
+export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
+
+rm -rf $LOG_DIR
+mkdir -p $LOG_DIR
+
+cd $PROJECT_ROOT
+python -u $PROJECT_ROOT/examples/run_grpo_sliding_puzzle.py \
+ cluster.gpus_per_node=2 \
+ grpo.max_rollout_turns=10 \
+ grpo.max_num_steps=3 \
+ policy.max_total_sequence_length=1024 \
+ policy.train_micro_batch_size=1 \
+ policy.generation.top_p=0.99 \
+ policy.generation.top_k=8000 \
+ logger.tensorboard_enabled=true \
+ logger.log_dir=$LOG_DIR \
+ logger.wandb_enabled=false \
+ checkpointing.enabled=false \
+ $@ \
+ 2>&1 | tee $RUN_LOG
+
+cd $SCRIPT_DIR
+python json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
+
+python check_metrics.py $JSON_METRICS \
+ 'max(data["train/token_mult_prob_error"]) < 1.1' \
+
diff --git a/tests/unit/environments/game_interface.py b/tests/unit/environments/game_interface.py
deleted file mode 100644
index 2f0237ed23..0000000000
--- a/tests/unit/environments/game_interface.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Dict, Any, Tuple, List, Optional, Callable
-
-
-class GameInterface:
- @staticmethod
- def generate(config: Dict[str, Any]) -> Dict[str, Any]:
- """
- Generate a new game state based on configuration.
-
- Args:
- config: Game configuration dictionary
-
- Returns:
- A dictionary containing the complete game state
- """
- raise NotImplementedError("Each game must implement generate()")
-
- @staticmethod
- def init(game_state: Dict[str, Any]) -> str:
- """
- Initialize a game and return welcome messages.
-
- Args:
- game_state: The game state dictionary
-
- Returns:
- String containing welcome message and instructions
- """
- raise NotImplementedError("Each game must implement init()")
-
- @staticmethod
- def step(
- action: str, game_state: Dict[str, Any]
- ) -> Tuple[str, float, bool, Dict[str, Any]]:
- """
- Process a game action and update the state.
-
- Args:
- action: String representing the player's action
- game_state: Current game state dictionary
-
- Returns:
- Tuple containing:
- - Response message
- - Reward for this action
- - Boolean indicating if game is terminated
- - Updated game state
- """
- raise NotImplementedError("Each game must implement step()")
-
- @staticmethod
- def render(game_state: Dict[str, Any]) -> str:
- """
- Render the current game state as a string.
-
- Args:
- game_state: The game state dictionary
-
- Returns:
- String representation of the game state
- """
- raise NotImplementedError("Each game must implement render()")
diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py
index daeecb2bc6..a1e72fa6a7 100644
--- a/tests/unit/experience/test_rollouts.py
+++ b/tests/unit/experience/test_rollouts.py
@@ -32,13 +32,15 @@
MultiStepCalculatorEnv,
_MultiStepCalculatorLogic,
MultiStepCalcMetadata,
+)
+
+from nemo_reinforcer.environments.games.sliding_puzzle import (
+ SlidingPuzzleGameLogic,
SlidingPuzzleEnv,
+ SlidingPuzzleConfig,
SlidingPuzzleMetadata,
)
-# Import the game logic for generating initial state from its new location
-from tests.unit.environments.sliding_puzzle_game import SlidingPuzzleGame
-
from nemo_reinforcer.models.generation.vllm import VllmConfig, VllmGeneration
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
@@ -463,16 +465,16 @@ def sliding_puzzle_environment(rollout_cluster):
def initial_sliding_puzzle_batch(rollout_tokenizer):
print("Creating initial sliding puzzle test batch...")
batch_size = 1
- game_config = {
+ game_config: SlidingPuzzleConfig = {
"size": 2,
"shuffle_moves": 1,
}
max_moves = 10 # Set a limit for the test
# Generate initial game state
- initial_game_state = SlidingPuzzleGame.generate(game_config)
- initial_render = SlidingPuzzleGame.render(initial_game_state)
- welcome_message = SlidingPuzzleGame.init(initial_game_state)
+ initial_game_state = SlidingPuzzleGameLogic.generate(game_config)
+ initial_render = SlidingPuzzleGameLogic.render(initial_game_state)
+ welcome_message = SlidingPuzzleGameLogic.init(initial_game_state)
prompt_instructions = (
f"{welcome_message}\n\n"
diff --git a/tests/unit/test_envs.py b/tests/unit/test_envs.py
index 5d410f5895..5fe62f135b 100644
--- a/tests/unit/test_envs.py
+++ b/tests/unit/test_envs.py
@@ -23,7 +23,6 @@
EnvironmentReturn,
)
from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES
-from .environments.sliding_puzzle_game import SlidingPuzzleGame
class MultiStepCalcMetadata(TypedDict):
@@ -33,12 +32,6 @@ class MultiStepCalcMetadata(TypedDict):
current_step: int
-class SlidingPuzzleMetadata(TypedDict):
- game_state: Dict[str, Any] # Stores the dict returned by SlidingPuzzleGame methods
- num_moves: int
- max_moves: int
-
-
class _MultiStepCalculatorLogic:
def __init__(self):
pass
@@ -183,113 +176,6 @@ def process_turn(
)
-class _SlidingPuzzleLogic:
- def __init__(self):
- pass # No initialization needed as game methods are static
-
- def _parse_action(self, text: str) -> Optional[str]:
- """Parses the action from ''"""
- prefix = ""
- suffix = ""
- # Find the prefix, case-insensitive, and potentially after some thought process
- text_lower = text.lower()
- prefix_lower = prefix.lower()
- suffix_lower = suffix.lower()
-
- start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence
-
- if start_idx != -1:
- # Find the end tag after the start tag
- end_idx = text_lower.find(suffix_lower, start_idx + len(prefix_lower))
- if end_idx != -1:
- # Extract content between tags
- action_content = text[start_idx + len(prefix) : end_idx].strip()
- return action_content
- return None
-
- def process_turn(
- self,
- message_log: LLMMessageLogType,
- metadata: SlidingPuzzleMetadata,
- ) -> Tuple[
- Dict[str, str],
- float,
- bool,
- Optional[List[str]],
- Optional[SlidingPuzzleMetadata],
- ]:
- """Processes a single turn for the sliding puzzle task."""
- game_state = metadata["game_state"]
- current_moves = metadata["num_moves"]
- max_moves = metadata["max_moves"]
-
- turn_reward = 0.0
- is_terminated = False
- next_stop_strings = [""]
- next_metadata = metadata.copy()
- next_observation_content = ""
-
- # Check if max moves reached
- if current_moves >= max_moves:
- is_terminated = True
- next_observation_content = (
- f"Maximum moves ({max_moves}) reached."
- )
- next_metadata = None
- return (
- {"role": "environment", "content": next_observation_content},
- 0.0,
- is_terminated,
- None,
- next_metadata,
- )
-
- # Get last assistant message and parse action
- last_assistant_msg_content = ""
- if message_log and message_log[-1]["role"] == "assistant":
- last_assistant_msg_content = message_log[-1]["content"].strip()
-
- parsed_action = self._parse_action(last_assistant_msg_content)
-
- if parsed_action is None:
- # Handle cases where parsing failed or it wasn't assistant's turn properly
- # is_terminated = True # Penalize for bad format
- rendered_board = SlidingPuzzleGame.render(game_state)
- next_observation_content = f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action"
- next_metadata = None
- elif parsed_action == "view":
- rendered_board = SlidingPuzzleGame.render(game_state)
- next_observation_content = f"\n{rendered_board}\n\nViewing the board. No move made."
- else:
- # Execute the game step
- step_response, reward, game_over, next_game_state = SlidingPuzzleGame.step(
- parsed_action, game_state
- )
-
- turn_reward = reward
- is_terminated = game_over
- next_metadata["game_state"] = next_game_state
- next_metadata["num_moves"] = current_moves + 1
-
- # Combine rendered board and step response for the next observation
- rendered_board = SlidingPuzzleGame.render(next_game_state)
- # next_observation_content = f"\n{rendered_board}\n\n{step_response}"
- next_observation_content = f"\n{step_response}\n"
- # next_observation_content = f"\n{step_response}"
-
- if is_terminated:
- next_metadata = None # Clear metadata on termination
- # next_stop_strings remains None
-
- return (
- {"role": "environment", "content": next_observation_content + "\n"},
- turn_reward,
- is_terminated,
- next_stop_strings,
- next_metadata,
- )
-
-
@ray.remote
class MultiStepCalculatorEnv(EnvironmentInterface):
DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM
@@ -350,69 +236,3 @@ def global_post_process_and_metrics(
)
success_rate = final_rewards.mean().item() if len(final_rewards) > 0 else 0.0
return batch, {"success_rate": success_rate}
-
-
-@ray.remote
-class SlidingPuzzleEnv(EnvironmentInterface):
- DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM
- """Sliding Puzzle environment (Ray Actor)."""
-
- def __init__(self, cfg: Optional[Dict] = None):
- # cfg could contain game generation config like {'size': 3, 'shuffle_moves': 50}
- self.game_config = cfg.get("game_config", {}) if cfg else {}
- self.logic = _SlidingPuzzleLogic()
-
- def step(
- self,
- message_log_batch: List[LLMMessageLogType],
- metadata_batch: List[SlidingPuzzleMetadata],
- ) -> EnvironmentReturn:
- """Processes a batch of sliding puzzle interactions."""
- # Since logic is synchronous, process sequentially (can parallelize if logic becomes heavy)
- results = [
- self.logic.process_turn(log, meta)
- for log, meta in zip(message_log_batch, metadata_batch)
- ]
-
- # Unpack results and format according to EnvironmentReturn NamedTuple
- observations = []
- rewards = []
- terminateds = []
- all_stop_strings = []
- all_next_metadata = []
-
- for obs, rew, term, stops, meta in results:
- observations.append(obs)
- rewards.append(rew)
- terminateds.append(term)
- all_stop_strings.append(stops)
- all_next_metadata.append(meta)
-
- rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
- terminated_tensor = torch.tensor(terminateds, dtype=torch.bool)
-
- return EnvironmentReturn(
- observations=observations,
- metadata=all_next_metadata,
- next_stop_strings=all_stop_strings,
- rewards=rewards_tensor,
- terminateds=terminated_tensor,
- )
-
- def shutdown(self):
- pass
-
- def global_post_process_and_metrics(
- self, batch: BatchedDataDict
- ) -> Tuple[BatchedDataDict, dict]:
- # Calculate success rate based on final reward == 1.0
- final_rewards = batch.get(
- "total_reward", torch.tensor([0.0] * len(batch["idx"]))
- )
- success_rate = (
- (final_rewards == 1.0).float().mean().item()
- if len(final_rewards) > 0
- else 0.0
- )
- # Could also calculate average number of moves for successful episodes, etc.
- return batch, {"sliding_puzzle_success_rate": success_rate}