Skip to content

fix: stop jobs after timeout and add warning for validation#1069

Merged
terrykong merged 9 commits intoNVIDIA-NeMo:mainfrom
wedu-nvidia:wedu/stop-gpu-job-timeout
Sep 13, 2025
Merged

fix: stop jobs after timeout and add warning for validation#1069
terrykong merged 9 commits intoNVIDIA-NeMo:mainfrom
wedu-nvidia:wedu/stop-gpu-job-timeout

Conversation

@wedu-nvidia
Copy link
Contributor

@wedu-nvidia wedu-nvidia commented Sep 4, 2025

What does this PR do ?

Stop job training after timeout and also add a warning message for validation when val_period > 0 and val_dataloader is None

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Time-based checkpointing with auto-save and early stop on timeout.
    • GRPO adds max_num_epochs, richer checkpoint state, and detailed per-epoch/step metrics.
    • Config adds checkpoint_must_save_by to control checkpoint timing.
  • Refactor

    • GRPO training loop reworked to be epoch-driven with integrated validation and checkpointing.
  • Bug Fixes

    • Enforced val_period=0 when no validation data to avoid inconsistent validation.
    • Early return after timeout-triggered saves to prevent extra training steps.
  • Tests

    • RM tests updated to include the new checkpoint config key.

Signed-off-by: Wei Du <wedu@nvidia.com>
Copy link
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Wei Du <wedu@nvidia.com>
@wedu-nvidia
Copy link
Contributor Author

@terrykong thanks, I added the timeout feature for rm as well

Signed-off-by: Wei Du <wedu@nvidia.com>
terrykong
terrykong previously approved these changes Sep 4, 2025
@terrykong terrykong enabled auto-merge September 4, 2025 16:54
@terrykong terrykong added this pull request to the merge queue Sep 4, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Sep 4, 2025
Signed-off-by: Wei Du <wedu@nvidia.com>
@wedu-nvidia
Copy link
Contributor Author

@terrykong I push another commit to fix the bug for unit test, can you put it into merged queue again, thanks.

Signed-off-by: Wei Du <wedu@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 10, 2025

Walkthrough

Adds a timeout-based checkpointing key and TimeoutChecker integration; enforces validation-period assertions when no val dataloader is provided; introduces early-return-on-timeout in training loops (RM, SFT, DPO); overhauls GRPO to epoch-driven training with expanded save state, logging, and Megatron train_iters propagation; tests and example config updated.

Changes

Cohort / File(s) Summary of Changes
Config: RM checkpointing
examples/configs/rm.yaml
Added rm.checkpointing.checkpoint_must_save_by: null.
RM: Timeout-based checkpointing
nemo_rl/algorithms/rm.py
Added TimeoutChecker usage; start/mark iterations; combined step+timeout save logic; update/save rm_save_state; prune stale val keys; early return when timeout-triggered save occurs; added assertion for absent val dataloader vs dpo val_period.
DPO: Validation assertion & timeout exit
nemo_rl/algorithms/dpo.py
Assert val_dataloader is None => master_config["dpo"]["val_period"] == 0; early return from dpo_train after updating total_steps when timeout-based save triggered.
SFT: Validation assertion & timeout exit
nemo_rl/algorithms/sft.py
Same validation assertion as DPO; early return from sft_train after a step when timeout-based save is triggered.
GRPO: Epoch-driven training & state
nemo_rl/algorithms/grpo.py
Added max_num_epochs to GRPOConfig; replaced single step save-state with consumed_samples, current_step, current_epoch, total_steps; switched to epoch-driven loop with epoch/step/total tracking; converted message types and generation flow; prepare/compute/train phases; propagated train_iters to Megatron; richer logging/metrics and checkpointing; guarded validation when no dataloader and DPO val period non-zero.
Tests: RM config
tests/unit/algorithms/test_rm.py
Extended master_config['rm']['checkpointing'] with checkpoint_must_save_by: None.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer as RM/SFT/DPO Trainer
  participant Timeout as TimeoutChecker
  participant CKPT as Checkpoint Manager

  Note over Trainer: Per-step loop
  Trainer->>Timeout: mark_iteration()
  Trainer->>Timeout: check_save()
  alt Timeout-triggered save
    Timeout-->>Trainer: should_save_by_timeout = true
    Trainer->>CKPT: save_checkpoint(timeout-based)
    CKPT-->>Trainer: saved
    Note over Trainer: Early return (exit training)
  else Step-periodic save
    Trainer->>CKPT: save_checkpoint(step-based)
    CKPT-->>Trainer: saved
  end
Loading
sequenceDiagram
  autonumber
  participant GRPO as GRPO Trainer
  participant Data as Dataloader
  participant Policy as Policy Model
  participant Ref as Reference Model
  participant CKPT as Checkpoint Manager
  participant Log as Logger

  Note over GRPO: while current_epoch < max_num_epochs and total_steps < max_num_steps
  GRPO->>Data: next batch (epoch loop)
  GRPO->>Policy: prepare_for_lp_inference
  GRPO->>Policy: generate (sync/async)
  GRPO->>GRPO: compute rewards/advantages
  GRPO->>Policy: prepare_for_training
  par Forward/Backward
    GRPO->>Policy: train step
    GRPO->>Ref: fprop logprobs (as needed)
  end
  GRPO->>Log: metrics (loss, reward, tokens, FLOPS, etc.)
  GRPO->>CKPT: maybe save (period/timeout)
  alt End of epoch
    GRPO->>GRPO: increment current_epoch, reset current_step
  else Continue
    GRPO->>GRPO: increment current_step and total_steps
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Pre-merge checks (2 passed, 1 warning)

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "fix: stop jobs after timeout and add warning for validation" succinctly and accurately summarizes the PR's primary changes—introducing timeout-driven job termination and adding a warning when validation is configured but no dataloader is provided—and matches the PR objectives and diffs. It is a single clear sentence, free of noise, and directly related to the changes in RM/SFT/DPO and related tests.

Poem

I tap my paw to ticking time,
Checkpoints bloom—then I resign! ⏱️
Epoch moons and steppy stars,
GRPO sails past token bars.
A rabbit signs off: “Ship it, fast!” 🐇✨

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6e9a88d and a15a0fe.

📒 Files selected for processing (1)
  • tests/unit/algorithms/test_rm.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/unit/algorithms/test_rm.py
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Wei Du <wedu@nvidia.com>
@wedu-nvidia
Copy link
Contributor Author

@terrykong I’ve resolved the conflict. Could you please add this to the merge queue, since others are actively updating and it may touch the same files? Thanks!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 08b1d22 and 6e9a88d.

📒 Files selected for processing (4)
  • examples/configs/rm.yaml (1 hunks)
  • nemo_rl/algorithms/dpo.py (2 hunks)
  • nemo_rl/algorithms/rm.py (5 hunks)
  • nemo_rl/algorithms/sft.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/configs/rm.yaml
  • nemo_rl/algorithms/dpo.py
🧰 Additional context used
🧬 Code graph analysis (1)
nemo_rl/algorithms/rm.py (2)
tests/unit/utils/test_timer.py (3)
  • timer (26-27)
  • TestTimeoutChecker (193-235)
  • test_double_save_prevented (204-208)
nemo_rl/utils/timer.py (4)
  • TimeoutChecker (264-321)
  • start_iterations (310-311)
  • mark_iteration (313-321)
  • check_save (284-308)
🔇 Additional comments (3)
nemo_rl/algorithms/sft.py (1)

584-585: LGTM! Timeout-based early exit is properly implemented.

The early return after timeout-based checkpoint saving ensures that training stops gracefully when the timeout is reached, which aligns with the PR objectives to stop jobs after timeout.

nemo_rl/algorithms/rm.py (2)

432-436: LGTM! TimeoutChecker integration follows best practices.

The timeout checker is properly initialized with the configuration value and fit_last_save_time enabled, which intelligently considers average iteration time when determining if a timeout is approaching.


633-634: LGTM! Timeout-based early termination is correctly implemented.

The early return after timeout ensures the training loop exits gracefully after saving a checkpoint, consistent with the implementation in the SFT module.

terrykong
terrykong previously approved these changes Sep 10, 2025
@terrykong terrykong enabled auto-merge September 10, 2025 15:42
@wedu-nvidia
Copy link
Contributor Author

@terrykong, I see it was enabled to auto-merge and why it still got stuck here?

@terrykong terrykong added this pull request to the merge queue Sep 12, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Sep 12, 2025
@terrykong
Copy link
Collaborator

@wedu-nvidia there was a test failure:

____________________________ test_exit_on_max_steps ____________________________

