Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2aff4a8
experts impl gpt oss
IlyasMoutawwakil Jan 12, 2026
9958efb
no need to transpose dequantized experts
IlyasMoutawwakil Jan 12, 2026
b23e1ff
skip test_reverse_loading_mapping
IlyasMoutawwakil Jan 12, 2026
e28f155
fix custom gating
IlyasMoutawwakil Jan 13, 2026
e57d0a8
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 13, 2026
be08fe4
revert transposition and simply support transposed experts to avoid m…
IlyasMoutawwakil Jan 13, 2026
e1dba4d
style
IlyasMoutawwakil Jan 13, 2026
0261a46
don't rely on weight shapes as they can be square matrices
IlyasMoutawwakil Jan 13, 2026
5bd25c7
no need to relaod
IlyasMoutawwakil Jan 14, 2026
846adca
fallback to eager
IlyasMoutawwakil Jan 14, 2026
b1a71a7
Update src/transformers/models/gpt_oss/modeling_gpt_oss.py
IlyasMoutawwakil Jan 14, 2026
9dbed89
fix
IlyasMoutawwakil Jan 15, 2026
2f3fd11
force 16 bytes alignmenet during weight loading
IlyasMoutawwakil Jan 15, 2026
dd377e1
simplify logic
IlyasMoutawwakil Jan 15, 2026
52e0778
quantization conversions should be applied first
IlyasMoutawwakil Jan 15, 2026
1c49112
avoid baddbmm as it is less performant / less optimizable by max-auto…
IlyasMoutawwakil Jan 15, 2026
4b0323c
no need for logger
IlyasMoutawwakil Jan 15, 2026
aa34996
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 15, 2026
f094c31
add comment explaining limitation
IlyasMoutawwakil Jan 16, 2026
221f9bd
standarize operations and only reshape when needed
IlyasMoutawwakil Jan 16, 2026
944afb5
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 16, 2026
1fc01dc
fixup conversion and test
vasqu Jan 16, 2026
d820713
Update src/transformers/conversion_mapping.py
IlyasMoutawwakil Jan 16, 2026
71fdb18
force alignment docstring
IlyasMoutawwakil Jan 16, 2026
e852cbb
move default apply gate
IlyasMoutawwakil Jan 16, 2026
d698dcb
offsets
IlyasMoutawwakil Jan 16, 2026
5c2ca3c
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 18, 2026
d6631bb
add docs and make kernel_config optional
IlyasMoutawwakil Jan 19, 2026
4f7226d
use reshapes as they are equivalent to views when memory is contiguous
IlyasMoutawwakil Jan 19, 2026
2117303
fix and better notes
IlyasMoutawwakil Jan 19, 2026
944a0ec
reshapes instead of views
IlyasMoutawwakil Jan 19, 2026
1a0ea12
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 20, 2026
16e6536
keep model saving and reloading in grouped_mm test to catch misalignm…
IlyasMoutawwakil Jan 20, 2026
75ab275
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 21, 2026
711a652
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Chunk,
Concatenate,
ErnieFuseAndSplitTextVisionExperts,
Force16BytesAlignment,
MergeModulelist,
Transpose,
WeightConverter,
Expand All @@ -40,6 +41,18 @@

