-
Notifications
You must be signed in to change notification settings - Fork 432
Merge mbridge distillation for any_model #1036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e82164f
2099df3
eb5cf8a
c9de41c
3c1bc1f
8357136
6cc2194
ee4e1e3
449b523
fb27bba
b350f82
fafe5a3
e988248
c717852
030f126
8dcdfbf
70df0df
bb56662
ecd953e
ee8f538
c9b76a1
6e3af61
0ad6d92
995eb1a
34081c9
ed5c00f
993b5ec
6e9f03b
e8b7a7d
47414d5
a8305d8
68421a5
d6b8028
ecd2341
f9d845d
d171b01
722da90
934ab2f
0f14ec3
dcb9e02
0c9ea5d
5b310e2
4f82b1c
176a435
02e2c9b
92c4419
aa1eb3e
2b84a96
fb838c0
13378ff
47ca0e3
96112f7
cb6b182
670bb34
0e1b591
ca845ec
be825bc
7fd1afa
7d7b609
249af9d
b80583c
88b1b13
c0da9c0
1dd742e
4a6ebbe
585f0ed
7fb5d9a
75d3d69
0e5722d
2dd9735
3561de5
27866de
a012fe6
52922a4
c234fb4
53dcd10
69d9648
4a692dc
e795f0c
631306c
dc77be2
b76e0ef
109b185
5cadc65
151081c
36daa6d
960b8ce
854d96b
b47f846
13f5edc
f2c1578
3592eec
f06cb20
ad31b09
0505916
7016857
7ede076
24ba700
44186c7
81f6d4e
a5b715f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Megatron-Bridge adapters for Puzzletron AnyModel checkpoints. | ||
|
|
||
| This module provides bridges for converting Puzzletron AnyModel checkpoints | ||
| (heterogeneous layer architectures) to Megatron-Core format via Megatron-Bridge. | ||
| """ | ||
|
|
||
| # Import to register bridges (side effect) | ||
| from modelopt.torch.puzzletron.export.mbridge.base import HeterogeneousBridgeMixin | ||
| from modelopt.torch.puzzletron.export.mbridge.llama import ( # noqa: F401 | ||
| PuzzletronLlamaAnyModelBridge, | ||
| ) | ||
| from modelopt.torch.puzzletron.export.mbridge.qwen3 import ( # noqa: F401 | ||
| PuzzletronQwen3AnyModelBridge, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "HeterogeneousBridgeMixin", | ||
| "PuzzletronLlamaAnyModelBridge", | ||
| "PuzzletronQwen3AnyModelBridge", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| #!/usr/bin/env python3 | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Mixin class for bridges that support heterogeneous layer architectures. | ||
|
|
||
| This module provides a mixin class for converting models with block_configs | ||
| (heterogeneous layer configurations) to Megatron-Core format via Megatron-Bridge. | ||
| """ | ||
|
|
||
| import dataclasses | ||
| import json | ||
| from collections.abc import Callable | ||
| from dataclasses import dataclass, fields | ||
|
|
||
| from megatron.bridge.models.gpt_provider import GPTModelProvider | ||
| from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM | ||
| from megatron.bridge.models.transformer_config import HeterogeneousTransformerConfig | ||
| from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( | ||
| get_gpt_heterogeneous_layer_spec, | ||
| ) | ||
| from megatron.core.transformer.spec_utils import ModuleSpec | ||
|
|
||
|
|
||
| def heterogeneous_layer_spec(config) -> ModuleSpec: | ||
| """Get GPT heterogeneous layer spec using Transformer Engine.""" | ||
| return get_gpt_heterogeneous_layer_spec(config, use_te=True) | ||
|
|
||
|
|
||
| @dataclass | ||
| class GenericHeterogeneousProvider(GPTModelProvider, HeterogeneousTransformerConfig): | ||
| """Generic provider for AnyModel checkpoints with block_configs.""" | ||
|
|
||
| # Heterogeneous configuration fields | ||
| heterogeneous_layers_config_path: str | None = None | ||
| heterogeneous_layers_config_encoded_json: str = "" | ||
| transformer_layer_spec: ModuleSpec | Callable = heterogeneous_layer_spec | ||
|
Comment on lines
+47
to
+50
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't let the parent provider overwrite the heterogeneous layer spec. Line 50 sets Proposed fix- provider_kwargs = dataclasses.asdict(parent_provider)
+ provider_kwargs = {
+ field.name: getattr(parent_provider, field.name)
+ for field in fields(parent_provider)
+ if field.init
+ }
@@
# Only keep kwargs that are valid fields
provider_kwargs = {k: v for k, v in provider_kwargs.items() if k in valid_fields}
+ provider_kwargs["transformer_layer_spec"] = heterogeneous_layer_spec
provider_kwargs["heterogeneous_layers_config_encoded_json"] = (
self._build_heterogeneous_config_json(hf_pretrained.config)
)Also applies to: 91-113 🤖 Prompt for AI Agents |
||
|
|
||
| def __getattr__(self, name: str): | ||
| """Handle missing attributes for OmegaConf compatibility. | ||
|
|
||
| Returns empty list for per_block_parameters if not yet initialized (before finalize()). | ||
| This allows OmegaConf to serialize/deserialize configs without errors. Actual usage | ||
| should call finalize() first to set per_block_parameters as a real attribute. | ||
| """ | ||
| if name == "per_block_parameters": | ||
| # Return existing attribute if set, otherwise [] for OmegaConf compatibility | ||
| try: | ||
| return object.__getattribute__(self, name) | ||
| except AttributeError: | ||
| return [] | ||
| raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") | ||
|
|
||
|
|
||
| class HeterogeneousBridgeMixin: | ||
| """Mixin for bridges supporting heterogeneous layer architectures (block_configs). | ||
|
|
||
| Must be used with multiple inheritance alongside a model-specific bridge. | ||
| Example: class PuzzletronLlamaAnyModelBridge(HeterogeneousBridgeMixin, LlamaBridge) | ||
| """ | ||
|
|
||
| def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: | ||
| """Convert HF AnyModel config to Megatron GPTModelProvider. | ||
|
|
||
| This method: | ||
| 1. Calls the parent bridge's provider_bridge() to get a GPTModelProvider with all | ||
| model-specific settings (e.g., LlamaBridge sets normalization="RMSNorm", etc.) | ||
| 2. Converts the provider to a dict and filters to only fields accepted by | ||
| GenericHeterogeneousProvider (which inherits from GPTModelProvider, so all valid | ||
| GPTModelProvider fields are preserved) | ||
| 3. Adds heterogeneous configuration and returns GenericHeterogeneousProvider | ||
|
|
||
| All parameters from the parent bridge (e.g., LlamaBridge) are maintained because | ||
| GenericHeterogeneousProvider inherits from GPTModelProvider, which includes all | ||
| the fields that the parent bridge sets. | ||
| """ | ||
|
|
||
| parent_provider = super().provider_bridge(hf_pretrained) # type: ignore[misc] | ||
|
|
||
| provider_kwargs = dataclasses.asdict(parent_provider) | ||
|
|
||
| # Filter to only fields that GenericHeterogeneousProvider accepts. | ||
| # GenericHeterogeneousProvider inherits from GPTModelProvider, so it includes all | ||
| # GPTModelProvider fields. Model-specific fields from subclasses (e.g., MistralModelProvider, | ||
| # GPTOSSModelProvider) are filtered out because GenericHeterogeneousProvider only inherits | ||
| # from GPTModelProvider, not from model-specific subclasses. | ||
| # | ||
| # Note: This logic may not work for bridges like MistralBridge or GPTOSSBridge if they | ||
| # use model-specific parameters not supported by GenericHeterogeneousProvider (e.g., | ||
| # scale_factor, yarn_rotary_scaling_factor, moe_* parameters). In such cases, create a | ||
| # model-specific heterogeneous provider that inherits from the model-specific provider. | ||
| valid_fields = {f.name for f in fields(GenericHeterogeneousProvider)} | ||
|
|
||
| # Only keep kwargs that are valid fields | ||
| provider_kwargs = {k: v for k, v in provider_kwargs.items() if k in valid_fields} | ||
|
|
||
| provider_kwargs["heterogeneous_layers_config_encoded_json"] = ( | ||
| self._build_heterogeneous_config_json(hf_pretrained.config) | ||
| ) | ||
| return GenericHeterogeneousProvider(**provider_kwargs) | ||
|
|
||
| def _build_heterogeneous_config_json(self, hf_config) -> str: | ||
| """Build heterogeneous layers config JSON from HF config.""" | ||
|
|
||
| hf_config_dict = json.loads(hf_config.to_json_string()) | ||
|
|
||
| mcore_block_configs = [ | ||
| self._convert_block_config(block) for block in hf_config_dict["block_configs"] | ||
| ] | ||
| return json.dumps({"block_configs": mcore_block_configs}, ensure_ascii=False) | ||
|
|
||
| def _convert_block_config(self, block: dict) -> dict: | ||
| """Convert a single block config from HF format to MCore format.""" | ||
| return { | ||
| "attention": self._convert_attention_config(block["attention"]), | ||
| "ffn": self._convert_ffn_config(block["ffn"]), | ||
| } | ||
|
|
||
| def _convert_attention_config(self, attention_config: dict) -> dict: | ||
| """Convert attention config from HF format to MCore format.""" | ||
| attention_config = attention_config.copy() | ||
| attention_config["num_query_groups"] = attention_config.pop("num_key_value_heads") | ||
| return attention_config | ||
|
|
||
| def _convert_ffn_config(self, ffn_config: dict) -> dict: | ||
| """Convert FFN/MLP config from HF format to MCore format.""" | ||
| ffn_config = ffn_config.copy() | ||
| ffn_config["ffn_hidden_size"] = ffn_config.pop("intermediate_size") | ||
| return ffn_config | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty much everything in this PR seems like we should instead merge to M-Bridge. Are we confident enough to upstream these changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not confident, e.g., we would need to talk to mbrdige/megatron-lm people on that first, align with their plans for heterogenous support. Let's think about it once puzzletron is in main.
We also have to do support for gpt-oss and mamba, so it is not the best time to merge it to mcore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nemo:26.04 container code freeze is in 2 weeks. Lets make sure we raise a PR for required changes to M-Bridge before that so we can see what can and cannot be upstreamed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unlikely have time for it in next 2 weeks