Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2aff4a8
experts impl gpt oss
IlyasMoutawwakil Jan 12, 2026
9958efb
no need to transpose dequantized experts
IlyasMoutawwakil Jan 12, 2026
b23e1ff
skip test_reverse_loading_mapping
IlyasMoutawwakil Jan 12, 2026
e28f155
fix custom gating
IlyasMoutawwakil Jan 13, 2026
e57d0a8
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 13, 2026
be08fe4
revert transposition and simply support transposed experts to avoid m…
IlyasMoutawwakil Jan 13, 2026
e1dba4d
style
IlyasMoutawwakil Jan 13, 2026
0261a46
don't rely on weight shapes as they can be square matrices
IlyasMoutawwakil Jan 13, 2026
5bd25c7
no need to relaod
IlyasMoutawwakil Jan 14, 2026
846adca
fallback to eager
IlyasMoutawwakil Jan 14, 2026
b1a71a7
Update src/transformers/models/gpt_oss/modeling_gpt_oss.py
IlyasMoutawwakil Jan 14, 2026
9dbed89
fix
IlyasMoutawwakil Jan 15, 2026
2f3fd11
force 16 bytes alignmenet during weight loading
IlyasMoutawwakil Jan 15, 2026
dd377e1
simplify logic
IlyasMoutawwakil Jan 15, 2026
52e0778
quantization conversions should be applied first
IlyasMoutawwakil Jan 15, 2026
1c49112
avoid baddbmm as it is less performant / less optimizable by max-auto…
IlyasMoutawwakil Jan 15, 2026
4b0323c
no need for logger
IlyasMoutawwakil Jan 15, 2026
aa34996
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 15, 2026
f094c31
add comment explaining limitation
IlyasMoutawwakil Jan 16, 2026
221f9bd
standarize operations and only reshape when needed
IlyasMoutawwakil Jan 16, 2026
944afb5
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 16, 2026
1fc01dc
fixup conversion and test
vasqu Jan 16, 2026
d820713
Update src/transformers/conversion_mapping.py
IlyasMoutawwakil Jan 16, 2026
71fdb18
force alignment docstring
IlyasMoutawwakil Jan 16, 2026
e852cbb
move default apply gate
IlyasMoutawwakil Jan 16, 2026
d698dcb
offsets
IlyasMoutawwakil Jan 16, 2026
5c2ca3c
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 18, 2026
d6631bb
add docs and make kernel_config optional
IlyasMoutawwakil Jan 19, 2026
4f7226d
use reshapes as they are equivalent to views when memory is contiguous
IlyasMoutawwakil Jan 19, 2026
2117303
fix and better notes
IlyasMoutawwakil Jan 19, 2026
944a0ec
reshapes instead of views
IlyasMoutawwakil Jan 19, 2026
1a0ea12
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 20, 2026
16e6536
keep model saving and reloading in grouped_mm test to catch misalignm…
IlyasMoutawwakil Jan 20, 2026
75ab275
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 21, 2026
711a652
Merge branch 'main' into gpt-oss-experts-impl
IlyasMoutawwakil Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Chunk,
Concatenate,
ErnieFuseAndSplitTextVisionExperts,
Force16BytesAlignment,
MergeModulelist,
Transpose,
WeightConverter,
Expand All @@ -40,6 +41,18 @@

def _build_checkpoint_conversion_mapping():
mapping = {
"gpt_oss": [
WeightConverter(
source_patterns="mlp.experts.gate_up_proj",
target_patterns="mlp.experts.gate_up_proj",
operations=[Force16BytesAlignment()],
),
WeightConverter(
source_patterns="mlp.experts.down_proj",
target_patterns="mlp.experts.down_proj",
operations=[Force16BytesAlignment()],
),
],
"mixtral": [
WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"),
WeightConverter(
Expand Down Expand Up @@ -347,6 +360,10 @@ def get_model_conversion_mapping(

# Add the ones from the quantizer as well if provided
if hf_quantizer is not None:
weight_conversions.extend(hf_quantizer.get_weight_conversions())
# NOTE: Since get_weight_conversions() only serves to dequantize, we need to put them first in the list.
# However, for now it's not possible to match 1 param with 2 converters (i.e. 1 dequantization converter
# and 1 model-specific converter). Which means that if a model that has model-specific conversions and is being
# dequantized, the model-specific conversion that has patterns matching the dequantization patterns will be ignored.
weight_conversions = hf_quantizer.get_weight_conversions() + weight_conversions

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it makes a difference at all no? Because the operation are ordered by length of collected tensors I thing.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah now that I'm digging deeper into the weight loader I think the reason I got the error above is because it's not possible to cascade converters (i.e., applying model-specific conversions on top of tensors created by the dequantization conversions). Not because you can't match one tensor with two converters (that's a valid limitation, but not the one happening here in gpt oss).

I added $ to the end of Force16BytesAlignement source pattern to fix my error without changing this order. Basically making sure that the sources of the mxfp dequant converter and Force16BytesAlignement are exclusive.
I will revert the line change and make this comment instead:

        # NOTE: Since get_weight_conversions() only serve to dequantize, we normally want to apply them first.
        # However, for now it's not possible to cascade converters (i.e., applying model-specific conversions on top
        # of tensors created by the dequantization conversions)
        # This means that if a model has model-specific conversions and is being dequantized, the model-specific conversion
        # that relies on tensors created by dequantization conversions will not be applied.
        # GptOss example: with Mxfp4Config(dequantize=True), Force16BytesAlignment converters are ignored because the tensors
        # "mlp.experts.gate_up_proj$" and "mlp.experts.down_proj$" are only created after dequantization conversions are applied.
        weight_conversions.extend(hf_quantizer.get_weight_conversions())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, as we talked offline let's add more comments about how to handle the weight converter


return weight_conversions
36 changes: 36 additions & 0 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,42 @@ def reverse_op(self) -> ConversionOps:
return ErnieFuseAndSplitTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim)


class Force16BytesAlignment(ConversionOps):
"""
Ensures that the given tensor is 16-bytes aligned in memory and clones it if not.
This garantees 16-bytes alignmenet for kernels / implementations that use TMA or SIMD instructions like torch._grouped_mm.
"""
Comment on lines +444 to +447

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice 🫡


@torch.no_grad()
def convert(
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs
) -> dict[str, torch.Tensor]:
target_pattern = self.get_target_pattern(input_dict, source_patterns, target_patterns)
tensors = next(iter(input_dict.values()))
tensor = tensors[0] if isinstance(tensors, list) else tensors
tensor = tensor.clone() if tensor.data_ptr() % 16 != 0 else tensor
return {target_pattern: tensor}

def get_target_pattern(
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str]
) -> str:
if len(input_dict) != 1:
raise ValueError("Undefined Operation encountered!")
# Here it's the first operation of a chain, so return the source
if len(target_patterns) > 1:
if len(source_patterns) == 1:
return source_patterns[0]
else:
raise ValueError("Undefined Operation encountered!")
# Here it's the only operation, or the last operation in a chain, so we return the target
else:
return target_patterns[0]

@property
def reverse_op(self) -> ConversionOps:
return Force16BytesAlignment()


@dataclass(slots=True)
class WeightTransform:
source_patterns: str | list[str] = field(init=True)
Expand Down
Loading