diff --git a/.gitignore b/.gitignore index 19f2edc5b8..478990ddc8 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,8 @@ dist/ *.vscode/ # Test -.coverage +coverage.json +.coverage* test_assets/ # Cache diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 4cf474df01..747e98abdb 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -2,6 +2,7 @@ grpo: num_prompts_per_step: 32 num_generations_per_prompt: 16 + max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question) max_num_steps: 1000000 normalize_rewards: true use_leave_one_out_baseline: true diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 1914b27e98..155b49a8df 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -24,7 +24,10 @@ from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.algorithms.utils import calculate_baseline_and_std_per_prompt -from nemo_reinforcer.environments.interfaces import EnvironmentInterface +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.data.interfaces import ( DatumSpec, @@ -59,6 +62,7 @@ from nemo_reinforcer.utils.logger import Logger, LoggerConfig from nemo_reinforcer.utils.timer import Timer from nemo_reinforcer.utils.checkpoint import CheckpointManager, CheckpointingConfig +from nemo_reinforcer.experience.rollouts import run_multi_turn_rollout # =============================================================================== @@ -73,6 +77,7 @@ class GRPOConfig(TypedDict): normalize_rewards: bool use_leave_one_out_baseline: bool val_period: int + val_batch_size: int val_at_start: bool checkpoint_dir: str @@ -94,7 +99,7 @@ def _default_grpo_save_state() -> GRPOSaveState: class MasterConfig(TypedDict): policy: PolicyConfig loss_fn: ClippedPGLossConfig - math_env: MathEnvConfig + env_configs: Dict[str, Any] data: DataConfig grpo: GRPOConfig logger: LoggerConfig @@ -283,120 +288,6 @@ def refit_policy_generation( policy.offload_after_refit() -def generate_responses( - policy_generation: GenerationInterface, - generation_input_data: BatchedDataDict[GenerationDatumSpec], - batch: BatchedDataDict[DatumSpec], - tokenizer, - input_lengths: torch.Tensor, - include_logprobs: bool = True, -) -> Tuple[BatchedDataDict[DatumSpec], List[List[int]], Dict[str, float | int]]: - """Generate responses from policy.""" - # Generate responses - generation_outputs = policy_generation.generate(generation_input_data) - - # Extract generated tokens - generated_ids = [] - unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"] - for output_ids, input_length, total_length in zip( - generation_outputs["output_ids"], input_lengths, unpadded_sequence_lengths - ): - generated_ids.append(output_ids[input_length:total_length]) - - generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - - # Append to message log - for i, (text, input_length, total_length) in enumerate( - zip(generated_texts, input_lengths, unpadded_sequence_lengths) - ): - message = { - "role": "assistant", - "content": text, - "token_ids": generation_outputs["output_ids"][i, input_length:total_length], - } - - if include_logprobs and "logprobs" in generation_outputs: - message["generation_logprobs"] = generation_outputs["logprobs"][ - i, input_length:total_length - ] - - batch["message_log"][i].append(message) - - metrics = { - "mean_generation_length": ( - torch.sum(unpadded_sequence_lengths) - torch.sum(input_lengths) - ).item() - / len(unpadded_sequence_lengths), - "max_seqlen": torch.max(unpadded_sequence_lengths).item(), - } - - return batch, generated_ids, metrics - - -def calculate_rewards( - batch: BatchedDataDict[DatumSpec], - task_to_env: Dict[str, EnvironmentInterface], -) -> Tuple[torch.Tensor, List[LLMMessageLogType]]: - """Calculate rewards for generated responses. - - Args: - batch: Batch containing message_log (LLMMessageLogType) with generated responses - task_to_env: Dictionary mapping task names to their corresponding environments - - Returns: - rewards: Tensor of rewards - to_env: Simplified message logs sent to environment (LLMMessageLogType format) - """ - # Extract message logs for environment - to_env = [ - get_keys_from_message_log(batch["message_log"][i], ["role", "content"]) - for i in range(len(batch["message_log"])) - ] - task_names = [batch["task_name"][i] for i in range(len(batch["task_name"]))] - - # Group messages by task type - task_groups = {} - for i, task_name in enumerate(task_names): - if task_name not in task_groups: - task_groups[task_name] = [] - task_groups[task_name].append((i, to_env[i])) - - # Calculate rewards for each task group concurrently - futures = [] - future_to_indices = {} # Map future to its corresponding indices - for task_name, group in task_groups.items(): - if task_name not in task_to_env: - raise ValueError(f"No environment found for task type: {task_name}") - - # Extract indices and messages for this group - indices = [idx for idx, _ in group] - messages = [msg for _, msg in group] - - # Get corresponding environment info - env_info = [batch["extra_env_info"][i] for i in indices] - - # Submit task to environment and store future - future = task_to_env[task_name].step.remote(messages, env_info) - futures.append(future) - future_to_indices[future] = indices - - results = ray.get(futures) - all_rewards = [] - for future, result in zip(futures, results): - indices = future_to_indices[future] - _, _, task_rewards, _ = result - - # Store results with their original indices - for idx, reward in zip(indices, task_rewards): - all_rewards.append((idx, reward)) - - # Sort results by original index to maintain order - all_rewards.sort(key=lambda x: x[0]) - rewards = torch.tensor([reward for _, reward in all_rewards]) - - return rewards, to_env - - # =============================================================================== # Training & Validation # =============================================================================== @@ -463,7 +354,7 @@ def grpo_train( print("▶ Preparing batch...") with timer.time("data_processing"): # Repeat batch items - repeated_batch = batch.repeat_interleave( + repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave( master_config["grpo"]["num_generations_per_prompt"] ) # Convert LLMMessageLogType to FlatMessagesType for generation @@ -472,36 +363,33 @@ def grpo_train( pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) input_ids = batched_flat["token_ids"] - # Create generation-specific input structure - generation_input_data = BatchedDataDict[GenerationDatumSpec]( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - } - ) # Generate responses - this updates the LLMMessageLogType in repeated_batch - print(f"▶ Generating responses for batch of size {len(input_ids)}...") + print(f"▶ Generating responses for batch of size {repeated_batch.size}...") with timer.time("prepare_for_generation"): if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation(policy, policy_generation) POLICY_GENERATION_STALE = False else: policy_generation.prepare_for_generation() + with timer.time("generation"): - repeated_batch, _, gen_metrics = generate_responses( - policy_generation, - generation_input_data, - repeated_batch, - tokenizer, - input_lengths, + repeated_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=policy_generation, + input_batch=repeated_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=master_config["policy"]["max_total_sequence_length"], + max_rollout_turns=master_config["grpo"]["max_rollout_turns"], + greedy=False, ) policy_generation.finish_generation() - # Calculate rewards & advantages based on the updated LLMMessageLogType - print("▶ Calculating rewards...") + # Calculate rewards & advantages + print("▶ Processing rewards...") with timer.time("reward_calculation"): - rewards, _ = calculate_rewards(repeated_batch, task_to_env) + # Extract rewards from final_batch + rewards = repeated_batch["total_reward"] print("▶ Computing advantages...") baseline, std = calculate_baseline_and_std_per_prompt( @@ -665,14 +553,14 @@ def grpo_train( metrics[k] = np.sum(v).item() else: metrics[k] = np.mean(v).item() - metrics.update(gen_metrics) + metrics.update(rollout_metrics) timing_metrics = timer.get_timing_metrics(reduction_op="sum") print(f" • Loss: {metrics['loss']:.4f}") print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") print( - f" • Mean Generation Length: {gen_metrics['mean_generation_length']:.4f}" + f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}" ) print("\n⏱️ Timing:") @@ -726,39 +614,25 @@ def validate( if batch_idx >= max_batches: break - # Convert LLMMessageLogType to FlatMessagesType for generation - batched_flat, input_lengths = batched_message_log_to_flat_message( - val_batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - ) - # Extract input IDs - input_ids = batched_flat["token_ids"] - # Create generation-specific input structure - generation_input_data = BatchedDataDict( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - } - ) - # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs) - val_batch, generated_ids, gen_metrics = generate_responses( + val_batch, gen_metrics = run_multi_turn_rollout( policy_generation, - generation_input_data, val_batch, tokenizer, - input_lengths, - include_logprobs=False, + val_task_to_env, + max_seq_len=master_config["policy"]["max_total_sequence_length"], + max_rollout_turns=master_config["grpo"]["max_rollout_turns"], + greedy=False, ) - - # Calculate rewards based on the updated LLMMessageLogType - with timer.time("reward_calculation"): - rewards, to_env = calculate_rewards(val_batch, val_task_to_env) + rewards = val_batch["total_reward"] total_rewards.extend(rewards.tolist()) - total_lengths.extend([len(ids) for ids in generated_ids]) + total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) # Collect message logs for later display + to_env = get_keys_from_message_log( + val_batch["message_log"], ["role", "content"] + ) all_message_logs.extend(to_env) # Calculate validation metrics diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 0d2e61a9ac..0286bcd8d0 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -98,6 +98,9 @@ def __call__( lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now) mult_prob_error = masked_mean(torch.exp(lp_error), mask).item() + if mult_prob_error == 0.0: + # this sometimes gets 0 (everything masked/invalid). Doing this to avoid screwing up stats too much + mult_prob_error = 1.0 next_token_logits = next_token_logits.to(torch.float32) diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 8d8ca78371..f872b4e9f5 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -124,6 +124,9 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: idx = [datum_spec["idx"] for datum_spec in data_batch] batch_max_length = torch.ones_like(length) * length.max() + # Extract stop_strings if present + stop_strings = [datum.get("stop_strings", None) for datum in data_batch] + output = BatchedDataDict( message_log=message_log, length=length, @@ -132,6 +135,7 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: task_name=task_names, idx=idx, batch_max_length=batch_max_length, + stop_strings=stop_strings, ) return output diff --git a/nemo_reinforcer/data/interfaces.py b/nemo_reinforcer/data/interfaces.py index 6ae1152be7..33a44d555d 100644 --- a/nemo_reinforcer/data/interfaces.py +++ b/nemo_reinforcer/data/interfaces.py @@ -32,6 +32,7 @@ class DatumSpec(TypedDict): loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid) idx: int task_name: Optional[str] = "default" + stop_strings: Optional[List[str]] = None # Optional stop strings for generation __extra__: Any # This allows additional fields of any type diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index 362183d978..db908f7bc9 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -289,8 +289,11 @@ def batched_message_log_to_flat_message( # Create input_lengths tensor input_lengths = [] for seq in sequenced_lists: - seq_len = next( - (v.size(0) for v in seq.values() if isinstance(v, torch.Tensor)), 0 + # Find the maximum length among all tensors in the dictionary, default to 0 if none exist + # Use maximum here since there may be keys that aren't populated for all messages yet. + # For example, logprobs don't get populated for non-generated tokens until post-processing. + seq_len = max( + (v.size(0) for v in seq.values() if isinstance(v, torch.Tensor)), default=0 ) input_lengths.append(seq_len) input_lengths_tensor = torch.tensor(input_lengths, dtype=torch.int32) diff --git a/nemo_reinforcer/distributed/batched_data_dict.py b/nemo_reinforcer/distributed/batched_data_dict.py index a1711ae8c2..8e00e3b13b 100644 --- a/nemo_reinforcer/distributed/batched_data_dict.py +++ b/nemo_reinforcer/distributed/batched_data_dict.py @@ -13,7 +13,7 @@ # limitations under the License. from copy import deepcopy from collections import UserDict -from typing import List, Dict, Optional, Iterator, TypeVar, Any, Generic +from typing import List, Dict, Optional, Iterator, TypeVar, Any, Generic, Union from typing_extensions import Self import torch @@ -137,7 +137,10 @@ def chunk(self, rank: int, chunks: int) -> "SlicedDataDict": return chunked_batch def shard_by_batch_size( - self, shards: int, batch_size: Optional[int] = None + self, + shards: int, + batch_size: Optional[int] = None, + allow_uneven_shards: bool = False, ) -> List["SlicedDataDict"]: """Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position. @@ -150,10 +153,47 @@ def shard_by_batch_size( Args: shards (int): The number of shards to divide each batch_size chunk into. batch_size (int): The size of each initial chunk. + allow_uneven_shards (bool): Whether to allow shards to be unevenly sized. + If True, the last shard may be smaller than the others. Returns: List[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards. + + Examples: + ```{doctest} + >>> from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + >>> # Create a batch of two message logs with different lengths + >>> batch = BatchedDataDict({ + ... 'problem_id': [0, 0, 1, 1, 2, 2, 3, 3], + ... 'arbitrary_data': [1, 2, 3, 4, 5, 6, 7, 8] + ... }) + >>> shards = batch.shard_by_batch_size(shards=2) + >>> shards + [{'problem_id': [0, 0, 1, 1], 'arbitrary_data': [1, 2, 3, 4]}, {'problem_id': [2, 2, 3, 3], 'arbitrary_data': [5, 6, 7, 8]}] + >>> # Now say that I'm training with a GBS of 4 and I want to take gradients steps on problems 0 and 1 before 2 and 3 (problems are repeated because GRPO) + >>> # In the current case, problems 0 and 2 will be trained on first since they're the first elements in each DP rank's batch. + >>> # So, we'll use the batch_size argument to split the batch into chunks of size 4 first. + >>> shards = batch.shard_by_batch_size(shards=2, batch_size=4) + >>> shards + [{'problem_id': [0, 0, 2, 2], 'arbitrary_data': [1, 2, 5, 6]}, {'problem_id': [1, 1, 3, 3], 'arbitrary_data': [3, 4, 7, 8]}] + >>> # Now, the ranks have 0 and 1 first so when they split their batches into microbatches (of size 2 since GBS=4 and DP=2), they'll train on 0 and 1 first. + >>> # Another way to use this function is with the 'allow_uneven_shards' flag, which allows the last shard to be smaller than the others when necessary. + >>> # This is necessary in multi-turn rollouts when some sequences terminate early, leaving unclean batch sizes. + >>> batch = BatchedDataDict({ + ... 'problem_id': [0, 1, 2, 3, 4], + ... 'arbitrary_data': [10, 11, 12, 13, 14] + ... }) + >>> shards = batch.shard_by_batch_size(shards=2, allow_uneven_shards=True) + >>> shards + [{'problem_id': [0, 1, 2], 'arbitrary_data': [10, 11, 12]}, {'problem_id': [3, 4], 'arbitrary_data': [13, 14]}] + >>> # This is incompatible with the batch_size argument + ``` """ + if allow_uneven_shards: + assert batch_size is None, ( + "batch_size must be None if allow_uneven_shards is True" + ) + # Get the total batch size batch_sizes = set() for val in self.data.values(): @@ -173,13 +213,18 @@ def shard_by_batch_size( assert total_batch_size % batch_size == 0, ( f"Total batch size ({total_batch_size}) is not a multiple of batch_size ({batch_size})" ) - assert batch_size % shards == 0, ( - f"Batch size ({batch_size}) is not a multiple of shards ({shards})" - ) + if not allow_uneven_shards: + assert batch_size % shards == 0, ( + f"Batch size ({batch_size}) is not a multiple of shards ({shards})" + ) num_chunks = total_batch_size // batch_size - shard_size = batch_size // shards - # Create one BatchedDataDict per shard position + # Calculate shard size, rounding up if not evenly divisible + shard_size = ( + (batch_size + shards - 1) // shards + if allow_uneven_shards + else batch_size // shards + ) aggregated_shards = [SlicedDataDict() for _ in range(shards)] # Group data by shard position across all chunks @@ -189,6 +234,11 @@ def shard_by_batch_size( chunk_start = chunk_idx * batch_size shard_start = chunk_start + shard_idx * shard_size shard_end = chunk_start + (shard_idx + 1) * shard_size + if allow_uneven_shards: + # Cap the end index at the total batch size for the last shard + # or if shard_end calculation goes beyond total_batch_size + shard_start = min(shard_start, total_batch_size) + shard_end = min(shard_end, total_batch_size) indices = torch.arange(shard_start, shard_end) for k in self.data: @@ -275,13 +325,37 @@ def size(self) -> int: return len(self.data[key]) return self.data[key].shape[0] - def to(self, device: torch.device) -> "BatchedDataDict": - """Move all tensors in the batch to a specific device.""" - for k in self.data: - if torch.is_tensor(self.data[k]): - self.data[k] = self.data[k].to(device) + def to(self, device: torch.device) -> Self: + """Move tensors in batched dict to device.""" + for k, v in self.data.items(): + if torch.is_tensor(v): + self.data[k] = v.to(device) return self + def select_indices( + self, indices: Union[List[int], torch.Tensor] + ) -> "BatchedDataDict": + """Selects specific rows from the batch based on indices. + + Args: + indices: A list or tensor of integer indices to select. + + Returns: + BatchedDataDict: A new BatchedDataDict containing only the selected rows. + """ + selected_batch = BatchedDataDict() + for k, v in self.data.items(): + if torch.is_tensor(v): + selected_batch[k] = v[indices] + elif isinstance(v, list): + selected_batch[k] = [v[i] for i in indices] + else: + # Handle other potential types if necessary, or raise error + raise TypeError( + f"Unsupported type {type(v)} for index selection in BatchedDataDict" + ) + return selected_batch + def get_dict(self) -> dict: """Get the underlying data dictionary.""" return self.data diff --git a/nemo_reinforcer/environments/interfaces.py b/nemo_reinforcer/environments/interfaces.py index 40986f4f19..46f42bc24f 100644 --- a/nemo_reinforcer/environments/interfaces.py +++ b/nemo_reinforcer/environments/interfaces.py @@ -12,27 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, NamedTuple, Optional from torch import Tensor from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.data.interfaces import LLMMessageLogType -EnvironmentReturn = Tuple[List[List[Dict[str, str]]], List[Dict], Tensor, Tensor] + +class EnvironmentReturn(NamedTuple): + """Standard batched return type for environment step methods. + + **All elements are batched.** + observations: New observation from the environment. + It's a (batched) 'message' type, which is a dict + with keys 'role' and 'content'. + metadata: Updated metadata from the environment. + next_stop_strings: The stop strings for the next turn. + If your environment is a game or similar, + you may want to return a list of stop strings + that are valid actions for the next turn or + similar. This field lets you control this per turn. + rewards: the rewards for this turn. + terminateds: whether the episode ended this turn. + """ + + observations: List[Dict[str, str]] + metadata: List[Optional[dict]] + next_stop_strings: List[Optional[List[str]]] + rewards: Tensor + terminateds: Tensor class EnvironmentInterface(abc.ABC): @abc.abstractmethod def step( self, - message_log_batch: List[List[Dict[str, str]]], - metadata: List[Dict], + message_log_batch: List[LLMMessageLogType], + metadata: List[Optional[dict]], *args, **kwargs, ) -> EnvironmentReturn: """Runs a step in the environment. Allows for asynchrony with remote servers, but it's not required (this function is a ray remote). message_log_batch: batch of OpenAI-API-like message logs that represent interactions with the LLM. + Each element is a List[Dict[str, Union[str, torch.Tensor]]]. For example, if this were a Math Environment, then the message log would be [ @@ -48,13 +72,10 @@ def step( {"role": "assistant", "content": "model response"}, ] metadata: batch of whatever the environment needs to keep track of. I.e. - math solutions, code unit tests, or agent states. + math solutions, code unit tests, or agent states. Can be None if episode terminated. Returns: - - List[Dict[str, str]]: An observation/response batch in an OpenAI-API-like message format that is the result of the step. - - List[Dict]: An updated batch of metadata. - - Tensor: A tensor of rewards. - - Tensor: A tensor of done flags. + - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags. """ @abc.abstractmethod diff --git a/nemo_reinforcer/environments/math_environment.py b/nemo_reinforcer/environments/math_environment.py index 65a5cc0e27..cc61fdcb2c 100644 --- a/nemo_reinforcer/environments/math_environment.py +++ b/nemo_reinforcer/environments/math_environment.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from itertools import tee -from typing import Dict, List, Tuple, TypedDict +from typing import Dict, List, Tuple, TypedDict, Optional import ray import torch from math_verify import parse, verify from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict -from nemo_reinforcer.environments.interfaces import EnvironmentInterface +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) from nemo_reinforcer.environments.metrics import ( calculate_pass_rate_per_prompt, ) @@ -29,6 +32,7 @@ class MathEnvConfig(TypedDict): num_workers: int + stop_strings: Optional[List[str]] = None # Default stop strings for this env @ray.remote @@ -66,7 +70,8 @@ class MathEnvironmentMetadata(TypedDict): class MathEnvironment(EnvironmentInterface): DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM - def __init__(self, cfg: Dict): + def __init__(self, cfg: MathEnvConfig): + self.cfg = cfg self.num_workers = cfg["num_workers"] self.workers = [ HFVerifyWorker.options( @@ -84,7 +89,7 @@ def step( self, message_log_batch: List[List[Dict[str, str]]], metadata: List[MathEnvironmentMetadata], - ): + ) -> EnvironmentReturn: """Runs a step in the math environment. Args: @@ -95,6 +100,7 @@ def step( EnvironmentReturn: A tuple containing: - List[Dict[str, str]]: Observations/responses batch - List[Dict]: Updated metadata + - List[str]: Next stop strings for the next turn - Tensor: Rewards tensor - Tensor: Done flags tensor """ @@ -129,7 +135,12 @@ def step( # flatten the results results = [item for sublist in results for item in sublist] observations = [ - {"role": "user", "content": "correct" if result else "incorrect"} + { + "role": "environment", + "content": "Environment: correct" + if result + else "Environment: incorrect", + } for result in results ] @@ -137,7 +148,15 @@ def step( rewards = torch.tensor(results).cpu() done = torch.ones_like(rewards).cpu() - return observations, metadata, rewards, done + next_stop_strings = [None] * len(message_log_batch) + + return EnvironmentReturn( + observations=observations, + metadata=metadata, + next_stop_strings=next_stop_strings, + rewards=rewards, + terminateds=done, + ) def global_post_process_and_metrics( self, batch: BatchedDataDict diff --git a/nemo_reinforcer/experience/rollouts.py b/nemo_reinforcer/experience/rollouts.py new file mode 100644 index 0000000000..f41661efaf --- /dev/null +++ b/nemo_reinforcer/experience/rollouts.py @@ -0,0 +1,379 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generate rollouts for arbitrary environments +# Supports multi-turn rollouts and many simultaneous environments (E.g. you can train on math, code, multi-turn games and more at once) + +import torch +from typing import List, Tuple, Dict, Optional, Any, NamedTuple +from transformers import AutoTokenizer +import ray + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.data.interfaces import ( + DatumSpec, + LLMMessageLogType, + FlatMessagesType, +) +from nemo_reinforcer.data.llm_message_utils import ( + get_keys_from_message_log, + batched_message_log_to_flat_message, +) +from nemo_reinforcer.models.generation.interfaces import ( + GenerationInterface, + GenerationDatumSpec, +) +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) + + +def generate_responses( + policy_generation: GenerationInterface, + generation_input_data: BatchedDataDict[GenerationDatumSpec], + batch: BatchedDataDict[DatumSpec], + tokenizer: AutoTokenizer, + input_lengths: torch.Tensor, + include_logprobs: bool = True, + greedy: bool = False, +) -> Tuple[BatchedDataDict[DatumSpec], List[torch.Tensor], dict]: + """Generate responses from policy.""" + # Add stop_strings to generation_input_data if present in the batch + if "stop_strings" in batch: + generation_input_data["stop_strings"] = batch["stop_strings"] + else: + # Ensure the key exists even if it's None, matching GenerationDatumSpec + generation_input_data["stop_strings"] = [None] * len(input_lengths) + + # Generate responses + generation_outputs = policy_generation.generate( + generation_input_data, greedy=greedy + ) + + # Extract generated tokens + generated_ids = [] + unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"] + for output_ids, input_length, total_length in zip( + generation_outputs["output_ids"], input_lengths, unpadded_sequence_lengths + ): + generated_ids.append(output_ids[input_length:total_length]) + + generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + # Append to message log + for i, (text, input_length, total_length) in enumerate( + zip(generated_texts, input_lengths, unpadded_sequence_lengths) + ): + message = { + "role": "assistant", + "content": text, + "token_ids": generation_outputs["output_ids"][i, input_length:total_length], + } + + if include_logprobs and "logprobs" in generation_outputs: + message["generation_logprobs"] = generation_outputs["logprobs"][ + i, input_length:total_length + ] + + batch["message_log"][i].append(message) + + metrics = { + "mean_generation_length": ( + torch.sum(unpadded_sequence_lengths) - torch.sum(input_lengths) + ).item() + / len(unpadded_sequence_lengths), + "max_seqlen": torch.max(unpadded_sequence_lengths).item(), + } + + return batch, generated_ids, metrics + + +def calculate_rewards( + batch: BatchedDataDict[DatumSpec], + task_to_env: Dict[str, EnvironmentInterface], +) -> EnvironmentReturn: + """Calculate rewards for generated responses and get environment feedback. + + Args: + batch: Batch containing message_log (LLMMessageLogType) with generated responses + task_to_env: Dictionary mapping task names to their corresponding environments + + Returns: + EnvironmentReturn namedtuple containing: + - observations: List of observations from the environment for the next turn. + - metadata: List of extracted metadata from the environment. + - next_stop_strings: List of stop strings for the next generation step. + - rewards: Tensor of rewards for the last turn. + - terminateds: Tensor of booleans indicating if an episode ended naturally. + """ + # Extract message logs for environment (most recent interaction) + to_env = [ + get_keys_from_message_log(batch["message_log"][i], ["role", "content"]) + for i in range(len(batch["message_log"])) + ] + task_names = batch["task_name"] + + # Group messages by task type + task_groups = {} + for i, task_name in enumerate(task_names): + if task_name not in task_groups: + task_groups[task_name] = [] + task_groups[task_name].append((i, to_env[i])) + + # Calculate rewards for each task group concurrently + futures = [] + future_to_indices = {} # Map future to its corresponding indices + for task_name, group in task_groups.items(): + if task_name not in task_to_env: + raise ValueError(f"No environment found for task type: {task_name}") + + # Extract indices and messages for this group + indices = [idx for idx, _ in group] + messages = [msg for _, msg in group] + + # Get corresponding environment info + env_info = [batch["extra_env_info"][i] for i in indices] + + # Submit task to environment and store future + future = task_to_env[task_name].step.remote(messages, env_info) + futures.append(future) + future_to_indices[future] = indices + + results = ray.get(futures) + all_rewards = [] + all_env_observations = [] + all_terminateds = [] + all_next_stop_strings = [] + all_metadata = [] # Store extracted metadata + all_indices_order = [] + + for future, result in zip(futures, results): + indices = future_to_indices[future] + # Environment step returns: EnvironmentReturn + env_observations, metadata, next_stop_strings, task_rewards, terminateds = ( + result + ) + if next_stop_strings is None: + next_stop_strings = [None] * len(task_rewards) + + # Store results with their original indices + for i, idx in enumerate(indices): + all_indices_order.append(idx) + all_rewards.append(task_rewards[i]) + all_env_observations.append(env_observations[i]) + all_terminateds.append(terminateds[i]) + all_next_stop_strings.append(next_stop_strings[i]) + all_metadata.append(metadata[i]) + + # Sort results by original index to maintain order + sorted_indices = sorted( + range(len(all_indices_order)), key=lambda k: all_indices_order[k] + ) + rewards = torch.tensor([all_rewards[i] for i in sorted_indices]) + env_observations = [all_env_observations[i] for i in sorted_indices] + terminateds = torch.tensor([all_terminateds[i] for i in sorted_indices]) + next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices] + metadata = [all_metadata[i] for i in sorted_indices] # Sort metadata + + return EnvironmentReturn( + observations=env_observations, + metadata=metadata, + next_stop_strings=next_stop_strings, + rewards=rewards, + terminateds=terminateds, + ) + + +def run_multi_turn_rollout( + policy_generation: GenerationInterface, + input_batch: BatchedDataDict[DatumSpec], + tokenizer: AutoTokenizer, + task_to_env: Dict[str, EnvironmentInterface], + max_seq_len: int, + max_rollout_turns: int = 999999, + greedy: bool = False, +) -> Tuple[BatchedDataDict[DatumSpec], Dict[str, Any]]: + """Runs a multi-turn rollout loop, interacting with the environment. + + Args: + policy_generation: The generation interface (policy). + input_batch: The starting batch containing initial message logs. + tokenizer: The tokenizer. + task_to_env: Dictionary mapping task names to environment instances. + max_rollout_turns: Maximum number of agent-environment interaction turns. + max_seq_len: Maximum sequence length allowed. + greedy: Whether to use greedy decoding. + + Returns: + Tuple containing: + - BatchedDataDict with the full interaction history and accumulated rewards + - Dictionary of rollout metrics + """ + current_batch = input_batch.copy() # Work on a copy + batch_size = len(current_batch["message_log"]) + active_indices = torch.arange(batch_size) + total_rewards = torch.zeros(batch_size, dtype=torch.float32) + + # Initialize stop_strings from the initial batch if present + current_stop_strings = current_batch.get("stop_strings", [None] * batch_size) + + # Tracking metrics for each sample + sample_turn_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_token_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_assistant_token_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_env_token_counts = torch.zeros(batch_size, dtype=torch.int32) + sample_terminated = torch.zeros(batch_size, dtype=torch.bool) + sample_truncated = torch.zeros(batch_size, dtype=torch.bool) + sample_max_turns_reached = torch.zeros(batch_size, dtype=torch.bool) + + # Tracking per-turn metrics + total_gen_tokens_per_turn = [] + active_samples_per_turn = [] + + for turn in range(max_rollout_turns): + if len(active_indices) == 0: + break + + active_samples_per_turn.append(len(active_indices)) + + # Convert LLMMessageLogType to FlatMessagesType for generation + active_batch = current_batch.select_indices(active_indices) + active_stop_strings = [current_stop_strings[i] for i in active_indices.tolist()] + + active_flat_messages: FlatMessagesType + active_flat_messages, active_input_lengths = ( + batched_message_log_to_flat_message( + active_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + ) + + # Extract input_ids and lengths from the flat messages + active_input_ids = active_flat_messages["token_ids"] + + generation_input_data = BatchedDataDict[GenerationDatumSpec]( + { + "input_ids": active_input_ids, + "input_lengths": active_input_lengths, + "stop_strings": active_stop_strings, + } + ) + + # generate_responses updates active_batch["message_log"] in-place + active_batch, generated_ids, gen_metrics = generate_responses( + policy_generation, + generation_input_data, + active_batch, + tokenizer, + active_input_lengths, + greedy=greedy, + ) + + # Record token usage - assistant + for i, global_idx in enumerate(active_indices.tolist()): + sample_assistant_token_counts[global_idx] += len(generated_ids[i]) + sample_token_counts[global_idx] += len(generated_ids[i]) + + # Track total generated tokens this turn + total_gen_tokens_per_turn.append(sum(len(ids) for ids in generated_ids)) + + # Calculate rewards and get environment feedback + env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env) + + total_rewards[active_indices] += env_output.rewards + + # Update message log for ALL active samples with env observation + # This must happen BEFORE filtering based on done flags + truncation_mask = torch.zeros_like(env_output.terminateds, dtype=torch.bool) + for i, global_idx in enumerate(active_indices.tolist()): + env_obs_content = env_output.observations[i]["content"] + # Tokenize the raw content from the environment + # TODO @sahilj: handle if we want these subsequent messages to have a chat template + tokenized_obs = tokenizer( + env_obs_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + + # check if new message overflows max_seq_len + if ( + len(tokenized_obs) + len(generated_ids[i]) + active_input_lengths[i] + >= max_seq_len + ): + # truncate + tokenized_obs = tokenized_obs[: max_seq_len - active_input_lengths[i]] + truncation_mask[i] = True + # Record truncation + sample_truncated[active_indices[i]] = True + + tokenized_env_obs_message = { + "role": env_output.observations[i]["role"], + "content": env_obs_content, + "token_ids": tokenized_obs, + } + current_batch["message_log"][global_idx].append(tokenized_env_obs_message) + + # Record token usage - environment + sample_env_token_counts[global_idx] += len(tokenized_obs) + sample_token_counts[global_idx] += len(tokenized_obs) + + # Increment turn count + sample_turn_counts[global_idx] += 1 + + # Determine done samples and update active set + terminateds = env_output.terminateds.bool() + done = truncation_mask | terminateds + sample_terminated[active_indices] |= done + + # Update active indices for the next iteration + active_indices_local_next = torch.where(~done)[0] + active_indices = active_indices[active_indices_local_next] + continuing_indices_global = active_indices # Indices relative to original batch + # Get next stop strings and infos corresponding to the indices that are *continuing* + continuing_next_stops = [ + env_output.next_stop_strings[i] for i in active_indices_local_next.tolist() + ] + # Get metadata corresponding to continuing indices, using the correct field name + continuing_metadata = [ + env_output.metadata[i] for i in active_indices_local_next.tolist() + ] + + for i, global_idx in enumerate(continuing_indices_global.tolist()): + # Update stop strings for the next turn + current_stop_strings[global_idx] = continuing_next_stops[i] + # Update metadata (extra_env_info) using info from environment + if continuing_metadata[i] is not None: + current_batch["extra_env_info"][global_idx] = continuing_metadata[i] + + # Record samples that reached max turns + sample_max_turns_reached[active_indices] = True + + # Add total rewards to the final batch + current_batch["total_reward"] = total_rewards + + # Calculate aggregate metrics + rollout_metrics = { + # Overall metrics + "total_turns": int(sample_turn_counts.sum().item()), + "avg_turns_per_sample": float(sample_turn_counts.float().mean().item()), + "max_turns_per_sample": int(sample_turn_counts.max().item()), + "natural_termination_rate": float(sample_terminated.float().mean().item()), + "truncation_rate": float(sample_truncated.float().mean().item()), + "max_turns_reached_rate": float(sample_max_turns_reached.float().mean().item()), + # Token usage metrics + "mean_gen_tokens_per_sample": float(sample_token_counts.float().mean().item()), + "mean_env_tokens_per_sample": float( + sample_env_token_counts.float().mean().item() + ), + } + return current_batch, rollout_metrics diff --git a/nemo_reinforcer/models/generation/interfaces.py b/nemo_reinforcer/models/generation/interfaces.py index 468714899f..48ed8554d8 100644 --- a/nemo_reinforcer/models/generation/interfaces.py +++ b/nemo_reinforcer/models/generation/interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, TypedDict, Union, Tuple, List +from typing import Any, TypedDict, Union, Tuple, List, Optional import torch from transformers import AutoTokenizer @@ -139,6 +139,7 @@ class GenerationDatumSpec(TypedDict): - input_ids: Tensor of token IDs representing the input sequences (right padded) - input_lengths: Tensor containing the actual length of each sequence (without padding) + - stop_strings: Optional list of strings to stop generation (per sample) - __extra__: Additional model-specific data fields Example of a batch with 4 entries with different sequence lengths: @@ -163,6 +164,7 @@ class GenerationDatumSpec(TypedDict): input_ids: torch.Tensor input_lengths: torch.Tensor + stop_strings: Optional[List[str]] = None __extra__: Any diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index ada0bf2623..03851f2de7 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -224,24 +224,37 @@ def generate( - generation_lengths: Lengths of each response - unpadded_sequence_lengths: Lengths of each input + generated sequence """ - # Verify input is right padded - assert isinstance(data, BatchedDataDict), ( - f"data must be a BatchedDataDict, got type: {type(data)}" - ) - assert "input_ids" in data and "input_lengths" in data, ( - f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" - ) - is_right_padded, error_msg = verify_right_padding( - data, pad_value=self.cfg["pad_token_id"] - ) - if not is_right_padded: - warnings.warn( - f"Input to vLLM worker is not properly right-padded: {error_msg}" + # Handle empty input case + if len(data["input_ids"]) == 0: + # Return empty BatchedDataDict with all required fields + return BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": torch.zeros((0, 0), dtype=torch.long), + "logprobs": torch.zeros((0, 0), dtype=torch.float), + "generation_lengths": torch.zeros(0, dtype=torch.long), + "unpadded_sequence_lengths": torch.zeros(0, dtype=torch.long), + } ) - # Convert inputs to vLLM format input_ids = data["input_ids"] input_lengths = data["input_lengths"] + # this function requires all generations have the same stop strings, so we collect all here + batch_stop_strings = data.get("stop_strings", []) + stop_strings = set() + for sample_stop_strings in batch_stop_strings: + if sample_stop_strings: + stop_strings.update(sample_stop_strings) + + # Add default stop strings from config + if self.cfg.get("stop_strings", None): + stop_strings.update(self.cfg["stop_strings"]) + + stop_strings = list(stop_strings) + + # verify inputs have correct padding + verify_right_padding(data, pad_value=self.cfg["pad_token_id"]) + + # Convert inputs to vLLM format batch_size = input_ids.shape[0] # Original input length with padding padded_input_length = input_ids.size(1) @@ -269,7 +282,7 @@ def generate( max_tokens=self.cfg["max_new_tokens"], logprobs=0, # Return logprobs for the generated tokens stop_token_ids=self.cfg["stop_token_ids"], - stop=self.cfg["stop_strings"], + stop=stop_strings, include_stop_str_in_output=True, # returning stop strings like hf ) @@ -359,6 +372,23 @@ def generate_text( BatchedDataDict containing: - texts: List of generated text responses """ + # Extract stop_strings if provided, else use default from config + batch_stop_strings = data.get( + "stop_strings", [self.cfg.get("stop_strings")] * len(data["prompts"]) + ) + + # This function requires all generations have the same stop strings, so we collect all here + stop_strings = set() + for sample_stop_strings in batch_stop_strings: + if sample_stop_strings: + stop_strings.update(sample_stop_strings) + + # Add default stop strings from config + if self.cfg.get("stop_strings", None): + stop_strings.update(self.cfg["stop_strings"]) + + stop_strings = list(stop_strings) if len(stop_strings) > 0 else None + # Read generation parameters from config top_k = self.cfg["top_k"] if self.cfg["top_k"] is not None else -1 sampling_params = self.SamplingParams( @@ -367,7 +397,7 @@ def generate_text( top_k=top_k if not greedy else 1, max_tokens=self.cfg["max_new_tokens"], stop_token_ids=self.cfg["stop_token_ids"], - stop=self.cfg["stop_strings"], + stop=stop_strings, include_stop_str_in_output=True, # returning stop strings like hf ) @@ -517,10 +547,8 @@ def generate( "input_ids and input_lengths are required in data for vLLM generation" ) - batch_size = data["input_ids"].shape[0] - # Shard the data across the tied worker groups - sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=batch_size) + sharded_data = data.shard_by_batch_size(self.dp_size, allow_uneven_shards=True) future_bundle = self.worker_group.run_all_workers_multiple_data( "generate", sharded_data, diff --git a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py index 192d51ce88..3ff7faed70 100644 --- a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py +++ b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py @@ -517,6 +517,19 @@ def generate( # Set attention mask for the actual tokens (at the end for left padding) left_padded_attention_mask[i, seq_len - length :] = 1 + # this function requires all generations have the same stop strings, so we collect all here + batch_stop_strings = gen_batch.get("stop_strings", []) + stop_strings = set() + for sample_stop_strings in batch_stop_strings: + if sample_stop_strings: + stop_strings.update(sample_stop_strings) + + # Add default stop strings from config + if gen_cfg.get("stop_strings", None): + stop_strings.update(gen_cfg["stop_strings"]) + + stop_strings = list(stop_strings) if len(stop_strings) > 0 else None + if isinstance( self.model, torch.distributed.fsdp.FullyShardedDataParallel ): @@ -533,7 +546,7 @@ def generate( top_k=gen_cfg["top_k"], pad_token_id=gen_cfg["pad_token_id"], eos_token_id=gen_cfg["stop_token_ids"], - stop_strings=gen_cfg["stop_strings"], + stop_strings=stop_strings, tokenizer=self.tokenizer, # needs for stop_strings return_dict_in_generate=True, output_scores=True, diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index c6491e02d6..da1a21244f 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -16,10 +16,13 @@ import ray from typing import Dict, List, Tuple -from nemo_reinforcer.algorithms.grpo import calculate_rewards +from nemo_reinforcer.experience.rollouts import calculate_rewards from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.data.interfaces import DatumSpec, LLMMessageLogType -from nemo_reinforcer.environments.interfaces import EnvironmentInterface +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) @ray.remote(num_cpus=0) @@ -30,9 +33,15 @@ def __init__(self, rewards: List[float]): def step( self, messages: List[LLMMessageLogType], env_info: List[dict] - ) -> Tuple[None, None, List[float], None]: + ) -> EnvironmentReturn: self._calls += 1 - return None, None, self.rewards, None + return ( + [{"role": "environment", "content": "observation"}] * len(messages), + [{}] * len(messages), + [[]] * len(messages), + self.rewards, + [True] * len(messages), + ) def get_calls(self): return self._calls @@ -117,11 +126,17 @@ def test_calculate_rewards_single_task(mock_env): batch = create_mock_batch(2, task_names, message_logs) # Calculate rewards - rewards, to_env = calculate_rewards(batch, task_to_env) + env_observations, metadata, next_stop_strings, rewards, terminateds = ( + calculate_rewards(batch, task_to_env) + ) # Verify results assert torch.allclose(rewards, torch.tensor([1.0, 2.0])) - assert len(to_env) == 2 + assert len(env_observations) == 2 + assert len(terminateds) == 2 + assert len(next_stop_strings) == 2 + assert len(metadata) == 2 + assert torch.allclose(rewards, torch.tensor([1.0, 2.0])) assert ( ray.get(mock_env.get_calls.remote()) == 1 ) # Should only call once for all samples of same task @@ -146,11 +161,17 @@ def test_calculate_rewards_multiple_tasks(mock_envs): batch = create_mock_batch(4, task_names, message_logs) # Calculate rewards - rewards, to_env = calculate_rewards(batch, mock_envs) + env_observations, metadata, next_stop_strings, rewards, terminateds = ( + calculate_rewards(batch, mock_envs) + ) # Verify results assert torch.allclose(rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) - assert len(to_env) == 4 + assert len(env_observations) == 4 + assert len(terminateds) == 4 + assert len(next_stop_strings) == 4 + assert len(metadata) == 4 + assert torch.allclose(rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) assert ( ray.get(mock_envs["math"].get_calls.remote()) == 1 ) # One call for all math samples @@ -167,11 +188,16 @@ def test_calculate_rewards_empty_batch(mock_env): batch = create_mock_batch(0, [], []) # Calculate rewards - rewards, to_env = calculate_rewards(batch, task_to_env) + env_observations, metadata, next_stop_strings, rewards, terminateds = ( + calculate_rewards(batch, task_to_env) + ) # Verify results assert len(rewards) == 0 - assert len(to_env) == 0 + assert len(env_observations) == 0 + assert len(terminateds) == 0 + assert len(next_stop_strings) == 0 + assert len(metadata) == 0 assert ( ray.get(mock_env.get_calls.remote()) == 0 ) # Should not call environment for empty batch diff --git a/tests/unit/environments/game_interface.py b/tests/unit/environments/game_interface.py new file mode 100644 index 0000000000..2f0237ed23 --- /dev/null +++ b/tests/unit/environments/game_interface.py @@ -0,0 +1,76 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any, Tuple, List, Optional, Callable + + +class GameInterface: + @staticmethod + def generate(config: Dict[str, Any]) -> Dict[str, Any]: + """ + Generate a new game state based on configuration. + + Args: + config: Game configuration dictionary + + Returns: + A dictionary containing the complete game state + """ + raise NotImplementedError("Each game must implement generate()") + + @staticmethod + def init(game_state: Dict[str, Any]) -> str: + """ + Initialize a game and return welcome messages. + + Args: + game_state: The game state dictionary + + Returns: + String containing welcome message and instructions + """ + raise NotImplementedError("Each game must implement init()") + + @staticmethod + def step( + action: str, game_state: Dict[str, Any] + ) -> Tuple[str, float, bool, Dict[str, Any]]: + """ + Process a game action and update the state. + + Args: + action: String representing the player's action + game_state: Current game state dictionary + + Returns: + Tuple containing: + - Response message + - Reward for this action + - Boolean indicating if game is terminated + - Updated game state + """ + raise NotImplementedError("Each game must implement step()") + + @staticmethod + def render(game_state: Dict[str, Any]) -> str: + """ + Render the current game state as a string. + + Args: + game_state: The game state dictionary + + Returns: + String representation of the game state + """ + raise NotImplementedError("Each game must implement render()") diff --git a/tests/unit/environments/sliding_puzzle_game.py b/tests/unit/environments/sliding_puzzle_game.py new file mode 100644 index 0000000000..664e4c312b --- /dev/null +++ b/tests/unit/environments/sliding_puzzle_game.py @@ -0,0 +1,256 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import copy +from typing import List, Tuple, Dict, Any, Optional +from .game_interface import GameInterface + + +class SlidingPuzzleGame(GameInterface): + @staticmethod + def generate(config: Dict[str, Any]) -> Dict[str, Any]: + """Generate a new Sliding Puzzle.""" + size = config.get("size", 4) # Default to 4x4 (15-puzzle) + shuffle_moves = config.get( + "shuffle_moves", 100 + ) # Number of random moves for shuffling + + # Create the solved state + grid = [[(r * size + c + 1) for c in range(size)] for r in range(size)] + # Set the bottom-right corner to 0 (empty space) + grid[size - 1][size - 1] = 0 + + # Save the solution + solution = [row[:] for row in grid] + + # Find the empty space + empty_pos = (size - 1, size - 1) + + # Shuffle the grid with valid moves + for _ in range(shuffle_moves): + # Get possible moves + moves = [] + r, c = empty_pos + for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]: # Right, Down, Left, Up + nr, nc = r + dr, c + dc + if 0 <= nr < size and 0 <= nc < size: + moves.append((nr, nc)) + + # Choose a random move + if moves: + new_r, new_c = random.choice(moves) + # Swap the empty space with the chosen tile + grid[r][c], grid[new_r][new_c] = grid[new_r][new_c], grid[r][c] + empty_pos = (new_r, new_c) + + # Create and return the game state + return { + "size": size, + "grid": grid, + "solution": solution, + "empty_pos": empty_pos, + "commands": { + "slide r c": "Slide tile at row r, column c into the empty space", + "up": "Slide tile below empty space up", + "down": "Slide tile above empty space down", + "left": "Slide tile to the right of empty space left", + "right": "Slide tile to the left of empty space right", + }, + } + + @staticmethod + def init(game_state: Dict[str, Any]) -> str: + """Initialize Sliding Puzzle game and return welcome message.""" + size = game_state["size"] + + return ( + f"\n===== SLIDING PUZZLE =====\n" + f"Arrange the {size}x{size} grid by sliding tiles into the empty space.\n" + f"- The goal is to arrange numbers from 1 to {size * size - 1} in order\n" + f"- Use 'slide r c' to slide a specific tile\n" + f"- Or use 'up', 'down', 'left', 'right' to slide in that direction" + ) + + @staticmethod + def step( + action: str, game_state: Dict[str, Any] + ) -> Tuple[str, float, bool, Dict[str, Any]]: + """Process an action in the Sliding Puzzle game.""" + size = game_state["size"] + grid = game_state["grid"] + empty_r, empty_c = game_state["empty_pos"] + + # Default return values + response = "Unknown command. Type 'help' to see available commands." + reward = -0.05 # Small penalty for invalid actions + is_terminated = False + + # Deep copy game state to avoid modifying the original + new_state = copy.deepcopy(game_state) + + move_made = False + + if action.startswith("slide "): + try: + _, r, c = action.split() + r, c = int(r) - 1, int(c) - 1 + + # Validate input + if not (0 <= r < size and 0 <= c < size): + return ( + f"Invalid position. Row/column must be between 1 and {size}.", + reward, + is_terminated, + new_state, + ) + + # Check if tile is adjacent to empty space + if abs(r - empty_r) + abs(c - empty_c) != 1: + return ( + "Tile must be adjacent to the empty space.", + reward, + is_terminated, + new_state, + ) + + # Slide the tile + new_state["grid"][empty_r][empty_c] = grid[r][c] + new_state["grid"][r][c] = 0 + new_state["empty_pos"] = (r, c) + + move_made = True + response = f"Slid tile {grid[r][c]} into the empty space." + + except ValueError: + return ( + "Invalid input format. Use: slide row col", + reward, + is_terminated, + new_state, + ) + + elif action in ["up", "down", "left", "right"]: + # Convert direction to row/col offset + if action == "up": + r, c = empty_r + 1, empty_c # Tile below moves up + dir_text = "up" + elif action == "down": + r, c = empty_r - 1, empty_c # Tile above moves down + dir_text = "down" + elif action == "left": + r, c = empty_r, empty_c + 1 # Tile to right moves left + dir_text = "left" + elif action == "right": + r, c = empty_r, empty_c - 1 # Tile to left moves right + dir_text = "right" + + # Check if the move is valid + if 0 <= r < size and 0 <= c < size: + # Slide the tile + new_state["grid"][empty_r][empty_c] = grid[r][c] + new_state["grid"][r][c] = 0 + new_state["empty_pos"] = (r, c) + + move_made = True + response = f"Slid tile {grid[r][c]} {dir_text}." + else: + return f"Cannot slide {dir_text}.", reward, is_terminated, new_state + + if move_made: + reward = 0 + + # Check if puzzle is solved + if new_state["grid"] == new_state["solution"]: + response = "Congratulations! You've solved the puzzle!" + reward = 1.0 # Win reward + is_terminated = True + + return response, reward, is_terminated, new_state + + @staticmethod + def render(game_state: Dict[str, Any]) -> str: + """Render the current Sliding Puzzle game state.""" + grid = game_state["grid"] + size = game_state["size"] + + output = ["\n"] + + # Create a visual representation of the grid + max_digits = len(str(size * size - 1)) + + # Top border + output.append(" " + "+" + "-" * (max_digits + 2) * size + "+") + + # Rows + for i, row in enumerate(grid): + row_str = f"{i + 1} |" + for val in row: + if val == 0: + # Empty space + row_str += " " * (max_digits + 2) + else: + # Tile with number + row_str += f" {val:>{max_digits}} " + row_str += "|" + output.append(row_str) + + # Bottom border + output.append(" " + "+" + "-" * (max_digits + 2) * size + "+") + + # Column labels + col_labels = " " + for i in range(size): + col_labels += f"{i + 1:^{max_digits + 2}}" + output.append(col_labels) + + return "\n".join(output) + + +def is_solvable(grid: List[List[int]], size: int) -> bool: + """Check if a sliding puzzle is solvable.""" + # Flatten the grid + flat = [num for row in grid for num in row if num != 0] + + # Count inversions + inversions = 0 + for i in range(len(flat)): + for j in range(i + 1, len(flat)): + if flat[i] > flat[j]: + inversions += 1 + + # Find row of the empty tile (0) from the bottom + empty_row = 0 + for i in range(size - 1, -1, -1): + for j in range(size): + if grid[i][j] == 0: + empty_row = size - i + break + + # For odd-sized grids, the puzzle is solvable if the number of inversions is even + if size % 2 == 1: + return inversions % 2 == 0 + # For even-sized grids, the puzzle is solvable if: + # (inversions odd and empty on even row from bottom) or (inversions even and empty on odd row from bottom) + else: + return (inversions % 2 == 1 and empty_row % 2 == 0) or ( + inversions % 2 == 0 and empty_row % 2 == 1 + ) + + +def play_sliding_puzzle(config=None): + """Wrapper function for backward compatibility.""" + from play_game import play_game + + play_game(SlidingPuzzleGame, config) diff --git a/tests/unit/environments/test_math_environment.py b/tests/unit/environments/test_math_environment.py index 9b2eb4e21c..c26035ce15 100644 --- a/tests/unit/environments/test_math_environment.py +++ b/tests/unit/environments/test_math_environment.py @@ -109,74 +109,82 @@ def multiple_assistant_test_data(): def test_math_env_step_basic(math_env, basic_test_data): """Test basic functionality of MathEnvironment step with simple messages.""" - observations, updated_metadata, rewards, done = ray.get( + result = ray.get( math_env.step.remote( basic_test_data["message_log_batch"], basic_test_data["metadata"] ) ) - # Check observations - assert len(observations) == 3, "Should return observations for all 3 messages" - assert all(obs["role"] == "user" for obs in observations), ( - "All observations should be from user" + # Check observations using field access + assert len(result.observations) == 3, ( + "Should return observations for all 3 messages" ) - assert all(obs["content"] == "correct" for obs in observations), ( - "All responses should be correct" + assert all(obs["role"] == "environment" for obs in result.observations), ( + "All observations should be from environment" ) + assert all( + obs["content"] == "Environment: correct" for obs in result.observations + ), "All responses should be correct" # Check metadata - assert len(updated_metadata) == 3, "Should return metadata for all 3 messages" - assert updated_metadata == basic_test_data["metadata"], ( + assert len(result.metadata) == 3, "Should return metadata for all 3 messages" + assert result.metadata == basic_test_data["metadata"], ( "Metadata should be unchanged" ) # Check rewards and done flags - assert rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" - assert all(rewards == 1.0), "All rewards should be 1.0 for correct answers" - assert done.shape == (3,), "Done flags should be a tensor of shape (3,)" - assert all(done == 1.0), "All done flags should be 1.0" + assert result.rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" + assert all(result.rewards == 1.0), "All rewards should be 1.0 for correct answers" + assert result.terminateds.shape == (3,), ( + "Terminated flags should be a tensor of shape (3,)" + ) + assert all(result.terminateds == 1.0), "All terminated flags should be 1.0" def test_math_env_step_mixed(math_env, mixed_test_data): """Test MathEnvironment step with a mix of correct and incorrect responses.""" - observations, updated_metadata, rewards, done = ray.get( + result = ray.get( math_env.step.remote( mixed_test_data["message_log_batch"], mixed_test_data["metadata"] ) ) # Check observations and rewards - assert len(observations) == 3, "Should return observations for all 3 messages" - assert observations[0]["content"] == "correct", "First response should be correct" - assert observations[1]["content"] == "incorrect", ( + assert len(result.observations) == 3, ( + "Should return observations for all 3 messages" + ) + assert result.observations[0]["content"] == "Environment: correct", ( + "First response should be correct" + ) + assert result.observations[1]["content"] == "Environment: incorrect", ( "Second response should be incorrect" ) - assert observations[2]["content"] == "correct", "Third response should be correct" + assert result.observations[2]["content"] == "Environment: correct", ( + "Third response should be correct" + ) - assert rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" - assert rewards[0] == 1.0, "First reward should be 1.0" - assert rewards[1] == 0.0, "Second reward should be 0.0" - assert rewards[2] == 1.0, "Third reward should be 1.0" + assert result.rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" + assert result.rewards[0] == 1.0, "First reward should be 1.0" + assert result.rewards[1] == 0.0, "Second reward should be 0.0" + assert result.rewards[2] == 1.0, "Third reward should be 1.0" def test_math_env_step_empty(math_env): """Test MathEnvironment step with empty input.""" - observations, updated_metadata, rewards, done = ray.get( - math_env.step.remote([], []) - ) + result = ray.get(math_env.step.remote([], [])) # Check all outputs are empty - assert len(observations) == 0, "Should return empty observations list" - assert len(updated_metadata) == 0, "Should return empty metadata list" - assert rewards.shape == (0,), "Should return empty rewards tensor" - assert done.shape == (0,), "Should return empty done tensor" + assert len(result.observations) == 0, "Should return empty observations list" + assert len(result.metadata) == 0, "Should return empty metadata list" + assert result.rewards.shape == (0,), "Should return empty rewards tensor" + assert result.terminateds.shape == (0,), "Should return empty terminateds tensor" def test_math_env_step_multiple_assistant_messages( math_env, multiple_assistant_test_data ): """Test MathEnvironment step with multiple assistant messages in a conversation.""" - observations, updated_metadata, rewards, done = ray.get( + result = ray.get( math_env.step.remote( multiple_assistant_test_data["message_log_batch"], multiple_assistant_test_data["metadata"], @@ -184,11 +192,13 @@ def test_math_env_step_multiple_assistant_messages( ) # Check that only the last assistant message is used - assert len(observations) == 2, "Should return observations for both conversations" - assert all(obs["content"] == "correct" for obs in observations), ( - "All responses should be correct" + assert len(result.observations) == 2, ( + "Should return observations for both conversations" ) - assert all(rewards == 1.0), "All rewards should be 1.0" + assert all( + obs["content"] == "Environment: correct" for obs in result.observations + ), "All responses should be correct" + assert all(result.rewards == 1.0), "All rewards should be 1.0" @pytest.mark.parametrize("batch_size", [1, 2, 10, 25, 101]) @@ -202,16 +212,20 @@ def test_math_env_various_batches(math_env, batch_size): ] * batch_size metadata = [{"ground_truth": "3.33333333"}] * batch_size - observations, updated_metadata, rewards, done = ray.get( - math_env.step.remote(message_log_batch, metadata) - ) + result = ray.get(math_env.step.remote(message_log_batch, metadata)) # Check outputs - assert len(observations) == batch_size, ( + assert len(result.observations) == batch_size, ( f"Should return observations for all {batch_size} messages" ) - assert all(obs["content"] == "correct" for obs in observations), ( - "All responses should be correct" + assert all( + obs["content"] == "Environment: correct" for obs in result.observations + ), "All responses should be correct" + assert result.rewards.shape == (batch_size,), ( + "Rewards should be a tensor of shape (batch_size,)" + ) + assert all(result.rewards == 1.0), "All rewards should be 1.0" + assert result.terminateds.shape == (batch_size,), ( + "Terminated flags should be a tensor of shape (batch_size,)" ) - assert all(rewards == 1.0), "All rewards should be 1.0" - assert all(done == 1.0), "All done flags should be 1.0" + assert all(result.terminateds == 1.0), "All terminated flags should be 1.0" diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py new file mode 100644 index 0000000000..daeecb2bc6 --- /dev/null +++ b/tests/unit/experience/test_rollouts.py @@ -0,0 +1,626 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import ray +import torch +from copy import deepcopy +import gc + +from transformers import AutoTokenizer + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.models.generation.interfaces import configure_generation_config +from nemo_reinforcer.experience.rollouts import run_multi_turn_rollout + +# Import the test environment definitions +from tests.unit.test_envs import ( + MultiStepCalculatorEnv, + _MultiStepCalculatorLogic, + MultiStepCalcMetadata, + SlidingPuzzleEnv, + SlidingPuzzleMetadata, +) + +# Import the game logic for generating initial state from its new location +from tests.unit.environments.sliding_puzzle_game import SlidingPuzzleGame + +from nemo_reinforcer.models.generation.vllm import VllmConfig, VllmGeneration + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="function") +def rollout_tokenizer(): + """Loads the tokenizer for the tests.""" + print(f"Loading tokenizer: {MODEL_NAME}") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + print( + f"Tokenizer loaded. Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})" + ) + return tokenizer + + +# Separate fixture for cluster setup and teardown +@pytest.fixture(scope="function") +def rollout_cluster(): + cluster_instance = None + cluster_name = f"test-rollout-cluster-{id(cluster_instance)}" # Unique name + print(f"\nCreating virtual cluster '{cluster_name}'...") + try: + # Use 1 GPU for simplicity + cluster_instance = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[1], + use_gpus=True, + num_gpus_per_node=1, + max_colocated_worker_groups=2, # Allow policy and env + ) + yield cluster_instance + finally: + print(f"\nCleaning up cluster '{cluster_name}'...") + if cluster_instance: + cluster_instance.shutdown() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print(f"Cluster '{cluster_name}' cleanup finished.") + + +# Fixture for the multi-step calculator environment actor +@pytest.fixture(scope="function") +def multi_step_calculator_environment(rollout_cluster): + env_actor = None + print("Creating MultiStepCalculatorEnv actor...") + try: + env_actor = MultiStepCalculatorEnv.remote() + task_to_env = {"multi_step_calculator_game": env_actor} + yield task_to_env, env_actor + finally: + print("Cleaning up multi_step_calculator_environment...") + if env_actor: + ray.kill(env_actor) + print("multi_step_calculator_environment cleanup finished.") + + +# Fixture for the multi-step calculator initial batch data +@pytest.fixture(scope="function") +def initial_multi_step_calculator_batch(rollout_tokenizer): + print("Creating initial multi-step calculator test batch...") + batch_size = 1 # Simpler to debug with one sample + problem = "(5 + 3) * 2" # Example problem + expected_answer = 16.0 + max_steps = 5 # Allow a few steps + tool_instructions = ( + "You have a calculator tool. To use it, respond with:\n" + "'[operand1, operand2, operation_name]'\n" + "The valid 'operation_name' values are exactly: 'sum', 'diff', 'prod', 'div'.\n" + "Example: [5, 3, sum]\n" + "You will receive the result of your calculation as ...\n" + "Use this result to make the next calculation if needed.\n" + "IMPORTANT: Only perform one calculation step (one tool call) before waiting for a result and making a new tool call.\n" + "IMPORTANT: Do not perform any other calculations or operations aside from the tool call and result. Doing so will result in failure.\n" + "To give the final answer, just output the number. numbers inside of don't count, so output just the final number yourself outside of this.\n" + "Example full output: [2, 4, sum]\n6.0\n[6, 6, diff]\n0.0 0\n(note how you have to output the final 0 outside of the tags)" + "------\n" + f"Solve: {problem}" + ) + batch_message_logs = [] + batch_extra_env_info = [] + batch_loss_multipliers = [] + batch_indices = [] + batch_task_names = [] + + for i in range(batch_size): + # Apply chat template to the initial prompt + initial_prompt_content = rollout_tokenizer.apply_chat_template( + [{"role": "user", "content": tool_instructions}], + tokenize=False, + add_system_prompt=False, + add_generation_prompt=True, + add_special_tokens=False, + ) + tokenized_prompt = rollout_tokenizer( + initial_prompt_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + message_log = [ + { + "role": "user", + "content": initial_prompt_content, + "token_ids": tokenized_prompt, + } + ] + metadata = MultiStepCalcMetadata( + problem=problem, + expected_final_answer=expected_answer, + max_steps=max_steps, + current_step=0, + ) + + batch_message_logs.append(message_log) + batch_extra_env_info.append(metadata) + batch_loss_multipliers.append(1.0) + batch_indices.append(i) + batch_task_names.append("multi_step_calculator_game") + + initial_batch_dict = { + "message_log": batch_message_logs, + "extra_env_info": batch_extra_env_info, + "loss_multiplier": batch_loss_multipliers, + "idx": batch_indices, + "task_name": batch_task_names, + "stop_strings": [[""]] * batch_size, + } + return BatchedDataDict(initial_batch_dict) + + +# Keep the base config separate +base_hf_test_config: PolicyConfig = { + "policy_type": "hf", + "model_name": MODEL_NAME, + "tokenizer_name": None, + "model_path": None, + "num_workers": 1, + "train_global_batch_size": 2, + "train_micro_batch_size": 1, + "logprob_batch_size": 2, + "generation_batch_size": 1, # Smaller for simpler testing + "learning_rate": 5e-6, + "precision": "float32", + "activation_checkpointing_enabled": False, + "fsdp_offload_enabled": False, + "generation": { + "backend": "hf", + "max_new_tokens": 50, # Increased for tool call format + "temperature": 0.01, + "top_p": 1.0, + "top_k": None, + "stop_token_ids": None, + "stop_strings": None, + }, + "optimizer": { + "name": "torch.optim.AdamW", + "kwargs": { + "lr": 5e-6, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-8, + }, + }, + "dtensor_cfg": {"enabled": False}, +} + +base_vllm_test_config: VllmConfig = { + "backend": "vllm", + "model_name": MODEL_NAME, + "tokenizer_name": None, + "dtype": "bfloat16", + "gpu_memory_utilization": 0.6, + "max_new_tokens": 50, # Increased for tool call format + "temperature": 0.01, # Near-greedy + "top_p": 1.0, + "top_k": None, + "stop_token_ids": None, + "stop_strings": None, + "vllm_cfg": { + "tensor_parallel_size": 1, + "max_model_len": 2048, + "disable_log_stats": True, + "disable_log_requests": True, + "gpu_memory_utilization": 0.6, + }, +} + + +@pytest.fixture(scope="function") +def multi_step_setup_hf( + rollout_cluster, + rollout_tokenizer, + multi_step_calculator_environment, + initial_multi_step_calculator_batch, +): + """Sets up components for multi-step calculator tests using HfPolicy.""" + policy = None + task_to_env, _ = multi_step_calculator_environment + print("Creating HfPolicy for Multi-Step Calculator Test...") + try: + config = deepcopy(base_hf_test_config) + config["tokenizer_name"] = rollout_tokenizer.name_or_path + if "gpt2" in rollout_tokenizer.name_or_path.lower(): + config["model_name"] = "gpt2" + config["generation"] = configure_generation_config( + config["generation"], rollout_tokenizer + ) + config["generation"]["stop_strings"] = None + policy = HfPolicy( + cluster=rollout_cluster, + config=config, + tokenizer=rollout_tokenizer, + init_reference_model=False, + init_optimizer=False, + ) + yield ( + policy, + rollout_tokenizer, + task_to_env, + initial_multi_step_calculator_batch, + rollout_cluster, + ) + finally: + print("Cleaning up HfPolicy (Multi-Step Calc Test)...") + if policy: + policy.shutdown() + print("HfPolicy cleanup finished (Multi-Step Calc Test).") + + +@pytest.fixture(scope="function") +def multi_step_setup_vllm( + rollout_cluster, + rollout_tokenizer, + multi_step_calculator_environment, + initial_multi_step_calculator_batch, +): + """Sets up components for multi-step calculator tests using VllmGeneration.""" + vllm_generation = None + task_to_env, _ = multi_step_calculator_environment + is_eval = True + print("Creating VllmGeneration for Multi-Step Calculator Test...") + try: + vllm_config = deepcopy(base_vllm_test_config) + vllm_config["tokenizer_name"] = rollout_tokenizer.name_or_path + if "gpt2" in rollout_tokenizer.name_or_path.lower(): + vllm_config["model_name"] = "gpt2" + vllm_config = configure_generation_config( + vllm_config, rollout_tokenizer, is_eval=is_eval + ) + vllm_generation = VllmGeneration(rollout_cluster, vllm_config) + vllm_generation.finish_generation() + yield ( + vllm_generation, + rollout_tokenizer, + task_to_env, + initial_multi_step_calculator_batch, + rollout_cluster, + ) + finally: + print("Cleaning up VllmGeneration (Multi-Step Calc Test)...") + if vllm_generation: + vllm_generation.shutdown() + # Force garbage collection to help release resources + import gc + + gc.collect() + torch.cuda.empty_cache() + print("VllmGeneration cleanup finished (Multi-Step Calc Test).") + + +def test_run_multi_step_calculator_hf(multi_step_setup_hf): + """Tests multi-step calculator rollout with HfPolicy.""" + policy, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( + multi_step_setup_hf + ) + max_rollout_turns = ( + initial_batch["extra_env_info"][0]["max_steps"] + 1 + ) # Allow max steps + final answer + max_seq_len = 1024 # Increased for potentially longer interaction + + print("\nRunning multi-step calculator rollout (HF)...") + policy.prepare_for_generation() + final_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=policy, + input_batch=initial_batch, + tokenizer=rollout_tokenizer, + task_to_env=task_to_env, + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + ) + policy.finish_generation() + print("Multi-step calculator rollout complete (HF).") + + # --- Assertions --- + assert isinstance(final_batch, BatchedDataDict) + assert "message_log" in final_batch + assert "total_reward" in final_batch + assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) + + sample_log = final_batch["message_log"][0] + expected_final_answer = initial_batch["extra_env_info"][0]["expected_final_answer"] + print("\nSample Interaction Log (Multi-Step Calculator - HF):") + tool_call_count = 0 + final_answer_msg = None + for i, msg in enumerate(sample_log): + print(f" {i}: Role={msg['role']}, Content='{msg['content']}'") + if msg["role"] == "assistant": + if msg["content"].strip().endswith(""): + tool_call_count += 1 + else: + final_answer_msg = msg["content"].strip() + + assert tool_call_count >= 1, "Expected at least one tool call" + assert final_answer_msg is not None, ( + "Expected a final answer message from assistant" + ) + + # Check final answer correctness (allowing for different final answer formats) + final_answer_logic = _MultiStepCalculatorLogic() + extracted_final_answer = final_answer_logic._is_final_answer(final_answer_msg) + assert extracted_final_answer is not None, ( + f"Could not parse final answer from: {final_answer_msg}" + ) + assert abs(extracted_final_answer - expected_final_answer) < 1e-6, ( + f"Final answer incorrect. Expected {expected_final_answer}, Got {extracted_final_answer}" + ) + + # Check total reward (should be 1.0 if correct) + assert torch.all(final_batch["total_reward"] == 1.0), ( + f"Expected total reward 1.0, got {final_batch['total_reward']}" + ) + + print("\nMulti-Step Calculator HF Test assertions passed.") + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 1, + reason="VLLM test requires at least 1 GPU", +) +def test_run_multi_step_calculator_vllm(multi_step_setup_vllm): + """Tests multi-step calculator rollout with VllmGeneration.""" + vllm_generation, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( + multi_step_setup_vllm + ) + max_rollout_turns = initial_batch["extra_env_info"][0]["max_steps"] + 1 + max_seq_len = 1024 + + print("\nRunning multi-step calculator rollout (VLLM)...") + vllm_generation.prepare_for_generation() + final_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=vllm_generation, + input_batch=initial_batch, + tokenizer=rollout_tokenizer, + task_to_env=task_to_env, + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + ) + vllm_generation.finish_generation() + print("Multi-step calculator rollout complete (VLLM).") + + # --- Assertions --- + assert isinstance(final_batch, BatchedDataDict) + assert "message_log" in final_batch + assert "total_reward" in final_batch + assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) + + sample_log = final_batch["message_log"][0] + expected_final_answer = initial_batch["extra_env_info"][0]["expected_final_answer"] + print("\nSample Interaction Log (Multi-Step Calculator - VLLM):") + tool_call_count = 0 + final_answer_msg = None + for i, msg in enumerate(sample_log): + print(f" {i}: Role={msg['role']}, Content='{msg['content']}'") + if msg["role"] == "assistant": + if msg["content"].strip().endswith(""): + tool_call_count += 1 + else: + final_answer_msg = msg["content"].strip() + + assert tool_call_count >= 1, "Expected at least one tool call" + assert final_answer_msg is not None, ( + "Expected a final answer message from assistant" + ) + + final_answer_logic = _MultiStepCalculatorLogic() + extracted_final_answer = final_answer_logic._is_final_answer(final_answer_msg) + assert extracted_final_answer is not None, ( + f"Could not parse final answer from: {final_answer_msg}" + ) + assert abs(extracted_final_answer - expected_final_answer) < 1e-6, ( + f"Final answer incorrect. Expected {expected_final_answer}, Got {extracted_final_answer}" + ) + + assert torch.all(final_batch["total_reward"] == 1.0), ( + f"Expected total reward 1.0, got {final_batch['total_reward']}" + ) + + print("\nMulti-Step Calculator VLLM Test assertions passed.") + + +# --- Fixture for Sliding Puzzle Environment --- +@pytest.fixture(scope="function") +def sliding_puzzle_environment(rollout_cluster): + env_actor = None + print("Creating SlidingPuzzleEnv actor...") + try: + # Pass game config if needed, e.g., {"game_config": {"size": 3}} + env_actor = SlidingPuzzleEnv.remote() + task_to_env = {"sliding_puzzle_game": env_actor} + yield task_to_env, env_actor + finally: + print("Cleaning up sliding_puzzle_environment...") + if env_actor: + env_actor.shutdown.remote() + ray.kill(env_actor) + print("sliding_puzzle_environment cleanup finished.") + + +# --- Fixture for Sliding Puzzle Initial Batch --- +@pytest.fixture(scope="function") +def initial_sliding_puzzle_batch(rollout_tokenizer): + print("Creating initial sliding puzzle test batch...") + batch_size = 1 + game_config = { + "size": 2, + "shuffle_moves": 1, + } + max_moves = 10 # Set a limit for the test + + # Generate initial game state + initial_game_state = SlidingPuzzleGame.generate(game_config) + initial_render = SlidingPuzzleGame.render(initial_game_state) + welcome_message = SlidingPuzzleGame.init(initial_game_state) + + prompt_instructions = ( + f"{welcome_message}\n\n" + f"Current Board State:\n{initial_render}\n\n" + f"Reach the goal state where numbers are ordered 1 through {game_config['size'] ** 2 - 1} " + f"with the empty space (0) at the bottom right.\n" + f"Valid actions: 'up', 'down', 'left', 'right'\n" + f"After thinking, output your chosen action on a new line starting with '' like this:\nyour_action" + f"\nIf you just want to see the board, output view" + f"\nThink carefully step-by-step before acting. If you get a 'cannot slide' error, try something different\n" + ) + + batch_message_logs = [] + batch_extra_env_info = [] + batch_loss_multipliers = [] + batch_indices = [] + batch_task_names = [] + + for i in range(batch_size): + # Apply chat template to the initial prompt + initial_prompt_content = rollout_tokenizer.apply_chat_template( + [{"role": "user", "content": prompt_instructions}], + tokenize=False, + add_system_prompt=True, # Include system prompt for Qwen + add_generation_prompt=True, + add_special_tokens=False, + ).strip() + tokenized_prompt = rollout_tokenizer( + initial_prompt_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + message_log = [ + { + "role": "user", + "content": initial_prompt_content, + "token_ids": tokenized_prompt, + } + ] + + metadata = SlidingPuzzleMetadata( + game_state=initial_game_state, num_moves=0, max_moves=max_moves + ) + + batch_message_logs.append(message_log) + batch_extra_env_info.append(metadata) + batch_loss_multipliers.append(1.0) + batch_indices.append(i) + batch_task_names.append("sliding_puzzle_game") + + initial_batch_dict = { + "message_log": batch_message_logs, + "extra_env_info": batch_extra_env_info, + "loss_multiplier": batch_loss_multipliers, + "idx": batch_indices, + "task_name": batch_task_names, + "stop_strings": [""], + } + return BatchedDataDict(initial_batch_dict) + + +@pytest.fixture(scope="function") +def sliding_puzzle_setup_vllm( + rollout_cluster, + rollout_tokenizer, + sliding_puzzle_environment, + initial_sliding_puzzle_batch, +): + """Sets up components for sliding puzzle tests using VllmGeneration.""" + vllm_generation = None + task_to_env, _ = sliding_puzzle_environment + is_eval = True + print("Creating VllmGeneration for Sliding Puzzle Test...") + try: + vllm_config = deepcopy(base_vllm_test_config) + # Qwen model name is already in base config + vllm_config["tokenizer_name"] = rollout_tokenizer.name_or_path + vllm_config = configure_generation_config( + vllm_config, rollout_tokenizer, is_eval=is_eval + ) + # Ensure max_new_tokens is sufficient + vllm_config["max_new_tokens"] = 500 + vllm_generation = VllmGeneration(rollout_cluster, vllm_config) + vllm_generation.finish_generation() + yield ( + vllm_generation, + rollout_tokenizer, + task_to_env, + initial_sliding_puzzle_batch, + rollout_cluster, + ) + finally: + print("Cleaning up VllmGeneration (Sliding Puzzle Test)...") + if vllm_generation: + vllm_generation.shutdown() + print("VllmGeneration cleanup finished (Sliding Puzzle Test).") + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 1, + reason="VLLM test requires at least 1 GPU", +) +def test_run_sliding_puzzle_vllm(sliding_puzzle_setup_vllm): + """Tests sliding puzzle rollout with VllmGeneration.""" + vllm_generation, rollout_tokenizer, task_to_env, initial_batch, rollout_cluster = ( + sliding_puzzle_setup_vllm + ) + max_moves = initial_batch["extra_env_info"][0]["max_moves"] + max_rollout_turns = max_moves + 1 + max_seq_len = 2048 + + print("\nRunning sliding puzzle rollout (VLLM)...") + vllm_generation.prepare_for_generation() + final_batch, rollout_metrics = run_multi_turn_rollout( + policy_generation=vllm_generation, + input_batch=initial_batch, + tokenizer=rollout_tokenizer, + task_to_env=task_to_env, + max_rollout_turns=max_rollout_turns, + max_seq_len=max_seq_len, + greedy=True, + ) + print(rollout_metrics) + vllm_generation.finish_generation() + print("Sliding puzzle rollout complete (VLLM).") + + # --- Assertions --- + assert isinstance(final_batch, BatchedDataDict) + assert "message_log" in final_batch + assert "total_reward" in final_batch + assert len(final_batch["message_log"]) == len(initial_batch["message_log"]) + + sample_log = final_batch["message_log"][0] + print(f"Final Total Reward: {final_batch['total_reward'][0].item()}") + + # Count the number of tags and environment messages + action_tag_count = 0 + environment_message_count = 0 + + for msg in sample_log: + if msg["role"] == "assistant" and "" in msg["content"]: + action_tag_count += 1 + elif msg["role"] == "environment": + environment_message_count += 1 + + print(f"Found {action_tag_count} messages with tags") + print(f"Found {environment_message_count} environment messages") + + # Assert that we have multiple action tags and environment messages + assert action_tag_count > 3, "Expected at least one message with tag" + assert environment_message_count > 3, "Expected at least one environment message" + + print("\nSliding Puzzle VLLM Test assertions passed.") diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 04e0cd5969..3f361394a9 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -814,3 +814,33 @@ def test_vllm_generation_with_stop( vllm_generation.shutdown() if not is_eval: hf_policy.shutdown() + + +def test_vllm_non_divisible_batch_handling(policy): + """Test that VLLM generation handles non divisible input batches correctly.""" + # This test runs on 2 GPUs but has a batch size of 1. The first GPU will run a batch + # and the second will run a batch of size 0. + + # Create and run with non divisible batch + empty_batch = BatchedDataDict( + { + "input_ids": torch.zeros((1, 1), dtype=torch.long), + "input_lengths": torch.ones(1, dtype=torch.long), + } + ) + + outputs = policy.generate(empty_batch) + + # Verify output structure and dimensions + required_keys = [ + "output_ids", + "logprobs", + "generation_lengths", + "unpadded_sequence_lengths", + ] + assert all(key in outputs for key in required_keys), ( + "Missing required output fields" + ) + assert all(outputs[key].shape[0] == 1 for key in required_keys), ( + "Output tensors should have a batch dimension of 1" + ) diff --git a/tests/unit/test_envs.py b/tests/unit/test_envs.py new file mode 100644 index 0000000000..5d410f5895 --- /dev/null +++ b/tests/unit/test_envs.py @@ -0,0 +1,418 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ray +import torch +from typing import Dict, List, Tuple, Optional, TypedDict, Literal, Any + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.data.interfaces import LLMMessageLogType +from nemo_reinforcer.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) +from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES +from .environments.sliding_puzzle_game import SlidingPuzzleGame + + +class MultiStepCalcMetadata(TypedDict): + problem: str + expected_final_answer: float + max_steps: int + current_step: int + + +class SlidingPuzzleMetadata(TypedDict): + game_state: Dict[str, Any] # Stores the dict returned by SlidingPuzzleGame methods + num_moves: int + max_moves: int + + +class _MultiStepCalculatorLogic: + def __init__(self): + pass + + def _parse_tool_call(self, text: str) -> Optional[Tuple[float, float, str]]: + """Parses '[opA, opB, operation]'.""" + # Use a more distinct tool call suffix + tool_call_suffix = "" + if not text.strip().endswith(tool_call_suffix): + return None + + content = text.strip()[: -len(tool_call_suffix)].strip() + if not (content.startswith("[") and content.endswith("]")): + return None + parts = content[1:-1].split(",") + if len(parts) != 3: + return None + try: + op_a = float(parts[0].strip()) + op_b = float(parts[1].strip()) + operation = parts[2].strip().lower() + return op_a, op_b, operation + except ValueError: + return None + + def _calculate(self, op_a: float, op_b: float, operation: str) -> Optional[float]: + """Performs the calculation.""" + # (Reusing the calculation logic) + if operation == "sum": + return op_a + op_b + elif operation == "diff": + return op_a - op_b + elif operation == "prod": + return op_a * op_b + elif operation == "div": + if abs(op_b) < 1e-6: + return None # Division by zero error + return op_a / op_b + else: + return None # Unknown operation + + def _is_final_answer(self, text: str) -> Optional[float]: + """Checks if the text is just a final numerical answer.""" + try: + # Allow potential formatting like 16.0 + # or just the number itself. + processed_text = text.strip() + if processed_text.startswith("") and processed_text.endswith( + "" + ): + processed_text = processed_text[ + len("") : -len("") + ] + + return float(processed_text) + except ValueError: + return None + + def process_turn( + self, + message_log: LLMMessageLogType, + metadata: MultiStepCalcMetadata, + ) -> Tuple[ + Dict[str, str], + float, + bool, + Optional[List[str]], + Optional[MultiStepCalcMetadata], + ]: + """Processes a single turn for the multi-step calculator task.""" + last_assistant_msg = "" + if message_log and message_log[-1]["role"] == "assistant": + last_assistant_msg = message_log[-1]["content"].strip() + + current_step = metadata["current_step"] + max_steps = metadata["max_steps"] + expected_final_answer = metadata["expected_final_answer"] + + turn_reward = 0.0 + is_terminated = False + next_stop_strings = [ + "" + ] # Let model generate tool call or final answer freely + next_metadata = metadata.copy() + next_observation_content = "" + + # Check if max steps reached + if current_step >= max_steps: + is_terminated = True + next_observation_content = "Maximum steps reached." + next_metadata = None + return ( + {"role": "environment", "content": next_observation_content}, + 0.0, + is_terminated, + None, + next_metadata, + ) + + # Check for final answer first + final_answer = self._is_final_answer(last_assistant_msg) + if final_answer is not None: + is_terminated = True + next_metadata = None # End of episode + if abs(final_answer - expected_final_answer) < 1e-6: + turn_reward = 1.0 # Correct final answer + next_observation_content = ( + f"Correct! The final answer is {final_answer:.2f}." + ) + else: + turn_reward = 0.0 # Incorrect final answer + next_observation_content = f"Incorrect final answer. Expected {expected_final_answer:.2f}, got {final_answer:.2f}." + else: + # Check for tool call + parsed_call = self._parse_tool_call(last_assistant_msg) + if parsed_call: + req_op_a, req_op_b, req_op = parsed_call + result = self._calculate(req_op_a, req_op_b, req_op) + if result is not None: + # Tool call success, provide result + next_observation_content = f"{result:.5f}" + next_metadata["current_step"] += 1 + is_terminated = False + else: # Calculation failed + is_terminated = True + next_observation_content = "Calculation failed." + next_metadata = None + else: # No final answer and no valid tool call + is_terminated = True + next_observation_content = ( + "Invalid response. Expected tool call or final answer." + ) + next_metadata = None + + next_observation = {"role": "environment", "content": next_observation_content} + return ( + next_observation, + turn_reward, + is_terminated, + next_stop_strings, + next_metadata, + ) + + +class _SlidingPuzzleLogic: + def __init__(self): + pass # No initialization needed as game methods are static + + def _parse_action(self, text: str) -> Optional[str]: + """Parses the action from ''""" + prefix = "" + suffix = "" + # Find the prefix, case-insensitive, and potentially after some thought process + text_lower = text.lower() + prefix_lower = prefix.lower() + suffix_lower = suffix.lower() + + start_idx = text_lower.rfind(prefix_lower) # Find the last occurrence + + if start_idx != -1: + # Find the end tag after the start tag + end_idx = text_lower.find(suffix_lower, start_idx + len(prefix_lower)) + if end_idx != -1: + # Extract content between tags + action_content = text[start_idx + len(prefix) : end_idx].strip() + return action_content + return None + + def process_turn( + self, + message_log: LLMMessageLogType, + metadata: SlidingPuzzleMetadata, + ) -> Tuple[ + Dict[str, str], + float, + bool, + Optional[List[str]], + Optional[SlidingPuzzleMetadata], + ]: + """Processes a single turn for the sliding puzzle task.""" + game_state = metadata["game_state"] + current_moves = metadata["num_moves"] + max_moves = metadata["max_moves"] + + turn_reward = 0.0 + is_terminated = False + next_stop_strings = [""] + next_metadata = metadata.copy() + next_observation_content = "" + + # Check if max moves reached + if current_moves >= max_moves: + is_terminated = True + next_observation_content = ( + f"Maximum moves ({max_moves}) reached." + ) + next_metadata = None + return ( + {"role": "environment", "content": next_observation_content}, + 0.0, + is_terminated, + None, + next_metadata, + ) + + # Get last assistant message and parse action + last_assistant_msg_content = "" + if message_log and message_log[-1]["role"] == "assistant": + last_assistant_msg_content = message_log[-1]["content"].strip() + + parsed_action = self._parse_action(last_assistant_msg_content) + + if parsed_action is None: + # Handle cases where parsing failed or it wasn't assistant's turn properly + # is_terminated = True # Penalize for bad format + rendered_board = SlidingPuzzleGame.render(game_state) + next_observation_content = f"\n{rendered_board}\n\nInvalid response format no move made. Try like this: your_action" + next_metadata = None + elif parsed_action == "view": + rendered_board = SlidingPuzzleGame.render(game_state) + next_observation_content = f"\n{rendered_board}\n\nViewing the board. No move made." + else: + # Execute the game step + step_response, reward, game_over, next_game_state = SlidingPuzzleGame.step( + parsed_action, game_state + ) + + turn_reward = reward + is_terminated = game_over + next_metadata["game_state"] = next_game_state + next_metadata["num_moves"] = current_moves + 1 + + # Combine rendered board and step response for the next observation + rendered_board = SlidingPuzzleGame.render(next_game_state) + # next_observation_content = f"\n{rendered_board}\n\n{step_response}" + next_observation_content = f"\n{step_response}\n" + # next_observation_content = f"\n{step_response}" + + if is_terminated: + next_metadata = None # Clear metadata on termination + # next_stop_strings remains None + + return ( + {"role": "environment", "content": next_observation_content + "\n"}, + turn_reward, + is_terminated, + next_stop_strings, + next_metadata, + ) + + +@ray.remote +class MultiStepCalculatorEnv(EnvironmentInterface): + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM + """Multi-step calculator environment (Ray Actor).""" + + def __init__(self, cfg: Optional[Dict] = None): + self.logic = _MultiStepCalculatorLogic() + + def step( + self, + message_log_batch: List[LLMMessageLogType], + metadata_batch: List[MultiStepCalcMetadata], + ) -> EnvironmentReturn: + """Processes a batch of interactions using the calculator logic.""" + futures = [ + self.logic.process_turn(log, meta) + for log, meta in zip(message_log_batch, metadata_batch) + ] + results = futures + + # Unpack results and format according to EnvironmentReturn tuple + observations = [] + rewards = [] + terminateds = [] + all_stop_strings = [] # List of Lists or Nones + all_next_metadata = [] + + for obs, rew, term, stops, meta in results: + observations.append(obs) # obs is already Dict[str, str] + rewards.append(rew) + terminateds.append(term) + all_stop_strings.append(stops) + all_next_metadata.append(meta) + + # Convert to tensors where needed + rewards_tensor = torch.tensor(rewards, dtype=torch.float32) + # Done flag combines termination and truncation (truncation not used here) + done_tensor = torch.tensor(terminateds, dtype=torch.bool) + + # Return tuple matching EnvironmentReturn NamedTuple + return EnvironmentReturn( + observations=observations, + metadata=all_next_metadata, + next_stop_strings=all_stop_strings, + rewards=rewards_tensor, + terminateds=done_tensor, + ) + + def shutdown(self): + pass + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + # Example: could calculate success rate based on final reward + final_rewards = batch.get( + "total_reward", torch.tensor([0.0] * len(batch["idx"])) + ) + success_rate = final_rewards.mean().item() if len(final_rewards) > 0 else 0.0 + return batch, {"success_rate": success_rate} + + +@ray.remote +class SlidingPuzzleEnv(EnvironmentInterface): + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM + """Sliding Puzzle environment (Ray Actor).""" + + def __init__(self, cfg: Optional[Dict] = None): + # cfg could contain game generation config like {'size': 3, 'shuffle_moves': 50} + self.game_config = cfg.get("game_config", {}) if cfg else {} + self.logic = _SlidingPuzzleLogic() + + def step( + self, + message_log_batch: List[LLMMessageLogType], + metadata_batch: List[SlidingPuzzleMetadata], + ) -> EnvironmentReturn: + """Processes a batch of sliding puzzle interactions.""" + # Since logic is synchronous, process sequentially (can parallelize if logic becomes heavy) + results = [ + self.logic.process_turn(log, meta) + for log, meta in zip(message_log_batch, metadata_batch) + ] + + # Unpack results and format according to EnvironmentReturn NamedTuple + observations = [] + rewards = [] + terminateds = [] + all_stop_strings = [] + all_next_metadata = [] + + for obs, rew, term, stops, meta in results: + observations.append(obs) + rewards.append(rew) + terminateds.append(term) + all_stop_strings.append(stops) + all_next_metadata.append(meta) + + rewards_tensor = torch.tensor(rewards, dtype=torch.float32) + terminated_tensor = torch.tensor(terminateds, dtype=torch.bool) + + return EnvironmentReturn( + observations=observations, + metadata=all_next_metadata, + next_stop_strings=all_stop_strings, + rewards=rewards_tensor, + terminateds=terminated_tensor, + ) + + def shutdown(self): + pass + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + # Calculate success rate based on final reward == 1.0 + final_rewards = batch.get( + "total_reward", torch.tensor([0.0] * len(batch["idx"])) + ) + success_rate = ( + (final_rewards == 1.0).float().mean().item() + if len(final_rewards) > 0 + else 0.0 + ) + # Could also calculate average number of moves for successful episodes, etc. + return batch, {"sliding_puzzle_success_rate": success_rate}