Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,15 @@ def distillation_train(

timer.reset()
step += 1
if should_save_by_timeout:
print("Timeout has been reached, stopping training early", flush=True)
return
if step >= max_steps:
break
print(
"Max number of steps has been reached, stopping training early",
flush=True,
)
return


def validate(
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,13 @@ def dpo_train(
total_steps += 1

if should_save_by_timeout:
print("Timeout has been reached, stopping training early", flush=True)
return
if total_steps >= master_config["dpo"]["max_num_steps"]:
print(
"Max number of steps has been reached, stopping training early",
flush=True,
)
return

current_epoch += 1
Expand Down
34 changes: 31 additions & 3 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,9 +1025,14 @@ def grpo_train(
current_step += 1
total_steps += 1
if should_save_by_timeout:
break
print("Timeout has been reached, stopping training early", flush=True)
return
if total_steps >= max_num_steps:
break
print(
"Max number of steps has been reached, stopping training early",
flush=True,
)
return

current_epoch += 1
current_step = 0 # Reset step counter for new epoch
Expand Down Expand Up @@ -1195,6 +1200,11 @@ def async_grpo_train(
from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer

timer = Timer()
timeout = TimeoutChecker(
timeout=master_config["checkpointing"]["checkpoint_must_save_by"],
fit_last_save_time=True,
)
timeout.start_iterations()
NEED_REFIT = True

# Setup generation interface
Expand Down Expand Up @@ -1684,9 +1694,18 @@ def async_grpo_train(

# Checkpointing (same as sync version)
consumed_samples += master_config["grpo"]["num_prompts_per_step"]
if master_config["checkpointing"]["enabled"] and (
timeout.mark_iteration()

should_save_by_step = (
is_last_step
or (step + 1) % master_config["checkpointing"]["save_period"] == 0
)
# +1 because step is 0-indexed
# Check if timeout-based checkpointing is enabled in config.
should_save_by_timeout = timeout.check_save()

if master_config["checkpointing"]["enabled"] and (
should_save_by_step or should_save_by_timeout
):
policy.prepare_for_training()

Expand Down Expand Up @@ -1788,6 +1807,15 @@ def async_grpo_train(

timer.reset()
step += 1
if should_save_by_timeout:
print("Timeout has been reached, stopping training early", flush=True)
return
if step >= master_config["grpo"]["max_num_steps"]:
print(
"Max number of steps has been reached, stopping training early",
flush=True,
)
return

finally:
# Clean up
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,11 +646,16 @@ def rm_train(
total_steps += 1

if should_save_by_timeout:
print("Timeout has been reached, stopping training early", flush=True)
return
if (
master_config["rm"]["max_num_steps"] != -1
and total_steps >= master_config["rm"]["max_num_steps"]
):
print(
"Max number of steps has been reached, stopping training early",
flush=True,
)
return

current_epoch += 1
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,13 @@ def sft_train(
total_steps += 1

if should_save_by_timeout:
print("Timeout has been reached, stopping training early", flush=True)
return
if total_steps >= master_config["sft"]["max_num_steps"]:
print(
"Max number of steps has been reached, stopping training early",
flush=True,
)
return

current_epoch += 1
Expand Down
59 changes: 58 additions & 1 deletion tests/unit/algorithms/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
import torch
Expand Down Expand Up @@ -212,6 +212,63 @@ def test_distillation_train_max_steps(mock_components):
assert mock_components["student_policy"].train.call_count == 5


def test_exit_on_timeout(mock_components, capsys):
"""Test that training loop exits when timeout is reached"""
# Set max steps to large number
mock_components["master_config"]["distillation"]["max_num_steps"] = 100

distillation_save_state = _default_distillation_save_state()

# Mock TimeoutChecker to return False for first 7 checks, then True (timeout)
with patch("nemo_rl.algorithms.distillation.TimeoutChecker") as mock_timeout_class:
mock_timeout_instance = MagicMock()
# Create a side_effect that returns False 7 times, then True
check_results = [False] * 7 + [True]
mock_timeout_instance.check_save.side_effect = check_results
mock_timeout_class.return_value = mock_timeout_instance

# Run training
distillation_train(
mock_components["student_policy"],
mock_components["teacher_policy"],
mock_components["student_generation"],
mock_components["train_dataloader"],
mock_components["val_dataloader"],
mock_components["tokenizer"],
mock_components["loss_fn"],
mock_components["task_to_env"],
mock_components["val_task_to_env"],
mock_components["logger"],
mock_components["checkpointer"],
distillation_save_state,
mock_components["master_config"],
)

# Verify training stopped at 8 steps (when check_save returned True)
assert mock_components["student_policy"].train.call_count == 8

# Verify the timeout message was printed and training actually stopped
captured = capsys.readouterr()
output_lines = captured.out.strip().split("\n")

# Find the timeout message
timeout_line_idx = None
for i, line in enumerate(output_lines):
if "Timeout has been reached, stopping training early" in line:
timeout_line_idx = i
break

assert timeout_line_idx is not None, "Timeout message not found in output"

# For distillation, verify we don't see more step messages after timeout
remaining_lines = output_lines[timeout_line_idx:]
for line in remaining_lines:
# Distillation doesn't have epochs, but check for step markers
assert not line.startswith("Step ") or "Step 8" in line, (
f"Training continued after timeout: {line}"
)


def test_validate_function(mock_components):
"""Test independent validation function to ensure validation logic correctness."""
# Run validation
Expand Down
Loading
Loading