Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e82164f
Add anymodel directories to feature/puzzletron
danielkorzekwa Mar 4, 2026
2099df3
Make any_model conversion working.
danielkorzekwa Mar 5, 2026
eb5cf8a
Update child_init.py with anymodel version
danielkorzekwa Mar 5, 2026
c9de41c
fix attention pruning
danielkorzekwa Mar 5, 2026
3c1bc1f
Add trust_remote_code to load_model_config (default to false)
danielkorzekwa Mar 5, 2026
8357136
Make activation scoring working
danielkorzekwa Mar 5, 2026
6cc2194
Comment all tested models aside of llama_3_1_8b_instruct
danielkorzekwa Mar 5, 2026
ee4e1e3
Delete not needed decilm test
danielkorzekwa Mar 5, 2026
449b523
Fix broken tests
danielkorzekwa Mar 5, 2026
fb27bba
Update puzzletron_nas_pluging to any_model version
danielkorzekwa Mar 5, 2026
b350f82
Correct test resources used by tests.
danielkorzekwa Mar 5, 2026
fafe5a3
Disable puzzletron tests (will be enabled after all any_model logic i…
danielkorzekwa Mar 5, 2026
e988248
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 6, 2026
c717852
Comment out not implemented models.
danielkorzekwa Mar 6, 2026
030f126
format python docs
danielkorzekwa Mar 6, 2026
8dcdfbf
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 6, 2026
70df0df
Use trust_remote_code in force_cache_dynamic_modules()
danielkorzekwa Mar 6, 2026
bb56662
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 6, 2026
ecd953e
Fix anymodel pruning
danielkorzekwa Mar 6, 2026
ee8f538
Fix buid docs issue.
danielkorzekwa Mar 6, 2026
c9b76a1
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 6, 2026
6e3af61
Merge branch 'dkorzekwa/anymodel_activation_scoring' into dkorzekwa/a…
danielkorzekwa Mar 6, 2026
0ad6d92
Merging build_library_and_stats
danielkorzekwa Mar 6, 2026
47414d5
Clarify readme and avoid reusing the same reference in llama_converter.
danielkorzekwa Mar 9, 2026
a8305d8
Fix tied-embedding handling before writing the safetensors index.
danielkorzekwa Mar 9, 2026
68421a5
Fix NaN ranking currently selects NaNs as “best” experts by default.
danielkorzekwa Mar 9, 2026
d6b8028
Code clean up.
danielkorzekwa Mar 9, 2026
ecd2341
Code clean up.
danielkorzekwa Mar 10, 2026
f9d845d
code clean up
danielkorzekwa Mar 10, 2026
d171b01
Merge branch 'dkorzekwa/anymodel_core' into dkorzekwa/anymodel_activa…
danielkorzekwa Mar 10, 2026
722da90
Merge branch 'dkorzekwa/anymodel_activation_scoring' into dkorzekwa/a…
danielkorzekwa Mar 10, 2026
934ab2f
code clean up
danielkorzekwa Mar 10, 2026
0f14ec3
Merge branch 'dkorzekwa/anymodel_pruning' into dkorzekwa/anymodel_bui…
danielkorzekwa Mar 10, 2026
dcb9e02
remove not needed comment
danielkorzekwa Mar 10, 2026
176a435
Fix a broken test_puzzletron test on 2 gpus.
danielkorzekwa Mar 10, 2026
02e2c9b
Merge branch 'dkorzekwa/anymodel_activation_scoring' into dkorzekwa/a…
danielkorzekwa Mar 10, 2026
92c4419
Merge branch 'dkorzekwa/anymodel_pruning' into dkorzekwa/anymodel_bui…
danielkorzekwa Mar 10, 2026
361014f
Merge branch 'feature/puzzletron' into dkorzekwa/anymodel_build_libra…
danielkorzekwa Mar 12, 2026
6b74412
Uncomment build_library_and_stats step
danielkorzekwa Mar 12, 2026
d4b8da6
Fixing build docs ( tox -e build-docs)
danielkorzekwa Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
380 changes: 379 additions & 1 deletion modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py

Large diffs are not rendered by default.

121 changes: 47 additions & 74 deletions modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,84 +15,57 @@
# mypy: ignore-errors

"""Provides a function to register activation hooks for a model.
Activation hooks are used to compute activation scores for pruning.
"""
Activation hooks are used to compute activation scores for pruning."""

import re
from typing import Type

from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import (
ForwardHook,
IndependentChannelContributionHook,
IndependentKvHeadContributionHook,
IterativeChannelContributionHook,
LayerNormContributionHook,
)
from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM
from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook as ActivationsHook
from modelopt.torch.puzzletron.tools.logger import aprint


def register_activation_hooks(
model: DeciLMForCausalLM, activation_hooks_kwargs: dict
) -> tuple[dict[str, ForwardHook], type[ForwardHook]]:
hook_class_map = {
"mlp.down_proj": {
"independent": IndependentChannelContributionHook,
"iterative": IterativeChannelContributionHook,
},
"self_attn.o_proj": {
"independent_kv_head_contribution": IndependentKvHeadContributionHook,
},
r"regex:experts\.\d+\.down_proj$": { # For MoE
"independent": IndependentChannelContributionHook,
},
# TODO: maybe this is too generic, and we should have it specifically for
# input_layernorm and post_attention_layernorm; now it might select qk_norms
"layernorm": {
"layer_norm_contribution": LayerNormContributionHook,
},
}

activation_hooks = {}
target_layer = activation_hooks_kwargs.get("target_layer", "mlp.c_proj")

if target_layer.startswith("regex:"):
target_layer_regex = target_layer[len("regex:") :]
pattern = re.compile(target_layer_regex)

def match_predicate(module_name, module):
return pattern.search(module_name)
else:

def match_predicate(module_name, module):
return module_name.endswith(target_layer)

target_layer_hooks_map = hook_class_map.get(target_layer)
if target_layer_hooks_map is None:
raise ValueError(f"no hook classes found for: {target_layer}")

hook_class = target_layer_hooks_map.get(activation_hooks_kwargs["method"])
if hook_class is None:
raise ValueError(f"Unknown hook class: {hook_class}")

if target_layer == "block":
pattern = re.compile(r"^transformer\.h\.\d+$")

def match_predicate(module_name, module):
return pattern.match(module_name)

model,
activation_hooks_kwargs: dict,
pruning_mixin,
hook_class: Type[ActivationsHook],
) -> dict[str, ActivationsHook]:
"""Register activation hooks using the pruning mixin approach.

Args:
model: The model to register hooks on.
activation_hooks_kwargs: Keyword arguments passed to hook constructors.
pruning_mixin: The pruning mixin that defines which modules to hook.
hook_class: The hook class to instantiate for each module.

Returns:
Dictionary mapping module names to hook instances.
"""
activation_hooks_kwargs["model"] = model
for module_name, module in model.named_modules():
if match_predicate(module_name, module):
block_config = None
if block_idx_match := re.search(r"\.(\d+)\.", module_name):
block_idx = int(block_idx_match.group(1))
block_config = model.config.block_configs[block_idx]
curr_activation_hooks_kwargs = {
**activation_hooks_kwargs,
"block_config": block_config,
}

hook = hook_class(module, curr_activation_hooks_kwargs)
module.register_forward_hook(hook)
activation_hooks[module_name] = hook

return activation_hooks, hook_class
if hook_class not in pruning_mixin.supported_hooks():
raise ValueError(
f"Hook class not supported for {pruning_mixin.__class__.__name__}, "
f"must be in {pruning_mixin.supported_hooks()}"
)

module_names_to_hook = pruning_mixin.get_module_names_to_hook(model)
activation_hooks = dict()
for block_idx, module_name in module_names_to_hook:
block_config = None
if block_idx is not None:
block_config = model.config.block_configs[block_idx]
curr_activation_hooks_kwargs = {
**activation_hooks_kwargs,
"block_config": block_config,
}

module = model.get_submodule(module_name)
hook = hook_class(module, curr_activation_hooks_kwargs)
module.register_forward_hook(hook)
activation_hooks[module_name] = hook

if len(activation_hooks) == 0:
raise ValueError("couldn't find any hooks")

aprint(f"Found the following hooks: {activation_hooks.keys()}")
return activation_hooks
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,4 @@ def launch_score_activations(cfg: DictConfig):
mprint("Starting pruning activation scoring...")

# The checkpoint manager inside validate_model handles all progress tracking
validate_model(args=cfg.pruning, pipeline_parallel=True)
validate_model(args=cfg.pruning)
204 changes: 204 additions & 0 deletions modelopt/torch/puzzletron/anymodel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# AnyModel Guide

This guide explains how to add support for new models in the Puzzletron pipeline.

## Convert model

Convert a HuggingFace model to Puzzletron format.

Step 1: Create Model Descriptor

Extend `ModelDescriptor` and implement `layer_name_predicates()` to define regex patterns for grouping weights into subblocks (embeddings, lm_head, block_N_ffn, block_N_attention).

Key points:

- Find weight names on the model's HuggingFace page → click "Files info" to see the safetensors structure with all tensor names (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json))

See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py)

Step 2: Create Converter

Extend `Converter` and implement `create_block_configs_from_main_config()` to create per-layer BlockConfigs from the HuggingFace config.

Key points:

- Import correct HuggingFace config class (e.g., `MistralConfig`, `LlamaConfig`, `Qwen2Config`). Find it in the transformers source: `github.com/huggingface/transformers/tree/main/src/transformers/models/<model_type>/configuration_<model_type>.py`

See example: [llama_converter.py](models/llama/llama_converter.py)

Step 3: Create `models/<model_name>/__init__.py`

Export descriptor and converter classes:

```python
from models.<model_name>.<model_name>_model_descriptor import MyModelDescriptor
from models.<model_name>.<model_name>_converter import MyConverter
```

Step 4: Register in `models/__init__.py`

Add import to trigger factory registration:

```python
from models.<model_name> import *
```

## Usage

```python
from modelopt.torch.puzzletron.anymodel import convert_model

convert_model(
input_dir="path/to/hf_checkpoint",
output_dir="path/to/puzzletron_checkpoint",
converter="model_name",
)
```

## Compress model

Run pruning and compression on a Puzzletron model.

Step 1: Implement ModelDescriptor methods for compression

Add to your `ModelDescriptor`:

- `decoder_layer_cls()` - return the decoder layer class(es) to patch for heterogeneous config support
- `block_config_to_layer_overrides()` - map BlockConfig to layer override dict (see [details](#implementing-block_config_to_layer_overrides))
- `init_rotary_embedding()` - reinitialize rotary embeddings after model loading (see [details](#implementing-init_rotary_embedding))
- `input_embedding_name()` - return the name of the input embedding layer (see [details](#implementing-path-based-methods))
- `output_embedding_name()` - return the name of the output embedding layer (see [details](#implementing-path-based-methods))
- `layer_block_name()` - return the name pattern for decoder layers (see [details](#implementing-path-based-methods))
- `final_norm_name()` - return the name of the final normalization layer (see [details](#implementing-path-based-methods))
- `attn_no_op_post_init()` - replace attention sublayers with no-op modules
- `mlp_no_op_post_init()` - replace MLP sublayers with no-op modules

Step 2: Create FFN Layer Descriptor

Extend `FFNIntermediateLayerDescriptor` to define model-specific paths for FFN pruning hooks (`down_proj_name`, `ffn_prefix_name`, `linear_weight_names`). Derive values from your model's weight names in `layer_name_predicates()`.

See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py) → `LlamaFFNIntermediateLayerDescriptor`

Step 3: Configure YAML files

Update the main model config YAML:

- Set `descriptor` to match the name used in `@ModelDescriptorFactory.register_decorator("your_model_name")`
- See example: [llama_3_1_8b_instruct.yaml](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/llama_3_1_8b_instruct.yaml)

Update pruning YAML files (`ffn_pruning.yaml`, `expert_pruning.yaml`, etc.):

- Set `pruning_mixin._target_` to the appropriate mixin class
- Set `layer_descriptor._target_` to your layer descriptor class
- Set `hook_class` to the activation hook for scoring
- Set `target_layer` in `activation_hooks_kwargs` to the layer name for hook attachment
- See examples in [configs/llama_3_1_8b_instruct/pruning/](../../../../tests/gpu/torch/puzzletron/resources/configs/llama_3_1_8b_instruct/pruning/)

## End-to-end example

See [test_puzzletron.py](../../../../tests/gpu/torch/puzzletron/test_puzzletron.py) for a complete example that runs both convert and compression steps.

---

## Advanced Topics

## Pruning Configuration

### Pruning YAML Structure

Each pruning type has a YAML config with these key fields:

```yaml
pruning_mixin:
_target_: pruning.<type>_pruning_mixin.<MixinClass>
layer_descriptor:
_target_: models.<model>.<descriptor_class>

hook_class: ${get_object:utils.activation_hooks.hooks.<HookClass>}
activation_hooks_kwargs:
method: <method_name>
target_layer: "<layer.name>" # e.g., "mlp.down_proj", "self_attn.o_proj"
```

| Field | Description |
|-------|-------------|
| `pruning_mixin._target_` | Mixin class that orchestrates this pruning type |
| `layer_descriptor._target_` | Model-specific class defining layer paths for hooks |
| `hook_class` | Activation hook class for importance scoring |
| `target_layer` | Layer name (relative to decoder block) where hooks attach |

### Adding a New Hook Class

1. **Implement the hook** in `modelopt/torch/nas/plugins/megatron_hooks/base_hooks.py`:
- Extend an existing hook base class (e.g., `RemoveExpertsIndependentHook`)
- Implement required methods (e.g., `get_router_logits_and_routed_experts`)

2. **Register the hook** in the appropriate pruning mixin's `supported_hooks()`:

For FFN pruning (`pruning/ffn_intermediate_pruning_mixin.py`):

```python
def supported_hooks(self) -> List[Type[ActivationsHook]]:
return [IndependentChannelContributionHook, IterativeChannelContributionHook, YourNewHook]
```

For expert removal (`pruning/expert_removal_pruning_mixin.py`):

```python
def supported_hooks(self) -> List[Type[ActivationsHook]]:
return [RankedChoiceVotingHook, ..., YourNewHook]
```

3. **Reference in YAML**:

```yaml
hook_class: ${get_object:utils.activation_hooks.hooks.YourNewHook}
```

### Pruning Types Reference

| Type | Mixin | Example Hooks |
|------|-------|---------------|
| FFN intermediate | [`FFNIntermediatePruningMixIn`](../pruning/ffn_intermediate_pruning_mixin.py) | [`IterativeChannelContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py), [`IndependentChannelContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) |
| Expert removal | [`ExpertRemovalPruningMixIn`](../pruning/expert_removal_pruning_mixin.py) | [`NemotronHRemoveExpertsIndependentHook`](../../../nas/plugins/megatron_hooks/base_hooks.py), [`Qwen3VLRemoveExpertsIndependentHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) |
| KV heads | [`KVHeadsPruningMixIn`](../pruning/kv_heads_pruning_mixin.py) | [`IndependentKvHeadContributionHook`](../../../nas/plugins/megatron_hooks/base_hooks.py) |

## Implementing `block_config_to_layer_overrides`

Maps Puzzletron's [`BlockConfig`](../decilm/deci_lm_hf_code/block_config.py) fields to HuggingFace config attribute names. Only override attributes that change during pruning:

| BlockConfig Field | HuggingFace Attribute (check `config.json`) |
|-------------------|---------------------------------------------|
| `attention.num_key_value_heads` | `num_key_value_heads` |
| `ffn.intermediate_size` | `intermediate_size` |
| `ffn.moe.num_local_experts` | `num_experts` or `n_routed_experts` (model-specific) |
| `ffn.moe.expert_intermediate_dim` | `moe_intermediate_size` |

**Tip**: Check the model's `config.json` for exact attribute names - they vary between models.

See examples: [qwen3_vl](models/qwen3_vl/qwen3_vl_model_descriptor.py), [nemotron_h](models/nemotron_h/nemotron_h_model_descriptor.py)

---

## Implementing path-based methods

These methods return paths derived from the model's weight names:

- `input_embedding_name()`, `output_embedding_name()`, `layer_block_name()`, `final_norm_name()`

Find them on the model's HuggingFace page → "Files info" → safetensors structure (example: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct?show_file_info=model.safetensors.index.json)).

See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py)

---

## Implementing `init_rotary_embedding`

Rotary embeddings are computed modules (not saved weights). After model sharding, they need re-initialization on the correct device/dtype.

Look in `github.com/huggingface/transformers/tree/main/src/transformers/models/<model_type>/modeling_<model_type>.py` for:

- `class.*Rotary` — the rotary embedding class name and constructor arguments
- `self.rotary_emb` — the attribute path

See example: [llama_model_descriptor.py](models/llama/llama_model_descriptor.py)
Loading
Loading