Skip to content

fix: Handle disabled validation in SFT training#1611

Merged
terrykong merged 2 commits intoNVIDIA-NeMo:mainfrom
sahgerlad:fix/sft-disabled-validation
Dec 19, 2025
Merged

fix: Handle disabled validation in SFT training#1611
terrykong merged 2 commits intoNVIDIA-NeMo:mainfrom
sahgerlad:fix/sft-disabled-validation

Conversation

@sahgerlad
Copy link
Contributor

@sahgerlad sahgerlad commented Dec 8, 2025

What does this PR do ?

Handle disabled validation in SFT training

Summary by CodeRabbit

  • Bug Fixes
    • Improved robustness when validation data is not provided
    • Fixed potential errors in token counting and validation timing metric calculations
    • Enhanced validation configuration error messages

✏️ Tip: You can customize this high-level summary in your review settings.

@sahgerlad sahgerlad requested a review from a team as a code owner December 8, 2025 18:23
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 8, 2025

📝 Walkthrough

Walkthrough

This pull request improves robustness in the SFT algorithm implementation by conditionally instantiating the validation dataloader only when a validation dataset exists, adjusting validation flow assertions, and using safer dictionary access patterns with defaults to prevent KeyErrors when metrics are unavailable.

Changes

Cohort / File(s) Summary
Validation dataloader and metrics robustness
nemo_rl/algorithms/sft.py
Conditional creation of val_dataloader (only if val_dataset is not None); updated validation assertion to reference sft.val_period instead of dpo.val_period; replaced direct dictionary indexing with .get() calls for global_valid_toks metric access to prevent KeyErrors in token counting and timing computation

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

  • Defensive programming improvements: All .get() substitutions follow a consistent pattern and are low-risk
  • Conditional dataloader logic: Straightforward null-check guard with no complex downstream dependencies
  • Assertion correction: Verify that the assertion change (from dpo.val_period to sft.val_period) aligns with the correct module context and intended validation behavior

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'fix: Handle disabled validation in SFT training' directly and clearly describes the main change: handling disabled validation in SFT training, which aligns with the PR objectives and file modifications.
Test Results For Major Changes ✅ Passed PR contains only minor changes: bug fixes (incorrect variable reference) and defensive programming (None checks, safe dictionary access). No new features, breaking changes, or performance/convergence impacts.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/algorithms/sft.py (1)

234-252: Ensure validate() always returns a 2‑tuple when called, even when val_dataloader is None

With the new early return:

if val_dataloader is None:
    assert master_config["sft"]["val_period"] <= 0, (
        "val_dataloader is None, so sft.val_period must be <= 0"
    )
    print("  ⚠️ No validation dataloader provided, skipping validation")
    return

validate() returns None when val_dataloader is None and sft.val_period <= 0. In sft_train() you always unpack:

val_metrics, validation_timings = validate(...)

This will raise TypeError: cannot unpack non-iterable NoneType object in configurations like:

  • val_dataset is None
  • sft.val_period <= 0 (validation disabled)
  • sft.val_at_start is True (so the initial validation call runs)

To make this path safe while preserving the assertion for misconfigurations, have validate() always return a 2‑tuple when it returns normally, e.g.:

-    if val_dataloader is None:
-        assert master_config["sft"]["val_period"] <= 0, (
-            "val_dataloader is None, so sft.val_period must be <= 0"
-        )
-        print("  ⚠️ No validation dataloader provided, skipping validation")
-        return
+    if val_dataloader is None:
+        assert master_config["sft"]["val_period"] <= 0, (
+            "val_dataloader is None, so sft.val_period must be <= 0"
+        )
+        print("  ⚠️ No validation dataloader provided, skipping validation")
+        # Return empty metric dicts so callers that unpack and log still work.
+        return {}, {}

Callers then safely unpack and log, and you still get a hard failure if val_period > 0 with no validation dataloader.

🧹 Nitpick comments (2)
nemo_rl/algorithms/sft.py (2)

146-162: Conditional val_dataloader creation is good; update typing to reflect Optional usage

Making val_dataloader conditional on val_dataset is not None correctly avoids constructing a dataloader when validation data is absent. However, the type hints now no longer match behavior:

  • setup() still annotates val_dataset: AllTaskProcessedDataset and the 4th return element as StatefulDataLoader, but it can now be None.
  • validate() still takes val_dataloader: StatefulDataLoader, though it explicitly handles None.

To keep static typing accurate, consider:

-def setup(
-    master_config: MasterConfig,
-    tokenizer: AutoTokenizer,
-    train_dataset: AllTaskProcessedDataset,
-    val_dataset: AllTaskProcessedDataset,
-) -> tuple[
-    Policy,
-    RayVirtualCluster,
-    StatefulDataLoader,
-    StatefulDataLoader,
+def setup(
+    master_config: MasterConfig,
+    tokenizer: AutoTokenizer,
+    train_dataset: AllTaskProcessedDataset,
+    val_dataset: Optional[AllTaskProcessedDataset],
+) -> tuple[
+    Policy,
+    RayVirtualCluster,
+    StatefulDataLoader,
+    Optional[StatefulDataLoader],

and in validate:

-def validate(
-    policy: PolicyInterface,
-    val_dataloader: StatefulDataLoader,
+def validate(
+    policy: PolicyInterface,
+    val_dataloader: Optional[StatefulDataLoader],

This keeps annotations aligned with the new “validation may be disabled” behavior.


612-618: Guard valid_tokens_per_sec_per_gpu against zero total_time

Using .get here avoids a KeyError when global_valid_toks is missing:

timing_metrics["valid_tokens_per_sec_per_gpu"] = (
    metrics.get("global_valid_toks", 0) / total_time / total_num_gpus
)

However, if total_time is 0 (e.g., very fast steps or missing timer entries), this will still raise ZeroDivisionError. You already guard against this when printing percentages above, so you can mirror that pattern:

-            timing_metrics["valid_tokens_per_sec_per_gpu"] = (
-                metrics.get("global_valid_toks", 0) / total_time / total_num_gpus
-            )
+            if total_time > 0:
+                timing_metrics["valid_tokens_per_sec_per_gpu"] = (
+                    metrics.get("global_valid_toks", 0) / total_time / total_num_gpus
+                )
+            else:
+                timing_metrics["valid_tokens_per_sec_per_gpu"] = 0.0

This keeps logging robust even in edge timing cases.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 64ab08d and 25a0edd.

📒 Files selected for processing (1)
  • nemo_rl/algorithms/sft.py (4 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/algorithms/sft.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/algorithms/sft.py
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • nemo_rl/algorithms/sft.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/algorithms/sft.py
🧬 Code graph analysis (1)
nemo_rl/algorithms/sft.py (1)
nemo_rl/data/collate_fn.py (1)
  • rl_collate_fn (29-73)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Post automodel integration comment / Comment on PR
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/algorithms/sft.py (1)

488-503: Using metrics.get("global_valid_toks", 0) for accumulation is a solid robustness improvement

Switching from direct indexing to:

total_valid_tokens += metrics.get("global_valid_toks", 0)

prevents a KeyError when global_valid_toks is absent from train_results["all_mb_metrics"], while preserving the previous behavior when it is present. This also keeps backward compatibility with older checkpoints via the default initialization of total_valid_tokens.

No further changes needed here.

Copy link
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashors1 to review

@terrykong terrykong requested a review from ashors1 December 10, 2025 19:34
@sahgerlad sahgerlad force-pushed the fix/sft-disabled-validation branch 5 times, most recently from 71f4c13 to 74af223 Compare December 12, 2025 07:25
@sahgerlad sahgerlad requested a review from a team as a code owner December 12, 2025 18:26
@sahgerlad sahgerlad force-pushed the fix/sft-disabled-validation branch 2 times, most recently from ea2aa6a to c3ebd12 Compare December 12, 2025 19:09
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Tests that SFT training works correctly when validation is disabled
by passing val_dataloader=None with val_period set to 0 or negative.

Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
@sahgerlad sahgerlad force-pushed the fix/sft-disabled-validation branch from c3ebd12 to 5368ed4 Compare December 15, 2025 15:14
Copy link
Contributor

@ashors1 ashors1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thanks for the PR!

@terrykong terrykong added the CI:L1 Run doctests, unit tests, and functional tests label Dec 18, 2025
@terrykong terrykong enabled auto-merge (squash) December 18, 2025 23:18
@terrykong terrykong merged commit 4794ca7 into NVIDIA-NeMo:main Dec 19, 2025
39 of 41 checks passed
@sahgerlad sahgerlad deleted the fix/sft-disabled-validation branch December 19, 2025 03:02
DeL-TaiseiOzaki pushed a commit to DeL-TaiseiOzaki/RL that referenced this pull request Jan 8, 2026
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
parthmannan pushed a commit to parthmannan/RL that referenced this pull request Jan 15, 2026
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Signed-off-by: Parth Mannan <pmannan@nvidia.com>
xavier-owkin pushed a commit to owkin/Owkin-NeMo-RL that referenced this pull request Feb 10, 2026
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 12, 2026
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
seonjinn pushed a commit that referenced this pull request Mar 9, 2026
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests community-request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants