Introduce refactored model builder abstractions#2241
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
f561800 to
f5707f5
Compare
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
| pg_collection: ProcessGroupCollection, | ||
| ddp_config: DistributedDataParallelConfig | None = None, | ||
| overlap_param_gather_with_optimizer_step: bool = False, | ||
| use_megatron_fsdp: bool = False, | ||
| use_torch_fsdp2: bool = False, | ||
| wrap_with_ddp: bool = True, | ||
| data_parallel_random_init: bool = True, | ||
| mixed_precision_wrapper: Callable[[Any, MegatronModule], MegatronModule] | None = Float16Module, | ||
| model_type: ModelType = ModelType.encoder_or_decoder, |
There was a problem hiding this comment.
i see that this is the translation for the existing setup, but how do these args generalize for the MIMO use case?
There was a problem hiding this comment.
mimo can take in pg collection and ddp config as dict for submodules? I dont want to over-design to fit mimo
There was a problem hiding this comment.
i think signature can be a little flexible until it's integrated into training loop, which will be in follow-up PR(s).
@yashaswikarnati do you have any feedback on how this signature can be a bit better for MIMO without changing too much?
There was a problem hiding this comment.
Sorry just coming back to this. I see that in the signature we accept only one pg_collection. do we plan to have a different signature for multimodal.
I dont think thats overdesign - I would say thats the bare minimum to achieve the functionality.
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
What does this PR do ?
This PR refactors the
ModelProviderMixinto split up responsibilities.This refactor preserves the existing level of model customizability as well as conversion compatibility, while separating model configuration and initialization/building into separate components. This architecture was determined to be more comprehensible based on several sources of feedback.
In particular, a custom model type can define one or both of:
Base model architectures like GPT and Mamba will follow this design.
These new abstractions are all new classes rather than direct modifications to
ModelProviderMixinto ease migration. Additionally,ModelProviderMixinrelied on the inheritance fromTransformerConfig, modified some config attributes during model building, and contained some unused code. Therefore, I think we should deprecate theModelProviderMixinafter integrating the new abstractions across the codebase.The new classes are:
ModelConfig- Has some required attributes, including import path of builder. Supports serialization to and from dictionary.ModelBuilder- Contains stub methods for implementation of (distributed) model init. Also maintains pre-wrap and post-wrap hook registration fromModelProviderMixinwhich is used by PEFT. Usage of the hooks is left up to the child class.This PR also includes the necessary changes for the Mamba base model to support this new interface. The refactor was easier to demonstrate on Mamba than on GPT. 'mamba/mamba_builder.py' contains:
MambaModelConfigwhich encapsulates all config settings used to create an MCore Mamba model.MambaModelBuilderwhich defines how a distributed Mamba model should be built.Since distributed model initialization is identical for GPT and Mamba models, the logic is extracted into a separate function under 'unimodal.py', which can be called by
GPTModelBuilderin the future. The logic is also split into helpers for readability.I plan to upstream all new files under src/ in this PR to MegatronLM in the near future. It may also be worthwhile to push the base model configs, e.g.
MambaModelConfigto MCore eventually use directly with the MCore model. Additionally, ifTransformerConfigis broken up in the future, the model-specific config can remain at the top-level of the hierarchy.Some additional notes:
unimodal.py, these abstractions should appear as a simple layer on top of existing MLM code likeget_model()andmamba_builders.py.GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Refactor