Allow custom checkpoint impl via checkpoint manager interface#1955
Allow custom checkpoint impl via checkpoint manager interface#1955yaoyu-33 merged 7 commits intoNVIDIA-NeMo:mainfrom
Conversation
|
/ok to test 7d3064c |
|
/ok to test 0425c2a |
maanug-nv
left a comment
There was a problem hiding this comment.
I think this will work great for pretraining, but I have some questions outside of that use case.
-
I think several functions in
model_load_save.pywould fail if using a custom checkpoint. probably we don't need to support HF <-> custom format. but what about loading for inference? i think we would either have to make all the functions thatmodel_load_save.pyimports from checkpointing part of the protocol (load model weights, get run config), or just document/warn that it is undefined behavior. -
if a user wants to do PEFT on a custom checkpoint, seems like loading would fail because we call
_load_checkpoint_from_path()in thepeft_pre_wrap_hook(). How would I replace that hook after setup? With a callback? Generally less familiar with PEFT, so maybe there's other concerns.
| @@ -247,14 +247,13 @@ def modelopt_pre_wrap_hook(model): | |||
There was a problem hiding this comment.
What if the user's custom logic does not save the same train_state.pt or changes it's name, location, etc.? then we never load right?
There was a problem hiding this comment.
similar q's -
-
I suspect cfg.checkpoint.load is expected to be the same as the save prop? i.e. the parent dir, not a specific checkpoint version? Is there any defaulting logic for it to be the same as save, or do users have to explicitly set it?
-
Is it possible to make this condition customizable, so that the checkpoint manager can determine if it can load or not, and then how to load? E.g. add an API to CheckpointManager like
should_load(...) -> boolor something. That way it can also take into account any custom configuration, plus specific logic it wants to run.- There are 2 conditions I'm thinking of for loading a checkpoint - whether the user has configured loading (usually yes to auto handle recovering after an interruption, for example) and whether there is an available checkpoint to load from. Both have to be true to successfully load, otherwise either train from scratch or fail if the user wants to force loading from a checkpoint.
-
Where is the downstream logic that will use this? How does it know whether to train from scratch or from a pre-loaded checkpoint?
I think supporting a custom checkpoint format is beyond the initial scope at the moment. This customization is intended for users to be able to handle the save/load operations, but we should expect to have the same checkpoint format on disk. Otherwise there are too many dependencies both within bridge and outside that would then need to access the checkpoint protocol to figure out how to load the checkpoint, which will be very complicated for users to deal with. I've updated the docs with a limitations section to make this clear |
|
/ok to test b3d9164 |
|
|
||
| barrier_and_log("after training is done") | ||
| ckpt_config = config.checkpoint | ||
| if ckpt_config.save and state.train_state.step != 0 and ckpt_config.save_interval != 0: |
There was a problem hiding this comment.
Hey @ananthsub , question: will this checkpoint every step so long as these properties are truthy? I'm wondering where the evaluation of something like if curr_step % save_interval == 0: do save is happening, to know when to call the save.
Alternatively, the step info can be passed in to checkpoint_manager.save (I think it is in state.training_state), and then the manager can make the decision. That would be ideal for us actually, but doesn't look like the default impl does that check, so I presume it is happening elsewhere, unless I missed it.
(I'm still reviewing to see if I have any other q's, but wanted to post this q in the meantime to avoid a long delay)
| ) | ||
| ``` | ||
|
|
||
| ```{note} |
There was a problem hiding this comment.
Btw unrelated, but when reading the rendered doc I noticed this seems to throw off the formatting (I think the backticks are not closed off at the end)
Could also use
> [!NOTE]
> The `load` parameter should ...
| except AttributeError as e: | ||
| raise AttributeError(f"Module '{module_path}' does not have class '{class_name}': {e}") from e | ||
|
|
||
| manager = custom_manager_class(checkpoint_config) |
There was a problem hiding this comment.
we would need to provide some additional config props to our custom checkpoint managers, which we can do by creating a subtype of CheckpointConfig, and using that in the recipe rather than CheckpointConfig. That looks like it should work fine, but would be great to get some confirmation/guarantee that subclassing will be supported, perhaps with documentation and some test case
There was a problem hiding this comment.
@g-husam yes, that is the intended way to customize this. I will update the documentation to make that clearer
| @@ -247,14 +247,13 @@ def modelopt_pre_wrap_hook(model): | |||
There was a problem hiding this comment.
similar q's -
-
I suspect cfg.checkpoint.load is expected to be the same as the save prop? i.e. the parent dir, not a specific checkpoint version? Is there any defaulting logic for it to be the same as save, or do users have to explicitly set it?
-
Is it possible to make this condition customizable, so that the checkpoint manager can determine if it can load or not, and then how to load? E.g. add an API to CheckpointManager like
should_load(...) -> boolor something. That way it can also take into account any custom configuration, plus specific logic it wants to run.- There are 2 conditions I'm thinking of for loading a checkpoint - whether the user has configured loading (usually yes to auto handle recovering after an interruption, for example) and whether there is an available checkpoint to load from. Both have to be true to successfully load, otherwise either train from scratch or fail if the user wants to force loading from a checkpoint.
-
Where is the downstream logic that will use this? How does it know whether to train from scratch or from a pre-loaded checkpoint?
|
👋 Hey @ananthsub , just checking in on my prior questions and to see when this could be merged. Thanks! |
Hi @ananthsub, just checking in again on this. Thanks! |
|
/ok to test 61c17d7 |
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
|
/ok to test 7695c03 |
|
/claude review |
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 3e86223 |
| CheckpointSaveContext( | ||
| state=state, | ||
| model=model, | ||
| optimizer=optimizer, | ||
| opt_param_scheduler=scheduler, | ||
| num_floating_point_operations_so_far=int(state.train_state.floating_point_operations_so_far), | ||
| train_data_iterator=train_data_iterator, | ||
| ) | ||
| ) | ||
|
|
||
| else: | ||
| print_rank_0("skipping training ...") |
There was a problem hiding this comment.
Bug: train() already calls checkpoint_manager.finalize_async_saves(terminate=True) before returning (train.py L626), which shuts down the async save worker. This subsequent save() may fail or behave unexpectedly if async checkpointing is enabled.
Also, if the final training step happens to align with save_interval, this duplicates the checkpoint that train() already saved.
| mock_init_ctx.return_value = {"context": "data"} | ||
| manager = DefaultCheckpointManager(config) | ||
| manager.save(ctx) | ||
|
|
There was a problem hiding this comment.
Missing test: The TypeError branch in create_checkpoint_manager (when a custom class doesn't implement the CheckpointManager protocol) has no test coverage. Consider adding a test with a class that's missing one of the required methods (e.g., save) to verify the error is raised.
- Add exception chaining (from err) to ValueError in create_checkpoint_manager - Rename unused ctx -> _ctx in CustomTestManager test stub - Add pytestmark = pytest.mark.unit to test_train.py - Remove redundant final checkpoint save in _pretrain: train() already saves the final checkpoint and terminates the async worker before returning; the duplicate save in pretrain.py was both redundant and unsafe for async ckpt Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 26e6b08 |
…checkpoint_manager context Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test da25dec |
What does this PR do ?
Currently, checkpointing save/load is implemented purely through functional calls, making customization by users impossible without direct modification/forking.
This PR introduces a CheckpointManager protocol to allow users to inject custom save/load implementations without needing to fork the code.
Protocol overview:
CheckpointManager: main protocol to implement save/load through. finalizing async calls is also routed through this protocl now too for cleanupCheckpointConfigis augmented with a field to specify the path to the custom manager class. If specified, the class is imported and used during training. A factory function is added to instantiate the manager to use for training. A default implementation is provided, which wraps around the existing functions. The default implementation also stores the checkpoint context map, used for local checkpointing. There are no changes made to these functions in this PR, so the core checkpoint saving/loading logic is untouched.Integration points:
setup.pyis updated to create the checkpoint manager and use it for the initial checkpoint loadtrain.pyin the main training loop, the checkpoint manager is used for all saves + async finalizationpretrain.py, the checkpoint manager is used for the end of train checkpoint saveChangelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
New Features
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.