cp: fix: grpo early exit edge case (1361) into r0.4.0#1364
Conversation
Signed-off-by: Terry Kong <terryk@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
📝 WalkthroughWalkthroughAdds timeout-aware early exits and messages across training loops. distillation switches from loop breaks to early returns on timeout/max-steps. dpo/rm/sft add print messages before existing early returns. grpo and async_grpo integrate TimeoutChecker-driven save/exit logic. Unit tests added/extended for timeout, max-steps, and max-epochs across algorithms, including async GRPO stubs. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer
participant TimeoutChecker as TimeoutChecker
participant Checkpointer
participant Logger
Note over Trainer,TimeoutChecker: Synchronous training loop (distillation/dpo/rm/sft/grpo)
Trainer->>TimeoutChecker: start(checkpoint_must_save_by)
loop For each step
Trainer->>Trainer: train_step()
Trainer->>TimeoutChecker: mark_iteration()
TimeoutChecker-->>Trainer: check_save() => should_save_by_timeout
alt should_save_by_timeout
Trainer->>Logger: print "Timeout reached, exiting"
Trainer->>Checkpointer: save_if_needed()
Trainer-->>Trainer: return
else max_steps reached
Trainer->>Logger: print "Max steps reached, exiting"
Trainer->>Checkpointer: save_if_needed()
Trainer-->>Trainer: return
end
end
sequenceDiagram
autonumber
participant AsyncTrainer
participant TimeoutChecker as TimeoutChecker
participant Collector as AsyncTrajectoryCollector (stubbed)
participant Replay as ReplayBuffer (stubbed)
participant Logger
participant Checkpointer
Note over AsyncTrainer,Collector: Async GRPO training loop
AsyncTrainer->>TimeoutChecker: start(checkpoint_must_save_by)
loop Steps/Epochs
AsyncTrainer->>Collector: request_rollouts()
Collector-->>AsyncTrainer: rollouts
AsyncTrainer->>Replay: add(rollouts)
AsyncTrainer->>TimeoutChecker: mark_iteration()
TimeoutChecker-->>AsyncTrainer: check_save() => should_save_by_timeout
alt should_save_by_timeout
AsyncTrainer->>Logger: print "Timeout reached, exiting"
AsyncTrainer->>Checkpointer: save_if_needed()
AsyncTrainer-->>AsyncTrainer: return
else max_steps reached
AsyncTrainer->>Logger: print "Max steps reached, exiting"
AsyncTrainer->>Checkpointer: save_if_needed()
AsyncTrainer-->>AsyncTrainer: return
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (3 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (5)
nemo_rl/algorithms/rm.py (1)
648-659: Early-exit messaging added correctlyPrint with flush before returns is good; behavior unchanged and test-friendly.
Consider also logging the reason via Logger at info level for structured logs (in addition to print) in a follow-up.
tests/unit/algorithms/test_grpo.py (3)
36-97: Lightweight stubs for async infra are well-designedProperty-based .remote mocks mimic Ray ObjectRef semantics adequately for unit tests. Keeps tests hermetic.
Optionally add brief docstrings on stub methods to clarify returned shapes for future maintainers.
700-762: GRPO termination tests (steps/epochs/timeout) are thorough
- Max-steps and max-epochs counts verified.
- Timeout test asserts message and no post-timeout progress; async variant checks cleanup messages.
Minor: to reduce duplication, extract shared rollout_metrics/mock_batch setup into a small helper within the test module.
Also applies to: 768-813, 816-923
64-64: Static analysis: unused args in test stubsBenign in tests; no action needed. If you want silence linters, prefix unused args with underscore or add noqa comments.
Also applies to: 205-205, 593-593, 601-601
tests/unit/algorithms/test_dpo.py (1)
155-156: Optional: Remove unusedselfparameter.The nested functions
train_iterandval_iterdon't use theselfparameter. You can remove it to clean up the code.Apply this diff:
- def train_iter(self): + def train_iter(): return iter([mock_batch] * 10) train_dataloader.__iter__ = train_iter- def val_iter(self): + def val_iter(): return iter([mock_batch] * 10) val_dataloader.__iter__ = val_iterAlso applies to: 163-164
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
nemo_rl/algorithms/distillation.py(1 hunks)nemo_rl/algorithms/dpo.py(1 hunks)nemo_rl/algorithms/grpo.py(4 hunks)nemo_rl/algorithms/rm.py(1 hunks)nemo_rl/algorithms/sft.py(1 hunks)tests/unit/algorithms/test_distillation.py(2 hunks)tests/unit/algorithms/test_dpo.py(2 hunks)tests/unit/algorithms/test_grpo.py(3 hunks)tests/unit/algorithms/test_rm.py(2 hunks)tests/unit/algorithms/test_sft.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
tests/unit/algorithms/test_sft.pytests/unit/algorithms/test_grpo.pynemo_rl/algorithms/grpo.pytests/unit/algorithms/test_rm.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/distillation.pytests/unit/algorithms/test_distillation.pynemo_rl/algorithms/dpo.pynemo_rl/algorithms/rm.pytests/unit/algorithms/test_dpo.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/algorithms/grpo.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/distillation.pynemo_rl/algorithms/dpo.pynemo_rl/algorithms/rm.py
🧬 Code graph analysis (5)
tests/unit/algorithms/test_sft.py (2)
nemo_rl/algorithms/sft.py (2)
_default_sft_save_state(56-63)sft_train(347-612)nemo_rl/utils/timer.py (1)
check_save(284-308)
nemo_rl/algorithms/grpo.py (1)
nemo_rl/utils/timer.py (4)
TimeoutChecker(264-321)start_iterations(310-311)mark_iteration(313-321)check_save(284-308)
tests/unit/algorithms/test_rm.py (5)
tests/unit/algorithms/test_distillation.py (2)
test_exit_on_timeout(215-269)mock_components(34-186)tests/unit/algorithms/test_dpo.py (1)
test_exit_on_timeout(268-318)tests/unit/algorithms/test_sft.py (2)
test_exit_on_timeout(156-207)mock_components(26-102)nemo_rl/algorithms/rm.py (2)
_default_rm_save_state(52-59)rm_train(420-662)nemo_rl/utils/timer.py (1)
check_save(284-308)
tests/unit/algorithms/test_distillation.py (3)
tests/unit/algorithms/test_dpo.py (1)
test_exit_on_timeout(268-318)nemo_rl/algorithms/distillation.py (2)
_default_distillation_save_state(96-102)distillation_train(468-849)nemo_rl/utils/timer.py (1)
check_save(284-308)
tests/unit/algorithms/test_dpo.py (4)
nemo_rl/algorithms/dpo.py (3)
_default_dpo_save_state(51-58)add_ref_logprobs_to_data(270-303)dpo_train(486-741)nemo_rl/algorithms/loss_functions.py (1)
PreferenceLoss(449-542)nemo_rl/distributed/named_sharding.py (3)
NamedSharding(19-222)layout(99-101)names(84-86)nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-860)
🪛 Ruff (0.14.0)
tests/unit/algorithms/test_grpo.py
64-64: Unused function argument: current_weight_version
(ARG001)
64-64: Unused function argument: max_age_steps
(ARG001)
205-205: Unused lambda argument: kwargs
(ARG005)
593-593: Unused function argument: self
(ARG001)
601-601: Unused function argument: self
(ARG001)
tests/unit/algorithms/test_dpo.py
155-155: Unused function argument: self
(ARG001)
163-163: Unused function argument: self
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (18)
nemo_rl/algorithms/dpo.py (1)
730-738: Timeout/max-steps prints before returns look goodClear, flushed messages right before early exit; no logic change.
tests/unit/algorithms/test_rm.py (2)
15-15: Import patch addedNecessary for TimeoutChecker mocking; looks good.
178-229: Timeout exit test is sound and isolatedMocks TimeoutChecker, asserts step count and message presence; checks no further epochs/steps. Solid.
nemo_rl/algorithms/distillation.py (1)
841-849: Return-on-timeout/max-steps with flushed printsMatches intended behavior; step increment happens immediately before exit, consistent with tests.
tests/unit/algorithms/test_distillation.py (1)
215-269: Timeout test mirrors training behavior wellPatches TimeoutChecker, verifies call count and message, and ensures no further steps after timeout. Good coverage.
nemo_rl/algorithms/grpo.py (4)
1028-1035: Synchronous GRPO: early-exit prints before returnMessages are flushed; no side effects. OK.
1203-1208: Async GRPO: TimeoutChecker initializationstart_iterations is called before mark_iteration; correct pattern.
Please confirm master_config["checkpointing"]["checkpoint_must_save_by"] is None or a valid "DD:HH:MM:SS" string at runtime to avoid parsing issues.
1698-1706: Async GRPO: timeout check integrated into checkpoint gatingmark_iteration then check_save is correct; gating by checkpointing flag for saves, independent early-exit below. Looks good.
Also applies to: 1707-1708
1810-1818: Async GRPO: early-exit prints and returnPrints before return; finally-block ensures cleanup. Good.
tests/unit/algorithms/test_grpo.py (1)
540-697: Fixture mock_grpo_components covers both sync/async needsIncludes proper batch structure, sharding annotations, configs (incl. async flags). Good scaffolding.
tests/unit/algorithms/test_dpo.py (4)
102-216: LGTM! Comprehensive test fixture.The
mock_dpo_componentsfixture provides a well-structured setup for DPO training tests, including properly mocked policy, dataloaders, tokenizer, loss function, and configuration.
219-241: LGTM! Test correctly validates max_num_steps termination.The test properly verifies that training exits after exactly 12 steps when
max_num_stepsis reached, which is less than the total possible steps (2 epochs × 10 batches = 20).
243-265: LGTM! Test correctly validates max_num_epochs termination.The test properly verifies that training completes exactly 2 epochs (20 batches) when
max_num_epochsis reached beforemax_num_steps.
268-318: LGTM! Comprehensive timeout test with proper validation.The test effectively validates timeout-driven early exit by:
- Mocking
TimeoutCheckerto trigger timeout after 7 steps- Verifying training stops at exactly 8 steps
- Confirming the timeout message appears in output
- Ensuring no new epoch starts after timeout (which would indicate incorrect use of
breakinstead ofreturn)nemo_rl/algorithms/sft.py (2)
602-602: LGTM! Clear timeout notification.The print statement correctly notifies users that training is stopping due to timeout, with
flush=Trueensuring the message appears immediately before the function returns.
605-608: LGTM! Clear max_num_steps notification.The print statement correctly notifies users that training is stopping due to reaching max_num_steps, with
flush=Trueensuring the message appears immediately before the function returns.tests/unit/algorithms/test_sft.py (2)
15-15: LGTM! Required import for timeout test.The
patchimport is necessary for mockingTimeoutCheckerin the newtest_exit_on_timeouttest.
156-207: LGTM! Comprehensive timeout test.The test effectively validates timeout-driven early exit behavior by:
- Mocking
TimeoutCheckerto trigger timeout after 7 steps- Verifying training stops at exactly 8 steps
- Confirming the timeout message appears in captured output
- Ensuring no new epoch starts after timeout (which would indicate a bug using
breakinstead ofreturn)The test pattern is consistent with similar tests in
test_dpo.pyand properly validates the early exit mechanism.
| def mock_async_grpo_infrastructure(mock_batch, mock_rollout_metrics): | ||
| """ | ||
| Context manager that mocks all async GRPO infrastructure (Ray actors, venv, etc). | ||
|
|
||
| Returns a dict of patches that can be used as a context manager stack. | ||
| """ | ||
| from contextlib import ExitStack | ||
|
|
||
| stack = ExitStack() | ||
|
|
||
| # Create stub instances with mock data | ||
| stub_buffer = StubReplayBuffer( | ||
| initial_size=10, | ||
| mock_batch=mock_batch, | ||
| mock_rollout_metrics=mock_rollout_metrics, | ||
| ) | ||
| stub_collector = StubAsyncTrajectoryCollector() | ||
|
|
||
| # Patch venv creation | ||
| stack.enter_context( | ||
| patch( | ||
| "nemo_rl.algorithms.grpo.create_local_venv_on_each_node", | ||
| return_value="/fake/venv", | ||
| ) | ||
| ) | ||
| stack.enter_context( | ||
| patch( | ||
| "nemo_rl.algorithms.grpo.get_actor_python_env", return_value="/fake/python" | ||
| ) | ||
| ) | ||
|
|
||
| # Patch Ray actor classes to return our stubs | ||
| mock_buffer_cls = MagicMock() | ||
| mock_buffer_cls.options.return_value.remote.return_value = stub_buffer | ||
| stack.enter_context( | ||
| patch("nemo_rl.algorithms.async_utils.ReplayBuffer", mock_buffer_cls) | ||
| ) | ||
|
|
||
| mock_collector_cls = MagicMock() | ||
| mock_collector_cls.options.return_value.remote.return_value = stub_collector | ||
| stack.enter_context( | ||
| patch( | ||
| "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector", | ||
| mock_collector_cls, | ||
| ) | ||
| ) | ||
|
|
||
| # Patch ray.get to return values from our stubs (not remote refs) | ||
| def mock_ray_get(ref): | ||
| # If it's already a plain value (from our stubs), return it | ||
| if isinstance(ref, (int, str, dict, list)): | ||
| return ref | ||
| # If it's a MagicMock, return a default response | ||
| return None | ||
|
|
||
| stack.enter_context(patch("ray.get", side_effect=mock_ray_get)) | ||
| stack.enter_context( | ||
| patch("ray.wait", side_effect=lambda refs, **kwargs: (refs, [])) | ||
| ) | ||
| stack.enter_context( | ||
| patch("ray.kill", return_value=None) | ||
| ) # Mock ray.kill for cleanup | ||
|
|
||
| # Patch the rollout functions used inside async_grpo_train | ||
| stack.enter_context( | ||
| patch( | ||
| "nemo_rl.algorithms.grpo.run_multi_turn_rollout", | ||
| return_value=(mock_batch, mock_rollout_metrics), | ||
| ) | ||
| ) | ||
| stack.enter_context( | ||
| patch( | ||
| "nemo_rl.algorithms.grpo.run_async_multi_turn_rollout", | ||
| return_value=(mock_batch, mock_rollout_metrics), | ||
| ) | ||
| ) | ||
|
|
||
| # Patch refit and validate functions | ||
| stack.enter_context( | ||
| patch("nemo_rl.algorithms.grpo.refit_policy_generation", return_value=None) | ||
| ) | ||
| stack.enter_context( | ||
| patch("nemo_rl.algorithms.grpo.validate", return_value=({}, {})) | ||
| ) | ||
|
|
||
| # Mock print_performance_metrics to avoid needing real timing metrics | ||
| stack.enter_context( | ||
| patch("nemo_rl.algorithms.grpo.print_performance_metrics", return_value={}) | ||
| ) | ||
|
|
||
| return stack | ||
|
|
There was a problem hiding this comment.
mock_async_grpo_infrastructure: comprehensive patching with ExitStack
Patches venv, Ray actors, ray.get/wait/kill, rollouts, refit, validate, and metrics. Scoped and readable.
The mocked ray.get returns None for MagicMock refs; ensure all ray.get usages in async path either return primitives or are not relied on. Current tests pass, but a stray ray.get on a MagicMock could silently yield None. Consider asserting expected types in the helper to catch regressions.
🧰 Tools
🪛 Ruff (0.14.0)
205-205: Unused lambda argument: kwargs
(ARG005)
🤖 Prompt for AI Agents
In tests/unit/algorithms/test_grpo.py around lines 148-239, the mock_ray_get
helper currently returns None for MagicMock refs which can hide regressions if
production code unexpectedly calls ray.get on a MagicMock; update mock_ray_get
to validate returned types instead of silently returning None—either raise a
clear exception when encountering a MagicMock/ref type you don't expect or
assert that ref resolves to a primitive/collection and include a descriptive
error message so tests fail loudly when ray.get is used on unsupported objects.
beep boop [🤖]: Hi @terrykong 👋,
Summary by CodeRabbit
New Features
Tests