Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/distillation_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ distillation:
val_batch_size: 64
val_period: 20
val_at_start: false
val_at_end: false
max_val_samples: 512
topk_logits_k: 64
seed: 42
Expand Down
1 change: 1 addition & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dpo:
val_global_batch_size: 8
val_micro_batch_size: 1
val_at_start: true
val_at_end: false
seed: 42

reference_policy_kl_penalty: 0.05
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
use_leave_one_out_baseline: true
val_period: 10
val_at_start: false
val_at_end: false
overlong_filtering: false
max_val_samples: 256
val_batch_size: 256
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ grpo:
use_leave_one_out_baseline: true
val_period: 10
val_at_start: false
val_at_end: false
max_val_samples: 256
val_batch_size: 256
async_grpo:
Expand Down
1 change: 1 addition & 0 deletions examples/configs/rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ rm:
val_global_batch_size: 32
val_micro_batch_size: 1
val_at_start: false
val_at_end: false
seed: 42

checkpointing:
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ sft:
val_global_batch_size: 32
val_micro_batch_size: 1
val_at_start: true
val_at_end: false
seed: 42

checkpointing:
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ sft:
val_global_batch_size: 128
val_micro_batch_size: 2
val_at_start: true
val_at_end: false
seed: 42

checkpointing:
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft_openmathinstruct2_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ sft:
val_global_batch_size: 128
val_micro_batch_size: 1
val_at_start: true
val_at_end: false
seed: 42

checkpointing:
Expand Down
1 change: 1 addition & 0 deletions examples/configs/vlm_grpo_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ grpo:
use_leave_one_out_baseline: true
val_period: 10
val_at_start: false
val_at_end: false
overlong_filtering: false
max_val_samples: 256
val_batch_size: 256
Expand Down
1 change: 1 addition & 0 deletions examples/configs/vlm_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ grpo:
use_leave_one_out_baseline: true
val_period: 10
val_at_start: false
val_at_end: false
overlong_filtering: false
max_val_samples: 256
val_batch_size: 256
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ grpo:
use_leave_one_out_baseline: true
val_period: 10
val_at_start: true
val_at_end: false
overlong_filtering: false
max_val_samples: null # inferred from size of val dataset. for multi evals, repeat val ds via `num_repeats` in `ng_prepare_data`.
val_batch_size: null
Expand Down
16 changes: 13 additions & 3 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class DistillationConfig(TypedDict):
val_batch_size: int
val_period: int
val_at_start: bool
# Whether to run validation on the last training step. Setting this to True ensures the
# final checkpoint has validation metrics, which is required for get_best_checkpoint_path().
val_at_end: bool
max_val_samples: int
topk_logits_k: int
seed: int
Expand Down Expand Up @@ -257,7 +260,11 @@ def setup(
# Load validation dataset if provided
val_dataloader: Optional[StatefulDataLoader] = None
# If validation is enabled, load the validation dataloader
if distillation_config["val_period"] > 0 or distillation_config["val_at_start"]:
if (
distillation_config["val_period"] > 0
or distillation_config["val_at_start"]
or distillation_config["val_at_end"]
):
assert val_dataset is not None, (
"Validation dataset is required if validation is enabled"
)
Expand Down Expand Up @@ -539,6 +546,7 @@ def distillation_train(
total_valid_tokens = distillation_save_state["total_valid_tokens"]
val_period = master_config["distillation"]["val_period"]
val_at_start = master_config["distillation"]["val_at_start"]
val_at_end = master_config["distillation"]["val_at_end"]
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]
max_epochs = master_config["distillation"][
"max_num_epochs"
Expand Down Expand Up @@ -721,8 +729,10 @@ def distillation_train(
and (current_step + 1 == len(dataloader))
)

# Run validation if it's a validation step
if val_period > 0 and (total_steps + 1) % val_period == 0:
# Run validation if it's a validation step or last step with val_at_end
if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
val_at_end and is_last_step
):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
student_policy, student_generation, colocated_inference
Expand Down
10 changes: 8 additions & 2 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class DPOConfig(TypedDict):
val_global_batch_size: int
val_micro_batch_size: int
val_at_start: bool
# Whether to run validation on the last training step. Setting this to True ensures the
# final checkpoint has validation metrics, which is required for get_best_checkpoint_path().
val_at_end: bool
seed: int

reference_policy_kl_penalty: float
Expand Down Expand Up @@ -524,6 +527,7 @@ def dpo_train(
# Validation configuration
val_period = dpo_config["val_period"]
val_at_start = dpo_config["val_at_start"]
val_at_end = dpo_config["val_at_end"]
max_num_epochs = dpo_config["max_num_epochs"]

# Run validation at the start if configured
Expand Down Expand Up @@ -582,8 +586,10 @@ def dpo_train(
and current_step + 1 == len(train_dataloader)
)

# Run validation if it's a validation step
if val_period > 0 and (total_steps + 1) % val_period == 0:
# Run validation if it's a validation step or last step with val_at_end
if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
val_at_end and is_last_step
):
validation_result = validate(
policy,
val_dataloader,
Expand Down
22 changes: 18 additions & 4 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class GRPOConfig(TypedDict):
val_period: int
val_batch_size: int
val_at_start: bool
# Whether to run validation on the last training step. Setting this to True ensures the
# final checkpoint has validation metrics, which is required for get_best_checkpoint_path().
val_at_end: bool
max_val_samples: int
skip_reference_policy_logprobs_calculation: NotRequired[bool]
seed: int
Expand Down Expand Up @@ -279,7 +282,11 @@ def setup(
# Load validation dataset if provided
val_dataloader: Optional[StatefulDataLoader] = None
# If validation is enabled, load the validation dataloader
if grpo_config["val_period"] > 0 or grpo_config["val_at_start"]:
if (
grpo_config["val_period"] > 0
or grpo_config["val_at_start"]
or grpo_config["val_at_end"]
):
assert val_dataset is not None, (
"Validation dataset is required if validation is enabled"
)
Expand Down Expand Up @@ -1157,6 +1164,7 @@ def grpo_train(
"total_valid_tokens", 0
) # total valid tokens processed across all epochs; default to 0 for backward compatibility with older checkpoints
val_at_start = master_config["grpo"]["val_at_start"]
val_at_end = master_config["grpo"]["val_at_end"]
val_period = master_config["grpo"]["val_period"]
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]

Expand Down Expand Up @@ -1564,8 +1572,10 @@ def grpo_train(
and (current_step + 1 == len(dataloader))
)

# Run validation if it's a validation step
if val_period > 0 and (total_steps + 1) % val_period == 0:
# Run validation if it's a validation step or last step with val_at_end
if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
val_at_end and is_last_step
):
memory_tracker.snapshot_start_of_stage("Validation", dir())
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
Expand Down Expand Up @@ -2135,6 +2145,7 @@ def async_grpo_train(
) # Default to 0 for backward compatibility with older checkpoints
val_period = master_config["grpo"]["val_period"]
val_at_start = master_config["grpo"]["val_at_start"]
val_at_end = master_config["grpo"]["val_at_end"]
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]

assert not colocated_inference, (
Expand Down Expand Up @@ -2574,7 +2585,10 @@ def async_grpo_train(
val_metrics, validation_timings = None, None
is_last_step = step + 1 == master_config["grpo"]["max_num_steps"]

if val_period > 0 and (step + 1) % val_period == 0:
# Run validation if it's a validation step or last step with val_at_end
if (val_period > 0 and (step + 1) % val_period == 0) or (
val_at_end and is_last_step
):
# Pause trajectory collection during validation to reduce memory pressure
trajectory_collector.pause.remote()

Expand Down
10 changes: 8 additions & 2 deletions nemo_rl/algorithms/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class RMConfig(TypedDict):
val_global_batch_size: int
val_micro_batch_size: int
val_at_start: bool
# Whether to run validation on the last training step. Setting this to True ensures the
# final checkpoint has validation metrics, which is required for get_best_checkpoint_path().
val_at_end: bool
seed: int


Expand Down Expand Up @@ -459,6 +462,7 @@ def rm_train(
# Validation configuration
val_period = rm_config["val_period"]
val_at_start = rm_config["val_at_start"]
val_at_end = rm_config["val_at_end"]
max_num_epochs = rm_config["max_num_epochs"]

# Run validation at the start if configured
Expand Down Expand Up @@ -515,8 +519,10 @@ def rm_train(
and current_step + 1 == len(train_dataloader)
)

# Run validation if it's a validation step
if val_period > 0 and (total_steps + 1) % val_period == 0:
# Run validation if it's a validation step or last step with val_at_end
if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
val_at_end and is_last_step
):
val_metrics, validation_timings = validate(
policy,
val_dataloader,
Expand Down
10 changes: 8 additions & 2 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class SFTConfig(TypedDict):
val_global_batch_size: int
val_micro_batch_size: int
val_at_start: bool
# Whether to run validation on the last training step. Setting this to True ensures the
# final checkpoint has validation metrics, which is required for get_best_checkpoint_path().
val_at_end: bool
seed: int


Expand Down Expand Up @@ -385,6 +388,7 @@ def sft_train(
# Validation configuration
val_period = sft_config["val_period"]
val_at_start = sft_config["val_at_start"]
val_at_end = sft_config["val_at_end"]
max_num_epochs = sft_config["max_num_epochs"]

# Run validation at the start if configured
Expand Down Expand Up @@ -465,8 +469,10 @@ def sft_train(
and current_step + 1 == len(train_dataloader)
)

# Run validation if it's a validation step
if val_period > 0 and (total_steps + 1) % val_period == 0:
# Run validation if it's a validation step or last step with val_at_end
if (val_period > 0 and (total_steps + 1) % val_period == 0) or (
val_at_end and is_last_step
):
val_metrics, validation_timings = validate(
policy,
val_dataloader,
Expand Down
30 changes: 23 additions & 7 deletions nemo_rl/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,25 +230,41 @@ def get_best_checkpoint_path(self) -> Optional[str]:
"""Get the path to the best checkpoint based on the metric.

Returns the path to the checkpoint with the best metric value. If no checkpoints
exist, returns None. If the metric isn't found, we warn and return the latest checkpoint.
exist, returns None. If some checkpoints are missing the metric, they are filtered
out with a warning. If no checkpoints have the metric, returns the latest checkpoint.

Returns:
Optional[str]: Path to the best checkpoint, or None if no valid checkpoints exist.
Optional[str]: Path to the best checkpoint, or None if no checkpoints exist.
"""
checkpoint_history = _load_checkpoint_history(self.checkpoint_dir)
if len(checkpoint_history) == 0:
return None
# sort by metric value
if self.metric_name not in checkpoint_history[0][2]:

# Filter checkpoints that have the metric
valid_checkpoints = [c for c in checkpoint_history if self.metric_name in c[2]]
ignored_count = len(checkpoint_history) - len(valid_checkpoints)

if ignored_count > 0:
ignored_steps = [
c[0] for c in checkpoint_history if self.metric_name not in c[2]
]
warnings.warn(
f"Ignoring {ignored_count} checkpoint(s) at step(s) {ignored_steps} that do not have "
f"metric '{self.metric_name}'. Consider enabling val_at_end or adjusting val_period "
f"to align with max_steps."
)

if len(valid_checkpoints) == 0:
warnings.warn(
f"Metric {self.metric_name} not found in checkpoint history. Returning last"
f"No checkpoints contain metric '{self.metric_name}'. Returning latest checkpoint. "
f"Consider enabling val_at_end or adjusting val_period to align with max_steps."
)
return self.get_latest_checkpoint_path()

checkpoint_history.sort(
valid_checkpoints.sort(
key=lambda x: x[2][self.metric_name], reverse=self.higher_is_better
)
return str(checkpoint_history[0][1])
return str(valid_checkpoints[0][1])

def get_latest_checkpoint_path(self) -> Optional[str]:
"""Get the path to the latest checkpoint.
Expand Down
1 change: 1 addition & 0 deletions research/template_project/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ grpo:
use_leave_one_out_baseline: true
val_period: 10
val_at_start: false
val_at_end: false
overlong_filtering: false
max_val_samples: 256
val_batch_size: 256
Expand Down
1 change: 1 addition & 0 deletions tests/functional/test_converter_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def create_test_config() -> Dict[str, Any]:
"val_global_batch_size": 4,
"val_micro_batch_size": 2,
"val_at_start": False,
"val_at_end": False,
"seed": 42,
},
"checkpointing": {
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/algorithms/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def val_iter(self):
"val_period": 100,
"val_batch_size": 1,
"val_at_start": False,
"val_at_end": False,
"max_val_samples": 10,
"topk_logits_k": 64,
"num_prompts_per_step": 1,
Expand Down Expand Up @@ -444,6 +445,7 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node():
"num_prompts_per_step": 1, # Config extraction requires this key
"val_period": 0, # Config extraction requires this key
"val_at_start": False, # Config extraction requires this key
"val_at_end": False, # Config extraction requires this key
},
"data": {"shuffle": False},
"logger": {}, # Config extraction requires this key
Expand Down Expand Up @@ -516,6 +518,7 @@ def test_distillation_setup_non_colocated_smoke(monkeypatch):
"max_num_steps": 100,
"val_period": 0,
"val_at_start": False,
"val_at_end": False,
},
"data": {"shuffle": False},
"logger": {},
Expand Down Expand Up @@ -624,6 +627,7 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node():
"num_prompts_per_step": 1, # Config extraction requires this key
"val_period": 0, # Config extraction requires this key
"val_at_start": False, # Config extraction requires this key
"val_at_end": False, # Config extraction requires this key
},
"data": {"shuffle": False},
"logger": {}, # Config extraction requires this key
Expand Down
1 change: 1 addition & 0 deletions tests/unit/algorithms/test_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def val_iter(self):
"val_global_batch_size": 1,
"val_micro_batch_size": 1,
"val_at_start": False,
"val_at_end": False,
},
"policy": {
"train_global_batch_size": 2,
Expand Down
Loading
Loading