From 9d7c8c0ff695cba06eb0cd3e3f5e3b23ad82146d Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 2 Apr 2026 18:06:38 +0000 Subject: [PATCH 1/3] Refactor Qwen3.5 MOE support for HF 5.0 Signed-off-by: Chenjie Luo --- modelopt/torch/export/layer_utils.py | 5 +- modelopt/torch/export/moe_utils.py | 117 ++++++++++++++++ modelopt/torch/export/unified_export_hf.py | 13 ++ .../torch/quantization/plugins/huggingface.py | 132 ++++++------------ 4 files changed, 175 insertions(+), 92 deletions(-) diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index 96ecf91e5b4..b6ccd426430 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -971,7 +971,6 @@ def module_match_name_list(module, name_list): "Qwen2MoeSparseMoeBlock", "Qwen3MoeSparseMoeBlock", "Qwen3NextSparseMoeBlock", - "Qwen3_5MoeSparseMoeBlock", "DeepseekMoE", ], ): @@ -980,8 +979,8 @@ def module_match_name_list(module, name_list): return ["linear_fc1", "linear_fc2"] elif module_match_name_list(module, ["DBRXMoeSparseMoeBlock"]): return ["w1_linear", "w2_linear", "v1_linear"] - elif module_match_name_list(module, ["GptOssMoE"]): - # GPT-OSS MoE modules use gate_up_proj and down_proj + elif module_match_name_list(module, ["GptOssMoE", "Qwen3_5MoeSparseMoeBlock"]): + # These MoE modules use fused gate_up_proj and down_proj return ["gate_up_proj", "down_proj"] else: # assuming w1, w2, w3 by default diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index dc357486893..b5317474d0d 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -17,9 +17,126 @@ from pathlib import Path +import torch import torch.nn as nn +def _export_qwen35_experts(module: nn.Module, dtype: torch.dtype) -> None: + """Split fused Qwen3.5 MoE expert weights and export per-expert quantization scales. + + The quantized ``Qwen3_5MoeExperts`` keeps fused 3D ``gate_up_proj`` and ``down_proj`` + parameters with per-expert quantizer ``ModuleList`` s at runtime. This function: + + 1. Handles amax fallback for uncalibrated expert quantizers. + 2. Splits the fused 3D weights into per-expert 2D projections. + 3. Calls ``_export_quantized_weight`` on each projection to compute scales and + quantize weights in the format expected by downstream consumers. + 4. Registers the results under the standard per-expert naming convention:: + + {E}.gate_proj.weight, {E}.gate_proj.weight_scale, ... + {E}.up_proj.weight, {E}.up_proj.weight_scale, ... + {E}.down_proj.weight, {E}.down_proj.weight_scale, ... + """ + import copy + + from modelopt.torch.export.layer_utils import set_expert_quantizer_amax + from modelopt.torch.export.unified_export_hf import _export_quantized_weight + + n = module.num_experts + expert_dim = module.intermediate_dim + + # 1. Amax fallback for uncalibrated expert input quantizers. + # Input amax depends on activations seen during calibration and can't be + # recomputed from weights, so borrow from calibrated peers. + # Weight quantizer amax is handled by _export_quantized_weight directly. + for quantizer_list in [ + module.gate_up_proj_input_quantizers, + module.down_proj_input_quantizers, + ]: + wrappers = [] + for q in quantizer_list: + w = nn.Module() + w.input_quantizer = q + wrappers.append(w) + set_expert_quantizer_amax(modules=wrappers, quantizer_attrs=["input_quantizer"]) + + gate_up = module.gate_up_proj.data + down = module.down_proj.data + + # 2-3. Split weights, export per-expert projections + # Each projection is (name, weight_slice, fused_start, fused_dim0). + # fused_start/fused_dim0 are used to proportionally slice per-channel amax + # when gate/up share a weight quantizer from the fused gate_up_proj. + fused_dim0 = gate_up.shape[1] # 2 * expert_dim + + for idx in range(n): + expert = nn.Module() + + projections = [ + ("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0), + ("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0), + ("down_proj", down[idx], None, None), + ] + + for proj_name, weight_slice, fused_start, fused_total in projections: + is_down = proj_name == "down_proj" + w_quantizer_src = ( + module.down_proj_weight_quantizers[idx] + if is_down + else module.gate_up_proj_weight_quantizers[idx] + ) + i_quantizer = ( + module.down_proj_input_quantizers[idx] + if is_down + else module.gate_up_proj_input_quantizers[idx] + ) + + # Clone weight quantizer so gate/up each get independent amax + w_quantizer = copy.deepcopy(w_quantizer_src) + + # For shared gate_up quantizers with per-channel amax (dim >= 1), + # proportionally slice dim0 to match the split weight. + if fused_start is not None and hasattr(w_quantizer, "_amax"): + amax = w_quantizer._amax + if amax.dim() >= 1: + amax_dim0 = amax.shape[0] + if fused_total % amax_dim0 != 0: + raise ValueError( + f"Fused weight dim0 ({fused_total}) is not divisible by " + f"amax dim0 ({amax_dim0})." + ) + slice_start = fused_start * amax_dim0 // fused_total + slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total + w_quantizer._amax = amax[slice_start:slice_end].contiguous() + + # Build a wrapper module that _export_quantized_weight understands + wrapper = nn.Module() + wrapper.weight = nn.Parameter(weight_slice.contiguous(), requires_grad=False) + wrapper.weight_quantizer = w_quantizer + wrapper.input_quantizer = i_quantizer + + _export_quantized_weight(wrapper, dtype) + + # Collect results into the per-expert submodule + proj = nn.Module() + proj.weight = wrapper.weight + for attr in ("weight_scale", "weight_scale_2", "input_scale"): + if hasattr(wrapper, attr): + proj.register_buffer(attr, getattr(wrapper, attr)) + + expert.add_module(proj_name, proj) + + module.add_module(str(idx), expert) + + # 4. Remove fused params and quantizer lists — replaced by per-expert submodules + delattr(module, "gate_up_proj") + delattr(module, "down_proj") + delattr(module, "gate_up_proj_weight_quantizers") + delattr(module, "gate_up_proj_input_quantizers") + delattr(module, "down_proj_weight_quantizers") + delattr(module, "down_proj_input_quantizers") + + def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): """Collect expert_token_count from all quantized MoE layers and save as an HTML table. diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 14a12bcdf3a..c9f8559a50d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -671,6 +671,14 @@ def _process_quantized_modules( with fsdp2_aware_weight_update(model, sub_module, reshard=False): for weight_name in ["gate_up_proj", "down_proj"]: _export_quantized_weight(sub_module, dtype, weight_name) + elif "Qwen3_5MoeExperts" in type(sub_module).__name__: + # Qwen3.5 MoE uses fused 3D params with per-expert quantizer + # ModuleLists at runtime. Split into per-expert modules and + # export each projection individually. + from modelopt.torch.export.moe_utils import _export_qwen35_experts + + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_qwen35_experts(sub_module, dtype) def _export_transformers_checkpoint( @@ -715,6 +723,11 @@ def _export_transformers_checkpoint( modules=list(linear_modulelist), quantizer_attrs=["input_quantizer"], ) + elif "QuantQwen3_5MoeExperts" in type(sub_module.experts).__name__: + # Qwen3.5 MoE uses per-expert quantizer ModuleLists. + # Amax fallback and scale export are handled by + # _export_qwen35_experts in _process_quantized_modules. + break elif "QuantGptOssExperts" in type(sub_module.experts).__name__: # Handle GPT-OSS experts specifically # GPT-OSS experts use gate_up_proj and down_proj diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 0d02716a6e9..1e17f5f5031 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -767,105 +767,59 @@ def forward( return next_states -class _Qwen35MoeExpertModule(nn.Module): - """Container for a single Qwen3.5 MoE expert's linear layers. +class _QuantQwen35MoeExperts(_QuantFunctionalMixin): + """Quantized wrapper for ``transformers.Qwen3_5MoeExperts``. - Produces the naming pattern: experts.{id}.gate_proj.weight - (consistent with standard Qwen3 MoE per-expert module structure). - """ + Keeps the original fused 3D ``gate_up_proj`` / ``down_proj`` parameters and + the unmodified HF forward (single ``F.linear`` + chunk per expert). - def __init__(self, hidden_dim: int, expert_dim: int): - super().__init__() - self.gate_proj = nn.Linear(hidden_dim, expert_dim, bias=False) - self.up_proj = nn.Linear(hidden_dim, expert_dim, bias=False) - self.down_proj = nn.Linear(expert_dim, hidden_dim, bias=False) + Per-expert quantization is achieved by intercepting ``F.linear`` and recovering + the expert index from the weight tensor's storage offset into the 3D parameter. + Each expert gets its own weight and input quantizers (``ModuleList``), so + calibration granularity matches the per-expert decomposition approach. + """ + def _get_expert_idx(self, weight: torch.Tensor) -> int: + """Recover expert index from a weight slice's storage offset.""" + base_offset = self.gate_up_proj.storage_offset() + stride = self.gate_up_proj.stride(0) + if stride == 0: + return 0 + return (weight.storage_offset() - base_offset) // stride -class _QuantQwen35MoeExperts(QuantModule): def _setup(self): - """Modify the Qwen3_5MoeExperts by using per-expert nn.Module containers. - - This produces the naming pattern: experts.{id}.gate_proj.weight - (consistent with standard Qwen3 MoE). - """ - from accelerate import init_empty_weights - - dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device - - def _copy_weight(module, weight): - module.to_empty(device=device) - with torch.no_grad(): - module.weight.data = weight.detach().data.to(dtype=dtype, device=device) - - expert_dim = self.intermediate_dim + n = self.num_experts + self.gate_up_proj_input_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)]) + self.gate_up_proj_weight_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)]) + self.down_proj_input_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)]) + self.down_proj_weight_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)]) - with init_empty_weights(): - expert_modules = nn.ModuleList( - [ - _Qwen35MoeExpertModule(self.hidden_dim, expert_dim) - for _ in range(self.num_experts) - ] - ) - - for idx in range(self.num_experts): - # gate_up_proj shape: (num_experts, 2*intermediate_dim, hidden_dim) - # Already in (out_features, in_features) format, no transpose needed - _copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :expert_dim, :]) - _copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, expert_dim:, :]) - # down_proj shape: (num_experts, hidden_dim, intermediate_dim) - # Already in (out_features, in_features) format - _copy_weight(expert_modules[idx].down_proj, self.down_proj[idx]) - - delattr(self, "gate_up_proj") - delattr(self, "down_proj") - # Register expert modules directly as numbered children (like nn.ModuleList) - # so the naming pattern is: experts.{id}.gate_proj.weight (no extra nesting) - for idx in range(self.num_experts): - self.add_module(str(idx), expert_modules[idx]) + self._register_temp_attribute("_down_proj_linear", False) - def __len__(self): - """Support len() so the module is iterable like standard MoE experts.""" - return self.num_experts + @property + def functionals_to_replace(self): + _orig_linear = torch.nn.functional.linear - def __iter__(self): - """Support iteration over expert modules.""" - for idx in range(self.num_experts): - yield getattr(self, str(idx)) + def _quantized_linear(input, weight, bias=None): + if self._down_proj_linear: + expert_idx = self._current_expert_idx + input = self.down_proj_input_quantizers[expert_idx](input) + weight = self.down_proj_weight_quantizers[expert_idx](weight) + else: + expert_idx = self._get_expert_idx(weight) + self._current_expert_idx = expert_idx + input = self.gate_up_proj_input_quantizers[expert_idx](input) + weight = self.gate_up_proj_weight_quantizers[expert_idx](weight) + self._down_proj_linear = not self._down_proj_linear + return _orig_linear(input, weight, bias) - def __getitem__(self, idx): - """Support indexing to get individual expert modules.""" - return getattr(self, str(int(idx))) + return [ + (torch.nn.functional, "linear", _quantized_linear), + ] - def forward( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - ) -> torch.Tensor: - final_hidden_states = torch.zeros_like(hidden_states) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - expert_idx = expert_idx[0] - if expert_idx == self.num_experts: - continue - with torch.no_grad(): - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - expert = self[expert_idx] - gate = expert.gate_proj(current_state) - up = expert.up_proj(current_state) - current_hidden_states = self.act_fn(gate) * up - current_hidden_states = expert.down_proj(current_hidden_states) - current_hidden_states = ( - current_hidden_states * top_k_weights[token_idx, top_k_pos, None] - ) - final_hidden_states.index_add_( - 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) - ) - return final_hidden_states + def forward(self, *args, **kwargs): + self._down_proj_linear = False + return super().forward(*args, **kwargs) class _QuantDbrxFFN(_QuantSparseMoe): From d3f23c5972d4c8fa8c0fc224516ab7e18153bf0d Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 2 Apr 2026 18:16:27 +0000 Subject: [PATCH 2/3] Reviewer1 Signed-off-by: Chenjie Luo --- modelopt/torch/export/layer_utils.py | 2 +- modelopt/torch/export/moe_utils.py | 64 +++++++++++-------- modelopt/torch/export/unified_export_hf.py | 2 +- .../torch/quantization/plugins/huggingface.py | 7 +- 4 files changed, 44 insertions(+), 31 deletions(-) diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index b6ccd426430..8cd382449b6 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -980,7 +980,7 @@ def module_match_name_list(module, name_list): elif module_match_name_list(module, ["DBRXMoeSparseMoeBlock"]): return ["w1_linear", "w2_linear", "v1_linear"] elif module_match_name_list(module, ["GptOssMoE", "Qwen3_5MoeSparseMoeBlock"]): - # These MoE modules use fused gate_up_proj and down_proj + # GptOssMoE and Qwen3_5MoeSparseMoeBlock use fused gate_up_proj and down_proj return ["gate_up_proj", "down_proj"] else: # assuming w1, w2, w3 by default diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index b5317474d0d..086daecf7db 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -73,41 +73,53 @@ def _export_qwen35_experts(module: nn.Module, dtype: torch.dtype) -> None: expert = nn.Module() projections = [ - ("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0), - ("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0), - ("down_proj", down[idx], None, None), + ("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0, True), + ("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0, True), + ("down_proj", down[idx], 0, down.shape[1], False), ] - for proj_name, weight_slice, fused_start, fused_total in projections: - is_down = proj_name == "down_proj" + for proj_name, weight_slice, fused_start, fused_total, is_gate_up in projections: w_quantizer_src = ( - module.down_proj_weight_quantizers[idx] - if is_down - else module.gate_up_proj_weight_quantizers[idx] + module.gate_up_proj_weight_quantizers[idx] + if is_gate_up + else module.down_proj_weight_quantizers[idx] ) i_quantizer = ( - module.down_proj_input_quantizers[idx] - if is_down - else module.gate_up_proj_input_quantizers[idx] + module.gate_up_proj_input_quantizers[idx] + if is_gate_up + else module.down_proj_input_quantizers[idx] ) - # Clone weight quantizer so gate/up each get independent amax - w_quantizer = copy.deepcopy(w_quantizer_src) + # gate/up share a weight quantizer — clone so each gets independent amax. + # down_proj has its own quantizer and uses the full range, no clone needed. + w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src - # For shared gate_up quantizers with per-channel amax (dim >= 1), - # proportionally slice dim0 to match the split weight. - if fused_start is not None and hasattr(w_quantizer, "_amax"): + # For per-channel amax (dim >= 1), proportionally slice dim0 + # to match the split weight. + if hasattr(w_quantizer, "_amax") and w_quantizer._amax.dim() >= 1: amax = w_quantizer._amax - if amax.dim() >= 1: - amax_dim0 = amax.shape[0] - if fused_total % amax_dim0 != 0: - raise ValueError( - f"Fused weight dim0 ({fused_total}) is not divisible by " - f"amax dim0 ({amax_dim0})." - ) - slice_start = fused_start * amax_dim0 // fused_total - slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total - w_quantizer._amax = amax[slice_start:slice_end].contiguous() + amax_dim0 = amax.shape[0] + if fused_total % amax_dim0 != 0: + raise ValueError( + f"Fused weight dim0 ({fused_total}) is not divisible by " + f"amax dim0 ({amax_dim0})." + ) + slice_start = fused_start * amax_dim0 // fused_total + slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total + w_quantizer._amax = amax[slice_start:slice_end].contiguous() + + # If the weight quantizer was never calibrated (expert received no + # tokens), compute amax directly from the weight data. + if ( + hasattr(w_quantizer, "is_enabled") + and w_quantizer.is_enabled + and ( + not hasattr(w_quantizer, "_amax") + or w_quantizer._amax is None + or torch.all(w_quantizer._amax == 0) + ) + ): + w_quantizer.amax = weight_slice.abs().amax().to(torch.float32) # Build a wrapper module that _export_quantized_weight understands wrapper = nn.Module() diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index c9f8559a50d..5a78b1059c9 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -727,7 +727,7 @@ def _export_transformers_checkpoint( # Qwen3.5 MoE uses per-expert quantizer ModuleLists. # Amax fallback and scale export are handled by # _export_qwen35_experts in _process_quantized_modules. - break + break # exits the inner `for linear_name` loop; type check prevents re-entry elif "QuantGptOssExperts" in type(sub_module.experts).__name__: # Handle GPT-OSS experts specifically # GPT-OSS experts use gate_up_proj and down_proj diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 1e17f5f5031..b76ed9fb326 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -779,8 +779,8 @@ class _QuantQwen35MoeExperts(_QuantFunctionalMixin): calibration granularity matches the per-expert decomposition approach. """ - def _get_expert_idx(self, weight: torch.Tensor) -> int: - """Recover expert index from a weight slice's storage offset.""" + def _get_expert_idx_from_gate_up(self, weight: torch.Tensor) -> int: + """Recover expert index from a ``gate_up_proj`` weight slice's storage offset.""" base_offset = self.gate_up_proj.storage_offset() stride = self.gate_up_proj.stride(0) if stride == 0: @@ -795,6 +795,7 @@ def _setup(self): self.down_proj_weight_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)]) self._register_temp_attribute("_down_proj_linear", False) + self._register_temp_attribute("_current_expert_idx", 0) @property def functionals_to_replace(self): @@ -806,7 +807,7 @@ def _quantized_linear(input, weight, bias=None): input = self.down_proj_input_quantizers[expert_idx](input) weight = self.down_proj_weight_quantizers[expert_idx](weight) else: - expert_idx = self._get_expert_idx(weight) + expert_idx = self._get_expert_idx_from_gate_up(weight) self._current_expert_idx = expert_idx input = self.gate_up_proj_input_quantizers[expert_idx](input) weight = self.gate_up_proj_weight_quantizers[expert_idx](weight) From 59d10d9da1933ead0f313ad3ceba238d2277dd5e Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 2 Apr 2026 18:23:15 +0000 Subject: [PATCH 3/3] Reviewer 2 Signed-off-by: Chenjie Luo --- modelopt/torch/export/moe_utils.py | 10 +++++++++- modelopt/torch/quantization/plugins/huggingface.py | 11 ++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index 086daecf7db..51ce5e9acd1 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -43,7 +43,15 @@ def _export_qwen35_experts(module: nn.Module, dtype: torch.dtype) -> None: from modelopt.torch.export.unified_export_hf import _export_quantized_weight n = module.num_experts - expert_dim = module.intermediate_dim + + # The attribute name was changed from `intermediate_size` to `intermediate_dim` in + # https://github.com/huggingface/transformers/commit/0642963ba13f2dae0596fe489415569e1d91fbda + if hasattr(module, "intermediate_size"): + expert_dim = module.intermediate_size + elif hasattr(module, "intermediate_dim"): + expert_dim = module.intermediate_dim + else: + raise AttributeError("Could not find intermediate dimension size in module") # 1. Amax fallback for uncalibrated expert input quantizers. # Input amax depends on activations seen during calibration and can't be diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index b76ed9fb326..d1c9d351865 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -780,7 +780,12 @@ class _QuantQwen35MoeExperts(_QuantFunctionalMixin): """ def _get_expert_idx_from_gate_up(self, weight: torch.Tensor) -> int: - """Recover expert index from a ``gate_up_proj`` weight slice's storage offset.""" + """Recover expert index from a ``gate_up_proj`` weight slice's storage offset. + + This relies on ``gate_up_proj[idx]`` returning a view into contiguous storage + (standard PyTorch indexing behaviour). The invariant breaks if the tensor is + ``.contiguous()``-copied or redistributed by certain distributed wrappers. + """ base_offset = self.gate_up_proj.storage_offset() stride = self.gate_up_proj.stride(0) if stride == 0: @@ -801,6 +806,10 @@ def _setup(self): def functionals_to_replace(self): _orig_linear = torch.nn.functional.linear + # The toggle assumes the HF forward calls F.linear exactly twice per expert + # in strict alternation: first for gate_up_proj, then for down_proj. + # forward() resets the toggle before each call to super().forward(), so a + # stale state from a prior exception does not carry over across forward passes. def _quantized_linear(input, weight, bias=None): if self._down_proj_linear: expert_idx = self._current_expert_idx