mock_components = {'checkpointer': <MagicMock id='140307635541616'>, 'logger': <MagicMock id='140303769157504'>, 'loss_fn': <nemo_rl.alg..._batch_size': 1}, 'rm': {'max_num_epochs': 2, 'max_num_steps': 12, 'val_at_start': False, 'val_batches': 1, ...}}, ...}

    def test_exit_on_max_steps(mock_components):
        """Test that training loop exits when max_num_steps is reached"""
        # Set max steps to 12, which is less than len(train_dataloader) * max_num_epochs
        mock_components["master_config"]["rm"]["max_num_steps"] = 12
    
        rm_save_state = _default_rm_save_state()
    
        # Run training
>       rm_train(
            mock_components["policy"],
            mock_components["train_dataloader"],
            mock_components["val_dataloader"],
            mock_components["tokenizer"],
            mock_components["loss_fn"],
            mock_components["master_config"],
            mock_components["logger"],
            mock_components["rm_task_spec"],
            mock_components["checkpointer"],
            rm_save_state,
        )

unit/algorithms/test_rm.py:129: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

policy = <MagicMock id='140307632684832'>
train_dataloader = <MagicMock spec='StatefulDataLoader' id='140306801623696'>
val_dataloader = <MagicMock spec='StatefulDataLoader' id='140307632837728'>
tokenizer = <MagicMock id='140306801922016'>
loss_fn = <nemo_rl.algorithms.loss_functions.PreferenceLoss object at 0x7f9b045209e0>
master_config = {'checkpointing': {'checkpoint_must_save_by': None, 'enabled': False}, 'policy': {'make_sequence_length_divisible_by':..._micro_batch_size': 1}, 'rm': {'max_num_epochs': 2, 'max_num_steps': 12, 'val_at_start': False, 'val_batches': 1, ...}}
logger = <MagicMock id='140303769157504'>
rm_task_spec = <MagicMock id='140307635545216'>
checkpointer = <MagicMock id='140307635541616'>
rm_save_state = {'consumed_samples': 1, 'epoch': 0, 'step': 0, 'total_steps': 0}

    def rm_train(
        policy,
        train_dataloader,
        val_dataloader,
        tokenizer,
        loss_fn,
        master_config,
        logger,
        rm_task_spec,
        checkpointer,
        rm_save_state,
    ):
        # Run basic rm training
        timer = Timer()
        timeout = TimeoutChecker(
            timeout=master_config["checkpointing"]["checkpoint_must_save_by"],
            fit_last_save_time=True,
        )
        timeout.start_iterations()
        if rm_save_state is None:
            rm_save_state = _default_rm_save_state()
            current_epoch = 0
            current_step = 0
            total_steps = 0
        else:
            current_epoch = rm_save_state["epoch"]
            current_step = rm_save_state["step"]
            total_steps = rm_save_state["total_steps"]
    
        rm_config = master_config["rm"]
        # Validation configuration
        val_period = rm_config["val_period"]
        val_at_start = rm_config["val_at_start"]
        max_num_epochs = rm_config["max_num_epochs"]
    
        # Run validation at the start if configured
        if val_at_start and total_steps == 0:
            print("\n🔍 Running initial validation...")
            val_metrics, validation_timings = validate(
                policy,
                val_dataloader,
                tokenizer,
                loss_fn,
                step=0,
                master_config=master_config,
                val_batches=rm_config["val_batches"],
                val_batch_size=rm_config["val_global_batch_size"],
                val_mbs=rm_config["val_micro_batch_size"],
                logger=logger,
            )
    
        policy.prepare_for_training()
    
        while current_epoch < max_num_epochs and (
            master_config["rm"]["max_num_steps"] == -1
            or total_steps < master_config["rm"]["max_num_steps"]
        ):
            print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}")
    
            for batch in train_dataloader:
                print(
                    f"\n{'=' * 25} Step {current_step + 1}/{min(len(train_dataloader), master_config['rm']['max_num_steps'] if master_config['rm']['max_num_steps'] != -1 else len(train_dataloader))} {'=' * 25}"
                )
                maybe_gpu_profile_step(policy, total_steps + 1)
                val_metrics, validation_timings = None, None
    
                with timer.time("total_step_time"):
                    # Prepare batch and generate responses
                    print("▶ Taking a training step...")
    
                    train_results = policy.train(
                        batch,
                        loss_fn,
                        eval_mode=False,
                        ## NOTE: we double the batch size here because each preference example corresponds to a pair of
                        ## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch.
                        gbs=master_config["policy"]["train_global_batch_size"] * 2,
                        mbs=master_config["policy"]["train_micro_batch_size"] * 2,
                    )
    
                    is_last_step = (
                        master_config["rm"]["max_num_steps"] != -1
                        and total_steps + 1 >= master_config["rm"]["max_num_steps"]
                    ) or (
                        current_epoch + 1 == max_num_epochs
                        and current_step + 1 == len(train_dataloader)
                    )
    
                    # Run validation if it's a validation step
                    if val_period > 0 and (total_steps + 1) % val_period == 0:
                        val_metrics, validation_timings = validate(
                            policy,
                            val_dataloader,
                            tokenizer,
                            loss_fn,
                            step=total_steps + 1,
                            master_config=master_config,
                            val_batches=rm_config["val_batches"],
                            val_batch_size=rm_config["val_global_batch_size"],
                            val_mbs=rm_config["val_micro_batch_size"],
                            logger=logger,
                        )
    
                    ## Checkpointing
                    timeout.mark_iteration()
    
                    rm_save_state["consumed_samples"] += master_config["policy"][
                        "train_global_batch_size"
                    ]
    
                    should_save_by_step = (
                        is_last_step
>                       or (total_steps + 1) % master_config["checkpointing"]["save_period"]
                                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        == 0
                    )
E                   KeyError: 'save_period'

../nemo_rl/algorithms/rm.py:530: KeyError

@wedu-nvidia
Copy link
Contributor Author

@terrykong Thanks for the info, I fixed the bug.

@terrykong terrykong enabled auto-merge September 12, 2025 15:46
@terrykong terrykong added this pull request to the merge queue Sep 12, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Sep 12, 2025
@wedu-nvidia
Copy link
Contributor Author

@terrykong it failed due to some network issues, maybe? can you please have a check when you have time

@terrykong terrykong added this pull request to the merge queue Sep 12, 2025
Merged via the queue into NVIDIA-NeMo:main with commit 94a3d49 Sep 13, 2025
24 checks passed
guyueh1 pushed a commit to guyueh1/NeMo-RL that referenced this pull request Sep 15, 2025
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants