Skip to content

Introduce refactored model builder abstractions#2241

Merged
maanug-nv merged 33 commits intomainfrom
maanug/provider-refactor-mamba-simpler
Feb 24, 2026
Merged

Introduce refactored model builder abstractions#2241
maanug-nv merged 33 commits intomainfrom
maanug/provider-refactor-mamba-simpler

Conversation

@maanug-nv
Copy link
Copy Markdown
Contributor

@maanug-nv maanug-nv commented Feb 5, 2026

What does this PR do ?

This PR refactors the ModelProviderMixin to split up responsibilities.

This refactor preserves the existing level of model customizability as well as conversion compatibility, while separating model configuration and initialization/building into separate components. This architecture was determined to be more comprehensible based on several sources of feedback.
In particular, a custom model type can define one or both of:

  1. A serializable configuration object encapsulating all config settings for the model. This may be a shallow dataclass hierarchy.
  2. A builder object that determines how the model is initialized for distributed training or inference using aforementioned config.

Base model architectures like GPT and Mamba will follow this design.

These new abstractions are all new classes rather than direct modifications to ModelProviderMixin to ease migration. Additionally, ModelProviderMixin relied on the inheritance from TransformerConfig, modified some config attributes during model building, and contained some unused code. Therefore, I think we should deprecate the ModelProviderMixin after integrating the new abstractions across the codebase.

The new classes are:

  1. ModelConfig - Has some required attributes, including import path of builder. Supports serialization to and from dictionary.
  2. ModelBuilder - Contains stub methods for implementation of (distributed) model init. Also maintains pre-wrap and post-wrap hook registration from ModelProviderMixin which is used by PEFT. Usage of the hooks is left up to the child class.

This PR also includes the necessary changes for the Mamba base model to support this new interface. The refactor was easier to demonstrate on Mamba than on GPT. 'mamba/mamba_builder.py' contains:

  1. MambaModelConfig which encapsulates all config settings used to create an MCore Mamba model.
  2. MambaModelBuilder which defines how a distributed Mamba model should be built.

Since distributed model initialization is identical for GPT and Mamba models, the logic is extracted into a separate function under 'unimodal.py', which can be called by GPTModelBuilder in the future. The logic is also split into helpers for readability.

I plan to upstream all new files under src/ in this PR to MegatronLM in the near future. It may also be worthwhile to push the base model configs, e.g. MambaModelConfig to MCore eventually use directly with the MCore model. Additionally, if TransformerConfig is broken up in the future, the model-specific config can remain at the top-level of the hierarchy.

Some additional notes:

  • Advanced use cases may need to customize how the model wrapping/distributed model initialization is done, eg MIMO. Therefore, it makes more sense for that functionality to be defined in the ModelBuilder, so that both the single-stage building logic and distributed building logic can be overriden from the same place.
  • With the functional helpers in unimodal.py, these abstractions should appear as a simple layer on top of existing MLM code like get_model() and mamba_builders.py.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced a model configuration and building framework supporting distributed training with DDP/FSDP options.
    • Added Mamba model builder with configurable stack specifications.
    • Added hook system for customizing model wrapping behavior.
    • Enabled virtual pipeline staging support for models.
  • Refactor

    • Consolidated distributed model utilities into shared module.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 5, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@maanug-nv maanug-nv changed the base branch from maanug/provider-refactor-mamba to main February 10, 2026 18:08
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
@maanug-nv maanug-nv force-pushed the maanug/provider-refactor-mamba-simpler branch from f561800 to f5707f5 Compare February 18, 2026 08:09
@maanug-nv maanug-nv changed the title Alternate impl of model provider refactor Introduce refactored model builder abstractions Feb 18, 2026
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Phlip79
Phlip79 previously approved these changes Feb 20, 2026
Comment on lines +220 to +228
pg_collection: ProcessGroupCollection,
ddp_config: DistributedDataParallelConfig | None = None,
overlap_param_gather_with_optimizer_step: bool = False,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
wrap_with_ddp: bool = True,
data_parallel_random_init: bool = True,
mixed_precision_wrapper: Callable[[Any, MegatronModule], MegatronModule] | None = Float16Module,
model_type: ModelType = ModelType.encoder_or_decoder,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see that this is the translation for the existing setup, but how do these args generalize for the MIMO use case?

Copy link
Copy Markdown
Contributor

@yaoyu-33 yaoyu-33 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mimo can take in pg collection and ddp config as dict for submodules? I dont want to over-design to fit mimo

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think signature can be a little flexible until it's integrated into training loop, which will be in follow-up PR(s).

@yashaswikarnati do you have any feedback on how this signature can be a bit better for MIMO without changing too much?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just coming back to this. I see that in the signature we accept only one pg_collection. do we plan to have a different signature for multimodal.

I dont think thats overdesign - I would say thats the bare minimum to achieve the functionality.

Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Maanu Grover <maanug@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New capabilities, enhancements, or enablement work

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement refactored Model Builder interface in Megatron Bridge

5 participants