Skip to content

Allow custom checkpoint impl via checkpoint manager interface#1955

Merged
yaoyu-33 merged 7 commits intoNVIDIA-NeMo:mainfrom
ananthsub:ckpt-interface
Mar 25, 2026
Merged

Allow custom checkpoint impl via checkpoint manager interface#1955
yaoyu-33 merged 7 commits intoNVIDIA-NeMo:mainfrom
ananthsub:ckpt-interface

Conversation

@ananthsub
Copy link
Contributor

@ananthsub ananthsub commented Jan 15, 2026

What does this PR do ?

Currently, checkpointing save/load is implemented purely through functional calls, making customization by users impossible without direct modification/forking.

This PR introduces a CheckpointManager protocol to allow users to inject custom save/load implementations without needing to fork the code.

Protocol overview:

  • CheckpointManager: main protocol to implement save/load through. finalizing async calls is also routed through this protocl now too for cleanup
  • Auxiliary dataclasses to encapsulate save/load argument inputs

CheckpointConfig is augmented with a field to specify the path to the custom manager class. If specified, the class is imported and used during training. A factory function is added to instantiate the manager to use for training. A default implementation is provided, which wraps around the existing functions. The default implementation also stores the checkpoint context map, used for local checkpointing. There are no changes made to these functions in this PR, so the core checkpoint saving/loading logic is untouched.

Integration points:

  • setup.py is updated to create the checkpoint manager and use it for the initial checkpoint load
  • in train.py in the main training loop, the checkpoint manager is used for all saves + async finalization
  • in pretrain.py, the checkpoint manager is used for the end of train checkpoint save

Changelog

  • Enable custom checkpoint implementations via CheckpointManager interface

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

  • New Features

    • Custom checkpoint manager configuration - users can now specify a custom checkpoint manager class through configuration to customize checkpoint save/load behavior.
  • Documentation

    • Comprehensive documentation added for creating and using custom checkpoint managers, including configuration guidance, usage examples, and implementation patterns.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ananthsub
Copy link
Contributor Author

/ok to test 7d3064c

@ananthsub
Copy link
Contributor Author

/ok to test 0425c2a

Copy link
Contributor

@maanug-nv maanug-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will work great for pretraining, but I have some questions outside of that use case.

  1. I think several functions in model_load_save.py would fail if using a custom checkpoint. probably we don't need to support HF <-> custom format. but what about loading for inference? i think we would either have to make all the functions that model_load_save.py imports from checkpointing part of the protocol (load model weights, get run config), or just document/warn that it is undefined behavior.

  2. if a user wants to do PEFT on a custom checkpoint, seems like loading would fail because we call _load_checkpoint_from_path() in the peft_pre_wrap_hook(). How would I replace that hook after setup? With a callback? Generally less familiar with PEFT, so maybe there's other concerns.

@@ -247,14 +247,13 @@ def modelopt_pre_wrap_hook(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the user's custom logic does not save the same train_state.pt or changes it's name, location, etc.? then we never load right?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar q's -

  1. I suspect cfg.checkpoint.load is expected to be the same as the save prop? i.e. the parent dir, not a specific checkpoint version? Is there any defaulting logic for it to be the same as save, or do users have to explicitly set it?

  2. Is it possible to make this condition customizable, so that the checkpoint manager can determine if it can load or not, and then how to load? E.g. add an API to CheckpointManager like should_load(...) -> bool or something. That way it can also take into account any custom configuration, plus specific logic it wants to run.

    • There are 2 conditions I'm thinking of for loading a checkpoint - whether the user has configured loading (usually yes to auto handle recovering after an interruption, for example) and whether there is an available checkpoint to load from. Both have to be true to successfully load, otherwise either train from scratch or fail if the user wants to force loading from a checkpoint.
  3. Where is the downstream logic that will use this? How does it know whether to train from scratch or from a pre-loaded checkpoint?

@ananthsub
Copy link
Contributor Author

I think several functions in model_load_save.py would fail if using a custom checkpoint. probably we don't need to support HF <-> custom format. but what about loading for inference? i think we would either have to make all the functions that model_load_save.py imports from checkpointing part of the protocol (load model weights, get run config), or just document/warn that it is undefined behavior.

if a user wants to do PEFT on a custom checkpoint, seems like loading would fail because we call _load_checkpoint_from_path() in the peft_pre_wrap_hook(). How would I replace that hook after setup? With a callback? Generally less familiar with PEFT, so maybe there's other concerns.

I think supporting a custom checkpoint format is beyond the initial scope at the moment. This customization is intended for users to be able to handle the save/load operations, but we should expect to have the same checkpoint format on disk. Otherwise there are too many dependencies both within bridge and outside that would then need to access the checkpoint protocol to figure out how to load the checkpoint, which will be very complicated for users to deal with.

I've updated the docs with a limitations section to make this clear

@ananthsub
Copy link
Contributor Author

/ok to test b3d9164

yaoyu-33
yaoyu-33 previously approved these changes Feb 5, 2026

barrier_and_log("after training is done")
ckpt_config = config.checkpoint
if ckpt_config.save and state.train_state.step != 0 and ckpt_config.save_interval != 0:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ananthsub , question: will this checkpoint every step so long as these properties are truthy? I'm wondering where the evaluation of something like if curr_step % save_interval == 0: do save is happening, to know when to call the save.

Alternatively, the step info can be passed in to checkpoint_manager.save (I think it is in state.training_state), and then the manager can make the decision. That would be ideal for us actually, but doesn't look like the default impl does that check, so I presume it is happening elsewhere, unless I missed it.

(I'm still reviewing to see if I have any other q's, but wanted to post this q in the meantime to avoid a long delay)

Copy link

@g-husam g-husam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again!

)
```

```{note}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw unrelated, but when reading the rendered doc I noticed this seems to throw off the formatting (I think the backticks are not closed off at the end)

