Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: ${rm.val_period}
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ def validate_one_dataset(
):
"""Run validation on one validation dataset."""
if val_dataloader is None:
assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
"val_dataloader is None, so dpo.val_period must be 0"
)
print(" ⚠️ No validation dataloader provided, skipping validation")
return

Expand Down Expand Up @@ -707,6 +710,8 @@ def dpo_train(
current_step += 1
total_steps += 1

if should_save_by_timeout:
return
if total_steps >= master_config["dpo"]["max_num_steps"]:
return

Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,8 @@ def grpo_train(
timer.reset()
current_step += 1
total_steps += 1
if should_save_by_timeout:
break
if total_steps >= max_num_steps:
break

Expand All @@ -978,6 +980,9 @@ def validate(
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Run validation on the validation dataset."""
if val_dataloader is None:
assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
"val_dataloader is None, so dpo.val_period must be 0"
)
print(" ⚠️ No validation dataloader provided, skipping validation", flush=True)
return {}, {}

Expand Down
23 changes: 20 additions & 3 deletions nemo_rl/algorithms/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager
from nemo_rl.utils.logger import Logger, LoggerConfig
from nemo_rl.utils.nsys import maybe_gpu_profile_step
from nemo_rl.utils.timer import Timer
from nemo_rl.utils.timer import TimeoutChecker, Timer


class RMSaveState(TypedDict):
Expand Down Expand Up @@ -305,6 +305,9 @@ def validate_one_dataset(
):
"""Run validation on one validation dataset."""
if val_dataloader is None:
assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
"val_dataloader is None, so dpo.val_period must be 0"
)
print(" ⚠️ No validation dataloader provided, skipping validation")
return

Expand Down Expand Up @@ -426,7 +429,11 @@ def rm_train(
):
# Run basic rm training
timer = Timer()

timeout = TimeoutChecker(
timeout=master_config["checkpointing"]["checkpoint_must_save_by"],
fit_last_save_time=True,
)
timeout.start_iterations()
if rm_save_state is None:
rm_save_state = _default_rm_save_state()
current_epoch = 0
Expand Down Expand Up @@ -512,13 +519,21 @@ def rm_train(
)

## Checkpointing
timeout.mark_iteration()

rm_save_state["consumed_samples"] += master_config["policy"][
"train_global_batch_size"
]
if master_config["checkpointing"]["enabled"] and (

should_save_by_step = (
is_last_step
or (total_steps + 1) % master_config["checkpointing"]["save_period"]
== 0
)
should_save_by_timeout = timeout.check_save()

if master_config["checkpointing"]["enabled"] and (
should_save_by_step or should_save_by_timeout
):
## +1 because step is 0-indexed
rm_save_state["step"] = (current_step + 1) % len(train_dataloader)
Expand Down Expand Up @@ -615,6 +630,8 @@ def rm_train(
current_step += 1
total_steps += 1

if should_save_by_timeout:
return
if (
master_config["rm"]["max_num_steps"] != -1
and total_steps >= master_config["rm"]["max_num_steps"]
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def validate(
):
"""Run validation on the validation dataset."""
if val_dataloader is None:
assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
"val_dataloader is None, so dpo.val_period must be 0"
)
print(" ⚠️ No validation dataloader provided, skipping validation")
return

Expand Down Expand Up @@ -578,6 +581,8 @@ def sft_train(
current_step += 1
total_steps += 1

if should_save_by_timeout:
return
if total_steps >= master_config["sft"]["max_num_steps"]:
return

Expand Down
6 changes: 5 additions & 1 deletion tests/unit/algorithms/test_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def val_iter(self):
},
"train_micro_batch_size": 1,
},
"checkpointing": {"enabled": False},
"checkpointing": {
"enabled": False,
"checkpoint_must_save_by": None,
"save_period": 10,
},
}

return {
Expand Down