Skip to content
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@ def validate_one_dataset(
):
"""Run validation on one validation dataset."""
if val_dataloader is None:
assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
"val_dataloader is None, so dpo.val_period must be 0"
)
print(" ⚠️ No validation dataloader provided, skipping validation")
return

Expand Down Expand Up @@ -640,6 +643,8 @@ def dpo_train(
current_step += 1
total_steps += 1

if should_save_by_timeout:
return
if total_steps >= master_config["dpo"]["max_num_steps"]:
return

Expand Down
6 changes: 6 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,9 @@ def grpo_train(

timer.reset()
step += 1

if should_save_by_timeout:
break
if step >= master_config["grpo"]["max_num_steps"]:
break

Expand All @@ -925,6 +928,9 @@ def validate(
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Run validation on the validation dataset."""
if val_dataloader is None:
assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
"val_dataloader is None, so dpo.val_period must be 0"
)
print(" ⚠️ No validation dataloader provided, skipping validation", flush=True)
return {}, {}

Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def validate(
):
"""Run validation on the validation dataset."""
if val_dataloader is None:
assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
"val_dataloader is None, so dpo.val_period must be 0"
)
Comment thread
terrykong marked this conversation as resolved.
print(" ⚠️ No validation dataloader provided, skipping validation")
return

Expand Down Expand Up @@ -564,6 +567,8 @@ def sft_train(
current_step += 1
total_steps += 1

if should_save_by_timeout:
return
if total_steps >= master_config["sft"]["max_num_steps"]:
return

Expand Down