diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index 5c4224c03dd..16a0a85f842 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -343,7 +343,7 @@ def _build_model_provider(hf_path): load=checkpoint_dir, # Resume from this directory (if exists) most_recent_k=5, # Keeps 5 most recent checkpoints (not metric-based) ckpt_format="torch_dist", - async_save=True, + async_save=False, fully_parallel_save=True, ), rng=RNGConfig(seed=args.seed), diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py index 3c1749d46ec..58b045bd21c 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py @@ -169,6 +169,19 @@ def uses_autocast() -> bool: """ return True + @staticmethod + def pruning_mixins() -> Dict[str, Any]: + """Return available pruning mixins for bypass distillation. + + Override in subclasses to provide model-specific pruning mixins, e.g. + ``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``. + + Returns an empty dict by default so that descriptors that do not need + model-specific weight-slicing (e.g. Llama with standard FFN truncation) + can rely on the generic ``create_child_state_dict`` fallback path. + """ + return {} + @staticmethod def get_language_model_config(config): """Get the language model config from a PretrainedConfig. diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py index c8fd86b4bb6..1abecdec0c2 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -28,6 +28,7 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn # Expert removal is supported for unquantized models (test models). # Production models use MXFP4 quantized MoE with combined tensors @@ -37,7 +38,11 @@ from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size -__all__ = ["GptOssModelDescriptor", "GptOssExpertRemovalLayerDescriptor"] +__all__ = [ + "GptOssExpertRemovalLayerDescriptor", + "GptOssKVHeadsLayerDescriptor", + "GptOssModelDescriptor", +] @ModelDescriptorFactory.register_decorator("gpt_oss") @@ -173,7 +178,29 @@ def pruning_mixins() -> Dict[str, PruningMixIn]: Note: Expert removal works for unquantized models (test models). Production models use MXFP4 quantization which is not yet supported. """ - return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())} + # Single instance shared between the canonical key and the legacy alias + # so resolve_pruning_mixin returns the same object regardless of which + # name a caller uses. + expert_mixin = ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor()) + return { + "experts_removal": expert_mixin, + # Backward-compat alias: this key was "expert_removal" before the + # bypass branch standardised on "experts_removal" (matching the + # NemotronH descriptor). Kept so external scripts that still call + # `resolve_pruning_mixin("expert_removal", GptOssModelDescriptor)` + # continue to work. Remove after a deprecation cycle. + "expert_removal": expert_mixin, + "kv_heads": KVHeadsPruningMixIn(GptOssKVHeadsLayerDescriptor()), + } + + +@dataclass +class GptOssKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) @dataclass diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py index 1c5706d1944..b3f33887367 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -29,11 +29,16 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same -__all__ = ["NemotronHExpertRemovalLayerDescriptor", "NemotronHModelDescriptor"] +__all__ = [ + "NemotronHExpertRemovalLayerDescriptor", + "NemotronHKVHeadsLayerDescriptor", + "NemotronHModelDescriptor", +] def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: @@ -51,6 +56,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: return matches +@dataclass +class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @dataclass class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): target_name: str = "mixer.gate" @@ -251,4 +265,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]: def pruning_mixins() -> Dict[str, PruningMixIn]: return { "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()), } diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py index 8290f961936..6b82459b5a8 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py @@ -33,12 +33,14 @@ FFNIntermediateLayerDescriptor, FFNIntermediatePruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same __all__ = [ "NemotronHV2FFNIntermediateLayerDescriptor", + "NemotronHV2KVHeadsLayerDescriptor", "NemotronHV2ExpertRemovalLayerDescriptor", "NemotronHV2ModelDescriptor", ] @@ -109,6 +111,15 @@ class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"]) +@dataclass +class NemotronHV2KVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @ModelDescriptorFactory.register_decorator("nemotron_h_v2") class NemotronHV2ModelDescriptor(ModelDescriptor): _DECODER_LAYER_CLS: Type[nn.Module] = None @@ -291,5 +302,6 @@ def pruning_mixins() -> Dict[str, PruningMixIn]: "ffn_intermediate": FFNIntermediatePruningMixIn( NemotronHV2FFNIntermediateLayerDescriptor() ), + "kv_heads": KVHeadsPruningMixIn(NemotronHV2KVHeadsLayerDescriptor()), # TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated } diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py index aeedd419923..a0f9c95c6ce 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py @@ -26,9 +26,13 @@ ) from ....block_config import BlockConfig -from ....pruning.expert_removal_pruning_mixin import ExpertRemovalLayerDescriptor +from ....pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor -from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn +from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size @@ -56,6 +60,13 @@ def get_language_model_config(config): """Qwen3-VL has nested text_config for language model parameters.""" return config.text_config if hasattr(config, "text_config") else config + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "experts_removal": ExpertRemovalPruningMixIn(Qwen3VLExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(Qwen3VLKVHeadsLayerDescriptor()), + } + @staticmethod def decoder_layer_cls(): return Qwen3VLMoeTextDecoderLayer diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py index 740d1fada3c..523112843b2 100644 --- a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py +++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py @@ -24,7 +24,12 @@ ) from .pruning_mixin import LayerDescriptor, PruningMixIn -from .pruning_utils import GQAInitMode, _init_attention_biases, _init_attention_weights +from .pruning_utils import ( + GQAInitMode, + _init_attention_biases, + _init_attention_weights, + _lm_head_dim, +) __all__ = [ "KVHeadsLayerDescriptor", @@ -60,6 +65,7 @@ def prune_single_layer( new_state_dict: dict, original_config: PretrainedConfig, new_config: PretrainedConfig, + descriptor, gqa_init_mode: GQAInitMode, mlp_init_config: Optional[dict[str, Any]], is_original_mha: bool, @@ -74,7 +80,7 @@ def prune_single_layer( f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names ] - head_size = new_config.head_dim + head_size = _lm_head_dim(new_config, descriptor) for part in ["weight", "bias"]: attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]] q_key, k_key, v_key, o_key = attn_keys @@ -94,6 +100,7 @@ def prune_single_layer( layer_idx=layer_idx, new_state_dict=new_state_dict, new_config=new_config, + descriptor=descriptor, original_state_dict=parent_state_dict, original_config=original_config, q_key=q_key, @@ -112,6 +119,7 @@ def prune_single_layer( layer_idx=layer_idx, new_state_dict=new_state_dict, new_config=new_config, + descriptor=descriptor, original_state_dict=parent_state_dict, original_config=original_config, q_key=q_key, diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py index c600e119cfa..38ab7a2e0be 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_utils.py +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -52,6 +52,7 @@ class MlpInitMode(Enum): PruneByActivationsLog = "PruneByActivationsLog" ExpertRemoval = "ExpertRemoval" ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + MoEChannelPruning = "MoEChannelPruning" class LinearInitMode(Enum): @@ -66,6 +67,14 @@ class HiddenSizeInitMode(Enum): CopyAsIs = "CopyAsIs" +def _lm_head_dim(config, descriptor: Type[ModelDescriptor]) -> int: + lm_config = descriptor.get_language_model_config(config) + head_dim = getattr(lm_config, "head_dim", None) + if head_dim is not None: + return head_dim + return lm_config.hidden_size // lm_config.num_attention_heads + + def resolve_pruning_mixin( pruning_mixin, descriptor: Type[ModelDescriptor] ) -> PruningMixIn | List[PruningMixIn]: @@ -214,6 +223,7 @@ def _init_attention_weights( layer_idx, new_state_dict, new_config, + descriptor, original_state_dict, q_key, k_key, @@ -224,10 +234,13 @@ def _init_attention_weights( head_size, mlp_init_config, ): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + new_lm = descriptor.get_language_model_config(new_config) + orig_lm = descriptor.get_language_model_config(original_config) + assert new_lm.num_attention_heads == orig_lm.num_attention_heads, ( + f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})" ) - num_q_heads = new_config.num_attention_heads + num_q_heads = new_lm.num_attention_heads + # block_configs lives on the outer puzzletron-converted config, not on text_config. num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads @@ -362,6 +375,7 @@ def _init_attention_biases( layer_idx, new_state_dict, new_config, + descriptor, original_state_dict, q_key, k_key, @@ -372,17 +386,29 @@ def _init_attention_biases( head_size, mlp_init_config, ): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + new_lm = descriptor.get_language_model_config(new_config) + orig_lm = descriptor.get_language_model_config(original_config) + assert new_lm.num_attention_heads == orig_lm.num_attention_heads, ( + f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})" ) - num_q_heads = new_config.num_attention_heads + num_q_heads = new_lm.num_attention_heads + # block_configs lives on the outer puzzletron-converted config, not on text_config. num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads n_heads_in_group = num_q_heads // num_kv_heads orig_n_heads_in_group = num_q_heads // orig_num_kv_heads - o_proj_bias = new_config.o_proj_bias - attention_bias = new_config.attention_bias + # Some HF native configs (e.g. GptOssConfig) don't expose o_proj_bias / attention_bias as + # top-level attributes the way puzzletron's DeciLM-style configs do. Fall back to probing + # the new state dict for the actual bias keys only when the attribute is omitted. + # KVHeadsPruningMixIn only calls this helper after filtering to keys present in + # new_state_dict, so the probe mirrors the caller's already-selected bias tensors. + o_proj_bias = getattr(new_config, "o_proj_bias", None) + if o_proj_bias is None: + o_proj_bias = o_key in new_state_dict + attention_bias = getattr(new_config, "attention_bias", None) + if attention_bias is None: + attention_bias = any(key in new_state_dict for key in (q_key, k_key, v_key)) # If no biases if not (o_proj_bias or attention_bias): @@ -438,8 +464,8 @@ def _init_attention_biases( assert not is_original_mha, ( "Degrouping can only be done on original models that are GQA themselves." ) - n_groups = new_config.num_attention_heads // n_heads_in_group - orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group + n_groups = new_lm.num_attention_heads // n_heads_in_group + orig_n_groups = orig_lm.num_attention_heads // orig_n_heads_in_group assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" n_repeats = n_groups // orig_n_groups if n_repeats > 1: diff --git a/modelopt/torch/puzzletron/sewing_kit/passage.py b/modelopt/torch/puzzletron/sewing_kit/passage.py index d8fa1f51cf9..c77b9dd41cd 100644 --- a/modelopt/torch/puzzletron/sewing_kit/passage.py +++ b/modelopt/torch/puzzletron/sewing_kit/passage.py @@ -45,6 +45,7 @@ "PassageOutput", "Predicate", "always_false_predicate", + "always_true_predicate", "Passage", "patch_module", ] diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 3db63f60013..106b0b3e4c3 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -16,6 +16,7 @@ from __future__ import annotations import inspect +import operator from contextlib import contextmanager from typing import ( TYPE_CHECKING, @@ -451,3 +452,95 @@ def _get_group_kwarg_if_necessary() -> dict: torch.distributed.distributed_c10d._object_to_tensor ).parameters.keys() return dict(group=None) if "group" in arg_names else dict() + + +# ────────────────────────────────────────────────────────────────────────────── +# Loss functions for bypass distillation (blockwise local knowledge distillation) +# ────────────────────────────────────────────────────────────────────────────── + +# `normalized_mse_loss` already lives in tools.kd_model — re-export it here so +# bypass-distillation imports stay co-located with the per-vector / per-batch +# variants below, without duplicating the implementation. The `as +# normalized_mse_loss` form is PEP 484's explicit re-export (mypy treats +# `from X import Y` as a private import otherwise). +from modelopt.torch.puzzletron.tools.kd_model import ( # noqa: E402 + normalized_mse_loss as normalized_mse_loss, +) + + +def vectorwise_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, +) -> torch.Tensor: + """Like normalized_mse_loss, but normalization is done per-vector (last dim), then averaged.""" + return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) + + +def batched_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, + batch_dims: Sequence[int] = (0,), +) -> torch.Tensor: + """Per-batch-element relative-L2 loss. + + For each batch element, computes ``||input - target||^2 / (||target||^2 + eps)`` + over the non-batch dims, then averages across batch elements. The additive + ``epsilon`` in the denominator handles all-zero target slices without a hard + clamp and makes the loss scale-invariant when ``||target||^2 >> eps``. + """ + input_shape = tuple(input.shape) + target_shape = tuple(target.shape) + + if epsilon <= 0: + raise ValueError(f"epsilon must be strictly positive, got {epsilon!r}") + + try: + raw_batch_dims = tuple(operator.index(dim) for dim in batch_dims) + except TypeError as exc: + raise ValueError( + f"batch_dims must be an iterable of integer dimensions; got {batch_dims!r} " + f"for input shape {input_shape} and target shape {target_shape}" + ) from exc + + resolved_batch_dims = [] + for dim in raw_batch_dims: + if dim < -input.ndim or dim >= input.ndim: + raise ValueError( + f"batch_dims contains invalid dimension {dim} for input.ndim={input.ndim}; " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={raw_batch_dims}, norm_dims=None" + ) + resolved_batch_dims.append(dim % input.ndim) + + if len(set(resolved_batch_dims)) != len(resolved_batch_dims): + raise ValueError( + f"batch_dims contains duplicate dimensions after normalization; " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={tuple(resolved_batch_dims)}, norm_dims=None" + ) + + norm_dims = tuple(d for d in range(input.ndim) if d not in set(resolved_batch_dims)) + + if input.ndim != target.ndim: + raise ValueError( + f"input and target must have the same number of dimensions; " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={tuple(resolved_batch_dims)}, norm_dims={norm_dims}" + ) + if input_shape != target_shape: + mismatched_dims = tuple( + dim + for dim, (input_size, target_size) in enumerate(zip(input_shape, target_shape)) + if input_size != target_size + ) + raise ValueError( + f"input and target shapes must match exactly; mismatched_dims={mismatched_dims}, " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={tuple(resolved_batch_dims)}, norm_dims={norm_dims}" + ) + + num = ((input - target) ** 2).sum(dim=norm_dims) + den = (target**2).sum(dim=norm_dims) + epsilon + return (num / den).mean() diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index b242c7d48ac..3979f305261 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -22,6 +22,8 @@ import os import re import time +from collections import ChainMap +from collections.abc import Iterator, MutableMapping from copy import deepcopy from functools import partial from pathlib import Path @@ -52,6 +54,49 @@ default_ignore_fn: IgnoreFn = lambda _: False +class _PerLayerKeysView(MutableMapping[str, str]): + def __init__(self, base: dict[str, str]) -> None: + self._base = base + self._overrides: dict[str, str] = {} + self._removed: dict[str, str] = {} + + def __getitem__(self, key: str) -> str: + if key in self._removed: + raise KeyError(key) + if key in self._overrides: + return self._overrides[key] + return self._base[key] + + def __setitem__(self, key: str, value: str) -> None: + self._removed.pop(key, None) + self._overrides[key] = value + + def __delitem__(self, key: str) -> None: + if key in self._removed: + raise KeyError(key) + if key in self._overrides: + self._removed[key] = self._overrides.pop(key) + elif key in self._base: + self._removed[key] = self._base[key] + else: + raise KeyError(key) + + def __iter__(self) -> Iterator[str]: + yield from self._overrides.keys() + for key in self._base: + if key not in self._overrides and key not in self._removed: + yield key + + def __len__(self) -> int: + return sum(1 for _ in self) + + def __contains__(self, key: object) -> bool: + return key not in self._removed and (key in self._overrides or key in self._base) + + def removed_items(self) -> dict[str, str]: + return dict(self._removed) + + class Printer: @staticmethod def print(s: str) -> None: @@ -83,27 +128,41 @@ def _process_single_layer( keys_to_remove = {} layer_out_state_dict = {} - # Delegate to pruning_mixin if available + # Delegate to pruning_mixin if available (supports a single mixin or a list of mixins). + # Mixins run sequentially. Each mixin sees the state dict produced by earlier mixins, + # which lets independent pruning methods compose on the same tensor (for example one + # pruning FFN channels and another pruning hidden-size dimensions). if pruning_mixin is not None: - _layer_out = pruning_mixin.prune_single_layer( - layer_idx=layer_idx, - parent_state_dict=parent_state_dict, - new_state_dict=new_state_dict, - original_config=original_config, - new_config=new_config, - gqa_init_mode=gqa_init_mode, - mlp_init_mode=mlp_init_mode, - mlp_init_config=mlp_init_config, - linear_init_mode=linear_init_mode, - ignored_keys=ignored_keys, - keys=keys, - is_original_mha=is_original_mha, - head_size=head_size, - hidden_size=hidden_size, - keys_to_remove=keys_to_remove, - ) - layer_out_state_dict.update(_layer_out) - return layer_out_state_dict, keys_to_remove + _mixins = pruning_mixin if isinstance(pruning_mixin, list) else [pruning_mixin] + merged_keys_to_remove = {} + parent_layer_updates = {} + current_parent_state_dict = ChainMap(parent_layer_updates, parent_state_dict) + current_keys = _PerLayerKeysView(keys) + for _mixin in _mixins: + mixin_keys_to_remove = {} + _layer_out = _mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=current_parent_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + descriptor=descriptor, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=current_keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=mixin_keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) + parent_layer_updates.update(_layer_out) + merged_keys_to_remove.update(current_keys.removed_items()) + merged_keys_to_remove.update(mixin_keys_to_remove) + return layer_out_state_dict, merged_keys_to_remove # Legacy inline processing (fallback when no pruning_mixin) @@ -134,6 +193,7 @@ def _process_single_layer( layer_idx=layer_idx, new_state_dict=new_state_dict, new_config=new_config, + descriptor=descriptor, original_state_dict=parent_state_dict, original_config=original_config, q_key=q_key, @@ -152,6 +212,7 @@ def _process_single_layer( layer_idx=layer_idx, new_state_dict=new_state_dict, new_config=new_config, + descriptor=descriptor, original_state_dict=parent_state_dict, original_config=original_config, q_key=q_key, @@ -791,7 +852,10 @@ def update_model_config( def override(item, item_overrides): if item_overrides is None: - return item_overrides + # Hydra/OmegaConf ``null`` means "leave this field unchanged" in + # model_config_overrides. This lets compact overrides update only one + # sibling field without clearing the rest of the dataclass. + return item if dataclasses.is_dataclass(item): assert isinstance(item_overrides, dict) return dataclass_override(item, item_overrides) diff --git a/modelopt/torch/puzzletron/tools/hydra_utils.py b/modelopt/torch/puzzletron/tools/hydra_utils.py index c30be4efde8..91adff0f076 100644 --- a/modelopt/torch/puzzletron/tools/hydra_utils.py +++ b/modelopt/torch/puzzletron/tools/hydra_utils.py @@ -32,16 +32,64 @@ ] -def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int: +def warmup_steps(tokens: int, block: int, mbs: int, grad_accum: int = 1, pct: float = 0.05) -> int: """ - Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage. - Used as a resolver in hydra configs. + Calculate warmup steps in optimizer-step units. + + total_iters = tokens / (block * mbs) gives micro-batches; one optimizer step + consumes ``grad_accum`` micro-batches, so total optimizer steps = total_iters + / grad_accum. The LR scheduler in ``_get_lr`` is indexed by ``step_num`` + (optimizer steps), so warmup must be in the same units. """ - steps = (int(tokens) // int(block)) // int(mbs) + try: + tokens = int(tokens) + block = int(block) + mbs = int(mbs) + grad_accum = int(grad_accum) + except (TypeError, ValueError) as exc: + raise ValueError( + "tokens, block, mbs, and grad_accum must be integers or castable to int; " + f"got tokens={tokens!r}, block={block!r}, mbs={mbs!r}, grad_accum={grad_accum!r}" + ) from exc + + try: + pct = float(pct) + except (TypeError, ValueError) as exc: + raise ValueError(f"pct must be a float or castable to float, got {pct!r}") from exc + + if tokens < 0: + raise ValueError(f"tokens must be >= 0, got {tokens!r}") + if block <= 0: + raise ValueError(f"block must be > 0, got {block!r}") + if mbs <= 0: + raise ValueError(f"mbs must be > 0, got {mbs!r}") + if grad_accum < 1: + raise ValueError(f"grad_accum must be >= 1, got {grad_accum!r}") + if not 0.0 <= pct <= 1.0: + raise ValueError(f"pct must be between 0.0 and 1.0 inclusive, got {pct!r}") + + iters = (tokens // block) // mbs + steps = max(1, iters // grad_accum) w = pct * steps return max(1, round(w)) +def _warmup_steps_resolver(*args): + if len(args) == 3: + return warmup_steps(*args) + if len(args) == 4: + tokens, block, mbs, pct = args + return warmup_steps(tokens, block, mbs, pct=pct) + if len(args) == 5: + return warmup_steps(*args) + raise ValueError( + "warmup_steps resolver expects 3, 4, or 5 arguments: " + "(tokens, block, micro_batch_size), " + "(tokens, block, micro_batch_size, warmup_ratio), or " + "(tokens, block, micro_batch_size, grad_accumulation_steps, warmup_ratio)" + ) + + def register_hydra_resolvers(): OmegaConf.register_new_resolver("to_path", lambda x: Path(x)) OmegaConf.register_new_resolver( @@ -50,7 +98,7 @@ def register_hydra_resolvers(): OmegaConf.register_new_resolver( "timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None ) - OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p)) + OmegaConf.register_new_resolver("warmup_steps", _warmup_steps_resolver) OmegaConf.register_new_resolver("get_object", lambda x: get_object(x)) diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py index f4046531491..3d8b94c82cc 100644 --- a/modelopt/torch/puzzletron/utils/data/dataloaders.py +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -31,7 +31,7 @@ from ...tools.logger import mprint from .dataset import ConstantLengthDataset -__all__ = ["create_validation_dataloader", "create_padded_tensor"] +__all__ = ["create_train_dataloader", "create_validation_dataloader", "create_padded_tensor"] def collate_none_fn( @@ -73,6 +73,74 @@ def load_streaming_fn( return dataset +def create_train_dataloader( + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset_path: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "train", + keep_in_memory: bool = False, + shuffle_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + num_workers: int = 0, +) -> DataLoader: + """Create an infinite training DataLoader over ConstantLengthDataset.""" + # ConstantLengthDataset.__iter__ does not consult torch.utils.data.get_worker_info() + # to shard work across DataLoader workers, so num_workers > 0 would have every + # worker iterate the full dataset and emit duplicate samples. Reject explicitly + # until ConstantLengthDataset gains worker-aware iteration; the guard can then + # be removed. + if num_workers > 0: + raise ValueError( + f"create_train_dataloader: num_workers={num_workers} is not supported " + f"because ConstantLengthDataset.__iter__ does not shard via " + f"torch.utils.data.get_worker_info(). Use num_workers=0 (the default) " + f"or add worker-aware sharding to ConstantLengthDataset.__iter__." + ) + + if isinstance(dataset_path, str): + dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory) + else: + dataset = dataset_path + + train_data = dataset[dataset_name] + if shuffle_seed is not None: + # `keep_in_memory` is only valid on map-style HF Datasets; streaming + # `IterableDataset.shuffle()` only accepts `seed` (and an optional + # `buffer_size`). Branch on the dataset type so streaming users + # (`load_from_disk: false`) don't crash on this call. + if isinstance(train_data, datasets.IterableDataset): + train_data = train_data.shuffle(seed=shuffle_seed) + else: + train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=keep_in_memory) + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + infinite=True, + seq_length=block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + ) + + return DataLoader( + train_dataset, + batch_size=micro_batch_size, + pin_memory=True, + num_workers=num_workers, + ) + + def create_validation_dataloader( accelerator: Accelerator | None, seed: int, diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py index f88e44a234b..4b4a97c38ca 100644 --- a/modelopt/torch/puzzletron/utils/data/dataset.py +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -14,6 +14,7 @@ # limitations under the License. # mypy: ignore-errors import functools +import warnings from collections.abc import Sequence import numpy as np @@ -33,6 +34,32 @@ FIM_TOKEN_CONNECTOR_SANTA = "-" # nosec B105 FIM_TOKEN_END_LIST = ["prefix>", "middle>", "suffix>", "pad>"] CODEGEN_FIM_TOKENS = ["", "<|endoftext|>", ""] +_CHAT_TEMPLATE_FALLBACK_WARNING_EMITTED = False + + +def _message_content_to_text(content) -> str: + if isinstance(content, str): + return content + if isinstance(content, dict): + if "text" in content: + return str(content["text"]) + raise ValueError( + f"Unsupported structured message content without a text field: {content!r}" + ) + return str(content) + + +def _format_messages_without_chat_template(messages) -> str: + global _CHAT_TEMPLATE_FALLBACK_WARNING_EMITTED + if not _CHAT_TEMPLATE_FALLBACK_WARNING_EMITTED: + warnings.warn( + "Tokenizer has no chat_template; formatting messages as role-tagged plain text.", + stacklevel=2, + ) + _CHAT_TEMPLATE_FALLBACK_WARNING_EMITTED = True + return "\n".join( + f"{message['role']}: {_message_content_to_text(message['content'])}" for message in messages + ) class ConstantLengthDataset(IterableDataset): @@ -122,15 +149,21 @@ def __iter__(self) -> dict[str, torch.Tensor]: continue if not self.is_dataset_already_tokenized: sample = sample[self.content_field] - if ( - isinstance(sample, list) - and isinstance(sample[0], dict) - and {"content", "role"}.issubset(sample[0]) - ): - if len(sample) > 1: - sample = self.tokenizer.apply_chat_template(sample, tokenize=False) - else: - sample = sample[0]["content"] + if isinstance(sample, list): + if len(sample) == 0: + sample = "" + elif isinstance(sample[0], dict) and {"content", "role"}.issubset( + sample[0] + ): + if len(sample) > 1: + if getattr(self.tokenizer, "chat_template", None) is not None: + sample = self.tokenizer.apply_chat_template( + sample, tokenize=False + ) + else: + sample = _format_messages_without_chat_template(sample) + else: + sample = _message_content_to_text(sample[0]["content"]) else: sample = sample[self.tokens_field] sample = sample[: self.max_sample_length] diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py index 149563b4321..acaac0a344e 100644 --- a/modelopt/torch/puzzletron/utils/parsing.py +++ b/modelopt/torch/puzzletron/utils/parsing.py @@ -24,6 +24,7 @@ # mypy: ignore-errors import json +import math from pathlib import Path from typing import Any @@ -116,7 +117,7 @@ def format_block_configs(config) -> str: ╭─────────────────────── Model Architecture ────────────────────────╮ │ Layer 1 │ Attention: no_op │ FFN: mult = 4.95 │ │ Layer 2 │ Attention: 4 heads in group │ FFN: mult = 4.95 │ - │ Layer 3 │ Attention: 4 heads in group │ FFN: no_op │ + │ Layer 3 │ Attention: no_op │ FFN: no_op │ ╰────────────────────────────────────────────────────────────────────╯ """ if not hasattr(config, "block_configs") or not config.block_configs: @@ -158,7 +159,7 @@ def _format_attention_config(attention_config) -> str: num_kv_heads = attention_config.num_key_value_heads if num_kv_heads is not None: - return f"{num_kv_heads} kv heads" + return f"🐙 {num_kv_heads} kv heads" if attention_config.replace_with_linear: return "linear replacement" @@ -192,12 +193,12 @@ def _format_ffn_config(ffn_config) -> str: ffn_intermediate = ffn_config.intermediate_size if ffn_intermediate is not None: - return f"ffn_intermediate = {ffn_intermediate}" + return f"🧱 ffn_dim = {ffn_intermediate}" # Check for MoE configuration moe_config = ffn_config.moe if moe_config: - return "MoE" + return "🔀 MoE" if ffn_config.sparsify: return "sparse" @@ -287,7 +288,7 @@ def _add_config_section(cfg: DictConfig, section_name: str = "", indent: int = 0 # Regular key-value pair indent_str = " " * (indent + 1) value_str = _format_value(value).replace(" " * 0, "").strip() - line = f"│ {indent_str} {key}: {value_str}" + line = f"│ {indent_str} • {key}: {value_str}" # Pad to box width if len(line) >= box_width - 1: # Truncate long lines @@ -310,6 +311,8 @@ def format_stitched_losses( losses_dict: dict[str, float], best_steps_dict: dict[str, int] | None = None, best_values_dict: dict[str, float] | None = None, + initial_values_dict: dict[str, float] | None = None, + not_trainable_names: set[str] | None = None, step_number: int | None = None, title: str = "Stitched Module Losses", ) -> str: @@ -320,6 +323,9 @@ def format_stitched_losses( losses_dict: Dictionary with block names as keys and current loss values as floats best_steps_dict: Optional dictionary with block names as keys and best step numbers as values best_values_dict: Optional dictionary with block names as keys and best loss values as floats + initial_values_dict: Optional dictionary with block names as keys and initial loss values + (from the first log chunk) as floats. Used to render the "Δ from initial" column as + a per-block training-progress signal. step_number: Optional current step number to include in summary title: Title to display at the top of the formatted output @@ -328,23 +334,39 @@ def format_stitched_losses( Example output: ╭─────────────────── Stitched Module Losses ──────────────────╮ - │ Block │ Loss Value │ Best Step │ Best Value │ Change from avg │ - │───────┼────────────┼───────────┼────────────┼──────────────────│ - │ 00 │ 6.21e-03 │ Step 5 │ 5.95e-03 │ ↑ +2.6e-04 │ - │ 01 │ 5.14e-04 │ Step 12 │ 5.14e-04 │ ↓ -1.2e-04 │ - │ 02 │ 9.84e-05 │ Step 15 │ 9.84e-05 │ ↓ -3.1e-04 │ + │ Block │ Loss Value │ Δ from initial │ Best Value │ Best Step │ + │───────┼────────────┼──────────────────┼────────────┼───────────│ + │ 00 │ 6.21e-03 │ ↓ -3.2e-04 (-5%) │ 5.95e-03 │ Step 5 │ + │ 01 │ 5.14e-04 │ ↓ -1.8e-03 (-78%)│ 5.14e-04 │ Step 12 │ + │ 02 │ 9.84e-05 │ ↓ -4.1e-04 (-81%)│ 9.84e-05 │ Step 15 │ ╰──────────────────────────────────────────────────────────────╯ """ if not losses_dict: + if not_trainable_names: + return ( + "No trainable losses found; " + f"skipped {len(not_trainable_names)} non-trainable blocks" + ) return "❌ No losses found" + if best_steps_dict: + best_steps_dict = {k: v for k, v in best_steps_dict.items() if k in losses_dict} + if best_values_dict: + best_values_dict = {k: v for k, v in best_values_dict.items() if k in losses_dict} + if initial_values_dict: + initial_values_dict = {k: v for k, v in initial_values_dict.items() if k in losses_dict} + lines = [] # Calculate statistics loss_values = list(losses_dict.values()) - max_loss = max(loss_values) - min_loss = min(loss_values) - avg_loss = sum(loss_values) / len(loss_values) + finite_loss_values = [value for value in loss_values if math.isfinite(value)] + if finite_loss_values: + max_loss = max(finite_loss_values) + min_loss = min(finite_loss_values) + avg_loss = sum(finite_loss_values) / len(finite_loss_values) + else: + max_loss = min_loss = avg_loss = float("nan") # Calculate box width for new layout (removed Bar column) box_width = 74 @@ -356,10 +378,10 @@ def format_stitched_losses( f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" ) separator = ( - f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Best Step':<10} │ " - f"{'Best Value':<12} │ {'Change from avg':<18} │" + f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Δ from initial':<18} │ " + f"{'Best Value':<12} │ {'Best Step':<10} │" ) - divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 12}┼{'─' * 14}┼{'─' * 20}│" + divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 20}┼{'─' * 14}┼{'─' * 12}│" lines.extend([header, title_line, separator, divider]) @@ -382,26 +404,35 @@ def format_stitched_losses( best_value = loss_value # Assume current is best if no history best_value_str = f"{best_value:.2e}" - # Calculate change from average - change_from_avg = loss_value - avg_loss - if abs(change_from_avg) > 1e-8: # Only show if meaningful - change_str = f"{abs(change_from_avg):.1e}" - if change_from_avg > 0: - # Current is above average (worse for loss) - change_display = f"↑ +{change_str}" - else: - # Current is below average (better for loss) - change_display = f"↓ -{change_str}" + # Calculate change from initial: current loss minus the block's loss in the + # first log chunk we saw. Per-block training-progress signal — answers "is + # bypass distillation actually reducing this block's loss?" and stays + # apples-to-apples even when blocks have very different intrinsic loss scales. + if not initial_values_dict or block_name not in initial_values_dict: + # No baseline supplied (callers may omit initial_values_dict). + change_display = " --" + elif not math.isfinite(loss_value) or not math.isfinite(initial_values_dict[block_name]): + change_display = "non-finite" else: - # At average value - change_display = "↔ 0.0e+00" + initial_value = initial_values_dict[block_name] + delta = loss_value - initial_value + if abs(delta) > 1e-8: + pct = (delta / initial_value * 100.0) if initial_value != 0.0 else 0.0 + # Clamp percentage display to keep the cell within the 18-char column + # even on pathological divergence (e.g. a block whose loss 10x'd). + pct_clamped = max(-999.0, min(999.0, pct)) + arrow = "↓" if delta < 0 else "↑" + sign = "-" if delta < 0 else "+" + change_display = f"{arrow} {sign}{abs(delta):.1e} ({pct_clamped:+.0f}%)" + else: + change_display = "↔ 0.0e+00" # Format the line block_display = block_name.replace("block_", "").zfill(2) line = ( - f"│ {block_display:<5} │ {loss_str:<12} │ {best_step_str:<10} │ " - f"{best_value_str:<12} │ {change_display:<18} │" + f"│ {block_display:<5} │ {loss_str:<12} │ {change_display:<18} │ " + f"{best_value_str:<12} │ {best_step_str:<10} │" ) lines.append(line) @@ -413,6 +444,8 @@ def format_stitched_losses( if step_number is not None: summary_parts.append(f"Step {step_number}") summary_parts.extend([f"Avg={avg_loss:.2e}", f"Max={max_loss:.2e}", f"Min={min_loss:.2e}"]) + if not_trainable_names: + summary_parts.append(f"Skipped={len(not_trainable_names)}") summary_text = ", ".join(summary_parts) summary = f"│ Summary: {summary_text}" @@ -436,7 +469,9 @@ def format_stitched_losses( best_step_values = [] for block_name, best_step in best_steps_dict.items(): if best_step == modal_best_step and block_name in best_values_dict: - best_step_values.append(best_values_dict[block_name]) + best_value = best_values_dict[block_name] + if math.isfinite(best_value): + best_step_values.append(best_value) if best_step_values: best_step_avg = sum(best_step_values) / len(best_step_values) diff --git a/tests/unit/torch/puzzletron/test_bypass_dataloaders.py b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py new file mode 100644 index 00000000000..17057019154 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py @@ -0,0 +1,278 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for bypass-distillation dataloader behavior added by this PR.""" + +from types import SimpleNamespace + +import pytest +import torch + +import modelopt.torch.puzzletron.utils.data.dataloaders as dl +import modelopt.torch.puzzletron.utils.data.dataset as dataset_module +from modelopt.torch.puzzletron.utils.data.dataloaders import create_train_dataloader +from modelopt.torch.puzzletron.utils.data.dataset import ConstantLengthDataset + + +def test_create_train_dataloader_rejects_num_workers_gt_zero(): + """ConstantLengthDataset doesn't shard work via ``get_worker_info`` — every + worker would emit the same samples. The guard fires before tokenizer or + dataset are touched, so bare-bones args are enough.""" + with pytest.raises(ValueError, match="num_workers"): + create_train_dataloader( + seed=0, + tokenizer=None, + block_size=8, + dataset_path={"train": []}, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + micro_batch_size=1, + num_workers=2, + ) + + +class _FakeTrainConstantLengthDataset: + last_args = None + last_kwargs = None + + def __init__(self, *args, **kwargs): + type(self).last_args = args + type(self).last_kwargs = kwargs + + +class _FakeTrainSplit: + def __init__(self): + self.shuffle_calls = [] + + def shuffle(self, **kwargs): + self.shuffle_calls.append(kwargs) + return self + + +@pytest.fixture +def patched_train_dataloader(monkeypatch): + captured = {} + + def fake_dataloader(dataset, batch_size, pin_memory, num_workers): + captured["dataset"] = dataset + captured["batch_size"] = batch_size + captured["pin_memory"] = pin_memory + captured["num_workers"] = num_workers + return SimpleNamespace(dataset=dataset) + + _FakeTrainConstantLengthDataset.last_args = None + _FakeTrainConstantLengthDataset.last_kwargs = None + monkeypatch.setattr(dl, "ConstantLengthDataset", _FakeTrainConstantLengthDataset) + monkeypatch.setattr(dl, "DataLoader", fake_dataloader) + return captured + + +def test_create_train_dataloader_builds_constant_length_dataset_from_loaded_split( + patched_train_dataloader, +): + train_split = _FakeTrainSplit() + load_calls = [] + + def fake_load_dataset(dataset_path, content_field, keep_in_memory): + load_calls.append((dataset_path, content_field, keep_in_memory)) + return {"custom_train": train_split} + + tokenizer = object() + out = create_train_dataloader( + seed=7, + tokenizer=tokenizer, + block_size=16, + dataset_path="/tmp/train", + content_field="conversation", + fim_rate=0.25, + fim_spm_rate=0.75, + micro_batch_size=3, + load_dataset_fn=fake_load_dataset, + dataset_name="custom_train", + keep_in_memory=True, + shuffle_seed=123, + source_datasets_to_discard=("bad-source",), + bos_rate=0.5, + ) + + assert out.dataset is patched_train_dataloader["dataset"] + assert load_calls == [("/tmp/train", "conversation", True)] + assert train_split.shuffle_calls == [{"seed": 123, "keep_in_memory": True}] + assert _FakeTrainConstantLengthDataset.last_args == (tokenizer, train_split) + assert _FakeTrainConstantLengthDataset.last_kwargs == { + "infinite": True, + "seq_length": 16, + "content_field": "conversation", + "fim_rate": 0.25, + "fim_spm_rate": 0.75, + "seed": 7, + "source_datasets_to_discard": ("bad-source",), + "bos_rate": 0.5, + } + assert isinstance(patched_train_dataloader["dataset"], _FakeTrainConstantLengthDataset) + assert patched_train_dataloader["batch_size"] == 3 + assert patched_train_dataloader["pin_memory"] is True + assert patched_train_dataloader["num_workers"] == 0 + + +def test_create_train_dataloader_streaming_shuffle_omits_keep_in_memory( + monkeypatch, + patched_train_dataloader, +): + class FakeStreamingDataset: + def __init__(self): + self.shuffle_seed = None + + def shuffle(self, seed): + self.shuffle_seed = seed + return self + + monkeypatch.setattr(dl.datasets, "IterableDataset", FakeStreamingDataset) + train_split = FakeStreamingDataset() + + create_train_dataloader( + seed=0, + tokenizer=object(), + block_size=8, + dataset_path={"train": train_split}, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + micro_batch_size=1, + load_dataset_fn=lambda *args, **kwargs: pytest.fail("dataset mapping should not load"), + shuffle_seed=99, + keep_in_memory=True, + ) + + assert train_split.shuffle_seed == 99 + assert _FakeTrainConstantLengthDataset.last_args[1] is train_split + assert isinstance(patched_train_dataloader["dataset"], _FakeTrainConstantLengthDataset) + + +class _NoChatTemplateTokenizer: + eos_token_id = 1 + bos_token_id = None + + def __init__(self): + self.seen_texts = None + self.vocab = {} # Required by ConstantLengthDataset.get_fim_token_ids. + + def __call__(self, texts, truncation=False): + self.seen_texts = texts + return {"input_ids": [[0] for _ in texts]} + + +class _ChatTemplateTokenizer(_NoChatTemplateTokenizer): + chat_template = "template" + + def __init__(self): + super().__init__() + self.template_messages = None + + def apply_chat_template(self, messages, tokenize=False): + self.template_messages = messages + return "templated chat" + + +class _ConversationDataset: + column_names = ("text",) + + def __iter__(self): + yield { + "text": [ + {"role": "user", "content": {"text": "hello"}}, + {"role": "assistant", "content": "world"}, + ] + } + + +class _EmptyConversationDataset: + column_names = ("text",) + + def __iter__(self): + yield {"text": []} + + +def test_constant_length_dataset_no_chat_template_adds_role_tags_and_warns_once(monkeypatch): + monkeypatch.setattr(dataset_module, "_CHAT_TEMPLATE_FALLBACK_WARNING_EMITTED", False) + tokenizer = _NoChatTemplateTokenizer() + dataset = ConstantLengthDataset( + tokenizer, + _ConversationDataset(), + infinite=False, + seq_length=2, + num_of_sequences=1, + chars_per_token=100, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + label_shift=False, + ) + + with pytest.warns(UserWarning, match="no chat_template"): + realized = list(dataset) + + assert tokenizer.seen_texts == ["user: hello\nassistant: world"] + assert len(realized) == 1 + assert torch.equal(realized[0]["input_ids"], torch.tensor([0, 1])) + assert torch.equal(realized[0]["targets"], torch.tensor([0, 1])) + + +def test_constant_length_dataset_uses_tokenizer_chat_template_when_available(monkeypatch): + monkeypatch.setattr(dataset_module, "_CHAT_TEMPLATE_FALLBACK_WARNING_EMITTED", False) + tokenizer = _ChatTemplateTokenizer() + dataset = ConstantLengthDataset( + tokenizer, + _ConversationDataset(), + infinite=False, + seq_length=2, + num_of_sequences=1, + chars_per_token=100, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + label_shift=False, + ) + + realized = list(dataset) + + assert tokenizer.template_messages == [ + {"role": "user", "content": {"text": "hello"}}, + {"role": "assistant", "content": "world"}, + ] + assert tokenizer.seen_texts == ["templated chat"] + assert len(realized) == 1 + + +def test_constant_length_dataset_handles_empty_message_list(): + tokenizer = _NoChatTemplateTokenizer() + dataset = ConstantLengthDataset( + tokenizer, + _EmptyConversationDataset(), + infinite=False, + seq_length=2, + num_of_sequences=1, + chars_per_token=100, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + label_shift=False, + ) + + realized = list(dataset) + + assert tokenizer.seen_texts == [""] + assert len(realized) == 1 diff --git a/tests/unit/torch/puzzletron/test_bypass_losses.py b/tests/unit/torch/puzzletron/test_bypass_losses.py new file mode 100644 index 00000000000..85c9490abf2 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_losses.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for bypass-distillation loss and loss-log formatting behavior.""" + +import pytest +import torch + +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + vectorwise_normalized_mse_loss, +) +from modelopt.torch.puzzletron.utils.parsing import format_stitched_losses + + +def test_vectorwise_normalized_mse_loss_matches_batched_last_dim(): + input_ = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4) + target = input_ + 1.0 + + vectorwise = vectorwise_normalized_mse_loss(input_, target) + batched = batched_normalized_mse_loss(input_, target, batch_dims=(0, 1)) + + torch.testing.assert_close(vectorwise, batched) + + +def test_batched_normalized_mse_loss_matches_manual_relative_l2(): + input_ = torch.tensor([[[1.0, 2.0], [3.0, 5.0]], [[2.0, 4.0], [6.0, 8.0]]]) + target = torch.tensor([[[1.0, 1.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]]]) + + loss = batched_normalized_mse_loss(input_, target, epsilon=1e-6, batch_dims=(0, 1)) + expected = (((input_ - target) ** 2).sum(dim=2) / ((target**2).sum(dim=2) + 1e-6)).mean() + + torch.testing.assert_close(loss, expected) + + +def test_batched_normalized_mse_loss_zero_target_is_finite(): + """All-zero target slice must not produce NaN/Inf. + + With the relative-L2 formula ``sum((x-t)^2) / (sum(t^2) + eps)``, an all-zero + target reduces the denominator to exactly ``eps`` — finite, no division by + zero — so the loss equals ``||input||^2 / eps``. The numeric value is large + by construction (that's what zero-magnitude targets mean), but the test + pins the property we actually care about: finiteness, not magnitude. + """ + input_ = torch.full((1, 8), 1.0) + target = torch.zeros(1, 8) + loss = batched_normalized_mse_loss(input_, target) + assert torch.isfinite(loss) + assert not torch.isnan(loss) + + +def test_batched_normalized_mse_loss_zero_input_and_target(): + """Both zero should give exactly 0.0 — numerator is zero, denominator is eps.""" + input_ = torch.zeros(2, 4) + target = torch.zeros(2, 4) + loss = batched_normalized_mse_loss(input_, target) + assert loss.item() == 0.0 + + +def test_batched_normalized_mse_loss_rejects_shape_mismatch(): + input_ = torch.randn(2, 3) + target = torch.randn(2, 1) + + with pytest.raises(ValueError, match="input and target shapes must match exactly"): + batched_normalized_mse_loss(input_, target) + + +def test_batched_normalized_mse_loss_rejects_invalid_batch_dim(): + input_ = torch.randn(2, 3) + target = torch.randn(2, 3) + + with pytest.raises(ValueError, match="batch_dims contains invalid dimension"): + batched_normalized_mse_loss(input_, target, batch_dims=(2,)) + + +def test_batched_normalized_mse_loss_rejects_invalid_options(): + input_ = torch.randn(2, 3) + target = torch.randn(2, 3) + + with pytest.raises(ValueError, match="epsilon must be strictly positive"): + batched_normalized_mse_loss(input_, target, epsilon=0.0) + + +def test_format_stitched_losses_keeps_trainable_nan_visible(): + out = format_stitched_losses( + {"block_0": float("nan"), "block_1": 1.0}, + initial_values_dict={"block_0": 0.5, "block_1": 2.0}, + not_trainable_names={"block_2"}, + step_number=3, + ) + + assert "nan" in out + assert "non-finite" in out + assert "Skipped=1" in out + assert "No trainable blocks found" not in out + + +def test_format_stitched_losses_empty_trainable_reports_skipped_blocks(): + out = format_stitched_losses({}, not_trainable_names={"block_0", "block_1"}) + + assert out == "No trainable losses found; skipped 2 non-trainable blocks" + + +def test_format_stitched_losses_reports_delta_from_initial_and_filters_stale_history(): + out = format_stitched_losses( + {"block_0": 1.0, "block_1": 3.0}, + best_steps_dict={"block_0": 5, "block_9": 99}, + best_values_dict={"block_0": 0.5, "block_9": 9.0}, + initial_values_dict={"block_0": 2.0, "block_1": 3.0, "block_9": 9.0}, + not_trainable_names={"block_2"}, + step_number=8, + ) + + assert "↓ -1.0e+00 (-50%)" in out + assert "↔ 0.0e+00" in out + assert "Step 5" in out + assert "Step 99" not in out + assert "Skipped=1" in out + assert "Avg=2.00e+00" in out diff --git a/tests/unit/torch/puzzletron/test_child_init_mixins.py b/tests/unit/torch/puzzletron/test_child_init_mixins.py new file mode 100644 index 00000000000..59ab8950f30 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_child_init_mixins.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import torch + +from modelopt.torch.puzzletron.block_config import AttentionConfig, BlockConfig, FFNConfig +from modelopt.torch.puzzletron.pruning.pruning_mixin import LayerDescriptor, PruningMixIn +from modelopt.torch.puzzletron.pruning.pruning_utils import resolve_pruning_mixin +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + _process_single_layer, + update_model_config, +) + + +class _AddOneMixin: + def prune_single_layer(self, parent_state_dict, keys_to_remove, **kwargs): + keys_to_remove["w"] = "w" + return {"w": parent_state_dict["w"] + 1} + + +class _TimesTwoMixin: + def prune_single_layer(self, parent_state_dict, keys_to_remove, **kwargs): + keys_to_remove["w"] = "w" + return {"w": parent_state_dict["w"] * 2} + + +class _ConcretePruningMixIn(PruningMixIn): + def supported_hooks(self): + return [] + + +_MAPPED_MIXIN = _ConcretePruningMixIn(LayerDescriptor()) + + +class _DescriptorWithPruningMixins: + @staticmethod + def pruning_mixins(): + return {"mapped": _MAPPED_MIXIN} + + +def _process_with_mixins( + mixins, + keys, + parent_state_dict=None, + new_state_dict=None, +): + return _process_single_layer( + layer_idx=0, + pruning_mixin=mixins, + descriptor=None, + parent_state_dict=parent_state_dict or {"w": torch.tensor([1.0])}, + new_state_dict=new_state_dict or {"w": torch.tensor([0.0])}, + original_config=SimpleNamespace(), + new_config=SimpleNamespace(), + gqa_init_mode=None, + mlp_init_mode=None, + mlp_init_config=None, + linear_init_mode=None, + ignored_keys=set(), + keys=keys, + is_original_mha=False, + head_size=1, + hidden_size=1, + ) + + +def test_pruning_mixins_compose_overlapping_outputs_sequentially(): + layer_state_dict, keys_to_remove = _process_with_mixins( + [_AddOneMixin(), _TimesTwoMixin()], {"w": "w"} + ) + + assert torch.equal(layer_state_dict["w"], torch.tensor([4.0])) + assert keys_to_remove == {"w": "w"} + + +def test_resolve_pruning_mixin_accepts_names_instances_and_lists(): + existing = _ConcretePruningMixIn(LayerDescriptor()) + + assert resolve_pruning_mixin("mapped", _DescriptorWithPruningMixins) is _MAPPED_MIXIN + assert resolve_pruning_mixin(existing, _DescriptorWithPruningMixins) is existing + assert resolve_pruning_mixin(["mapped", existing], _DescriptorWithPruningMixins) == [ + _MAPPED_MIXIN, + existing, + ] + + +def test_update_model_config_treats_null_overrides_as_leave_unchanged(): + config = SimpleNamespace( + num_hidden_layers=1, + block_configs=[ + BlockConfig( + attention=AttentionConfig(num_key_value_heads=8), + ffn=FFNConfig(intermediate_size=32), + ) + ], + ) + + updated = update_model_config( + config, + [ + { + "attention": {"num_key_value_heads": 4}, + "ffn": None, + } + ], + ) + + assert updated is not config + assert updated.block_configs[0].attention.num_key_value_heads == 4 + assert updated.block_configs[0].ffn == config.block_configs[0].ffn + assert config.block_configs[0].attention.num_key_value_heads == 8 diff --git a/tests/unit/torch/puzzletron/test_hydra_utils.py b/tests/unit/torch/puzzletron/test_hydra_utils.py new file mode 100644 index 00000000000..c33f8e6ac8b --- /dev/null +++ b/tests/unit/torch/puzzletron/test_hydra_utils.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from modelopt.torch.puzzletron.tools.hydra_utils import _warmup_steps_resolver, warmup_steps + + +def test_warmup_steps_casts_inputs_before_computing(): + assert warmup_steps("100", "10", "2", "5", "0.5") == 1 + + +def test_warmup_steps_preserves_legacy_defaults(): + assert warmup_steps("1000", "10", "2") == 2 + assert _warmup_steps_resolver("1000", "10", "2") == 2 + assert _warmup_steps_resolver("1000", "10", "2", "0.5") == 25 + assert _warmup_steps_resolver("1000", "10", "2", "5", "0.5") == 5 + + +def test_warmup_steps_resolver_rejects_unknown_arity(): + with pytest.raises(ValueError, match="expects 3, 4, or 5 arguments"): + _warmup_steps_resolver("1000", "10") + + +def test_warmup_steps_rejects_non_castable_inputs(): + with pytest.raises(ValueError, match="castable to int"): + warmup_steps("not-int", "10", "2") + + +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ({"tokens": -1, "block": 1, "mbs": 1, "grad_accum": 1, "pct": 0.1}, "tokens"), + ({"tokens": 1, "block": 0, "mbs": 1, "grad_accum": 1, "pct": 0.1}, "block"), + ({"tokens": 1, "block": 1, "mbs": 0, "grad_accum": 1, "pct": 0.1}, "mbs"), + ({"tokens": 1, "block": 1, "mbs": 1, "grad_accum": 0, "pct": 0.1}, "grad_accum"), + ({"tokens": 1, "block": 1, "mbs": 1, "grad_accum": 1, "pct": 1.1}, "pct"), + ], +) +def test_warmup_steps_rejects_invalid_inputs(kwargs, message): + with pytest.raises(ValueError, match=message): + warmup_steps(**kwargs) diff --git a/tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py b/tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py new file mode 100644 index 00000000000..ce721a4b65d --- /dev/null +++ b/tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import torch + +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import ( + KVHeadsLayerDescriptor, + KVHeadsPruningMixIn, +) +from modelopt.torch.puzzletron.pruning.pruning_utils import GQAInitMode, LinearInitMode, MlpInitMode +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import _process_single_layer + +ATTN_PREFIX = "model.layers.0.self_attn" +QKVO_NAMES = ["q_proj", "k_proj", "v_proj", "o_proj"] + + +class DecoderConfigDescriptor: + @staticmethod + def get_language_model_config(config): + return config.decoder_config + + +def _make_config(): + return SimpleNamespace( + decoder_config=SimpleNamespace(head_dim=2, hidden_size=4, num_attention_heads=2), + block_configs=[ + SimpleNamespace( + attention=SimpleNamespace(num_key_value_heads=2), + ffn=SimpleNamespace(is_moe=False), + ) + ], + attention_bias=True, + o_proj_bias=True, + ) + + +def _make_attention_state_dict(fill_value: float): + state_dict = {} + for proj_idx, name in enumerate(QKVO_NAMES): + weight_key = f"{ATTN_PREFIX}.{name}.weight" + bias_key = f"{ATTN_PREFIX}.{name}.bias" + state_dict[weight_key] = torch.full((4, 4), fill_value + proj_idx) + state_dict[bias_key] = torch.full((4,), fill_value + proj_idx) + return state_dict + + +def _assert_attention_state_dict_matches(actual, expected): + assert set(actual) == set(expected) + for key in expected: + torch.testing.assert_close(actual[key], expected[key]) + + +def test_kv_heads_pruning_mixin_uses_descriptor_selected_config_for_attention_init(): + original_config = _make_config() + new_config = _make_config() + original_state_dict = _make_attention_state_dict(fill_value=1.0) + new_state_dict = _make_attention_state_dict(fill_value=10.0) + keys_to_remove = {} + + mixin = KVHeadsPruningMixIn( + KVHeadsLayerDescriptor( + o_proj_name="o_proj", + attn_prefix_name="model.layers.{layer_idx}.self_attn", + qkvo_weight_names=QKVO_NAMES, + ) + ) + + layer_state_dict = mixin.prune_single_layer( + layer_idx=0, + parent_state_dict=original_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + descriptor=DecoderConfigDescriptor, + gqa_init_mode=GQAInitMode.CopyAsIs, + mlp_init_config=None, + is_original_mha=True, + keys={key: key for key in original_state_dict}, + keys_to_remove=keys_to_remove, + ) + + _assert_attention_state_dict_matches(layer_state_dict, original_state_dict) + assert keys_to_remove == {key: key for key in original_state_dict} + + +def test_legacy_process_single_layer_uses_descriptor_selected_config_for_attention_init(): + original_config = _make_config() + new_config = _make_config() + original_state_dict = _make_attention_state_dict(fill_value=1.0) + new_state_dict = _make_attention_state_dict(fill_value=10.0) + + layer_state_dict, keys_to_remove = _process_single_layer( + layer_idx=0, + pruning_mixin=None, + descriptor=DecoderConfigDescriptor, + parent_state_dict=original_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=GQAInitMode.CopyAsIs, + mlp_init_mode=MlpInitMode.CopyAsIs, + mlp_init_config=None, + linear_init_mode=LinearInitMode.Random, + ignored_keys=set(), + keys={key: key for key in original_state_dict}, + is_original_mha=True, + head_size=2, + hidden_size=4, + ) + + _assert_attention_state_dict_matches(layer_state_dict, original_state_dict) + assert keys_to_remove == {key: key for key in original_state_dict} diff --git a/tests/unit/torch/puzzletron/test_pruning_descriptor_mixins.py b/tests/unit/torch/puzzletron/test_pruning_descriptor_mixins.py new file mode 100644 index 00000000000..f17f1b8acd9 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_pruning_descriptor_mixins.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for model-descriptor pruning mixin registries. + +Bypass child initialization resolves pruning behavior by descriptor key. These +tests pin the public keys and layer prefixes that external configs depend on, +without instantiating full transformer models. +""" + +from importlib import import_module + +import pytest + +from modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin import ExpertRemovalPruningMixIn +from modelopt.torch.puzzletron.pruning.ffn_intermediate_pruning_mixin import ( + FFNIntermediatePruningMixIn, +) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import KVHeadsPruningMixIn + + +def test_gpt_oss_descriptor_exposes_canonical_alias_and_kv_heads_mixins(): + pytest.importorskip("transformers.models.gpt_oss.modeling_gpt_oss") + + from modelopt.torch.puzzletron.anymodel.models.gpt_oss.gpt_oss_model_descriptor import ( + GptOssModelDescriptor, + ) + + mixins = GptOssModelDescriptor.pruning_mixins() + + assert set(mixins) == {"experts_removal", "expert_removal", "kv_heads"} + assert mixins["experts_removal"] is mixins["expert_removal"] + assert isinstance(mixins["experts_removal"], ExpertRemovalPruningMixIn) + assert isinstance(mixins["kv_heads"], KVHeadsPruningMixIn) + assert mixins["kv_heads"].layer_descriptor.attn_prefix(3) == "model.layers.3.self_attn" + + +def test_nemotron_h_descriptor_exposes_expert_removal_and_kv_heads_mixins(): + from modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor import ( + NemotronHModelDescriptor, + ) + + mixins = NemotronHModelDescriptor.pruning_mixins() + + assert set(mixins) == {"experts_removal", "kv_heads"} + assert isinstance(mixins["experts_removal"], ExpertRemovalPruningMixIn) + assert isinstance(mixins["kv_heads"], KVHeadsPruningMixIn) + assert mixins["kv_heads"].layer_descriptor.attn_prefix(2) == "backbone.layers.2.mixer" + + +def test_nemotron_h_v2_descriptor_exposes_ffn_and_kv_heads_mixins(): + module = import_module( + "modelopt.torch.puzzletron.anymodel.models.nemotron_h_v2.nemotron_h_v2_model_descriptor" + ) + + mixins = module.NemotronHV2ModelDescriptor.pruning_mixins() + + assert set(mixins) == {"ffn_intermediate", "kv_heads"} + assert isinstance(mixins["ffn_intermediate"], FFNIntermediatePruningMixIn) + assert isinstance(mixins["kv_heads"], KVHeadsPruningMixIn) + assert mixins["kv_heads"].layer_descriptor.attn_prefix(2) == "backbone.layers.2.mixer" + + +def test_qwen3_vl_descriptor_exposes_expert_removal_and_kv_heads_mixins(): + pytest.importorskip("transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe") + + from modelopt.torch.puzzletron.anymodel.models.qwen3_vl.qwen3_vl_model_descriptor import ( + Qwen3VLModelDescriptor, + ) + + mixins = Qwen3VLModelDescriptor.pruning_mixins() + + assert set(mixins) == {"experts_removal", "kv_heads"} + assert isinstance(mixins["experts_removal"], ExpertRemovalPruningMixIn) + assert isinstance(mixins["kv_heads"], KVHeadsPruningMixIn) + assert ( + mixins["kv_heads"].layer_descriptor.attn_prefix(2) + == "model.language_model.layers.2.self_attn" + )