Skip to content

cp: fix: grpo early exit edge case (1361) into r0.4.0#1364

Merged
terrykong merged 1 commit intor0.4.0from
cherry-pick-1361-r0.4.0
Oct 16, 2025
Merged

cp: fix: grpo early exit edge case (1361) into r0.4.0#1364
terrykong merged 1 commit intor0.4.0from
cherry-pick-1361-r0.4.0

Conversation

@chtruong814
Copy link
Contributor

@chtruong814 chtruong814 commented Oct 15, 2025

beep boop [🤖]: Hi @terrykong 👋,

we've cherry picked #1361 into  for you! 🚀

Please review and approve this cherry pick by your convenience!

Summary by CodeRabbit

  • New Features

    • Added timeout-based early stopping across training workflows, with clear runtime messages when training halts due to timeouts or reaching maximum steps.
    • Unified early-exit behavior for consistent, predictable termination of training across algorithms, including asynchronous paths.
    • Improved visibility into checkpoint-triggered exits via printed notifications.
  • Tests

    • Expanded unit test coverage to validate early stopping on timeouts, max steps, and max epochs.
    • Added tests for synchronous and asynchronous training paths, ensuring no further steps or epochs occur after early exit.

Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 15, 2025

📝 Walkthrough

Walkthrough

Adds timeout-aware early exits and messages across training loops. distillation switches from loop breaks to early returns on timeout/max-steps. dpo/rm/sft add print messages before existing early returns. grpo and async_grpo integrate TimeoutChecker-driven save/exit logic. Unit tests added/extended for timeout, max-steps, and max-epochs across algorithms, including async GRPO stubs.

Changes

Cohort / File(s) Summary of changes
Algorithms: Distillation early-exit behavior
nemo_rl/algorithms/distillation.py
Replace break-on-max-steps with early return; add early return on timeout with print/log messages. No data/validation changes.
Algorithms: DPO messaging
nemo_rl/algorithms/dpo.py
Add print messages before existing early returns for timeout and max-steps; no control-flow changes beyond messaging.
Algorithms: GRPO timeout integration (sync + async)
nemo_rl/algorithms/grpo.py
Integrate TimeoutChecker with step/iteration marking; compute should_save_by_timeout and combine with step-based saves; add printed messages and early returns on timeout and max-steps in both grpo_train and async_grpo_train; adjust async paths to respect timeout during validation/collection.
Algorithms: RM messaging
nemo_rl/algorithms/rm.py
Add print messages before early returns on timeout and max-steps; logic unchanged.
Algorithms: SFT messaging
nemo_rl/algorithms/sft.py
Add print messages before early returns on timeout and max-steps; logic unchanged.
Tests: Distillation timeout
tests/unit/algorithms/test_distillation.py
Add test_exit_on_timeout using patched TimeoutChecker; update imports to include patch; assert early stop and printed message.
Tests: DPO coverage (steps/epochs/timeout)
tests/unit/algorithms/test_dpo.py
Expand fixtures and imports; add tests for max-steps, max-epochs, and timeout with output assertions; import utilities/functions; scaffold mock components.
Tests: GRPO sync/async with stubs
tests/unit/algorithms/test_grpo.py
Add extensive non-Ray stubs and mocks; fixtures for GRPO components; tests for step-, epoch-, and timeout-based exits in grpo_train and async_grpo_train; verify cleanup/output.
Tests: RM timeout
tests/unit/algorithms/test_rm.py
Add test_exit_on_timeout using patched TimeoutChecker; assert call counts and output; import patch.
Tests: SFT timeout
tests/unit/algorithms/test_sft.py
Add test_exit_on_timeout with patched TimeoutChecker; assert early stop and no further epochs; import patch.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer
  participant TimeoutChecker as TimeoutChecker
  participant Checkpointer
  participant Logger

  Note over Trainer,TimeoutChecker: Synchronous training loop (distillation/dpo/rm/sft/grpo)
  Trainer->>TimeoutChecker: start(checkpoint_must_save_by)
  loop For each step
    Trainer->>Trainer: train_step()
    Trainer->>TimeoutChecker: mark_iteration()
    TimeoutChecker-->>Trainer: check_save() => should_save_by_timeout
    alt should_save_by_timeout
      Trainer->>Logger: print "Timeout reached, exiting"
      Trainer->>Checkpointer: save_if_needed()
      Trainer-->>Trainer: return
    else max_steps reached
      Trainer->>Logger: print "Max steps reached, exiting"
      Trainer->>Checkpointer: save_if_needed()
      Trainer-->>Trainer: return
    end
  end
Loading
sequenceDiagram
  autonumber
  participant AsyncTrainer
  participant TimeoutChecker as TimeoutChecker
  participant Collector as AsyncTrajectoryCollector (stubbed)
  participant Replay as ReplayBuffer (stubbed)
  participant Logger
  participant Checkpointer

  Note over AsyncTrainer,Collector: Async GRPO training loop
  AsyncTrainer->>TimeoutChecker: start(checkpoint_must_save_by)
  loop Steps/Epochs
    AsyncTrainer->>Collector: request_rollouts()
    Collector-->>AsyncTrainer: rollouts
    AsyncTrainer->>Replay: add(rollouts)
    AsyncTrainer->>TimeoutChecker: mark_iteration()
    TimeoutChecker-->>AsyncTrainer: check_save() => should_save_by_timeout
    alt should_save_by_timeout
      AsyncTrainer->>Logger: print "Timeout reached, exiting"
      AsyncTrainer->>Checkpointer: save_if_needed()
      AsyncTrainer-->>AsyncTrainer: return
    else max_steps reached
      AsyncTrainer->>Logger: print "Max steps reached, exiting"
      AsyncTrainer->>Checkpointer: save_if_needed()
      AsyncTrainer-->>AsyncTrainer: return
    end
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

r0.4.0, CI:L1

Suggested reviewers

  • terrykong
  • yfw

Pre-merge checks and finishing touches

❌ Failed checks (3 warnings)
Check name Status Explanation Resolution
Title Check ⚠️ Warning The title focuses on the cherry-pick operation instead of summarizing the actual fix and includes noise such as backticks, the PR number, and the target branch, so it does not clearly convey the primary change. Rename the PR to a concise summary of the change, for example “Fix GRPO early-exit edge case in timeout and step logic,” removing backticks, branch names, and PR numbers from the title.
Docstring Coverage ⚠️ Warning Docstring coverage is 70.73% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning The pull request introduces significant new behavior across multiple training algorithms by adding timeout-based early exits and extensive new tests, which could affect convergence or control flow, yet the PR description only notes the cherry-pick and reviewer request without any mention of test outcomes, regression validation, or performance data. Please update the PR description to summarize the testing performed—such as the newly added timeout and max-step unit tests—and include confirmation that these changes introduce no regressions in numerics, convergence, or performance, with any relevant before/after metrics or configurations.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch cherry-pick-1361-r0.4.0

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
nemo_rl/algorithms/rm.py (1)

648-659: Early-exit messaging added correctly

Print with flush before returns is good; behavior unchanged and test-friendly.

Consider also logging the reason via Logger at info level for structured logs (in addition to print) in a follow-up.

tests/unit/algorithms/test_grpo.py (3)

36-97: Lightweight stubs for async infra are well-designed

Property-based .remote mocks mimic Ray ObjectRef semantics adequately for unit tests. Keeps tests hermetic.

Optionally add brief docstrings on stub methods to clarify returned shapes for future maintainers.


700-762: GRPO termination tests (steps/epochs/timeout) are thorough

  • Max-steps and max-epochs counts verified.
  • Timeout test asserts message and no post-timeout progress; async variant checks cleanup messages.

Minor: to reduce duplication, extract shared rollout_metrics/mock_batch setup into a small helper within the test module.

Also applies to: 768-813, 816-923


64-64: Static analysis: unused args in test stubs

Benign in tests; no action needed. If you want silence linters, prefix unused args with underscore or add noqa comments.

Also applies to: 205-205, 593-593, 601-601

tests/unit/algorithms/test_dpo.py (1)

155-156: Optional: Remove unused self parameter.

The nested functions train_iter and val_iter don't use the self parameter. You can remove it to clean up the code.

Apply this diff:

-    def train_iter(self):
+    def train_iter():
         return iter([mock_batch] * 10)
     
     train_dataloader.__iter__ = train_iter
-    def val_iter(self):
+    def val_iter():
         return iter([mock_batch] * 10)
     
     val_dataloader.__iter__ = val_iter

Also applies to: 163-164

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a9b20d5 and 991393a.

📒 Files selected for processing (10)
  • nemo_rl/algorithms/distillation.py (1 hunks)
  • nemo_rl/algorithms/dpo.py (1 hunks)
  • nemo_rl/algorithms/grpo.py (4 hunks)
  • nemo_rl/algorithms/rm.py (1 hunks)
  • nemo_rl/algorithms/sft.py (1 hunks)
  • tests/unit/algorithms/test_distillation.py (2 hunks)
  • tests/unit/algorithms/test_dpo.py (2 hunks)
  • tests/unit/algorithms/test_grpo.py (3 hunks)
  • tests/unit/algorithms/test_rm.py (2 hunks)
  • tests/unit/algorithms/test_sft.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • tests/unit/algorithms/test_sft.py
  • tests/unit/algorithms/test_grpo.py
  • nemo_rl/algorithms/grpo.py
  • tests/unit/algorithms/test_rm.py
  • nemo_rl/algorithms/sft.py
  • nemo_rl/algorithms/distillation.py
  • tests/unit/algorithms/test_distillation.py
  • nemo_rl/algorithms/dpo.py
  • nemo_rl/algorithms/rm.py
  • tests/unit/algorithms/test_dpo.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/algorithms/sft.py
  • nemo_rl/algorithms/distillation.py
  • nemo_rl/algorithms/dpo.py
  • nemo_rl/algorithms/rm.py
🧬 Code graph analysis (5)
tests/unit/algorithms/test_sft.py (2)
nemo_rl/algorithms/sft.py (2)
  • _default_sft_save_state (56-63)
  • sft_train (347-612)
nemo_rl/utils/timer.py (1)
  • check_save (284-308)
nemo_rl/algorithms/grpo.py (1)
nemo_rl/utils/timer.py (4)
  • TimeoutChecker (264-321)
  • start_iterations (310-311)
  • mark_iteration (313-321)
  • check_save (284-308)
tests/unit/algorithms/test_rm.py (5)
tests/unit/algorithms/test_distillation.py (2)
  • test_exit_on_timeout (215-269)
  • mock_components (34-186)
tests/unit/algorithms/test_dpo.py (1)
  • test_exit_on_timeout (268-318)
tests/unit/algorithms/test_sft.py (2)
  • test_exit_on_timeout (156-207)
  • mock_components (26-102)
nemo_rl/algorithms/rm.py (2)
  • _default_rm_save_state (52-59)
  • rm_train (420-662)
nemo_rl/utils/timer.py (1)
  • check_save (284-308)
tests/unit/algorithms/test_distillation.py (3)
tests/unit/algorithms/test_dpo.py (1)
  • test_exit_on_timeout (268-318)
nemo_rl/algorithms/distillation.py (2)
  • _default_distillation_save_state (96-102)
  • distillation_train (468-849)
nemo_rl/utils/timer.py (1)
  • check_save (284-308)
tests/unit/algorithms/test_dpo.py (4)
nemo_rl/algorithms/dpo.py (3)
  • _default_dpo_save_state (51-58)
  • add_ref_logprobs_to_data (270-303)
  • dpo_train (486-741)
nemo_rl/algorithms/loss_functions.py (1)
  • PreferenceLoss (449-542)
nemo_rl/distributed/named_sharding.py (3)
  • NamedSharding (19-222)
  • layout (99-101)
  • names (84-86)
nemo_rl/distributed/batched_data_dict.py (1)
  • BatchedDataDict (75-860)
🪛 Ruff (0.14.0)
tests/unit/algorithms/test_grpo.py

64-64: Unused function argument: current_weight_version

(ARG001)


64-64: Unused function argument: max_age_steps

(ARG001)


205-205: Unused lambda argument: kwargs

(ARG005)


593-593: Unused function argument: self

(ARG001)


601-601: Unused function argument: self

(ARG001)

tests/unit/algorithms/test_dpo.py

155-155: Unused function argument: self

(ARG001)


163-163: Unused function argument: self

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post automodel integration comment / Comment on PR
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (18)
nemo_rl/algorithms/dpo.py (1)

730-738: Timeout/max-steps prints before returns look good

Clear, flushed messages right before early exit; no logic change.

tests/unit/algorithms/test_rm.py (2)

15-15: Import patch added

Necessary for TimeoutChecker mocking; looks good.


178-229: Timeout exit test is sound and isolated

Mocks TimeoutChecker, asserts step count and message presence; checks no further epochs/steps. Solid.

nemo_rl/algorithms/distillation.py (1)

841-849: Return-on-timeout/max-steps with flushed prints

Matches intended behavior; step increment happens immediately before exit, consistent with tests.

tests/unit/algorithms/test_distillation.py (1)

215-269: Timeout test mirrors training behavior well

Patches TimeoutChecker, verifies call count and message, and ensures no further steps after timeout. Good coverage.

nemo_rl/algorithms/grpo.py (4)

1028-1035: Synchronous GRPO: early-exit prints before return

Messages are flushed; no side effects. OK.


1203-1208: Async GRPO: TimeoutChecker initialization

start_iterations is called before mark_iteration; correct pattern.

Please confirm master_config["checkpointing"]["checkpoint_must_save_by"] is None or a valid "DD:HH:MM:SS" string at runtime to avoid parsing issues.


1698-1706: Async GRPO: timeout check integrated into checkpoint gating

mark_iteration then check_save is correct; gating by checkpointing flag for saves, independent early-exit below. Looks good.

Also applies to: 1707-1708


1810-1818: Async GRPO: early-exit prints and return

Prints before return; finally-block ensures cleanup. Good.

tests/unit/algorithms/test_grpo.py (1)

540-697: Fixture mock_grpo_components covers both sync/async needs

Includes proper batch structure, sharding annotations, configs (incl. async flags). Good scaffolding.

tests/unit/algorithms/test_dpo.py (4)

102-216: LGTM! Comprehensive test fixture.

The mock_dpo_components fixture provides a well-structured setup for DPO training tests, including properly mocked policy, dataloaders, tokenizer, loss function, and configuration.


219-241: LGTM! Test correctly validates max_num_steps termination.

The test properly verifies that training exits after exactly 12 steps when max_num_steps is reached, which is less than the total possible steps (2 epochs × 10 batches = 20).


243-265: LGTM! Test correctly validates max_num_epochs termination.

The test properly verifies that training completes exactly 2 epochs (20 batches) when max_num_epochs is reached before max_num_steps.


268-318: LGTM! Comprehensive timeout test with proper validation.

The test effectively validates timeout-driven early exit by:

  • Mocking TimeoutChecker to trigger timeout after 7 steps
  • Verifying training stops at exactly 8 steps
  • Confirming the timeout message appears in output
  • Ensuring no new epoch starts after timeout (which would indicate incorrect use of break instead of return)
nemo_rl/algorithms/sft.py (2)

602-602: LGTM! Clear timeout notification.

The print statement correctly notifies users that training is stopping due to timeout, with flush=True ensuring the message appears immediately before the function returns.


605-608: LGTM! Clear max_num_steps notification.

The print statement correctly notifies users that training is stopping due to reaching max_num_steps, with flush=True ensuring the message appears immediately before the function returns.

tests/unit/algorithms/test_sft.py (2)

15-15: LGTM! Required import for timeout test.

The patch import is necessary for mocking TimeoutChecker in the new test_exit_on_timeout test.


156-207: LGTM! Comprehensive timeout test.

The test effectively validates timeout-driven early exit behavior by:

  • Mocking TimeoutChecker to trigger timeout after 7 steps
  • Verifying training stops at exactly 8 steps
  • Confirming the timeout message appears in captured output
  • Ensuring no new epoch starts after timeout (which would indicate a bug using break instead of return)

The test pattern is consistent with similar tests in test_dpo.py and properly validates the early exit mechanism.

Comment on lines +148 to +239
def mock_async_grpo_infrastructure(mock_batch, mock_rollout_metrics):
"""
Context manager that mocks all async GRPO infrastructure (Ray actors, venv, etc).

Returns a dict of patches that can be used as a context manager stack.
"""
from contextlib import ExitStack

stack = ExitStack()

# Create stub instances with mock data
stub_buffer = StubReplayBuffer(
initial_size=10,
mock_batch=mock_batch,
mock_rollout_metrics=mock_rollout_metrics,
)
stub_collector = StubAsyncTrajectoryCollector()

# Patch venv creation
stack.enter_context(
patch(
"nemo_rl.algorithms.grpo.create_local_venv_on_each_node",
return_value="/fake/venv",
)
)
stack.enter_context(
patch(
"nemo_rl.algorithms.grpo.get_actor_python_env", return_value="/fake/python"
)
)

# Patch Ray actor classes to return our stubs
mock_buffer_cls = MagicMock()
mock_buffer_cls.options.return_value.remote.return_value = stub_buffer
stack.enter_context(
patch("nemo_rl.algorithms.async_utils.ReplayBuffer", mock_buffer_cls)
)

mock_collector_cls = MagicMock()
mock_collector_cls.options.return_value.remote.return_value = stub_collector
stack.enter_context(
patch(
"nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector",
mock_collector_cls,
)
)

# Patch ray.get to return values from our stubs (not remote refs)
def mock_ray_get(ref):
# If it's already a plain value (from our stubs), return it
if isinstance(ref, (int, str, dict, list)):
return ref
# If it's a MagicMock, return a default response
return None

stack.enter_context(patch("ray.get", side_effect=mock_ray_get))
stack.enter_context(
patch("ray.wait", side_effect=lambda refs, **kwargs: (refs, []))
)
stack.enter_context(
patch("ray.kill", return_value=None)
) # Mock ray.kill for cleanup

# Patch the rollout functions used inside async_grpo_train
stack.enter_context(
patch(
"nemo_rl.algorithms.grpo.run_multi_turn_rollout",
return_value=(mock_batch, mock_rollout_metrics),
)
)
stack.enter_context(
patch(
"nemo_rl.algorithms.grpo.run_async_multi_turn_rollout",
return_value=(mock_batch, mock_rollout_metrics),
)
)

# Patch refit and validate functions
stack.enter_context(
patch("nemo_rl.algorithms.grpo.refit_policy_generation", return_value=None)
)
stack.enter_context(
patch("nemo_rl.algorithms.grpo.validate", return_value=({}, {}))
)

# Mock print_performance_metrics to avoid needing real timing metrics
stack.enter_context(
patch("nemo_rl.algorithms.grpo.print_performance_metrics", return_value={})
)

return stack

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

mock_async_grpo_infrastructure: comprehensive patching with ExitStack

Patches venv, Ray actors, ray.get/wait/kill, rollouts, refit, validate, and metrics. Scoped and readable.

The mocked ray.get returns None for MagicMock refs; ensure all ray.get usages in async path either return primitives or are not relied on. Current tests pass, but a stray ray.get on a MagicMock could silently yield None. Consider asserting expected types in the helper to catch regressions.

🧰 Tools
🪛 Ruff (0.14.0)

205-205: Unused lambda argument: kwargs

(ARG005)

🤖 Prompt for AI Agents
In tests/unit/algorithms/test_grpo.py around lines 148-239, the mock_ray_get
helper currently returns None for MagicMock refs which can hide regressions if
production code unexpectedly calls ray.get on a MagicMock; update mock_ray_get
to validate returned types instead of silently returning None—either raise a
clear exception when encountering a MagicMock/ref type you don't expect or
assert that ref resolves to a primitive/collection and include a descriptive
error message so tests fail loudly when ray.get is used on unsupported objects.

@terrykong terrykong added the CI:L1 Run doctests, unit tests, and functional tests label Oct 15, 2025
@terrykong terrykong enabled auto-merge (squash) October 15, 2025 16:06
@terrykong terrykong merged commit 411d8db into r0.4.0 Oct 16, 2025
83 of 88 checks passed
@terrykong terrykong deleted the cherry-pick-1361-r0.4.0 branch October 16, 2025 09:31
terrykong added a commit that referenced this pull request Nov 19, 2025
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick CI:L1 Run doctests, unit tests, and functional tests Run CICD

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants