diff --git a/deepseek_dequantize.py b/deepseek_dequantize.py new file mode 100644 index 0000000000..fc98e52bc2 --- /dev/null +++ b/deepseek_dequantize.py @@ -0,0 +1,16 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from llmcompressor.modeling.moe.linearize import linearize_moe_model + +model = AutoModelForCausalLM.from_pretrained( + "deepseek-ai/DeepSeek-V4-Flash", + torch_dtype="auto", + device_map="cpu", +) +delattr(model, "_weight_conversions") +tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V4-Flash") + +save_dir = "DeepSeek-V4-Flash-bf16" +#model.dequantize(torch.bfloat16) +model.save_pretrained(save_dir) +tokenizer.save_pretrained(save_dir) diff --git a/examples/quantizing_moe/deepseek_v4_example.py b/examples/quantizing_moe/deepseek_v4_example.py new file mode 100644 index 0000000000..4506be3fd2 --- /dev/null +++ b/examples/quantizing_moe/deepseek_v4_example.py @@ -0,0 +1,145 @@ +from compressed_tensors.offload import load_offloaded_model +from compressed_tensors.quantization.quant_scheme import ( + FP8_BLOCK, + NVFP4, + QuantizationScheme, +) +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modeling.moe.linearize import linearize_moe_model +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils.dev import skip_weights_download + +# Select model and load it. +MODEL_ID = "RedHatAI/DeepSeek-V4-Flash-BF16" + +with load_offloaded_model(): + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype="auto", + # device_map="auto_offload", + device_map="cpu", + # max_memory={"cpu": 3e10}, + # offload_folder="offload_folder", + ) +# from transformers.core_model_loading import revert_weight_conversion +# from compressed_tensors.offload import disable_onloading +# with disable_onloading(): +# new_state_dict = revert_weight_conversion(model, model.state_dict()) +# print(new_state_dict.keys()) +# exit(0) + +linearize_moe_model(model) + +# kluge for the way I saved the decompressed checkpoint +# mds = model.model.layers[-1].self_attn.wq_a._hf_hook.weights_map.dataset.index +# mds["model.hc_head.base"] = mds['model.hc_head.hc_base'] +# mds["model.hc_head.fn"] = mds['model.hc_head.hc_fn'] +# mds["model.hc_head.scale"] = mds['model.hc_head.hc_scale'] + +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 64 # 1024 +MAX_SEQUENCE_LENGTH = 512 + +# Load dataset and preprocess. +ds = load_dataset( + DATASET_ID, + split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]", # get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES) +) +ds = ds.shuffle(seed=42) + + +def preprocess(example): + # DeepSeek-V4 does not have a traditional chat template. + # Encode manually per https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/tree/main/encoding + BOS = "<|begin▁of▁sentence|>" + EOS = "<|end▁of▁sentence|>" + text = BOS + for message in example["messages"]: + role = message["role"] + content = message["content"] + if role == "system": + text += content + elif role == "user": + text += f"<|User|>{content}" + elif role == "assistant": + text += f"<|Assistant|>{content}{EOS}" + + return {"text": text} + + +ds = ds.map(preprocess) + + +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize mlp/expert weights to NVFP4 +# * quantize attention projection weights to FP8_BLOCK +# model.model.layers.0.self_attn.q_a_proj +# +# wq_a | q_a_proj +# wq_b | q_b_proj +# wkv | kv_proj +# wo_a | o_a_proj +# wo_b | o_b_proj + +recipe = QuantizationModifier( + config_groups={ + "attention": QuantizationScheme( + targets=[ + r"re:.*attn\.(q_a_proj|q_b_proj|kv_proj|o_a_proj|o_b_proj)$", + r"re:.*attn\.compressor\.indexer\.q_b_proj$", + ], + **FP8_BLOCK, + ), + "experts": QuantizationScheme( + targets=[ + r"re:.*mlp\.experts.*(gate|up|down)_proj$", + r"re:.*mlp\.shared_experts.*(gate|up|down)_proj$", + ], + **NVFP4, + ), + }, + ignore=[], +) +# model.layers.4.self_attn.compressor.indexer.weights_proj +# model.layers.3.ffn_hc + +# Apply algorithms. +# due to the large size of DeepSeek-V4, we specify sequential targets such that +# only one block is loaded into GPU memory at a time +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + sequential_targets=["DeepseekV4DecoderLayer"], + batch_size=1, + shuffle_calibration_samples=True, +) + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4-FP8-BLOCK" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/fix_checkpoint_keys.py b/fix_checkpoint_keys.py new file mode 100644 index 0000000000..af291eb01e --- /dev/null +++ b/fix_checkpoint_keys.py @@ -0,0 +1,183 @@ +"""Resave a DeepSeek-V4-Flash NVFP4 checkpoint with key names matching the BF16 +checkpoint structure. Quantization parameter suffixes (weight_packed, weight_scale, +input_global_scale, weight_global_scale) are preserved; only prefixes and module +names are changed.""" + +import argparse +import json +import re +import shutil +from pathlib import Path + +from safetensors import safe_open +from safetensors.torch import save_file + + +def rename_key(key: str) -> str: + if key == "head.weight": + return key + + if key.startswith("model."): + key = key[len("model."):] + + top_level = { + "embed_tokens.weight": "embed.weight", + "norm.weight": "norm.weight", + "hc_head.hc_base": "hc_head_base", + "hc_head.hc_fn": "hc_head_fn", + "hc_head.hc_scale": "hc_head_scale", + } + if key in top_level: + return top_level[key] + + m = re.match(r"(layers\.\d+\.)(.*)", key) + if not m: + raise ValueError(f"Unrecognized key: {key}") + + prefix = m.group(1) + rest = m.group(2) + + # --- layer norms --- + if rest == "input_layernorm.weight": + return prefix + "attn_norm.weight" + if rest == "post_attention_layernorm.weight": + return prefix + "ffn_norm.weight" + + # --- hardware counters --- + hc_map = { + "attn_hc.base": "hc_attn_base", + "attn_hc.fn": "hc_attn_fn", + "attn_hc.scale": "hc_attn_scale", + "ffn_hc.base": "hc_ffn_base", + "ffn_hc.fn": "hc_ffn_fn", + "ffn_hc.scale": "hc_ffn_scale", + } + if rest in hc_map: + return prefix + hc_map[rest] + + # --- compressor.indexer (most specific first) --- + ci_exact = { + "self_attn.compressor.indexer.gate_proj.weight": "attn.indexer.compressor.wgate.weight", + "self_attn.compressor.indexer.kv_norm.weight": "attn.indexer.compressor.norm.weight", + "self_attn.compressor.indexer.kv_proj.weight": "attn.indexer.compressor.wkv.weight", + "self_attn.compressor.indexer.position_bias": "attn.indexer.compressor.ape", + "self_attn.compressor.indexer.weights_proj.weight": "attn.indexer.weights_proj.weight", + } + if rest in ci_exact: + return prefix + ci_exact[rest] + m2 = re.match(r"self_attn\.compressor\.indexer\.q_b_proj\.(.*)", rest) + if m2: + return prefix + "attn.indexer.wq_b." + m2.group(1) + + # --- compressor (without indexer) --- + c_exact = { + "self_attn.compressor.gate_proj.weight": "attn.compressor.wgate.weight", + "self_attn.compressor.kv_norm.weight": "attn.compressor.norm.weight", + "self_attn.compressor.kv_proj.weight": "attn.compressor.wkv.weight", + "self_attn.compressor.position_bias": "attn.compressor.ape", + } + if rest in c_exact: + return prefix + c_exact[rest] + + # --- self-attention (exact matches) --- + attn_exact = { + "self_attn.sinks": "attn.attn_sink", + "self_attn.kv_norm.weight": "attn.kv_norm.weight", + "self_attn.q_a_norm.weight": "attn.q_norm.weight", + } + if rest in attn_exact: + return prefix + attn_exact[rest] + + # --- self-attention projections (with possible quant suffixes) --- + attn_proj_map = { + "self_attn.kv_proj": "attn.wkv", + "self_attn.o_a_proj": "attn.wo_a", + "self_attn.o_b_proj": "attn.wo_b", + "self_attn.q_a_proj": "attn.wq_a", + "self_attn.q_b_proj": "attn.wq_b", + } + for old, new in attn_proj_map.items(): + m2 = re.match(rf"{re.escape(old)}\.(.*)", rest) + if m2: + return prefix + new + "." + m2.group(1) + + # --- MLP gate --- + gate_map = { + "mlp.gate.weight": "ffn.gate.weight", + "mlp.gate.tid2eid": "ffn.gate.tid2eid", + "mlp.gate.e_score_correction_bias": "ffn.gate.bias", + } + if rest in gate_map: + return prefix + gate_map[rest] + + # --- MLP experts --- + proj_map = {"gate_proj": "w1", "down_proj": "w2", "up_proj": "w3"} + m2 = re.match(r"mlp\.experts\.(\d+)\.(gate_proj|down_proj|up_proj)\.(.*)", rest) + if m2: + eid, proj, suffix = m2.group(1), m2.group(2), m2.group(3) + return prefix + f"ffn.experts.{eid}.{proj_map[proj]}.{suffix}" + + # --- MLP shared experts --- + m2 = re.match(r"mlp\.shared_experts\.(gate_proj|down_proj|up_proj)\.(.*)", rest) + if m2: + proj, suffix = m2.group(1), m2.group(2) + return prefix + f"ffn.shared_experts.{proj_map[proj]}.{suffix}" + + raise ValueError(f"Unrecognized key: layers.*.{rest}") + + +def main(): + parser = argparse.ArgumentParser( + description="Resave NVFP4 checkpoint with BF16-style key names" + ) + parser.add_argument("input_dir", type=Path) + parser.add_argument("output_dir", type=Path) + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + index_path = args.input_dir / "model.safetensors.index.json" + with open(index_path) as f: + index = json.load(f) + + shard_names = sorted(set(index["weight_map"].values())) + + new_weight_map = {} + for old_key, shard_name in index["weight_map"].items(): + new_weight_map[rename_key(old_key)] = shard_name + + for i, shard_name in enumerate(shard_names): + src = args.input_dir / shard_name + dst = args.output_dir / shard_name + print(f"[{i + 1}/{len(shard_names)}] Processing {shard_name} ...") + + tensors = {} + with safe_open(str(src), framework="pt") as f: + for key in f.keys(): + tensors[rename_key(key)] = f.get_tensor(key) + + save_file(tensors, str(dst)) + del tensors + print(f" Saved {dst}") + + new_index = { + "metadata": index.get("metadata", {}), + "weight_map": new_weight_map, + } + out_index = args.output_dir / "model.safetensors.index.json" + with open(out_index, "w") as f: + json.dump(new_index, f, indent=2, sort_keys=False) + print(f"Saved {out_index}") + + for name in ("config.json", "generation_config.json", + "tokenizer.json", "tokenizer_config.json"): + src = args.input_dir / name + if src.exists(): + shutil.copy2(src, args.output_dir / name) + print(f"Copied {name}") + + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index caf1c1714b..3ec7d7b90d 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -22,7 +22,6 @@ from llmcompressor.core.session_functions import active_session from llmcompressor.datasets import get_calibration_dataloader from llmcompressor.entrypoints.utils import post_process, pre_process -from llmcompressor.modeling.moe_context import moe_calibration_context from llmcompressor.modeling.offset_norm import norm_calibration_context from llmcompressor.pipelines import CalibrationPipeline @@ -219,10 +218,7 @@ def apply_recipe_modifiers( # (Helen INFERENG-661): validate recipe modifiers before initialization # Apply calibration contexts for the entire calibration process - with norm_calibration_context(self.model), moe_calibration_context( - self.model, - calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts, - ): + with norm_calibration_context(self.model): session.initialize( model=self.model, start=-1, diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index 2bdc98e196..39ee181ec2 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -10,18 +10,7 @@ """ # trigger registration -from .afmoe import CalibrationAfmoeMoE # noqa: F401 -from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401 -from .glm4_moe import CalibrationGlm4MoeMoE # noqa: F401 -from .glm4_moe_lite import CalibrationGlm4MoeLiteMoE # noqa: F401 -from .glm_moe_dsa import CalibrationGlmMoeDsaMoE # noqa: F401 -from .llama4 import SequentialLlama4TextMoe # noqa: F401 -from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401 -from .qwen3_5_moe import CalibrationQwen3_5MoeSparseMoeBlock -from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401 -from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401 from .offset_norm import CalibrationOffsetNorm # noqa: F401 -from .gemma4 import SequentialGemma4TextExperts # noqa: F401 # TODO: add granite4 from .fuse import * diff --git a/src/llmcompressor/modeling/afmoe.py b/src/llmcompressor/modeling/afmoe.py deleted file mode 100644 index 32e47052bc..0000000000 --- a/src/llmcompressor/modeling/afmoe.py +++ /dev/null @@ -1,108 +0,0 @@ -import torch - -from llmcompressor.modeling.moe_context import MoECalibrationModule - - -@MoECalibrationModule.register("AfmoeMoE") -class CalibrationAfmoeMoE(MoECalibrationModule): - """ - Calibration version of AfmoeMoE that sends all tokens to all experts. - - During calibration, when calibrate_all_experts=True, all tokens are sent to - all experts to ensure proper quantization statistics are collected for every - expert, not just those activated by the calibration data routing. - - The Afmoe architecture uses: - - Token-choice top-K routing with sigmoid/softmax scoring - - Optional shared experts processed on all tokens - - Learnable expert bias for routing control - - Note: AfmoeMoE is loaded dynamically from the model hub via trust_remote_code=True. - The original module is passed as a parameter. - """ - - is_permanent = False - - def __init__( - self, - original: torch.nn.Module, - config, - calibrate_all_experts: bool = True, - ): - super().__init__() - self.config = config - self.router = original.router - self.experts = original.experts - self.shared_experts = original.shared_experts - self.expert_bias = original.expert_bias - self.calibrate_all_experts = calibrate_all_experts - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Forward pass with optional calibration mode. - - When calibrate_all_experts=True: - - All tokens are sent to all experts for calibration - - Routing weights are still used for final output combination - - This ensures all experts see calibration data - When calibrate_all_experts=False: - - Normal MoE routing behavior (only routed tokens go to each expert) - """ - batch_size, seq_len, hidden_dim = hidden_states.shape - hidden_states_flat = hidden_states.view(-1, hidden_dim) - - # Step 1: Get routing decisions - top_scores, selected_experts = self.router(hidden_states, self.expert_bias) - - # Step 2: Process through shared experts - if self.shared_experts is not None: - shared_output = self.shared_experts(hidden_states_flat) - else: - shared_output = torch.zeros_like(hidden_states_flat) - - # Step 3: Create expert mask for routing - which tokens - # were selected - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.config.num_experts - ).permute(2, 1, 0) # (num_experts, top_k, batch_size * seq_len) - - # Step 4: Process routed experts - routed_output = torch.zeros_like( - hidden_states_flat, dtype=hidden_states.dtype, device=hidden_states.device - ) - - for expert_idx, expert in enumerate(self.experts): - # Get the indices of tokens routed to this expert - idx, token_idx = torch.where(expert_mask[expert_idx]) - - if self.calibrate_all_experts: - # Pass all tokens through the expert but only outputs - # for the selected tokens are extracted (i.e if this - # expert was selected) - expert_output = expert(hidden_states_flat)[token_idx] - else: - # Only pass routed tokens through the expert - expert_output = expert(hidden_states_flat[token_idx]) - - # If any tokens were routed to this expert, add their contribution - if len(token_idx) > 0: - weighted_output = expert_output * top_scores[token_idx, idx, None] - # add weighted output to the final output for the routed tokens - routed_output.index_add_( - 0, token_idx, weighted_output.to(hidden_states.dtype) - ) - - # Step 5: Combine shared and routed expert output - output = shared_output.to(hidden_states.dtype) + routed_output.to( - hidden_states.dtype - ) - return output.view(batch_size, seq_len, hidden_dim) - - def restore(self, original: torch.nn.Module) -> torch.nn.Module: - """ - Restore the original module structure. - - Since is_permanent=False, this method is called when exiting - the calibration context to restore the original MoE module. - """ - return original diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py deleted file mode 100644 index a4d791ac39..0000000000 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config -from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( - DeepseekV3MoE as OriginalDeepseekV3MoE, -) - -from llmcompressor.modeling.moe_context import MoECalibrationModule - - -@MoECalibrationModule.register("DeepseekV3MoE") -class CalibrationDeepseekV3MoE(MoECalibrationModule): - """ - Calibration version of DeepseekV3MoE that sends all tokens to all experts. - """ - - is_permanent = True - - def __init__( - self, - original: OriginalDeepseekV3MoE, - config: DeepseekV3Config, - calibrate_all_experts: bool = True, - ): - super().__init__() - self.config = config - self.experts = original.experts - self.gate = original.gate - self.shared_experts = original.shared_experts - self.calibrate_all_experts = calibrate_all_experts - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - residuals = hidden_states - orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - # Begin MoE - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - expert_mask = torch.nn.functional.one_hot( - topk_indices, num_classes=len(self.experts) - ) - expert_mask = expert_mask.permute(2, 0, 1) - - for expert_idx, expert in enumerate(self.experts): - token_indices, weight_indices = torch.where(expert_mask[expert_idx]) - has_tokens = token_indices.numel() > 0 - - if self.calibrate_all_experts: - expert_input = hidden_states - expert_output = expert(expert_input) - - if has_tokens: - expert_weights = topk_weights[token_indices, weight_indices] - routed_output = expert_output[ - token_indices - ] * expert_weights.unsqueeze(-1) - final_hidden_states.index_add_(0, token_indices, routed_output) - else: - # Normal MoE: only process tokens routed to this expert - if has_tokens: - expert_input = hidden_states[token_indices] - expert_output = expert(expert_input) - expert_weights = topk_weights[token_indices, weight_indices] - routed_output = expert_output * expert_weights.unsqueeze(-1) - final_hidden_states.index_add_(0, token_indices, routed_output) - # End MoE - - hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states diff --git a/src/llmcompressor/modeling/gemma4.py b/src/llmcompressor/modeling/gemma4.py deleted file mode 100644 index d7b8f370cd..0000000000 --- a/src/llmcompressor/modeling/gemma4.py +++ /dev/null @@ -1,116 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import torch - -from llmcompressor.modeling.moe_context import MoECalibrationModule -from llmcompressor.utils.dev import skip_weights_initialize - -if TYPE_CHECKING: - from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig - from transformers.models.gemma4.modeling_gemma4 import ( - Gemma4Config, - Gemma4TextExperts, - ) - - -@MoECalibrationModule.register("Gemma4TextExperts") -class SequentialGemma4TextExperts(MoECalibrationModule): - """ - Calibration version of Gemma4TextExperts that unpacks experts. - - This module unpacks the packed expert weights (3D -> 2D) for calibration and - stays in unpacked form (permanent) for vLLM compatibility. - """ - - is_permanent = True - - def __init__( - self, - original: Gemma4TextExperts, - config: Gemma4Config, - calibrate_all_experts: bool = True, - ): - super().__init__() - self.num_experts = original.num_experts - self.hidden_dim = original.hidden_dim - self.intermediate_dim = original.intermediate_dim - self.calibrate_all_experts = calibrate_all_experts - - # Unpack the 3D expert weights into individual MLP modules - # Register experts directly as numbered children to avoid double nesting - # (HF has layers[i].experts, so we want layers[i].experts.0, - # not layers[i].experts.experts.0) - expert_list = Gemma4TextExpertsList(config.text_config, original) - for i, expert in enumerate(expert_list): - self.add_module(str(i), expert) - - def forward( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - ) -> torch.Tensor: - final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot( - top_k_index, num_classes=self.num_experts - ) - expert_mask = expert_mask.permute(2, 1, 0) - - for expert_idx in range(self.num_experts): - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - expert_layer = getattr(self, str(expert_idx)) - - if self.calibrate_all_experts: - # Pass all tokens through expert, then select routed outputs - expert_out = expert_layer(hidden_states)[token_idx] - else: - # Only pass routed tokens through expert - expert_out = expert_layer(hidden_states[token_idx]) - - if len(token_idx) > 0: - current_hidden_states = ( - expert_out * top_k_weights[token_idx, top_k_pos, None] - ) - final_hidden_states.index_add_( - 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) - ) - - return final_hidden_states - - -class Gemma4TextExpertsList(torch.nn.ModuleList): - """ - Unpacks 3D expert parameter tensors into individual Gemma4TextMLP modules - so that each expert's weights are nn.Linear and can be targeted by - quantization with targets="Linear". - """ - - def __init__(self, config: Gemma4TextConfig, original: Gemma4TextExperts): - from transformers.models.gemma4.modeling_gemma4 import Gemma4TextMLP - - self.num_experts = config.num_experts - intermediate_size = config.moe_intermediate_size - - with skip_weights_initialize(): - super().__init__( - [Gemma4TextMLP(config, layer_idx=0) for _ in range(self.num_experts)] - ) - - gate_up_data = original.gate_up_proj.data # [num_experts, 2*inter, hidden] - down_data = original.down_proj.data # [num_experts, hidden, inter] - - for i in range(self.num_experts): - gate_up = gate_up_data[i] # [2*intermediate, hidden] - down = down_data[i] # [hidden, intermediate] - - # gate_up_proj stores [gate; up] stacked along dim 0 - # nn.Linear weight is [out_features, in_features] - self[i].gate_proj.weight.data = ( - gate_up[:intermediate_size, :].clone().contiguous() - ) - self[i].up_proj.weight.data = ( - gate_up[intermediate_size:, :].clone().contiguous() - ) - self[i].down_proj.weight.data = down.clone().contiguous() diff --git a/src/llmcompressor/modeling/glm4_moe.py b/src/llmcompressor/modeling/glm4_moe.py deleted file mode 100644 index 4f4e470d50..0000000000 --- a/src/llmcompressor/modeling/glm4_moe.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig -from transformers.models.glm4_moe.modeling_glm4_moe import ( - Glm4MoeMoE as OriginalGlm4MoeMoE, -) - -from llmcompressor.modeling.moe_context import MoECalibrationModule - - -@MoECalibrationModule.register("Glm4MoeMoE") -class CalibrationGlm4MoeMoE(MoECalibrationModule): - """ - Calibration version of Glm4MoeMoE that sends all tokens to all experts. - During calibration, when calibrate_all_experts=True, all tokens are sent to - all experts to ensure proper quantization statistics are collected for every - expert, not just those activated by the calibration data routing. - """ - - is_permanent = False - - def __init__( - self, - original: OriginalGlm4MoeMoE, - config: Glm4MoeConfig, - calibrate_all_experts: bool = True, - ): - super().__init__() - self.config = config - self.experts = original.experts - self.gate = original.gate - self.shared_experts = original.shared_experts - self.calibrate_all_experts = calibrate_all_experts - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ - Forward pass with optional calibration mode. - When calibrate_all_experts=True: - - All tokens are sent to all experts for calibration - - Routing weights are still used for final output combination - - This ensures all experts see calibration data - When calibrate_all_experts=False: - - Normal MoE routing behavior (only routed tokens go to each expert) - """ - residuals = hidden_states - orig_shape = hidden_states.shape - topk_indices, topk_weights = self.gate(hidden_states) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - # Begin MoE - inline the moe() method logic with calibration support - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - expert_mask = torch.nn.functional.one_hot( - topk_indices, num_classes=len(self.experts) - ) - expert_mask = expert_mask.permute(2, 0, 1) - - for expert_idx, expert in enumerate(self.experts): - mask = expert_mask[expert_idx] - token_indices, weight_indices = torch.where(mask) - has_tokens = token_indices.numel() > 0 - - if self.calibrate_all_experts: - # When calibrating, run all tokens through the expert to gather stats. - # The output is still calculated using only the routed tokens. - expert_output_full = expert(hidden_states) - if not has_tokens: - # No tokens routed to this expert, but stats were gathered. - continue - expert_output = expert_output_full[token_indices] - else: - # Standard MoE behavior: only process tokens routed to this expert. - if not has_tokens: - continue - expert_output = expert(hidden_states[token_indices]) - - # Common logic for combining expert outputs - expert_weights = topk_weights[token_indices, weight_indices] - weighted_output = expert_output * expert_weights.unsqueeze(-1) - final_hidden_states.index_add_(0, token_indices, weighted_output) - # End MoE - - hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states - - def restore(self, original: torch.nn.Module) -> torch.nn.Module: - """ - Restore the original module structure. - - Since is_permanent=False, this method is called when exiting - the calibration context to restore the original MoE module. - """ - return original diff --git a/src/llmcompressor/modeling/glm4_moe_lite.py b/src/llmcompressor/modeling/glm4_moe_lite.py deleted file mode 100644 index 7896e7bf3f..0000000000 --- a/src/llmcompressor/modeling/glm4_moe_lite.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import torch - -from llmcompressor.utils.dev import skip_weights_initialize - -from .glm_moe_dsa import CalibrationGlmMoeDsaMoE -from .moe_context import MoECalibrationModule - -if TYPE_CHECKING: - from transformers.models.glm4_moe_lite.configuration_glm4_moe_lite import ( - Glm4MoeLiteConfig, - ) - from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import ( - Glm4MoeLiteNaiveMoe, - ) - - -@MoECalibrationModule.register("Glm4MoeLiteMoE") -class CalibrationGlm4MoeLiteMoE(CalibrationGlmMoeDsaMoE): - """ - Calibration version of Glm4MoeLiteMoE that unfuses 3D expert parameters into - individual MLP modules (nn.Linear) so they can be quantized. - - GLM-4.7-Flash Lite stores routed experts in a ``Glm4MoeLiteNaiveMoe`` module - using 3D parameters (``gate_up_proj``, ``down_proj``) instead of ``nn.Linear`` - submodules. Since llm-compressor targets ``Linear`` modules, the original - routed experts are invisible to quantization and remain BF16 unless they are - unpacked. - - Inherits routing logic (:meth:`route_tokens_to_experts`) and forward pass - from :class:`CalibrationGlmMoeDsaMoE`, overriding only expert creation to - use ``Glm4MoeLiteMLP`` modules. - """ - - def _get_num_experts(self, config) -> int: - return config.n_routed_experts - - def _make_experts(self, config, original_experts) -> torch.nn.ModuleList: - return SequentialGlm4MoeLiteExperts(config, original_experts) - - -class SequentialGlm4MoeLiteExperts(torch.nn.ModuleList): - """ - Unpacks 3D expert parameter tensors into individual Glm4MoeLiteMLP modules so - each routed expert has standard ``nn.Linear`` projections visible to - ``targets="Linear"``. - """ - - def __init__(self, config: Glm4MoeLiteConfig, original: Glm4MoeLiteNaiveMoe): - from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import ( - Glm4MoeLiteMLP, - ) - - self.num_experts = config.n_routed_experts - intermediate_size = config.moe_intermediate_size - - with skip_weights_initialize(): - super().__init__( - [ - Glm4MoeLiteMLP(config, intermediate_size=intermediate_size) - for _ in range(self.num_experts) - ] - ) - - with torch.no_grad(): - for i in range(self.num_experts): - gate_up = original.gate_up_proj[i] - down = original.down_proj[i] - gate_proj, up_proj = gate_up.chunk(2, dim=0) - - self[i].gate_proj.weight.copy_(gate_proj.contiguous()) - self[i].up_proj.weight.copy_(up_proj.contiguous()) - self[i].down_proj.weight.copy_(down.contiguous()) diff --git a/src/llmcompressor/modeling/glm_moe_dsa.py b/src/llmcompressor/modeling/glm_moe_dsa.py deleted file mode 100644 index c4ddcca5aa..0000000000 --- a/src/llmcompressor/modeling/glm_moe_dsa.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import torch - -from llmcompressor.modeling.moe_context import MoECalibrationModule - -if TYPE_CHECKING: - from transformers.models.glm_moe_dsa.configuration_glm_moe_dsa import ( - GlmMoeDsaConfig, - ) - from transformers.models.glm_moe_dsa.modeling_glm_moe_dsa import ( - GlmMoeDsaMoE, - GlmMoeDsaNaiveMoe, - ) - -from llmcompressor.utils.dev import skip_weights_initialize - - -@MoECalibrationModule.register("GlmMoeDsaMoE") -class CalibrationGlmMoeDsaMoE(MoECalibrationModule): - """ - Calibration version of GlmMoeDsaMoE that unpacks experts for sequential - processing. - - This module: - 1. Unpacks the packed expert weights (3D -> 2D) for calibration - 2. Optionally sends all tokens to all experts during calibration - 3. Stays in unpacked form (permanent) for vLLM compatibility - - Subclasses (e.g. :class:`CalibrationGlm4MoeLiteMoE`) override - :meth:`_get_num_experts` and :meth:`_make_experts` to handle - model-specific config fields and MLP classes, while inheriting the - shared routing and forward logic. - """ - - is_permanent = True - - def _get_num_experts(self, config) -> int: - """Return the number of routed experts from the model config. - - Override in subclasses whose config stores the expert count under a - different attribute name (e.g. ``n_routed_experts``). - """ - return config.num_local_experts - - def _make_experts(self, config, original_experts) -> torch.nn.ModuleList: - """Create the sequential (unpacked) expert module list. - - Override in subclasses that need a different MLP class for unpacking - (e.g. ``Glm4MoeLiteMLP`` instead of ``GlmMoeDsaMLP``). - """ - return SequentialGlmMoeDsaExperts(config, original_experts) - - def __init__( - self, - original: GlmMoeDsaMoE, - config: GlmMoeDsaConfig, - calibrate_all_experts: bool = True, - ): - super().__init__() - self.top_k = config.num_experts_per_tok - self.num_experts = self._get_num_experts(config) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - - self.experts = self._make_experts(config, original.experts) - self.gate = original.gate - self.shared_experts = original.shared_experts - self.calibrate_all_experts = calibrate_all_experts - - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view( - -1, self.n_group, self.n_routed_experts // self.n_group - ) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill( - ~score_mask.bool(), 0.0 - ) - topk_indices = torch.topk( - scores_for_choice, k=self.top_k, dim=-1, sorted=False - )[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - residuals = hidden_states - orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot( - topk_indices, num_classes=self.num_experts - ) - expert_mask = expert_mask.permute(2, 1, 0) - - for i in range(self.num_experts): - top_k_pos, token_idx = torch.where(expert_mask[i]) - has_tokens = token_idx.numel() > 0 - - if self.calibrate_all_experts: - expert_out_all = self.experts[i](hidden_states) - if not has_tokens: - continue - expert_out = expert_out_all[token_idx] - else: - if not has_tokens: - continue - expert_out = self.experts[i](hidden_states[token_idx]) - - weighted_output = expert_out * topk_weights[token_idx, top_k_pos, None] - final_hidden_states.index_add_( - 0, token_idx, weighted_output.to(final_hidden_states.dtype) - ) - - hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states - - -class SequentialGlmMoeDsaExperts(torch.nn.ModuleList): - def __init__(self, config: GlmMoeDsaConfig, original: GlmMoeDsaNaiveMoe): - from transformers.models.glm_moe_dsa.modeling_glm_moe_dsa import GlmMoeDsaMLP - - self.num_experts = config.num_local_experts - with skip_weights_initialize(): - super().__init__( - [ - GlmMoeDsaMLP(config, intermediate_size=config.moe_intermediate_size) - for _ in range(self.num_experts) - ] - ) - - with torch.no_grad(): - for i in range(self.num_experts): - gate_up = original.gate_up_proj[i] - down = original.down_proj[i] - - gate_proj, up_proj = gate_up.chunk(2, dim=0) - - self[i].gate_proj.weight.copy_(gate_proj.contiguous()) - self[i].up_proj.weight.copy_(up_proj.contiguous()) - self[i].down_proj.weight.copy_(down.contiguous()) diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py deleted file mode 100644 index 44258a3927..0000000000 --- a/src/llmcompressor/modeling/gpt_oss.py +++ /dev/null @@ -1,259 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import List, Optional - -import torch -import torch.nn as nn - - -class LinearExpert(nn.Module): - """ - One MoE expert with separate gate / up / down projections. - - This mirrors the GPT-OSS expert behavior: - gate = clamp(gate_proj(x)) - up = clamp(up_proj(x)) - glu = gate * sigmoid(alpha * gate) - y = down_proj((up + 1) * glu) - """ - - def __init__( - self, hidden_size: int, intermediate_size: int, alpha: float, limit: float - ): - super().__init__() - self.alpha = alpha - self.limit = limit - - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - gate = self.gate_proj(x) - up = self.up_proj(x) - - gate = gate.clamp(max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - - glu = gate * torch.sigmoid(self.alpha * gate) - act = (up + 1) * glu - return self.down_proj(act) - - -class LinearExperts(nn.Module): - """ - Container of multiple LinearExpert modules, driven by - router_indices / routing_weights. - - This is the "separate gate/up" layout. - It is meant to replace the original GPT-OSS `experts` submodule. - """ - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - num_experts: int, - alpha: float = 1.702, - limit: float = 7.0, - ): - super().__init__() - self.hidden_size = hidden_size - self.expert_dim = intermediate_size - self.num_experts = num_experts - self.alpha = alpha - self.limit = limit - - self.experts = nn.ModuleList( - [ - LinearExpert(hidden_size, intermediate_size, alpha, limit) - for _ in range(num_experts) - ] - ) - - @torch.no_grad() - def copy_from_fused_weights( - self, - legacy_gate_up_W: torch.Tensor, # [E, H, 2D] - legacy_gate_up_b: torch.Tensor, # [E, 2D] - legacy_down_W: torch.Tensor, # [E, D, H] - legacy_down_b: torch.Tensor, # [E, H] - ) -> None: - """ - De-interleave fused gate_up weights/bias and copy into separate gate/up experts. - """ - E, H, twoD = legacy_gate_up_W.shape - assert E == self.num_experts - D = twoD // 2 - assert D == self.expert_dim - - for i in range(E): - Wi = legacy_gate_up_W[i] # [H, 2D] - bi = legacy_gate_up_b[i] # [2D] - - Wg = Wi[:, 0::2].contiguous() # [H, D] - Wu = Wi[:, 1::2].contiguous() # [H, D] - bg = bi[0::2].contiguous() # [D] - bu = bi[1::2].contiguous() # [D] - - expert = self.experts[i] - expert.gate_proj.weight.copy_(Wg.t()) - expert.gate_proj.bias.copy_(bg) - expert.up_proj.weight.copy_(Wu.t()) - expert.up_proj.bias.copy_(bu) - - expert.down_proj.weight.copy_(legacy_down_W[i].t()) - expert.down_proj.bias.copy_(legacy_down_b[i]) - - def forward( - self, - hidden_states: torch.Tensor, # [B, T, H] - router_indices: Optional[ - torch.Tensor - ] = None, # [B, T, top_k] or [tokens, top_k] - routing_weights: Optional[torch.Tensor] = None, # [B, T, E] or [tokens, E] - ) -> torch.Tensor: - """ - Implements the MoE computation using the router outputs. - - This is compatible with the GPT-OSS MoE call pattern: - experts(hidden_states, router_indices, routing_weights) - """ - assert ( - routing_weights is not None and router_indices is not None - ), "router inputs required" - - # Normalize shapes to [tokens, H], [tokens, top_k], [tokens, E] - if hidden_states.dim() == 3: - B, T, H = hidden_states.shape - x = hidden_states.reshape(-1, H) - else: - # Already flattened - B, _ = 1, hidden_states.shape[0] - H = hidden_states.shape[-1] - x = hidden_states - - if router_indices.dim() == 3: - router_indices = router_indices.reshape(-1, router_indices.shape[-1]) - if routing_weights.dim() == 3: - routing_weights = routing_weights.reshape(-1, routing_weights.shape[-1]) - - num_experts_plus_dummy = routing_weights.shape[1] - out = torch.zeros_like(x) - - # GPT-OSS router uses an extra "no expert" bucket at index E - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=num_experts_plus_dummy - ).permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for idx in expert_hit: - e = idx[0].item() - if e == self.num_experts: - # Skip "no expert" bucket - continue - - _, token_idx = torch.where(expert_mask[e]) - xi = x[token_idx] - - expert = self.experts[e] - yi = expert(xi) - - w = routing_weights[token_idx, e, None] - out.index_add_(0, token_idx, (yi * w).to(out.dtype)) - - return out.view(B, -1, H) - - -@dataclass -class ExpertMeta: - path: str - hidden_size: int - intermediate_size: int - num_experts: int - device: torch.device - dtype: torch.dtype - - -def get_module_by_path(root: nn.Module, dotpath: str) -> nn.Module: - m: nn.Module = root - if not dotpath: - return root - for p in dotpath.split("."): - m = getattr(m, p) - return m - - -def set_module_by_path(root: nn.Module, dotpath: str, new_module: nn.Module) -> None: - parts = dotpath.split(".") - parent = get_module_by_path(root, ".".join(parts[:-1])) - setattr(parent, parts[-1], new_module) - - -def find_experts(model: nn.Module) -> List[ExpertMeta]: - """ - Locate GPT-OSS MoE expert modules under model.model.layers[*].mlp.experts. - """ - metas: List[ExpertMeta] = [] - for li, layer in enumerate(model.model.layers): - experts = layer.mlp.experts - device = next(experts.parameters(), torch.zeros(())).device - dtype = next(experts.parameters(), torch.zeros(())).dtype - intermediate = getattr(experts, "expert_dim", None) - if intermediate is None: - intermediate = getattr(experts, "intermediate_size") - - metas.append( - ExpertMeta( - path=f"model.layers.{li}.mlp.experts", - hidden_size=experts.hidden_size, - intermediate_size=intermediate, - num_experts=experts.num_experts, - device=device, - dtype=dtype, - ) - ) - return metas - - -def convert_model_for_quantization_gptoss(model: nn.Module) -> None: - """ - In-place conversion of a GPT-OSS model: - - - Finds all fused MoE expert blocks (with gate_up_proj/down_proj). - - Replaces them with LinearExperts that expose plain nn.Linear - parameters (gate_proj, up_proj, down_proj), which play nicely - with LLM Compressor W4A8 quantization. - """ - metas = find_experts(model) - for meta in metas: - legacy = get_module_by_path(model, meta.path) - - # Sanity check that this is the fused layout we expect. - if not all( - hasattr(legacy, attr) - for attr in [ - "gate_up_proj", - "gate_up_proj_bias", - "down_proj", - "down_proj_bias", - ] - ): - continue - - new_exp = LinearExperts( - hidden_size=meta.hidden_size, - intermediate_size=meta.intermediate_size, - num_experts=meta.num_experts, - ).to(device=meta.device, dtype=meta.dtype) - - new_exp.copy_from_fused_weights( - legacy_gate_up_W=legacy.gate_up_proj, - legacy_gate_up_b=legacy.gate_up_proj_bias, - legacy_down_W=legacy.down_proj, - legacy_down_b=legacy.down_proj_bias, - ) - - set_module_by_path(model, meta.path, new_exp) diff --git a/src/llmcompressor/modeling/granite4.py b/src/llmcompressor/modeling/granite4.py deleted file mode 100644 index 09b2d391cd..0000000000 --- a/src/llmcompressor/modeling/granite4.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( - GraniteMoeHybridParallelExperts, -) - - -class GraniteMoeHybridParallelExpertsLinear(torch.nn.Linear): - def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: - """Use a real Linear so that llmcompressor and vllm can handle it easier. - 1. Change .weight from 3D [num_experts, output_size, input_size] to 2D - [num_experts * output_size, input_size] before calling llm-compressor - 2. Change it back to 3D before saving ckpt - """ - super().__init__( - input_size, output_size * num_experts, bias=False, device="meta" - ) - self.num_experts = num_experts - self.input_size = input_size - self.output_size = output_size - self.is_2d: bool = True - - @classmethod - def from_3d_expert(cls, original: GraniteMoeHybridParallelExperts): - """Reshape weights of GraniteMoeHybridParallelExperts module into 2D and store - them as weights of this "Linear" module. - """ - newMoeLin = cls(original.num_experts, original.input_size, original.output_size) - newMoeLin.weight = torch.nn.Parameter( - original.weight.view(-1, original.input_size).clone(), - requires_grad=False, - ) - original.to("cpu") - newMoeLin.is_2d = True - return newMoeLin - - def to_3d_expert(self) -> None: - """Convert weights and quantization parameters from 2D to 3D shape.""" - # Calculate all shapes up front - packed_input_size = self.weight.shape[1] - pack_factor = self.input_size // packed_input_size - - assert hasattr(self, "weight_scale"), "weight_scale not found" - grouped_output = self.weight_scale.shape[0] // self.num_experts - grouped_input = self.weight_scale.shape[1] - - expected_packed_weight_shape = torch.Size( - (self.num_experts * self.output_size, packed_input_size) - ) - final_packed_weight_shape = torch.Size( - (self.num_experts, self.output_size, packed_input_size) - ) - - expected_packed_weight_scale_shape = torch.Size( - (self.num_experts * grouped_output, grouped_input) - ) - final_packed_weight_scale_shape = torch.Size( - (self.num_experts, grouped_output, grouped_input) - ) - - # Assert shapes match expectations - assert self.weight.shape == expected_packed_weight_shape, ( - f"weight shape {self.weight.shape} != " - f"expected {expected_packed_weight_shape}" - ) - - assert self.weight_scale.shape == expected_packed_weight_scale_shape, ( - f"weight_scale shape {self.weight_scale.shape} != " - f"expected {expected_packed_weight_scale_shape}" - ) - - # Reshape to 3D - self.weight = torch.nn.Parameter( - self.weight.view(final_packed_weight_shape).clone(), - requires_grad=False, - ) - self.weight_scale = torch.nn.Parameter( - self.weight_scale.view(final_packed_weight_scale_shape).clone(), - requires_grad=False, - ) - - if hasattr(self, "weight_zero_point"): - expected_packed_zp_shape = torch.Size( - (self.num_experts * grouped_output // pack_factor, grouped_input) - ) - final_packed_zp_shape = torch.Size( - (self.num_experts, grouped_output // pack_factor, grouped_input) - ) - assert self.weight_zero_point.shape == expected_packed_zp_shape, ( - f"weight_zero_point shape {self.weight_zero_point.shape} != " - f"expected {expected_packed_zp_shape}" - ) - self.weight_zero_point = torch.nn.Parameter( - self.weight_zero_point.view(final_packed_zp_shape).clone(), - requires_grad=False, - ) - - self.is_2d = False - - def forward(self, inputs, expert_size): - """Modified from original forward()""" - - input_list = inputs.split(expert_size, dim=0) - - weight_3d = self.weight.view( - self.num_experts, self.output_size, self.input_size - ) - output_list = [] - for i in range(self.num_experts): - output_list.append(torch.nn.functional.linear(input_list[i], weight_3d[i])) - - results = torch.cat(output_list, dim=0) - return results - - def __repr__(self): - if self.is_2d: - sizes_str = f"(out={self.weight.shape[0]},in={self.weight.shape[1]})" - else: - sizes_str = ( - f"(exp={self.weight.shape[0]},out={self.weight.shape[1]}," - f"in={self.weight.shape[2]})" - ) - return f"{self.__class__.__name__}{sizes_str}" diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py deleted file mode 100644 index 9145c66a60..0000000000 --- a/src/llmcompressor/modeling/llama4.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Tuple - -import torch -from transformers.models.llama4.configuration_llama4 import ( - Llama4Config, - Llama4TextConfig, -) -from transformers.models.llama4.modeling_llama4 import ( - Llama4TextExperts, - Llama4TextMLP, - Llama4TextMoe, -) - -from llmcompressor.modeling.moe_context import MoECalibrationModule -from llmcompressor.utils.dev import skip_weights_initialize - - -@MoECalibrationModule.register("Llama4TextMoe") -class SequentialLlama4TextMoe(MoECalibrationModule): - """ - Calibration version of Llama4TextMoe that unpacks experts for sequential processing. - - This module: - 1. Unpacks the packed expert weights (3D -> 2D) for calibration - 2. Optionally sends all tokens to all experts during calibration - 3. Stays in unpacked form (permanent) for vLLM compatibility - """ - - is_permanent = True - - def __init__( - self, - original: Llama4TextMoe, - config: Llama4Config, - calibrate_all_experts: bool = True, - ): - super().__init__() - # Extract text config from multimodal config - text_config: Llama4TextConfig = config.get_text_config() - self.top_k = text_config.num_experts_per_tok - self.hidden_dim = text_config.hidden_size - self.num_experts = text_config.num_local_experts - - self.experts = SequentialLlama4TextExperts(text_config, original.experts) - self.router = original.router - self.shared_expert = original.shared_expert - self.calibrate_all_experts = calibrate_all_experts - - def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_scores, router_logits = self.router(hidden_states) - out = self.shared_expert(hidden_states) - - _, router_indices = torch.topk(router_logits, self.top_k, dim=1) - expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=self.num_experts - ).permute(2, 1, 0) # (num_experts, top_k, batch_size * sequence_length) - - for i in range(self.num_experts): - # fetch relevant token indices for this expert - token_idx = torch.where(expert_mask[i].squeeze(0)) - - # Original Llama4 definition - apply score to hidden states - # before applying to expert this results in NaNs during calibration - # routed_in = hidden_states * router_scores[:, i].reshape(-1, 1) - - if self.calibrate_all_experts: - # all tokens for this expert - expert_out = self.experts[i](hidden_states)[token_idx] - else: - # only relevant tokens for this expert - expert_out = self.experts[i](hidden_states[token_idx]) - - if len(token_idx) > 0: - # Deviation from original Llama4 definition to avoid NaNs - # NaNs during calibration - weighted_output = expert_out * router_scores[:, i][token_idx].reshape( - -1, 1 - ) - out[token_idx] += weighted_output - - return out, router_logits - - -class SequentialLlama4TextExperts(torch.nn.ModuleList): - def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts): - self.num_experts = original.gate_up_proj.shape[0] - with skip_weights_initialize(): - super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) - - for i in range(self.num_experts): - gate_up = original.gate_up_proj[i] - down = original.down_proj[i] - - gate_proj, up_proj = gate_up.chunk(2, dim=-1) - - self[i].gate_proj.weight.data = gate_proj.t().contiguous() - self[i].up_proj.weight.data = up_proj.t().contiguous() - self[i].down_proj.weight.data = down.t().contiguous() diff --git a/src/llmcompressor/modeling/moe/context.py b/src/llmcompressor/modeling/moe/context.py new file mode 100644 index 0000000000..54c8ab92b3 --- /dev/null +++ b/src/llmcompressor/modeling/moe/context.py @@ -0,0 +1,21 @@ +import contextlib + +from transformers import PreTrainedModel + +from . import linearize + +CALIBRATE_ALL_EXPERTS = False + + +@contextlib.contextmanager +def moe_calibration_context(model: PreTrainedModel, calibrate_all_experts: bool): + global CALIBRATE_ALL_EXPERTS + + linearize.linearize_moe_model(model) + + restore_value = CALIBRATE_ALL_EXPERTS + CALIBRATE_ALL_EXPERTS = calibrate_all_experts + try: + yield + finally: + CALIBRATE_ALL_EXPERTS = restore_value diff --git a/src/llmcompressor/modeling/moe/deepseek_v4.py b/src/llmcompressor/modeling/moe/deepseek_v4.py new file mode 100644 index 0000000000..a88a578322 --- /dev/null +++ b/src/llmcompressor/modeling/moe/deepseek_v4.py @@ -0,0 +1,68 @@ +from transformers import PreTrainedModel +from transformers.conversion_mapping import extract_weight_conversions_for_model +from transformers.core_model_loading import WeightConverter, WeightRenaming + + +def modify_save_with_linearized_experts_deepseek_v4(model: PreTrainedModel): + """ + Replace the fused-expert weight converters in the deepseek_v4 conversion + mapping with per-expert renamings so that checkpoint weights load directly + into a linearized model (individual ``nn.Linear`` modules per expert) + instead of being merged into 3-D ``gate_up_proj`` / ``down_proj`` tensors. + """ + weight_conversions = extract_weight_conversions_for_model(model) + if weight_conversions is None: + weight_conversions = [] + + new_conversions = [ + conv + for conv in weight_conversions + if not _is_fused_experts_converter(conv) + ] + + new_conversions.extend( + [ + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.mlp\.experts\.(\d+)\.w1\.", + target_patterns=r"layers.\1.mlp.experts.\2.gate_proj.", + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.mlp\.experts\.(\d+)\.w3\.", + target_patterns=r"layers.\1.mlp.experts.\2.up_proj.", + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.mlp\.experts\.(\d+)\.w2\.", + target_patterns=r"layers.\1.mlp.experts.\2.down_proj.", + ), + ] + ) + + model._weight_conversions = new_conversions + + +def modify_save_with_linearized_experts_qwen2_moe(model: PreTrainedModel): + """ + Remove the fused-expert weight converters from the qwen2_moe conversion + mapping so that checkpoint weights load directly into a linearized model. + + The qwen2_moe checkpoint already uses ``gate_proj`` / ``up_proj`` / + ``down_proj`` naming, which matches the linearized module names, so no + replacement renamings are needed — just dropping the ``MergeModulelist`` + + ``Concatenate`` converters is sufficient. + """ + weight_conversions = extract_weight_conversions_for_model(model) + if weight_conversions is None: + weight_conversions = [] + + model._weight_conversions = [ + conv + for conv in weight_conversions + if not _is_fused_experts_converter(conv) + ] + + +def _is_fused_experts_converter(converter) -> bool: + return isinstance(converter, WeightConverter) and converter.target_patterns in ( + ["mlp.experts.gate_up_proj"], + ["mlp.experts.down_proj"], + ) diff --git a/src/llmcompressor/modeling/moe/gpt_oss.py b/src/llmcompressor/modeling/moe/gpt_oss.py new file mode 100644 index 0000000000..9b523ca2a0 --- /dev/null +++ b/src/llmcompressor/modeling/moe/gpt_oss.py @@ -0,0 +1,46 @@ +class GptOssExpertMLP(ExpertMLPWithGate): + @classmethod + def from_experts( + cls, + experts: FusedExpertsModule, + expert_index: int, + moe_intermediate_size: int, + hidden_dim: int, + ): + assert experts.has_gate + if experts.__class__._apply_gate is not _default_apply_gate: + # assume that if a `_apply_gate` is implemented, then the weight is not valid for quantization (for example, might be interleaved) + raise NotImplementedError( + f"Linearization for {experts.__class__.__name__} has not been implemented yet" + ) + + with skip_weights_initialize(): + instance = cls( + hidden_dim, moe_intermediate_size, experts.has_bias, experts._apply_gate + ) + + # load weights + gate_weight = experts.gate_up_proj[expert_index, :moe_intermediate_size] + up_weight = experts.gate_up_proj[expert_index, moe_intermediate_size:] + down_weight = experts.down_proj[expert_index] + + if experts.is_transposed: + gate_weight = gate_weight.T + up_weight = up_weight.T + down_weight = down_weight.T + + instance.gate_proj.weight.copy_(gate_weight) + instance.up_proj.weight.copy_(up_weight) + instance.down_proj.weight.copy_(down_weight) + + # load biases + if experts.has_bias: + gate_bias = experts.gate_up_proj_bias[expert_index, :moe_intermediate_size] + up_bias = experts.gate_up_proj_bias[expert_index, moe_intermediate_size:] + down_bias = experts.down_proj_bias[expert_index] + + instance.gate_proj.bias.copy_(gate_bias) + instance.up_proj.bias.copy_(up_bias) + instance.down_proj.bias.copy_(down_bias) + + return instance diff --git a/src/llmcompressor/modeling/moe/helpers.py b/src/llmcompressor/modeling/moe/helpers.py new file mode 100644 index 0000000000..a329030567 --- /dev/null +++ b/src/llmcompressor/modeling/moe/helpers.py @@ -0,0 +1,77 @@ +import ast +import inspect +from abc import ABC +from typing import Callable, ClassVar, Optional + +import torch +from transformers import PreTrainedConfig +from transformers.core_model_loading import ( + WeightConverter, + WeightTransform, +) + + +class FusedExpertsModule(torch.nn.Module, ABC): + """ + Fake Typing Class + + """ + + config: PreTrainedConfig + has_gate: bool + has_bias: bool + is_transposed: bool + _apply_gate: ClassVar[Callable] + + gate_up_proj: torch.nn.Parameter + down_proj: torch.nn.Parameter + act_fn: torch.nn.Module + + up_proj: Optional[torch.nn.Parameter] # not has_gate + up_proj_bias: Optional[torch.nn.Parameter] # not has_gate, has_bias + gate_up_proj_bias: Optional[torch.nn.Parameter] # has_bias + down_proj_bias: Optional[torch.nn.Parameter] # has_bias + + +def _is_moe_experts_module(module) -> bool: + """Detect modules whose class is decorated with + ``@use_experts_implementation`` by inspecting the class source AST.""" + try: + source = inspect.getsource(type(module)) + tree = ast.parse(source) + except (OSError, TypeError): + return False + + for node in ast.iter_child_nodes(tree): + if not isinstance(node, ast.ClassDef): + continue + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name): + name = decorator.id + elif isinstance(decorator, ast.Call) and isinstance( + decorator.func, ast.Name + ): + name = decorator.func.id + else: + continue + if name == "use_experts_implementation": + return True + + return False + + +def _is_moe_experts_converter(converter: WeightTransform) -> bool: + return isinstance(converter, WeightConverter) and converter.target_patterns in ( + ".experts.gate_up_proj", + ".experts.down_proj", + ) + + +def _get_moe_shapes(experts: FusedExpertsModule) -> tuple[int, int, int]: + # get shapes from the down_proj. This is more reliable than getting from config + return ( + experts.config.n_routed_experts, + experts.config.moe_intermediate_size, + experts.config.hidden_size, + ) + diff --git a/src/llmcompressor/modeling/moe/linear_experts.py b/src/llmcompressor/modeling/moe/linear_experts.py new file mode 100644 index 0000000000..da88ab2598 --- /dev/null +++ b/src/llmcompressor/modeling/moe/linear_experts.py @@ -0,0 +1,165 @@ +# probably only need to registry this class +class ExpertMLP(torch.nn.Module, ABC): + pass + + +class ExpertMLPWithGate(ExpertMLP): + up_proj: torch.nn.Linear + gate_proj: torch.nn.Linear + down_proj: torch.nn.Linear + _apply_gate: Callable[[torch.Tensor], torch.Tensor] + + def __init__( + self, + hidden_dim: int, + moe_intermediate_size: int, + has_bias: bool, + _apply_gate: Callable[[torch.Tensor], torch.Tensor], + ): + super().__init__() + self.up_proj = torch.nn.Linear(hidden_dim, moe_intermediate_size, bias=has_bias) + self.gate_proj = torch.nn.Linear( + hidden_dim, moe_intermediate_size, bias=has_bias + ) + self.down_proj = torch.nn.Linear( + moe_intermediate_size, hidden_dim, bias=has_bias + ) + self._apply_gate = _apply_gate + + @classmethod + def from_experts( + cls, + experts: FusedExpertsModule, + expert_index: int, + moe_intermediate_size: int, + hidden_dim: int, + ): + assert experts.has_gate + # if experts.__class__._apply_gate is not _default_apply_gate: + # # assume that if a `_apply_gate` is implemented, then the weight + # # is not valid for quantization (for example, might be interleaved) + # raise NotImplementedError( + # f"Linearization for {experts.__class__.__name__} " + # "has not been implemented yet" + # ) + + with skip_weights_initialize(): + instance = cls( + hidden_dim, moe_intermediate_size, experts.has_bias, experts._apply_gate + ) + + for module in instance.modules(): + offload_module(module, **get_cache_init_kwargs(experts)) + + # load weights + gate_weight = experts.gate_up_proj[expert_index, :moe_intermediate_size] + up_weight = experts.gate_up_proj[expert_index, moe_intermediate_size:] + down_weight = experts.down_proj[expert_index] + + if experts.is_transposed: + gate_weight = gate_weight.T + up_weight = up_weight.T + down_weight = down_weight.T + + instance.gate_proj.weight.copy_(gate_weight) + instance.up_proj.weight.copy_(up_weight) + instance.down_proj.weight.copy_(down_weight) + + # load biases + if experts.has_bias: + gate_bias = experts.gate_up_proj_bias[expert_index, :moe_intermediate_size] + up_bias = experts.gate_up_proj_bias[expert_index, moe_intermediate_size:] + down_bias = experts.down_proj_bias[expert_index] + + instance.gate_proj.bias.copy_(gate_bias) + instance.up_proj.bias.copy_(up_bias) + instance.down_proj.bias.copy_(down_bias) + + return instance + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.down_proj( + self._apply_gate( + torch.cat( + [self.gate_proj(hidden_states), self.up_proj(hidden_states)], dim=-1 + ) + ) + ) + +class ExpertMLPWithoutGate(ExpertMLP): + up_proj: torch.nn.Linear + down_proj: torch.nn.Linear + act_fn: torch.nn.Module + + def __init__( + self, + hidden_dim: int, + moe_intermediate_size: int, + has_bias: bool, + act_fn: torch.nn.Module, + ): + with local_torch_dtype(model.config.dtype, model.__class__.__name__): + super().__init__() + + self.up_proj = torch.nn.Linear(hidden_dim, moe_intermediate_size, bias=has_bias) + self.down_proj = torch.nn.Linear( + moe_intermediate_size, hidden_dim, bias=has_bias + ) + self.act_fn = act_fn + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act_fn(self.up_proj(hidden_states))) + + +class LinearExperts(torch.nn.ModuleList): + def __init__(self, config): + with local_torch_dtype(model.config.dtype, model.__class__.__name__): + num_experts, moe_intermediate_size, hidden_dim = _get_moe_shapes(experts) + + # TODO: add registry + experts_cls = ExpertMLPWithGate if experts.has_gate else ExpertMLPWithoutGate + + return cls( + [ + experts_cls(hidden_dim, moe_intermediate_size) + for index in range(num_experts) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + num_experts = len(self) + + # create tokens mask + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + + for expert_idx in range(num_experts): + # select tokens for this expert + top_k_pos, token_indices = torch.where(expert_mask[expert_idx]) + + # apply expert, maybe pass all tokens to the expert + expert = self[expert_idx] + # if context.CALIBRATE_ALL_EXPERTS: + # TODO: fully integrate moe context + if True: + expert_output = expert(hidden_states)[token_indices] + else: + expert_output = expert(hidden_states[token_indices]) + + # apply weighting to outputs + expert_weights = top_k_weights[token_indices, top_k_pos, None] + weighted_output = expert_output * expert_weights + + # accumulate the selected tokens + final_hidden_states.index_add_( + 0, token_indices, weighted_output.to(final_hidden_states.dtype) + ) # TODO: check why float + + return final_hidden_states \ No newline at end of file diff --git a/src/llmcompressor/modeling/moe/linearize.py b/src/llmcompressor/modeling/moe/linearize.py new file mode 100644 index 0000000000..34459006ce --- /dev/null +++ b/src/llmcompressor/modeling/moe/linearize.py @@ -0,0 +1,69 @@ +from abc import ABC +from typing import Callable, Type + +import torch +import contextlib +import torch.distributed as dist +import tqdm +from compressed_tensors.distributed import is_distributed +from compressed_tensors.offload import get_cache_init_kwargs, offload_module +from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM +from transformers.conversion_mapping import ( + extract_weight_conversions_for_model, +) +from transformers.integrations.moe import _default_apply_gate +from transformers.modeling_utils import local_torch_dtype +from transformers.monkey_patching import register_patch_mapping +from transformers.conversion_mapping import register_checkpoint_conversion_mapping, WeightConverter + + +from llmcompressor.utils.dev import skip_weights_initialize + +from .helpers import ( + FusedExpertsModule, + _get_moe_shapes, + _is_moe_experts_converter, + _is_moe_experts_module, +) +from .linear_experts import LinearExperts + +from compressed_tensors.utils import patch_attr + + +# TODO: in the future, can probably match using regex +ARCH_TO_EXPERTS_MODULE_CLS = { + "deepseek_v4": "DeepseekV4Experts" +} + + +def get_linear_conversion_mapping(): + pass + + +@contextlib.contextmanager +def load_linearized_moe(model_cls: Type[PreTrainedModel] = AutoModelForCausalLM): + + original_from_pretrained = model_cls.from_pretrained + + @classmethod + def patched(cls, *args, **kwargs): + config = AutoConfig.from_pretrained(*args, **kwargs) + model_type = config.model_type + + experts_cls = ARCH_TO_EXPERTS_MODULE_CLS[model_type] + #forward_mapping, backward_mapping = get_linear_conversion_mapping(model_type) + + register_patch_mapping( + {experts_cls.__name__: LinearExperts} + ) + # register_checkpoint_conversion_mapping( + # model_type=model_type, mapping=forward_mapping, overwrite=True + # ) + + model: PreTrainedModel = original_from_pretrained(cls, *args, **kwargs) + #model._conversion_mapping = backward_mapping + return model + + with patch_attr(model_cls, "from_pretrained", patched): + yield + diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py deleted file mode 100644 index a78dcad523..0000000000 --- a/src/llmcompressor/modeling/moe_context.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Simplified interface for MoE model calibration. - -MoE (Mixture of Experts) models route tokens to different expert networks. -During calibration for quantization/compression, we need to ensure ALL experts -see data, not just the ones selected by the router. This module provides the -infrastructure to temporarily modify MoE modules for proper calibration. - -Key components: -- MoECalibrationModule: Abstract base class for calibration modules -- moe_calibration_context: Context manager that applies calibration to a model -""" - -import contextlib -from abc import ABC - -import torch -import torch.distributed as dist -from compressed_tensors.offload import ( - get_cache_init_kwargs, - is_distributed, -) -from compressed_tensors.offload.cache import OffloadCache -from compressed_tensors.offload.module import offload_module -from compressed_tensors.registry import RegistryMixin, standardize_lookup_name -from loguru import logger -from tqdm import tqdm -from transformers import PreTrainedModel - -__all__ = [ - "MoECalibrationModule", - "moe_calibration_context", -] - - -class MoECalibrationModule(ABC, torch.nn.Module, RegistryMixin): - """ - Abstract base class for MoE calibration modules. - - Calibration modules replace original MoE modules during the calibration - phase to ensure all experts receive data for proper quantization statistics. - - Subclasses must: - 1. Implement `__init__()` with signature: - (self, original, config, calibrate_all_experts=True) - 2. Set `is_permanent` to indicate if module should stay in calibration form - 3. Optionally implement `restore()` if is_permanent=False - """ - - is_permanent: bool = False - - def restore(self, original: torch.nn.Module) -> torch.nn.Module: - """ - Restore the original module structure. - - Only needed if is_permanent=False. For permanent modules, this is a no-op. - - Returns: - The original module (or self if permanent) - """ - if self.is_permanent: - return self - raise NotImplementedError( - f"{self.__class__.__name__} has is_permanent=False but doesn't " - "implement restore()" - ) - - -@contextlib.contextmanager -def moe_calibration_context( - model: PreTrainedModel, - calibrate_all_experts: bool = True, -): - """ - Context manager that applies MoE calibration to a model. - - This scans all modules in the model and replaces any MoE modules with their - calibration equivalents. After the context exits, non-permanent modules are - restored to their original form. - - The model is modified in-place, so the same model object should be used - within the context. - - Args: - model: The model to apply MoE calibration to (modified in-place) - calibrate_all_experts: If True, all experts see all tokens during calibration. - If False, use normal routing (useful for some techniques) - - Example: - with moe_calibration_context(model): - # Run calibration - all experts will see data - for batch in dataloader: - model(**batch) - # Model is now restored (unless permanent) - """ - - replaced = {} - - # Step 1: Collect all MoE modules that need replacement - logger.debug("Entering MoE calibration context") - modules_to_replace = [] - for name, module in model.named_modules(): - class_name = module.__class__.__name__ - if _is_registered(class_name, MoECalibrationModule): - modules_to_replace.append((name, module, class_name)) - - # Step 2: Replace modules with progress bar - if modules_to_replace: - logger.info(f"Found {len(modules_to_replace)} MoE modules to replace") - for name, module, class_name in tqdm( - modules_to_replace, desc="Replacing MoE modules for calibration" - ): - replacement = MoECalibrationModule.load_from_registry( - class_name, - original=module, - config=model.config, - calibrate_all_experts=calibrate_all_experts, - ) - # Apply the same offloading settings as the original module - _apply_offloading_to_replacement(module, replacement) - - model.set_submodule(name, replacement) - - # Only store original if we need to restore it later - if replacement.is_permanent: - replaced[name] = (None, replacement) - del module - else: - replaced[name] = (module, replacement) - - if is_distributed(): - dist.barrier() - - # Log what was replaced - if replaced: - logger.info(f"Replaced {len(replaced)} MoE modules for calibration") - permanent_count = sum( - 1 for _, (_, repl) in replaced.items() if repl.is_permanent - ) - if permanent_count > 0: - logger.info( - f"{permanent_count}/{len(replaced)} modules will remain in " - "calibration form (permanent)" - ) - if permanent_count < len(replaced): - logger.info( - f"{len(replaced) - permanent_count}/{len(replaced)} modules will " - "be restored after calibration" - ) - - try: - yield - finally: - # Step 2: Restore non-permanent modules - for name, (original, replacement) in replaced.items(): - if not replacement.is_permanent: - restored = replacement.restore(original) - model.set_submodule(name, restored) - - -def _is_registered(name: str, subclass: RegistryMixin): - return standardize_lookup_name(name) in subclass.registered_names() - - -def _find_ancestor_with_offload_cache(module): - if isinstance(module._parameters, OffloadCache): - return module - - for child in module.children(): - child_val = _find_ancestor_with_offload_cache(child) - if child_val is not None: - return child_val - return None - - -def _apply_offloading_to_replacement( - original: torch.nn.Module, replacement: torch.nn.Module -): - """ - Apply the same offloading configuration from original to replacement module. - - If the original module or ANY of its children use OffloadCache, this recursively - applies the same offloading settings to all submodules of the replacement that - have parameters. - """ - - module_with_cache = _find_ancestor_with_offload_cache(original) - if module_with_cache is None: - return - - kwargs = get_cache_init_kwargs(module_with_cache) - - # Apply offloading to all submodules that have parameters - # and are not already offloaded - for module in replacement.modules(): - if isinstance(module._parameters, OffloadCache): - continue - if len(list(module.parameters(recurse=False))) == 0: - continue - - offload_module(module, **kwargs) diff --git a/src/llmcompressor/modeling/qwen3_5_moe.py b/src/llmcompressor/modeling/qwen3_5_moe.py deleted file mode 100644 index adbd8fbd01..0000000000 --- a/src/llmcompressor/modeling/qwen3_5_moe.py +++ /dev/null @@ -1,157 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import torch -import torch.nn.functional as F - -from llmcompressor.modeling.moe_context import MoECalibrationModule -from llmcompressor.utils.dev import skip_weights_initialize - -if TYPE_CHECKING: - from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( - Qwen3_5MoeSparseMoeBlock, - ) - - -@MoECalibrationModule.register("Qwen3_5MoeSparseMoeBlock") -class CalibrationQwen3_5MoeSparseMoeBlock(MoECalibrationModule): - """ - Calibration version of Qwen3_5MoeSparseMoeBlock that unfuses 3D expert - parameters into individual MLP modules (nn.Linear) so they can be - individually quantized. Sends all tokens to all experts during calibration. - - is_permanent = True because the unfused structure must persist for - quantization to target the individual nn.Linear expert weights. - """ - - is_permanent = True - - def __init__( - self, - original: Qwen3_5MoeSparseMoeBlock, - config, - calibrate_all_experts: bool = True, - ): - super().__init__() - text_config = getattr(config, "text_config", config) - - self.calibrate_all_experts = calibrate_all_experts - - # Use plain Linear for gate so module_type() returns "Linear" - # This ensures gates appear in the ignore list when config is saved - original_weight = original.gate.weight.data - self.gate = torch.nn.Linear( - text_config.hidden_size, text_config.num_experts, bias=False - ) - self.gate.weight.data = self.gate.weight.data.to(original_weight.dtype) - self.gate.weight.data.copy_(original_weight) - - # Store routing parameters needed for forward pass - self.top_k = text_config.num_experts_per_tok - self.num_experts = text_config.num_experts - self.hidden_dim = text_config.hidden_size - self.hidden_size = text_config.hidden_size - - self.shared_expert = original.shared_expert - self.shared_expert_gate = original.shared_expert_gate - self.experts = SequentialQwen3_5MoeExperts(text_config, original.experts) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - - # Perform routing (previously in Qwen3VLMoeTextTopKRouter.forward) - router_logits = F.linear(hidden_states_reshaped, self.gate.weight) - router_logits = F.softmax(router_logits, dtype=torch.float, dim=-1) - routing_weights, selected_experts = torch.topk( - router_logits, self.top_k, dim=-1 - ) - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(router_logits.dtype) - - # expert mask: (num_experts, top_k, num_tokens) - expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute( - 2, 1, 0 - ) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - for expert_idx, expert_layer in enumerate(self.experts): - idx, token_idx = torch.where(expert_mask[expert_idx]) - - if self.calibrate_all_experts: - expert_out = expert_layer(hidden_states_reshaped)[token_idx] - else: - expert_out = expert_layer(hidden_states_reshaped[token_idx]) - - if len(token_idx) > 0: - current_hidden_states = ( - expert_out * routing_weights[token_idx, idx, None] - ) - final_hidden_states.index_add_( - 0, - token_idx, - current_hidden_states.to(hidden_states.dtype), - ) - - # shared expert - shared_expert_output = self.shared_expert(hidden_states_reshaped) - shared_expert_output = ( - F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) - * shared_expert_output - ) - final_hidden_states = final_hidden_states + shared_expert_output - - final_hidden_states = final_hidden_states.reshape( - batch_size, sequence_length, hidden_dim - ) - return final_hidden_states - - def restore(self, original: torch.nn.Module) -> torch.nn.Module: - return self - - -class SequentialQwen3_5MoeExperts(torch.nn.ModuleList): - """ - Unfuses 3D expert parameter tensors into individual Qwen3_5MoeMLP modules - so that each expert's weights are nn.Linear and can be targeted by - quantization with targets="Linear". - """ - - def __init__(self, config, original): - from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( - Qwen3_5MoeMLP, - ) - - self.num_experts = config.num_experts - intermediate_size = config.moe_intermediate_size - - with skip_weights_initialize(): - super().__init__( - [ - Qwen3_5MoeMLP(config, intermediate_size=intermediate_size) - for _ in range(self.num_experts) - ] - ) - - gate_up_data = original.gate_up_proj.data # [num_experts, 2*inter, hidden] - down_data = original.down_proj.data # [num_experts, hidden, inter] - - for i in range(self.num_experts): - gate_up = gate_up_data[i] # [2*intermediate, hidden] - down = down_data[i] # [hidden, intermediate] - - # gate_up_proj stores [gate; up] stacked along dim 0 - # nn.Linear weight is [out_features, in_features] - self[i].gate_proj.weight.data = ( - gate_up[:intermediate_size, :].clone().contiguous() - ) - self[i].up_proj.weight.data = ( - gate_up[intermediate_size:, :].clone().contiguous() - ) - self[i].down_proj.weight.data = down.clone().contiguous() diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py deleted file mode 100644 index 890ac32c98..0000000000 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ /dev/null @@ -1,99 +0,0 @@ -# coding=utf-8 -# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from transformers.models import Qwen3MoeConfig -from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, -) - -from llmcompressor.modeling.moe_context import MoECalibrationModule - - -@MoECalibrationModule.register("Qwen3MoeSparseMoeBlock") -class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule): - """ - Calibration version of Qwen3MoeSparseMoeBlock that sends all tokens to all experts. - """ - - is_permanent = False - - def __init__( - self, - original: OriginalQwen3MoeSparseMoeBlock, - config: Qwen3MoeConfig, - calibrate_all_experts: bool = True, - ): - super().__init__() - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - - self.calibrate_all_experts = calibrate_all_experts - self.gate = original.gate - self.experts = original.experts - - def forward(self, hidden_states: torch.Tensor): - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = torch.nn.functional.softmax( - router_logits, dim=1, dtype=torch.float - ) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - if self.norm_topk_prob: # only diff with mixtral sparse moe block! - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.num_experts - ).permute(2, 1, 0) - - for expert_idx, expert_layer in enumerate(self.experts): - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - - if self.calibrate_all_experts: - expert_out = expert_layer(hidden_states)[top_x] - else: - expert_out = expert_layer(hidden_states[top_x]) - - # TODO: double check - if len(top_x) > 0: - current_hidden_states = expert_out * routing_weights[top_x, idx, None] - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) - - final_hidden_states = final_hidden_states.reshape( - batch_size, sequence_length, hidden_dim - ) - return final_hidden_states, router_logits - - def restore(self, original: torch.nn.Module) -> torch.nn.Module: - return original diff --git a/src/llmcompressor/modeling/qwen3_next_moe.py b/src/llmcompressor/modeling/qwen3_next_moe.py deleted file mode 100644 index d74141c012..0000000000 --- a/src/llmcompressor/modeling/qwen3_next_moe.py +++ /dev/null @@ -1,125 +0,0 @@ -from __future__ import annotations - -# coding=utf-8 -# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING - -import torch - -from llmcompressor.modeling.moe_context import MoECalibrationModule - -if TYPE_CHECKING: - from transformers import Qwen3NextConfig - from transformers.models.qwen3_next.modeling_qwen3_next import ( - Qwen3NextSparseMoeBlock, - ) - - -@MoECalibrationModule.register("Qwen3NextSparseMoeBlock") -class CalibrationQwen3NextSparseMoeBlock(MoECalibrationModule): - """ - Calibration version of Qwen3NextSparseMoeBlock that sends all tokens to all experts. - """ - - is_permanent = False - - def __init__( - self, - original: Qwen3NextSparseMoeBlock, - config: Qwen3NextConfig, - calibrate_all_experts: bool = True, - ): - super().__init__() - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - - # gating - self.calibrate_all_experts = calibrate_all_experts - self.gate = original.gate - self.experts = original.experts - - self.shared_expert = original.shared_expert - self.shared_expert_gate = original.shared_expert_gate - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = torch.nn.functional.softmax( - router_logits, dim=-1, dtype=torch.float - ) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be - # sollicitated - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.num_experts - ).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the - # computation on each expert - # expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx, expert_layer in enumerate(self.experts): - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - - if self.calibrate_all_experts: - expert_out = expert_layer(hidden_states)[top_x] - else: - expert_out = expert_layer(hidden_states[top_x]) - - # Index the correct hidden states and compute the expert hidden - # state for the current expert. We need to make sure to multiply - # the output hidden states by `routing_weights` on the - # corresponding tokens (top-1 and top-2) - if len(top_x) > 0: - current_hidden_states = expert_out * routing_weights[top_x, idx, None] - final_hidden_states.index_add_( - 0, - top_x, - current_hidden_states.to(hidden_states.dtype), - ) - - shared_expert_output = self.shared_expert(hidden_states) - shared_expert_output = ( - torch.nn.functional.sigmoid(self.shared_expert_gate(hidden_states)) - * shared_expert_output - ) - - final_hidden_states = final_hidden_states + shared_expert_output - final_hidden_states = final_hidden_states.reshape( - batch_size, sequence_length, hidden_dim - ) - return final_hidden_states, router_logits - - def restore(self, original: torch.nn.Module) -> torch.nn.Module: - return original diff --git a/src/llmcompressor/modeling/qwen3_vl_moe.py b/src/llmcompressor/modeling/qwen3_vl_moe.py deleted file mode 100644 index da941a3b86..0000000000 --- a/src/llmcompressor/modeling/qwen3_vl_moe.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -from llmcompressor.modeling.moe_context import MoECalibrationModule -from llmcompressor.utils.dev import skip_weights_initialize - -if TYPE_CHECKING: - from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( - Qwen3VLMoeTextSparseMoeBlock, - ) - - -@MoECalibrationModule.register("Qwen3VLMoeTextSparseMoeBlock") -class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule): - """ - Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all - experts. - """ - - is_permanent = True - - def __init__( - self, - original: "Qwen3VLMoeTextSparseMoeBlock", - config: "Qwen3VLMoeConfig", - calibrate_all_experts: bool, - ): - super().__init__() - text_config: "Qwen3VLMoeTextConfig" = config.get_text_config() - - self.hidden_size = text_config.hidden_size - self.num_experts = text_config.num_experts - self.top_k = original.top_k - # Note: gate was changed to be a Linear layer in transformers==4.57.0 - # https://github.com/JJJYmmm/transformers/commit/f5dea1c694af8c994c769170813a8702332119ee - self.gate = original.gate - self.calibrate_all_experts = calibrate_all_experts - self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.reshape(-1, hidden_dim) - - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - routing_weights = torch.nn.functional.softmax( - router_logits, dim=1, dtype=torch.float - ) - # get topk experts per token - # routing_weight: (num_tokens, top_k) - # routing_indices: (num_tokens, top_k) - routing_weights, router_indices = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) - - next_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - # convert router indices into OHE list - # reshape to be (num_experts, top_k, batch_size * sequence_length) - expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=self.num_experts - ).permute(2, 1, 0) - - for expert_idx, expert_layer in enumerate(self.experts): - idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) - - if self.calibrate_all_experts: - expert_out = expert_layer(hidden_states)[token_idx] - else: - expert_out = expert_layer(hidden_states[token_idx]) - - if len(token_idx) > 0: - # if there are tokens meant for this expert, further scale the expert - # output by the score - weighted_output = expert_out * routing_weights[token_idx, idx, None] - next_states.index_add_( - 0, token_idx, weighted_output.to(hidden_states.dtype) - ) - - next_states = next_states.reshape(batch_size, sequence_length, hidden_dim) - return next_states, router_logits - - def restore(self, original: torch.nn.Module) -> torch.nn.Module: - return original - - -class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList): - def __init__(self, config, original): - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( - Qwen3VLMoeTextMLP, - ) - - self.num_experts = original.gate_up_proj.shape[0] - with skip_weights_initialize(): - super().__init__( - [Qwen3VLMoeTextMLP(config) for _ in range(self.num_experts)] - ) - - intermediate_size = original.down_proj.shape[1] - - for i in range(self.num_experts): - gate_up = original.gate_up_proj[i] - down = original.down_proj[i] - - gate_proj = gate_up[:, :intermediate_size] - up_proj = gate_up[:, intermediate_size:] - - self[i].gate_proj.weight.data = gate_proj.t().clone().contiguous() - self[i].up_proj.weight.data = up_proj.t().clone().contiguous() - self[i].down_proj.weight.data = down.t().clone().contiguous() diff --git a/src/llmcompressor/modifiers/gptq/base.py b/src/llmcompressor/modifiers/gptq/base.py index 489e2f0c84..97a380d262 100644 --- a/src/llmcompressor/modifiers/gptq/base.py +++ b/src/llmcompressor/modifiers/gptq/base.py @@ -318,7 +318,7 @@ def compress_module_list(self, module_list): comp_logger.set_results(name="GPTQ", loss=loss) for attr, val in q_param_dict.items(): - update_offload_parameter(module, attr, val) + update_offload_parameter(module, attr, val, source_rank=dist.get_rank()) def _reduce_hessian_to_target_rank(self, module_list, module_to_rank): rank = dist.get_rank() diff --git a/src/llmcompressor/modifiers/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/gptq/gptq_quantize.py index 5e3f721589..12e4bcd53f 100644 --- a/src/llmcompressor/modifiers/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/gptq/gptq_quantize.py @@ -23,6 +23,9 @@ def make_empty_hessian( module: torch.nn.Module, device: torch.device | None = None ) -> torch.Tensor: + if not isinstance(module, torch.nn.Linear): + raise ValueError(f"Cannot quantize layer type `{module.__class__.__name__}`") + weight = module.weight num_columns = weight.shape[1] device = device if device is not None else weight.device @@ -36,29 +39,18 @@ def accumulate_hessian( num_samples: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: inp = inp.to(device=H.device) - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - num_added = inp.shape[0] - - match module: - case torch.nn.Linear() | transformers.Conv1D(): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - case torch.nn.Conv2d(): - unfold = torch.nn.Unfold( - module.kernel_size, - dilation=module.dilation, - padding=module.padding, - stride=module.stride, - ) - inp = unfold(inp) - inp = inp.permute([1, 0, 2]) - inp = inp.flatten(1) + # ensure batch and sequence dims are populated + while inp.ndim < 3: + inp = inp.unsqueeze(0) + # count samples from batch length (not sequence length) + num_added = sum(inp.shape[:-2]) num_samples += num_added + inp = inp.flatten(0, -2) + inp = inp.t() + inp = inp.to(dtype=GPTQ_PRECISION) inp = math.sqrt(2) * inp H += inp.matmul(inp.t()) diff --git a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py index f327a4c34d..b0c05fdaf0 100644 --- a/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py +++ b/src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py @@ -24,18 +24,16 @@ def accumulate_hessian( num_samples: int, ) -> Tuple[torch.Tensor, int]: inp = inp.to(device=H.device) - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - - num_added = inp.shape[0] # note this is the number of dataset samples, not - # multiplied by the sequence length if isinstance(module, (torch.nn.Linear, transformers.Conv1D)): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) + num_added = inp[..., 0].numel() + inp = inp.reshape(-1, inp.shape[-1]) inp = inp.t() if isinstance(module, torch.nn.Conv2d): + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + num_added = inp.shape[0] unfold = torch.nn.Unfold( module.kernel_size, dilation=module.dilation, diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index bbae001c5a..4e2dbfcb2e 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -199,10 +199,17 @@ def __init__(self, ancestors: set[Module], offloaded: set[Module]): def create_arg(self, a: Any) -> Argument: # special extension allows models which depend on config values to be traced + # whenever the config instance is referenced, create a new config on the fly + # with the same class and kwargs as the original if isinstance(a, PretrainedConfig): kwargs = {k: self.create_arg(v) for k, v in a.to_dict().items()} return self.create_node("call_function", a.__class__, (), kwargs) + # special extension allows models which pass kv caches to be traced + # usually use of `past_key_values` is disabled by `calibration_forward_context` + # however, some models such as deepseekv4 require their use + # if isinstance(a, ) + else: return super().create_arg(a) diff --git a/src/llmcompressor/pipelines/sequential/transformers_helpers.py b/src/llmcompressor/pipelines/sequential/transformers_helpers.py index 2c4f4f4a10..f8ddc2cafa 100644 --- a/src/llmcompressor/pipelines/sequential/transformers_helpers.py +++ b/src/llmcompressor/pipelines/sequential/transformers_helpers.py @@ -31,6 +31,7 @@ import torch.utils._pytree as pytree from torch import nn from torch.fx import Graph, GraphModule, Node, Proxy, Tracer +from torch.fx.node import Argument from torch.fx._compatibility import compatibility from torch.fx._symbolic_trace import is_fx_tracing from torch.fx.proxy import ParameterProxy @@ -804,9 +805,17 @@ def check_proxy(a): return wrapper -class HFProxyableClassMeta(type): +class HFProxyableCacheMeta(type): """ Metaclass that creates a class with its main methods wrapped to be proxyable. + + In the same way that objects can be replaced with `Proxy`s during trace-time, + classes can also be replaced with `ProxyableClass`es during trace-time. + This meta class acts as factory for creating `ProxyableClass`es of `Cache` classes. + + At trace-time, all references to `Cache` classes are monkeypatched with references + to `ProxyableCache` classes. Whenever this class is used to construct a new instance + of a `Cache`, an instance of `HFCacheProxy` is generated instead. """ def __new__( @@ -816,19 +825,34 @@ def __new__( attrs: dict[str, Any], proxy_factory_fn: Callable[[Node], Proxy] | None = None, ): - instance = super().__new__(cls, name, bases, attrs) + if len(bases) != 1: + raise ValueError( + "`HFProxyableCacheMeta` only supports creating proxies" + " with single class inheritance. Compose your classes directly " + "before creating the class with this meta" + ) + + instance = super().__new__(cls, name, bases, attrs) # "instance" of Type[Cache] for attr_name in dir(instance): attr = getattr(instance, attr_name, None) if attr is None: continue + if attr_name == "__new__": + setattr( + instance, + attr_name, + cls.create__new__wrapper(bases[0], proxy_factory_fn), + ) + continue + if attr_name == "__init__": op_type = "call_function" elif attr_name.startswith("__"): op_type = None elif inspect.ismethod(attr): - op_type = "call_function" - elif inspect.isfunction(attr): op_type = "call_method" + elif inspect.isfunction(attr): + op_type = "call_function" else: op_type = None if op_type is not None: @@ -839,6 +863,53 @@ def __new__( ) return instance + def create__new__wrapper( + orig_cache_cls: type[Cache], + proxy_factory_fn: Callable[[Node], Proxy], + ): + """ + Mirrors `create_wrapper`, but only used to override the `__new__` method. + Whenever this class is used to construct a new instance of a `Cache`, an + instance of `HFCacheProxy` is generated instead. + + The `HFCacheProxy` class allows caches to be traced through the fx graph. + + :param orig_cache_cls: `Cache` class being proxied + :param proxy_factory_fn: function which converts an instance of `Cache` to + an instance of `HFCacheProxy` + :return: wrapper function used to replace the `__new__` method of + `HFProxyableClass` + """ + + def wrapper(*args, **kwargs): + if not isinstance(_CURRENT_TRACER, HFTracer): + raise RuntimeError( + "Cannot create HFCacheProxy because " + "there is no HFTracer currently tracing." + ) + + # unfortunately, there is no way easy way to call just `__new__` + # without also calling `__init__`. Calling as a method means that the `type` + # will end up in the fx graph, and fx graphs cannot encode `type`s. Calling + # as `orig_cache_cls.__init__` yields an error when `_find_module_of_method` + # attempts to find the original module to create a qualified name for graph + + # kind: tell node to call target + # target: class constructor for original cache class + # args: positional arguments to class constructor + # kwargs: keyword arguments to class constructor + # proxy_factory_fn: converts instance of `orig_cache_cls` + # into an instance of `HFCacheProxy` + return _CURRENT_TRACER.create_proxy( + kind="call_function", + target=orig_cache_cls, + args=args[1:], + kwargs=kwargs, # typically {"config": PreTrainedConfig(...)} + proxy_factory_fn=proxy_factory_fn, + ) + + return wrapper + def gen_constructor_wrapper(target: Callable) -> tuple[Callable, Callable]: """ @@ -875,19 +946,19 @@ def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`. -ProxyableCache = HFProxyableClassMeta( +ProxyableCache = HFProxyableCacheMeta( "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache), ) -ProxyableDynamicCache = HFProxyableClassMeta( +ProxyableDynamicCache = HFProxyableCacheMeta( "ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache), ) -ProxyableStaticCache = HFProxyableClassMeta( +ProxyableStaticCache = HFProxyableCacheMeta( "ProxyableStaticCache", (StaticCache,), {}, @@ -1325,17 +1396,21 @@ def patch_for_tracing(self, root: torch.nn.Module | Callable[..., Any]): # Patching classes patched = [] - module_of_model = inspect.getmodule(root) - for name, mod in sys.modules.items(): - if module_of_model is not None and mod is not module_of_model: - continue - if not name.startswith("transformers"): - continue + if isinstance(root, torch.nn.Module): + forward_fns = set(module.forward for module in root.modules()) + else: + forward_fns = set([root]) + logger.warning( + "Cannot patch all classes for tracing, " + "please pass model to HFTracer.trace()" + ) + + for forward_fn in forward_fns: for orig_cls, patched_cls in self._CLASSES_TO_PATCH.items(): - for attr_name, attr in mod.__dict__.items(): + for attr_name, attr in forward_fn.__globals__.items(): if attr is orig_cls: - patched.append((mod, attr_name, orig_cls)) - setattr(mod, attr_name, patched_cls) + patched.append((forward_fn.__globals__, attr_name, orig_cls)) + forward_fn.__globals__[attr_name] = patched_cls yield @@ -1345,8 +1420,8 @@ def patch_for_tracing(self, root: torch.nn.Module | Callable[..., Any]): self.patched_torch_methods = {} self.orig_fns = set() - for mod, attr_name, orig_cls in patched: - setattr(mod, attr_name, orig_cls) + for forward_fn_globals, attr_name, orig_cls in patched: + forward_fn_globals[attr_name] = orig_cls def trace( self, diff --git a/src/llmcompressor/utils/dev.py b/src/llmcompressor/utils/dev.py index c948e9c3bf..3291d261e2 100644 --- a/src/llmcompressor/utils/dev.py +++ b/src/llmcompressor/utils/dev.py @@ -31,7 +31,7 @@ @contextlib.contextmanager -def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM): +def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM, skip_init: bool = True): """ Context manager under which models are initialized without having to download the model weight files. This differs from `init_empty_weights` in that weights are @@ -39,6 +39,8 @@ def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausa device :param model_class: class to patch, defaults to `AutoModelForCausalLM` + :param skip_init: skip weight initialization, which can be costly when allocating + a large number of weights """ original_fn = model_class.from_pretrained weights_files = [ @@ -80,7 +82,7 @@ def patched(cls, *args, **kwargs): with ( tempfile.TemporaryDirectory() as tmp_dir, patch_attr(model_class, "from_pretrained", patched), - skip_weights_initialize(), + skip_weights_initialize() if not skip_init else contextlib.nullcontext(), patch_transformers_logger_level(), ): yield diff --git a/tests/llmcompressor/modeling/test_linearize.py b/tests/llmcompressor/modeling/test_linearize.py new file mode 100644 index 0000000000..45924d8d2c --- /dev/null +++ b/tests/llmcompressor/modeling/test_linearize.py @@ -0,0 +1,35 @@ +from transformers import AutoModelForCausalLM +from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.modeling.moe.linearize import load_linearized_moe + +# def test_linearize_moe_model(tmp_dir): +# input_ids = # random 2048 tokens + +# with skip_weights_download(skip_init=False): +# model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-V4-Flash", num_hidden_layers=3, num_hash_layers=1, num_nextn_predict_layers=1) + +# true_outputs = model(*input_ids) +# model.save_pretrained(tmp_dir) +# del model + +# with load_linearized_moe(): +# linearized_model = AutoModelForCausalLM.from_pretrained(tmp_dir) + +# # check calibrate_all_experts=True +# with ...: +# linearized_model = model(*input_ids) +# assert linearized_model == true_outputs + +# # check calibrate_all_experts=False +# with ...: +# linearized_model = model(*input_ids) +# assert linearized_model == true_outputs + + +# linearized_model.save_pretrained(model) +# # check checkpoint keys for experts.N.up_proj structure + + +with skip_weights_download(skip_init=False): + model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-V4-Flash", num_hidden_layers=3, num_hash_layers=1, num_nextn_predict_layers=1) + \ No newline at end of file diff --git a/tests/llmcompressor/pipelines/sequential/test_cache_tracing.py b/tests/llmcompressor/pipelines/sequential/test_cache_tracing.py new file mode 100644 index 0000000000..edc90c32e5 --- /dev/null +++ b/tests/llmcompressor/pipelines/sequential/test_cache_tracing.py @@ -0,0 +1,83 @@ +import torch +from transformers.cache_utils import DynamicCache + +from llmcompressor.pipelines.sequential.transformers_helpers import HFTracer + + +class DummyModelWithCache(torch.nn.Module): + """ + Generates the following table after tracing and `graph.print_tabular()`: + + opcode name target args + ------------- ---------------- ----------------------------------------------- ------------------------------------ + placeholder x x () + call_function dynamic_cache () + call_module linear linear (x,) + call_module linear_1 linear (x,) + call_method update update (dynamic_cache, linear, linear_1, 0) + call_function getitem (update, 0) + call_function getitem_1 (update, 1) + call_method get_seq_length get_seq_length (dynamic_cache, 0) + call_function add (getitem, get_seq_length) + call_function getattr_1 (dynamic_cache, 'layers') + call_function getitem_2 (getattr_1, 0) + call_method get_seq_length_1 get_seq_length (getitem_2,) + call_function add_1 (getitem_1, get_seq_length_1) + output output output ((add, add_1, dynamic_cache),) + """ # noqa: E501 + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x): + cache = DynamicCache() + key_states = self.linear(x) + value_states = self.linear(x) + + if torch.fx._symbolic_trace.is_fx_symbolic_tracing(): + # `DynamicCache` has been monkeypatched to ProxyableDynamicCache + # but still betrays itself via the `__name__` attribute + assert DynamicCache.__name__ == "ProxyableDynamicCache" + + # `cache` has its `__class__` attribute overwritten to look like + # `DynamicCache`, but betrays itself by implementing the below method + assert hasattr(cache, "install_orig_cache_cls") + + return self._dummy_cache_ops(key_states, value_states, cache) + + def _dummy_cache_ops(self, key_states, value_states, cache: DynamicCache): + key_states, value_states = cache.update(key_states, value_states, 0) + key_states += cache.get_seq_length(0) + value_states += cache.layers[0].get_seq_length() + + return key_states, value_states, cache + + +def test_dynamic_cache_produces_hf_cache_proxy_node(): + model = DummyModelWithCache() + tracer = HFTracer() + graph = tracer.trace(model, dummy_inputs={"x": torch.randn(1, 10)}) + + nodes = {node.name: node for node in graph.nodes} + + # DynamicCache is traced as a call_function node + assert "dynamic_cache" in nodes + cache_node = nodes["dynamic_cache"] + assert cache_node.op == "call_function" + assert cache_node.target is DynamicCache + + # cache.update() is traced as a call_method on the cache proxy + assert "update" in nodes + update_node = nodes["update"] + assert update_node.op == "call_method" + assert update_node.args[0] is cache_node + + # cache.get_seq_length() is traced as a call_method on the cache proxy + assert "get_seq_length" in nodes + assert nodes["get_seq_length"].op == "call_method" + assert nodes["get_seq_length"].args[0] is cache_node + + # the cache proxy is passed through to the output + output_node = nodes["output"] + assert cache_node in output_node.args[0]