diff --git a/docs/training/checkpointing.md b/docs/training/checkpointing.md index 0de08a3b9a..e7bdcfea03 100644 --- a/docs/training/checkpointing.md +++ b/docs/training/checkpointing.md @@ -277,6 +277,147 @@ Local checkpointing leverages the [NVIDIA Resiliency Extension](https://nvidia.g | `dist_ckpt_optim_fully_reshardable` | `bool` | `False` | Make optimizer distributed checkpoint fully reshardable (TP/PP/EP/DP) as opposed to plain DP reshardability | | `distrib_optim_fully_reshardable_mem_efficient` | `bool` | `False` | Use as little memory as possible during save and load by using Gloo. Has affect only with `dist_ckpt_optim_fully_reshardable` flag | +## Custom Checkpoint Manager + +For advanced use cases, you can provide a custom checkpoint manager implementation to override the default save/load behavior. This enables integration with custom storage backends, alternative checkpoint formats, or organization-specific checkpointing workflows. + +### Configuration + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `custom_manager_class` | `str \| None` | `None` | Fully qualified class name for a custom `CheckpointManager` implementation | + +### Usage + +Specify a custom checkpoint manager class in your configuration: + +**YAML:** +```yaml +checkpoint: + save: /path/to/checkpoints + custom_manager_class: "mypackage.checkpoint.MyCheckpointManager" +``` + +**Python:** +```python +from megatron.bridge.training.config import CheckpointConfig + +checkpoint = CheckpointConfig( + save="/path/to/checkpoints", + custom_manager_class="mypackage.checkpoint.MyCheckpointManager", +) +``` + +### Implementing a Custom Manager + +Your custom manager must implement the `CheckpointManager` protocol defined in `megatron.bridge.training.checkpointing`: + +```python +from megatron.bridge.training.checkpointing import ( + CheckpointManager, + CheckpointSaveContext, + CheckpointLoadContext, + save_checkpoint, + load_checkpoint, + init_checkpointing_context, +) +from megatron.bridge.training.config import CheckpointConfig +from megatron.bridge.training.state import GlobalState + + +class MyCheckpointManager: + """Custom checkpoint manager example.""" + + def __init__(self, checkpoint_config: CheckpointConfig) -> None: + self.checkpoint_config = checkpoint_config + # Initialize internal context for caching strategies + self._context = init_checkpointing_context(checkpoint_config) + + def save(self, ctx: CheckpointSaveContext) -> None: + """Save a checkpoint with custom logic.""" + # Option 1: Completely custom implementation + # my_custom_save(ctx.state, ctx.model, ...) + + # Option 2: Wrap the default implementation + save_checkpoint( + state=ctx.state, + model=ctx.model, + optimizer=ctx.optimizer, + opt_param_scheduler=ctx.opt_param_scheduler, + num_floating_point_operations_so_far=ctx.num_floating_point_operations_so_far, + checkpointing_context=self._context, + non_persistent_ckpt=ctx.non_persistent_ckpt, + train_data_iterator=ctx.train_data_iterator, + ) + # Add custom post-processing (e.g., upload to cloud) + upload_to_s3(ctx.state.cfg.checkpoint.save) + + def load(self, ctx: CheckpointLoadContext) -> tuple[int, int]: + """Load a checkpoint with custom logic.""" + # Returns (iteration, num_floating_point_operations_so_far) + return load_checkpoint( + state=ctx.state, + model=ctx.model, + optimizer=ctx.optimizer, + opt_param_scheduler=ctx.opt_param_scheduler, + strict=ctx.strict, + checkpointing_context=self._context, + skip_load_to_model_and_opt=ctx.skip_load_to_model_and_opt, + ) + + def finalize_async_saves( + self, state: GlobalState, blocking: bool = False, terminate: bool = False + ) -> None: + """Finalize any pending asynchronous saves.""" + from megatron.bridge.training.checkpointing import maybe_finalize_async_save + + maybe_finalize_async_save( + global_state=state, + ckpt_cfg=self.checkpoint_config, + blocking=blocking, + terminate=terminate, + ) +``` + +### Context Dataclasses + +The save and load methods receive context dataclasses that bundle all required parameters: + +**`CheckpointSaveContext`:** +| Field | Type | Description | +|-------|------|-------------| +| `state` | `GlobalState` | Global training state (config, train_state, loggers) | +| `model` | `list[MegatronModule]` | Model modules to save | +| `optimizer` | `MegatronOptimizer \| None` | Optimizer instance | +| `opt_param_scheduler` | `Any \| None` | Learning rate scheduler | +| `num_floating_point_operations_so_far` | `int` | Cumulative FLOPs | +| `train_data_iterator` | `Any \| None` | Data iterator (optional) | +| `non_persistent_ckpt` | `bool` | Whether this is a non-persistent checkpoint | + +**`CheckpointLoadContext`:** +| Field | Type | Description | +|-------|------|-------------| +| `state` | `GlobalState` | Global training state | +| `model` | `list[MegatronModule]` | Model modules to load into | +| `optimizer` | `MegatronOptimizer \| None` | Optimizer instance | +| `opt_param_scheduler` | `Any \| None` | Learning rate scheduler | +| `strict` | `bool` | Enforce strict loading (default: `True`) | +| `skip_load_to_model_and_opt` | `bool` | Skip loading into model/optimizer (default: `False`) | + +### Limitations + +The custom checkpoint manager is designed for customizing the save/load **operations** during training. The following limitations apply: + +**Checkpoint format compatibility**: Custom managers that change the checkpoint directory structure or metadata files (e.g., `latest_train_state.pt`, `run_config.yaml`) are not well supported. Many utilities in Megatron Bridge assume the standard Megatron checkpoint format. For instance, HuggingFace ↔ custom format conversion is not supported. + +**PEFT with custom checkpoints**: When using PEFT (Parameter-Efficient Fine-Tuning), the `pretrained_checkpoint` path must point to a Megatron-format checkpoint. The custom manager only applies to the training save/load flow (the `save` and `load` configuration paths), not to base model loading for PEFT. + +**Inference loading**: Loading checkpoints for inference via `model_load_save.py` utilities is undefined behavior with custom checkpoint formats. Use your custom format's loading utilities instead. + +### Default Behavior + +When `custom_manager_class` is not set, Megatron Bridge uses `DefaultCheckpointManager`, which wraps the existing `save_checkpoint` and `load_checkpoint` functions. This ensures full backward compatibility—the checkpoint manager abstraction introduces no changes to existing training workflows. + ## Related Documentation - {doc}`megatron-fsdp` - Megatron FSDP configuration and `fsdp_dtensor` format requirements diff --git a/src/megatron/bridge/training/checkpointing.py b/src/megatron/bridge/training/checkpointing.py index 00501a157f..111d55f180 100644 --- a/src/megatron/bridge/training/checkpointing.py +++ b/src/megatron/bridge/training/checkpointing.py @@ -20,11 +20,12 @@ import shutil import sys import threading +from dataclasses import dataclass from enum import Enum, auto from logging import getLogger from pathlib import Path from time import time -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Protocol, Union, runtime_checkable import numpy as np import torch @@ -464,6 +465,256 @@ class CheckpointType(Enum): FSDP_DTENSOR = auto() +@dataclass +class CheckpointSaveContext: + """Context containing all state needed for a checkpoint save operation. + + Attributes: + state: The GlobalState object containing config, train_state, etc. + model: List of model modules (MegatronModule instances). + optimizer: The optimizer instance (may be None for inference checkpoints). + opt_param_scheduler: The learning rate scheduler instance. + num_floating_point_operations_so_far: Cumulative FLOPs computed up to this point. + train_data_iterator: Optional training data iterator to save its state. + non_persistent_ckpt: If True, saves as a non-persistent (temporary) checkpoint. + """ + + state: GlobalState + model: list[MegatronModule] + optimizer: MegatronOptimizer | None + opt_param_scheduler: Any | None + num_floating_point_operations_so_far: int + + train_data_iterator: Any | None = None + non_persistent_ckpt: bool = False + + +@dataclass +class CheckpointLoadContext: + """Context containing all state needed for a checkpoint load operation. + + Attributes: + state: The GlobalState object containing config, train_state, etc. + model: List of model modules to load state into. + optimizer: The optimizer instance to load state into. + opt_param_scheduler: The learning rate scheduler instance. + strict: Whether to enforce strict loading (see torch.nn.Module.load_state_dict). + skip_load_to_model_and_opt: If True, only loads metadata but skips loading + state into model and optimizer modules. + """ + + state: GlobalState + model: list[MegatronModule] + optimizer: MegatronOptimizer | None + opt_param_scheduler: Any | None + + strict: bool = True + skip_load_to_model_and_opt: bool = False + + +@runtime_checkable +class CheckpointManager(Protocol): + """Protocol defining the checkpoint manager interface. + + Implement this protocol to create custom checkpoint save/load behavior. + The default implementation (DefaultCheckpointManager) delegates to the + existing functional checkpoint code. + """ + + def __init__(self, checkpoint_config: CheckpointConfig) -> None: + """Initialize the checkpoint manager. + + Args: + checkpoint_config: The checkpoint configuration. + """ + ... + + def save(self, ctx: CheckpointSaveContext) -> None: + """Save a checkpoint. + + Args: + ctx: CheckpointSaveContext containing all state needed for save. + """ + ... + + def load(self, ctx: CheckpointLoadContext) -> tuple[int, int]: + """Load a checkpoint. + + Args: + ctx: CheckpointLoadContext containing all state needed for load. + + Returns: + A tuple of (iteration, num_floating_point_operations_so_far). + Returns (0, 0) if no checkpoint was loaded. + """ + ... + + def finalize_async_saves(self, state: GlobalState, blocking: bool = False, terminate: bool = False) -> None: + """Finalize any pending asynchronous checkpoint saves. + + Args: + state: The GlobalState object (needed for async_calls_queue access). + blocking: If True, wait for all pending saves to complete. + terminate: If True, close the async queue after finalization. + """ + ... + + +class DefaultCheckpointManager: + """Default checkpoint manager that delegates to existing functional code. + + This implementation wraps the default save_checkpoint and load_checkpoint + functions. + + The manager owns the checkpointing_context dictionary, which is used to + cache strategies and local checkpoint managers across save/load operations. + + Attributes: + checkpoint_config: The CheckpointConfig instance. + _context: Internal context dictionary for caching checkpoint strategies. + """ + + def __init__(self, checkpoint_config: CheckpointConfig) -> None: + """Initialize the checkpoint manager. + + Args: + checkpoint_config: The checkpoint configuration. + """ + self.checkpoint_config = checkpoint_config + self._context: dict[str, Any] = init_checkpointing_context(checkpoint_config) + + @property + def checkpointing_context(self) -> dict[str, Any]: + """The internal checkpointing context dictionary. + + This context is passed to save/load functions and caches: + - Save/load strategies for distributed checkpointing + - Local checkpoint manager (for non-persistent local checkpoints) + - Cached metadata for constant-structure optimization + """ + return self._context + + def save(self, ctx: CheckpointSaveContext) -> None: + """Save a checkpoint using the default implementation. + + Delegates to save_checkpoint function. + + Args: + ctx: CheckpointSaveContext containing all state needed for save. + """ + save_checkpoint( + state=ctx.state, + model=ctx.model, + optimizer=ctx.optimizer, + opt_param_scheduler=ctx.opt_param_scheduler, + num_floating_point_operations_so_far=ctx.num_floating_point_operations_so_far, + checkpointing_context=self._context, + non_persistent_ckpt=ctx.non_persistent_ckpt, + train_data_iterator=ctx.train_data_iterator, + ) + + def load(self, ctx: CheckpointLoadContext) -> tuple[int, int]: + """Load a checkpoint using the default implementation. + + Delegates to load_checkpoint function. + + Args: + ctx: CheckpointLoadContext containing all state needed for load. + + Returns: + A tuple of (iteration, num_floating_point_operations_so_far). + """ + return load_checkpoint( + state=ctx.state, + model=ctx.model, + optimizer=ctx.optimizer, + opt_param_scheduler=ctx.opt_param_scheduler, + strict=ctx.strict, + checkpointing_context=self._context, + skip_load_to_model_and_opt=ctx.skip_load_to_model_and_opt, + ) + + def finalize_async_saves(self, state: GlobalState, blocking: bool = False, terminate: bool = False) -> None: + """Finalize any pending asynchronous checkpoint saves. + + Args: + state: The GlobalState object (needed for async_calls_queue access). + blocking: If True, wait for all pending saves to complete. + terminate: If True, close the async queue after finalization. + """ + maybe_finalize_async_save( + global_state=state, + ckpt_cfg=self.checkpoint_config, + blocking=blocking, + terminate=terminate, + ) + + +def create_checkpoint_manager(checkpoint_config: CheckpointConfig) -> CheckpointManager: + """Factory function to create a checkpoint manager. + + Creates either the default checkpoint manager or a custom manager + based on the checkpoint_config.custom_manager_class setting. + + Args: + checkpoint_config: The checkpoint configuration. If custom_manager_class is set, + it should be a fully qualified class name (e.g., "mypackage.module.MyManager"). + + Returns: + A CheckpointManager instance. + + Raises: + ImportError: If the custom manager module cannot be imported. + AttributeError: If the custom manager class is not found in the module. + ValueError: If custom_manager_class format is invalid. + TypeError: If the custom manager does not implement the CheckpointManager protocol. + + Example: + # Default manager + config = CheckpointConfig(save="/path/to/checkpoints") + manager = create_checkpoint_manager(config) + + # Custom manager + config = CheckpointConfig( + save="/path/to/checkpoints", + custom_manager_class="mypackage.checkpoint.MyCheckpointManager", + ) + manager = create_checkpoint_manager(config) + """ + if checkpoint_config.custom_manager_class is not None: + import importlib + + try: + module_path, class_name = checkpoint_config.custom_manager_class.rsplit(".", 1) + except ValueError as err: + raise ValueError( + f"Invalid custom_manager_class format: '{checkpoint_config.custom_manager_class}'. " + f"Expected fully qualified class name like 'mypackage.module.ClassName'." + ) from err + + try: + module = importlib.import_module(module_path) + except ImportError as e: + raise ImportError(f"Could not import module '{module_path}' for custom checkpoint manager: {e}") from e + + try: + custom_manager_class = getattr(module, class_name) + except AttributeError as e: + raise AttributeError(f"Module '{module_path}' does not have class '{class_name}': {e}") from e + + manager = custom_manager_class(checkpoint_config) + + if not isinstance(manager, CheckpointManager): + raise TypeError( + f"Custom checkpoint manager '{checkpoint_config.custom_manager_class}' " + f"does not implement the CheckpointManager protocol." + ) + + return manager + + return DefaultCheckpointManager(checkpoint_config) + + def save_checkpoint( state: GlobalState, model: list[MegatronModule], diff --git a/src/megatron/bridge/training/config.py b/src/megatron/bridge/training/config.py index 1afecd838c..daf6853f43 100644 --- a/src/megatron/bridge/training/config.py +++ b/src/megatron/bridge/training/config.py @@ -985,6 +985,16 @@ class CheckpointConfig: replication_factor: int = 2 """Number of machines storing the replica of a given rank's data.""" + custom_manager_class: str | None = None + """Fully qualified class name for a custom CheckpointManager implementation. + + When set, checkpoint operations will instantiate and delegate to this class instead of the default + checkpoint manager. The custom class must implement the `CheckpointManager` protocol + defined in `megatron.bridge.training.checkpointing`. + + Example: ``'mypackage.checkpoint.MyCheckpointManager'`` + """ + def finalize(self) -> None: """Post-initialization checks for checkpoint config.""" if self.pretrained_checkpoint is not None: diff --git a/src/megatron/bridge/training/pretrain.py b/src/megatron/bridge/training/pretrain.py index e8a5f91ce1..6c26bf78b1 100644 --- a/src/megatron/bridge/training/pretrain.py +++ b/src/megatron/bridge/training/pretrain.py @@ -133,7 +133,7 @@ def _pretrain( train_data_iterator = setup_output.train_data_iterator valid_data_iterator = setup_output.valid_data_iterator test_data_iterator = setup_output.test_data_iterator - ckpt_context = setup_output.checkpointing_context + checkpoint_manager = setup_output.checkpoint_manager pg_collection = setup_output.pg_collection # TRAINING @@ -147,7 +147,7 @@ def _pretrain( train_data_iterator, valid_data_iterator, state, - ckpt_context, + checkpoint_manager, pg_collection, callback_manager=callback_manager, ) @@ -188,7 +188,7 @@ def _pretrain( is_test=True, ) - _finish_train(state) + _finish_train(state, checkpoint_manager) _maybe_destroy_process_group(should_destroy_process_group) diff --git a/src/megatron/bridge/training/setup.py b/src/megatron/bridge/training/setup.py index bc6fa199c9..dadba57649 100644 --- a/src/megatron/bridge/training/setup.py +++ b/src/megatron/bridge/training/setup.py @@ -40,8 +40,9 @@ from megatron.bridge.training.checkpointing import ( _load_checkpoint_from_path, checkpoint_exists, - init_checkpointing_context, - load_checkpoint, + CheckpointLoadContext, + CheckpointManager, + create_checkpoint_manager, ) from megatron.bridge.training.config import ConfigContainer, runtime_config_update from megatron.bridge.training.initialize import initialize_megatron, set_jit_fusion_options @@ -69,8 +70,7 @@ class SetupOutput(NamedTuple): train_data_iterator: The data iterator for the training dataset, if applicable. valid_data_iterator: The data iterator for the validation dataset, if applicable. test_data_iterator: The data iterator for the testing dataset, if applicable. - checkpointing_context: A dictionary holding context for checkpointing operations, - especially for non-persistent local checkpointing. + checkpoint_manager: The checkpoint manager for save/load operations. pg_collection: The process group collection initialized for this run. """ @@ -81,7 +81,7 @@ class SetupOutput(NamedTuple): train_data_iterator: Optional[RerunDataIterator | list[RerunDataIterator]] valid_data_iterator: Optional[RerunDataIterator | list[RerunDataIterator]] test_data_iterator: Optional[RerunDataIterator | list[RerunDataIterator]] - checkpointing_context: dict[str, Any] + checkpoint_manager: CheckpointManager pg_collection: ProcessGroupCollection @@ -173,8 +173,8 @@ def setup( print_rank_0("time to initialize megatron (seconds): {:.3f}".format(time.time() - state.start_time)) barrier_and_log("after megatron is initialized") - # Context used for persisting some state between checkpoint saves. - checkpointing_context = init_checkpointing_context(cfg.checkpoint) + # Create checkpoint manager for save/load operations. + checkpoint_manager = create_checkpoint_manager(cfg.checkpoint) # Tokenizer timers("tokenizer-setup", log_level=0).start(barrier=True) @@ -244,9 +244,10 @@ def modelopt_pre_wrap_hook(model): # checkpoints are independent of global ones — they don't write # latest_train_state.pt to load_dir, so checkpoint_exists() won't # find them. + _ckpt_ctx = getattr(checkpoint_manager, "checkpointing_context", {}) has_local_checkpoint = ( - "local_checkpoint_manager" in checkpointing_context - and checkpointing_context["local_checkpoint_manager"].find_latest() != -1 + "local_checkpoint_manager" in _ckpt_ctx + and _ckpt_ctx["local_checkpoint_manager"].find_latest() != -1 ) # For PEFT, the pretrained checkpoint is loaded in the pre-wrap hook @@ -265,14 +266,13 @@ def modelopt_pre_wrap_hook(model): if should_load_checkpoint: timers("load-checkpoint", log_level=0).start(barrier=True) - load_checkpoint( - state, - model, - optimizer, - scheduler, - checkpointing_context=checkpointing_context, + checkpoint_manager.load(CheckpointLoadContext( + state=state, + model=model, + optimizer=optimizer, + opt_param_scheduler=scheduler, skip_load_to_model_and_opt=cfg.dist.use_torch_fsdp2 or cfg.dist.use_megatron_fsdp, - ) + )) timers("load-checkpoint").stop(barrier=True) timers.log(["load-checkpoint"]) @@ -329,7 +329,7 @@ def modelopt_pre_wrap_hook(model): train_data_iterator, valid_data_iterator, test_data_iterator, - checkpointing_context, + checkpoint_manager, pg_collection, ) diff --git a/src/megatron/bridge/training/train.py b/src/megatron/bridge/training/train.py index 858a2bb75a..250001522e 100644 --- a/src/megatron/bridge/training/train.py +++ b/src/megatron/bridge/training/train.py @@ -53,7 +53,10 @@ from megatron.bridge.data.iterator_utils import make_data_iterator_list from megatron.bridge.training import fault_tolerance from megatron.bridge.training.callbacks import CallbackContext, CallbackManager, should_fire -from megatron.bridge.training.checkpointing import maybe_finalize_async_save, save_checkpoint +from megatron.bridge.training.checkpointing import ( + CheckpointManager, + CheckpointSaveContext, +) from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.eval import evaluate_and_print_results from megatron.bridge.training.forward_step_func_types import ForwardStepCallable @@ -93,7 +96,7 @@ def train( train_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]], valid_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]], global_state: GlobalState, - checkpointing_context: dict[str, Any], + checkpoint_manager: CheckpointManager, pg_collection: ProcessGroupCollection, process_non_loss_data_func: Optional[Callable] = None, non_loss_data_func: Optional[Callable] = None, @@ -112,7 +115,7 @@ def train( train_data_iterator: Iterator for the training dataset. valid_data_iterator: Iterator for the validation dataset. global_state: The GlobalState object holding various training states. - checkpointing_context: Context dictionary for checkpointing. + checkpoint_manager: The checkpoint manager for save/load operations. process_non_loss_data_func: Optional function to process non-loss data during evaluation. non_loss_data_func: Optional function to compute non-loss data during evaluation. callback_manager: Optional CallbackManager for custom callback execution. @@ -296,7 +299,7 @@ def train( nsys_nvtx_context = nvtx_ctx fault_tolerance.on_checkpointing_start(global_state) - maybe_finalize_async_save(global_state=global_state, ckpt_cfg=config.checkpoint, blocking=False) + checkpoint_manager.finalize_async_saves(state=global_state, blocking=False) fault_tolerance.on_checkpointing_end(global_state=global_state, is_async_finalization=True) # Update the timeout for all process groups after initialization @@ -324,7 +327,7 @@ def train( optimizer, scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, non_persistent_ckpt=False, # TODO: implement non-persistent checkpointing train_data_iterator=train_data_iterator, ) @@ -415,7 +418,7 @@ def train( optimizer, scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, train_data_iterator=train_data_iterator, non_persistent_ckpt=False, # TODO: implement non-persistent checkpointing ) @@ -582,7 +585,7 @@ def train( optimizer, scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, train_data_iterator, ) if should_exit: @@ -605,7 +608,7 @@ def train( optimizer, scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, train_data_iterator=train_data_iterator, ) @@ -623,7 +626,7 @@ def train( # This will finalize all unfinalized async request and terminate # a persistent async worker if persistent ckpt worker is enabled fault_tolerance.on_checkpointing_start(global_state) - maybe_finalize_async_save(global_state=global_state, ckpt_cfg=config.checkpoint, blocking=True, terminate=True) + checkpoint_manager.finalize_async_saves(state=global_state, blocking=True, terminate=True) fault_tolerance.on_checkpointing_end(global_state=global_state, is_async_finalization=True) # Shutdown NVRx straggler detection if enabled @@ -639,7 +642,7 @@ def train( if should_exit: # Close NVIDIA DLFw Inspect if enabled tensor_inspect_end_if_enabled(config.tensor_inspect) - maybe_finalize_async_save(global_state=global_state, ckpt_cfg=config.checkpoint, blocking=True, terminate=True) + checkpoint_manager.finalize_async_saves(state=global_state, blocking=True, terminate=True) wandb_writer = global_state.wandb_logger if wandb_writer: wandb_writer.finish() @@ -1091,13 +1094,13 @@ def save_checkpoint_and_time( optimizer: MegatronOptimizer, opt_param_scheduler: OptimizerParamScheduler, num_floating_point_operations_so_far: float, - checkpointing_context: dict[str, Any], + checkpoint_manager: CheckpointManager, non_persistent_ckpt: bool = False, train_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]] = None, ) -> None: """Saves a checkpoint and logs the timing. - Wraps the `save_checkpoint` function with timers and forces parameter + Wraps the checkpoint manager's save method with timers and forces parameter synchronization when using distributed optimizer with overlapped parameter gather to ensure checkpoint correctness. @@ -1107,7 +1110,7 @@ def save_checkpoint_and_time( optimizer: The optimizer instance. opt_param_scheduler: The optimizer parameter scheduler instance. num_floating_point_operations_so_far: Cumulative Model TFLOPs up to this point. - checkpointing_context: Dictionary holding checkpointing-related state. + checkpoint_manager: The checkpoint manager for save operations. non_persistent_ckpt: Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False. train_data_iterator: Optional training data iterator to save its state. @@ -1142,16 +1145,18 @@ def save_checkpoint_and_time( model_chunk.free_overlap_buffers() torch.cuda.empty_cache() - save_checkpoint( - state, - model, - optimizer, - opt_param_scheduler, - num_floating_point_operations_so_far, - checkpointing_context=checkpointing_context, - non_persistent_ckpt=non_persistent_ckpt, - train_data_iterator=train_data_iterator, + checkpoint_manager.save( + CheckpointSaveContext( + state=state, + model=model, + optimizer=optimizer, + opt_param_scheduler=opt_param_scheduler, + num_floating_point_operations_so_far=int(num_floating_point_operations_so_far), + train_data_iterator=train_data_iterator, + non_persistent_ckpt=non_persistent_ckpt, + ) ) + if state.cfg.model.fp8 is not None: # Run garbage collection after checkpoint saving to free memory from # dequantized bf16 tensors that were temporarily created during fp8 @@ -1175,7 +1180,7 @@ def checkpoint_and_decide_exit( optimizer: MegatronOptimizer, opt_param_scheduler: OptimizerParamScheduler, num_floating_point_operations_so_far: float, - checkpointing_context: dict[str, Any], + checkpoint_manager: CheckpointManager, train_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]], ) -> bool: """Handles checkpointing decisions and determines if training should exit. @@ -1190,7 +1195,7 @@ def checkpoint_and_decide_exit( optimizer: The optimizer instance. opt_param_scheduler: The optimizer parameter scheduler instance. num_floating_point_operations_so_far: Cumulative TFLOPs up to this point. - checkpointing_context: Dictionary holding checkpointing-related state. + checkpoint_manager: The checkpoint manager for save operations. train_data_iterator: Optional training data iterator to save its state. Returns: @@ -1209,7 +1214,7 @@ def checkpoint_and_decide_exit( optimizer, opt_param_scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, train_data_iterator=train_data_iterator, ) barrier_and_log("exiting program after receiving SIGTERM.") @@ -1228,7 +1233,7 @@ def checkpoint_and_decide_exit( optimizer, opt_param_scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, train_data_iterator=train_data_iterator, ) saved_checkpoint = True @@ -1244,7 +1249,7 @@ def checkpoint_and_decide_exit( optimizer, opt_param_scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, non_persistent_ckpt=True, train_data_iterator=train_data_iterator, ) @@ -1264,7 +1269,7 @@ def checkpoint_and_decide_exit( optimizer, opt_param_scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, train_data_iterator=train_data_iterator, ) barrier_and_log(f"exiting program after {train_time} minutes") @@ -1280,7 +1285,7 @@ def checkpoint_and_decide_exit( optimizer, opt_param_scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, train_data_iterator=train_data_iterator, ) barrier_and_log(f"exiting program at iteration {state.train_state.step}") @@ -1296,7 +1301,7 @@ def checkpoint_and_decide_exit( optimizer, opt_param_scheduler, num_floating_point_operations_so_far, - checkpointing_context, + checkpoint_manager, train_data_iterator=train_data_iterator, ) barrier_and_log("Exiting program due to straggler detection.") @@ -1305,14 +1310,18 @@ def checkpoint_and_decide_exit( return False -def _finish_train(global_state: GlobalState): - ckpt_cfg = global_state.cfg.checkpoint +def _finish_train(global_state: GlobalState, checkpoint_manager: CheckpointManager): + """Cleanup function called at the end of training. + Args: + global_state: The global training state. + checkpoint_manager: The checkpoint manager for finalizing async saves. + """ # Shutdown NVRx straggler detection if enabled safe_shutdown_nvrx_straggler_manager(global_state.nvrx_straggler_manager) fault_tolerance.on_checkpointing_start(global_state) - maybe_finalize_async_save(global_state=global_state, blocking=True, terminate=True, ckpt_cfg=ckpt_cfg) + checkpoint_manager.finalize_async_saves(state=global_state, blocking=True, terminate=True) fault_tolerance.on_checkpointing_end(global_state=global_state, is_async_finalization=True) fault_tolerance.shutdown(global_state) diff --git a/tests/unit_tests/training/test_checkpointing.py b/tests/unit_tests/training/test_checkpointing.py index 3eb3dd7d6d..9cc3743902 100644 --- a/tests/unit_tests/training/test_checkpointing.py +++ b/tests/unit_tests/training/test_checkpointing.py @@ -24,7 +24,11 @@ from megatron.bridge.training.checkpointing import ( _DIRECT_ITERATION_DIR_SENTINEL, + CheckpointLoadContext, + CheckpointManager, + CheckpointSaveContext, CheckpointType, + DefaultCheckpointManager, _extract_megatron_lm_args_from_state_dict, _get_checkpoint_format, _get_non_persistent_iteration, @@ -32,6 +36,7 @@ _load_model_state_dict, checkpoint_exists, cleanup_old_non_persistent_checkpoint, + create_checkpoint_manager, delete_extra_state, ensure_directory_exists, find_checkpoint_rank_0, @@ -2797,3 +2802,278 @@ def test_fsdp_dtensor_skips_tracker_resolution(self, mock_reader, mock_unwrap, m # with the direct path (not a tracker-resolved path). mock_reader.assert_called_once_with("/ckpt/iter_0001000") mock_is_iter_dir.assert_called_once_with("/ckpt/iter_0001000") + + +class TestCheckpointManager: + """Tests for the CheckpointManager interface and DefaultCheckpointManager.""" + + def test_checkpoint_save_context_creation(self): + """Test CheckpointSaveContext dataclass can be created with required fields.""" + mock_state = Mock(spec=GlobalState) + mock_model = [Mock()] + mock_optimizer = Mock() + mock_scheduler = Mock() + + ctx = CheckpointSaveContext( + state=mock_state, + model=mock_model, + optimizer=mock_optimizer, + opt_param_scheduler=mock_scheduler, + num_floating_point_operations_so_far=1000, + ) + + assert ctx.state is mock_state + assert ctx.model is mock_model + assert ctx.optimizer is mock_optimizer + assert ctx.opt_param_scheduler is mock_scheduler + assert ctx.num_floating_point_operations_so_far == 1000 + assert ctx.train_data_iterator is None + assert ctx.non_persistent_ckpt is False + + def test_checkpoint_save_context_with_optional_fields(self): + """Test CheckpointSaveContext with optional fields set.""" + mock_state = Mock(spec=GlobalState) + mock_iterator = Mock() + + ctx = CheckpointSaveContext( + state=mock_state, + model=[Mock()], + optimizer=None, + opt_param_scheduler=None, + num_floating_point_operations_so_far=0, + train_data_iterator=mock_iterator, + non_persistent_ckpt=True, + ) + + assert ctx.train_data_iterator is mock_iterator + assert ctx.non_persistent_ckpt is True + + def test_checkpoint_load_context_creation(self): + """Test CheckpointLoadContext dataclass can be created with required fields.""" + mock_state = Mock(spec=GlobalState) + mock_model = [Mock()] + mock_optimizer = Mock() + mock_scheduler = Mock() + + ctx = CheckpointLoadContext( + state=mock_state, + model=mock_model, + optimizer=mock_optimizer, + opt_param_scheduler=mock_scheduler, + ) + + assert ctx.state is mock_state + assert ctx.model is mock_model + assert ctx.optimizer is mock_optimizer + assert ctx.opt_param_scheduler is mock_scheduler + assert ctx.strict is True + assert ctx.skip_load_to_model_and_opt is False + + def test_checkpoint_load_context_with_optional_fields(self): + """Test CheckpointLoadContext with optional fields set.""" + mock_state = Mock(spec=GlobalState) + + ctx = CheckpointLoadContext( + state=mock_state, + model=[Mock()], + optimizer=None, + opt_param_scheduler=None, + strict=False, + skip_load_to_model_and_opt=True, + ) + + assert ctx.strict is False + assert ctx.skip_load_to_model_and_opt is True + + def test_create_checkpoint_manager_returns_default(self): + """Test create_checkpoint_manager returns DefaultCheckpointManager when no custom class.""" + config = CheckpointConfig() + + manager = create_checkpoint_manager(config) + + assert isinstance(manager, DefaultCheckpointManager) + assert manager.checkpoint_config is config + + def test_create_checkpoint_manager_invalid_format_raises(self): + """Test create_checkpoint_manager raises ValueError for invalid class format.""" + config = CheckpointConfig(custom_manager_class="InvalidClassName") + + with pytest.raises(ValueError, match="Invalid custom_manager_class format"): + create_checkpoint_manager(config) + + def test_create_checkpoint_manager_missing_module_raises(self): + """Test create_checkpoint_manager raises ImportError for non-existent module.""" + config = CheckpointConfig(custom_manager_class="nonexistent.module.ClassName") + + with pytest.raises(ImportError, match="Could not import module"): + create_checkpoint_manager(config) + + def test_create_checkpoint_manager_missing_class_raises(self): + """Test create_checkpoint_manager raises AttributeError for non-existent class.""" + # Use a real module but non-existent class + config = CheckpointConfig(custom_manager_class="os.path.NonExistentClass") + + with pytest.raises(AttributeError, match="does not have class"): + create_checkpoint_manager(config) + + def test_create_checkpoint_manager_custom_class(self): + """Test create_checkpoint_manager loads and instantiates a custom class.""" + + # Create a simple custom manager class for testing + class CustomTestManager: + def __init__(self, checkpoint_config): + self.checkpoint_config = checkpoint_config + self.initialized = True + + def save(self, _ctx): + pass + + def load(self, _ctx): + return (0, 0) + + def finalize_async_saves(self, state, blocking=False, terminate=False): + pass + + # Patch the import to return our test class + with patch("importlib.import_module") as mock_import: + mock_module = Mock() + mock_module.CustomTestManager = CustomTestManager + mock_import.return_value = mock_module + + config = CheckpointConfig(custom_manager_class="test.module.CustomTestManager") + manager = create_checkpoint_manager(config) + + assert isinstance(manager, CustomTestManager) + assert manager.checkpoint_config is config + assert manager.initialized is True + + def test_default_checkpoint_manager_init(self): + """Test DefaultCheckpointManager initialization.""" + config = CheckpointConfig() + + with patch("megatron.bridge.training.checkpointing.init_checkpointing_context") as mock_init_ctx: + mock_init_ctx.return_value = {"test_key": "test_value"} + manager = DefaultCheckpointManager(config) + + assert manager.checkpoint_config is config + mock_init_ctx.assert_called_once_with(config) + assert manager.checkpointing_context == {"test_key": "test_value"} + + def test_default_checkpoint_manager_save_delegates(self): + """Test DefaultCheckpointManager.save() delegates to save_checkpoint.""" + config = CheckpointConfig() + mock_state = Mock(spec=GlobalState) + mock_model = [Mock()] + mock_optimizer = Mock() + mock_scheduler = Mock() + + ctx = CheckpointSaveContext( + state=mock_state, + model=mock_model, + optimizer=mock_optimizer, + opt_param_scheduler=mock_scheduler, + num_floating_point_operations_so_far=5000, + train_data_iterator=Mock(), + non_persistent_ckpt=True, + ) + + with ( + patch("megatron.bridge.training.checkpointing.init_checkpointing_context") as mock_init_ctx, + patch("megatron.bridge.training.checkpointing.save_checkpoint") as mock_save, + ): + mock_init_ctx.return_value = {"context": "data"} + manager = DefaultCheckpointManager(config) + manager.save(ctx) + + mock_save.assert_called_once_with( + state=mock_state, + model=mock_model, + optimizer=mock_optimizer, + opt_param_scheduler=mock_scheduler, + num_floating_point_operations_so_far=5000, + checkpointing_context={"context": "data"}, + non_persistent_ckpt=True, + train_data_iterator=ctx.train_data_iterator, + ) + + def test_default_checkpoint_manager_load_delegates(self): + """Test DefaultCheckpointManager.load() delegates to load_checkpoint.""" + config = CheckpointConfig() + mock_state = Mock(spec=GlobalState) + mock_model = [Mock()] + mock_optimizer = Mock() + mock_scheduler = Mock() + + ctx = CheckpointLoadContext( + state=mock_state, + model=mock_model, + optimizer=mock_optimizer, + opt_param_scheduler=mock_scheduler, + strict=False, + skip_load_to_model_and_opt=True, + ) + + with ( + patch("megatron.bridge.training.checkpointing.init_checkpointing_context") as mock_init_ctx, + patch("megatron.bridge.training.checkpointing.load_checkpoint") as mock_load, + ): + mock_init_ctx.return_value = {"context": "data"} + mock_load.return_value = (100, 50000) + + manager = DefaultCheckpointManager(config) + result = manager.load(ctx) + + mock_load.assert_called_once_with( + state=mock_state, + model=mock_model, + optimizer=mock_optimizer, + opt_param_scheduler=mock_scheduler, + strict=False, + checkpointing_context={"context": "data"}, + skip_load_to_model_and_opt=True, + ) + assert result == (100, 50000) + + def test_default_checkpoint_manager_finalize_async_saves(self): + """Test DefaultCheckpointManager.finalize_async_saves() delegates correctly.""" + config = CheckpointConfig() + mock_state = Mock(spec=GlobalState) + + with ( + patch("megatron.bridge.training.checkpointing.init_checkpointing_context") as mock_init_ctx, + patch("megatron.bridge.training.checkpointing.maybe_finalize_async_save") as mock_finalize, + ): + mock_init_ctx.return_value = {} + + manager = DefaultCheckpointManager(config) + manager.finalize_async_saves(state=mock_state, blocking=True, terminate=True) + + mock_finalize.assert_called_once_with( + global_state=mock_state, + ckpt_cfg=config, + blocking=True, + terminate=True, + ) + + def test_checkpoint_manager_protocol_compliance(self): + """Test that DefaultCheckpointManager satisfies the CheckpointManager protocol.""" + config = CheckpointConfig() + + with patch("megatron.bridge.training.checkpointing.init_checkpointing_context"): + manager = DefaultCheckpointManager(config) + + # Protocol check using isinstance (works because @runtime_checkable) + assert isinstance(manager, CheckpointManager) + + def test_checkpointing_context_property(self): + """Test DefaultCheckpointManager.checkpointing_context property.""" + config = CheckpointConfig() + expected_context = {"local_checkpoint_manager": Mock(), "save_strategy": Mock()} + + with patch("megatron.bridge.training.checkpointing.init_checkpointing_context") as mock_init_ctx: + mock_init_ctx.return_value = expected_context + + manager = DefaultCheckpointManager(config) + + assert manager.checkpointing_context is expected_context + assert manager._context is expected_context diff --git a/tests/unit_tests/training/test_train.py b/tests/unit_tests/training/test_train.py index 0ef614d1bb..c2893ff481 100644 --- a/tests/unit_tests/training/test_train.py +++ b/tests/unit_tests/training/test_train.py @@ -18,6 +18,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch +import pytest from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel as megatron_FSDP from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer @@ -38,6 +39,9 @@ from megatron.bridge.training.utils.train_utils import maybe_inject_state +pytestmark = pytest.mark.unit + + class TestFSDPRegistration: """Unit tests for FSDP buffer manual registration.""" @@ -505,15 +509,14 @@ def _make_state(self) -> tuple[SimpleNamespace, Mock]: @patch("megatron.bridge.training.train.force_param_sync") @patch("megatron.bridge.training.train.should_disable_forward_pre_hook", return_value=True) - @patch("megatron.bridge.training.train.save_checkpoint") def test_param_sync_forced_when_overlap_enabled( self, - mock_save_checkpoint, mock_should_disable, mock_force_param_sync, ): state, _ = self._make_state() model = [Mock()] + mock_checkpoint_manager = Mock() save_checkpoint_and_time( state=state, @@ -521,24 +524,23 @@ def test_param_sync_forced_when_overlap_enabled( optimizer=Mock(), opt_param_scheduler=Mock(), num_floating_point_operations_so_far=123.0, - checkpointing_context={}, + checkpoint_manager=mock_checkpoint_manager, ) mock_should_disable.assert_called_once_with(False, True, True) mock_force_param_sync.assert_called_once_with(model) - mock_save_checkpoint.assert_called_once() + mock_checkpoint_manager.save.assert_called_once() @patch("megatron.bridge.training.train.force_param_sync") @patch("megatron.bridge.training.train.should_disable_forward_pre_hook", return_value=False) - @patch("megatron.bridge.training.train.save_checkpoint") def test_param_sync_skipped_when_not_required( self, - mock_save_checkpoint, mock_should_disable, mock_force_param_sync, ): state, _ = self._make_state() model = [Mock()] + mock_checkpoint_manager = Mock() save_checkpoint_and_time( state=state, @@ -546,12 +548,12 @@ def test_param_sync_skipped_when_not_required( optimizer=Mock(), opt_param_scheduler=Mock(), num_floating_point_operations_so_far=123.0, - checkpointing_context={}, + checkpoint_manager=mock_checkpoint_manager, ) mock_should_disable.assert_called_once_with(False, True, True) mock_force_param_sync.assert_not_called() - mock_save_checkpoint.assert_called_once() + mock_checkpoint_manager.save.assert_called_once() class TestCheckpointAndDecideExit: @@ -598,7 +600,7 @@ def _create_mock_args(self): "optimizer": Mock(), "opt_param_scheduler": Mock(), "num_floating_point_operations_so_far": 1000.0, - "checkpointing_context": {}, + "checkpoint_manager": Mock(), "train_data_iterator": None, }