def _build_checkpoint_conversion_mapping():
mapping = {
"gpt_oss": [
WeightConverter(
source_patterns="mlp.experts.gate_up_proj",
target_patterns="mlp.experts.gate_up_proj",
operations=[Force16BytesAlignment()],
),
WeightConverter(
source_patterns="mlp.experts.down_proj",
target_patterns="mlp.experts.down_proj",
operations=[Force16BytesAlignment()],
),
],
"mixtral": [
WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"),
WeightConverter(
Expand Down Expand Up @@ -332,6 +345,6 @@ def get_model_conversion_mapping(

# Add the ones from the quantizer as well if provided
if hf_quantizer is not None:
weight_conversions.extend(hf_quantizer.get_weight_conversions())
weight_conversions = hf_quantizer.get_weight_conversions() + weight_conversions

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's seems that the order of operations matters when both quantization conversions and weight conversions are defined. e.g. with gpt_oss if I want to apply the Force16ByteAlignment conversion, it has to be placed after the dequantization otherwise the loader is confused (see output below). @vasqu @ArthurZucker

device = torch.device("cuda:1")
model_id = "openai/gpt-oss-20b"
model = GptOssForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    dtype=torch.bfloat16,
    quantization_config=Mxfp4Config(dequantize=True),
).eval()

results in:

Loading weights: 100%|█████████████████████████████████████████████| 363/363 [00:00<00:00, 1003.90it/s, Materializing param=model.norm.weight]
GptOssForCausalLM LOAD REPORT from: openai/gpt-oss-20b
Key                                                   | Status     | 
------------------------------------------------------+------------+-
model.layers.{0...23}.mlp.experts.down_proj_scales    | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.down_proj_blocks    | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.gate_up_proj_scales | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.gate_up_proj_blocks | UNEXPECTED | 
model.layers.{0...23}.mlp.experts.gate_up_proj        | MISSING    | 
model.layers.{0...23}.mlp.experts.down_proj           | MISSING    | 

Notes:
- UNEXPECTED    :can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING       :those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Since get_weight_conversions is only to DEQUANTIZE it makes sense that it's first. However, as discussed offline, for now it's not possible to match 1 param with 2 converter (i.e. 1 dequant converter, and 1 from the hardcoded mapping). So it means that any model with a mapping registered cannot be dequantized 😭 So in theory you're right it should come first, but as it does not work anyway currently it does not make a difference 😭 Let's still keep this change adding a comment on all that please!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @Cyrilvallez added a comment mostly rewording your explanation !


return weight_conversions
33 changes: 33 additions & 0 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,39 @@ def reverse_op(self) -> ConversionOps:
return Transpose(dim0=self.dim1, dim1=self.dim0)


class Force16BytesAlignment(ConversionOps):
"""
Ensures that the given tensor is 16-bytes aligned in memory.
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated
"""
Comment on lines +444 to +447

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice 🫡


@torch.no_grad()
def convert(
self,
input_dict: dict[str, list[torch.Tensor]],
source_patterns: list[str],
target_patterns: list[str],
config,
**kwargs,
) -> dict[str, list[torch.Tensor]]:
if len(input_dict) != len(target_patterns):
raise ValueError(
f"Force16BytesAlignment conversion can only happen on each key ({len(input_dict)}) "
f"and should match exact one target ({len(target_patterns)})."
)

output: dict[str, list[torch.Tensor]] = {}
for key, target_pattern in zip(input_dict.keys(), target_patterns):
tensor = input_dict.get(key, [])
if len(tensor) != 1:
raise ValueError(f"Force16BytesAlignment conversion requires exactly one tensor, found {len(tensor)}.")
output[target_pattern] = tensor[0].clone() if tensor[0].data_ptr() % 16 == 0 else tensor[0].clone()
return output

@property
def reverse_op(self) -> ConversionOps:
return deepcopy(self)


@dataclass(slots=True)
class WeightTransform:
source_patterns: str | list[str] = field(init=True)
Expand Down
207 changes: 156 additions & 51 deletions src/transformers/integrations/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

from ..utils.generic import GeneralInterface
from ..utils.import_utils import is_torch_available
from ..utils.logging import get_logger


if is_torch_available():
import torch

logger = get_logger(__name__)

# Examples of experts class with its eager mm implementation
# class Experts(nn.Module):
Expand Down Expand Up @@ -62,6 +64,43 @@
# return final_hidden_states


def _batched_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
is_transposed: bool = False,
) -> torch.Tensor:
"""Batched linear layer supporting optional bias and transposed weights.

Args:
input (`torch.Tensor`):
Input tensor of shape (batch_size, input_dim).
weight (`torch.Tensor`):
Weight tensor of shape (batch_size, output_dim, input_dim) if transposed is `False`,
else of shape (batch_size, input_dim, output_dim).
bias (`torch.Tensor`, *optional*):
Bias tensor of shape (batch_size, output_dim). Default is `None`.
is_transposed (`bool`, *optional*, defaults to `False`):
Whether the weight tensor is transposed.
Returns:
`torch.Tensor`: Output tensor of shape (batch_size, output_dim).
"""
if bias is not None:
if is_transposed:
# (batch_size, 1, output_dim) + (batch_size, 1, input_dim) @ (batch_size, input_dim, output_dim) -> (batch_size, 1, output_dim) -> (batch_size, output_dim)
return torch.baddbmm(bias.unsqueeze(1), input.unsqueeze(1), weight).squeeze(1)
else:
# (batch_size, output_dim, 1) + (batch_size, output_dim, input_dim) @ (batch_size, input_dim, 1) -> (batch_size, output_dim, 1) -> (batch_size, output_dim)
return torch.baddbmm(bias.unsqueeze(-1), weight, input.unsqueeze(-1)).squeeze(-1)
else:
if is_transposed:
# (batch_size, 1, input_dim) @ (batch_size, input_dim, output_dim) -> (batch_size, 1, output_dim) -> (batch_size, output_dim)
return torch.bmm(input.unsqueeze(1), weight).squeeze(1)
else:
# (batch_size, output_dim, input_dim) @ (batch_size, input_dim, 1) -> (batch_size, output_dim, 1) -> (batch_size, output_dim)
return torch.bmm(weight, input.unsqueeze(-1)).squeeze(-1)


def batched_mm_experts_forward(
self: torch.nn.Module,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -92,27 +131,26 @@ def batched_mm_experts_forward(
)

# Get current hidden states for selected samples
current_hidden_states = hidden_states[token_idx] # (S, hidden_dim)
selected_hidden_states = hidden_states[token_idx]

# Select projection matrices for selected experts
selected_gate_up = self.gate_up_proj[expert_ids] # (S, hidden_dim, 2 * intermediate_dim)
selected_down = self.down_proj[expert_ids] # (S, hidden_dim, intermediate_dim)
# Select expert weights and biases for selected samples
selected_gate_up = self.gate_up_proj[expert_ids]
selected_down = self.down_proj[expert_ids]
selected_gate_up_bias = self.gate_up_proj_bias[expert_ids] if self.has_bias else None
selected_down_bias = self.down_proj_bias[expert_ids] if self.has_bias else None

# --- Up projection per expert (batched) ---
gate_up_out = torch.bmm(selected_gate_up, current_hidden_states.unsqueeze(-1)).squeeze(-1)
if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None:
gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids]
gate_up_out = _batched_linear(
selected_hidden_states, selected_gate_up, bias=selected_gate_up_bias, is_transposed=self.is_transposed
) # (S, 2 * intermediate_dim)

# Split into gate and up components
gate, up = gate_up_out.chunk(2, dim=-1) # both have shape (S, intermediate_dim)

# Apply activation
hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim)
# Apply gating
hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim)

# --- Down projection per expert (batched) ---
out_per_sample = torch.bmm(selected_down, hidden_after_activation.unsqueeze(-1)).squeeze(-1)
if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None:
out_per_sample = out_per_sample + self.down_proj_bias[expert_ids]
out_per_sample = _batched_linear(
hidden_after_activation, selected_down, bias=selected_down_bias, is_transposed=self.is_transposed
) # (S, hidden_dim)

# Apply routing weights
out_per_sample = out_per_sample * sample_weights.unsqueeze(-1) # (S, hidden_dim)
Expand All @@ -123,6 +161,44 @@ def batched_mm_experts_forward(
return final_hidden_states


def _grouped_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
offs: torch.Tensor | None = None,
is_transposed: bool = False,
) -> torch.Tensor:
"""Grouped linear layer supporting optional bias and transposed weights.

Args:
input (`torch.Tensor`):
Input tensor of shape (S, input_dim).
weight (`torch.Tensor`):
Weight tensor of shape (num_experts, output_dim, input_dim) if transposed is `False`,
else of shape (num_experts, input_dim, output_dim).
bias (`torch.Tensor`, *optional*):
Bias tensor of shape (num_experts, output_dim). Default is `None`.
offs (`torch.Tensor`, *optional*):
Offsets tensor indicating the boundaries of each group in the input tensor.
is_transposed (`bool`, *optional*, defaults to `False`):
Whether the weight tensor is transposed.
Returns:
`torch.Tensor`: Output tensor of shape (S, output_dim).
"""
if is_transposed:
# (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim)
out = torch._grouped_mm(input, weight, offs=offs)
else:
# (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim)
out = torch._grouped_mm(input, weight.transpose(-2, -1), offs=offs)

if bias is not None:
# We should be able to pass bias to the grouped_mm call, but it's not yet supported.
out = out + bias

return out


def grouped_mm_experts_forward(
self: torch.nn.Module,
hidden_states: torch.Tensor,
Expand All @@ -144,10 +220,6 @@ def grouped_mm_experts_forward(
expert_ids = top_k_index.reshape(-1)
token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1)

# Get permutation to group by expert
perm = torch.argsort(expert_ids, stable=True)
inv_perm = torch.argsort(perm, stable=True)

# Resolve routing weights per selected sample, allowing top_k_weights to be either:
# - (num_tokens, num_top_k) Qwen2MoE style
# - (num_tokens, num_experts) DeepseekV2 style
Expand All @@ -162,36 +234,37 @@ def grouped_mm_experts_forward(
)

# Get current hidden states for selected samples
current_hidden_states = hidden_states[token_idx] # (S, hidden_dim)
current_hidden_states = hidden_states[token_idx]

# Sort by expert for grouped processing
perm = torch.argsort(expert_ids, stable=True)
inv_perm = torch.argsort(perm, stable=True)

# Group by expert for grouped_mm
# Group by expert
expert_ids_g = expert_ids[perm]
sample_weights_g = sample_weights[perm]
current_states_g = current_hidden_states[perm]
selected_gate_up_bias = self.gate_up_proj_bias[expert_ids_g] if self.has_bias else None
selected_down_bias = self.down_proj_bias[expert_ids_g] if self.has_bias else None

# Compute offsets for grouped_mm
# using histc instead of bincount to avoid cuda graph issues
# (grouped_mm_experts_forward still fails with cuda graphs but because of _grouped_mm internals)
num_tokens_per_expert = torch.histc(expert_ids_g.float(), bins=num_experts, min=0, max=num_experts - 1)
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
offs = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated

# --- Up projection per expert (grouped_mm) ---
gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj.transpose(-2, -1), offs=offsets)
if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None:
# we should be able to pass bias to the grouped_mm call, but it's still not fully supported
gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids_g]
gate_up_out = _grouped_linear(
current_states_g, self.gate_up_proj, bias=selected_gate_up_bias, is_transposed=self.is_transposed, offs=offs
) # (S, 2 * intermediate_dim)

# Split into gate and up components
gate, up = gate_up_out.chunk(2, dim=-1) # both have shape (S, intermediate_dim)

# Apply activation
hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim)
# Apply gating
hidden_after_activation = self._apply_gate(gate_up_out) # (S, intermediate_dim)

# --- Down projection per expert (grouped_mm) ---
out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj.transpose(-2, -1), offs=offsets)
if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None:
# we should be able to pass bias to the grouped_mm call, but it's still not fully supported
out_per_sample_g = out_per_sample_g + self.down_proj_bias[expert_ids_g]
out_per_sample_g = _grouped_linear(
hidden_after_activation, self.down_proj, bias=selected_down_bias, is_transposed=self.is_transposed, offs=offs
) # (S, hidden_dim)

# Apply routing weights
out_per_sample_g = out_per_sample_g * sample_weights_g.unsqueeze(-1)
Expand All @@ -217,24 +290,56 @@ class ExpertsInterface(GeneralInterface):
ALL_EXPERTS_FUNCTIONS = ExpertsInterface()


def use_experts_implementation(experts_class: type[torch.nn.Module]) -> type[torch.nn.Module]:
original_init = experts_class.__init__
original_forward = experts_class.forward
def use_experts_implementation(
experts_class: type[torch.nn.Module] | None = None, *, is_transposed: bool = False, has_bias: bool = False
) -> type[torch.nn.Module]:
"""Decorator to modify experts class to support different experts implementations.

Args:
experts_class (`type[torch.nn.Module]`, *optional*):
The experts class to modify. If not provided, returns a decorator that can be applied to the class.
is_transposed (`bool`, *optional*, defaults to `False`):
Whether the expert weights are stored in transposed format.
has_bias (`bool`, *optional*, defaults to `False`):
Whether the expert layers include bias terms.

Returns:
`type[torch.nn.Module]`: The modified experts class.
"""

def wrapper(experts_class: type[torch.nn.Module]) -> type[torch.nn.Module]:
original_init = experts_class.__init__
original_forward = experts_class.forward

@wraps(original_init)
def __init__(self, config, *args, **kwargs):
original_init(self, config, *args, **kwargs)
self.config = config
self.has_bias = has_bias
self.is_transposed = is_transposed

@wraps(original_forward)
def forward(self, *args, **kwargs):
experts_forward = original_forward

if self.config._experts_implementation != "eager":
experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation]

return experts_forward(self, *args, **kwargs)

if not hasattr(experts_class, "_apply_gate"):

@wraps(original_init)
def __init__(self, config, *args, **kwargs):
original_init(self, config, *args, **kwargs)
self.config = config
def _apply_gate(self, gate_up_out: torch.Tensor) -> torch.Tensor:
gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim)
return self.act_fn(gate) * up # (S, intermediate_dim)
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated

@wraps(original_forward)
def forward(self, *args, **kwargs):
experts_forward = original_forward
experts_class._apply_gate = _apply_gate

if self.config._experts_implementation != "eager":
experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation]
experts_class.__init__ = __init__
experts_class.forward = forward
return experts_class

return experts_forward(self, *args, **kwargs)
if experts_class is not None:
return wrapper(experts_class)

experts_class.__init__ = __init__
experts_class.forward = forward
return experts_class
return wrapper
Loading