fix: Handle disabled validation in SFT training#1611
fix: Handle disabled validation in SFT training#1611terrykong merged 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
📝 WalkthroughWalkthroughThis pull request improves robustness in the SFT algorithm implementation by conditionally instantiating the validation dataloader only when a validation dataset exists, adjusting validation flow assertions, and using safer dictionary access patterns with defaults to prevent KeyErrors when metrics are unavailable. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/algorithms/sft.py (1)
234-252: Ensurevalidate()always returns a 2‑tuple when called, even whenval_dataloaderisNoneWith the new early return:
if val_dataloader is None: assert master_config["sft"]["val_period"] <= 0, ( "val_dataloader is None, so sft.val_period must be <= 0" ) print(" ⚠️ No validation dataloader provided, skipping validation") return
validate()returnsNonewhenval_dataloader is Noneandsft.val_period <= 0. Insft_train()you always unpack:val_metrics, validation_timings = validate(...)This will raise
TypeError: cannot unpack non-iterable NoneType objectin configurations like:
val_dataset is Nonesft.val_period <= 0(validation disabled)sft.val_at_start is True(so the initial validation call runs)To make this path safe while preserving the assertion for misconfigurations, have
validate()always return a 2‑tuple when it returns normally, e.g.:- if val_dataloader is None: - assert master_config["sft"]["val_period"] <= 0, ( - "val_dataloader is None, so sft.val_period must be <= 0" - ) - print(" ⚠️ No validation dataloader provided, skipping validation") - return + if val_dataloader is None: + assert master_config["sft"]["val_period"] <= 0, ( + "val_dataloader is None, so sft.val_period must be <= 0" + ) + print(" ⚠️ No validation dataloader provided, skipping validation") + # Return empty metric dicts so callers that unpack and log still work. + return {}, {}Callers then safely unpack and log, and you still get a hard failure if
val_period > 0with no validation dataloader.
🧹 Nitpick comments (2)
nemo_rl/algorithms/sft.py (2)
146-162: Conditionalval_dataloadercreation is good; update typing to reflectOptionalusageMaking
val_dataloaderconditional onval_dataset is not Nonecorrectly avoids constructing a dataloader when validation data is absent. However, the type hints now no longer match behavior:
setup()still annotatesval_dataset: AllTaskProcessedDatasetand the 4th return element asStatefulDataLoader, but it can now beNone.validate()still takesval_dataloader: StatefulDataLoader, though it explicitly handlesNone.To keep static typing accurate, consider:
-def setup( - master_config: MasterConfig, - tokenizer: AutoTokenizer, - train_dataset: AllTaskProcessedDataset, - val_dataset: AllTaskProcessedDataset, -) -> tuple[ - Policy, - RayVirtualCluster, - StatefulDataLoader, - StatefulDataLoader, +def setup( + master_config: MasterConfig, + tokenizer: AutoTokenizer, + train_dataset: AllTaskProcessedDataset, + val_dataset: Optional[AllTaskProcessedDataset], +) -> tuple[ + Policy, + RayVirtualCluster, + StatefulDataLoader, + Optional[StatefulDataLoader],and in
validate:-def validate( - policy: PolicyInterface, - val_dataloader: StatefulDataLoader, +def validate( + policy: PolicyInterface, + val_dataloader: Optional[StatefulDataLoader],This keeps annotations aligned with the new “validation may be disabled” behavior.
612-618: Guardvalid_tokens_per_sec_per_gpuagainst zerototal_timeUsing
.gethere avoids aKeyErrorwhenglobal_valid_toksis missing:timing_metrics["valid_tokens_per_sec_per_gpu"] = ( metrics.get("global_valid_toks", 0) / total_time / total_num_gpus )However, if
total_timeis 0 (e.g., very fast steps or missing timer entries), this will still raiseZeroDivisionError. You already guard against this when printing percentages above, so you can mirror that pattern:- timing_metrics["valid_tokens_per_sec_per_gpu"] = ( - metrics.get("global_valid_toks", 0) / total_time / total_num_gpus - ) + if total_time > 0: + timing_metrics["valid_tokens_per_sec_per_gpu"] = ( + metrics.get("global_valid_toks", 0) / total_time / total_num_gpus + ) + else: + timing_metrics["valid_tokens_per_sec_per_gpu"] = 0.0This keeps logging robust even in edge timing cases.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
nemo_rl/algorithms/sft.py(4 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/algorithms/sft.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/algorithms/sft.py
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
nemo_rl/algorithms/sft.py
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
nemo_rl/algorithms/sft.py
🧬 Code graph analysis (1)
nemo_rl/algorithms/sft.py (1)
nemo_rl/data/collate_fn.py (1)
rl_collate_fn(29-73)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/algorithms/sft.py (1)
488-503: Usingmetrics.get("global_valid_toks", 0)for accumulation is a solid robustness improvementSwitching from direct indexing to:
total_valid_tokens += metrics.get("global_valid_toks", 0)prevents a
KeyErrorwhenglobal_valid_toksis absent fromtrain_results["all_mb_metrics"], while preserving the previous behavior when it is present. This also keeps backward compatibility with older checkpoints via the default initialization oftotal_valid_tokens.No further changes needed here.
25a0edd to
e7330f9
Compare
71f4c13 to
74af223
Compare
ea2aa6a to
c3ebd12
Compare
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Tests that SFT training works correctly when validation is disabled by passing val_dataloader=None with val_period set to 0 or negative. Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
c3ebd12 to
5368ed4
Compare
ashors1
left a comment
There was a problem hiding this comment.
Looks good to me! Thanks for the PR!
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Signed-off-by: Sahger Lad <lad.sahger@gmail.com> Signed-off-by: Parth Mannan <pmannan@nvidia.com>
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Signed-off-by: Sahger Lad <lad.sahger@gmail.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Signed-off-by: Sahger Lad <lad.sahger@gmail.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Signed-off-by: Sahger Lad <lad.sahger@gmail.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
What does this PR do ?
Handle disabled validation in SFT training
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.