Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 50
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo-deepscaler-1.5b-16K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ policy:


dynamic_batching:
enabled: False
enabled: False
1 change: 1 addition & 0 deletions examples/configs/grpo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 10
save_period: 10
checkpoint_must_save_by: null

policy:
# Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227)
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null

policy:
# Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227)
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null

policy:
model_name: "Qwen/Qwen2.5-1.5B"
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_8B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ grpo:
checkpointing:
enabled: false
checkpoint_dir: "results/grpo_8b_megatron"
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_sliding_puzzle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null

policy:
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10000
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10000
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10000
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10000
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 50
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: google/gemma-3-1b-it
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: google/gemma-3-27b-it
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: meta-llama/Llama-3.1-8B-Instruct
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: meta-llama/Llama-3.2-1B-Instruct
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: Qwen/Qwen2.5-32B
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: Qwen/Qwen2.5-32B
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: Qwen/Qwen2.5-7B-Instruct
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ checkpointing:
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: Qwen/Qwen2.5-Math-1.5B-Instruct
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: meta-llama/Llama-3.1-8B-Instruct
tokenizer:
Expand Down Expand Up @@ -73,3 +74,4 @@ logger:
cluster:
gpus_per_node: 8
num_nodes: 1

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: meta-llama/Llama-3.1-8B-Instruct
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: meta-llama/Llama-3.1-8B-Instruct
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: meta-llama/Llama-3.2-1B
tokenizer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: Qwen/Qwen2.5-32B
tokenizer:
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.2-1B"
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ checkpointing:
higher_is_better: false
keep_top_k: 100
save_period: 500
checkpoint_must_save_by: null

policy:
model_name: "meta-llama/Llama-3.1-8B"
Expand Down
20 changes: 17 additions & 3 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager
from nemo_rl.utils.logger import Logger, LoggerConfig
from nemo_rl.utils.nsys import maybe_gpu_profile_step
from nemo_rl.utils.timer import Timer
from nemo_rl.utils.timer import TimeoutChecker, Timer


class DPOSaveState(TypedDict):
Expand Down Expand Up @@ -354,6 +354,11 @@ def dpo_train(
):
# Run dpo training
timer = Timer()
timeout = TimeoutChecker(
timeout=master_config["checkpointing"]["checkpoint_must_save_by"],
fit_last_save_time=True,
)
timeout.start_iterations()

if dpo_save_state is None:
dpo_save_state = _default_dpo_save_state()
Expand Down Expand Up @@ -447,11 +452,20 @@ def dpo_train(
dpo_save_state["consumed_samples"] += master_config["policy"][
"train_global_batch_size"
]
if master_config["checkpointing"]["enabled"] and (
timeout.mark_iteration()

should_save_by_step = (
is_last_step
or (total_steps + 1) % master_config["checkpointing"]["save_period"]
== 0
): # +1 because step is 0-indexed
)
# +1 because step is 0-indexed
# Check if timeout-based checkpointing is enabled in config.
should_save_by_timeout = timeout.check_save()

if master_config["checkpointing"]["enabled"] and (
should_save_by_step or should_save_by_timeout
):
dpo_save_state["step"] = (current_step + 1) % len(train_dataloader)
dpo_save_state["total_steps"] = total_steps + 1
dpo_save_state["epoch"] = current_epoch
Expand Down
21 changes: 18 additions & 3 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
print_message_log_samples,
)
from nemo_rl.utils.nsys import maybe_gpu_profile_step
from nemo_rl.utils.timer import Timer
from nemo_rl.utils.timer import TimeoutChecker, Timer

# ===============================================================================
# Configuration
Expand Down Expand Up @@ -476,6 +476,12 @@ def grpo_train(
) -> None:
"""Run GRPO training algorithm."""
timer = Timer()
timeout = TimeoutChecker(
timeout=master_config["checkpointing"]["checkpoint_must_save_by"],
fit_last_save_time=True,
)
timeout.start_iterations()

NEED_REFIT = True
# If policy_generation is None, use the policy as the generation interface (megatron framework backend)
if policy_generation is None:
Expand Down Expand Up @@ -696,10 +702,19 @@ def grpo_train(

## Checkpointing
consumed_samples += master_config["grpo"]["num_prompts_per_step"]
if master_config["checkpointing"]["enabled"] and (
timeout.mark_iteration()

should_save_by_step = (
is_last_step
or (step + 1) % master_config["checkpointing"]["save_period"] == 0
): # +1 because step is 0-indexed
)
# +1 because step is 0-indexed
# Check if timeout-based checkpointing is enabled in config.
should_save_by_timeout = timeout.check_save()

if master_config["checkpointing"]["enabled"] and (
should_save_by_step or should_save_by_timeout
):
policy.prepare_for_training()

grpo_save_state["step"] = step + 1
Expand Down
18 changes: 15 additions & 3 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager
from nemo_rl.utils.logger import Logger, LoggerConfig
from nemo_rl.utils.nsys import maybe_gpu_profile_step
from nemo_rl.utils.timer import Timer
from nemo_rl.utils.timer import TimeoutChecker, Timer


class SFTSaveState(TypedDict):
Expand Down Expand Up @@ -326,6 +326,11 @@ def sft_train(
):
# Run basic sft training
timer = Timer()
timeout = TimeoutChecker(
timeout=master_config["checkpointing"]["checkpoint_must_save_by"],
fit_last_save_time=True,
)
timeout.start_iterations()

if sft_save_state is None:
sft_save_state = _default_sft_save_state()
Expand Down Expand Up @@ -439,12 +444,19 @@ def sft_train(
sft_save_state["consumed_samples"] += master_config["policy"][
"train_global_batch_size"
]
if master_config["checkpointing"]["enabled"] and (
timeout.mark_iteration()
should_save_by_step = (
is_last_step
or (total_steps + 1) % master_config["checkpointing"]["save_period"]
== 0
)
# +1 because step is 0-indexed
# Check if timeout-based checkpointing is enabled in config.
should_save_by_timeout = timeout.check_save()

if master_config["checkpointing"]["enabled"] and (
should_save_by_step or should_save_by_timeout
):
## +1 because step is 0-indexed
sft_save_state["step"] = (current_step + 1) % len(train_dataloader)
sft_save_state["total_steps"] = total_steps + 1
sft_save_state["epoch"] = current_epoch
Expand Down
Loading