From 991393ad89aebcba03c643b10ffd87bd63e94041 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 15 Oct 2025 02:17:39 -0700 Subject: [PATCH] fix: grpo early exit edge case (#1361) Signed-off-by: Terry Kong Signed-off-by: NeMo Bot --- nemo_rl/algorithms/distillation.py | 9 +- nemo_rl/algorithms/dpo.py | 5 + nemo_rl/algorithms/grpo.py | 34 +- nemo_rl/algorithms/rm.py | 5 + nemo_rl/algorithms/sft.py | 5 + tests/unit/algorithms/test_distillation.py | 59 +- tests/unit/algorithms/test_dpo.py | 230 +++++++- tests/unit/algorithms/test_grpo.py | 599 +++++++++++++++++++++ tests/unit/algorithms/test_rm.py | 56 +- tests/unit/algorithms/test_sft.py | 56 +- 10 files changed, 1049 insertions(+), 9 deletions(-) diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index fe71e56b02..cc6a935aa8 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -838,8 +838,15 @@ def distillation_train( timer.reset() step += 1 + if should_save_by_timeout: + print("Timeout has been reached, stopping training early", flush=True) + return if step >= max_steps: - break + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) + return def validate( diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index ed4cf7f127..68ac6542e5 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -728,8 +728,13 @@ def dpo_train( total_steps += 1 if should_save_by_timeout: + print("Timeout has been reached, stopping training early", flush=True) return if total_steps >= master_config["dpo"]["max_num_steps"]: + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) return current_epoch += 1 diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d69b45fdfb..589046df62 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1025,9 +1025,14 @@ def grpo_train( current_step += 1 total_steps += 1 if should_save_by_timeout: - break + print("Timeout has been reached, stopping training early", flush=True) + return if total_steps >= max_num_steps: - break + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) + return current_epoch += 1 current_step = 0 # Reset step counter for new epoch @@ -1195,6 +1200,11 @@ def async_grpo_train( from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() NEED_REFIT = True # Setup generation interface @@ -1684,9 +1694,18 @@ def async_grpo_train( # Checkpointing (same as sync version) consumed_samples += master_config["grpo"]["num_prompts_per_step"] - if master_config["checkpointing"]["enabled"] and ( + timeout.mark_iteration() + + should_save_by_step = ( is_last_step or (step + 1) % master_config["checkpointing"]["save_period"] == 0 + ) + # +1 because step is 0-indexed + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout ): policy.prepare_for_training() @@ -1788,6 +1807,15 @@ def async_grpo_train( timer.reset() step += 1 + if should_save_by_timeout: + print("Timeout has been reached, stopping training early", flush=True) + return + if step >= master_config["grpo"]["max_num_steps"]: + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) + return finally: # Clean up diff --git a/nemo_rl/algorithms/rm.py b/nemo_rl/algorithms/rm.py index 79ed068e99..5c78df4129 100644 --- a/nemo_rl/algorithms/rm.py +++ b/nemo_rl/algorithms/rm.py @@ -646,11 +646,16 @@ def rm_train( total_steps += 1 if should_save_by_timeout: + print("Timeout has been reached, stopping training early", flush=True) return if ( master_config["rm"]["max_num_steps"] != -1 and total_steps >= master_config["rm"]["max_num_steps"] ): + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) return current_epoch += 1 diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 7d1dec65fb..877a3bff25 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -599,8 +599,13 @@ def sft_train( total_steps += 1 if should_save_by_timeout: + print("Timeout has been reached, stopping training early", flush=True) return if total_steps >= master_config["sft"]["max_num_steps"]: + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) return current_epoch += 1 diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index d51e9c1eed..bc1d4b734e 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch @@ -212,6 +212,63 @@ def test_distillation_train_max_steps(mock_components): assert mock_components["student_policy"].train.call_count == 5 +def test_exit_on_timeout(mock_components, capsys): + """Test that training loop exits when timeout is reached""" + # Set max steps to large number + mock_components["master_config"]["distillation"]["max_num_steps"] = 100 + + distillation_save_state = _default_distillation_save_state() + + # Mock TimeoutChecker to return False for first 7 checks, then True (timeout) + with patch("nemo_rl.algorithms.distillation.TimeoutChecker") as mock_timeout_class: + mock_timeout_instance = MagicMock() + # Create a side_effect that returns False 7 times, then True + check_results = [False] * 7 + [True] + mock_timeout_instance.check_save.side_effect = check_results + mock_timeout_class.return_value = mock_timeout_instance + + # Run training + distillation_train( + mock_components["student_policy"], + mock_components["teacher_policy"], + mock_components["student_generation"], + mock_components["train_dataloader"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["loss_fn"], + mock_components["task_to_env"], + mock_components["val_task_to_env"], + mock_components["logger"], + mock_components["checkpointer"], + distillation_save_state, + mock_components["master_config"], + ) + + # Verify training stopped at 8 steps (when check_save returned True) + assert mock_components["student_policy"].train.call_count == 8 + + # Verify the timeout message was printed and training actually stopped + captured = capsys.readouterr() + output_lines = captured.out.strip().split("\n") + + # Find the timeout message + timeout_line_idx = None + for i, line in enumerate(output_lines): + if "Timeout has been reached, stopping training early" in line: + timeout_line_idx = i + break + + assert timeout_line_idx is not None, "Timeout message not found in output" + + # For distillation, verify we don't see more step messages after timeout + remaining_lines = output_lines[timeout_line_idx:] + for line in remaining_lines: + # Distillation doesn't have epochs, but check for step markers + assert not line.startswith("Step ") or "Step 8" in line, ( + f"Training continued after timeout: {line}" + ) + + def test_validate_function(mock_components): """Test independent validation function to ensure validation logic correctness.""" # Run validation diff --git a/tests/unit/algorithms/test_dpo.py b/tests/unit/algorithms/test_dpo.py index c81f4da515..45d59fe990 100644 --- a/tests/unit/algorithms/test_dpo.py +++ b/tests/unit/algorithms/test_dpo.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np +import pytest import torch +from torchdata.stateful_dataloader import StatefulDataLoader -from nemo_rl.algorithms.dpo import add_ref_logprobs_to_data +from nemo_rl.algorithms.dpo import ( + _default_dpo_save_state, + add_ref_logprobs_to_data, + dpo_train, +) +from nemo_rl.algorithms.loss_functions import PreferenceLoss from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.named_sharding import NamedSharding @@ -90,3 +97,222 @@ def test_add_logprobs_to_batch(): # Verify the logprobs were rolled by -1 as expected expected_logprobs = torch.roll(mock_logprobs, -1, dims=-1) assert torch.equal(augmented_batch["reference_policy_logprobs"], expected_logprobs) + + +@pytest.fixture +def mock_dpo_components(): + # Create mock components + policy = MagicMock() + policy.train.return_value = { + "loss": torch.tensor(0.5), + "grad_norm": torch.tensor(1.0), + "all_mb_metrics": { + "loss": [0.5], + "sft_loss": [0.3], + "preference_loss": [0.2], + "accuracy": [1.0], + "rewards_chosen_mean": [4.5], + "rewards_rejected_mean": [3.5], + "num_valid_samples": [1.0], + "global_valid_seqs": [1.0], + "global_valid_toks": [10], + }, + } + policy.get_reference_policy_logprobs.return_value = { + "reference_logprobs": torch.randn(2, 10) + } + policy.sharding_annotations = NamedSharding( + layout=np.arange(1).reshape(1, -1, 1, 1), # 1 GPU to match cluster config + names=[ + "pipeline_parallel", + "data_parallel", + "context_parallel", + "tensor_parallel", + ], + ) + + # Create a proper message log structure with token_ids + mock_batch = BatchedDataDict( + { + "message_log": [ + [ # chosen + {"role": "user", "token_ids": torch.tensor([1, 2, 3])}, + {"role": "assistant", "token_ids": torch.tensor([4, 5, 6])}, + ], + [ # rejected + {"role": "user", "token_ids": torch.tensor([1, 2, 3])}, + {"role": "assistant", "token_ids": torch.tensor([7, 8, 9, 10, 11])}, + ], + ], + "length": torch.tensor([6, 8]), + "loss_multiplier": torch.tensor([1.0, 1.0]), + } + ) + + # Create mock dataloader with 10 batches that can be iterated multiple times + train_dataloader = MagicMock(spec=StatefulDataLoader) + + def train_iter(self): + return iter([mock_batch] * 10) + + train_dataloader.__iter__ = train_iter + train_dataloader.__len__ = MagicMock(return_value=10) + + val_dataloader = MagicMock(spec=StatefulDataLoader) + + def val_iter(self): + return iter([mock_batch] * 10) + + val_dataloader.__iter__ = val_iter + val_dataloader.__len__ = MagicMock(return_value=10) + + tokenizer = MagicMock() + tokenizer.pad_token_id = 0 + + loss_fn = PreferenceLoss() + logger = MagicMock() + checkpointer = MagicMock() + + # Create mock master config + master_config = { + "dpo": { + "max_num_steps": 5, + "max_num_epochs": 2, + "val_period": 100, + "val_batches": 1, + "val_global_batch_size": 1, + "val_micro_batch_size": 1, + "val_at_start": False, + }, + "policy": { + "train_global_batch_size": 2, + "make_sequence_length_divisible_by": 1, + "reward_model_cfg": { + "enabled": True, + "reward_model_type": "bradley_terry", + }, + "train_micro_batch_size": 1, + }, + "checkpointing": { + "enabled": False, + "checkpoint_must_save_by": None, + "save_period": 10, + }, + "cluster": { + "num_nodes": 1, + "gpus_per_node": 1, + }, + } + + return { + "policy": policy, + "train_dataloader": train_dataloader, + "val_dataloader": val_dataloader, + "tokenizer": tokenizer, + "loss_fn": loss_fn, + "logger": logger, + "checkpointer": checkpointer, + "master_config": master_config, + } + + +def test_exit_on_max_steps(mock_dpo_components): + """Test that training loop exits when max_num_steps is reached""" + # Set max steps to 12, which is less than len(train_dataloader) * max_num_epochs + mock_dpo_components["master_config"]["dpo"]["max_num_steps"] = 12 + + dpo_save_state = _default_dpo_save_state() + + # Run training + dpo_train( + mock_dpo_components["policy"], + mock_dpo_components["train_dataloader"], + mock_dpo_components["val_dataloader"], + mock_dpo_components["tokenizer"], + mock_dpo_components["loss_fn"], + mock_dpo_components["master_config"], + mock_dpo_components["logger"], + mock_dpo_components["checkpointer"], + dpo_save_state, + ) + + # Verify we only trained for 12 steps. + assert mock_dpo_components["policy"].train.call_count == 12 + + +def test_exit_on_max_epochs(mock_dpo_components): + """Test that training loop exits when max_num_epochs is reached""" + # Set max epochs to 2 and max steps to a large number + mock_dpo_components["master_config"]["dpo"]["max_num_epochs"] = 2 + mock_dpo_components["master_config"]["dpo"]["max_num_steps"] = 100 + + dpo_save_state = _default_dpo_save_state() + + # Run training + dpo_train( + mock_dpo_components["policy"], + mock_dpo_components["train_dataloader"], + mock_dpo_components["val_dataloader"], + mock_dpo_components["tokenizer"], + mock_dpo_components["loss_fn"], + mock_dpo_components["master_config"], + mock_dpo_components["logger"], + mock_dpo_components["checkpointer"], + dpo_save_state, + ) + + # Verify we trained for exactly two epochs (20 batches). + assert mock_dpo_components["policy"].train.call_count == 20 + + +def test_exit_on_timeout(mock_dpo_components, capsys): + """Test that training loop exits when timeout is reached""" + # Set max steps and epochs to large numbers + mock_dpo_components["master_config"]["dpo"]["max_num_steps"] = 100 + mock_dpo_components["master_config"]["dpo"]["max_num_epochs"] = 10 + + dpo_save_state = _default_dpo_save_state() + + # Mock TimeoutChecker to return False for first 7 checks, then True (timeout) + with patch("nemo_rl.algorithms.dpo.TimeoutChecker") as mock_timeout_class: + mock_timeout_instance = MagicMock() + # Create a side_effect that returns False 7 times, then True + check_results = [False] * 7 + [True] + mock_timeout_instance.check_save.side_effect = check_results + mock_timeout_class.return_value = mock_timeout_instance + + # Run training + dpo_train( + mock_dpo_components["policy"], + mock_dpo_components["train_dataloader"], + mock_dpo_components["val_dataloader"], + mock_dpo_components["tokenizer"], + mock_dpo_components["loss_fn"], + mock_dpo_components["master_config"], + mock_dpo_components["logger"], + mock_dpo_components["checkpointer"], + dpo_save_state, + ) + + # Verify training stopped at 8 steps (when check_save returned True) + assert mock_dpo_components["policy"].train.call_count == 8 + + # Verify the timeout message was printed and is near the end (not followed by more training) + captured = capsys.readouterr() + output_lines = captured.out.strip().split("\n") + + # Find the timeout message + timeout_line_idx = None + for i, line in enumerate(output_lines): + if "Timeout has been reached, stopping training early" in line: + timeout_line_idx = i + break + + assert timeout_line_idx is not None, "Timeout message not found in output" + + # Verify no new epoch started after timeout (which would indicate a bug where break was used instead of return) + remaining_lines = output_lines[timeout_line_idx:] + for line in remaining_lines: + assert "Epoch" not in line or "Epoch 1/10" in line, ( + f"Training continued to next epoch after timeout: {line}" + ) diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 13ca63bc1b..6ee5449eb9 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -12,10 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import MagicMock, patch + import pytest import ray import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from nemo_rl.algorithms.grpo import ( + _default_grpo_save_state, + async_grpo_train, + grpo_train, +) +from nemo_rl.algorithms.loss_functions import ClippedPGLossFn from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.environments.interfaces import ( @@ -24,6 +33,210 @@ ) from nemo_rl.experience.rollouts import calculate_rewards +# ============================================================================ +# Stub classes for async GRPO testing (non-Ray versions for easy mocking) +# ============================================================================ + + +class StubReplayBuffer: + """Non-Ray stub of ReplayBuffer for unit testing + + Each method returns a MagicMock with a 'remote' attribute that can be called. + """ + + def __init__(self, initial_size=10, mock_batch=None, mock_rollout_metrics=None): + self._size = initial_size + self._trajectories = [] + self._mock_batch = mock_batch + self._mock_rollout_metrics = mock_rollout_metrics or {} + + @property + def size(self): + """Return a mock that returns buffer size when .remote() is called""" + mock = MagicMock() + mock.remote = MagicMock(return_value=self._size) # ray.get will extract this + return mock + + @property + def sample(self): + """Return a mock that returns sample result when .remote() is called""" + + def _sample(num_prompt_groups, current_weight_version, max_age_steps): + # Return proper trajectory structure expected by async GRPO + trajectories = [ + { + "batch": self._mock_batch, + "rollout_metrics": self._mock_rollout_metrics, + } + for _ in range(num_prompt_groups) + ] + return { + "trajectories": trajectories, + "avg_trajectory_age": 0.5, + } + + mock = MagicMock() + mock.remote = MagicMock( + side_effect=lambda *args, **kwargs: _sample(*args, **kwargs) + ) + return mock + + @property + def get_debug_info(self): + """Return a mock that returns debug info when .remote() is called""" + mock = MagicMock() + mock.remote = MagicMock( + return_value={ + "total_trajectories": self._size, + "trajectory_versions": [0], + "target_weight_versions": [0], + "max_size": 100, + } + ) + return mock + + +class StubAsyncTrajectoryCollector: + """Non-Ray stub of AsyncTrajectoryCollector for unit testing + + Each method is a property that returns a MagicMock with a 'remote' attribute. + """ + + @property + def start_collection(self): + """Start collection - returns a remote-callable mock""" + mock = MagicMock() + mock.remote = MagicMock(return_value=MagicMock()) # Returns a fake ObjectRef + return mock + + @property + def set_weight_version(self): + """Set weight version - returns a remote-callable mock""" + mock = MagicMock() + mock.remote = MagicMock(return_value=MagicMock()) + return mock + + @property + def pause(self): + """Pause collection - returns a remote-callable mock""" + mock = MagicMock() + mock.remote = MagicMock(return_value=MagicMock()) + return mock + + @property + def resume(self): + """Resume collection - returns a remote-callable mock""" + mock = MagicMock() + mock.remote = MagicMock(return_value=MagicMock()) + return mock + + @property + def stop(self): + """Stop collection - returns a remote-callable mock""" + mock = MagicMock() + mock.remote = MagicMock(return_value=MagicMock()) + return mock + + @property + def wait_for_stop(self): + """Wait for stop - returns a remote-callable mock""" + mock = MagicMock() + mock.remote = MagicMock(return_value=MagicMock()) + return mock + + +def mock_async_grpo_infrastructure(mock_batch, mock_rollout_metrics): + """ + Context manager that mocks all async GRPO infrastructure (Ray actors, venv, etc). + + Returns a dict of patches that can be used as a context manager stack. + """ + from contextlib import ExitStack + + stack = ExitStack() + + # Create stub instances with mock data + stub_buffer = StubReplayBuffer( + initial_size=10, + mock_batch=mock_batch, + mock_rollout_metrics=mock_rollout_metrics, + ) + stub_collector = StubAsyncTrajectoryCollector() + + # Patch venv creation + stack.enter_context( + patch( + "nemo_rl.algorithms.grpo.create_local_venv_on_each_node", + return_value="/fake/venv", + ) + ) + stack.enter_context( + patch( + "nemo_rl.algorithms.grpo.get_actor_python_env", return_value="/fake/python" + ) + ) + + # Patch Ray actor classes to return our stubs + mock_buffer_cls = MagicMock() + mock_buffer_cls.options.return_value.remote.return_value = stub_buffer + stack.enter_context( + patch("nemo_rl.algorithms.async_utils.ReplayBuffer", mock_buffer_cls) + ) + + mock_collector_cls = MagicMock() + mock_collector_cls.options.return_value.remote.return_value = stub_collector + stack.enter_context( + patch( + "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector", + mock_collector_cls, + ) + ) + + # Patch ray.get to return values from our stubs (not remote refs) + def mock_ray_get(ref): + # If it's already a plain value (from our stubs), return it + if isinstance(ref, (int, str, dict, list)): + return ref + # If it's a MagicMock, return a default response + return None + + stack.enter_context(patch("ray.get", side_effect=mock_ray_get)) + stack.enter_context( + patch("ray.wait", side_effect=lambda refs, **kwargs: (refs, [])) + ) + stack.enter_context( + patch("ray.kill", return_value=None) + ) # Mock ray.kill for cleanup + + # Patch the rollout functions used inside async_grpo_train + stack.enter_context( + patch( + "nemo_rl.algorithms.grpo.run_multi_turn_rollout", + return_value=(mock_batch, mock_rollout_metrics), + ) + ) + stack.enter_context( + patch( + "nemo_rl.algorithms.grpo.run_async_multi_turn_rollout", + return_value=(mock_batch, mock_rollout_metrics), + ) + ) + + # Patch refit and validate functions + stack.enter_context( + patch("nemo_rl.algorithms.grpo.refit_policy_generation", return_value=None) + ) + stack.enter_context( + patch("nemo_rl.algorithms.grpo.validate", return_value=({}, {})) + ) + + # Mock print_performance_metrics to avoid needing real timing metrics + stack.enter_context( + patch("nemo_rl.algorithms.grpo.print_performance_metrics", return_value={}) + ) + + return stack + @ray.remote(num_cpus=0) class MockEnvironment(EnvironmentInterface): @@ -322,3 +535,389 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node(): # Configure mocks to skip checkpoint loading mock_checkpointer.return_value.get_latest_checkpoint_path.return_value = None setup(master_config, tokenizer, dataset, None) + + +@pytest.fixture +def mock_grpo_components(): + # Create mock components + policy = MagicMock() + policy.train.return_value = { + "loss": torch.tensor(0.5), + "grad_norm": torch.tensor(1.0), + "all_mb_metrics": { + "loss": [0.5], + "policy_gradient_loss": [0.3], + "value_loss": [0.2], + "global_valid_toks": [10], + "token_mult_prob_error": [ + 1.0 + ], # Must be <= 1.05 to avoid logging extra plots + }, + } + policy.generate.return_value = { + "output_ids": torch.randint(0, 100, (2, 20)), + "generation_lengths": torch.tensor([10, 15]), + "unpadded_sequence_lengths": torch.tensor([12, 18]), + "logprobs": torch.randn(2, 20), + } + policy.prepare_for_training.return_value = None + # Mock sharding annotations for async GRPO + policy.sharding_annotations.get_axis_size.return_value = 1 # data_parallel size + + # Create mock batch with proper structure + mock_batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + { + "role": "user", + "content": "test", + "token_ids": torch.tensor([1, 2, 3]), + }, + ] + ], + "task_name": ["math"], + "extra_env_info": [{}], + "loss_multiplier": torch.tensor([1.0]), + "idx": torch.tensor([0]), + "length": torch.tensor([3]), # Add length field for GRPO + "total_reward": torch.tensor( + [1.0] + ), # Add total_reward for rollout processing + } + ) + + # Create mock dataloader with 10 batches + train_dataloader = MagicMock(spec=StatefulDataLoader) + + def train_iter(self): + return iter([mock_batch] * 10) + + train_dataloader.__iter__ = train_iter + train_dataloader.__len__ = MagicMock(return_value=10) + + val_dataloader = MagicMock(spec=StatefulDataLoader) + + def val_iter(self): + return iter([mock_batch] * 10) + + val_dataloader.__iter__ = val_iter + val_dataloader.__len__ = MagicMock(return_value=10) + + tokenizer = MagicMock() + tokenizer.pad_token_id = 0 + + loss_fn = ClippedPGLossFn( + { + "reference_policy_kl_penalty": 0.01, + "ratio_clip_min": 0.8, + "ratio_clip_max": 1.2, + "ratio_clip_c": 1.0, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "token_level_loss": True, + } + ) + logger = MagicMock() + checkpointer = MagicMock() + + # Create mock environment + task_to_env = {"math": MagicMock()} + val_task_to_env = {"math": MagicMock()} + + # Mock environment return values + for env in [task_to_env["math"], val_task_to_env["math"]]: + env.step.return_value = ( + [{"role": "environment", "content": "correct"}], # observations + [{}], # metadata + [[]], # next_stop_strings + [1.0], # rewards + [True], # terminateds + [None], # answers + ) + env.global_post_process_and_metrics.return_value = (mock_batch, {}) + + # Create mock master config + master_config = { + "grpo": { + "max_num_steps": 5, + "max_num_epochs": 2, + "num_prompts_per_step": 1, + "num_generations_per_prompt": 1, + "max_rollout_turns": 1, + "val_period": 100, + "val_batch_size": 1, + "val_at_start": False, + "max_val_samples": 10, + "seed": 42, + "advantage_normalization": "global", + "use_leave_one_out_baseline": False, + "normalize_rewards": False, + "overlong_filtering": False, + }, + "policy": { + "train_global_batch_size": 1, + "train_micro_batch_size": 1, + "max_total_sequence_length": 2048, + "make_sequence_length_divisible_by": 1, + "generation": { + "backend": "vllm", + "colocated": {"enabled": True}, + "vllm_cfg": {"async_engine": True}, # Support async mode + }, + }, + "loss_fn": { + "use_importance_sampling_correction": True, # Required for async mode + }, + "checkpointing": { + "enabled": False, + "checkpoint_must_save_by": None, + "save_period": 10, + }, + "cluster": { + "num_nodes": 1, + "gpus_per_node": 2, + }, + "logger": { + "num_val_samples_to_print": 5, + }, + } + + return { + "policy": policy, + "train_dataloader": train_dataloader, + "val_dataloader": val_dataloader, + "tokenizer": tokenizer, + "loss_fn": loss_fn, + "logger": logger, + "checkpointer": checkpointer, + "task_to_env": task_to_env, + "val_task_to_env": val_task_to_env, + "master_config": master_config, + } + + +@pytest.mark.parametrize("train_func", [grpo_train, async_grpo_train]) +def test_grpo_exit_on_max_steps(mock_grpo_components, train_func): + """Test that GRPO training loop exits when max_num_steps is reached""" + # Set max steps to 12 + mock_grpo_components["master_config"]["grpo"]["max_num_steps"] = 12 + grpo_save_state = _default_grpo_save_state() + + # Async GRPO requires non-colocated inference + if train_func == async_grpo_train: + mock_grpo_components["master_config"]["policy"]["generation"]["colocated"][ + "enabled" + ] = False + + # Prepare mock data + mock_rollout_metrics = { + "mean_gen_tokens_per_sample": 10.0, + "max_gen_tokens": 20, + "min_gen_tokens": 5, + } + mock_batch = next(iter(mock_grpo_components["train_dataloader"])) + + # Use our helper to mock async infrastructure if needed + if train_func == async_grpo_train: + with mock_async_grpo_infrastructure(mock_batch, mock_rollout_metrics): + train_func( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) + else: + # For sync grpo_train, just mock the rollout functions + with patch( + "nemo_rl.algorithms.grpo.run_multi_turn_rollout", + return_value=(mock_batch, mock_rollout_metrics), + ): + with patch( + "nemo_rl.algorithms.grpo.run_async_multi_turn_rollout", + return_value=(mock_batch, mock_rollout_metrics), + ): + train_func( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) + + # Verify we trained for exactly 12 steps + assert mock_grpo_components["policy"].train.call_count == 12 + + +@pytest.mark.parametrize( + "train_func", [grpo_train] +) # Only test sync version for epochs (async uses steps) +def test_grpo_exit_on_max_epochs(mock_grpo_components, train_func): + """Test that GRPO training loop exits when max_num_epochs is reached""" + # Set max epochs to 2 and max steps to a large number + mock_grpo_components["master_config"]["grpo"]["max_num_epochs"] = 2 + mock_grpo_components["master_config"]["grpo"]["max_num_steps"] = 100 + + grpo_save_state = _default_grpo_save_state() + + # Mock rollout functions to return proper metrics + mock_rollout_metrics = { + "mean_gen_tokens_per_sample": 10.0, + "max_gen_tokens": 20, + "min_gen_tokens": 5, + } + + # Get a mock batch to return + mock_batch = next(iter(mock_grpo_components["train_dataloader"])) + + with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = (mock_batch, mock_rollout_metrics) + + with patch( + "nemo_rl.algorithms.grpo.run_async_multi_turn_rollout" + ) as mock_async_rollout: + mock_async_rollout.return_value = (mock_batch, mock_rollout_metrics) + + # Run training + train_func( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) + + # Verify we trained for exactly two epochs (20 batches) + assert mock_grpo_components["policy"].train.call_count == 20 + + +@pytest.mark.parametrize("train_func", [grpo_train, async_grpo_train]) +def test_grpo_exit_on_timeout(mock_grpo_components, train_func, capsys): + """Test that GRPO training loop exits when timeout is reached""" + # Set max steps and epochs to large numbers + mock_grpo_components["master_config"]["grpo"]["max_num_steps"] = 100 + mock_grpo_components["master_config"]["grpo"]["max_num_epochs"] = 10 + grpo_save_state = _default_grpo_save_state() + + # Async GRPO requires non-colocated inference + if train_func == async_grpo_train: + mock_grpo_components["master_config"]["policy"]["generation"]["colocated"][ + "enabled" + ] = False + + # Prepare mock data + mock_rollout_metrics = { + "mean_gen_tokens_per_sample": 10.0, + "max_gen_tokens": 20, + "min_gen_tokens": 5, + } + mock_batch = next(iter(mock_grpo_components["train_dataloader"])) + + # Mock TimeoutChecker to return False for first 7 checks, then True (timeout) + with patch("nemo_rl.algorithms.grpo.TimeoutChecker") as mock_timeout_class: + mock_timeout_instance = MagicMock() + check_results = [False] * 7 + [True] + mock_timeout_instance.check_save.side_effect = check_results + mock_timeout_class.return_value = mock_timeout_instance + + # Use our helper for async, or simple mocking for sync + if train_func == async_grpo_train: + with mock_async_grpo_infrastructure(mock_batch, mock_rollout_metrics): + train_func( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) + else: + with patch( + "nemo_rl.algorithms.grpo.run_multi_turn_rollout", + return_value=(mock_batch, mock_rollout_metrics), + ): + with patch( + "nemo_rl.algorithms.grpo.run_async_multi_turn_rollout", + return_value=(mock_batch, mock_rollout_metrics), + ): + train_func( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) + + # Verify training stopped at 8 steps (when check_save returned True) + assert mock_grpo_components["policy"].train.call_count == 8 + + # Verify the timeout message was printed and training actually stopped + captured = capsys.readouterr() + output_lines = captured.out.strip().split("\n") + + # Find the timeout message + timeout_line_idx = None + for i, line in enumerate(output_lines): + if "Timeout has been reached, stopping training early" in line: + timeout_line_idx = i + break + + assert timeout_line_idx is not None, "Timeout message not found in output" + + # Check what comes after the timeout message + remaining_lines = output_lines[timeout_line_idx + 1 :] + + # For async_grpo_train, we expect cleanup messages in the finally block + if train_func.__name__ == "async_grpo_train": + cleanup_found = any( + "Stopping trajectory collection" in line + or "Async GRPO training complete" in line + for line in remaining_lines + ) + assert cleanup_found, ( + "Expected cleanup messages after timeout in async mode" + ) + + # Verify no new epoch/step started after timeout + for line in remaining_lines: + assert "Epoch" not in line or "Epoch 1/10" in line, ( + f"Training continued to next epoch after timeout: {line}" + ) + assert not (line.startswith("Step ") and "Step 9" in line), ( + f"Training continued to next step after timeout: {line}" + ) diff --git a/tests/unit/algorithms/test_rm.py b/tests/unit/algorithms/test_rm.py index 149b37d709..8dabbedcb3 100644 --- a/tests/unit/algorithms/test_rm.py +++ b/tests/unit/algorithms/test_rm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch @@ -173,3 +173,57 @@ def test_exit_on_max_epochs(mock_components): # Verify we trained for exactly two epochs (20 batches). assert mock_components["policy"].train.call_count == 20 + + +def test_exit_on_timeout(mock_components, capsys): + """Test that training loop exits when timeout is reached""" + # Set max steps and epochs to large numbers + mock_components["master_config"]["rm"]["max_num_steps"] = 100 + mock_components["master_config"]["rm"]["max_num_epochs"] = 10 + + rm_save_state = _default_rm_save_state() + + # Mock TimeoutChecker to return False for first 7 checks, then True (timeout) + with patch("nemo_rl.algorithms.rm.TimeoutChecker") as mock_timeout_class: + mock_timeout_instance = MagicMock() + # Create a side_effect that returns False 7 times, then True + check_results = [False] * 7 + [True] + mock_timeout_instance.check_save.side_effect = check_results + mock_timeout_class.return_value = mock_timeout_instance + + # Run training + rm_train( + mock_components["policy"], + mock_components["train_dataloader"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["loss_fn"], + mock_components["master_config"], + mock_components["logger"], + mock_components["rm_task_spec"], + mock_components["checkpointer"], + rm_save_state, + ) + + # Verify training stopped at 8 steps (when check_save returned True) + assert mock_components["policy"].train.call_count == 8 + + # Verify the timeout message was printed and is near the end (not followed by more training) + captured = capsys.readouterr() + output_lines = captured.out.strip().split("\n") + + # Find the timeout message + timeout_line_idx = None + for i, line in enumerate(output_lines): + if "Timeout has been reached, stopping training early" in line: + timeout_line_idx = i + break + + assert timeout_line_idx is not None, "Timeout message not found in output" + + # Verify no new epoch started after timeout (which would indicate a bug where break was used instead of return) + remaining_lines = output_lines[timeout_line_idx:] + for line in remaining_lines: + assert "Epoch" not in line or "Epoch 1/10" in line, ( + f"Training continued to next epoch after timeout: {line}" + ) diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index df48537c37..e43630651e 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch @@ -151,3 +151,57 @@ def test_exit_on_max_epochs(mock_components): # Verify we trained for exactly two epochs (20 batches). assert mock_components["policy"].train.call_count == 20 + + +def test_exit_on_timeout(mock_components, capsys): + """Test that training loop exits when timeout is reached""" + # Set max steps and epochs to large numbers + mock_components["master_config"]["sft"]["max_num_steps"] = 100 + mock_components["master_config"]["sft"]["max_num_epochs"] = 10 + + sft_save_state = _default_sft_save_state() + + # Mock TimeoutChecker to return False for first 7 checks, then True (timeout) + with patch("nemo_rl.algorithms.sft.TimeoutChecker") as mock_timeout_class: + mock_timeout_instance = MagicMock() + # Create a side_effect that returns False 7 times, then True + check_results = [False] * 7 + [True] + mock_timeout_instance.check_save.side_effect = check_results + mock_timeout_class.return_value = mock_timeout_instance + + # Run training + sft_train( + mock_components["policy"], + mock_components["train_dataloader"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["loss_fn"], + mock_components["master_config"], + mock_components["logger"], + mock_components["sft_task_spec"], + mock_components["checkpointer"], + sft_save_state, + ) + + # Verify training stopped at 8 steps (when check_save returned True) + assert mock_components["policy"].train.call_count == 8 + + # Verify the timeout message was printed and is near the end (not followed by more training) + captured = capsys.readouterr() + output_lines = captured.out.strip().split("\n") + + # Find the timeout message + timeout_line_idx = None + for i, line in enumerate(output_lines): + if "Timeout has been reached, stopping training early" in line: + timeout_line_idx = i + break + + assert timeout_line_idx is not None, "Timeout message not found in output" + + # Verify no new epoch started after timeout (which would indicate a bug where break was used instead of return) + remaining_lines = output_lines[timeout_line_idx:] + for line in remaining_lines: + assert "Epoch" not in line or "Epoch 1/10" in line, ( + f"Training continued to next epoch after timeout: {line}" + )