From 2aff4a8879e2c028b17b13add4eebb2fd1d9c9ee Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 12 Jan 2026 09:46:21 +0100 Subject: [PATCH 01/28] experts impl gpt oss --- src/transformers/conversion_mapping.py | 15 +++ src/transformers/integrations/moe.py | 14 ++- .../models/gpt_oss/modeling_gpt_oss.py | 114 +++++++---------- .../models/gpt_oss/modular_gpt_oss.py | 116 +++++++----------- 4 files changed, 116 insertions(+), 143 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index e56db36874b6..b73936898935 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -163,6 +163,21 @@ def _build_checkpoint_conversion_mapping(): operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)], ), ], + # we have to create a temorary name to avoid conflicts during reversed conversion + "gpt_oss": [ + WeightRenaming("mlp.gate_up_proj", "mlp.gate_up_proj_tmp"), + WeightRenaming("mlp.down_proj", "mlp.down_proj_tmp"), + WeightConverter( + source_patterns="mlp.gate_up_proj_tmp", + target_patterns="mlp.gate_up_proj", + operations=[Transpose(dim0=0, dim1=1)], + ), + WeightConverter( + source_patterns="mlp.down_proj_tmp", + target_patterns="mlp.down_proj", + operations=[Transpose(dim0=0, dim1=1)], + ), + ], "jamba": [ WeightConverter( source_patterns=[ diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 87302c656be0..5c78f67d864e 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -107,7 +107,12 @@ def batched_mm_experts_forward( gate, up = gate_up_out.chunk(2, dim=-1) # both have shape (S, intermediate_dim) # Apply activation - hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) + if hasattr(self, "_apply_gate"): + # Applies custom handling of the gating mechanism if defined + hidden_after_activation = self._apply_gate(gate, up) # (S, intermediate_dim) + else: + # Default gating mechanism + hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (batched) --- out_per_sample = torch.bmm(selected_down, hidden_after_activation.unsqueeze(-1)).squeeze(-1) @@ -185,7 +190,12 @@ def grouped_mm_experts_forward( gate, up = gate_up_out.chunk(2, dim=-1) # both have shape (S, intermediate_dim) # Apply activation - hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) + if hasattr(self, "_apply_gate"): + # Applies custom handling of the gating mechanism if defined + hidden_after_activation = self._apply_gate(gate, up) # (S, intermediate_dim) + else: + # Default gating mechanism + hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (grouped_mm) --- out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj.transpose(-2, -1), offs=offsets) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 311b1baf15b5..a5b93d248246 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -27,7 +27,7 @@ from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub, use_kernelized_func +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import ( GenericForSequenceClassification, @@ -64,87 +64,59 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_experts_implementation class GptOssExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size)) self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.hidden_size, self.expert_dim))) self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: - """ - When training it is more efficient to just loop over the experts and compute the output for each expert - as otherwise the memory would explode. - - For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. + def _apply_gate(self, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + return gated_output - Args: - hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) - selected_experts (torch.Tensor): (batch_size * seq_len, top_k) - routing_weights (torch.Tensor): (batch_size * seq_len, top_k) - Returns: - torch.Tensor - """ - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - if hidden_states.device.type == "cpu" or self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=self.num_experts - ) # masking is also a class - expert_mask = expert_mask.permute(2, 1, 0) - # we sum on the top_k and on the sequence length to get which experts - # are hit this time around - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit[:]: - # expert_idx only have 1 element, so we can use scale for fast indexing - expert_idx = expert_idx[0] - # skip masking index - if expert_idx == self.num_experts: - continue - with torch.no_grad(): - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, top_k_pos, None] - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states.view(batch_size, -1, self.hidden_size) - else: - num_tokens = hidden_states.shape[0] - hidden_states = hidden_states.repeat(self.num_experts, 1) - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) - next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size) - - full_routing_weights = torch.zeros( - num_tokens, self.num_experts, device=routing_weights.device, dtype=routing_weights.dtype + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear( + current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_bias[expert_idx] + ).chunk(2, dim=-1) + current_hidden_states = self._apply_gate(gate, up) + current_hidden_states = nn.functional.linear( + current_hidden_states, self.down_proj[expert_idx], self.down_proj_bias[expert_idx] ) - full_routing_weights.scatter_(1, router_indices, routing_weights) - full_routing_weights = full_routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1, 1) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) - next_states = next_states * full_routing_weights - next_states = next_states.sum(dim=0) - return next_states + return final_hidden_states class GptOssTopKRouter(nn.Module): @@ -157,7 +129,6 @@ def __init__(self, config): self.bias = nn.Parameter(torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) @@ -173,9 +144,12 @@ def __init__(self, config): self.experts = GptOssExperts(config) def forward(self, hidden_states): + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) _, router_scores, router_indices = self.router(hidden_states) - routed_out = self.experts(hidden_states, router_indices, router_scores) - return routed_out, router_scores + hidden_states = self.experts(hidden_states, router_indices, router_scores) + hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) + return hidden_states, router_scores class GptOssRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 2db276752010..2699d7ab5969 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -19,7 +19,7 @@ from ... import initialization as init from ...cache_utils import Cache, DynamicCache -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_outputs import ( MoeModelOutputWithPast, @@ -61,87 +61,59 @@ def forward(self, hidden_states): return (self.weight * hidden_states).to(input_dtype) # main diff with Llama +@use_experts_implementation class GptOssExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size)) self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.hidden_size, self.expert_dim))) self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: - """ - When training it is more efficient to just loop over the experts and compute the output for each expert - as otherwise the memory would explode. - - For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. - - Args: - hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) - selected_experts (torch.Tensor): (batch_size * seq_len, top_k) - routing_weights (torch.Tensor): (batch_size * seq_len, top_k) - Returns: - torch.Tensor - """ - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - if hidden_states.device.type == "cpu" or self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=self.num_experts - ) # masking is also a class - expert_mask = expert_mask.permute(2, 1, 0) - # we sum on the top_k and on the sequence length to get which experts - # are hit this time around - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit[:]: - # expert_idx only have 1 element, so we can use scale for fast indexing - expert_idx = expert_idx[0] - # skip masking index - if expert_idx == self.num_experts: - continue - with torch.no_grad(): - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, top_k_pos, None] - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states.view(batch_size, -1, self.hidden_size) - else: - num_tokens = hidden_states.shape[0] - hidden_states = hidden_states.repeat(self.num_experts, 1) - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) - next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size) - - full_routing_weights = torch.zeros( - num_tokens, self.num_experts, device=routing_weights.device, dtype=routing_weights.dtype + def _apply_gate(self, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + return gated_output + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear( + current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_bias[expert_idx] + ).chunk(2, dim=-1) + current_hidden_states = self._apply_gate(gate, up) + current_hidden_states = nn.functional.linear( + current_hidden_states, self.down_proj[expert_idx], self.down_proj_bias[expert_idx] ) - full_routing_weights.scatter_(1, router_indices, routing_weights) - full_routing_weights = full_routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1, 1) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) - next_states = next_states * full_routing_weights - next_states = next_states.sum(dim=0) - return next_states + return final_hidden_states class GptOssTopKRouter(nn.Module): @@ -154,7 +126,6 @@ def __init__(self, config): self.bias = nn.Parameter(torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) @@ -170,9 +141,12 @@ def __init__(self, config): self.experts = GptOssExperts(config) def forward(self, hidden_states): + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) _, router_scores, router_indices = self.router(hidden_states) - routed_out = self.experts(hidden_states, router_indices, router_scores) - return routed_out, router_scores + hidden_states = self.experts(hidden_states, router_indices, router_scores) + hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) + return hidden_states, router_scores class GptOssRotaryEmbedding(Qwen2RotaryEmbedding): From 9958efbac91f14a0645cd0b305d1fd0e7e47b915 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 12 Jan 2026 12:52:28 +0100 Subject: [PATCH 02/28] no need to transpose dequantized experts --- src/transformers/integrations/mxfp4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index f0a8de789f48..b3245d2a0a06 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -270,7 +270,7 @@ def convert_moe_packed_tensors( out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) del blocks, scales, lut - return out.transpose(1, 2).contiguous() + return out.contiguous() class Mxfp4GptOssExperts(nn.Module): From b23e1ffad3fe567ba741091c18d77b563bb4fd10 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 12 Jan 2026 13:38:50 +0100 Subject: [PATCH 03/28] skip test_reverse_loading_mapping --- src/transformers/conversion_mapping.py | 2 +- tests/models/gpt_oss/test_modeling_gpt_oss.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index b73936898935..8440720f92e1 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -166,12 +166,12 @@ def _build_checkpoint_conversion_mapping(): # we have to create a temorary name to avoid conflicts during reversed conversion "gpt_oss": [ WeightRenaming("mlp.gate_up_proj", "mlp.gate_up_proj_tmp"), - WeightRenaming("mlp.down_proj", "mlp.down_proj_tmp"), WeightConverter( source_patterns="mlp.gate_up_proj_tmp", target_patterns="mlp.gate_up_proj", operations=[Transpose(dim0=0, dim1=1)], ), + WeightRenaming("mlp.down_proj", "mlp.down_proj_tmp"), WeightConverter( source_patterns="mlp.down_proj_tmp", target_patterns="mlp.down_proj", diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 55fa9b98b74f..961aa49dd7e2 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -121,6 +121,10 @@ def test_flex_attention_with_grads(self): def test_generate_compile_model_forward_fullgraph(self): return super().test_generate_compile_model_forward_fullgraph() + @unittest.skip("GptOss's conversion mapping does not rename weights, only transposes them") + def test_reverse_loading_mapping(self): + pass + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json" From e28f1555c4fc60f40e7bb80ad2ecb82733207ccf Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 13 Jan 2026 08:41:31 +0100 Subject: [PATCH 04/28] fix custom gating --- src/transformers/integrations/moe.py | 20 +++++++---------- .../models/gpt_oss/modeling_gpt_oss.py | 11 +++++----- .../models/gpt_oss/modular_gpt_oss.py | 11 +++++----- tests/models/gpt_oss/test_modeling_gpt_oss.py | 22 +++++++++++++++---- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 5c78f67d864e..74e549ec50f4 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -103,15 +103,13 @@ def batched_mm_experts_forward( if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids] - # Split into gate and up components - gate, up = gate_up_out.chunk(2, dim=-1) # both have shape (S, intermediate_dim) - - # Apply activation + # Apply gating if hasattr(self, "_apply_gate"): - # Applies custom handling of the gating mechanism if defined - hidden_after_activation = self._apply_gate(gate, up) # (S, intermediate_dim) + # Custom gating if defined + hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim) else: # Default gating mechanism + gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim) hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (batched) --- @@ -186,15 +184,13 @@ def grouped_mm_experts_forward( # we should be able to pass bias to the grouped_mm call, but it's still not fully supported gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids_g] - # Split into gate and up components - gate, up = gate_up_out.chunk(2, dim=-1) # both have shape (S, intermediate_dim) - - # Apply activation + # Apply gating if hasattr(self, "_apply_gate"): - # Applies custom handling of the gating mechanism if defined - hidden_after_activation = self._apply_gate(gate, up) # (S, intermediate_dim) + # Custom gating if defined + hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim) else: # Default gating mechanism + gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim) hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (grouped_mm) --- diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index a5b93d248246..d93fc00e8ece 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -76,12 +76,13 @@ def __init__(self, config): self.expert_dim = self.intermediate_size self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size)) self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.hidden_size, self.expert_dim))) + self.down_proj = nn.Parameter(torch.zeros((self.num_experts, self.hidden_size, self.expert_dim))) self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 - def _apply_gate(self, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + gate, up = gate_up[..., ::2], gate_up[..., 1::2] gate = gate.clamp(min=None, max=self.limit) up = up.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) @@ -106,10 +107,10 @@ def forward( continue top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] - gate, up = nn.functional.linear( + gate_up = nn.functional.linear( current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_bias[expert_idx] - ).chunk(2, dim=-1) - current_hidden_states = self._apply_gate(gate, up) + ) + current_hidden_states = self._apply_gate(gate_up) current_hidden_states = nn.functional.linear( current_hidden_states, self.down_proj[expert_idx], self.down_proj_bias[expert_idx] ) diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 2699d7ab5969..b0730172e3dc 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -73,12 +73,13 @@ def __init__(self, config): self.expert_dim = self.intermediate_size self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size)) self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.hidden_size, self.expert_dim))) + self.down_proj = nn.Parameter(torch.zeros((self.num_experts, self.hidden_size, self.expert_dim))) self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 - def _apply_gate(self, gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + gate, up = gate_up[..., ::2], gate_up[..., 1::2] gate = gate.clamp(min=None, max=self.limit) up = up.clamp(min=-self.limit, max=self.limit) glu = gate * torch.sigmoid(gate * self.alpha) @@ -103,10 +104,10 @@ def forward( continue top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] - gate, up = nn.functional.linear( + gate_up = nn.functional.linear( current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_bias[expert_idx] - ).chunk(2, dim=-1) - current_hidden_states = self._apply_gate(gate, up) + ) + current_hidden_states = self._apply_gate(gate_up) current_hidden_states = nn.functional.linear( current_hidden_states, self.down_proj[expert_idx], self.down_proj_bias[expert_idx] ) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 961aa49dd7e2..fa6c8a1ce705 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -48,9 +48,8 @@ if is_torch_available(): import torch - from transformers import ( - GptOssModel, - ) + from transformers import GptOssModel + from transformers.utils.quantization_config import Mxfp4Config NUM_GPUS = torch.cuda.device_count() @@ -349,6 +348,13 @@ def run_distributed_test(quantized, model, kernels, attn_impl, mode): @parameterized.expand(PARAMETERS) @require_read_token def test_model_outputs(self, quantized, model, kernels, attn_impl, mode): + additional_kwargs = {} + if not quantized: + additional_kwargs = { + "quantization_config": Mxfp4Config(dequantize=True), + "experts_implementation": "eager", + } + model_id = f"openai/gpt-oss-{model}" output_texts = self.load_and_forward( model_id, @@ -356,6 +362,7 @@ def test_model_outputs(self, quantized, model, kernels, attn_impl, mode): self.input_text, mode=mode, use_kernels=kernels, + **additional_kwargs, ) # Generate key to look up expected outputs @@ -426,14 +433,21 @@ def test_training_step(self, quantized, model, kernels, attn_impl, mode): if quantized: self.skipTest("Training test for quantized models is not supported.") - model_id = f"openai/gpt-oss-{model}" + additional_kwargs = {} + if not quantized: + additional_kwargs = { + "quantization_config": Mxfp4Config(dequantize=True), + "experts_implementation": "eager", + } + model_id = f"openai/gpt-oss-{model}" model_obj = AutoModelForCausalLM.from_pretrained( model_id, dtype=torch.bfloat16, device_map="auto", attn_implementation=attn_impl, use_kernels=kernels, + **additional_kwargs, ) model_obj.train() From be08fe4811e2dc59225d010c82c40f01e19dd403 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 13 Jan 2026 16:45:57 +0100 Subject: [PATCH 05/28] revert transposition and simply support transposed experts to avoid modifying eager --- src/transformers/conversion_mapping.py | 15 ------ src/transformers/integrations/moe.py | 28 +++++++++-- src/transformers/integrations/mxfp4.py | 2 +- .../models/gpt_oss/modeling_gpt_oss.py | 49 +++++++++---------- .../models/gpt_oss/modular_gpt_oss.py | 49 +++++++++---------- tests/models/gpt_oss/test_modeling_gpt_oss.py | 21 -------- 6 files changed, 69 insertions(+), 95 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 8440720f92e1..e56db36874b6 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -163,21 +163,6 @@ def _build_checkpoint_conversion_mapping(): operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)], ), ], - # we have to create a temorary name to avoid conflicts during reversed conversion - "gpt_oss": [ - WeightRenaming("mlp.gate_up_proj", "mlp.gate_up_proj_tmp"), - WeightConverter( - source_patterns="mlp.gate_up_proj_tmp", - target_patterns="mlp.gate_up_proj", - operations=[Transpose(dim0=0, dim1=1)], - ), - WeightRenaming("mlp.down_proj", "mlp.down_proj_tmp"), - WeightConverter( - source_patterns="mlp.down_proj_tmp", - target_patterns="mlp.down_proj", - operations=[Transpose(dim0=0, dim1=1)], - ), - ], "jamba": [ WeightConverter( source_patterns=[ diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 74e549ec50f4..85de8e756b5a 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -69,9 +69,11 @@ def batched_mm_experts_forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: device = hidden_states.device + hidden_dim = self.hidden_size num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) num_experts = self.gate_up_proj.size(0) + intermediate_dim = self.intermediate_size final_hidden_states = torch.zeros_like(hidden_states) # Flatten top_k_index to get expert_ids per selected sample @@ -99,7 +101,11 @@ def batched_mm_experts_forward( selected_down = self.down_proj[expert_ids] # (S, hidden_dim, intermediate_dim) # --- Up projection per expert (batched) --- - gate_up_out = torch.bmm(selected_gate_up, current_hidden_states.unsqueeze(-1)).squeeze(-1) + if selected_gate_up.shape == (num_tokens * num_top_k, hidden_dim, 2 * intermediate_dim): + gate_up_out = torch.bmm(current_hidden_states.unsqueeze(1), selected_gate_up).squeeze(1) + else: + gate_up_out = torch.bmm(selected_gate_up, current_hidden_states.unsqueeze(-1)).squeeze(-1) + if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids] @@ -113,7 +119,11 @@ def batched_mm_experts_forward( hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (batched) --- - out_per_sample = torch.bmm(selected_down, hidden_after_activation.unsqueeze(-1)).squeeze(-1) + if selected_down.shape == (num_tokens * num_top_k, hidden_dim, intermediate_dim): + out_per_sample = torch.bmm(hidden_after_activation.unsqueeze(1), selected_down).squeeze(1) + else: + out_per_sample = torch.bmm(selected_down, hidden_after_activation.unsqueeze(-1)).squeeze(-1) + if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: out_per_sample = out_per_sample + self.down_proj_bias[expert_ids] @@ -138,9 +148,11 @@ def grouped_mm_experts_forward( ) device = hidden_states.device + hidden_dim = self.hidden_size num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) num_experts = self.gate_up_proj.size(0) + intermediate_dim = self.intermediate_size final_hidden_states = torch.zeros_like(hidden_states) # Flatten top_k_index to get expert_ids per selected sample @@ -179,7 +191,11 @@ def grouped_mm_experts_forward( offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # --- Up projection per expert (grouped_mm) --- - gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj.transpose(-2, -1), offs=offsets) + if self.gate_up_proj.shape == (num_experts, hidden_dim, 2 * intermediate_dim): + gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj, offs=offsets) + else: + gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj.transpose(-2, -1), offs=offsets) + if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: # we should be able to pass bias to the grouped_mm call, but it's still not fully supported gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids_g] @@ -194,7 +210,11 @@ def grouped_mm_experts_forward( hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (grouped_mm) --- - out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj.transpose(-2, -1), offs=offsets) + if self.down_proj.shape == (num_experts, hidden_dim, intermediate_dim): + out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj, offs=offsets) + else: + out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj.transpose(-2, -1), offs=offsets) + if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: # we should be able to pass bias to the grouped_mm call, but it's still not fully supported out_per_sample_g = out_per_sample_g + self.down_proj_bias[expert_ids_g] diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index b3245d2a0a06..f0a8de789f48 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -270,7 +270,7 @@ def convert_moe_packed_tensors( out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) del blocks, scales, lut - return out.contiguous() + return out.transpose(1, 2).contiguous() class Mxfp4GptOssExperts(nn.Module): diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index d93fc00e8ece..a2920a5b84f5 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -66,18 +66,16 @@ def extra_repr(self): @use_experts_implementation class GptOssExperts(nn.Module): - """Collection of expert weights stored as 3D tensors.""" - def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size)) - self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.zeros((self.num_experts, self.hidden_size, self.expert_dim))) - self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 @@ -89,35 +87,32 @@ def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: gated_output = (up + 1) * glu return gated_output - def forward( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - ) -> torch.Tensor: - final_hidden_states = torch.zeros_like(hidden_states) + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=self.num_experts + ) # masking is also a class expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx in expert_hit: + for expert_idx in expert_hit[:]: + # expert_idx only have 1 element, so we can use scale for fast indexing expert_idx = expert_idx[0] + # skip masking index if expert_idx == self.num_experts: continue - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + with torch.no_grad(): + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] - gate_up = nn.functional.linear( - current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_bias[expert_idx] - ) - current_hidden_states = self._apply_gate(gate_up) - current_hidden_states = nn.functional.linear( - current_hidden_states, self.down_proj[expert_idx], self.down_proj_bias[expert_idx] - ) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] - final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gated_output = self._apply_gate(gate_up) + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out * routing_weights[token_idx, top_k_pos, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - return final_hidden_states + return next_states class GptOssTopKRouter(nn.Module): diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index b0730172e3dc..3a91e5ebc916 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -63,18 +63,16 @@ def forward(self, hidden_states): @use_experts_implementation class GptOssExperts(nn.Module): - """Collection of expert weights stored as 3D tensors.""" - def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, self.hidden_size)) - self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.zeros((self.num_experts, self.hidden_size, self.expert_dim))) - self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 @@ -86,35 +84,32 @@ def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: gated_output = (up + 1) * glu return gated_output - def forward( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - ) -> torch.Tensor: - final_hidden_states = torch.zeros_like(hidden_states) + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=self.num_experts + ) # masking is also a class expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx in expert_hit: + for expert_idx in expert_hit[:]: + # expert_idx only have 1 element, so we can use scale for fast indexing expert_idx = expert_idx[0] + # skip masking index if expert_idx == self.num_experts: continue - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + with torch.no_grad(): + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] - gate_up = nn.functional.linear( - current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_bias[expert_idx] - ) - current_hidden_states = self._apply_gate(gate_up) - current_hidden_states = nn.functional.linear( - current_hidden_states, self.down_proj[expert_idx], self.down_proj_bias[expert_idx] - ) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] - final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gated_output = self._apply_gate(gate_up) + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out * routing_weights[token_idx, top_k_pos, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - return final_hidden_states + return next_states class GptOssTopKRouter(nn.Module): diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index c68a9dea0efa..0591ce62e276 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -48,7 +48,6 @@ import torch from transformers import GptOssModel - from transformers.utils.quantization_config import Mxfp4Config NUM_GPUS = torch.cuda.device_count() @@ -119,10 +118,6 @@ def test_flex_attention_with_grads(self): def test_generate_compile_model_forward_fullgraph(self): return super().test_generate_compile_model_forward_fullgraph() - @unittest.skip("GptOss's conversion mapping does not rename weights, only transposes them") - def test_reverse_loading_mapping(self): - pass - RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json" @@ -346,13 +341,6 @@ def run_distributed_test(quantized, model, kernels, attn_impl, mode): # ------------------------ @parameterized.expand(PARAMETERS) def test_model_outputs(self, quantized, model, kernels, attn_impl, mode): - additional_kwargs = {} - if not quantized: - additional_kwargs = { - "quantization_config": Mxfp4Config(dequantize=True), - "experts_implementation": "eager", - } - model_id = f"openai/gpt-oss-{model}" output_texts = self.load_and_forward( model_id, @@ -360,7 +348,6 @@ def test_model_outputs(self, quantized, model, kernels, attn_impl, mode): self.input_text, mode=mode, use_kernels=kernels, - **additional_kwargs, ) # Generate key to look up expected outputs @@ -429,13 +416,6 @@ def test_training_step(self, quantized, model, kernels, attn_impl, mode): if quantized: self.skipTest("Training test for quantized models is not supported.") - additional_kwargs = {} - if not quantized: - additional_kwargs = { - "quantization_config": Mxfp4Config(dequantize=True), - "experts_implementation": "eager", - } - model_id = f"openai/gpt-oss-{model}" model_obj = AutoModelForCausalLM.from_pretrained( model_id, @@ -443,7 +423,6 @@ def test_training_step(self, quantized, model, kernels, attn_impl, mode): device_map="auto", attn_implementation=attn_impl, use_kernels=kernels, - **additional_kwargs, ) model_obj.train() From e1dba4d3352905d05027107231d9395eee73c40c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 13 Jan 2026 16:50:36 +0100 Subject: [PATCH 06/28] style --- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 2 +- src/transformers/models/gpt_oss/modular_gpt_oss.py | 2 +- tests/models/gpt_oss/test_modeling_gpt_oss.py | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index a2920a5b84f5..5fb9ff2a7f96 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -97,7 +97,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig # we sum on the top_k and on the sequence length to get which experts # are hit this time around expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit[:]: + for expert_idx in expert_hit: # expert_idx only have 1 element, so we can use scale for fast indexing expert_idx = expert_idx[0] # skip masking index diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 3a91e5ebc916..1dd1a728051a 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -94,7 +94,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig # we sum on the top_k and on the sequence length to get which experts # are hit this time around expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit[:]: + for expert_idx in expert_hit: # expert_idx only have 1 element, so we can use scale for fast indexing expert_idx = expert_idx[0] # skip masking index diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 0591ce62e276..772fd7f03013 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -47,7 +47,9 @@ if is_torch_available(): import torch - from transformers import GptOssModel + from transformers import ( + GptOssModel, + ) NUM_GPUS = torch.cuda.device_count() @@ -417,6 +419,7 @@ def test_training_step(self, quantized, model, kernels, attn_impl, mode): self.skipTest("Training test for quantized models is not supported.") model_id = f"openai/gpt-oss-{model}" + model_obj = AutoModelForCausalLM.from_pretrained( model_id, dtype=torch.bfloat16, From 0261a467116de4a9103b58200f9d82f8fb5a0e86 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 13 Jan 2026 17:17:53 +0100 Subject: [PATCH 07/28] don't rely on weight shapes as they can be square matrices --- src/transformers/integrations/moe.py | 12 ++++-------- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 8 ++++---- src/transformers/models/gpt_oss/modular_gpt_oss.py | 8 ++++---- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 85de8e756b5a..3f0e1f706896 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -69,11 +69,9 @@ def batched_mm_experts_forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: device = hidden_states.device - hidden_dim = self.hidden_size num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) num_experts = self.gate_up_proj.size(0) - intermediate_dim = self.intermediate_size final_hidden_states = torch.zeros_like(hidden_states) # Flatten top_k_index to get expert_ids per selected sample @@ -101,7 +99,7 @@ def batched_mm_experts_forward( selected_down = self.down_proj[expert_ids] # (S, hidden_dim, intermediate_dim) # --- Up projection per expert (batched) --- - if selected_gate_up.shape == (num_tokens * num_top_k, hidden_dim, 2 * intermediate_dim): + if getattr(self, "experts_are_transposed", False): gate_up_out = torch.bmm(current_hidden_states.unsqueeze(1), selected_gate_up).squeeze(1) else: gate_up_out = torch.bmm(selected_gate_up, current_hidden_states.unsqueeze(-1)).squeeze(-1) @@ -119,7 +117,7 @@ def batched_mm_experts_forward( hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (batched) --- - if selected_down.shape == (num_tokens * num_top_k, hidden_dim, intermediate_dim): + if getattr(self, "experts_are_transposed", False): out_per_sample = torch.bmm(hidden_after_activation.unsqueeze(1), selected_down).squeeze(1) else: out_per_sample = torch.bmm(selected_down, hidden_after_activation.unsqueeze(-1)).squeeze(-1) @@ -148,11 +146,9 @@ def grouped_mm_experts_forward( ) device = hidden_states.device - hidden_dim = self.hidden_size num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) num_experts = self.gate_up_proj.size(0) - intermediate_dim = self.intermediate_size final_hidden_states = torch.zeros_like(hidden_states) # Flatten top_k_index to get expert_ids per selected sample @@ -191,7 +187,7 @@ def grouped_mm_experts_forward( offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # --- Up projection per expert (grouped_mm) --- - if self.gate_up_proj.shape == (num_experts, hidden_dim, 2 * intermediate_dim): + if getattr(self, "experts_are_transposed", False): gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj, offs=offsets) else: gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj.transpose(-2, -1), offs=offsets) @@ -210,7 +206,7 @@ def grouped_mm_experts_forward( hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (grouped_mm) --- - if self.down_proj.shape == (num_experts, hidden_dim, intermediate_dim): + if getattr(self, "experts_are_transposed", False): out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj, offs=offsets) else: out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj.transpose(-2, -1), offs=offsets) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 5fb9ff2a7f96..93e91d6d40f6 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -71,10 +71,10 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size - self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) - self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.experts_are_transposed = True + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.intermediate_size)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_size)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.intermediate_size, self.hidden_size))) self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 1dd1a728051a..76e0cd587363 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -68,10 +68,10 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size - self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) - self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.experts_are_transposed = True + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.intermediate_size)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_size)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.intermediate_size, self.hidden_size))) self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 From 5bd25c75ce6067c51169543ce1ef1ced0e6d3b8a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 14 Jan 2026 11:55:11 +0100 Subject: [PATCH 08/28] no need to relaod --- tests/test_modeling_common.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b58061da766b..3b7f1e6d22cc 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -560,14 +560,9 @@ def _test_eager_matches_batched_and_grouped_inference(self, name, dtype): for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() set_config_for_less_flaky_test(config) - model = model_class(config) + model = model_class(config).eval().to(torch_device).to(dtype) set_model_for_less_flaky_test(model) - # Load with dtype - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, dtype=dtype).eval().to(torch_device) - with torch.no_grad(): inputs_dict = {k: v.to(dtype) if torch.is_floating_point(v) else v for k, v in inputs_dict.items()} prepared_inputs = self._prepare_for_class(inputs_dict, model_class) @@ -840,6 +835,7 @@ def test_save_load(self): model = model_class.from_pretrained(tmpdirname) model.to(torch_device) + model.eval() with torch.no_grad(): second = model(**self._prepare_for_class(inputs_dict, model_class))[0] From 846adcad9df3254d7058f163077f9821ef920f6e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 14 Jan 2026 13:01:44 +0100 Subject: [PATCH 09/28] fallback to eager --- src/transformers/integrations/moe.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 3f0e1f706896..296ec46041bc 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -16,11 +16,13 @@ from ..utils.generic import GeneralInterface from ..utils.import_utils import is_torch_available +from ..utils.logging import get_logger if is_torch_available(): import torch +logger = get_logger(__name__) # Examples of experts class with its eager mm implementation # class Experts(nn.Module): @@ -252,6 +254,15 @@ def __init__(self, config, *args, **kwargs): def forward(self, *args, **kwargs): experts_forward = original_forward + if self.config._experts_implementation == "grouped_mm": + if self.gate_up_proj.data_ptr() % 16 != 0 or self.down_proj.data_ptr() % 16 != 0: + logger.warning( + "'grouped_mm' experts implementation requires 16-byte aligned expert weights. " + "We will fall back to 'eager' implementation to avoid potential crashes. " + "Please re-initialize the expert weights with 16-byte alignment." + ) + self.config._experts_implementation = "eager" + if self.config._experts_implementation != "eager": experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation] From b1a71a79db9738f5d99af0ff5adbb9e9a810c22b Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Wed, 14 Jan 2026 16:17:10 +0100 Subject: [PATCH 10/28] Update src/transformers/models/gpt_oss/modeling_gpt_oss.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 93e91d6d40f6..6ae6771828ab 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -64,7 +64,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -@use_experts_implementation +@use_experts_implementation(transposed_experts=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() From 9dbed89b75c4514308f73daf0262cd0314c6c316 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 15 Jan 2026 09:01:11 +0100 Subject: [PATCH 11/28] fix --- src/transformers/integrations/moe.py | 77 ++++++++++++------- .../models/gpt_oss/modeling_gpt_oss.py | 6 +- .../models/gpt_oss/modular_gpt_oss.py | 6 +- 3 files changed, 53 insertions(+), 36 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 296ec46041bc..188147a88ed2 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -101,7 +101,7 @@ def batched_mm_experts_forward( selected_down = self.down_proj[expert_ids] # (S, hidden_dim, intermediate_dim) # --- Up projection per expert (batched) --- - if getattr(self, "experts_are_transposed", False): + if getattr(self, "transposed_weights", False): gate_up_out = torch.bmm(current_hidden_states.unsqueeze(1), selected_gate_up).squeeze(1) else: gate_up_out = torch.bmm(selected_gate_up, current_hidden_states.unsqueeze(-1)).squeeze(-1) @@ -119,7 +119,7 @@ def batched_mm_experts_forward( hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (batched) --- - if getattr(self, "experts_are_transposed", False): + if getattr(self, "transposed_weights", False): out_per_sample = torch.bmm(hidden_after_activation.unsqueeze(1), selected_down).squeeze(1) else: out_per_sample = torch.bmm(selected_down, hidden_after_activation.unsqueeze(-1)).squeeze(-1) @@ -189,7 +189,7 @@ def grouped_mm_experts_forward( offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # --- Up projection per expert (grouped_mm) --- - if getattr(self, "experts_are_transposed", False): + if getattr(self, "transposed_weights", False): gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj, offs=offsets) else: gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj.transpose(-2, -1), offs=offsets) @@ -208,7 +208,7 @@ def grouped_mm_experts_forward( hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) # --- Down projection per expert (grouped_mm) --- - if getattr(self, "experts_are_transposed", False): + if getattr(self, "transposed_weights", False): out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj, offs=offsets) else: out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj.transpose(-2, -1), offs=offsets) @@ -241,33 +241,54 @@ class ExpertsInterface(GeneralInterface): ALL_EXPERTS_FUNCTIONS = ExpertsInterface() -def use_experts_implementation(experts_class: type[torch.nn.Module]) -> type[torch.nn.Module]: - original_init = experts_class.__init__ - original_forward = experts_class.forward +def use_experts_implementation( + experts_class: type[torch.nn.Module] | None = None, *, transposed_weights: bool = False +) -> type[torch.nn.Module]: + """Decorator to modify experts class to support different experts implementations. - @wraps(original_init) - def __init__(self, config, *args, **kwargs): - original_init(self, config, *args, **kwargs) - self.config = config + Args: + experts_class (`type[torch.nn.Module]`, *optional*): + The experts class to modify. If not provided, returns a decorator that can be applied to the class. + transposed_weights (`bool`, *optional*, defaults to `False`): + Whether the expert weights are stored in transposed format. - @wraps(original_forward) - def forward(self, *args, **kwargs): - experts_forward = original_forward + Returns: + `type[torch.nn.Module]`: The modified experts class. + """ - if self.config._experts_implementation == "grouped_mm": - if self.gate_up_proj.data_ptr() % 16 != 0 or self.down_proj.data_ptr() % 16 != 0: - logger.warning( - "'grouped_mm' experts implementation requires 16-byte aligned expert weights. " - "We will fall back to 'eager' implementation to avoid potential crashes. " - "Please re-initialize the expert weights with 16-byte alignment." - ) - self.config._experts_implementation = "eager" + def wrapper(experts_class: type[torch.nn.Module]) -> type[torch.nn.Module]: + original_init = experts_class.__init__ + original_forward = experts_class.forward - if self.config._experts_implementation != "eager": - experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation] + @wraps(original_init) + def __init__(self, config, *args, **kwargs): + original_init(self, config, *args, **kwargs) + self.transposed_weights = transposed_weights + self.config = config - return experts_forward(self, *args, **kwargs) + @wraps(original_forward) + def forward(self, *args, **kwargs): + experts_forward = original_forward - experts_class.__init__ = __init__ - experts_class.forward = forward - return experts_class + if self.config._experts_implementation == "grouped_mm": + if self.gate_up_proj.data_ptr() % 16 != 0 or self.down_proj.data_ptr() % 16 != 0: + logger.warning( + "'grouped_mm' experts implementation requires 16-byte aligned expert weights. " + "We will fall back to 'eager' implementation to avoid potential crashes. " + "Please re-initialize the expert weights with 16-byte alignment." + ) + self.config._experts_implementation = "eager" + + if self.config._experts_implementation != "eager": + experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation] + + return experts_forward(self, *args, **kwargs) + + experts_class.__init__ = __init__ + experts_class.forward = forward + return experts_class + + if experts_class is not None: + return wrapper(experts_class) + + return wrapper diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 6ae6771828ab..bc21b77ed791 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -64,14 +64,13 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -@use_experts_implementation(transposed_experts=True) +@use_experts_implementation(transposed_weights=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size - self.experts_are_transposed = True self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.intermediate_size)) self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_size)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.intermediate_size, self.hidden_size))) @@ -103,8 +102,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig # skip masking index if expert_idx == self.num_experts: continue - with torch.no_grad(): - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] gated_output = self._apply_gate(gate_up) diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 76e0cd587363..8b547a35dc7b 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -61,14 +61,13 @@ def forward(self, hidden_states): return (self.weight * hidden_states).to(input_dtype) # main diff with Llama -@use_experts_implementation +@use_experts_implementation(transposed_weights=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size - self.experts_are_transposed = True self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.intermediate_size)) self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_size)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.intermediate_size, self.hidden_size))) @@ -100,8 +99,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig # skip masking index if expert_idx == self.num_experts: continue - with torch.no_grad(): - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] gated_output = self._apply_gate(gate_up) From 2f3fd11c7b88234e2217dd7354f5d536636b4ed5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 15 Jan 2026 09:50:59 +0100 Subject: [PATCH 12/28] force 16 bytes alignmenet during weight loading --- src/transformers/conversion_mapping.py | 13 ++++++++ src/transformers/core_model_loading.py | 33 +++++++++++++++++++ src/transformers/integrations/moe.py | 9 ----- tests/models/gpt_oss/test_modeling_gpt_oss.py | 4 +++ 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index e56db36874b6..7acb9f0707dd 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -21,6 +21,7 @@ Chunk, Concatenate, ErnieFuseAndSplitTextVisionExperts, + Force16BytesAlignment, MergeModulelist, Transpose, WeightConverter, @@ -40,6 +41,18 @@ def _build_checkpoint_conversion_mapping(): mapping = { + "gpt_oss": [ + WeightConverter( + source_patterns="mlp.experts.gate_up_proj", + target_patterns="mlp.experts.gate_up_proj", + operations=[Force16BytesAlignment()], + ), + WeightConverter( + source_patterns="mlp.experts.down_proj", + target_patterns="mlp.experts.down_proj", + operations=[Force16BytesAlignment()], + ), + ], "mixtral": [ WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"), WeightConverter( diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 9e43baf498b1..6ff8da9167fc 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -439,6 +439,39 @@ def reverse_op(self) -> ConversionOps: return Transpose(dim0=self.dim1, dim1=self.dim0) +class Force16BytesAlignment(ConversionOps): + """ + Ensures that the given tensor is 16-bytes aligned in memory. + """ + + @torch.no_grad() + def convert( + self, + input_dict: dict[str, list[torch.Tensor]], + source_patterns: list[str], + target_patterns: list[str], + config, + **kwargs, + ) -> dict[str, list[torch.Tensor]]: + if len(input_dict) != len(target_patterns): + raise ValueError( + f"Force16BytesAlignment conversion can only happen on each key ({len(input_dict)}) " + f"and should match exact one target ({len(target_patterns)})." + ) + + output: dict[str, list[torch.Tensor]] = {} + for key, target_pattern in zip(input_dict.keys(), target_patterns): + tensor = input_dict.get(key, []) + if len(tensor) != 1: + raise ValueError(f"Force16BytesAlignment conversion requires exactly one tensor, found {len(tensor)}.") + output[target_pattern] = tensor[0].clone() if tensor[0].data_ptr() % 16 == 0 else tensor[0].clone() + return output + + @property + def reverse_op(self) -> ConversionOps: + return deepcopy(self) + + @dataclass(slots=True) class WeightTransform: source_patterns: str | list[str] = field(init=True) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 188147a88ed2..a473be0ce369 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -270,15 +270,6 @@ def __init__(self, config, *args, **kwargs): def forward(self, *args, **kwargs): experts_forward = original_forward - if self.config._experts_implementation == "grouped_mm": - if self.gate_up_proj.data_ptr() % 16 != 0 or self.down_proj.data_ptr() % 16 != 0: - logger.warning( - "'grouped_mm' experts implementation requires 16-byte aligned expert weights. " - "We will fall back to 'eager' implementation to avoid potential crashes. " - "Please re-initialize the expert weights with 16-byte alignment." - ) - self.config._experts_implementation = "eager" - if self.config._experts_implementation != "eager": experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation] diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 772fd7f03013..9c12f5e5f7c2 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -120,6 +120,10 @@ def test_flex_attention_with_grads(self): def test_generate_compile_model_forward_fullgraph(self): return super().test_generate_compile_model_forward_fullgraph() + @unittest.skip("GptOss does not rename any weights in its conversion mapping") + def test_reverse_loading_mapping(self): + pass + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json" From dd377e19f9fba4e5c87f7ed6775edba07993758b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 15 Jan 2026 12:16:16 +0100 Subject: [PATCH 13/28] simplify logic --- src/transformers/integrations/moe.py | 176 ++++++++++++------ .../models/gpt_oss/modeling_gpt_oss.py | 2 +- .../models/gpt_oss/modular_gpt_oss.py | 2 +- tests/test_modeling_common.py | 5 + 4 files changed, 125 insertions(+), 60 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index a473be0ce369..c719e8690664 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -64,6 +64,43 @@ # return final_hidden_states +def _batched_linear( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + is_transposed: bool = False, +) -> torch.Tensor: + """Batched linear layer supporting optional bias and transposed weights. + + Args: + input (`torch.Tensor`): + Input tensor of shape (batch_size, input_dim). + weight (`torch.Tensor`): + Weight tensor of shape (batch_size, output_dim, input_dim) if transposed is `False`, + else of shape (batch_size, input_dim, output_dim). + bias (`torch.Tensor`, *optional*): + Bias tensor of shape (batch_size, output_dim). Default is `None`. + is_transposed (`bool`, *optional*, defaults to `False`): + Whether the weight tensor is transposed. + Returns: + `torch.Tensor`: Output tensor of shape (batch_size, output_dim). + """ + if bias is not None: + if is_transposed: + # (batch_size, 1, output_dim) + (batch_size, 1, input_dim) @ (batch_size, input_dim, output_dim) -> (batch_size, 1, output_dim) -> (batch_size, output_dim) + return torch.baddbmm(bias.unsqueeze(1), input.unsqueeze(1), weight).squeeze(1) + else: + # (batch_size, output_dim, 1) + (batch_size, output_dim, input_dim) @ (batch_size, input_dim, 1) -> (batch_size, output_dim, 1) -> (batch_size, output_dim) + return torch.baddbmm(bias.unsqueeze(-1), weight, input.unsqueeze(-1)).squeeze(-1) + else: + if is_transposed: + # (batch_size, 1, input_dim) @ (batch_size, input_dim, output_dim) -> (batch_size, 1, output_dim) -> (batch_size, output_dim) + return torch.bmm(input.unsqueeze(1), weight).squeeze(1) + else: + # (batch_size, output_dim, input_dim) @ (batch_size, input_dim, 1) -> (batch_size, output_dim, 1) -> (batch_size, output_dim) + return torch.bmm(weight, input.unsqueeze(-1)).squeeze(-1) + + def batched_mm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -94,38 +131,26 @@ def batched_mm_experts_forward( ) # Get current hidden states for selected samples - current_hidden_states = hidden_states[token_idx] # (S, hidden_dim) + selected_hidden_states = hidden_states[token_idx] - # Select projection matrices for selected experts - selected_gate_up = self.gate_up_proj[expert_ids] # (S, hidden_dim, 2 * intermediate_dim) - selected_down = self.down_proj[expert_ids] # (S, hidden_dim, intermediate_dim) + # Select expert weights and biases for selected samples + selected_gate_up = self.gate_up_proj[expert_ids] + selected_down = self.down_proj[expert_ids] + selected_gate_up_bias = self.gate_up_proj_bias[expert_ids] if self.has_bias else None + selected_down_bias = self.down_proj_bias[expert_ids] if self.has_bias else None # --- Up projection per expert (batched) --- - if getattr(self, "transposed_weights", False): - gate_up_out = torch.bmm(current_hidden_states.unsqueeze(1), selected_gate_up).squeeze(1) - else: - gate_up_out = torch.bmm(selected_gate_up, current_hidden_states.unsqueeze(-1)).squeeze(-1) - - if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: - gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids] + gate_up_out = _batched_linear( + selected_hidden_states, selected_gate_up, bias=selected_gate_up_bias, is_transposed=self.is_transposed + ) # (S, 2 * intermediate_dim) # Apply gating - if hasattr(self, "_apply_gate"): - # Custom gating if defined - hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim) - else: - # Default gating mechanism - gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim) - hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) + hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim) # --- Down projection per expert (batched) --- - if getattr(self, "transposed_weights", False): - out_per_sample = torch.bmm(hidden_after_activation.unsqueeze(1), selected_down).squeeze(1) - else: - out_per_sample = torch.bmm(selected_down, hidden_after_activation.unsqueeze(-1)).squeeze(-1) - - if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: - out_per_sample = out_per_sample + self.down_proj_bias[expert_ids] + out_per_sample = _batched_linear( + hidden_after_activation, selected_down, bias=selected_down_bias, is_transposed=self.is_transposed + ) # (S, hidden_dim) # Apply routing weights out_per_sample = out_per_sample * sample_weights.unsqueeze(-1) # (S, hidden_dim) @@ -136,6 +161,44 @@ def batched_mm_experts_forward( return final_hidden_states +def _grouped_linear( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + offs: torch.Tensor | None = None, + is_transposed: bool = False, +) -> torch.Tensor: + """Grouped linear layer supporting optional bias and transposed weights. + + Args: + input (`torch.Tensor`): + Input tensor of shape (S, input_dim). + weight (`torch.Tensor`): + Weight tensor of shape (num_experts, output_dim, input_dim) if transposed is `False`, + else of shape (num_experts, input_dim, output_dim). + bias (`torch.Tensor`, *optional*): + Bias tensor of shape (num_experts, output_dim). Default is `None`. + offs (`torch.Tensor`, *optional*): + Offsets tensor indicating the boundaries of each group in the input tensor. + is_transposed (`bool`, *optional*, defaults to `False`): + Whether the weight tensor is transposed. + Returns: + `torch.Tensor`: Output tensor of shape (S, output_dim). + """ + if is_transposed: + # (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim) + out = torch._grouped_mm(input, weight, offs=offs) + else: + # (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim) + out = torch._grouped_mm(input, weight.transpose(-2, -1), offs=offs) + + if bias is not None: + # We should be able to pass bias to the grouped_mm call, but it's not yet supported. + out = out + bias + + return out + + def grouped_mm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -157,10 +220,6 @@ def grouped_mm_experts_forward( expert_ids = top_k_index.reshape(-1) token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) - # Get permutation to group by expert - perm = torch.argsort(expert_ids, stable=True) - inv_perm = torch.argsort(perm, stable=True) - # Resolve routing weights per selected sample, allowing top_k_weights to be either: # - (num_tokens, num_top_k) Qwen2MoE style # - (num_tokens, num_experts) DeepseekV2 style @@ -175,47 +234,37 @@ def grouped_mm_experts_forward( ) # Get current hidden states for selected samples - current_hidden_states = hidden_states[token_idx] # (S, hidden_dim) + current_hidden_states = hidden_states[token_idx] + + # Sort by expert for grouped processing + perm = torch.argsort(expert_ids, stable=True) + inv_perm = torch.argsort(perm, stable=True) - # Group by expert for grouped_mm + # Group by expert expert_ids_g = expert_ids[perm] sample_weights_g = sample_weights[perm] current_states_g = current_hidden_states[perm] + selected_gate_up_bias = self.gate_up_proj_bias[expert_ids_g] if self.has_bias else None + selected_down_bias = self.down_proj_bias[expert_ids_g] if self.has_bias else None # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues # (grouped_mm_experts_forward still fails with cuda graphs but because of _grouped_mm internals) num_tokens_per_expert = torch.histc(expert_ids_g.float(), bins=num_experts, min=0, max=num_experts - 1) - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + offs = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # --- Up projection per expert (grouped_mm) --- - if getattr(self, "transposed_weights", False): - gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj, offs=offsets) - else: - gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj.transpose(-2, -1), offs=offsets) - - if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: - # we should be able to pass bias to the grouped_mm call, but it's still not fully supported - gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids_g] + gate_up_out = _grouped_linear( + current_states_g, self.gate_up_proj, bias=selected_gate_up_bias, is_transposed=self.is_transposed, offs=offs + ) # (S, 2 * intermediate_dim) # Apply gating - if hasattr(self, "_apply_gate"): - # Custom gating if defined - hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim) - else: - # Default gating mechanism - gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim) - hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) + hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim) # --- Down projection per expert (grouped_mm) --- - if getattr(self, "transposed_weights", False): - out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj, offs=offsets) - else: - out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj.transpose(-2, -1), offs=offsets) - - if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: - # we should be able to pass bias to the grouped_mm call, but it's still not fully supported - out_per_sample_g = out_per_sample_g + self.down_proj_bias[expert_ids_g] + out_per_sample_g = _grouped_linear( + hidden_after_activation, self.down_proj, bias=selected_down_bias, is_transposed=self.is_transposed, offs=offs + ) # (S, hidden_dim) # Apply routing weights out_per_sample_g = out_per_sample_g * sample_weights_g.unsqueeze(-1) @@ -242,15 +291,17 @@ class ExpertsInterface(GeneralInterface): def use_experts_implementation( - experts_class: type[torch.nn.Module] | None = None, *, transposed_weights: bool = False + experts_class: type[torch.nn.Module] | None = None, *, is_transposed: bool = False, has_bias: bool = False ) -> type[torch.nn.Module]: """Decorator to modify experts class to support different experts implementations. Args: experts_class (`type[torch.nn.Module]`, *optional*): The experts class to modify. If not provided, returns a decorator that can be applied to the class. - transposed_weights (`bool`, *optional*, defaults to `False`): + is_transposed (`bool`, *optional*, defaults to `False`): Whether the expert weights are stored in transposed format. + has_bias (`bool`, *optional*, defaults to `False`): + Whether the expert layers include bias terms. Returns: `type[torch.nn.Module]`: The modified experts class. @@ -263,8 +314,9 @@ def wrapper(experts_class: type[torch.nn.Module]) -> type[torch.nn.Module]: @wraps(original_init) def __init__(self, config, *args, **kwargs): original_init(self, config, *args, **kwargs) - self.transposed_weights = transposed_weights self.config = config + self.has_bias = has_bias + self.is_transposed = is_transposed @wraps(original_forward) def forward(self, *args, **kwargs): @@ -275,6 +327,14 @@ def forward(self, *args, **kwargs): return experts_forward(self, *args, **kwargs) + if not hasattr(experts_class, "_apply_gate"): + + def _apply_gate(self, gate_up_out: torch.Tensor) -> torch.Tensor: + gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim) + return self.act_fn(gate) * up # (S, intermediate_dim) + + experts_class._apply_gate = _apply_gate + experts_class.__init__ = __init__ experts_class.forward = forward return experts_class diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index bc21b77ed791..1feeef2c9334 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -64,7 +64,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -@use_experts_implementation(transposed_weights=True) +@use_experts_implementation(is_transposed=True, has_bias=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 8b547a35dc7b..eb39af4a09c2 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -61,7 +61,7 @@ def forward(self, hidden_states): return (self.weight * hidden_states).to(input_dtype) # main diff with Llama -@use_experts_implementation(transposed_weights=True) +@use_experts_implementation(is_transposed=True, has_bias=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3b7f1e6d22cc..dc4138fdaad8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2690,6 +2690,11 @@ def test_multi_gpu_data_parallel_forward(self): model.to(0) model.eval() + if model.config._experts_implementation == "grouped_mm": + # DataParallel does not respect buffer alignment when replicating the model on + # multiple GPUs, which can cause errors in grouped_mm experts implementation. + model.set_experts_implementation("eager") + # Wrap model in nn.DataParallel model = nn.DataParallel(model) torch.cuda.synchronize() # otherwise the transfer might not be complete From 52e077869c9f1e6eaaeb57e8ffdc5fcf99ff753e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 15 Jan 2026 12:37:19 +0100 Subject: [PATCH 14/28] quantization conversions should be applied first --- src/transformers/conversion_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 7acb9f0707dd..c52e7a4008b2 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -345,6 +345,6 @@ def get_model_conversion_mapping( # Add the ones from the quantizer as well if provided if hf_quantizer is not None: - weight_conversions.extend(hf_quantizer.get_weight_conversions()) + weight_conversions = hf_quantizer.get_weight_conversions() + weight_conversions return weight_conversions From 1c491124719e9c380420e458f137d8ba78f98f40 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 15 Jan 2026 12:59:49 +0100 Subject: [PATCH 15/28] avoid baddbmm as it is less performant / less optimizable by max-autotune --- src/transformers/integrations/moe.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index c719e8690664..8b23c6bdf50c 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -85,20 +85,17 @@ def _batched_linear( Returns: `torch.Tensor`: Output tensor of shape (batch_size, output_dim). """ - if bias is not None: - if is_transposed: - # (batch_size, 1, output_dim) + (batch_size, 1, input_dim) @ (batch_size, input_dim, output_dim) -> (batch_size, 1, output_dim) -> (batch_size, output_dim) - return torch.baddbmm(bias.unsqueeze(1), input.unsqueeze(1), weight).squeeze(1) - else: - # (batch_size, output_dim, 1) + (batch_size, output_dim, input_dim) @ (batch_size, input_dim, 1) -> (batch_size, output_dim, 1) -> (batch_size, output_dim) - return torch.baddbmm(bias.unsqueeze(-1), weight, input.unsqueeze(-1)).squeeze(-1) + if is_transposed: + # (batch_size, 1, input_dim) @ (batch_size, input_dim, output_dim) -> (batch_size, 1, output_dim) -> (batch_size, output_dim) + out = torch.bmm(input.unsqueeze(1), weight).squeeze(1) else: - if is_transposed: - # (batch_size, 1, input_dim) @ (batch_size, input_dim, output_dim) -> (batch_size, 1, output_dim) -> (batch_size, output_dim) - return torch.bmm(input.unsqueeze(1), weight).squeeze(1) - else: - # (batch_size, output_dim, input_dim) @ (batch_size, input_dim, 1) -> (batch_size, output_dim, 1) -> (batch_size, output_dim) - return torch.bmm(weight, input.unsqueeze(-1)).squeeze(-1) + # (batch_size, output_dim, input_dim) @ (batch_size, input_dim, 1) -> (batch_size, output_dim, 1) -> (batch_size, output_dim) + out = torch.bmm(weight, input.unsqueeze(-1)).squeeze(-1) + + if bias is not None: + out = out + bias + + return out def batched_mm_experts_forward( From 4b0323ce4e81948b1f9f11bb60a7300aa4aacc8c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 15 Jan 2026 13:00:21 +0100 Subject: [PATCH 16/28] no need for logger --- src/transformers/integrations/moe.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 8b23c6bdf50c..5cf750915463 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -16,14 +16,11 @@ from ..utils.generic import GeneralInterface from ..utils.import_utils import is_torch_available -from ..utils.logging import get_logger if is_torch_available(): import torch -logger = get_logger(__name__) - # Examples of experts class with its eager mm implementation # class Experts(nn.Module): # """Collection of expert weights stored as 3D tensors.""" From f094c319321827e81290a0878806a6d0f1a0eba4 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 16 Jan 2026 09:04:58 +0100 Subject: [PATCH 17/28] add comment explaining limitation --- src/transformers/conversion_mapping.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index ac58f9f89730..5ff1d6bc8dd2 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -346,6 +346,10 @@ def get_model_conversion_mapping( # Add the ones from the quantizer as well if provided if hf_quantizer is not None: + # NOTE: Since get_weight_conversions() only serves to dequantize, we need to put them first in the list. + # However, for now it's not possible to match 1 param with 2 converters (i.e. 1 dequantization converter + # and 1 model-specific converter. Which means that if a model that has model-specific conversions and is being + # dequantized, the model-specific conversion that has patterns matching the dequantization patterns will be ignored. weight_conversions = hf_quantizer.get_weight_conversions() + weight_conversions return weight_conversions From 221f9bdae0bbb386af55cf9b0bf569446532d3d1 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 16 Jan 2026 11:37:37 +0100 Subject: [PATCH 18/28] standarize operations and only reshape when needed --- src/transformers/integrations/moe.py | 55 ++++++++++++++++------------ 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 5cf750915463..53236c14e45f 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -108,16 +108,18 @@ def batched_mm_experts_forward( final_hidden_states = torch.zeros_like(hidden_states) # Flatten top_k_index to get expert_ids per selected sample - expert_ids = top_k_index.reshape(-1) + expert_ids = top_k_index.view(-1) token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # Resolve routing weights per selected sample, allowing top_k_weights to be either: # - (num_tokens, num_top_k) Qwen2MoE style # - (num_tokens, num_experts) DeepseekV2 style if top_k_weights.shape == (num_tokens, num_top_k): - sample_weights = top_k_weights.reshape(-1) # (S,) + sample_weights = top_k_weights.view(-1, 1) # (S, 1) elif top_k_weights.shape == (num_tokens, num_experts): - sample_weights = top_k_weights[token_idx, expert_ids] # (S,) + # TODO: routers that output full expert distribution + # should probably be corrected to output only top_k weights + sample_weights = top_k_weights[token_idx, expert_ids].view(-1, 1) # (S, 1) else: raise ValueError( f"top_k_weights has an invalid/unsupported shape. It should be either (num_tokens, num_top_k)({num_tokens}, {num_top_k}) " @@ -135,19 +137,19 @@ def batched_mm_experts_forward( # --- Up projection per expert (batched) --- gate_up_out = _batched_linear( - selected_hidden_states, selected_gate_up, bias=selected_gate_up_bias, is_transposed=self.is_transposed + selected_hidden_states, selected_gate_up, selected_gate_up_bias, is_transposed=self.is_transposed ) # (S, 2 * intermediate_dim) # Apply gating - hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim) + gated_out = self._apply_gate(gate_up_out) # (S, intermediate_dim) # --- Down projection per expert (batched) --- out_per_sample = _batched_linear( - hidden_after_activation, selected_down, bias=selected_down_bias, is_transposed=self.is_transposed + gated_out, selected_down, selected_down_bias, is_transposed=self.is_transposed ) # (S, hidden_dim) # Apply routing weights - out_per_sample = out_per_sample * sample_weights.unsqueeze(-1) # (S, hidden_dim) + out_per_sample = out_per_sample * sample_weights # (S, hidden_dim) # Accumulate results back to the final_hidden_states using original token indices final_hidden_states.index_add_(0, token_idx, out_per_sample.to(final_hidden_states.dtype)) @@ -211,16 +213,18 @@ def grouped_mm_experts_forward( final_hidden_states = torch.zeros_like(hidden_states) # Flatten top_k_index to get expert_ids per selected sample - expert_ids = top_k_index.reshape(-1) + expert_ids = top_k_index.view(-1) token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # Resolve routing weights per selected sample, allowing top_k_weights to be either: # - (num_tokens, num_top_k) Qwen2MoE style # - (num_tokens, num_experts) DeepseekV2 style if top_k_weights.shape == (num_tokens, num_top_k): - sample_weights = top_k_weights.reshape(-1) # (S,) + sample_weights = top_k_weights.view(-1, 1) # (S, 1) elif top_k_weights.shape == (num_tokens, num_experts): - sample_weights = top_k_weights[token_idx, expert_ids] # (S,) + # TODO: routers that output full expert distribution + # should probably be corrected to output only top_k weights + sample_weights = top_k_weights[token_idx, expert_ids].view(-1, 1) # (S, 1) else: raise ValueError( f"top_k_weights has an invalid/unsupported shape. It should be either (num_tokens, num_top_k)({num_tokens}, {num_top_k}) " @@ -228,40 +232,45 @@ def grouped_mm_experts_forward( ) # Get current hidden states for selected samples - current_hidden_states = hidden_states[token_idx] + selected_hidden_states = hidden_states[token_idx] # Sort by expert for grouped processing - perm = torch.argsort(expert_ids, stable=True) - inv_perm = torch.argsort(perm, stable=True) - - # Group by expert + perm = torch.argsort(expert_ids) + inv_perm = torch.argsort(perm) expert_ids_g = expert_ids[perm] sample_weights_g = sample_weights[perm] - current_states_g = current_hidden_states[perm] + selected_hidden_states_g = selected_hidden_states[perm] + + # Select expert weights and biases for selected samples + # NOTE: We keep all experts here and rely on offsets to target the active ones. + # I have already implemented a version that only passes the active experts, but + # to do so I had to use torch.unique which breaks the graph capture (data-dependent). + # Also there were no speedup gains from it in my experiments, even in eager mode. + selected_gate_up = self.gate_up_proj + selected_down = self.down_proj selected_gate_up_bias = self.gate_up_proj_bias[expert_ids_g] if self.has_bias else None selected_down_bias = self.down_proj_bias[expert_ids_g] if self.has_bias else None # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues - # (grouped_mm_experts_forward still fails with cuda graphs but because of _grouped_mm internals) num_tokens_per_expert = torch.histc(expert_ids_g.float(), bins=num_experts, min=0, max=num_experts - 1) offs = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # --- Up projection per expert (grouped_mm) --- + # --- Up projection per expert (grouped) --- gate_up_out = _grouped_linear( - current_states_g, self.gate_up_proj, bias=selected_gate_up_bias, is_transposed=self.is_transposed, offs=offs + selected_hidden_states_g, selected_gate_up, selected_gate_up_bias, offs, is_transposed=self.is_transposed ) # (S, 2 * intermediate_dim) # Apply gating - hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim) + gated_out = self._apply_gate(gate_up_out) # (S, intermediate_dim) - # --- Down projection per expert (grouped_mm) --- + # --- Down projection per expert (grouped) --- out_per_sample_g = _grouped_linear( - hidden_after_activation, self.down_proj, bias=selected_down_bias, is_transposed=self.is_transposed, offs=offs + gated_out, selected_down, selected_down_bias, offs, is_transposed=self.is_transposed ) # (S, hidden_dim) # Apply routing weights - out_per_sample_g = out_per_sample_g * sample_weights_g.unsqueeze(-1) + out_per_sample_g = out_per_sample_g * sample_weights_g # Restore original order out_per_sample = out_per_sample_g[inv_perm] From 1fc01dc3dbdb539920eb9ca3cb5589ddfd8fbb1f Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 16 Jan 2026 15:18:24 +0100 Subject: [PATCH 19/28] fixup conversion and test --- src/transformers/core_model_loading.py | 41 ++++++++++--------- tests/models/gpt_oss/test_modeling_gpt_oss.py | 5 +-- tests/test_modeling_common.py | 12 ++++-- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 6a6f1cc90c87..fbb399512939 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -447,30 +447,31 @@ class Force16BytesAlignment(ConversionOps): @torch.no_grad() def convert( - self, - input_dict: dict[str, list[torch.Tensor]], - source_patterns: list[str], - target_patterns: list[str], - config, - **kwargs, - ) -> dict[str, list[torch.Tensor]]: - if len(input_dict) != len(target_patterns): - raise ValueError( - f"Force16BytesAlignment conversion can only happen on each key ({len(input_dict)}) " - f"and should match exact one target ({len(target_patterns)})." - ) + self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs + ) -> dict[str, torch.Tensor]: + target_pattern = self.get_target_pattern(input_dict, source_patterns, target_patterns) + tensors = next(iter(input_dict.values())) + tensor = tensors[0] if isinstance(tensors, list) else tensors + return {target_pattern: tensor.clone()} - output: dict[str, list[torch.Tensor]] = {} - for key, target_pattern in zip(input_dict.keys(), target_patterns): - tensor = input_dict.get(key, []) - if len(tensor) != 1: - raise ValueError(f"Force16BytesAlignment conversion requires exactly one tensor, found {len(tensor)}.") - output[target_pattern] = tensor[0].clone() if tensor[0].data_ptr() % 16 == 0 else tensor[0].clone() - return output + def get_target_pattern( + self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str] + ) -> str: + if len(input_dict) != 1: + raise ValueError("Undefined Operation encountered!") + # Here it's the first operation of a chain, so return the source + if len(target_patterns) > 1: + if len(source_patterns) == 1: + return source_patterns[0] + else: + raise ValueError("Undefined Operation encountered!") + # Here it's the only operation, or the last operation in a chain, so we return the target + else: + return target_patterns[0] @property def reverse_op(self) -> ConversionOps: - return deepcopy(self) + return Force16BytesAlignment() @dataclass(slots=True) diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 9c12f5e5f7c2..23c92d900e35 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -120,9 +120,8 @@ def test_flex_attention_with_grads(self): def test_generate_compile_model_forward_fullgraph(self): return super().test_generate_compile_model_forward_fullgraph() - @unittest.skip("GptOss does not rename any weights in its conversion mapping") - def test_reverse_loading_mapping(self): - pass + def test_reverse_loading_mapping(self, check_keys_were_modified=False): + super().test_reverse_loading_mapping(check_keys_were_modified) RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3cb7d3dbae3b..a4c6c3d14988 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4461,7 +4461,7 @@ def test_tp_plan_matches_params(self): len(unused_entries) == 0, f"The following entries of the TP-plan are not valid: {unused_entries}" ) - def test_reverse_loading_mapping(self): + def test_reverse_loading_mapping(self, check_keys_were_modified=True): """Make sure we can load and save correctly the models having any weight renaming mapping or weight conversion mapping. Note that this test would be better if we could start from the serialized keys, and check that the model @@ -4470,6 +4470,11 @@ def test_reverse_loading_mapping(self): reverse the conversion and then check that those converted keys match correctly the conversions. However, all the checks performed here should ensure everything is going as it should. + + Args: + check_keys_were_modified (`bool`, *optional*, defaults to `True`): + Whether to expect keys being modified or not. In some cases, models do not change keys but + their weights, e.g. via transpose, memory alignment, etc. """ config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -4500,8 +4505,9 @@ def test_reverse_loading_mapping(self): # Get all the serialized keys that we just saved according to the reverse mapping serialized_keys = list(state_dict.keys()) - # They should be different, otherwise we did not perform any mapping - self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!") + if check_keys_were_modified: + # They should be different, otherwise we did not perform any mapping + self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!") # Check that for each conversion entry, we at least map to one key for conversion in conversions: From d820713890348a864f94f5f7392097095958a01a Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Fri, 16 Jan 2026 15:39:47 +0100 Subject: [PATCH 20/28] Update src/transformers/conversion_mapping.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> --- src/transformers/conversion_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 2e19769f266e..551d7c4afdf6 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -362,7 +362,7 @@ def get_model_conversion_mapping( if hf_quantizer is not None: # NOTE: Since get_weight_conversions() only serves to dequantize, we need to put them first in the list. # However, for now it's not possible to match 1 param with 2 converters (i.e. 1 dequantization converter - # and 1 model-specific converter. Which means that if a model that has model-specific conversions and is being + # and 1 model-specific converter). Which means that if a model that has model-specific conversions and is being # dequantized, the model-specific conversion that has patterns matching the dequantization patterns will be ignored. weight_conversions = hf_quantizer.get_weight_conversions() + weight_conversions From 71fdb18c1817176cec73d5a26553f2522587704e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 16 Jan 2026 15:49:36 +0100 Subject: [PATCH 21/28] force alignment docstring --- src/transformers/core_model_loading.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index fbb399512939..e04244603286 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -442,7 +442,8 @@ def reverse_op(self) -> ConversionOps: class Force16BytesAlignment(ConversionOps): """ - Ensures that the given tensor is 16-bytes aligned in memory. + Ensures that the given tensor is 16-bytes aligned in memory and clones it if not. + This garantees 16-bytes alignmenet for kernels / implementations that use TMA or SIMD instructions like torch._grouped_mm. """ @torch.no_grad() @@ -452,7 +453,8 @@ def convert( target_pattern = self.get_target_pattern(input_dict, source_patterns, target_patterns) tensors = next(iter(input_dict.values())) tensor = tensors[0] if isinstance(tensors, list) else tensors - return {target_pattern: tensor.clone()} + tensor = tensor.clone() if tensor.data_ptr() % 16 != 0 else tensor + return {target_pattern: tensor} def get_target_pattern( self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str] From e852cbb02791055a2f3b5c3bfdbd2b7fe670d053 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 16 Jan 2026 15:51:28 +0100 Subject: [PATCH 22/28] move default apply gate --- src/transformers/integrations/moe.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 53236c14e45f..a49b3203056c 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -293,6 +293,20 @@ class ExpertsInterface(GeneralInterface): ALL_EXPERTS_FUNCTIONS = ExpertsInterface() +def _default_apply_gate(self, gate_up_out: torch.Tensor) -> torch.Tensor: + """ + Default gating mechanism: splits the gate_up_out into gate and up parts, + applies the activation function to the gate part, and multiplies it with the up part. + Args: + gate_up_out (`torch.Tensor`): + The output tensor from the gate and up projection of shape (S, 2 * intermediate_dim). + Returns: + `torch.Tensor`: The gated output tensor of shape (S, intermediate_dim). + """ + gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim) + return self.act_fn(gate) * up # (S, intermediate_dim) + + def use_experts_implementation( experts_class: type[torch.nn.Module] | None = None, *, is_transposed: bool = False, has_bias: bool = False ) -> type[torch.nn.Module]: @@ -331,13 +345,7 @@ def forward(self, *args, **kwargs): return experts_forward(self, *args, **kwargs) if not hasattr(experts_class, "_apply_gate"): - - def _apply_gate(self, gate_up_out: torch.Tensor) -> torch.Tensor: - gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim) - return self.act_fn(gate) * up # (S, intermediate_dim) - - experts_class._apply_gate = _apply_gate - + experts_class._apply_gate = _default_apply_gate experts_class.__init__ = __init__ experts_class.forward = forward return experts_class From d698dcb433f6a664610f014243dc6fc25c912c3e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 16 Jan 2026 15:52:16 +0100 Subject: [PATCH 23/28] offsets --- src/transformers/integrations/moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index a49b3203056c..cdd94cfdf4ee 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -254,11 +254,11 @@ def grouped_mm_experts_forward( # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues num_tokens_per_expert = torch.histc(expert_ids_g.float(), bins=num_experts, min=0, max=num_experts - 1) - offs = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # --- Up projection per expert (grouped) --- gate_up_out = _grouped_linear( - selected_hidden_states_g, selected_gate_up, selected_gate_up_bias, offs, is_transposed=self.is_transposed + selected_hidden_states_g, selected_gate_up, selected_gate_up_bias, offsets, is_transposed=self.is_transposed ) # (S, 2 * intermediate_dim) # Apply gating @@ -266,7 +266,7 @@ def grouped_mm_experts_forward( # --- Down projection per expert (grouped) --- out_per_sample_g = _grouped_linear( - gated_out, selected_down, selected_down_bias, offs, is_transposed=self.is_transposed + gated_out, selected_down, selected_down_bias, offsets, is_transposed=self.is_transposed ) # (S, hidden_dim) # Apply routing weights From d6631bbafe5e3a4eec4f891494375dca6c0c6bec Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 19 Jan 2026 08:18:43 +0100 Subject: [PATCH 24/28] add docs and make kernel_config optional --- src/transformers/modeling_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 623f20f4c8cb..bab3a2a09220 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3586,7 +3586,15 @@ def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_ca return init_contexts - def set_use_kernels(self, use_kernels, kernel_config): + def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None): + """ + Set whether or not to use the `kernels` library to kernelize some layers of the model. + Args: + use_kernels (`bool`): + Whether or not to use the `kernels` library to kernelize some layers of the model. + kernel_config (`KernelConfig`, *optional*): + The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used. + """ if use_kernels: if not is_kernels_available(): raise ValueError( From 4f7226d8fa4a8ce28ce6f2e1ec6a6825f5fb231e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 19 Jan 2026 09:42:30 +0100 Subject: [PATCH 25/28] use reshapes as they are equivalent to views when memory is contiguous --- src/transformers/integrations/moe.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index cdd94cfdf4ee..e0236bedc4b8 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -108,23 +108,24 @@ def batched_mm_experts_forward( final_hidden_states = torch.zeros_like(hidden_states) # Flatten top_k_index to get expert_ids per selected sample - expert_ids = top_k_index.view(-1) + expert_ids = top_k_index.reshape(-1) token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # Resolve routing weights per selected sample, allowing top_k_weights to be either: # - (num_tokens, num_top_k) Qwen2MoE style # - (num_tokens, num_experts) DeepseekV2 style if top_k_weights.shape == (num_tokens, num_top_k): - sample_weights = top_k_weights.view(-1, 1) # (S, 1) + sample_weights = top_k_weights elif top_k_weights.shape == (num_tokens, num_experts): # TODO: routers that output full expert distribution # should probably be corrected to output only top_k weights - sample_weights = top_k_weights[token_idx, expert_ids].view(-1, 1) # (S, 1) + sample_weights = top_k_weights[token_idx, expert_ids] else: raise ValueError( f"top_k_weights has an invalid/unsupported shape. It should be either (num_tokens, num_top_k)({num_tokens}, {num_top_k}) " f"or (num_tokens, num_experts)({num_tokens}, {num_experts}), but got {top_k_weights.shape}." ) + sample_weights = sample_weights.reshape(-1, 1) # (S, 1) # Get current hidden states for selected samples selected_hidden_states = hidden_states[token_idx] @@ -213,23 +214,24 @@ def grouped_mm_experts_forward( final_hidden_states = torch.zeros_like(hidden_states) # Flatten top_k_index to get expert_ids per selected sample - expert_ids = top_k_index.view(-1) + expert_ids = top_k_index.reshape(-1) token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # Resolve routing weights per selected sample, allowing top_k_weights to be either: # - (num_tokens, num_top_k) Qwen2MoE style # - (num_tokens, num_experts) DeepseekV2 style if top_k_weights.shape == (num_tokens, num_top_k): - sample_weights = top_k_weights.view(-1, 1) # (S, 1) + sample_weights = top_k_weights elif top_k_weights.shape == (num_tokens, num_experts): # TODO: routers that output full expert distribution # should probably be corrected to output only top_k weights - sample_weights = top_k_weights[token_idx, expert_ids].view(-1, 1) # (S, 1) + sample_weights = top_k_weights[token_idx, expert_ids] else: raise ValueError( f"top_k_weights has an invalid/unsupported shape. It should be either (num_tokens, num_top_k)({num_tokens}, {num_top_k}) " f"or (num_tokens, num_experts)({num_tokens}, {num_experts}), but got {top_k_weights.shape}." ) + sample_weights = sample_weights.reshape(-1, 1) # (S, 1) # Get current hidden states for selected samples selected_hidden_states = hidden_states[token_idx] From 21173033b2dad03beb5b34fb745854fc5b04d215 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 19 Jan 2026 12:53:58 +0100 Subject: [PATCH 26/28] fix and better notes --- src/transformers/conversion_mapping.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 551d7c4afdf6..362841dc00b0 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -42,13 +42,18 @@ def _build_checkpoint_conversion_mapping(): mapping = { "gpt_oss": [ + # NOTE: These converters are only applied if the model is being loaded from pre-dequantized checkpoint. + # If you are dequantizing the model on the fly, these converters will be ignored because the tensors + # that match these patterns are only created after dequantization. + # That's not an issue for now since the dequantization converters already ensure 16 bytes alignment + # by enforcing contiguity. WeightConverter( - source_patterns="mlp.experts.gate_up_proj", + source_patterns="mlp.experts.gate_up_proj$", target_patterns="mlp.experts.gate_up_proj", operations=[Force16BytesAlignment()], ), WeightConverter( - source_patterns="mlp.experts.down_proj", + source_patterns="mlp.experts.down_proj$", target_patterns="mlp.experts.down_proj", operations=[Force16BytesAlignment()], ), @@ -360,10 +365,13 @@ def get_model_conversion_mapping( # Add the ones from the quantizer as well if provided if hf_quantizer is not None: - # NOTE: Since get_weight_conversions() only serves to dequantize, we need to put them first in the list. - # However, for now it's not possible to match 1 param with 2 converters (i.e. 1 dequantization converter - # and 1 model-specific converter). Which means that if a model that has model-specific conversions and is being - # dequantized, the model-specific conversion that has patterns matching the dequantization patterns will be ignored. - weight_conversions = hf_quantizer.get_weight_conversions() + weight_conversions + # NOTE: Since get_weight_conversions() only serve to dequantize, we would normally want to apply them first. + # However, for now it's not possible to cascade converters (i.e., applying model-specific conversions on top + # of tensors created by the dequantization conversions) + # This means that if a model has model-specific conversions and is being dequantized, the model-specific conversion + # that relies on tensors created by dequantization conversions will not be applied. + # GptOss example: with Mxfp4Config(dequantize=True), Force16BytesAlignment converters are ignored because the tensors + # "mlp.experts.gate_up_proj$" and "mlp.experts.down_proj$" are only created after dequantization conversions are applied. + weight_conversions.extend(hf_quantizer.get_weight_conversions()) return weight_conversions From 944a0ecaac01a2713121a702e9ac8e2782415ac9 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 19 Jan 2026 19:09:57 +0100 Subject: [PATCH 27/28] reshapes instead of views --- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 7 +++---- src/transformers/models/gpt_oss/modular_gpt_oss.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 1feeef2c9334..56e894119b33 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -125,8 +125,7 @@ def __init__(self, config): def forward(self, hidden_states): router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = router_top_value + router_scores = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) return router_logits, router_scores, router_indices @@ -139,10 +138,10 @@ def __init__(self, config): def forward(self, hidden_states): batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) + hidden_states = hidden_states.reshape(-1, hidden_dim) _, router_scores, router_indices = self.router(hidden_states) hidden_states = self.experts(hidden_states, router_indices, router_scores) - hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) + hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states, router_scores diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index eb39af4a09c2..267d42bd9e20 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -122,8 +122,7 @@ def __init__(self, config): def forward(self, hidden_states): router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = router_top_value + router_scores = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) return router_logits, router_scores, router_indices @@ -136,10 +135,10 @@ def __init__(self, config): def forward(self, hidden_states): batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) + hidden_states = hidden_states.reshape(-1, hidden_dim) _, router_scores, router_indices = self.router(hidden_states) hidden_states = self.experts(hidden_states, router_indices, router_scores) - hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) + hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states, router_scores From 16e653660b97ea30f6fe5486126926badae35870 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 20 Jan 2026 15:37:17 +0100 Subject: [PATCH 28/28] keep model saving and reloading in grouped_mm test to catch misalignment issues --- tests/test_modeling_common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a4c6c3d14988..0ea261430b0d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -562,6 +562,11 @@ def _test_eager_matches_batched_and_grouped_inference(self, name, dtype): model = model_class(config).eval().to(torch_device).to(dtype) set_model_for_less_flaky_test(model) + # Reload to find any buffer misalignments after saving/loading + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname).eval().to(torch_device).to(dtype) + with torch.no_grad(): inputs_dict = {k: v.to(dtype) if torch.is_floating_point(v) else v for k, v in inputs_dict.items()} prepared_inputs = self._prepare_for_class(inputs_dict, model_class)