Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions docs/training/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,147 @@ Local checkpointing leverages the [NVIDIA Resiliency Extension](https://nvidia.g
| `dist_ckpt_optim_fully_reshardable` | `bool` | `False` | Make optimizer distributed checkpoint fully reshardable (TP/PP/EP/DP) as opposed to plain DP reshardability |
| `distrib_optim_fully_reshardable_mem_efficient` | `bool` | `False` | Use as little memory as possible during save and load by using Gloo. Has affect only with `dist_ckpt_optim_fully_reshardable` flag |

## Custom Checkpoint Manager

For advanced use cases, you can provide a custom checkpoint manager implementation to override the default save/load behavior. This enables integration with custom storage backends, alternative checkpoint formats, or organization-specific checkpointing workflows.

### Configuration

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `custom_manager_class` | `str \| None` | `None` | Fully qualified class name for a custom `CheckpointManager` implementation |

### Usage

Specify a custom checkpoint manager class in your configuration:

**YAML:**
```yaml
checkpoint:
save: /path/to/checkpoints
custom_manager_class: "mypackage.checkpoint.MyCheckpointManager"
```

**Python:**
```python
from megatron.bridge.training.config import CheckpointConfig

checkpoint = CheckpointConfig(
save="/path/to/checkpoints",
custom_manager_class="mypackage.checkpoint.MyCheckpointManager",
)
```

### Implementing a Custom Manager

Your custom manager must implement the `CheckpointManager` protocol defined in `megatron.bridge.training.checkpointing`:

```python
from megatron.bridge.training.checkpointing import (
CheckpointManager,
CheckpointSaveContext,
CheckpointLoadContext,
save_checkpoint,
load_checkpoint,
init_checkpointing_context,
)
from megatron.bridge.training.config import CheckpointConfig
from megatron.bridge.training.state import GlobalState


class MyCheckpointManager:
"""Custom checkpoint manager example."""

def __init__(self, checkpoint_config: CheckpointConfig) -> None:
self.checkpoint_config = checkpoint_config
# Initialize internal context for caching strategies
self._context = init_checkpointing_context(checkpoint_config)

def save(self, ctx: CheckpointSaveContext) -> None:
"""Save a checkpoint with custom logic."""
# Option 1: Completely custom implementation
# my_custom_save(ctx.state, ctx.model, ...)

# Option 2: Wrap the default implementation
save_checkpoint(
state=ctx.state,
model=ctx.model,
optimizer=ctx.optimizer,
opt_param_scheduler=ctx.opt_param_scheduler,
num_floating_point_operations_so_far=ctx.num_floating_point_operations_so_far,
checkpointing_context=self._context,
non_persistent_ckpt=ctx.non_persistent_ckpt,
train_data_iterator=ctx.train_data_iterator,
)
# Add custom post-processing (e.g., upload to cloud)
upload_to_s3(ctx.state.cfg.checkpoint.save)

def load(self, ctx: CheckpointLoadContext) -> tuple[int, int]:
"""Load a checkpoint with custom logic."""
# Returns (iteration, num_floating_point_operations_so_far)
return load_checkpoint(
state=ctx.state,
model=ctx.model,
optimizer=ctx.optimizer,
opt_param_scheduler=ctx.opt_param_scheduler,
strict=ctx.strict,
checkpointing_context=self._context,
skip_load_to_model_and_opt=ctx.skip_load_to_model_and_opt,
)

def finalize_async_saves(
self, state: GlobalState, blocking: bool = False, terminate: bool = False
) -> None:
"""Finalize any pending asynchronous saves."""
from megatron.bridge.training.checkpointing import maybe_finalize_async_save

maybe_finalize_async_save(
global_state=state,
ckpt_cfg=self.checkpoint_config,
blocking=blocking,
terminate=terminate,
)
```

### Context Dataclasses

The save and load methods receive context dataclasses that bundle all required parameters:

**`CheckpointSaveContext`:**
| Field | Type | Description |
|-------|------|-------------|
| `state` | `GlobalState` | Global training state (config, train_state, loggers) |
| `model` | `list[MegatronModule]` | Model modules to save |
| `optimizer` | `MegatronOptimizer \| None` | Optimizer instance |
| `opt_param_scheduler` | `Any \| None` | Learning rate scheduler |
| `num_floating_point_operations_so_far` | `int` | Cumulative FLOPs |
| `train_data_iterator` | `Any \| None` | Data iterator (optional) |
| `non_persistent_ckpt` | `bool` | Whether this is a non-persistent checkpoint |

**`CheckpointLoadContext`:**
| Field | Type | Description |
|-------|------|-------------|
| `state` | `GlobalState` | Global training state |
| `model` | `list[MegatronModule]` | Model modules to load into |
| `optimizer` | `MegatronOptimizer \| None` | Optimizer instance |
| `opt_param_scheduler` | `Any \| None` | Learning rate scheduler |
| `strict` | `bool` | Enforce strict loading (default: `True`) |
| `skip_load_to_model_and_opt` | `bool` | Skip loading into model/optimizer (default: `False`) |

### Limitations

The custom checkpoint manager is designed for customizing the save/load **operations** during training. The following limitations apply:

**Checkpoint format compatibility**: Custom managers that change the checkpoint directory structure or metadata files (e.g., `latest_train_state.pt`, `run_config.yaml`) are not well supported. Many utilities in Megatron Bridge assume the standard Megatron checkpoint format. For instance, HuggingFace ↔ custom format conversion is not supported.

**PEFT with custom checkpoints**: When using PEFT (Parameter-Efficient Fine-Tuning), the `pretrained_checkpoint` path must point to a Megatron-format checkpoint. The custom manager only applies to the training save/load flow (the `save` and `load` configuration paths), not to base model loading for PEFT.

**Inference loading**: Loading checkpoints for inference via `model_load_save.py` utilities is undefined behavior with custom checkpoint formats. Use your custom format's loading utilities instead.

### Default Behavior

When `custom_manager_class` is not set, Megatron Bridge uses `DefaultCheckpointManager`, which wraps the existing `save_checkpoint` and `load_checkpoint` functions. This ensures full backward compatibility—the checkpoint manager abstraction introduces no changes to existing training workflows.

## Related Documentation

- {doc}`megatron-fsdp` - Megatron FSDP configuration and `fsdp_dtensor` format requirements
Expand Down
Loading
Loading