diff --git a/examples/quantization_w4a8/gpt_oss_20b.py b/examples/quantization_w4a8/gpt_oss_20b.py new file mode 100644 index 0000000000..fd7998c7e2 --- /dev/null +++ b/examples/quantization_w4a8/gpt_oss_20b.py @@ -0,0 +1,322 @@ +""" +GPT-OSS Model Quantization Example + +This script demonstrates quantizing GPT-OSS models using various quantization +algorithms: W4A8, AWQ, and GPTQ. + +Usage: + # Basic W4A8 quantization + python gpt_oss_20b.py --algorithm w4a8 + + # AWQ quantization + python gpt_oss_20b.py --algorithm awq + + # GPTQ quantization + python gpt_oss_20b.py --algorithm gptq + + # Custom options + python gpt_oss_20b.py \ + --algorithm gptq \ + --model openai/gpt-oss-20b \ + --num-samples 512 \ + --max-seq-length 2048 \ + --output-dir my-quantized-model +""" + +import argparse +from enum import Enum + +import torch +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, + QuantizationType, +) +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modeling.gpt_oss import ( + convert_model_for_quantization_gptoss, +) +from llmcompressor.modifiers.awq import AWQModifier +from llmcompressor.modifiers.quantization import ( + GPTQModifier, + QuantizationModifier, +) +from llmcompressor.utils import dispatch_for_generation + + +class QuantizationAlgorithm(str, Enum): + """Supported quantization algorithms for GPT-OSS.""" + + W4A8 = "w4a8" + AWQ = "awq" + GPTQ = "gptq" + + +def create_recipe(algorithm): + """Create quantization recipe based on algorithm.""" + + # Shared weights configuration for all algorithms + weights_args = QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + strategy=QuantizationStrategy.CHANNEL, + symmetric=True, + dynamic=False, + ) + + if algorithm == QuantizationAlgorithm.W4A8: + # W4A8 is unique - includes 8-bit activation quantization + activations_args = QuantizationArgs( + num_bits=8, + type=QuantizationType.INT, + strategy=QuantizationStrategy.TOKEN, + symmetric=False, + dynamic=True, + observer=None, + ) + + scheme = QuantizationScheme( + targets=["Linear"], + weights=weights_args, + input_activations=activations_args, + ) + + return QuantizationModifier( + config_groups={"group_0": scheme}, + ignore=["lm_head"], + ) + + # AWQ and GPTQ share the same config_groups pattern + config_groups = { + "group_0": { + "targets": ["Linear"], + "weights": weights_args, + } + } + + if algorithm == QuantizationAlgorithm.AWQ: + return AWQModifier( + targets=["Linear"], + ignore=["lm_head", "re:.*router$"], + config_groups=config_groups, + ) + + elif algorithm == QuantizationAlgorithm.GPTQ: + return GPTQModifier( + targets=["Linear"], + ignore=["lm_head", "re:.*router$"], + config_groups=config_groups, + ) + + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Quantize GPT-OSS models with various algorithms" + ) + parser.add_argument( + "--algorithm", + type=QuantizationAlgorithm, + choices=list(QuantizationAlgorithm), + default=QuantizationAlgorithm.W4A8, + help="Quantization algorithm to use (default: w4a8)", + ) + parser.add_argument( + "--model", + type=str, + default="openai/gpt-oss-20b", + help="Model ID from HuggingFace Hub (default: openai/gpt-oss-20b)", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Output directory (default: {model_name}-{algorithm})", + ) + parser.add_argument( + "--num-samples", + type=int, + default=256, + help="Number of calibration samples (default: 256)", + ) + parser.add_argument( + "--max-seq-length", + type=int, + default=2048, + help="Maximum sequence length (default: 2048)", + ) + parser.add_argument( + "--dataset", + type=str, + default="HuggingFaceH4/ultrachat_200k", + help="Calibration dataset ID (default: HuggingFaceH4/ultrachat_200k)", + ) + parser.add_argument( + "--dataset-split", + type=str, + default="train_sft", + help="Dataset split to use (default: train_sft)", + ) + parser.add_argument( + "--no-calibrate-all-experts", + action="store_true", + help="Disable calibrate_all_experts mode (not recommended)", + ) + parser.add_argument( + "--skip-generation-test", + action="store_true", + help="Skip generation test after quantization", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Use sensible defaults if not provided + num_samples = args.num_samples + max_seq_length = args.max_seq_length + + # Set output directory + base_name = args.model.rstrip("/").split("/")[-1] + output_dir = args.output_dir or f"{base_name}-{args.algorithm.value}" + + print("=" * 70) + print(f"GPT-OSS {args.algorithm.value.upper()} Quantization") + print("=" * 70) + print(f"Model: {args.model}") + print(f"Algorithm: {args.algorithm.value.upper()}") + print(f"Calibration samples: {num_samples}") + print(f"Max sequence length: {max_seq_length}") + print(f"Output directory: {output_dir}") + print( + f"Calibrate all experts: {not args.no_calibrate_all_experts} (recommended)" + ) + print("=" * 70) + + print(f"\n[1/6] Loading model: {args.model}") + model = AutoModelForCausalLM.from_pretrained( + args.model, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained( + args.model, trust_remote_code=True + ) + print("Model loaded successfully") + + print("\n[2/6] Converting MoE experts for quantization...") + print( + " This linearizes fused expert weights into separate projections" + ) + convert_model_for_quantization_gptoss( + model, calibrate_all_experts=not args.no_calibrate_all_experts + ) + print("Conversion completed") + + print(f"\n[3/6] Loading calibration dataset: {args.dataset}") + ds = load_dataset( + args.dataset, split=f"{args.dataset_split}[:{num_samples}]" + ) + ds = ds.shuffle(seed=42) + + def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + ds = ds.map(preprocess) + + # Tokenize for GPTQ (required for GPTQ, optional for others) + if args.algorithm == QuantizationAlgorithm.GPTQ: + + def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=max_seq_length, + truncation=True, + add_special_tokens=False, + ) + + ds = ds.map(tokenize, remove_columns=ds.column_names) + + print(f"Loaded {len(ds)} calibration samples") + + algo_name = args.algorithm.value.upper() + print(f"\n[4/6] Creating {algo_name} quantization recipe...") + recipe = create_recipe(args.algorithm) + print("Recipe created") + + print(f"\n[5/6] Running {algo_name} quantization...") + print(" This will calibrate all experts for optimal quantization") + if args.algorithm == QuantizationAlgorithm.GPTQ: + print( + " GPTQ uses layer-wise reconstruction (this may take a while)" + ) + elif args.algorithm == QuantizationAlgorithm.AWQ: + print(" AWQ analyzes activation patterns for optimal scales") + + # GPTQ requires pre-tokenized dataset, so we pass None for tokenizer + use_tokenizer = ( + None if args.algorithm == QuantizationAlgorithm.GPTQ else tokenizer + ) + + oneshot( + model=model, + dataset=ds, + recipe=recipe, + tokenizer=use_tokenizer, + max_seq_length=max_seq_length, + num_calibration_samples=num_samples, + save_compressed=False, + output_dir=output_dir, + ) + print("Quantization completed") + + if not args.skip_generation_test: + print("\n[6/6] Testing generation with quantized model...") + dispatch_for_generation(model) + test_prompt = "Hello, my name is" + inputs = tokenizer(test_prompt, return_tensors="pt") + inputs = {k: v.to(model.device) for k, v in inputs.items()} + output = model.generate(**inputs, max_new_tokens=50) + generated_text = tokenizer.decode(output[0]) + print(f" Prompt: {test_prompt}") + print(f" Generated: {generated_text}") + print("Generation test passed") + else: + print("\n[6/6] Skipping generation test") + + print(f"\nSaving quantized model to: {output_dir}") + print("Model saved successfully") + + # ---- Display vLLM Instructions ---- + print("\n" + "=" * 70) + print("Quantization Complete!") + print("=" * 70) + print(f"Quantized model saved to: {output_dir}") + print("\nTo run inference with vLLM:") + print("-" * 70) + print("from vllm import LLM, SamplingParams\n") + print(f'model = LLM(model="{output_dir}", trust_remote_code=True)') + print('prompts = ["Hello, my name is"]') + print("sampling_params = SamplingParams(temperature=0.7, max_tokens=100)") + print("outputs = model.generate(prompts, sampling_params)\n") + print("for output in outputs:") + print(" print(output.outputs[0].text)") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/examples/quantization_w4a8/gpt_oss_20b_example.py b/examples/quantization_w4a8/gpt_oss_20b_example.py deleted file mode 100644 index 60a090572f..0000000000 --- a/examples/quantization_w4a8/gpt_oss_20b_example.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -from compressed_tensors.quantization import QuantizationScheme -from compressed_tensors.quantization.quant_args import ( - QuantizationArgs, - QuantizationStrategy, - QuantizationType, -) -from transformers import AutoModelForCausalLM, AutoTokenizer - -from llmcompressor import oneshot -from llmcompressor.modeling.gpt_oss import convert_model_for_quantization_gptoss -from llmcompressor.modifiers.quantization import QuantizationModifier - - -def main(): - MODEL_ID = "openai/gpt-oss-20b" - BASE_NAME = MODEL_ID.rstrip("/").split("/")[-1] - OUTPUT_DIR = f"{BASE_NAME}-w4a8-channelwise" - - print(f"[GPT-OSS] Loading model: {MODEL_ID}") - model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, - torch_dtype=torch.bfloat16, - device_map="auto", - trust_remote_code=True, - ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) - - # ---- GPT-OSS MoE → linear experts conversion ---- - print("[GPT-OSS] Converting fused MoE experts to LinearExperts for quantization...") - convert_model_for_quantization_gptoss(model) - print("[GPT-OSS] Conversion completed.") - - # ---- Quantization config: W4A8 (int4 weights, int8 activations) ---- - - # Weights: 4-bit, channelwise, symmetric, static - weights_args = QuantizationArgs( - num_bits=4, - type=QuantizationType.INT, - strategy=QuantizationStrategy.CHANNEL, - symmetric=True, - dynamic=False, - ) - - # Activations: 8-bit, per-token, asymmetric, dynamic - activations_args = QuantizationArgs( - num_bits=8, - type=QuantizationType.INT, - strategy=QuantizationStrategy.TOKEN, - symmetric=False, - dynamic=True, - observer=None, - ) - - # Apply to all Linear layers, excluding lm_head - scheme = QuantizationScheme( - targets=["Linear"], - weights=weights_args, - input_activations=activations_args, - ) - - recipe = QuantizationModifier( - config_groups={"group_0": scheme}, - ignore=["lm_head"], - ) - - print(f"[GPT-OSS] Starting oneshot quantization → {OUTPUT_DIR}") - oneshot( - model=model, - recipe=recipe, - tokenizer=tokenizer, - output_dir=OUTPUT_DIR, - trust_remote_code_model=True, - ) - print(f"[GPT-OSS] Quantization finished. Quantized model written to: {OUTPUT_DIR}") - - -if __name__ == "__main__": - main() diff --git a/src/llmcompressor/modeling/gpt_oss.py b/src/llmcompressor/modeling/gpt_oss.py index 44258a3927..d69acce02d 100644 --- a/src/llmcompressor/modeling/gpt_oss.py +++ b/src/llmcompressor/modeling/gpt_oss.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn +from llmcompressor.modeling.moe_context import MoECalibrationModule + class LinearExpert(nn.Module): """ @@ -19,7 +21,11 @@ class LinearExpert(nn.Module): """ def __init__( - self, hidden_size: int, intermediate_size: int, alpha: float, limit: float + self, + hidden_size: int, + intermediate_size: int, + alpha: float, + limit: float, ): super().__init__() self.alpha = alpha @@ -106,24 +112,13 @@ def copy_from_fused_weights( expert.down_proj.weight.copy_(legacy_down_W[i].t()) expert.down_proj.bias.copy_(legacy_down_b[i]) - def forward( + def _normalize_shapes( self, - hidden_states: torch.Tensor, # [B, T, H] - router_indices: Optional[ - torch.Tensor - ] = None, # [B, T, top_k] or [tokens, top_k] - routing_weights: Optional[torch.Tensor] = None, # [B, T, E] or [tokens, E] - ) -> torch.Tensor: - """ - Implements the MoE computation using the router outputs. - - This is compatible with the GPT-OSS MoE call pattern: - experts(hidden_states, router_indices, routing_weights) - """ - assert ( - routing_weights is not None and router_indices is not None - ), "router inputs required" - + hidden_states: torch.Tensor, + router_indices: torch.Tensor, + routing_weights: torch.Tensor, + ): + """Normalize input shapes to 2D format for processing.""" # Normalize shapes to [tokens, H], [tokens, top_k], [tokens, E] if hidden_states.dim() == 3: B, T, H = hidden_states.shape @@ -135,10 +130,24 @@ def forward( x = hidden_states if router_indices.dim() == 3: - router_indices = router_indices.reshape(-1, router_indices.shape[-1]) + router_indices = router_indices.reshape( + -1, router_indices.shape[-1] + ) if routing_weights.dim() == 3: - routing_weights = routing_weights.reshape(-1, routing_weights.shape[-1]) + routing_weights = routing_weights.reshape( + -1, routing_weights.shape[-1] + ) + + return x, router_indices, routing_weights, B, H + def _route_and_compute( + self, + x: torch.Tensor, + router_indices: torch.Tensor, + routing_weights: torch.Tensor, + calibrate_all: bool = False, + ) -> torch.Tensor: + """Shared routing logic for expert computation.""" num_experts_plus_dummy = routing_weights.shape[1] out = torch.zeros_like(x) @@ -147,7 +156,9 @@ def forward( expert_mask = torch.nn.functional.one_hot( router_indices, num_classes=num_experts_plus_dummy ).permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + expert_hit = torch.greater( + expert_mask.sum(dim=(-1, -2)), 0 + ).nonzero() for idx in expert_hit: e = idx[0].item() @@ -156,14 +167,101 @@ def forward( continue _, token_idx = torch.where(expert_mask[e]) - xi = x[token_idx] - expert = self.experts[e] - yi = expert(xi) + + if calibrate_all: + # Process all tokens through expert for calibration + yi = expert(x)[token_idx] + else: + # Normal routing: only process assigned tokens + xi = x[token_idx] + yi = expert(xi) w = routing_weights[token_idx, e, None] out.index_add_(0, token_idx, (yi * w).to(out.dtype)) + return out + + def forward( + self, + hidden_states: torch.Tensor, + router_indices: Optional[torch.Tensor] = None, + routing_weights: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Implements the MoE computation using the router outputs. + + This is compatible with the GPT-OSS MoE call pattern: + experts(hidden_states, router_indices, routing_weights) + """ + assert ( + routing_weights is not None and router_indices is not None + ), "router inputs required" + + x, router_indices, routing_weights, B, H = self._normalize_shapes( + hidden_states, router_indices, routing_weights + ) + + out = self._route_and_compute(x, router_indices, routing_weights) + return out.view(B, -1, H) + + +@MoECalibrationModule.register("GptOssExperts") +class CalibrationLinearExperts(LinearExperts, MoECalibrationModule): + """ + Calibration version of LinearExperts that sends all tokens to all experts. + + This module wraps the already-linearized LinearExperts to provide + calibration support during quantization. Since LinearExperts already has + the correct structure (separate gate/up/down projections), just add the + calibrate_all_experts functionality. + """ + + is_permanent = True + + def __init__( + self, + original: LinearExperts, + config, + calibrate_all_experts: bool = True, + ): + # Don't call LinearExperts.__init__, just copy attributes + nn.Module.__init__(self) + self.hidden_size = original.hidden_size + self.expert_dim = original.expert_dim + self.num_experts = original.num_experts + self.alpha = original.alpha + self.limit = original.limit + self.experts = original.experts + self.calibrate_all_experts = calibrate_all_experts + + def forward( + self, + hidden_states: torch.Tensor, + router_indices: Optional[torch.Tensor] = None, + routing_weights: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Implements the MoE computation using the router outputs. + + This is compatible with the GPT-OSS MoE call pattern: + experts(hidden_states, router_indices, routing_weights) + + When calibrate_all_experts=True, all experts process all tokens + to ensure proper calibration statistics. Enables activations + through all expert paths. + """ + assert ( + routing_weights is not None and router_indices is not None + ), "router inputs required" + + x, router_indices, routing_weights, B, H = self._normalize_shapes( + hidden_states, router_indices, routing_weights + ) + + out = self._route_and_compute( + x, router_indices, routing_weights, self.calibrate_all_experts + ) return out.view(B, -1, H) @@ -186,7 +284,9 @@ def get_module_by_path(root: nn.Module, dotpath: str) -> nn.Module: return m -def set_module_by_path(root: nn.Module, dotpath: str, new_module: nn.Module) -> None: +def set_module_by_path( + root: nn.Module, dotpath: str, new_module: nn.Module +) -> None: parts = dotpath.split(".") parent = get_module_by_path(root, ".".join(parts[:-1])) setattr(parent, parts[-1], new_module) @@ -218,14 +318,23 @@ def find_experts(model: nn.Module) -> List[ExpertMeta]: return metas -def convert_model_for_quantization_gptoss(model: nn.Module) -> None: +def convert_model_for_quantization_gptoss( + model: nn.Module, calibrate_all_experts: bool = True +) -> None: """ - In-place conversion of a GPT-OSS model: - - - Finds all fused MoE expert blocks (with gate_up_proj/down_proj). - - Replaces them with LinearExperts that expose plain nn.Linear - parameters (gate_proj, up_proj, down_proj), which play nicely - with LLM Compressor W4A8 quantization. + In-place conversion of a GPT-OSS model for quantization. + + This function performs two key transformations: + 1. Linearizes fused MoE expert blocks (gate_up_proj/down_proj) into + separate nn.Linear parameters (gate_proj, up_proj, down_proj) + 2. Wraps them with CalibrationLinearExperts for proper calibration + + Args: + model: The GPT-OSS model to convert (modified in-place) + calibrate_all_experts: If True, all experts will see all tokens + during calibration. This is the recommended setting for proper + quantization statistics. Set to False only if you want normal + routing behavior during calibration. """ metas = find_experts(model) for meta in metas: @@ -243,17 +352,25 @@ def convert_model_for_quantization_gptoss(model: nn.Module) -> None: ): continue - new_exp = LinearExperts( + # Step 1: Create LinearExperts with separate gate/up/down projections + linear_experts = LinearExperts( hidden_size=meta.hidden_size, intermediate_size=meta.intermediate_size, num_experts=meta.num_experts, ).to(device=meta.device, dtype=meta.dtype) - new_exp.copy_from_fused_weights( + linear_experts.copy_from_fused_weights( legacy_gate_up_W=legacy.gate_up_proj, legacy_gate_up_b=legacy.gate_up_proj_bias, legacy_down_W=legacy.down_proj, legacy_down_b=legacy.down_proj_bias, ) - set_module_by_path(model, meta.path, new_exp) + # Step 2: Wrap with CalibrationLinearExperts for MoE calibration + calibration_experts = CalibrationLinearExperts( + original=linear_experts, + config=model.config, + calibrate_all_experts=calibrate_all_experts, + ) + + set_module_by_path(model, meta.path, calibration_experts) diff --git a/tests/llmcompressor/modeling/test_calib_gpt_oss.py b/tests/llmcompressor/modeling/test_calib_gpt_oss.py new file mode 100644 index 0000000000..8828663172 --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_gpt_oss.py @@ -0,0 +1,210 @@ +import contextlib +from functools import partial + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.modeling.gpt_oss import ( + CalibrationLinearExperts, + LinearExperts, + convert_model_for_quantization_gptoss, +) +from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.utils.helpers import calibration_forward_context +from tests.testing_utils import requires_cadence, requires_gpu + + +@requires_cadence("weekly") +@pytest.mark.parametrize("model_stub", ["openai/gpt-oss-20b"]) +def test_convert_model_for_quantization_gptoss(model_stub): + """Test convert_model_for_quantization_gptoss correctly replaces modules.""" + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained( + model_stub, trust_remote_code=True + ) + + # convert model for quantization + convert_model_for_quantization_gptoss(model, calibrate_all_experts=True) + + # find CalibrationLinearExperts layer + calib_layer = None + for _, module in model.named_modules(): + if isinstance(module, CalibrationLinearExperts): + calib_layer = module + break + + assert ( + calib_layer is not None + ), "No CalibrationLinearExperts found in model" + assert calib_layer.calibrate_all_experts is True + assert hasattr(calib_layer, "experts") + assert len(calib_layer.experts) > 0 + + +@requires_cadence("weekly") +@pytest.mark.parametrize("model_stub", ["openai/gpt-oss-20b"]) +def test_calib_replace_gptoss_all_experts(model_stub): + """Test all experts are triggered when calibrate_all_experts=True.""" + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained( + model_stub, trust_remote_code=True + ) + + # convert model with calibrate_all_experts enabled + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + convert_model_for_quantization_gptoss( + model, calibrate_all_experts=True + ) + + # find a CalibrationLinearExperts layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, CalibrationLinearExperts): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = moe_layer.hidden_size + batch, seq_len = 2, 16 + num_experts_plus_dummy = moe_layer.num_experts + 1 + + # Create sample input + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + + # Create router outputs (indices and weights) + # Simulate router selecting 2 experts per token + top_k = 2 + router_indices = torch.randint( + 0, moe_layer.num_experts, (batch, seq_len, top_k) + ) + routing_weights = torch.randn(batch, seq_len, num_experts_plus_dummy) + routing_weights = torch.softmax(routing_weights, dim=-1) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample, router_indices, routing_weights) + + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" + + +@requires_gpu +def test_calib_linear_experts_module(): + """Test correctness of CalibrationLinearExperts""" + # Create a LinearExperts module + hidden_size = 768 + intermediate_size = 2880 + num_experts = 8 + batch, seq_len = 2, 16 + top_k = 2 + + with torch.device("cuda"): + original = LinearExperts( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts=num_experts, + ).eval() + + # Create dummy input + sample = torch.randn(batch, seq_len, hidden_size, device="cuda") + + # Create router outputs + num_experts_plus_dummy = num_experts + 1 + router_indices = torch.randint(0, num_experts, (batch, seq_len, top_k)).to( + "cuda" + ) + routing_weights = torch.randn(batch, seq_len, num_experts_plus_dummy).to( + "cuda" + ) + routing_weights = torch.softmax(routing_weights, dim=-1) + + # Get original output + with calibration_forward_context(original): + true_output = original(sample, router_indices, routing_weights) + + # Test with calibrate_all_experts=True + class MockConfig: + pass + + config = MockConfig() + module = CalibrationLinearExperts( + original, config, calibrate_all_experts=True + ) + with calibration_forward_context(module): + output = module(sample, router_indices, routing_weights) + assert torch.allclose(true_output, output, atol=1e-5) + + # Test with calibrate_all_experts=False + module = CalibrationLinearExperts( + original, config, calibrate_all_experts=False + ) + with calibration_forward_context(module): + output = module(sample, router_indices, routing_weights) + assert torch.allclose(true_output, output, atol=1e-5) + + +@requires_gpu +def test_linear_experts_shape_normalization(): + """Test that _normalize_shapes work correctly.""" + hidden_size = 768 + intermediate_size = 2880 + num_experts = 8 + + with torch.device("cuda"): + module = LinearExperts( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts=num_experts, + ).eval() + + # Test 3D input + batch, seq_len = 4, 32 + sample_3d = torch.randn(batch, seq_len, hidden_size, device="cuda") + router_indices_3d = torch.randint(0, num_experts, (batch, seq_len, 2)).to( + "cuda" + ) + routing_weights_3d = torch.randn(batch, seq_len, num_experts + 1).to( + "cuda" + ) + + x, indices, weights, B, H = module._normalize_shapes( + sample_3d, router_indices_3d, routing_weights_3d + ) + + assert x.shape == (batch * seq_len, hidden_size) + assert indices.shape == (batch * seq_len, 2) + assert weights.shape == (batch * seq_len, num_experts + 1) + assert B == batch + assert H == hidden_size + + # Test 2D input (already flattened) + tokens = batch * seq_len + sample_2d = torch.randn(tokens, hidden_size, device="cuda") + router_indices_2d = torch.randint(0, num_experts, (tokens, 2)).to("cuda") + routing_weights_2d = torch.randn(tokens, num_experts + 1).to("cuda") + + x, indices, weights, B, H = module._normalize_shapes( + sample_2d, router_indices_2d, routing_weights_2d + ) + + assert x.shape == (tokens, hidden_size) + assert indices.shape == (tokens, 2) + assert weights.shape == (tokens, num_experts + 1) + assert B == 1 + assert H == hidden_size