diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 4524338e4f..62d322245b 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -26,6 +26,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 50 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index e742480739..b082b66511 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -28,6 +28,7 @@ checkpointing: higher_is_better: true keep_top_k: 10 save_period: 10 + checkpoint_must_save_by: null policy: # Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index b9be32bdda..12212288e7 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -28,6 +28,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: # Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227) diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index cf6ba44d75..40ee4ee74c 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -29,6 +29,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: "Qwen/Qwen2.5-1.5B" diff --git a/examples/configs/grpo_math_8B_megatron.yaml b/examples/configs/grpo_math_8B_megatron.yaml index 3f68344417..90cfb49e6b 100644 --- a/examples/configs/grpo_math_8B_megatron.yaml +++ b/examples/configs/grpo_math_8B_megatron.yaml @@ -8,6 +8,7 @@ grpo: checkpointing: enabled: false checkpoint_dir: "results/grpo_8b_megatron" + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" diff --git a/examples/configs/grpo_sliding_puzzle.yaml b/examples/configs/grpo_sliding_puzzle.yaml index 97f54cc67a..2d9392de95 100644 --- a/examples/configs/grpo_sliding_puzzle.yaml +++ b/examples/configs/grpo_sliding_puzzle.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: "Qwen/Qwen2.5-1.5B-Instruct" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml index e7eaef706a..f49f6a7886 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp1.v2.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10000 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml index 4906550001..7d50bd27aa 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10000 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml index 789f4fcbdf..2faacb5c39 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10000 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml index 7d480f58a3..ea5cc67876 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10000 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" diff --git a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml index 8863fad45f..eb71e87efc 100644 --- a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml @@ -21,6 +21,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 50 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index 102c274bd6..9d82de0338 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: google/gemma-3-1b-it tokenizer: diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml index ff89e45881..2a607463ed 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: google/gemma-3-27b-it tokenizer: diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index d778674238..da1da9f173 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.1-8B-Instruct tokenizer: diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index ea4f5e66e0..d87d0c8387 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.2-1B-Instruct tokenizer: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml index 9b8ecb47b9..011d7ff3b2 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-32B tokenizer: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml index 4a21332a07..deaf512cd4 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-32B tokenizer: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index 54b60a3cfb..26da69ad31 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-7B-Instruct tokenizer: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index b0930e76c2..b16e9d8bb8 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -24,6 +24,7 @@ checkpointing: higher_is_better: true keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-Math-1.5B-Instruct tokenizer: diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml index 8535855965..9bde2d6f1e 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp1-long.v2.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.1-8B-Instruct tokenizer: @@ -75,3 +76,4 @@ logger: cluster: gpus_per_node: 8 num_nodes: 1 + diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml index 2eff0aabf6..00489a4d4c 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.1-8B-Instruct tokenizer: diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml index 07f5524000..8b64b71f63 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.1-8B-Instruct tokenizer: diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml index c6311cf357..c0fd12288c 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v2.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: meta-llama/Llama-3.2-1B tokenizer: diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml index 54d30dd80b..77a1dd51fc 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v2.yaml @@ -14,6 +14,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: Qwen/Qwen2.5-32B tokenizer: diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index a592321cfe..90435051d5 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -19,6 +19,7 @@ checkpointing: higher_is_better: false keep_top_k: 3 save_period: 10 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.2-1B" diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index 1f1b88a8a9..e9f4b595c8 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -16,6 +16,7 @@ checkpointing: higher_is_better: false keep_top_k: 100 save_period: 500 + checkpoint_must_save_by: null policy: model_name: "meta-llama/Llama-3.1-8B" diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 30ba78f6f2..0702502182 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -36,7 +36,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import TimeoutChecker, Timer class DPOSaveState(TypedDict): @@ -364,6 +364,11 @@ def dpo_train( ) -> None: # Run dpo training timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() if dpo_save_state is None: dpo_save_state = _default_dpo_save_state() @@ -465,11 +470,20 @@ def dpo_train( dpo_save_state["consumed_samples"] += master_config["policy"][ "train_global_batch_size" ] - if master_config["checkpointing"]["enabled"] and ( + timeout.mark_iteration() + + should_save_by_step = ( is_last_step or (total_steps + 1) % master_config["checkpointing"]["save_period"] == 0 - ): # +1 because step is 0-indexed + ) + # +1 because step is 0-indexed + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): dpo_save_state["step"] = (current_step + 1) % len(train_dataloader) dpo_save_state["total_steps"] = total_steps + 1 dpo_save_state["epoch"] = current_epoch diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index fceb2173c6..e1fb82f81f 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -65,7 +65,7 @@ print_message_log_samples, ) from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import TimeoutChecker, Timer # =============================================================================== # Configuration @@ -487,6 +487,12 @@ def grpo_train( ) -> None: """Run GRPO training algorithm.""" timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() + NEED_REFIT = True # If policy_generation is None, use the policy as the generation interface (megatron framework backend) if policy_generation is None: @@ -707,10 +713,19 @@ def grpo_train( ## Checkpointing consumed_samples += master_config["grpo"]["num_prompts_per_step"] - if master_config["checkpointing"]["enabled"] and ( + timeout.mark_iteration() + + should_save_by_step = ( is_last_step or (step + 1) % master_config["checkpointing"]["save_period"] == 0 - ): # +1 because step is 0-indexed + ) + # +1 because step is 0-indexed + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): policy.prepare_for_training() grpo_save_state["step"] = step + 1 diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 804909c2c4..e2fb884c5e 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -40,7 +40,7 @@ from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.nsys import maybe_gpu_profile_step -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import TimeoutChecker, Timer class SFTSaveState(TypedDict): @@ -326,6 +326,11 @@ def sft_train( ) -> None: # Run basic sft training timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() if sft_save_state is None: sft_save_state = _default_sft_save_state() @@ -439,12 +444,19 @@ def sft_train( sft_save_state["consumed_samples"] += master_config["policy"][ "train_global_batch_size" ] - if master_config["checkpointing"]["enabled"] and ( + timeout.mark_iteration() + should_save_by_step = ( is_last_step or (total_steps + 1) % master_config["checkpointing"]["save_period"] == 0 + ) + # +1 because step is 0-indexed + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout ): - ## +1 because step is 0-indexed sft_save_state["step"] = (current_step + 1) % len(train_dataloader) sft_save_state["total_steps"] = total_steps + 1 sft_save_state["epoch"] = current_epoch diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index 48231d76a8..6f84d7782f 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -49,6 +49,7 @@ class CheckpointingConfig(TypedDict): higher_is_better: bool save_period: int keep_top_k: NotRequired[int] + checkpoint_must_save_by: NotRequired[str | None] class CheckpointManager: diff --git a/nemo_rl/utils/timer.py b/nemo_rl/utils/timer.py index 4fdaffee98..5366d3f339 100644 --- a/nemo_rl/utils/timer.py +++ b/nemo_rl/utils/timer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import time from contextlib import contextmanager from typing import Callable, Generator, Optional, Sequence, Union @@ -245,3 +246,76 @@ def reset(self, label: Optional[str] = None) -> None: else: self._timers = {} self._start_times = {} + + +def convert_to_seconds(time_string: str) -> int: + """Converts a time string in the format 'DD:HH:MM:SS' to total seconds. + + Args: + time_string (str): Time duration string, e.g., '00:03:45:00'. + + Returns: + int: Total time in seconds. + """ + days, hours, minutes, seconds = map(int, time_string.split(":")) + return days * 86400 + hours * 3600 + minutes * 60 + seconds + + +class TimeoutChecker: + def __init__( + self, timeout: Optional[str] = "00:03:45:00", fit_last_save_time: bool = False + ): + """Initializes the TimeoutChecker. + + Args: + timeout (str or None): Timeout in format 'DD:HH:MM:SS'. If None, timeout is considered infinite. + fit_last_save_time (bool): If True, considers average iteration time when checking timeout. + """ + super().__init__() + self.last_save_time = ( + float("inf") if timeout is None else convert_to_seconds(timeout) + ) + self.start_time = time.time() + self.last_saved = False + self.iteration_times = [] + self.previous_iteration_time: Optional[float] = None + self.fit_last_save_time = fit_last_save_time + + def check_save(self): + # Flush + sys.stdout.flush() + sys.stderr.flush() + + # Already saved after timeout + if self.last_saved: + return False + + current_time = time.time() + elapsed_time = current_time - self.start_time + + if self.fit_last_save_time and self.iteration_times: + average_iteration_time = sum(self.iteration_times) / len( + self.iteration_times + ) + if elapsed_time + average_iteration_time >= self.last_save_time: + self.last_saved = True + return True + + if elapsed_time >= self.last_save_time: + self.last_saved = True + return True + + return False + + def start_iterations(self): + self.previous_iteration_time = time.time() + + def mark_iteration(self): + sys.stdout.flush() + sys.stderr.flush() + + current_time = time.time() + if self.previous_iteration_time is not None: + elapsed_time = current_time - self.previous_iteration_time + self.previous_iteration_time = current_time + self.iteration_times.append(elapsed_time) diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index 4b6d9ee2ce..538c29ff14 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -78,7 +78,11 @@ def val_iter(self): "train_global_batch_size": 1, "make_sequence_length_divisible_by": 8, }, - "checkpointing": {"enabled": False}, + "checkpointing": { + "enabled": False, + "checkpoint_must_save_by": None, + "save_period": 10, + }, } return { diff --git a/tests/unit/utils/test_timer.py b/tests/unit/utils/test_timer.py index 56ba315b55..041193b777 100644 --- a/tests/unit/utils/test_timer.py +++ b/tests/unit/utils/test_timer.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from nemo_rl.utils.timer import Timer +from nemo_rl.utils.timer import TimeoutChecker, Timer class TestTimer: @@ -188,3 +188,48 @@ def test_precise_timing(self, mock_perf_counter, timer): # Check the elapsed time assert elapsed == 5.0 assert timer._timers["precise_test"][0] == 5.0 + + +class TestTimeoutChecker: + def test_infinite_timeout(self): + checker = TimeoutChecker(timeout=None) + time.sleep(0.1) + assert checker.check_save() is False + + def test_short_timeout(self): + checker = TimeoutChecker(timeout="00:00:00:01") + time.sleep(1.1) + assert checker.check_save() is True + + def test_double_save_prevented(self): + checker = TimeoutChecker(timeout="00:00:00:01") + time.sleep(1.1) + assert checker.check_save() is True + assert checker.check_save() is False + + def test_fit_last_save_time_enabled(self): + # Create a TimeoutChecker with a 3-second timeout and enable fit_last_save_time logic + checker = TimeoutChecker(timeout="00:00:00:03", fit_last_save_time=True) + checker.start_iterations() + + # Simulate 10 iterations, each taking about 0.1 seconds + # This builds up a stable average iteration time + for _ in range(10): + time.sleep(0.1) + checker.mark_iteration() + + # Wait an additional ~2.0 seconds so that: + # elapsed time + avg iteration time >= timeout (3 seconds) + time.sleep(2.0) + + result = checker.check_save() + # Assert that the checker triggers a save due to timeout + assert result is True + + def test_iteration_tracking(self): + checker = TimeoutChecker() + checker.start_iterations() + time.sleep(0.05) + checker.mark_iteration() + assert len(checker.iteration_times) == 1 + assert checker.iteration_times[0] > 0