Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
986c5fb
Split bypass prerequisites
Separius May 12, 2026
12086fb
Address CodeRabbit feedback for bypass integration
Separius May 12, 2026
b9c00ba
Address additional MR1 review feedback
Separius May 12, 2026
bb4217c
Merge branch 'main' into ssameni/puzzletron-bypass-1-prereqs
Separius May 13, 2026
d052cce
Merge branch 'main' into ssameni/puzzletron-bypass-1-prereqs
Separius May 19, 2026
2a10f7f
Address puzzletron review feedback
Separius May 19, 2026
4b3e381
Handle empty chat message lists
Separius May 19, 2026
29c5981
Address CodeRabbit test nitpicks
Separius May 19, 2026
aeac849
Fix no-template dataloader test tokenizer
Separius May 19, 2026
1e7f9a7
Use descriptor for pruning LM config
Separius May 20, 2026
4f69204
Add pruning descriptor coverage
Separius May 20, 2026
a38b8b9
Apply pre-commit formatting
Separius May 20, 2026
33d2d6d
Add targeted puzzletron bypass tests
Separius May 22, 2026
4ad3b56
Apply puzzletron test formatting
Separius May 22, 2026
369b450
Merge branch 'main' into ssameni/puzzletron-bypass-1-prereqs
Separius May 22, 2026
1475f41
Merge branch 'main' into ssameni/puzzletron-bypass-1-prereqs
Separius May 26, 2026
3fe86e8
Merge branch 'main' into ssameni/puzzletron-bypass-1-prereqs
Separius May 27, 2026
e5b0d4b
Merge branch 'main' into ssameni/puzzletron-bypass-1-prereqs
Separius May 28, 2026
e567f57
Prune redundant Puzzletron tests
Separius May 28, 2026
3084194
Apply Puzzletron test formatting
Separius May 28, 2026
dccf464
Merge branch 'main' into ssameni/puzzletron-bypass-1-prereqs
kevalmorabia97 May 28, 2026
eada923
Disable async save for Megatron Bridge distill export
Separius May 29, 2026
11c1eea
Merge branch 'main' into ssameni/puzzletron-bypass-1-prereqs
Separius May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/megatron_bridge/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def _build_model_provider(hf_path):
load=checkpoint_dir, # Resume from this directory (if exists)
most_recent_k=5, # Keeps 5 most recent checkpoints (not metric-based)
ckpt_format="torch_dist",
async_save=True,
async_save=False,
fully_parallel_save=True,
),
rng=RNGConfig(seed=args.seed),
Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/puzzletron/anymodel/model_descriptor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,19 @@ def uses_autocast() -> bool:
"""
return True

@staticmethod
def pruning_mixins() -> Dict[str, Any]:
"""Return available pruning mixins for bypass distillation.

Override in subclasses to provide model-specific pruning mixins, e.g.
``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``.

Returns an empty dict by default so that descriptors that do not need
model-specific weight-slicing (e.g. Llama with standard FFN truncation)
can rely on the generic ``create_child_state_dict`` fallback path.
"""
return {}

@staticmethod
def get_language_model_config(config):
"""Get the language model config from a PretrainedConfig.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn

# Expert removal is supported for unquantized models (test models).
# Production models use MXFP4 quantized MoE with combined tensors
Expand All @@ -37,7 +38,11 @@
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size

__all__ = ["GptOssModelDescriptor", "GptOssExpertRemovalLayerDescriptor"]
__all__ = [
"GptOssExpertRemovalLayerDescriptor",
"GptOssKVHeadsLayerDescriptor",
"GptOssModelDescriptor",
]


@ModelDescriptorFactory.register_decorator("gpt_oss")
Expand Down Expand Up @@ -173,7 +178,29 @@ def pruning_mixins() -> Dict[str, PruningMixIn]:
Note: Expert removal works for unquantized models (test models).
Production models use MXFP4 quantization which is not yet supported.
"""
return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())}
# Single instance shared between the canonical key and the legacy alias
# so resolve_pruning_mixin returns the same object regardless of which
# name a caller uses.
expert_mixin = ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())
return {
"experts_removal": expert_mixin,
# Backward-compat alias: this key was "expert_removal" before the
# bypass branch standardised on "experts_removal" (matching the
# NemotronH descriptor). Kept so external scripts that still call
# `resolve_pruning_mixin("expert_removal", GptOssModelDescriptor)`
# continue to work. Remove after a deprecation cycle.
"expert_removal": expert_mixin,
"kv_heads": KVHeadsPruningMixIn(GptOssKVHeadsLayerDescriptor()),
}


@dataclass
class GptOssKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "self_attn.o_proj"
attn_prefix_name: str = "model.layers.{layer_idx}.self_attn"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn
from ....pruning.pruning_mixin import PruningMixIn
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same

__all__ = ["NemotronHExpertRemovalLayerDescriptor", "NemotronHModelDescriptor"]
__all__ = [
"NemotronHExpertRemovalLayerDescriptor",
"NemotronHKVHeadsLayerDescriptor",
"NemotronHModelDescriptor",
]


def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
Expand All @@ -51,6 +56,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
return matches


@dataclass
class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "mixer.o_proj"
attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@dataclass
class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor):
target_name: str = "mixer.gate"
Expand Down Expand Up @@ -251,4 +265,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]:
def pruning_mixins() -> Dict[str, PruningMixIn]:
return {
"experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()),
"kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()),
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@
FFNIntermediateLayerDescriptor,
FFNIntermediatePruningMixIn,
)
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn
from ....pruning.pruning_mixin import PruningMixIn
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same

__all__ = [
"NemotronHV2FFNIntermediateLayerDescriptor",
"NemotronHV2KVHeadsLayerDescriptor",
"NemotronHV2ExpertRemovalLayerDescriptor",
"NemotronHV2ModelDescriptor",
]
Expand Down Expand Up @@ -109,6 +111,15 @@ class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"])


@dataclass
class NemotronHV2KVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "mixer.o_proj"
attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@ModelDescriptorFactory.register_decorator("nemotron_h_v2")
class NemotronHV2ModelDescriptor(ModelDescriptor):
_DECODER_LAYER_CLS: Type[nn.Module] = None
Expand Down Expand Up @@ -291,5 +302,6 @@ def pruning_mixins() -> Dict[str, PruningMixIn]:
"ffn_intermediate": FFNIntermediatePruningMixIn(
NemotronHV2FFNIntermediateLayerDescriptor()
),
"kv_heads": KVHeadsPruningMixIn(NemotronHV2KVHeadsLayerDescriptor()),
# TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
)

from ....block_config import BlockConfig
from ....pruning.expert_removal_pruning_mixin import ExpertRemovalLayerDescriptor
from ....pruning.expert_removal_pruning_mixin import (
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn
from ....pruning.pruning_mixin import PruningMixIn
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size

Expand Down Expand Up @@ -56,6 +60,13 @@ def get_language_model_config(config):
"""Qwen3-VL has nested text_config for language model parameters."""
return config.text_config if hasattr(config, "text_config") else config

@staticmethod
def pruning_mixins() -> Dict[str, PruningMixIn]:
return {
"experts_removal": ExpertRemovalPruningMixIn(Qwen3VLExpertRemovalLayerDescriptor()),
"kv_heads": KVHeadsPruningMixIn(Qwen3VLKVHeadsLayerDescriptor()),
}

@staticmethod
def decoder_layer_cls():
return Qwen3VLMoeTextDecoderLayer
Expand Down
12 changes: 10 additions & 2 deletions modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
)

from .pruning_mixin import LayerDescriptor, PruningMixIn
from .pruning_utils import GQAInitMode, _init_attention_biases, _init_attention_weights
from .pruning_utils import (
GQAInitMode,
_init_attention_biases,
_init_attention_weights,
_lm_head_dim,
)

__all__ = [
"KVHeadsLayerDescriptor",
Expand Down Expand Up @@ -60,6 +65,7 @@ def prune_single_layer(
new_state_dict: dict,
original_config: PretrainedConfig,
new_config: PretrainedConfig,
descriptor,
gqa_init_mode: GQAInitMode,
mlp_init_config: Optional[dict[str, Any]],
is_original_mha: bool,
Expand All @@ -74,7 +80,7 @@ def prune_single_layer(
f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names
]

head_size = new_config.head_dim
head_size = _lm_head_dim(new_config, descriptor)
for part in ["weight", "bias"]:
attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]]
q_key, k_key, v_key, o_key = attn_keys
Expand All @@ -94,6 +100,7 @@ def prune_single_layer(
layer_idx=layer_idx,
new_state_dict=new_state_dict,
new_config=new_config,
descriptor=descriptor,
original_state_dict=parent_state_dict,
original_config=original_config,
q_key=q_key,
Expand All @@ -112,6 +119,7 @@ def prune_single_layer(
layer_idx=layer_idx,
new_state_dict=new_state_dict,
new_config=new_config,
descriptor=descriptor,
original_state_dict=parent_state_dict,
original_config=original_config,
q_key=q_key,
Expand Down
46 changes: 36 additions & 10 deletions modelopt/torch/puzzletron/pruning/pruning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class MlpInitMode(Enum):
PruneByActivationsLog = "PruneByActivationsLog"
ExpertRemoval = "ExpertRemoval"
ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN"
MoEChannelPruning = "MoEChannelPruning"


class LinearInitMode(Enum):
Expand All @@ -66,6 +67,14 @@ class HiddenSizeInitMode(Enum):
CopyAsIs = "CopyAsIs"


def _lm_head_dim(config, descriptor: Type[ModelDescriptor]) -> int:
lm_config = descriptor.get_language_model_config(config)
head_dim = getattr(lm_config, "head_dim", None)
if head_dim is not None:
return head_dim
return lm_config.hidden_size // lm_config.num_attention_heads


def resolve_pruning_mixin(
pruning_mixin, descriptor: Type[ModelDescriptor]
) -> PruningMixIn | List[PruningMixIn]:
Expand Down Expand Up @@ -214,6 +223,7 @@ def _init_attention_weights(
layer_idx,
new_state_dict,
new_config,
descriptor,
original_state_dict,
q_key,
k_key,
Expand All @@ -224,10 +234,13 @@ def _init_attention_weights(
head_size,
mlp_init_config,
):
assert new_config.num_attention_heads == original_config.num_attention_heads, (
f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})"
new_lm = descriptor.get_language_model_config(new_config)
orig_lm = descriptor.get_language_model_config(original_config)
assert new_lm.num_attention_heads == orig_lm.num_attention_heads, (
f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})"
)
num_q_heads = new_config.num_attention_heads
num_q_heads = new_lm.num_attention_heads
# block_configs lives on the outer puzzletron-converted config, not on text_config.
num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads
orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads

Expand Down Expand Up @@ -362,6 +375,7 @@ def _init_attention_biases(
layer_idx,
new_state_dict,
new_config,
descriptor,
original_state_dict,
q_key,
k_key,
Expand All @@ -372,17 +386,29 @@ def _init_attention_biases(
head_size,
mlp_init_config,
):
assert new_config.num_attention_heads == original_config.num_attention_heads, (
f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})"
new_lm = descriptor.get_language_model_config(new_config)
orig_lm = descriptor.get_language_model_config(original_config)
assert new_lm.num_attention_heads == orig_lm.num_attention_heads, (
f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})"
)
num_q_heads = new_config.num_attention_heads
num_q_heads = new_lm.num_attention_heads
# block_configs lives on the outer puzzletron-converted config, not on text_config.
num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads
orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads
n_heads_in_group = num_q_heads // num_kv_heads
orig_n_heads_in_group = num_q_heads // orig_num_kv_heads

o_proj_bias = new_config.o_proj_bias
attention_bias = new_config.attention_bias
# Some HF native configs (e.g. GptOssConfig) don't expose o_proj_bias / attention_bias as
# top-level attributes the way puzzletron's DeciLM-style configs do. Fall back to probing
# the new state dict for the actual bias keys only when the attribute is omitted.
# KVHeadsPruningMixIn only calls this helper after filtering to keys present in
# new_state_dict, so the probe mirrors the caller's already-selected bias tensors.
o_proj_bias = getattr(new_config, "o_proj_bias", None)
if o_proj_bias is None:
o_proj_bias = o_key in new_state_dict
attention_bias = getattr(new_config, "attention_bias", None)
if attention_bias is None:
attention_bias = any(key in new_state_dict for key in (q_key, k_key, v_key))

# If no biases
if not (o_proj_bias or attention_bias):
Expand Down Expand Up @@ -438,8 +464,8 @@ def _init_attention_biases(
assert not is_original_mha, (
"Degrouping can only be done on original models that are GQA themselves."
)
n_groups = new_config.num_attention_heads // n_heads_in_group
orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group
n_groups = new_lm.num_attention_heads // n_heads_in_group
orig_n_groups = orig_lm.num_attention_heads // orig_n_heads_in_group
assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}"
n_repeats = n_groups // orig_n_groups
if n_repeats > 1:
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/puzzletron/sewing_kit/passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"PassageOutput",
"Predicate",
"always_false_predicate",
"always_true_predicate",
"Passage",
"patch_module",
]
Expand Down
Loading
Loading