Could also use

> [!NOTE]
> The `load` parameter should ...

except AttributeError as e:
raise AttributeError(f"Module '{module_path}' does not have class '{class_name}': {e}") from e

manager = custom_manager_class(checkpoint_config)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we would need to provide some additional config props to our custom checkpoint managers, which we can do by creating a subtype of CheckpointConfig, and using that in the recipe rather than CheckpointConfig. That looks like it should work fine, but would be great to get some confirmation/guarantee that subclassing will be supported, perhaps with documentation and some test case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@g-husam yes, that is the intended way to customize this. I will update the documentation to make that clearer

@@ -247,14 +247,13 @@ def modelopt_pre_wrap_hook(model):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar q's -

  1. I suspect cfg.checkpoint.load is expected to be the same as the save prop? i.e. the parent dir, not a specific checkpoint version? Is there any defaulting logic for it to be the same as save, or do users have to explicitly set it?

  2. Is it possible to make this condition customizable, so that the checkpoint manager can determine if it can load or not, and then how to load? E.g. add an API to CheckpointManager like should_load(...) -> bool or something. That way it can also take into account any custom configuration, plus specific logic it wants to run.

    • There are 2 conditions I'm thinking of for loading a checkpoint - whether the user has configured loading (usually yes to auto handle recovering after an interruption, for example) and whether there is an available checkpoint to load from. Both have to be true to successfully load, otherwise either train from scratch or fail if the user wants to force loading from a checkpoint.
  3. Where is the downstream logic that will use this? How does it know whether to train from scratch or from a pre-loaded checkpoint?

@g-husam
Copy link

g-husam commented Feb 26, 2026

👋 Hey @ananthsub , just checking in on my prior questions and to see when this could be merged. Thanks!

@g-husam
Copy link

g-husam commented Mar 11, 2026

👋 Hey @ananthsub , just checking in on my prior questions and to see when this could be merged. Thanks!

Hi @ananthsub, just checking in again on this. Thanks!

@yaoyu-33
Copy link
Contributor

/ok to test 61c17d7

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
@yaoyu-33
Copy link
Contributor

/ok to test 7695c03

@yaoyu-33
Copy link
Contributor

/claude review

Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor

/ok to test 3e86223

Comment on lines 160 to 171
CheckpointSaveContext(
state=state,
model=model,
optimizer=optimizer,
opt_param_scheduler=scheduler,
num_floating_point_operations_so_far=int(state.train_state.floating_point_operations_so_far),
train_data_iterator=train_data_iterator,
)
)

else:
print_rank_0("skipping training ...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: train() already calls checkpoint_manager.finalize_async_saves(terminate=True) before returning (train.py L626), which shuts down the async save worker. This subsequent save() may fail or behave unexpectedly if async checkpointing is enabled.

Also, if the final training step happens to align with save_interval, this duplicates the checkpoint that train() already saved.

mock_init_ctx.return_value = {"context": "data"}
manager = DefaultCheckpointManager(config)
manager.save(ctx)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test: The TypeError branch in create_checkpoint_manager (when a custom class doesn't implement the CheckpointManager protocol) has no test coverage. Consider adding a test with a class that's missing one of the required methods (e.g., save) to verify the error is raised.

- Add exception chaining (from err) to ValueError in create_checkpoint_manager
- Rename unused ctx -> _ctx in CustomTestManager test stub
- Add pytestmark = pytest.mark.unit to test_train.py
- Remove redundant final checkpoint save in _pretrain: train() already saves
  the final checkpoint and terminates the async worker before returning; the
  duplicate save in pretrain.py was both redundant and unsafe for async ckpt

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor

/ok to test 26e6b08

…checkpoint_manager context

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor

/ok to test da25dec

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-more-tests Requires additional L0 and L1 test coverage before merge ready-to-merge PR is approved, current, and only waiting for CI to pass before merge x-googlevertex

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants