From b772f4a6612987da5a09e00559812b2e2bb59b96 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Wed, 23 Jul 2025 16:30:50 -0700 Subject: [PATCH 01/12] save checking point before timeout to deal with 4 hour cluster running time lmit Signed-off-by: Wei Du --- nemo_rl/algorithms/grpo.py | 19 ++++++++--- nemo_rl/utils/timer.py | 64 +++++++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 6 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 530ad4e8f9..9a97ef65cd 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -64,7 +64,7 @@ print_message_log_samples, ) from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import Timer, TimeoutChecker # =============================================================================== # Configuration @@ -476,6 +476,12 @@ def grpo_train( ) -> None: """Run GRPO training algorithm.""" timer = Timer() + timeout = TimeoutChecker( + timeout=master_config.get('timeout', '00:03:45:00'), # three hours and 45 minutes + fit_last_save_time=True, + ) + timeout.start_iterations() + NEED_REFIT = True # If policy_generation is None, use the policy as the generation interface (megatron framework backend) if policy_generation is None: @@ -696,10 +702,13 @@ def grpo_train( ## Checkpointing consumed_samples += master_config["grpo"]["num_prompts_per_step"] - if master_config["checkpointing"]["enabled"] and ( - is_last_step - or (step + 1) % master_config["checkpointing"]["save_period"] == 0 - ): # +1 because step is 0-indexed + timeout.mark_iteration() + + should_save_by_step = (is_last_step or (step + 1) % master_config["checkpointing"]["save_period"] == 0) + # +1 because step is 0-indexed + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and (should_save_by_step or should_save_by_timeout): policy.prepare_for_training() grpo_save_state["step"] = step + 1 diff --git a/nemo_rl/utils/timer.py b/nemo_rl/utils/timer.py index 4fdaffee98..35a7b64857 100644 --- a/nemo_rl/utils/timer.py +++ b/nemo_rl/utils/timer.py @@ -14,7 +14,7 @@ import time from contextlib import contextmanager from typing import Callable, Generator, Optional, Sequence, Union - +import sys import numpy as np @@ -245,3 +245,65 @@ def reset(self, label: Optional[str] = None) -> None: else: self._timers = {} self._start_times = {} + + + + + + + +def convert_to_seconds(time_string): + # Split the string into components + days, hours, minutes, seconds = map(int, time_string.split(':')) + + # Calculate total seconds + total_seconds = days * 86400 + hours * 3600 + minutes * 60 + seconds + + return total_seconds + + +class TimeoutChecker: + def __init__(self, timeout: str = '00:03:45:00', fit_last_save_time=False): + super().__init__() + self.last_save_time = convert_to_seconds(timeout) + self.start_time = time.time() + self.last_saved = False + self.iteration_times = [] + self.previous_iteration_time = None + self.fit_last_save_time = fit_last_save_time + + def check_save(self): + # Flush + sys.stdout.flush() + sys.stderr.flush() + + # Already saved after timeout + if self.last_saved: + return False + + current_time = time.time() + elapsed_time = current_time - self.start_time + + if self.fit_last_save_time: + average_iteration_time = sum(self.iteration_times) / len(self.iteration_times) + if elapsed_time + average_iteration_time >= self.last_save_time: + self.last_saved = True + return True + + if elapsed_time >= self.last_save_time: + self.last_saved = True + return True + + return False + + def start_iterations(self): + self.previous_iteration_time = time.time() + + def mark_iteration(self): + sys.stdout.flush() + sys.stderr.flush() + + current_time = time.time() + elapsed_time = current_time - self.previous_iteration_time + self.previous_iteration_time = current_time + self.iteration_times.append(elapsed_time) \ No newline at end of file From 730c313f97aadb97dc8c634462f6d9cdabbfad59 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Wed, 23 Jul 2025 17:17:30 -0700 Subject: [PATCH 02/12] remove unused space Signed-off-by: Wei Du --- nemo_rl/utils/timer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/nemo_rl/utils/timer.py b/nemo_rl/utils/timer.py index 35a7b64857..d5242f39a6 100644 --- a/nemo_rl/utils/timer.py +++ b/nemo_rl/utils/timer.py @@ -247,11 +247,6 @@ def reset(self, label: Optional[str] = None) -> None: self._start_times = {} - - - - - def convert_to_seconds(time_string): # Split the string into components days, hours, minutes, seconds = map(int, time_string.split(':')) From 6fc99bf5615be053c66d809f35e5c998b01c3fa7 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Wed, 23 Jul 2025 17:38:56 -0700 Subject: [PATCH 03/12] make timeout optionl Signed-off-by: Wei Du --- nemo_rl/algorithms/grpo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 9a97ef65cd..f1676808f3 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -706,7 +706,11 @@ def grpo_train( should_save_by_step = (is_last_step or (step + 1) % master_config["checkpointing"]["save_period"] == 0) # +1 because step is 0-indexed - should_save_by_timeout = timeout.check_save() + # Check if timeout-based checkpointing is enabled in config. + # If so, use TimeoutChecker to determine whether we should save due to timeout.Otherwise, default to False (no timeout-based saving). + if 'timeout' in master_config: should_save_by_timeout = timeout.check_save() + else: should_save_by_timeout = False + if master_config["checkpointing"]["enabled"] and (should_save_by_step or should_save_by_timeout): policy.prepare_for_training() From 5d92547fb57355ec0e24366cad82b0cf306af739 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Mon, 28 Jul 2025 13:41:53 -0700 Subject: [PATCH 04/12] update timeout for all algorithms and configs and add unit tests as well Signed-off-by: Wei Du --- examples/configs/dpo.yaml | 3 ++ .../configs/grpo-deepscaler-1.5b-16K.yaml | 5 +- examples/configs/grpo-deepscaler-1.5b-8K.yaml | 3 ++ .../configs/grpo_deepscaler-1.5b-24K.yaml | 3 ++ examples/configs/grpo_math_1B.yaml | 3 ++ examples/configs/grpo_math_1B_megatron.yaml | 3 ++ examples/configs/grpo_math_70B_megatron.yaml | 3 ++ examples/configs/grpo_math_8B.yaml | 3 ++ examples/configs/grpo_math_8B_megatron.yaml | 3 ++ .../grpo_math_qwen30ba3b_megatron.yaml | 3 ++ examples/configs/grpo_sliding_puzzle.yaml | 3 ++ ...llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml | 3 ++ ....1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml | 2 + ...po-llama3.1-8b-instruct-4n8g-megatron.yaml | 2 + ...8b-instruct-4n8g-megatrontp2pp2-quick.yaml | 2 + ...llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml | 2 + .../llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml | 2 + ...-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml | 2 + ...3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml | 2 + ...llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml | 2 + ...-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml | 2 + ...en2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml | 2 + ...wen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml | 2 + ...5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml | 2 + ...3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml | 2 + ...ama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml | 2 + ...ft-llama3.1-8b-instruct-1n8g-megatron.yaml | 2 + .../llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml | 2 + ...wen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml | 2 + examples/configs/sft.yaml | 3 ++ examples/configs/sft_openmathinstruct2.yaml | 3 ++ nemo_rl/algorithms/dpo.py | 21 +++++--- nemo_rl/algorithms/grpo.py | 9 ++-- nemo_rl/algorithms/sft.py | 21 +++++--- nemo_rl/utils/timer.py | 31 ++++++++---- tests/unit/utils/test_timer.py | 49 ++++++++++++++++++- 36 files changed, 180 insertions(+), 29 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index bcbffb0761..20671cb334 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -173,3 +173,6 @@ logger: cluster: gpus_per_node: 1 num_nodes: 1 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo-deepscaler-1.5b-16K.yaml b/examples/configs/grpo-deepscaler-1.5b-16K.yaml index 575db2f538..312b8ecdd5 100644 --- a/examples/configs/grpo-deepscaler-1.5b-16K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-16K.yaml @@ -11,4 +11,7 @@ policy: dynamic_batching: - enabled: False \ No newline at end of file + enabled: False + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index 08d021f582..3f6097150c 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -146,3 +146,6 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + + +checkpoint_must_save_by: null diff --git a/examples/configs/grpo_deepscaler-1.5b-24K.yaml b/examples/configs/grpo_deepscaler-1.5b-24K.yaml index dc9db4ceab..28db7b449c 100644 --- a/examples/configs/grpo_deepscaler-1.5b-24K.yaml +++ b/examples/configs/grpo_deepscaler-1.5b-24K.yaml @@ -46,3 +46,6 @@ policy: # For most cases, use "dummy" to load the initial weights, since they will be overwritten during refit # For Gemma models, we need to use "auto" due to a vllm bug load_format: dummy + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index a388f7b2cc..a6a414c604 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -158,3 +158,6 @@ logger: cluster: gpus_per_node: 1 num_nodes: 1 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index d58eb47aae..d484239e7c 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -174,3 +174,6 @@ logger: cluster: gpus_per_node: 1 num_nodes: 1 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_70B_megatron.yaml b/examples/configs/grpo_math_70B_megatron.yaml index 1317e45a04..9d2fa8d6fe 100644 --- a/examples/configs/grpo_math_70B_megatron.yaml +++ b/examples/configs/grpo_math_70B_megatron.yaml @@ -68,3 +68,6 @@ policy: cluster: gpus_per_node: 8 num_nodes: 8 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 6a958957c4..7f9a69ff26 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -62,3 +62,6 @@ policy: cluster: gpus_per_node: 8 num_nodes: 1 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_8B_megatron.yaml b/examples/configs/grpo_math_8B_megatron.yaml index 004bc738b0..84f85ee1bc 100644 --- a/examples/configs/grpo_math_8B_megatron.yaml +++ b/examples/configs/grpo_math_8B_megatron.yaml @@ -73,3 +73,6 @@ policy: cluster: gpus_per_node: 8 num_nodes: 1 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml index 84d6736cec..dcfd92f714 100644 --- a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml +++ b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml @@ -75,3 +75,6 @@ policy: cluster: gpus_per_node: 8 num_nodes: 8 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index aeb6b48da4..4f46fff85c 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -64,3 +64,6 @@ logger: gpu_monitoring: collection_interval: 10 # How often to collect GPU usage metrics (in seconds) flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml index b060004882..df55f027a0 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml @@ -94,3 +94,6 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml index c34771595b..ec31ff84c3 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml @@ -94,3 +94,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml index abc42f30eb..1b10add6c6 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml @@ -127,3 +127,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml index a571f32582..262a878bcc 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml @@ -127,3 +127,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml index 832d989b59..5977d19112 100644 --- a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml @@ -94,3 +94,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index b503afad4b..c491218f18 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -122,3 +122,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index ea3188b9ae..2e80ae7b40 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -123,3 +123,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 16 + +checkpoint_must_save_by: null diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index d29b88c4e0..dc4b7def52 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -123,3 +123,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 355cd3a5d3..931d6c3a4c 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -123,3 +123,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml index 0ce93de5ae..5a460a675d 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -123,3 +123,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 32 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml index 45788b3172..c857445f50 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -123,3 +123,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 32 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index ae0add9bd2..af0253bf58 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -123,3 +123,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index cce3f5b327..3ca6ede24e 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -123,3 +123,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml index 50aa3b96c6..ed136dba49 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml @@ -73,3 +73,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml index 7a774c3654..7c8187084c 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml @@ -73,3 +73,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml index 14c2f9692e..0e5ed586d7 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml @@ -117,3 +117,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml index 617ce45096..24f07f245f 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml @@ -73,3 +73,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml index 6761e2f015..0f87c6c15d 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml @@ -73,3 +73,5 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 7ddcd3d867..e8f2fff8a9 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -156,3 +156,6 @@ logger: cluster: gpus_per_node: 1 num_nodes: 1 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index aa128e5a99..88c5a33f3f 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -91,3 +91,6 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + + +checkpoint_must_save_by: null \ No newline at end of file diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 259b23d665..fcc35c058a 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -37,7 +37,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import Timer, TimeoutChecker class DPOSaveState(TypedDict): @@ -85,6 +85,7 @@ class MasterConfig(TypedDict): logger: LoggerConfig cluster: ClusterConfig checkpointing: CheckpointingConfig + checkpoint_must_save_by: NotRequired[str] # ======================================================= @@ -354,6 +355,11 @@ def dpo_train( ): # Run dpo training timer = Timer() + timeout = TimeoutChecker( + timeout=master_config['checkpoint_must_save_by'], + fit_last_save_time=True, + ) + timeout.start_iterations() if dpo_save_state is None: dpo_save_state = _default_dpo_save_state() @@ -447,11 +453,14 @@ def dpo_train( dpo_save_state["consumed_samples"] += master_config["policy"][ "train_global_batch_size" ] - if master_config["checkpointing"]["enabled"] and ( - is_last_step - or (total_steps + 1) % master_config["checkpointing"]["save_period"] - == 0 - ): # +1 because step is 0-indexed + timeout.mark_iteration() + + should_save_by_step = (is_last_step or (total_steps + 1) % master_config["checkpointing"]["save_period"] == 0) + # +1 because step is 0-indexed + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and (should_save_by_step or should_save_by_timeout): dpo_save_state["step"] = (current_step + 1) % len(train_dataloader) dpo_save_state["total_steps"] = total_steps + 1 dpo_save_state["epoch"] = current_epoch diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index f1676808f3..f238fdd355 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -112,6 +112,8 @@ class MasterConfig(TypedDict): logger: GRPOLoggerConfig cluster: ClusterConfig checkpointing: CheckpointingConfig + checkpoint_must_save_by: NotRequired[str] + # =============================================================================== @@ -477,7 +479,7 @@ def grpo_train( """Run GRPO training algorithm.""" timer = Timer() timeout = TimeoutChecker( - timeout=master_config.get('timeout', '00:03:45:00'), # three hours and 45 minutes + timeout=master_config['checkpoint_must_save_by'], fit_last_save_time=True, ) timeout.start_iterations() @@ -707,11 +709,8 @@ def grpo_train( should_save_by_step = (is_last_step or (step + 1) % master_config["checkpointing"]["save_period"] == 0) # +1 because step is 0-indexed # Check if timeout-based checkpointing is enabled in config. - # If so, use TimeoutChecker to determine whether we should save due to timeout.Otherwise, default to False (no timeout-based saving). - if 'timeout' in master_config: should_save_by_timeout = timeout.check_save() - else: should_save_by_timeout = False + should_save_by_timeout = timeout.check_save() - if master_config["checkpointing"]["enabled"] and (should_save_by_step or should_save_by_timeout): policy.prepare_for_training() diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index ee227e0aa6..db44bdd867 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -40,7 +40,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import Timer, TimeoutChecker class SFTSaveState(TypedDict): @@ -78,6 +78,7 @@ class MasterConfig(TypedDict): logger: LoggerConfig cluster: ClusterConfig checkpointing: CheckpointingConfig + checkpoint_must_save_by: NotRequired[str] # ======================================================= @@ -326,6 +327,11 @@ def sft_train( ): # Run basic sft training timer = Timer() + timeout = TimeoutChecker( + timeout=master_config['checkpoint_must_save_by'], + fit_last_save_time=True, + ) + timeout.start_iterations() if sft_save_state is None: sft_save_state = _default_sft_save_state() @@ -439,12 +445,13 @@ def sft_train( sft_save_state["consumed_samples"] += master_config["policy"][ "train_global_batch_size" ] - if master_config["checkpointing"]["enabled"] and ( - is_last_step - or (total_steps + 1) % master_config["checkpointing"]["save_period"] - == 0 - ): - ## +1 because step is 0-indexed + timeout.mark_iteration() + should_save_by_step = (is_last_step or (total_steps + 1) % master_config["checkpointing"]["save_period"] == 0) + # +1 because step is 0-indexed + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and (should_save_by_step or should_save_by_timeout): sft_save_state["step"] = (current_step + 1) % len(train_dataloader) sft_save_state["total_steps"] = total_steps + 1 sft_save_state["epoch"] = current_epoch diff --git a/nemo_rl/utils/timer.py b/nemo_rl/utils/timer.py index d5242f39a6..c1c7de8ac4 100644 --- a/nemo_rl/utils/timer.py +++ b/nemo_rl/utils/timer.py @@ -16,7 +16,7 @@ from typing import Callable, Generator, Optional, Sequence, Union import sys import numpy as np - +from typing import Optional class Timer: """A utility for timing code execution. @@ -247,26 +247,39 @@ def reset(self, label: Optional[str] = None) -> None: self._start_times = {} -def convert_to_seconds(time_string): - # Split the string into components - days, hours, minutes, seconds = map(int, time_string.split(':')) +def convert_to_seconds(time_string: str) -> int: + """ + Converts a time string in the format 'DD:HH:MM:SS' to total seconds. - # Calculate total seconds - total_seconds = days * 86400 + hours * 3600 + minutes * 60 + seconds + Args: + time_string (str): Time duration string, e.g., '00:03:45:00'. + + Returns: + int: Total time in seconds. + """ + days, hours, minutes, seconds = map(int, time_string.split(':')) + return days * 86400 + hours * 3600 + minutes * 60 + seconds - return total_seconds class TimeoutChecker: - def __init__(self, timeout: str = '00:03:45:00', fit_last_save_time=False): + def __init__(self, timeout: Optional[str] = '00:03:45:00', fit_last_save_time: bool = False): + """ + Initializes the TimeoutChecker. + + Args: + timeout (str or None): Timeout in format 'DD:HH:MM:SS'. If None, timeout is considered infinite. + fit_last_save_time (bool): If True, considers average iteration time when checking timeout. + """ super().__init__() - self.last_save_time = convert_to_seconds(timeout) + self.last_save_time = float('inf') if timeout is None else convert_to_seconds(timeout) self.start_time = time.time() self.last_saved = False self.iteration_times = [] self.previous_iteration_time = None self.fit_last_save_time = fit_last_save_time + def check_save(self): # Flush sys.stdout.flush() diff --git a/tests/unit/utils/test_timer.py b/tests/unit/utils/test_timer.py index 56ba315b55..a222a5c736 100644 --- a/tests/unit/utils/test_timer.py +++ b/tests/unit/utils/test_timer.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import Timer, TimeoutChecker class TestTimer: @@ -188,3 +188,50 @@ def test_precise_timing(self, mock_perf_counter, timer): # Check the elapsed time assert elapsed == 5.0 assert timer._timers["precise_test"][0] == 5.0 + + + +class TestTimeoutChecker: + def test_infinite_timeout(self): + checker = TimeoutChecker(timeout=None) + time.sleep(0.1) + assert checker.check_save() is False + + def test_short_timeout(self): + checker = TimeoutChecker(timeout='00:00:00:01') + time.sleep(1.1) + assert checker.check_save() is True + + def test_double_save_prevented(self): + checker = TimeoutChecker(timeout='00:00:00:01') + time.sleep(1.1) + assert checker.check_save() is True + assert checker.check_save() is False + + def test_fit_last_save_time_enabled(self): + # Create a TimeoutChecker with a 3-second timeout and enable fit_last_save_time logic + checker = TimeoutChecker(timeout='00:00:00:03', fit_last_save_time=True) + checker.start_iterations() + + # Simulate 10 iterations, each taking about 0.1 seconds + # This builds up a stable average iteration time + for _ in range(10): + time.sleep(0.1) + checker.mark_iteration() + + # Wait an additional ~2.0 seconds so that: + # elapsed time + avg iteration time >= timeout (3 seconds) + time.sleep(2.0) + + result = checker.check_save() + # Assert that the checker triggers a save due to timeout + assert result is True + + + def test_iteration_tracking(self): + checker = TimeoutChecker() + checker.start_iterations() + time.sleep(0.05) + checker.mark_iteration() + assert len(checker.iteration_times) == 1 + assert checker.iteration_times[0] > 0 \ No newline at end of file From 94d1cc443dbef92b4ece653e4060f1fb7c14a88a Mon Sep 17 00:00:00 2001 From: Wei Du Date: Tue, 29 Jul 2025 12:13:03 -0700 Subject: [PATCH 05/12] put checkpoint_must_save_by under checkpointing Signed-off-by: Wei Du --- examples/configs/dpo.yaml | 2 +- examples/configs/grpo-deepscaler-1.5b-16K.yaml | 1 - examples/configs/grpo-deepscaler-1.5b-8K.yaml | 3 ++- examples/configs/grpo_deepscaler-1.5b-24K.yaml | 1 - examples/configs/grpo_math_1B.yaml | 2 +- examples/configs/grpo_math_1B_megatron.yaml | 2 +- examples/configs/grpo_math_70B_megatron.yaml | 1 - examples/configs/grpo_math_8B.yaml | 1 - examples/configs/grpo_math_8B_megatron.yaml | 2 +- examples/configs/grpo_math_qwen30ba3b_megatron.yaml | 1 - examples/configs/grpo_sliding_puzzle.yaml | 2 +- .../llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml | 2 +- .../llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml | 2 +- .../recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml | 2 +- .../dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml | 2 +- .../llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml | 2 +- .../configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml | 2 +- .../llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml | 3 ++- .../llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml | 2 +- .../llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml | 2 +- .../grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml | 2 +- .../llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml | 2 +- .../llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml | 2 +- .../llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml | 2 +- .../llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml | 2 +- .../llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml | 2 +- .../recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml | 2 +- .../configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml | 2 +- .../llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml | 2 +- examples/configs/sft.yaml | 2 +- examples/configs/sft_openmathinstruct2.yaml | 2 +- nemo_rl/algorithms/dpo.py | 4 ++-- nemo_rl/algorithms/grpo.py | 4 ++-- nemo_rl/algorithms/sft.py | 4 ++-- nemo_rl/utils/timer.py | 2 +- 35 files changed, 35 insertions(+), 38 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 20671cb334..f7bbe72f1e 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -26,6 +26,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 50 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" @@ -175,4 +176,3 @@ cluster: num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo-deepscaler-1.5b-16K.yaml b/examples/configs/grpo-deepscaler-1.5b-16K.yaml index 312b8ecdd5..0129f22da8 100644 --- a/examples/configs/grpo-deepscaler-1.5b-16K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-16K.yaml @@ -14,4 +14,3 @@ policy: enabled: False -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index 3f6097150c..db480f510d 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -28,6 +28,7 @@ checkpointing: higher_is_better: true keep_top_k: 10 save_period: 10 + checkpoint_must_save_by: null policy: # Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227) @@ -148,4 +149,4 @@ cluster: num_nodes: 1 -checkpoint_must_save_by: null + diff --git a/examples/configs/grpo_deepscaler-1.5b-24K.yaml b/examples/configs/grpo_deepscaler-1.5b-24K.yaml index 28db7b449c..70125ef54f 100644 --- a/examples/configs/grpo_deepscaler-1.5b-24K.yaml +++ b/examples/configs/grpo_deepscaler-1.5b-24K.yaml @@ -48,4 +48,3 @@ policy: load_format: dummy -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index a6a414c604..b5f4e3799a 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -28,6 +28,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: # Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227) @@ -160,4 +161,3 @@ cluster: num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index d484239e7c..7154ed6c9b 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -29,6 +29,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: "Qwen/Qwen2.5-1.5B" @@ -176,4 +177,3 @@ cluster: num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_70B_megatron.yaml b/examples/configs/grpo_math_70B_megatron.yaml index 9d2fa8d6fe..86e4cd2c39 100644 --- a/examples/configs/grpo_math_70B_megatron.yaml +++ b/examples/configs/grpo_math_70B_megatron.yaml @@ -70,4 +70,3 @@ cluster: num_nodes: 8 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 7f9a69ff26..1f403e5dc2 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -64,4 +64,3 @@ cluster: num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_8B_megatron.yaml b/examples/configs/grpo_math_8B_megatron.yaml index 84f85ee1bc..5ff13a8874 100644 --- a/examples/configs/grpo_math_8B_megatron.yaml +++ b/examples/configs/grpo_math_8B_megatron.yaml @@ -8,6 +8,7 @@ grpo: checkpointing: enabled: false checkpoint_dir: "results/grpo_8b_megatron" + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" @@ -75,4 +76,3 @@ cluster: num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml index dcfd92f714..0c839cce66 100644 --- a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml +++ b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml @@ -77,4 +77,3 @@ cluster: num_nodes: 8 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index 4f46fff85c..bab614b19c 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: "Qwen/Qwen2.5-1.5B-Instruct" @@ -66,4 +67,3 @@ logger: flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml index df55f027a0..22ee10c8cb 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10000 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" @@ -96,4 +97,3 @@ cluster: num_nodes: 4 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml index ec31ff84c3..a509e74381 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10000 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" @@ -95,4 +96,3 @@ cluster: gpus_per_node: 8 num_nodes: 4 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml index 1b10add6c6..1788eaffe7 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10000 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" @@ -128,4 +129,3 @@ cluster: gpus_per_node: 8 num_nodes: 4 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml index 262a878bcc..7344adc94c 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10000 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" @@ -128,4 +129,3 @@ cluster: gpus_per_node: 8 num_nodes: 4 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml index 5977d19112..8f62c3ee50 100644 --- a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 50 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" @@ -95,4 +96,3 @@ cluster: gpus_per_node: 8 num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index c491218f18..d8ac0ac7a2 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: google/gemma-3-1b-it tokenizer: @@ -123,4 +124,3 @@ cluster: gpus_per_node: 8 num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index 2e80ae7b40..111eae8f9f 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: google/gemma-3-27b-it tokenizer: @@ -124,4 +125,4 @@ cluster: gpus_per_node: 8 num_nodes: 16 -checkpoint_must_save_by: null + diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index dc4b7def52..90893b72d7 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.1-8B-Instruct tokenizer: @@ -124,4 +125,3 @@ cluster: gpus_per_node: 8 num_nodes: 4 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 931d6c3a4c..47b7ee9cf6 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.2-1B-Instruct tokenizer: @@ -124,4 +125,3 @@ cluster: gpus_per_node: 8 num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml index 5a460a675d..2e0a103298 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-32B tokenizer: @@ -124,4 +125,3 @@ cluster: gpus_per_node: 8 num_nodes: 32 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml index c857445f50..40581f74a8 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-32B tokenizer: @@ -124,4 +125,3 @@ cluster: gpus_per_node: 8 num_nodes: 32 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index af0253bf58..953ca406b1 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-7B-Instruct tokenizer: @@ -124,4 +125,3 @@ cluster: gpus_per_node: 8 num_nodes: 4 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index 3ca6ede24e..d81e26bc57 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-Math-1.5B-Instruct tokenizer: @@ -124,4 +125,3 @@ cluster: gpus_per_node: 8 num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml index ed136dba49..ae2599b2ae 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.1-8B-Instruct tokenizer: @@ -74,4 +75,3 @@ cluster: gpus_per_node: 8 num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml index 7c8187084c..669714ee5d 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.1-8B-Instruct tokenizer: @@ -74,4 +75,3 @@ cluster: gpus_per_node: 8 num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml index 0e5ed586d7..903a987365 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.1-8B-Instruct tokenizer: @@ -118,4 +119,3 @@ cluster: gpus_per_node: 8 num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml index 24f07f245f..6edb64262b 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.2-1B tokenizer: @@ -74,4 +75,3 @@ cluster: gpus_per_node: 8 num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml index 0f87c6c15d..66dfdc3f24 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-32B tokenizer: @@ -74,4 +75,3 @@ cluster: gpus_per_node: 8 num_nodes: 4 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index e8f2fff8a9..d6bee25cb5 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -19,6 +19,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.2-1B" @@ -158,4 +159,3 @@ cluster: num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index 88c5a33f3f..63731f37fe 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -16,6 +16,7 @@ checkpointing: higher_is_better: false keep_top_k: 100 save_period: 500 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B" @@ -93,4 +94,3 @@ cluster: num_nodes: 1 -checkpoint_must_save_by: null \ No newline at end of file diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index fcc35c058a..2ff052859a 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -85,7 +85,7 @@ class MasterConfig(TypedDict): logger: LoggerConfig cluster: ClusterConfig checkpointing: CheckpointingConfig - checkpoint_must_save_by: NotRequired[str] + # ======================================================= @@ -356,7 +356,7 @@ def dpo_train( # Run dpo training timer = Timer() timeout = TimeoutChecker( - timeout=master_config['checkpoint_must_save_by'], + timeout=master_config["checkpointing"]['checkpoint_must_save_by'], fit_last_save_time=True, ) timeout.start_iterations() diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index f238fdd355..4bce368f2a 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -112,7 +112,7 @@ class MasterConfig(TypedDict): logger: GRPOLoggerConfig cluster: ClusterConfig checkpointing: CheckpointingConfig - checkpoint_must_save_by: NotRequired[str] + @@ -479,7 +479,7 @@ def grpo_train( """Run GRPO training algorithm.""" timer = Timer() timeout = TimeoutChecker( - timeout=master_config['checkpoint_must_save_by'], + timeout=master_config["checkpointing"]['checkpoint_must_save_by'], fit_last_save_time=True, ) timeout.start_iterations() diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index db44bdd867..878638e07b 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -78,7 +78,7 @@ class MasterConfig(TypedDict): logger: LoggerConfig cluster: ClusterConfig checkpointing: CheckpointingConfig - checkpoint_must_save_by: NotRequired[str] + # ======================================================= @@ -328,7 +328,7 @@ def sft_train( # Run basic sft training timer = Timer() timeout = TimeoutChecker( - timeout=master_config['checkpoint_must_save_by'], + timeout=master_config["checkpointing"]['checkpoint_must_save_by'], fit_last_save_time=True, ) timeout.start_iterations() diff --git a/nemo_rl/utils/timer.py b/nemo_rl/utils/timer.py index c1c7de8ac4..6acec91cbe 100644 --- a/nemo_rl/utils/timer.py +++ b/nemo_rl/utils/timer.py @@ -292,7 +292,7 @@ def check_save(self): current_time = time.time() elapsed_time = current_time - self.start_time - if self.fit_last_save_time: + if self.fit_last_save_time and self.iteration_times: average_iteration_time = sum(self.iteration_times) / len(self.iteration_times) if elapsed_time + average_iteration_time >= self.last_save_time: self.last_saved = True From 421d34cde78231e2a474eea087d795618831b6b9 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Wed, 30 Jul 2025 14:01:17 -0700 Subject: [PATCH 06/12] style: fix formatting via pre-commit Signed-off-by: Wei Du --- nemo_rl/algorithms/dpo.py | 21 +++++++++++++-------- nemo_rl/algorithms/grpo.py | 19 +++++++++++-------- nemo_rl/algorithms/sft.py | 21 +++++++++++++-------- nemo_rl/utils/timer.py | 29 ++++++++++++++++------------- tests/unit/utils/test_timer.py | 12 +++++------- 5 files changed, 58 insertions(+), 44 deletions(-) diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 2ff052859a..b39c6712eb 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -37,7 +37,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer, TimeoutChecker +from nemo_rl.utils.timer import TimeoutChecker, Timer class DPOSaveState(TypedDict): @@ -87,7 +87,6 @@ class MasterConfig(TypedDict): checkpointing: CheckpointingConfig - # ======================================================= # Setup & Initialization # ======================================================= @@ -356,8 +355,8 @@ def dpo_train( # Run dpo training timer = Timer() timeout = TimeoutChecker( - timeout=master_config["checkpointing"]['checkpoint_must_save_by'], - fit_last_save_time=True, + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, ) timeout.start_iterations() @@ -454,13 +453,19 @@ def dpo_train( "train_global_batch_size" ] timeout.mark_iteration() - - should_save_by_step = (is_last_step or (total_steps + 1) % master_config["checkpointing"]["save_period"] == 0) + + should_save_by_step = ( + is_last_step + or (total_steps + 1) % master_config["checkpointing"]["save_period"] + == 0 + ) # +1 because step is 0-indexed # Check if timeout-based checkpointing is enabled in config. should_save_by_timeout = timeout.check_save() - - if master_config["checkpointing"]["enabled"] and (should_save_by_step or should_save_by_timeout): + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): dpo_save_state["step"] = (current_step + 1) % len(train_dataloader) dpo_save_state["total_steps"] = total_steps + 1 dpo_save_state["epoch"] = current_epoch diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 4bce368f2a..e4d7dc0f52 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -64,7 +64,7 @@ print_message_log_samples, ) from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer, TimeoutChecker +from nemo_rl.utils.timer import TimeoutChecker, Timer # =============================================================================== # Configuration @@ -114,8 +114,6 @@ class MasterConfig(TypedDict): checkpointing: CheckpointingConfig - - # =============================================================================== # Setup & Initialization # =============================================================================== @@ -479,8 +477,8 @@ def grpo_train( """Run GRPO training algorithm.""" timer = Timer() timeout = TimeoutChecker( - timeout=master_config["checkpointing"]['checkpoint_must_save_by'], - fit_last_save_time=True, + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, ) timeout.start_iterations() @@ -706,12 +704,17 @@ def grpo_train( consumed_samples += master_config["grpo"]["num_prompts_per_step"] timeout.mark_iteration() - should_save_by_step = (is_last_step or (step + 1) % master_config["checkpointing"]["save_period"] == 0) + should_save_by_step = ( + is_last_step + or (step + 1) % master_config["checkpointing"]["save_period"] == 0 + ) # +1 because step is 0-indexed # Check if timeout-based checkpointing is enabled in config. should_save_by_timeout = timeout.check_save() - - if master_config["checkpointing"]["enabled"] and (should_save_by_step or should_save_by_timeout): + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): policy.prepare_for_training() grpo_save_state["step"] = step + 1 diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 878638e07b..816141bafc 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -40,7 +40,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer, TimeoutChecker +from nemo_rl.utils.timer import TimeoutChecker, Timer class SFTSaveState(TypedDict): @@ -80,7 +80,6 @@ class MasterConfig(TypedDict): checkpointing: CheckpointingConfig - # ======================================================= # Setup & Initialization # ======================================================= @@ -328,8 +327,8 @@ def sft_train( # Run basic sft training timer = Timer() timeout = TimeoutChecker( - timeout=master_config["checkpointing"]['checkpoint_must_save_by'], - fit_last_save_time=True, + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, ) timeout.start_iterations() @@ -446,12 +445,18 @@ def sft_train( "train_global_batch_size" ] timeout.mark_iteration() - should_save_by_step = (is_last_step or (total_steps + 1) % master_config["checkpointing"]["save_period"] == 0) + should_save_by_step = ( + is_last_step + or (total_steps + 1) % master_config["checkpointing"]["save_period"] + == 0 + ) # +1 because step is 0-indexed # Check if timeout-based checkpointing is enabled in config. - should_save_by_timeout = timeout.check_save() - - if master_config["checkpointing"]["enabled"] and (should_save_by_step or should_save_by_timeout): + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): sft_save_state["step"] = (current_step + 1) % len(train_dataloader) sft_save_state["total_steps"] = total_steps + 1 sft_save_state["epoch"] = current_epoch diff --git a/nemo_rl/utils/timer.py b/nemo_rl/utils/timer.py index 6acec91cbe..3fb4ffb61c 100644 --- a/nemo_rl/utils/timer.py +++ b/nemo_rl/utils/timer.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import time from contextlib import contextmanager from typing import Callable, Generator, Optional, Sequence, Union -import sys + import numpy as np -from typing import Optional + class Timer: """A utility for timing code execution. @@ -248,8 +249,7 @@ def reset(self, label: Optional[str] = None) -> None: def convert_to_seconds(time_string: str) -> int: - """ - Converts a time string in the format 'DD:HH:MM:SS' to total seconds. + """Converts a time string in the format 'DD:HH:MM:SS' to total seconds. Args: time_string (str): Time duration string, e.g., '00:03:45:00'. @@ -257,29 +257,30 @@ def convert_to_seconds(time_string: str) -> int: Returns: int: Total time in seconds. """ - days, hours, minutes, seconds = map(int, time_string.split(':')) + days, hours, minutes, seconds = map(int, time_string.split(":")) return days * 86400 + hours * 3600 + minutes * 60 + seconds - class TimeoutChecker: - def __init__(self, timeout: Optional[str] = '00:03:45:00', fit_last_save_time: bool = False): - """ - Initializes the TimeoutChecker. + def __init__( + self, timeout: Optional[str] = "00:03:45:00", fit_last_save_time: bool = False + ): + """Initializes the TimeoutChecker. Args: timeout (str or None): Timeout in format 'DD:HH:MM:SS'. If None, timeout is considered infinite. fit_last_save_time (bool): If True, considers average iteration time when checking timeout. """ super().__init__() - self.last_save_time = float('inf') if timeout is None else convert_to_seconds(timeout) + self.last_save_time = ( + float("inf") if timeout is None else convert_to_seconds(timeout) + ) self.start_time = time.time() self.last_saved = False self.iteration_times = [] self.previous_iteration_time = None self.fit_last_save_time = fit_last_save_time - def check_save(self): # Flush sys.stdout.flush() @@ -293,7 +294,9 @@ def check_save(self): elapsed_time = current_time - self.start_time if self.fit_last_save_time and self.iteration_times: - average_iteration_time = sum(self.iteration_times) / len(self.iteration_times) + average_iteration_time = sum(self.iteration_times) / len( + self.iteration_times + ) if elapsed_time + average_iteration_time >= self.last_save_time: self.last_saved = True return True @@ -314,4 +317,4 @@ def mark_iteration(self): current_time = time.time() elapsed_time = current_time - self.previous_iteration_time self.previous_iteration_time = current_time - self.iteration_times.append(elapsed_time) \ No newline at end of file + self.iteration_times.append(elapsed_time) diff --git a/tests/unit/utils/test_timer.py b/tests/unit/utils/test_timer.py index a222a5c736..041193b777 100644 --- a/tests/unit/utils/test_timer.py +++ b/tests/unit/utils/test_timer.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from nemo_rl.utils.timer import Timer, TimeoutChecker +from nemo_rl.utils.timer import TimeoutChecker, Timer class TestTimer: @@ -190,7 +190,6 @@ def test_precise_timing(self, mock_perf_counter, timer): assert timer._timers["precise_test"][0] == 5.0 - class TestTimeoutChecker: def test_infinite_timeout(self): checker = TimeoutChecker(timeout=None) @@ -198,19 +197,19 @@ def test_infinite_timeout(self): assert checker.check_save() is False def test_short_timeout(self): - checker = TimeoutChecker(timeout='00:00:00:01') + checker = TimeoutChecker(timeout="00:00:00:01") time.sleep(1.1) assert checker.check_save() is True def test_double_save_prevented(self): - checker = TimeoutChecker(timeout='00:00:00:01') + checker = TimeoutChecker(timeout="00:00:00:01") time.sleep(1.1) assert checker.check_save() is True assert checker.check_save() is False def test_fit_last_save_time_enabled(self): # Create a TimeoutChecker with a 3-second timeout and enable fit_last_save_time logic - checker = TimeoutChecker(timeout='00:00:00:03', fit_last_save_time=True) + checker = TimeoutChecker(timeout="00:00:00:03", fit_last_save_time=True) checker.start_iterations() # Simulate 10 iterations, each taking about 0.1 seconds @@ -227,11 +226,10 @@ def test_fit_last_save_time_enabled(self): # Assert that the checker triggers a save due to timeout assert result is True - def test_iteration_tracking(self): checker = TimeoutChecker() checker.start_iterations() time.sleep(0.05) checker.mark_iteration() assert len(checker.iteration_times) == 1 - assert checker.iteration_times[0] > 0 \ No newline at end of file + assert checker.iteration_times[0] > 0 From 74abdb96756ce679a0146537e42df70c03f37b81 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Wed, 30 Jul 2025 20:46:58 -0700 Subject: [PATCH 07/12] fix: add type annotation and null check for previous_iteration_time Signed-off-by: Wei Du --- nemo_rl/utils/timer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nemo_rl/utils/timer.py b/nemo_rl/utils/timer.py index 3fb4ffb61c..5366d3f339 100644 --- a/nemo_rl/utils/timer.py +++ b/nemo_rl/utils/timer.py @@ -278,7 +278,7 @@ def __init__( self.start_time = time.time() self.last_saved = False self.iteration_times = [] - self.previous_iteration_time = None + self.previous_iteration_time: Optional[float] = None self.fit_last_save_time = fit_last_save_time def check_save(self): @@ -315,6 +315,7 @@ def mark_iteration(self): sys.stderr.flush() current_time = time.time() - elapsed_time = current_time - self.previous_iteration_time - self.previous_iteration_time = current_time + if self.previous_iteration_time is not None: + elapsed_time = current_time - self.previous_iteration_time + self.previous_iteration_time = current_time self.iteration_times.append(elapsed_time) From 2dc53469e61078e339b6a94c5581e1bcd4c7f38e Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 30 Jul 2025 22:15:24 -0700 Subject: [PATCH 08/12] remove a bunch of unneeded new lines Signed-off-by: Terry Kong --- examples/configs/dpo.yaml | 2 -- examples/configs/grpo-deepscaler-1.5b-16K.yaml | 2 -- examples/configs/grpo-deepscaler-1.5b-8K.yaml | 3 --- examples/configs/grpo_deepscaler-1.5b-24K.yaml | 2 -- examples/configs/grpo_math_1B.yaml | 2 -- examples/configs/grpo_math_1B_megatron.yaml | 2 -- examples/configs/grpo_math_70B_megatron.yaml | 2 -- examples/configs/grpo_math_8B.yaml | 2 -- examples/configs/grpo_math_8B_megatron.yaml | 2 -- examples/configs/grpo_math_qwen30ba3b_megatron.yaml | 2 -- examples/configs/grpo_sliding_puzzle.yaml | 2 -- .../recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml | 2 -- .../llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml | 1 - .../recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml | 1 - .../dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml | 1 - .../recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml | 1 - .../configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml | 1 - .../llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml | 2 -- .../llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml | 1 - .../llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml | 1 - .../llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml | 1 - .../llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml | 1 - .../llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml | 1 - .../llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml | 1 - .../llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml | 1 - .../recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml | 1 - .../configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml | 1 - .../llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml | 1 - examples/configs/sft.yaml | 2 -- examples/configs/sft_openmathinstruct2.yaml | 2 -- 30 files changed, 46 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index f7bbe72f1e..f8627ae7af 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -174,5 +174,3 @@ logger: cluster: gpus_per_node: 1 num_nodes: 1 - - diff --git a/examples/configs/grpo-deepscaler-1.5b-16K.yaml b/examples/configs/grpo-deepscaler-1.5b-16K.yaml index 0129f22da8..81988e5902 100644 --- a/examples/configs/grpo-deepscaler-1.5b-16K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-16K.yaml @@ -12,5 +12,3 @@ policy: dynamic_batching: enabled: False - - diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index db480f510d..5c8f2638ea 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -147,6 +147,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 - - - diff --git a/examples/configs/grpo_deepscaler-1.5b-24K.yaml b/examples/configs/grpo_deepscaler-1.5b-24K.yaml index 70125ef54f..dc9db4ceab 100644 --- a/examples/configs/grpo_deepscaler-1.5b-24K.yaml +++ b/examples/configs/grpo_deepscaler-1.5b-24K.yaml @@ -46,5 +46,3 @@ policy: # For most cases, use "dummy" to load the initial weights, since they will be overwritten during refit # For Gemma models, we need to use "auto" due to a vllm bug load_format: dummy - - diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index b5f4e3799a..852f339cd7 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -159,5 +159,3 @@ logger: cluster: gpus_per_node: 1 num_nodes: 1 - - diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 7154ed6c9b..fc7bed6e0a 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -175,5 +175,3 @@ logger: cluster: gpus_per_node: 1 num_nodes: 1 - - diff --git a/examples/configs/grpo_math_70B_megatron.yaml b/examples/configs/grpo_math_70B_megatron.yaml index 86e4cd2c39..1317e45a04 100644 --- a/examples/configs/grpo_math_70B_megatron.yaml +++ b/examples/configs/grpo_math_70B_megatron.yaml @@ -68,5 +68,3 @@ policy: cluster: gpus_per_node: 8 num_nodes: 8 - - diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 1f403e5dc2..6a958957c4 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -62,5 +62,3 @@ policy: cluster: gpus_per_node: 8 num_nodes: 1 - - diff --git a/examples/configs/grpo_math_8B_megatron.yaml b/examples/configs/grpo_math_8B_megatron.yaml index 5ff13a8874..efa020934e 100644 --- a/examples/configs/grpo_math_8B_megatron.yaml +++ b/examples/configs/grpo_math_8B_megatron.yaml @@ -74,5 +74,3 @@ policy: cluster: gpus_per_node: 8 num_nodes: 1 - - diff --git a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml index 0c839cce66..84d6736cec 100644 --- a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml +++ b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml @@ -75,5 +75,3 @@ policy: cluster: gpus_per_node: 8 num_nodes: 8 - - diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index bab614b19c..4bf48e4c36 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -65,5 +65,3 @@ logger: gpu_monitoring: collection_interval: 10 # How often to collect GPU usage metrics (in seconds) flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) - - diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml index 22ee10c8cb..043f955060 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml @@ -95,5 +95,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 - - diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml index a509e74381..96366c458b 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml @@ -95,4 +95,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 - diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml index 1788eaffe7..8575b5afec 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml @@ -128,4 +128,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 - diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml index 7344adc94c..ef3c76dee7 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml @@ -128,4 +128,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 - diff --git a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml index 8f62c3ee50..5639d62e1f 100644 --- a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml @@ -95,4 +95,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 - diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index d8ac0ac7a2..3f25e3344d 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -123,4 +123,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 - diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index 111eae8f9f..b35c635b2a 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -124,5 +124,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 16 - - diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index 90893b72d7..9ac33dc7e4 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -124,4 +124,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 - diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 47b7ee9cf6..ffffe82b90 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -124,4 +124,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 - diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml index 2e0a103298..4935c1c384 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -124,4 +124,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 32 - diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml index 40581f74a8..aaeba34142 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -124,4 +124,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 32 - diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index 953ca406b1..8af7b39a83 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -124,4 +124,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 - diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index d81e26bc57..3989cc135e 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -124,4 +124,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 - diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml index 669714ee5d..118a122dda 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml @@ -74,4 +74,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 - diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml index 903a987365..475ce53f25 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml @@ -118,4 +118,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 - diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml index 6edb64262b..fa24499b31 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml @@ -74,4 +74,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 - diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml index 66dfdc3f24..45402fc1ae 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml @@ -74,4 +74,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 4 - diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index d6bee25cb5..9a342f2e3a 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -157,5 +157,3 @@ logger: cluster: gpus_per_node: 1 num_nodes: 1 - - diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index 63731f37fe..0649767823 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -92,5 +92,3 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 - - From cd306d83c2ecae5c63a1fa939d0c4ef51274d60e Mon Sep 17 00:00:00 2001 From: Wei Du Date: Thu, 31 Jul 2025 20:18:28 -0700 Subject: [PATCH 09/12] test: add unit test for sft edge case Signed-off-by: Wei Du --- tests/unit/algorithms/test_sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index 4b6d9ee2ce..6e84e772ec 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -78,7 +78,7 @@ def val_iter(self): "train_global_batch_size": 1, "make_sequence_length_divisible_by": 8, }, - "checkpointing": {"enabled": False}, + "checkpointing": {"enabled": False, "checkpoint_must_save_by": None}, } return { From 514b8edf6813c515830728b2c26a07ca4458d5c4 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Mon, 4 Aug 2025 11:01:32 -0700 Subject: [PATCH 10/12] fix unit test error Signed-off-by: Wei Du --- tests/unit/algorithms/test_sft.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index 6e84e772ec..538c29ff14 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -78,7 +78,11 @@ def val_iter(self): "train_global_batch_size": 1, "make_sequence_length_divisible_by": 8, }, - "checkpointing": {"enabled": False, "checkpoint_must_save_by": None}, + "checkpointing": { + "enabled": False, + "checkpoint_must_save_by": None, + "save_period": 10, + }, } return { From 3c55472b4acb2fb3eebc8c4ccfc05d7fac2db170 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Mon, 4 Aug 2025 12:10:04 -0700 Subject: [PATCH 11/12] fix unit test error Signed-off-by: Wei Du --- nemo_rl/utils/checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index ebc276bba6..cda2b111fc 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -49,6 +49,7 @@ class CheckpointingConfig(TypedDict): higher_is_better: bool save_period: int keep_top_k: NotRequired[int] + checkpoint_must_save_by: NotRequired[int | None] class CheckpointManager: From 51aaeaa75df65543a4a3c2f394eb089c6cdcab26 Mon Sep 17 00:00:00 2001 From: Wei Du Date: Mon, 4 Aug 2025 12:13:31 -0700 Subject: [PATCH 12/12] fix unit test error Signed-off-by: Wei Du --- nemo_rl/utils/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index cda2b111fc..388b1ff83c 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -49,7 +49,7 @@ class CheckpointingConfig(TypedDict): higher_is_better: bool save_period: int keep_top_k: NotRequired[int] - checkpoint_must_save_by: NotRequired[int | None] + checkpoint_must_save_by: NotRequired[str | None] class CheckpointManager: