From 48aeb069d41d0ed90bb2b60ea7f3f9f6cdec97c4 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 6 Feb 2026 15:10:43 +0000 Subject: [PATCH 01/26] Working GPT OSS --- unsloth_zoo/compiler.py | 4 + unsloth_zoo/temporary_patches/gpt_oss.py | 429 +++++++++++++++++++++-- 2 files changed, 412 insertions(+), 21 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b905a1e74..756135705 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -811,6 +811,10 @@ def create_new_function( imports += "import torch\n" imports += "import torch.nn as nn\n" imports += "from torch.nn import functional as F\n" + if "torch_compile" in new_source: + imports += "from unsloth_zoo.temporary_patches.common import torch_compile\n" + if "KWARGS_TYPE" in new_source: + imports += "from unsloth_zoo.temporary_patches.utils import KWARGS_TYPE\n" imports += ( "from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable\n" ) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 3a3cbcc5f..818e01aa1 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -18,6 +18,7 @@ import os import torch import torch.nn as nn +import torch.nn.init as init import torch.nn.functional as F import inspect from .common import ( @@ -44,6 +45,32 @@ from ..hf_utils import dtype_from_config torch_cuda_device = torch.cuda.device +# MXFP4 configuration +# Set UNSLOTH_MXFP4_NO_DEQUANTIZE=1 to keep MXFP4 weights quantized (requires triton_kernels) +# Otherwise, MXFP4 weights will be dequantized to bf16 for LoRA training +UNSLOTH_MXFP4_NO_DEQUANTIZE = os.environ.get("UNSLOTH_MXFP4_NO_DEQUANTIZE", "0") == "1" + + +def _check_triton_kernels_available(): + """Check if OpenAI's triton_kernels package is available for MXFP4.""" + try: + from triton_kernels import matmul_ogs, swiglu + + return True + except ImportError: + return False + + +_TRITON_KERNELS_AVAILABLE = None + + +def is_triton_kernels_available(): + """Cached check for triton_kernels availability.""" + global _TRITON_KERNELS_AVAILABLE + if _TRITON_KERNELS_AVAILABLE is None: + _TRITON_KERNELS_AVAILABLE = _check_triton_kernels_available() + return _TRITON_KERNELS_AVAILABLE + @torch_compile(dynamic = True, fullgraph = True) def swiglu_torch_forward(a, alpha, limit, dtype = None): @@ -87,14 +114,24 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): def patch_gpt_oss(): try: import triton_kernels + + HAS_TRITON_KERNELS = True except Exception as e: - return raise_error("Please install triton_kernels", e) + HAS_TRITON_KERNELS = False + # return raise_error("Please install triton_kernels", e) try: import transformers.quantizers.quantizer_mxfp4 - def is_kernels_available(): return True - transformers.quantizers.quantizer_mxfp4.is_kernels_available = is_kernels_available - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = lambda *args, **kwargs: True + + def is_kernels_available(): + return True + + transformers.quantizers.quantizer_mxfp4.is_kernels_available = ( + is_kernels_available + ) + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = ( + lambda *args, **kwargs: True + ) except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.is_kernels_available", e) @@ -106,16 +143,21 @@ def is_kernels_available(): return True except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) - try: - from triton_kernels import matmul_ogs, swiglu - FnSpecs, FusedActivation, matmul_ogs = ( - matmul_ogs.FnSpecs, - matmul_ogs.FusedActivation, - matmul_ogs.matmul_ogs, - ) - swiglu_fn = swiglu.swiglu_fn - except Exception as e: - return raise_error("triton_kernels", e) + if HAS_TRITON_KERNELS: + try: + from triton_kernels import matmul_ogs, swiglu + + FnSpecs, FusedActivation, matmul_ogs = ( + matmul_ogs.FnSpecs, + matmul_ogs.FusedActivation, + matmul_ogs.matmul_ogs, + ) + swiglu_fn = swiglu.swiglu_fn + except Exception as e: + return raise_error("triton_kernels", e) + else: + # Skip MXFP4 patches when triton_kernels not available + return try: import transformers.integrations.mxfp4 @@ -205,8 +247,8 @@ def forward( @staticmethod def backward(ctx, grad_token): raise NotImplementedError( - "Backwards pass using MXFP4 is still under construction!\n"\ - "Instead, use `unsloth/gpt-oss-20b-BF16` for bfloat16 training which will work for LoRA.\n"\ + "Backwards pass using MXFP4 is still under construction!\n" + "Instead, use `unsloth/gpt-oss-20b-BF16` for bfloat16 training which will work for LoRA.\n" "Or, use `load_in_4bit = True` which allows finetuning." ) (pre_act, gamma, gather_src, gather_dst, scatter_src, scatter_dst,) = ctx.saved_tensors @@ -243,12 +285,13 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size + # MXFP4 quantized format (blocks + scales) self.gate_up_proj_blocks = nn.Parameter( torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), requires_grad=False, ) self.gate_up_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( @@ -256,11 +299,11 @@ def __init__(self, config): ) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16,), dtype=torch.uint8), requires_grad=False, ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, 16, dtype=torch.uint8), requires_grad=False, ) self.down_proj_bias = nn.Parameter( @@ -271,7 +314,10 @@ def __init__(self, config): self.gate_up_proj_precision_config = None self.down_proj_precision_config = None - def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: + @property + def gate_up_proj(selforward( + self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx + ) -> torch.Tensor: with torch_cuda_device(hidden_states.device): if not hasattr(self, "act"): self.act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2) @@ -520,7 +566,12 @@ def __init__(self, config): self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - self.linear = nn.Linear(self.hidden_dim, self.num_experts, dtype=dtype_from_config(config)) + self.linear = nn.Linear( + self.hidden_dim, self.num_experts, dtype=dtype_from_config(config) + ) + # Expose weight/bias for HF init compatibility + self.weight = self.linear.weight + self.bias = self.linear.bias @torch_compile(dynamic = True, fullgraph = True) def forward(self, hidden_states): @@ -657,6 +708,10 @@ def patch_gpt_oss_linearized(): except Exception as e: return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + _patch_gpt_oss_init_weights_for_modulelist( + transformers.models.gpt_oss.modeling_gpt_oss + ) + # We find down_proj overflows in GPT OSS if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": def forward( @@ -736,6 +791,338 @@ def forward( TEMPORARY_PATCHES.append(patch_gpt_oss_linearized) +def _patch_gpt_oss_init_weights_for_modulelist(transformers_module): + GptOssPreTrainedModel = transformers_module.GptOssPreTrainedModel + GptOssExperts = transformers_module.GptOssExperts + if getattr(GptOssPreTrainedModel, "_unsloth_init_weights_patched", False): + return + _original_init_weights = GptOssPreTrainedModel._init_weights + + def _patched_init_weights(self, module): + _original_init_weights(self, module) + if isinstance(module, GptOssExperts) and not hasattr(module, "gate_up_proj"): + std = self.config.initializer_range + for up in getattr(module, "gate_up_projs", []): + init.normal_(up.weight, mean=0.0, std=std) + if up.bias is not None: + init.zeros_(up.bias) + for down in getattr(module, "down_projs", []): + init.normal_(down.weight, mean=0.0, std=std) + if down.bias is not None: + init.zeros_(down.bias) + + patch_function(GptOssPreTrainedModel, "_init_weights", _patched_init_weights) + GptOssPreTrainedModel._unsloth_init_weights_patched = True + + +def patch_gpt_oss_bf16_split_lora(): + """ + Patch GPT-OSS BF16 model to use split LoRA with grouped GEMM. + + This patch applies to BF16 models loaded via unsloth/gpt-oss-20b-BF16 or similar. + It creates stacked expert weights and uses moe_utils.py's forward_native_grouped_mm + for efficient training with split LoRA. + + Key differences from Qwen3 MoE: + 1. GPT-OSS uses TRANSPOSED weight layout: + - gate_up_proj: (num_experts, hidden_size, 2 * intermediate_size) + - down_proj: (num_experts, intermediate_size, hidden_size) + 2. GPT-OSS gate_up has interleaved layout (::2 for gate, 1::2 for up) + 3. GPT-OSS has biases: gate_up_proj_bias, down_proj_bias + + 4-bit BNB models use gate_up_projs/down_projs ModuleList and are NOT affected. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + # Skip 4-bit models - they use ModuleList and are handled by patch_gpt_oss_linearized + if "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + + from .moe_utils import ( + forward_native_grouped_mm, + select_moe_backend, + patch_param_wrapper_for_moe, + _has_lora_adapters, + _extract_lora_from_wrapper, + _should_use_separated_lora, + _is_moe_experts_module, + ) + + # Patch ParamWrapper.forward for MoE separated LoRA + patch_param_wrapper_for_moe() + + # LoRA Extractor function for GPT-OSS + # GPT-OSS weight layout is TRANSPOSED: (E, hidden, output) instead of (E, output, hidden) + def _gpt_oss_lora_extractor( + self, wrapper, weight_A, weight_B, scaling, num_experts + ): + """ + GPT-OSS LoRA extractor for transposed weight layout. + + GPT-OSS weights: + gate_up_proj: (E, H, 2*I) - transposed layout (in_dim, out_dim) + down_proj: (E, I, H) - transposed layout (in_dim, out_dim) + + For grouped_mm: X @ W where W is (E, in_dim, out_dim) + + PEFT creates: + lora_A: (E*R, in_dim) - projects input to rank space + lora_B: (out_dim, E*R) - projects rank to output + + For transposed format, the LoRA dimensions are already correct: + - We want X @ (E, in, R) @ (E, R, out) + """ + # This Unsloth Zoo code section is licensed under AGPL3 + + total_rank = weight_A.shape[0] + rank_per_expert = total_rank // num_experts + dim_A = weight_A.shape[ + 1 + ] # in_dim (hidden_dim for gate_up, intermediate for down) + dim_B = weight_B.shape[ + 0 + ] # out_dim (2*intermediate for gate_up, hidden_dim for down) + + # Get model dimensions from the experts module + hidden_dim = None + intermediate_dim = None + current = wrapper + while hasattr(current, "base_layer"): + current = current.base_layer + if hasattr(current, "hidden_size"): + hidden_dim = current.hidden_size + if hasattr(current, "intermediate_size"): + intermediate_dim = current.intermediate_size + + # Get parameter name + param_name = getattr(wrapper, "parameter_name", None) + + # GPT-OSS uses TRANSPOSED layout, so LoRA dimensions map directly: + # Input projection: X @ (E, in_dim, R) + # Output projection: result @ (E, R, out_dim) + + if ( + param_name == "down_proj" + and intermediate_dim is not None + and hidden_dim is not None + ): + # down_proj: input=intermediate_dim, output=hidden_dim + # Weight shape: (E, I, H) - transposed + # lora_A: (E*R, H) from PEFT (swapped due to 3D param handling) + # lora_B: (I, E*R) from PEFT (swapped) + # For X @ first @ second: first is (E, I, R), second is (E, R, H) + + # first_weight from B (has intermediate_dim) + first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + first_weight = first_weight.permute(1, 0, 2).contiguous() # (E, I, R) + + # second_weight from A (has hidden_dim) + second_weight = weight_A.view( + num_experts, rank_per_expert, dim_A + ) # (E, R, H) + + return first_weight, second_weight, scaling, num_experts + + elif param_name == "gate_up_proj" and hidden_dim is not None: + # gate_up_proj: input=hidden_dim, output=2*intermediate_dim + # Weight shape: (E, H, 2*I) - transposed + # lora_A: (E*R, 2*I) from PEFT + # lora_B: (H, E*R) from PEFT + # For X @ first @ second: first is (E, H, R), second is (E, R, 2*I) + + # first_weight from B (has hidden_dim) + first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + first_weight = first_weight.permute(1, 0, 2).contiguous() # (E, H, R) + + # second_weight from A (has 2*intermediate_dim) + second_weight = weight_A.view( + num_experts, rank_per_expert, dim_A + ) # (E, R, 2*I) + + return first_weight, second_weight, scaling, num_experts + + # Fallback: dimension-based detection + if hidden_dim is not None: + if dim_B == hidden_dim: + # B connects to hidden_dim (transposed case) + first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + first_weight = first_weight.permute(1, 0, 2).contiguous() + second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) + return first_weight, second_weight, scaling, num_experts + elif dim_A == hidden_dim: + # A connects to hidden_dim + first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) + first_weight = first_weight.permute(0, 2, 1).contiguous() + second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + second_weight = second_weight.permute(1, 2, 0).contiguous() + return first_weight, second_weight, scaling, num_experts + + # Final fallback + first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) + first_weight = first_weight.permute(0, 2, 1).contiguous() + second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + second_weight = second_weight.permute(1, 2, 0).contiguous() + return first_weight, second_weight, scaling, num_experts + + # Patch GptOssExperts.forward to use grouped GEMM + split LoRA (BF16 only) + GptOssExperts = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts + _original_gpt_oss_experts_forward = GptOssExperts.forward + + def _bf16_split_lora_forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + # Fallback to original for 4-bit ModuleList or missing routing data + if ( + router_indices is None + or routing_weights is None + or hasattr(self, "gate_up_projs") + or not hasattr(self, "gate_up_proj") + or not hasattr(self, "down_proj") + ): + return _original_gpt_oss_experts_forward( + self, hidden_states, router_indices, routing_weights + ) + + gate_up_param = getattr(self, "gate_up_proj", None) + down_param = getattr(self, "down_proj", None) + if ( + not isinstance(gate_up_param, nn.Parameter) + or gate_up_param.ndim != 3 + or not isinstance(down_param, nn.Parameter) + or down_param.ndim != 3 + ): + return _original_gpt_oss_experts_forward( + self, hidden_states, router_indices, routing_weights + ) + + if not hasattr(self, "_unsloth_model_type"): + self._unsloth_model_type = "gpt_oss" + + return forward_native_grouped_mm( + self, + hidden_states, + router_indices, # top_k_index + routing_weights, # top_k_weights + ) + + patch_function(GptOssExperts, "forward", _bf16_split_lora_forward) + + _patch_gpt_oss_init_weights_for_modulelist( + transformers.models.gpt_oss.modeling_gpt_oss + ) + + # BF16 Experts class with stacked weights (for split LoRA via moe_utils) + class GptOssExpertsBF16Stacked(nn.Module): + """ + GPT-OSS BF16 Experts with stacked weights for grouped GEMM. + + Uses moe_utils.forward_native_grouped_mm for efficient MoE computation + with separated LoRA support. + """ + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.alpha = 1.702 + self.limit = getattr(config, "swiglu_limit", 7.0) + self.dtype = dtype_from_config(config) + + # Stacked weights in transposed format (E, in_dim, out_dim) for grouped_mm + self.gate_up_proj = nn.Parameter( + torch.empty( + self.num_experts, + self.hidden_size, + 2 * self.intermediate_size, + dtype=self.dtype, + ) + ) + self.gate_up_proj_bias = nn.Parameter( + torch.empty( + self.num_experts, 2 * self.intermediate_size, dtype=self.dtype + ) + ) + self.down_proj = nn.Parameter( + torch.empty( + self.num_experts, + self.intermediate_size, + self.hidden_size, + dtype=self.dtype, + ) + ) + self.down_proj_bias = nn.Parameter( + torch.empty(self.num_experts, self.hidden_size, dtype=self.dtype) + ) + + # Register LoRA extractor + self._unsloth_lora_extractor_fn = _gpt_oss_lora_extractor + + @property + def hidden_dim(self): + return self.hidden_size + + @property + def intermediate_dim(self): + return self.intermediate_size + + def forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + """ + Forward pass using grouped GEMM from moe_utils. + Uses forward_native_grouped_mm which handles split LoRA automatically. + """ + # Call moe_utils grouped MM forward which handles everything + return forward_native_grouped_mm( + self, + hidden_states, + router_indices, # top_k_index + routing_weights, # top_k_weights + ) + + # MLP wrapper that works with stacked experts + class GptOssMLP_BF16(nn.Module): + def __init__(self, config): + super().__init__() + self.router = GptOssTopKRouter(config) + self.experts = GptOssExpertsBF16Stacked(config) + + def forward(self, hidden_states): + bsz, qlen, hd = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hd) + + # Get router scores + router_scores, router_indices = self.router(hidden_states_flat) + + # Run experts with stacked weights + routed_out = self.experts( + hidden_states_flat, + router_indices=router_indices, + routing_weights=router_scores, + ) + + routed_out = routed_out.view(bsz, qlen, hd) + return routed_out, router_scores + + # Patch transformers module + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts._unsloth_lora_extractor_fn = _gpt_oss_lora_extractor + + # Check if model has stacked weights (BF16) vs ModuleList (4-bit) + # This is done at load time by checking the model structure + if UNSLOTH_ENABLE_LOGGING: + logger.info("Unsloth: Patched GPT-OSS for BF16 split LoRA with grouped GEMM") + + +pass +TEMPORARY_PATCHES.append(patch_gpt_oss_bf16_split_lora) + + def patch_GptOssAttention(): if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return From 2eb8ce3c7a941a8bcc2ff36db11d00a8918c1265 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 6 Feb 2026 16:57:34 +0000 Subject: [PATCH 02/26] Working GPT OSS --- unsloth_zoo/temporary_patches/gpt_oss.py | 2066 ++++++++++++++------ unsloth_zoo/temporary_patches/moe_utils.py | 3 + 2 files changed, 1470 insertions(+), 599 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 818e01aa1..c60c64f21 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -30,6 +30,7 @@ ) from importlib.metadata import version as importlib_version from ..utils import Version + transformers_version = Version(importlib_version("transformers")) has_static_cache = transformers_version >= Version("4.56.0.dev0") from .utils import ( @@ -43,6 +44,7 @@ process_return, ) from ..hf_utils import dtype_from_config + torch_cuda_device = torch.cuda.device # MXFP4 configuration @@ -84,8 +86,11 @@ def swiglu_torch_forward(a, alpha, limit, dtype = None): out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) out = out_gelu * (a_linear + 1) return out.to(a.dtype if dtype is None else dtype) + + pass + @torch_compile(dynamic = True, fullgraph = True) def swiglu_torch_backward(pre_act, alpha, limit, g1): g, l = pre_act[..., ::2].to(torch.float32), pre_act[..., 1::2].to(torch.float32) @@ -93,21 +98,23 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): if limit is not None: mask_g = g <= limit mask_l = l.abs() <= limit - ḡ = torch.where(mask_g, g, limit) + ḡ = torch.where(mask_g, g, limit) l̄ = torch.where(mask_l, l, l.sign() * limit) - else: # no clipping + else: # no clipping mask_g = mask_l = torch.ones_like(g, dtype=bool) - ḡ, l̄ = g, l + ḡ, l̄ = g, l - σ = torch.sigmoid(alpha * ḡ) - dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) - dl = ḡ * σ - dg = torch.where(mask_g, dg, 0.) # clamp-grad - dl = torch.where(mask_l, dl, 0.) + σ = torch.sigmoid(alpha * ḡ) + dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) + dl = ḡ * σ + dg = torch.where(mask_g, dg, 0.0) # clamp-grad + dl = torch.where(mask_l, dl, 0.0) grad = torch.empty_like(pre_act) grad[..., ::2], grad[..., 1::2] = dg, dl return g1 * grad.to(g1.dtype) + + pass @@ -133,15 +140,25 @@ def is_kernels_available(): lambda *args, **kwargs: True ) except Exception as e: - return raise_error("transformers.quantizers.quantizer_mxfp4.is_kernels_available", e) + return raise_error( + "transformers.quantizers.quantizer_mxfp4.is_kernels_available", e + ) - if hasattr(transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer, "_lazy_import_kernels"): - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer._lazy_import_kernels = lambda *args, **kwargs: triton_kernels + if hasattr( + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer, "_lazy_import_kernels" + ): + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer._lazy_import_kernels = ( + lambda *args, **kwargs: triton_kernels + ) try: - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = lambda *args, **kwargs: True + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = ( + lambda *args, **kwargs: True + ) except Exception as e: - return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) + return raise_error( + "transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e + ) if HAS_TRITON_KERNELS: try: @@ -166,6 +183,7 @@ def is_kernels_available(): def swizzle_mxfp4(w, w_scale, *args, **kwargs): from triton_kernels import tensor, tensor_details + FP4, convert_layout, wrap_torch_tensor = ( tensor.FP4, tensor.convert_layout, @@ -174,8 +192,12 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): layout = tensor_details.layout StridedLayout = tensor_details.layout.StridedLayout - value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) - w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1 + ) + w = convert_layout( + wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts + ) # TODO : add that when we are actually sure that it works on B200 # if torch.cuda.get_device_capability()[0] == 10: # constraints = { @@ -190,7 +212,13 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale - patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level = "relaxed") + + patch_function( + transformers.integrations.mxfp4, + "swizzle_mxfp4", + swizzle_mxfp4, + match_level="relaxed", + ) class Mxfp4GptOssExperts_Training(torch.autograd.Function): @staticmethod @@ -203,7 +231,9 @@ def forward( scatter_idx, ): pre_activation = matmul_ogs( - hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware + hidden_states.to( + torch.bfloat16 + ), # tl.dot_scaled upcasts to BF16 for old hardware self_class.gate_up_proj, self_class.gate_up_proj_bias, routing_data, @@ -242,6 +272,7 @@ def forward( ctx.scatter_idx = scatter_idx ctx.routing_data = routing_data return out + pass @staticmethod @@ -274,7 +305,9 @@ def backward(ctx, grad_token): dx_token = torch.zeros_like(grad_token) dx_token.index_add_(0, gather_dst, dx_exp) return (dx_token, None, None, None, None,) + pass + pass class Mxfp4GptOssExperts(nn.Module): @@ -287,7 +320,13 @@ def __init__(self, config): # MXFP4 quantized format (blocks + scales) self.gate_up_proj_blocks = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), + torch.zeros( + self.num_experts, + 2 * self.intermediate_size, + self.hidden_size // 32, + 16, + dtype=torch.uint8, + ), requires_grad=False, ) self.gate_up_proj_scales = nn.Parameter( @@ -295,7 +334,8 @@ def __init__(self, config): requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), + requires_grad=False, ) self.down_proj_blocks = nn.Parameter( @@ -307,15 +347,91 @@ def __init__(self, config): requires_grad=False, ) self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), + requires_grad=False, ) + self.alpha = 1.702 self.limit = getattr(config, "swiglu_limit", 7.0) self.gate_up_proj_precision_config = None self.down_proj_precision_config = None @property - def gate_up_proj(selforward( + def gate_up_proj(self): + """Return gate_up_proj tensor (created from blocks/scales or stored directly).""" + # Check if already set as an attribute (from checkpoint loading or previous dequantization) + if "_gate_up_proj" in self.__dict__: + return self.__dict__["_gate_up_proj"] + + # Check if MXFP4 weights are present (blocks and scales are not all zeros) + blocks_valid = ( + self.gate_up_proj_blocks.device.type != "meta" + and self.gate_up_proj_blocks.numel() > 0 + and self.gate_up_proj_blocks.any() + ) + + if not blocks_valid: + # No MXFP4 weights and no regular weights loaded + raise AttributeError( + f"Mxfp4GptOssExperts.gate_up_proj: No weights loaded. " + f"Try 'openai/gpt-oss-20b' with load_in_4bit=True instead." + ) + + # MXFP4 weights present - dequantize them + try: + from transformers.integrations.mxfp4 import dequantize + # Dequantize: (E, out_dim, in_dim//32, 16) -> (E, out_dim, in_dim) + dequantized = dequantize(self.gate_up_proj_blocks, self.gate_up_proj_scales) + # Cache for future accesses + self.__dict__["_gate_up_proj"] = dequantized + return dequantized + except Exception as e: + raise RuntimeError( + f"Failed to dequantize MXFP4 gate_up_proj: {e}. " + f"Ensure transformers.integrations.mxfp4.dequantize is available." + ) + + @gate_up_proj.setter + def gate_up_proj(self, value): + """Set gate_up_proj tensor (called during checkpoint loading).""" + self.__dict__["_gate_up_proj"] = value + + @property + def down_proj(self): + """Return down_proj tensor (created from blocks/scales or stored directly).""" + if "_down_proj" in self.__dict__: + return self.__dict__["_down_proj"] + + blocks_valid = ( + self.down_proj_blocks.device.type != "meta" + and self.down_proj_blocks.numel() > 0 + and self.down_proj_blocks.any() + ) + + if not blocks_valid: + raise AttributeError( + f"Mxfp4GptOssExperts.down_proj: No weights loaded." + ) + + # MXFP4 weights present - dequantize them + try: + from transformers.integrations.mxfp4 import dequantize + # Dequantize: (E, out_dim, in_dim//32, 16) -> (E, out_dim, in_dim) + dequantized = dequantize(self.down_proj_blocks, self.down_proj_scales) + # Cache for future accesses + self.__dict__["_down_proj"] = dequantized + return dequantized + except Exception as e: + raise RuntimeError( + f"Failed to dequantize MXFP4 down_proj: {e}" + ) + + @down_proj.setter + def down_proj(self, value): + """Set down_proj tensor (called during checkpoint loading).""" + self.__dict__["_down_proj"] = value + + def forward( self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx ) -> torch.Tensor: with torch_cuda_device(hidden_states.device): @@ -323,7 +439,7 @@ def gate_up_proj(selforward( self.act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2) if not hidden_states.requires_grad: intermediate_cache1 = matmul_ogs( - hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware + hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware self.gate_up_proj, self.gate_up_proj_bias, routing_data, @@ -350,43 +466,50 @@ def gate_up_proj(selforward( scatter_idx, ) return intermediate_cache3 + pass + patch_function(transformers.integrations.mxfp4, "Mxfp4GptOssExperts", Mxfp4GptOssExperts) - try: - routing = triton_kernels.routing.routing - routing = torch.compiler.disable(routing) - except Exception as e: - return raise_error("triton_kernels.routing.routing", e) + if HAS_TRITON_KERNELS: + try: + routing = triton_kernels.routing.routing + routing = torch.compiler.disable(routing) + except Exception as e: + return raise_error("triton_kernels.routing.routing", e) - def mlp_forward(self, hidden_states): - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) - router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) + def mlp_forward(self, hidden_states): + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) + router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - with torch_cuda_device(router_logits.device): - routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) + with torch_cuda_device(router_logits.device): + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) - routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) - routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) - return routed_out, router_logits - patch_function(transformers.integrations.mxfp4, "mlp_forward", mlp_forward) + routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) + routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) + return routed_out, router_logits - try: - PrecisionConfig, FlexCtx, InFlexData = ( - triton_kernels.matmul_ogs.PrecisionConfig, - triton_kernels.matmul_ogs.FlexCtx, - triton_kernels.matmul_ogs.InFlexData, - ) - except Exception as e: - return raise_error("triton_kernels.matmul_ogs", e) + patch_function(transformers.integrations.mxfp4, "mlp_forward", mlp_forward) + + if HAS_TRITON_KERNELS: + try: + PrecisionConfig, FlexCtx, InFlexData = ( + triton_kernels.matmul_ogs.PrecisionConfig, + triton_kernels.matmul_ogs.FlexCtx, + triton_kernels.matmul_ogs.InFlexData, + ) + except Exception as e: + return raise_error("triton_kernels.matmul_ogs", e) try: from transformers.integrations.tensor_parallel import shard_and_distribute_module except Exception as e: return raise_error("transformers.integrations.tensor_parallel.shard_and_distribute_module", e) - def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args, **kwargs): + def load_and_swizzle_mxfp4( + module, param_name, param_value, target_device, *args, **kwargs + ): model = kwargs.get("model", None) empty_param = kwargs.get("empty_param", None) casting_dtype = kwargs.get("casting_dtype", None) @@ -397,17 +520,22 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args for proj in ["gate_up_proj", "down_proj"]: if proj in param_name: if device_mesh is not None: - shard_and_distribute_module( - model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh - ) + shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) else: setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False)) blocks_attr = f"{proj}_blocks" scales_attr = f"{proj}_scales" blocks = getattr(module, blocks_attr) scales = getattr(module, scales_attr) - # Check if both blocks and scales both not on on meta device - if blocks.device.type != "meta" and scales.device.type != "meta": + # Check if both blocks and scales both not on meta device AND not all zeros + # (if blocks are all zeros, they're from initialization, not from checkpoint) + blocks_valid = ( + blocks.device.type != "meta" + and scales.device.type != "meta" + and blocks.numel() > 0 + and blocks.any() # At least some non-zero values + ) + if blocks_valid: # need it for ep local_experts = blocks.size(0) if proj == "gate_up_proj": @@ -448,13 +576,16 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args delattr(module, blocks_attr) # setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) del blocks + pass - patch_function(transformers.integrations.mxfp4, "load_and_swizzle_mxfp4", load_and_swizzle_mxfp4, match_level = "relaxed") + patch_function(transformers.integrations.mxfp4, "load_and_swizzle_mxfp4", load_and_swizzle_mxfp4, match_level="relaxed") try: from transformers.integrations.mxfp4 import _replace_with_mxfp4_linear except Exception as e: - return raise_error("transformers.integrations.mxfp4._replace_with_mxfp4_linear", e) + return raise_error( + "transformers.integrations.mxfp4._replace_with_mxfp4_linear", e + ) def replace_with_mxfp4_linear( model, @@ -464,17 +595,11 @@ def replace_with_mxfp4_linear( config=None, ): if quantization_config.dequantize: return model - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + modules_to_not_convert = (["lm_head"] if modules_to_not_convert is None else modules_to_not_convert) if quantization_config.modules_to_not_convert is not None: modules_to_not_convert.extend(quantization_config.modules_to_not_convert) modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _replace_with_mxfp4_linear( - model, - modules_to_not_convert, - current_key_name, - quantization_config, - config=config, - ) + model, has_been_replaced = _replace_with_mxfp4_linear(model, modules_to_not_convert, current_key_name, quantization_config, config=config) if not has_been_replaced: logger.warning_once( "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." @@ -483,206 +608,765 @@ def replace_with_mxfp4_linear( ) return model - patch_function(transformers.integrations.mxfp4, "replace_with_mxfp4_linear", replace_with_mxfp4_linear) + + patch_function( + transformers.integrations.mxfp4, + "replace_with_mxfp4_linear", + replace_with_mxfp4_linear, + ) + + pass TEMPORARY_PATCHES.append(patch_gpt_oss) +class ParameterModule(nn.Linear): + """ + A module that wraps a parameter to look like a Linear layer for PEFT. + It inherits from nn.Linear but manages 3D <-> 2D weight conversion. + Unsloth grouped_mm requires 3D weights: + - gate_up: (E, H, 2I) + - down: (E, I, H) + PEFT Linear requires 2D weights: (Out, In). + We store the weight as 2D for PEFT, and reshape for Unsloth via get_param(). + """ + + def __init__( + self, in_features, out_features, shape_3d, permute_to_2d, permute_to_3d + ): + # Initialize nn.Linear (creates weight of size (out, in)) + super().__init__(in_features, out_features, bias=False) + self.shape_3d = shape_3d + self.permute_to_2d = permute_to_2d + self.permute_to_3d = permute_to_3d + + # We expect the caller to set the weight content correctly. + # nn.Linear initialized it randomly. We will overwrite it. + + def extra_repr(self): + return f"in_features={self.in_features}, out_features={self.out_features}, shape_3d={self.shape_3d}" + + def get_param(self): + """Restores the 3D weight for Unsloth computation.""" + # 2D weight (Out, In) -> View (Unflattened 2D) -> Permute -> 3D(E, ...) + # We need to know the unflattened shape. + # gate_up: 2D (E*2I, H). View (E, 2I, H). Permute(0,2,1) -> (E, H, 2I). + # We store (E*2I, H). + # To restore: View (E, 2I, H)? + # (E, 2I, H) is shape_3d permuted by permute_to_2d. + + unflattened_shape = [self.shape_3d[i] for i in self.permute_to_2d] + return ( + self.weight.view(*unflattened_shape) + .permute(*self.permute_to_3d) + .contiguous() + ) + + def set_weight_from_3d(self, weight_3d): + """Sets the 2D weight from a 3D tensor.""" + # 3D -> Permute -> Flatten + weight_2d = weight_3d.permute(*self.permute_to_2d).reshape( + self.out_features, self.in_features + ) + self.weight.data.copy_(weight_2d) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # 'gate_up_proj' in checkpoint is 3D. + # We need to load it into 'gate_up_proj.weight' (2D). + # key is '...gate_up_proj' + key = prefix[:-1] + + if key in state_dict: + # Found the parameter (likely from original model structure where it was a Param) + val = state_dict[key] + # Convert 3D val to 2D + val_2d = val.permute(*self.permute_to_2d).reshape( + self.out_features, self.in_features + ) + + # Put into 'weight' key + state_dict[prefix + "weight"] = val_2d + del state_dict[key] + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + class GptOssExperts(nn.Module): + """ + GPT OSS MoE Experts layer with 3D stacked parameters. + Compatible with transformers' _init_weights and supports grouped_mm with split LoRA. + + Uses the same structure as the original transformers GptOssExperts: + - gate_up_proj: (num_experts, hidden_size, 2 * expert_dim) + - gate_up_proj_bias: (num_experts, 2 * expert_dim) + - down_proj: (num_experts, expert_dim, hidden_size) + - down_proj_bias: (num_experts, hidden_size) + """ + + def __init__(self, config): + super().__init__() + + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.expert_dim = config.intermediate_size + self.intermediate_size = config.intermediate_size # Alias for compatibility + self.alpha = 1.702 + self.limit = getattr(config, "swiglu_limit", 7.0) + self.dtype = dtype_from_config(config) + + # gate_up_proj: 3D (E, H, 2I). Target 2D (E*2I, H). + # Permute (0, 2, 1) -> (E, 2I, H). Reverse (0, 2, 1). + self.gate_up_proj = ParameterModule( + in_features=self.hidden_size, + out_features=self.num_experts * 2 * self.expert_dim, + shape_3d=(self.num_experts, self.hidden_size, 2 * self.expert_dim), + permute_to_2d=(0, 2, 1), + permute_to_3d=(0, 2, 1), + ) + # Initialize 3D zero tensor and set 2D weight + self.gate_up_proj.set_weight_from_3d( + torch.zeros( + self.num_experts, + self.hidden_size, + 2 * self.expert_dim, + dtype=self.dtype, + ) + ) + + self.gate_up_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=self.dtype) + ) + + # down_proj: 3D (E, I, H). Target 2D (H, E*I). + # Permute (2, 0, 1) -> (H, E, I). Reverse (1, 2, 0) + self.down_proj = ParameterModule( + in_features=self.num_experts * self.expert_dim, + out_features=self.hidden_size, + shape_3d=(self.num_experts, self.expert_dim, self.hidden_size), + permute_to_2d=(2, 0, 1), + permute_to_3d=(1, 2, 0), + ) + self.down_proj.set_weight_from_3d( + torch.empty( + self.num_experts, self.expert_dim, self.hidden_size, dtype=self.dtype + ) + ) + + self.down_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, dtype=self.dtype) + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """ + Override to handle loading 3D tensors (gate_up_proj, down_proj) from original checkpoints. + The original checkpoint has these as nn.Parameter (3D tensors), but we now have + them as ParameterModule (nn.Linear subclass) which expects .weight as 2D tensors. + + This method intercepts the 3D tensors and converts them to the 2D format + that ParameterModule expects. + """ + # Handle gate_up_proj: checkpoint has 3D tensor, we need 2D for ParameterModule.weight + gate_up_key = prefix + "gate_up_proj" + gate_up_weight_key = prefix + "gate_up_proj.weight" + if gate_up_key in state_dict and gate_up_weight_key not in state_dict: + val_3d = state_dict.pop(gate_up_key) + # gate_up_proj: 3D (E, H, 2I) -> permute (0, 2, 1) -> (E, 2I, H) -> reshape (E*2I, H) + val_2d = val_3d.permute(0, 2, 1).reshape( + self.num_experts * 2 * self.expert_dim, # out_features + self.hidden_size, # in_features + ) + state_dict[gate_up_weight_key] = val_2d + + # Handle down_proj: checkpoint has 3D tensor, we need 2D for ParameterModule.weight + down_key = prefix + "down_proj" + down_weight_key = prefix + "down_proj.weight" + if down_key in state_dict and down_weight_key not in state_dict: + val_3d = state_dict.pop(down_key) + # down_proj: 3D (E, I, H) -> permute (2, 0, 1) -> (H, E, I) -> reshape (H, E*I) + val_2d = val_3d.permute(2, 0, 1).reshape( + self.hidden_size, # out_features + self.num_experts * self.expert_dim, # in_features + ) + state_dict[down_weight_key] = val_2d + + # Call parent implementation + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + """Forward using grouped_mm or loop fallback with LoRA support.""" + # Use optimized grouped_mm if available + if _check_torch_grouped_mm_supported(): + return forward_native_grouped_mm( + self, hidden_states, router_indices, routing_weights + ) + + # Fallback to loop-based implementation + return forward_native_moe_loop( + self, hidden_states, router_indices, routing_weights + ) + + +pass + + +class _RouterLinearParams(nn.Module): + """ + Simple parameter container that stores weight/bias like nn.Linear + but is NOT nn.Linear itself, so BitsAndBytes will NOT quantize it. + State dict keys: linear.weight, linear.bias (matching BnB 4-bit checkpoints + where the router was saved via an nn.Linear submodule). + """ + def __init__(self, in_features, out_features, dtype): + super().__init__() + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) + self.bias = nn.Parameter(torch.zeros(out_features, dtype=dtype)) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + + +class GptOssTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + # Use _RouterLinearParams (not nn.Linear) to avoid BnB 4-bit quantization. + # State dict keys are router.linear.weight / router.linear.bias, matching + # the BnB 4-bit checkpoint format where router was stored via nn.Linear. + self.linear = _RouterLinearParams( + self.hidden_dim, self.num_experts, dtype=dtype_from_config(config) + ) + + # Properties for compatibility with transformers' _init_weights which expects .weight and .bias + @property + def weight(self): + return self.linear.weight + + @weight.setter + def weight(self, value): + self.linear.weight = value + + @property + def bias(self): + return self.linear.bias + + @bias.setter + def bias(self, value): + self.linear.bias = value + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.linear( + hidden_states.to(self.linear.weight.dtype) + ) # (batch_size * seq_len, num_experts) + router_top_value, router_indices = torch.topk( + router_logits, self.top_k, dim=-1 + ) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax( + router_top_value, dim=1, dtype=router_top_value.dtype + ) + router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_( + 1, router_indices, router_top_value + ) + return router_logits, router_scores, router_indices + + +pass + + +# BitsAndBytes 4bit compatible classes for loading pre-quantized models +class GptOssExpertsBnb4bit(nn.Module): + """ + GPT OSS MoE Experts using nn.Linear layers for BitsAndBytes 4bit compatibility. + This version uses gate_up_projs and down_projs as ModuleLists of Linear layers, + which can be properly quantized with BitsAndBytes. + """ + def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = config.intermediate_size + self.intermediate_size = config.intermediate_size self.alpha = 1.702 self.limit = getattr(config, "swiglu_limit", 7.0) self.dtype = dtype_from_config(config) self.gate_up_projs = nn.ModuleList([ - nn.Linear(self.hidden_size, 2 * self.expert_dim, dtype=self.dtype) + nn.Linear(self.hidden_size, 2 * self.expert_dim, bias=True, dtype=self.dtype) for _ in range(self.num_experts) ]) self.down_projs = nn.ModuleList([ - nn.Linear(self.expert_dim, self.hidden_size, dtype=self.dtype) + nn.Linear(self.expert_dim, self.hidden_size, bias=True, dtype=self.dtype) for _ in range(self.num_experts) ]) - def forward( - self, - hidden_states: torch.Tensor, - router_indices = None, - routing_weights = None - ) -> torch.Tensor: + # Provide minimal tensors for transformers _init_weights compatibility. + # Keep them empty to avoid allocating large bf16 weights in 4-bit mode. + self.register_buffer( + "gate_up_proj", torch.empty(0, dtype=self.dtype), persistent=False + ) + self.register_buffer( + "gate_up_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False + ) + self.register_buffer( + "down_proj", torch.empty(0, dtype=self.dtype), persistent=False + ) + self.register_buffer( + "down_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False + ) + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) num_experts = routing_weights.shape[1] + if self.training: - next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) - # with torch.no_grad(): - # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - # expert_mask = expert_mask.permute(2, 1, 0) - # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - # for expert_idx in expert_hitted[:]: - for expert_idx in range(num_experts): + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted[:]: with torch.no_grad(): - # _, token_idx = torch.where(expert_mask[expert_idx[0]]) - token_idx, _ = torch.where(router_indices == expert_idx) + _, token_idx = torch.where(expert_mask[expert_idx[0]]) current_state = hidden_states[token_idx] gate_up = self.gate_up_projs[expert_idx](current_state) - gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit) - # gate, up = gate_up[..., ::2], gate_up[..., 1::2] - # gate = gate.clamp(min=None, max=self.limit) - # up = up.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # gated_output = (up + 1) * glu + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu out = self.down_projs[expert_idx](gated_output) - weighted_output = out * routing_weights[token_idx, expert_idx, None].to(torch.float32) - next_states.index_add_(0, token_idx, weighted_output) + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states.to(hidden_states.dtype) + return next_states else: X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] gate_up = torch.stack(gate_up_list, dim=0) - fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = X_rep.dtype) - # gate = gate_up[..., ::2] - # up_h = gate_up[..., 1::2] - # gate = gate.clamp(max=self.limit) - # up_h = up_h.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # fused = (up_h + 1) * glu + gate = gate_up[..., ::2] + up_h = gate_up[..., 1::2] + gate = gate.clamp(max=self.limit) + up_h = up_h.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + fused = (up_h + 1) * glu out_list = [down_l(fused[e]) for e, down_l in enumerate(self.down_projs)] outs = torch.stack(out_list, dim=0) rw = routing_weights.transpose(0, 1).unsqueeze(-1) - mixed = (outs.to(torch.float32) * rw.to(torch.float32)).sum(dim=0) - return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) + mixed = (outs * rw).sum(dim=0) + return mixed.view(batch_size, -1, self.hidden_size) + + pass -class GptOssTopKRouter(nn.Module): + +class GptOssTopKRouterBnb4bit(nn.Module): + """ + GPT OSS Router using direct parameters for BitsAndBytes 4bit compatibility. + This version uses weight/bias as direct nn.Parameter instead of nested Linear. + """ + def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - self.linear = nn.Linear( - self.hidden_dim, self.num_experts, dtype=dtype_from_config(config) + self.dtype = dtype_from_config(config) + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, dtype=self.dtype)) + self.bias = nn.Parameter(torch.zeros(self.num_experts, dtype=self.dtype)) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # Accept checkpoints that used a Linear router (linear.weight/bias) + linear_weight_key = prefix + "linear.weight" + linear_bias_key = prefix + "linear.bias" + weight_key = prefix + "weight" + bias_key = prefix + "bias" + moved_weight = False + moved_bias = False + if linear_weight_key in state_dict and weight_key not in state_dict: + state_dict[weight_key] = state_dict.pop(linear_weight_key) + moved_weight = True + if linear_bias_key in state_dict and bias_key not in state_dict: + state_dict[bias_key] = state_dict.pop(linear_bias_key) + moved_bias = True + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, ) - # Expose weight/bias for HF init compatibility - self.weight = self.linear.weight - self.bias = self.linear.bias + if moved_weight: + if linear_weight_key in unexpected_keys: + unexpected_keys.remove(linear_weight_key) + if weight_key in missing_keys: + missing_keys.remove(weight_key) + if moved_bias: + if linear_bias_key in unexpected_keys: + unexpected_keys.remove(linear_bias_key) + if bias_key in missing_keys: + missing_keys.remove(bias_key) + if self.weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): + self.weight.data = self.weight.data.to(self.dtype) + if self.bias is not None and self.bias.dtype not in ( + torch.float16, torch.bfloat16, torch.float32 + ): + self.bias.data = self.bias.data.to(self.dtype) - @torch_compile(dynamic = True, fullgraph = True) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.linear(hidden_states.to(self.linear.weight.dtype)) # (batch_size * seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_logits = torch.nn.functional.linear( + hidden_states.to(self.weight.dtype), self.weight, self.bias + ) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) - router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_(1, router_indices, router_top_value) + return router_logits, router_scores, router_indices + + pass +def patch_gpt_oss_bnb4bit(): + """ + Patch transformers to use BnB 4bit compatible classes when loading pre-quantized models. + This should be called before loading models that were saved with BitsAndBytes quantization + using the linear-based expert structure. + + Usage: + from unsloth_zoo.temporary_patches.gpt_oss import patch_gpt_oss_bnb4bit + patch_gpt_oss_bnb4bit() # Call before loading the model + model = FastLanguageModel.from_pretrained(...) + """ + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + + # Store original classes for potential restoration + if not hasattr(transformers.models.gpt_oss.modeling_gpt_oss, '_original_GptOssExperts'): + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssExperts = \ + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssTopKRouter = \ + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter + + # Replace with BnB 4bit compatible versions + # Preserve original symbol names for compiler-generated modules. + GptOssExpertsBnb4bit.__name__ = "GptOssExperts" + GptOssExpertsBnb4bit.__qualname__ = "GptOssExperts" + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit + # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. + # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. + # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict + # override to remap keys, but transformers v5 bypasses _load_from_state_dict + # (uses accelerate's set_module_tensor_to_device), so the remapping never ran + # and router weights were randomly initialized - causing high loss (~4-5). + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter + + logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") + os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" + return True + + +pass + + +def restore_gpt_oss_original(): + """ + Restore original GPT-OSS classes (undo BnB 4bit patch). + """ + try: + import transformers.models.gpt_oss.modeling_gpt_oss + if hasattr(transformers.models.gpt_oss.modeling_gpt_oss, '_original_GptOssExperts'): + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = \ + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssExperts + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = \ + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssTopKRouter + logger.info("Unsloth: Restored original GPT OSS classes") + return True + except Exception: + pass + return False + +def patch_gpt_oss_bnb4bit_auto(): + """ + Auto-patch GPT-OSS for BnB 4-bit when load_in_4bit is active. + Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to opt out. + """ + if not _should_use_gpt_oss_bnb4bit(): + return + # Avoid compiler-generated modules missing BnB helpers + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + patch_gpt_oss_bnb4bit() + # Ensure inference path avoids torch.compile for 4-bit + try: + global moe_forward_inference + moe_forward_inference = torch.compiler.disable(moe_forward_inference) + except Exception: + pass + + +TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto) + + # Combo kernels uses too much VRAM for low memory GPUs from ..device_type import DEVICE_TYPE + if DEVICE_TYPE == "xpu": device_memory = torch.xpu.memory.mem_get_info(0)[-1] else: device_memory = torch.cuda.memory.mem_get_info(0)[-1] -use_combo_kernels = False if device_memory/1024/1024/1024 <= 40 else True +use_combo_kernels = False if device_memory / 1024 / 1024 / 1024 <= 40 else True fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion = True, - max_autotune = False, # Too slow - shape_padding = True, - cudagraphs = True, - coordinate_descent_tuning = use_combo_kernels, # Very slow! - combo_kernels = use_combo_kernels, - memory_planning = True, - multi_kernel = False, # Fails on torch 2.10 nightly - use_block_ptr = True, - logging = UNSLOTH_ENABLE_LOGGING, + epilogue_fusion=True, + max_autotune=False, # Too slow + shape_padding=True, + cudagraphs=True, + coordinate_descent_tuning=use_combo_kernels, # Very slow! + combo_kernels=use_combo_kernels, + memory_planning=True, + multi_kernel=False, # Fails on torch 2.10 nightly + use_block_ptr=True, + logging=UNSLOTH_ENABLE_LOGGING, ) no_combo_fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion = True, - max_autotune = False, # Too slow - shape_padding = True, - cudagraphs = True, - coordinate_descent_tuning = use_combo_kernels, # Very slow! - combo_kernels = False, # Breaks on attention - memory_planning = True, - multi_kernel = False, # Fails on torch 2.10 nightly - use_block_ptr = True, - logging = UNSLOTH_ENABLE_LOGGING, + epilogue_fusion=True, + max_autotune=False, # Too slow + shape_padding=True, + cudagraphs=True, + coordinate_descent_tuning=use_combo_kernels, # Very slow! + combo_kernels=False, # Breaks on attention + memory_planning=True, + multi_kernel=False, # Fails on torch 2.10 nightly + use_block_ptr=True, + logging=UNSLOTH_ENABLE_LOGGING, ) -@_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) + +@_torch_compile(dynamic=None, fullgraph=True, options=fused_torch_compile_options) def moe_forward_inference(self, hidden_states): """Torch compile for forward inference path only with CUDAGraphs""" # Router - router_scores, router_indices = self.router(hidden_states) + router_out = self.router(hidden_states) + if isinstance(router_out, tuple) and len(router_out) == 3: + _, router_scores, router_indices = router_out + else: + router_scores, router_indices = router_out routing_weights = router_scores moe = self.experts batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, moe.hidden_size) num_experts = routing_weights.shape[1] - X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) - # Gate up projection - gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(moe.gate_up_projs)] - gate_up = torch.stack(gate_up_list, dim = 0) - dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype - fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype = dtype) + # Check if using ModuleList (old style) or 3D parameters (new style) + if hasattr(moe, "gate_up_projs"): + # ModuleList style + X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) + gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(moe.gate_up_projs)] + gate_up = torch.stack(gate_up_list, dim=0) + dtype = ( + torch.float32 + if hidden_states.dtype != torch.bfloat16 + else hidden_states.dtype + ) + fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) - # Down projection must be done in float32 if not bfloat16 otherwise infinites - fused = fused.to(dtype) - device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - out_list = [down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs)] - outs = torch.stack(out_list, dim=0) + fused = fused.to(dtype) + device_type = ( + fused.device.type + if isinstance(fused.device.type, str) and fused.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + out_list = [ + down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs) + ] + outs = torch.stack(out_list, dim=0) + else: + # 3D parameter style (compatible with transformers) + X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) + # gate_up_proj: (E, hidden_size, 2*expert_dim) - bmm: (E, N, H) @ (E, H, 2I) -> (E, N, 2I) + gate_up = ( + torch.bmm(X_rep, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] + ) + dtype = ( + torch.float32 + if hidden_states.dtype != torch.bfloat16 + else hidden_states.dtype + ) + fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) + + fused = fused.to(dtype) + device_type = ( + fused.device.type + if isinstance(fused.device.type, str) and fused.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + # down_proj: (E, expert_dim, hidden_size) - bmm: (E, N, I) @ (E, I, H) -> (E, N, H) + outs = ( + torch.bmm(fused.to(dtype), moe.down_proj) + + moe.down_proj_bias[..., None, :] + ) rw = routing_weights.to(dtype).transpose(0, 1).unsqueeze(-1) mixed = (outs * rw).sum(dim=0) return mixed.view(batch_size, -1, moe.hidden_size).to(hidden_states.dtype) + + pass -@torch_compile(dynamic = True, fullgraph = True) + +@torch_compile(dynamic=True, fullgraph=True) def moe_router_forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, self.bias) # (seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) - router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) + router_logits = F.linear( + hidden_states.to(self.weight.dtype), self.weight, self.bias + ) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk( + router_logits, self.top_k, dim=-1 + ) # (seq_len, top_k) + dtype = ( + torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype + ) + router_top_value = torch.nn.functional.softmax( + router_top_value, dim=1, dtype=torch.float32 + ).to(dtype) + router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_( + 1, router_indices, router_top_value + ) return router_scores, router_indices + + pass -# Combo Kernels errors with InductorError: AttributeError: 'NullKernelHandler' object has no attribute 'index_to_str' -@_torch_compile(dynamic = None, fullgraph = True, options = no_combo_fused_torch_compile_options) -def moe_forward_inference_bf16(self, hidden_states): - router_scores, router_indices = moe_router_forward(self.router, hidden_states) - routing_weights = router_scores - moe = self.experts +# Combo Kernels errors with InductorError: AttributeError: 'NullKernelHandler' object has no attribute 'index_to_str' +@_torch_compile( + dynamic=None, fullgraph=True, options=no_combo_fused_torch_compile_options +) +def _moe_forward_inference_bf16_kernel( + hidden_states, routing_weights, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias, limit, alpha, hidden_size +): + """Inner compiled kernel for BF16 MoE inference - works with raw tensors only.""" batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, moe.hidden_size) + hidden_states = hidden_states.reshape(-1, hidden_size) num_experts = routing_weights.shape[1] hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, moe.hidden_size) - gate_up = torch.bmm(hidden_states, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] + hidden_states = hidden_states.view(num_experts, -1, hidden_size) + + gate_up = ( + torch.bmm(hidden_states, gate_up_proj) + gate_up_proj_bias[..., None, :] + ) gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=moe.limit) - up = up.clamp(min=-moe.limit, max=moe.limit) - glu = gate * torch.sigmoid(gate.to(torch.float32) * moe.alpha).to(gate.dtype) - next_states = torch.bmm(((up + 1) * glu), moe.down_proj) - next_states = next_states + moe.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, moe.hidden_size) - next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate.to(torch.float32) * alpha).to(gate.dtype) + next_states = torch.bmm(((up + 1) * glu), down_proj) + next_states = next_states + down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, hidden_size) + next_states = ( + next_states + * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + ) next_states = next_states.sum(dim=0) return next_states -pass + + +def _unwrap_peft_experts(module): + """Unwrap PEFT ParamWrapper chain to get the actual experts module.""" + while hasattr(module, 'base_layer'): + module = module.base_layer + return module + + +def moe_forward_inference_bf16(self, hidden_states): + """Wrapper that extracts weights from ParameterModule before calling the compiled kernel.""" + router_scores, router_indices = moe_router_forward(self.router, hidden_states) + routing_weights = router_scores + + moe = _unwrap_peft_experts(self.experts) + + # Handle ParameterModule (which wraps nn.Linear) vs direct 3D tensor + # Extract weights BEFORE the compiled region + gate_up_proj = moe.gate_up_proj + if hasattr(gate_up_proj, "get_param"): + gate_up_proj = gate_up_proj.get_param() + elif hasattr(gate_up_proj, "weight"): + gate_up_proj = gate_up_proj.weight + + down_proj = moe.down_proj + if hasattr(down_proj, "get_param"): + down_proj = down_proj.get_param() + elif hasattr(down_proj, "weight"): + down_proj = down_proj.weight + + return _moe_forward_inference_bf16_kernel( + hidden_states, + routing_weights, + gate_up_proj, + moe.gate_up_proj_bias, + down_proj, + moe.down_proj_bias, + moe.limit, + moe.alpha, + moe.hidden_size, + ) + + class GptOssMLP(nn.Module): @@ -695,432 +1379,472 @@ def forward(self, hidden_states): bsz, qlen, hd = hidden_states.shape if qlen == 1 and not self.training: return moe_forward_inference(self, hidden_states), None - router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out, router_scores -pass - -def patch_gpt_oss_linearized(): - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return - if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return - try: - import transformers.models.gpt_oss.modeling_gpt_oss - except Exception as e: - return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + router_out = self.router(hidden_states) + if isinstance(router_out, tuple) and len(router_out) == 3: + _, router_scores, router_indices = router_out + else: + router_scores, router_indices = router_out + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out, router_scores - _patch_gpt_oss_init_weights_for_modulelist( - transformers.models.gpt_oss.modeling_gpt_oss - ) - # We find down_proj overflows in GPT OSS - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - def forward( - self, - hidden_states: torch.Tensor, - router_indices = None, - routing_weights = None - ) -> torch.Tensor: +pass - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) - num_experts = routing_weights.shape[1] - if self.training: - next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) - # with torch.no_grad(): - # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - # expert_mask = expert_mask.permute(2, 1, 0) - # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - # for expert_idx in expert_hitted[:]: - for expert_idx in range(num_experts): - with torch.no_grad(): - # _, token_idx = torch.where(expert_mask[expert_idx[0]]) - token_idx, _ = torch.where(router_indices == expert_idx) - current_state = hidden_states[token_idx] - gate_up = self.gate_up_projs[expert_idx](current_state) - down_proj = self.down_projs[expert_idx] - gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = torch.float32) - # gate, up = gate_up[..., ::2], gate_up[..., 1::2] - # gate = gate.clamp(min=None, max=self.limit) - # up = up.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # gated_output = (up + 1) * glu - - # Force float32 matrix multiply on some down projection modules - gated_output = gated_output.to(torch.float32) - device_type = gated_output.device.type if isinstance(gated_output.device.type, str) and gated_output.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - out = down_proj(gated_output) - weighted_output = out.to(torch.float32) * routing_weights[token_idx, expert_idx, None].to(torch.float32) - next_states.index_add_(0, token_idx, weighted_output) - next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states.to(torch.float32) - else: - X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) - gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] - gate_up = torch.stack(gate_up_list, dim=0) - dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype - fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = dtype) - # gate = gate_up[..., ::2] - # up_h = gate_up[..., 1::2] - # gate = gate.clamp(max=self.limit) - # up_h = up_h.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # fused = (up_h + 1) * glu - - # Force float32 matrix multiply on down projection only - device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - out_list = [ - down_l(fused[e].to(dtype)) - for e, down_l in enumerate(self.down_projs) - ] - outs = torch.stack(out_list, dim=0) - rw = routing_weights.transpose(0, 1).unsqueeze(-1) - mixed = (outs.to(dtype) * rw.to(dtype)).sum(dim=0) - return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) - pass - pass - GptOssExperts.forward = forward - pass - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExperts - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter - transformers.models.gpt_oss.modeling_gpt_oss.GptOssMLP = GptOssMLP - return -pass -TEMPORARY_PATCHES.append(patch_gpt_oss_linearized) +# ============================================================================ +# GPT OSS MoE LoRA Support using grouped GEMM kernels +# ============================================================================ + +# IMPORTS FROM MOE UTILS +from .moe_utils import ( + _check_grouped_gemm_available, + _TORCH_GROUPED_MM_AVAILABLE, + _check_torch_grouped_mm_supported, + native_moe_grouped_mm, + _get_moe_lora_weights, + _apply_lora_grouped_mm, + _get_lora_wrapper_for_param, + select_moe_backend, + patch_param_wrapper_for_moe, + forward_native_grouped_mm, + forward_native_moe_loop, +) -def _patch_gpt_oss_init_weights_for_modulelist(transformers_module): - GptOssPreTrainedModel = transformers_module.GptOssPreTrainedModel - GptOssExperts = transformers_module.GptOssExperts - if getattr(GptOssPreTrainedModel, "_unsloth_init_weights_patched", False): - return - _original_init_weights = GptOssPreTrainedModel._init_weights +def _should_use_gpt_oss_bnb4bit() -> bool: + """ + Decide if GPT-OSS should use BnB-compatible 4-bit experts. + Default: True when load_in_4bit is active. + Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return False + if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return False + return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1" - def _patched_init_weights(self, module): - _original_init_weights(self, module) - if isinstance(module, GptOssExperts) and not hasattr(module, "gate_up_proj"): - std = self.config.initializer_range - for up in getattr(module, "gate_up_projs", []): - init.normal_(up.weight, mean=0.0, std=std) - if up.bias is not None: - init.zeros_(up.bias) - for down in getattr(module, "down_projs", []): - init.normal_(down.weight, mean=0.0, std=std) - if down.bias is not None: - init.zeros_(down.bias) - patch_function(GptOssPreTrainedModel, "_init_weights", _patched_init_weights) - GptOssPreTrainedModel._unsloth_init_weights_patched = True +def _is_gpt_oss_4bit_load() -> bool: + return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "") -def patch_gpt_oss_bf16_split_lora(): - """ - Patch GPT-OSS BF16 model to use split LoRA with grouped GEMM. +def _is_transformers_v5() -> bool: + return transformers_version >= Version("5.0.0.dev0") - This patch applies to BF16 models loaded via unsloth/gpt-oss-20b-BF16 or similar. - It creates stacked expert weights and uses moe_utils.py's forward_native_grouped_mm - for efficient training with split LoRA. - Key differences from Qwen3 MoE: - 1. GPT-OSS uses TRANSPOSED weight layout: - - gate_up_proj: (num_experts, hidden_size, 2 * intermediate_size) - - down_proj: (num_experts, intermediate_size, hidden_size) - 2. GPT-OSS gate_up has interleaved layout (::2 for gate, 1::2 for up) - 3. GPT-OSS has biases: gate_up_proj_bias, down_proj_bias +def patch_gpt_oss_moe_for_lora(): + """ + Patch GPT OSS MoE experts for LoRA training with grouped GEMM support. + This patches the original GptOssExperts class (with 3D parameter tensors) + to use optimized grouped GEMM kernels with LoRA support. - 4-bit BNB models use gate_up_projs/down_projs ModuleList and are NOT affected. + IMPORTANT: We only patch the forward method, NOT replace the entire class. + This preserves the original class structure so weights load correctly. """ if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return - # Skip 4-bit models - they use ModuleList and are handled by patch_gpt_oss_linearized - if "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit(): + # 4-bit loads should keep quantized weights and use default PEFT LoRA. + return + if not _is_transformers_v5(): + # Split-LoRA grouped_mm path is only needed for transformers v5+ return + # First patch ParamWrapper for MoE separated LoRA + patch_param_wrapper_for_moe() try: import transformers.models.gpt_oss.modeling_gpt_oss + + # Get the ORIGINAL class - don't replace it! + GptOssExpertsClass = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts except Exception as e: - return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + if UNSLOTH_ENABLE_LOGGING: + logger.warning(f"Unsloth: Could not patch GPT OSS MoE for LoRA: {e}") + return - from .moe_utils import ( - forward_native_grouped_mm, - select_moe_backend, - patch_param_wrapper_for_moe, - _has_lora_adapters, - _extract_lora_from_wrapper, - _should_use_separated_lora, - _is_moe_experts_module, - ) + # Check if already patched + if hasattr(GptOssExpertsClass, "_unsloth_lora_patched"): + return - # Patch ParamWrapper.forward for MoE separated LoRA - patch_param_wrapper_for_moe() + # Select backend + backend = select_moe_backend() - # LoRA Extractor function for GPT-OSS - # GPT-OSS weight layout is TRANSPOSED: (E, hidden, output) instead of (E, output, hidden) - def _gpt_oss_lora_extractor( - self, wrapper, weight_A, weight_B, scaling, num_experts - ): - """ - GPT-OSS LoRA extractor for transposed weight layout. + if backend == "grouped_mm": + forward = forward_native_grouped_mm + else: + forward = forward_native_moe_loop - GPT-OSS weights: - gate_up_proj: (E, H, 2*I) - transposed layout (in_dim, out_dim) - down_proj: (E, I, H) - transposed layout (in_dim, out_dim) + # Store original forward and patch - but DON'T replace the class! + GptOssExpertsClass._original_forward = GptOssExpertsClass.forward + GptOssExpertsClass.forward = forward + GptOssExpertsClass._unsloth_lora_patched = True - For grouped_mm: X @ W where W is (E, in_dim, out_dim) + if UNSLOTH_ENABLE_LOGGING: + backend_desc = { + "grouped_mm": "torch._grouped_mm (batched, fastest)", + "unsloth_triton": "Triton kernels", + "native_torch": "loop fallback (slower)", + }.get(backend, backend) + logger.info( + f"Unsloth: Patched GPT OSS MoE for LoRA training using {backend_desc}" + ) - PEFT creates: - lora_A: (E*R, in_dim) - projects input to rank space - lora_B: (out_dim, E*R) - projects rank to output - For transposed format, the LoRA dimensions are already correct: - - We want X @ (E, in, R) @ (E, R, out) - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - total_rank = weight_A.shape[0] - rank_per_expert = total_rank // num_experts - dim_A = weight_A.shape[ - 1 - ] # in_dim (hidden_dim for gate_up, intermediate for down) - dim_B = weight_B.shape[ - 0 - ] # out_dim (2*intermediate for gate_up, hidden_dim for down) - - # Get model dimensions from the experts module - hidden_dim = None - intermediate_dim = None - current = wrapper - while hasattr(current, "base_layer"): - current = current.base_layer - if hasattr(current, "hidden_size"): - hidden_dim = current.hidden_size - if hasattr(current, "intermediate_size"): - intermediate_dim = current.intermediate_size - - # Get parameter name - param_name = getattr(wrapper, "parameter_name", None) - - # GPT-OSS uses TRANSPOSED layout, so LoRA dimensions map directly: - # Input projection: X @ (E, in_dim, R) - # Output projection: result @ (E, R, out_dim) - - if ( - param_name == "down_proj" - and intermediate_dim is not None - and hidden_dim is not None - ): - # down_proj: input=intermediate_dim, output=hidden_dim - # Weight shape: (E, I, H) - transposed - # lora_A: (E*R, H) from PEFT (swapped due to 3D param handling) - # lora_B: (I, E*R) from PEFT (swapped) - # For X @ first @ second: first is (E, I, R), second is (E, R, H) - - # first_weight from B (has intermediate_dim) - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() # (E, I, R) - - # second_weight from A (has hidden_dim) - second_weight = weight_A.view( - num_experts, rank_per_expert, dim_A - ) # (E, R, H) - - return first_weight, second_weight, scaling, num_experts - - elif param_name == "gate_up_proj" and hidden_dim is not None: - # gate_up_proj: input=hidden_dim, output=2*intermediate_dim - # Weight shape: (E, H, 2*I) - transposed - # lora_A: (E*R, 2*I) from PEFT - # lora_B: (H, E*R) from PEFT - # For X @ first @ second: first is (E, H, R), second is (E, R, 2*I) - - # first_weight from B (has hidden_dim) - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() # (E, H, R) - - # second_weight from A (has 2*intermediate_dim) - second_weight = weight_A.view( - num_experts, rank_per_expert, dim_A - ) # (E, R, 2*I) - - return first_weight, second_weight, scaling, num_experts - - # Fallback: dimension-based detection - if hidden_dim is not None: - if dim_B == hidden_dim: - # B connects to hidden_dim (transposed case) - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() - second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - return first_weight, second_weight, scaling, num_experts - elif dim_A == hidden_dim: - # A connects to hidden_dim - first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - first_weight = first_weight.permute(0, 2, 1).contiguous() - second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - second_weight = second_weight.permute(1, 2, 0).contiguous() - return first_weight, second_weight, scaling, num_experts - - # Final fallback - first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - first_weight = first_weight.permute(0, 2, 1).contiguous() - second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - second_weight = second_weight.permute(1, 2, 0).contiguous() - return first_weight, second_weight, scaling, num_experts - - # Patch GptOssExperts.forward to use grouped GEMM + split LoRA (BF16 only) - GptOssExperts = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts - _original_gpt_oss_experts_forward = GptOssExperts.forward +TEMPORARY_PATCHES.append(patch_gpt_oss_moe_for_lora) - def _bf16_split_lora_forward( - self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None - ) -> torch.Tensor: - # Fallback to original for 4-bit ModuleList or missing routing data - if ( - router_indices is None - or routing_weights is None - or hasattr(self, "gate_up_projs") - or not hasattr(self, "gate_up_proj") - or not hasattr(self, "down_proj") - ): - return _original_gpt_oss_experts_forward( - self, hidden_states, router_indices, routing_weights - ) - gate_up_param = getattr(self, "gate_up_proj", None) - down_param = getattr(self, "down_proj", None) - if ( - not isinstance(gate_up_param, nn.Parameter) - or gate_up_param.ndim != 3 - or not isinstance(down_param, nn.Parameter) - or down_param.ndim != 3 - ): - return _original_gpt_oss_experts_forward( - self, hidden_states, router_indices, routing_weights - ) +# ============================================================================ +# MXFP4 (4-bit) GPT OSS MoE LoRA Support +# ============================================================================ - if not hasattr(self, "_unsloth_model_type"): - self._unsloth_model_type = "gpt_oss" +_MXFP4_LORA_PATH_LOGGED = False - return forward_native_grouped_mm( - self, - hidden_states, - router_indices, # top_k_index - routing_weights, # top_k_weights +@torch.compiler.disable +def forward_mxfp4_gpt_oss_with_lora( + self, + hidden_states: torch.Tensor, + routing_data, + gather_idx, + scatter_idx, +) -> torch.Tensor: + """ + MXFP4 GPT OSS MoE forward pass with LoRA support. + + For LoRA training with MXFP4: + - Base MXFP4 matmul is computed without gradients (frozen weights) + - LoRA delta is computed with gradients + - Output = base_output + lora_delta + + This allows finetuning MXFP4 quantized models with LoRA. + + Requires triton_kernels for native MXFP4 matmul. + If triton_kernels is not available, model should be loaded with + Mxfp4Config(dequantize=True) to convert to bf16. + """ + if not is_triton_kernels_available(): + raise RuntimeError( + "triton_kernels is required for native MXFP4 GPT OSS forward pass. " + "Either:\n" + " 1. Install triton_kernels from OpenAI, OR\n" + " 2. Load model with dequantization: Mxfp4Config(dequantize=True), OR\n" + " 3. Use the BF16 model: 'unsloth/gpt-oss-20b-BF16'\n" + "Set UNSLOTH_MXFP4_NO_DEQUANTIZE=0 (default) to auto-dequantize." ) - patch_function(GptOssExperts, "forward", _bf16_split_lora_forward) + from triton_kernels import matmul_ogs, swiglu - _patch_gpt_oss_init_weights_for_modulelist( - transformers.models.gpt_oss.modeling_gpt_oss - ) + matmul_ogs_fn = matmul_ogs.matmul_ogs + FnSpecs = matmul_ogs.FnSpecs + FusedActivation = matmul_ogs.FusedActivation + swiglu_fn = swiglu.swiglu_fn - # BF16 Experts class with stacked weights (for split LoRA via moe_utils) - class GptOssExpertsBF16Stacked(nn.Module): - """ - GPT-OSS BF16 Experts with stacked weights for grouped GEMM. + # Get LoRA wrappers + gate_up_wrapper = _get_lora_wrapper_for_param(self, "gate_up_proj") + down_wrapper = _get_lora_wrapper_for_param(self, "down_proj") - Uses moe_utils.forward_native_grouped_mm for efficient MoE computation - with separated LoRA support. - """ + has_lora = gate_up_wrapper is not None or down_wrapper is not None - def __init__(self, config): - super().__init__() - self.num_experts = config.num_local_experts - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.alpha = 1.702 - self.limit = getattr(config, "swiglu_limit", 7.0) - self.dtype = dtype_from_config(config) + # Log path info once + global _MXFP4_LORA_PATH_LOGGED + if not _MXFP4_LORA_PATH_LOGGED: + _MXFP4_LORA_PATH_LOGGED = True + logger.warning_once( + f"Unsloth: GPT-OSS MoE training path: MXFP4 + triton_kernels. " + f"LoRA={has_lora}, experts={self.num_experts}. " + f"Tip: Increase batch_size for better GPU utilization." + ) - # Stacked weights in transposed format (E, in_dim, out_dim) for grouped_mm - self.gate_up_proj = nn.Parameter( - torch.empty( - self.num_experts, - self.hidden_size, - 2 * self.intermediate_size, - dtype=self.dtype, + # If no LoRA, use the original MXFP4 forward (inference path) + if not has_lora: + with torch_cuda_device(hidden_states.device): + if not hasattr(self, "act"): + self.act = FusedActivation( + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + (self.alpha, self.limit), + 2, ) + intermediate_cache1 = matmul_ogs_fn( + hidden_states.to(torch.bfloat16), + self.gate_up_proj, + self.gate_up_proj_bias, + routing_data, + gather_indx=gather_idx, + precision_config=self.gate_up_proj_precision_config, + gammas=None, + fused_activation=self.act, ) - self.gate_up_proj_bias = nn.Parameter( - torch.empty( - self.num_experts, 2 * self.intermediate_size, dtype=self.dtype + intermediate_cache3 = matmul_ogs_fn( + intermediate_cache1, + self.down_proj, + self.down_proj_bias, + routing_data, + scatter_indx=scatter_idx, + precision_config=self.down_proj_precision_config, + gammas=routing_data.gate_scal if routing_data else None, + ) + return intermediate_cache3 + + # With LoRA: compute base MXFP4 output (no grad) + LoRA delta (with grad) + with torch_cuda_device(hidden_states.device): + # 1. Compute MXFP4 base output (detached, no gradients for base weights) + with torch.no_grad(): + if not hasattr(self, "act"): + self.act = FusedActivation( + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + (self.alpha, self.limit), + 2, ) + + # Gate-up projection (MXFP4) + pre_activation = matmul_ogs_fn( + hidden_states.to(torch.bfloat16), + self.gate_up_proj, + self.gate_up_proj_bias, + routing_data, + gather_indx=gather_idx, + precision_config=self.gate_up_proj_precision_config, + gammas=None, + fused_activation=None, # Don't fuse activation so we can add LoRA before ) - self.down_proj = nn.Parameter( - torch.empty( - self.num_experts, - self.intermediate_size, - self.hidden_size, - dtype=self.dtype, + + # 2. Add LoRA for gate_up_proj using grouped_mm + if gate_up_wrapper is not None: + lora_data = _get_moe_lora_weights(gate_up_wrapper) + if lora_data is not None: + lora_A, lora_B, scaling, num_experts = lora_data + + # Convert triton_kernels routing format to grouped_mm offsets + # routing_data.exp_indx contains expert index for each token + gather_src = gather_idx.src_indx + permuted_input = hidden_states[gather_src].to(torch.bfloat16) + + # Compute offsets from expert indices (cumsum of token counts) + expert_ids = routing_data.exp_indx + num_tokens_per_expert = torch.bincount( + expert_ids.int(), minlength=num_experts + ).int() + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + + # Use grouped_mm for LoRA computation (much faster than loop!) + lora_delta = _apply_lora_grouped_mm( + permuted_input, + lora_A, + lora_B, + offsets, + scaling, + grouped_mm_func=native_moe_grouped_mm, ) + pre_activation = pre_activation + lora_delta + + # 3. Apply SwiGLU activation + alpha = getattr(self, "alpha", 1.702) + limit = getattr(self, "limit", 7.0) + swiglu_output = swiglu_torch_forward(pre_activation, alpha, limit) + + # 4. Down projection (MXFP4) + with torch.no_grad(): + base_output = matmul_ogs_fn( + swiglu_output.detach(), # Detach to prevent grad through MXFP4 + self.down_proj, + self.down_proj_bias, + routing_data, + scatter_indx=scatter_idx, + precision_config=self.down_proj_precision_config, + gammas=routing_data.gate_scal if routing_data else None, ) - self.down_proj_bias = nn.Parameter( - torch.empty(self.num_experts, self.hidden_size, dtype=self.dtype) + + # 5. Add LoRA for down_proj using grouped_mm + if down_wrapper is not None: + lora_data = _get_moe_lora_weights(down_wrapper) + if lora_data is not None: + lora_A, lora_B, scaling, num_experts = lora_data + + # Compute offsets from expert indices (reuse if available) + expert_ids = routing_data.exp_indx + num_tokens_per_expert = torch.bincount( + expert_ids.int(), minlength=num_experts + ).int() + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + + # Use grouped_mm for LoRA computation + lora_delta = _apply_lora_grouped_mm( + swiglu_output, + lora_A, + lora_B, + offsets, + scaling, + grouped_mm_func=native_moe_grouped_mm, + ) + + # Scatter LoRA delta back to output (apply routing weights) + scatter_dst = scatter_idx.dst_indx + gamma = routing_data.gate_scal.unsqueeze(-1) + base_output.index_add_( + 0, scatter_dst, (lora_delta * gamma).to(base_output.dtype) + ) + + return base_output + + +def patch_mxfp4_gpt_oss_for_lora(): + """ + Patch MXFP4 GPT OSS experts for LoRA training. + + This enables finetuning the MXFP4 quantized model (unsloth/gpt-oss-20b) + with LoRA. + + IMPORTANT: Requires triton_kernels for native MXFP4 matmul. + If triton_kernels is not available, users must either: + - Use Mxfp4Config(dequantize=True) when loading, OR + - Use the BF16 model: 'unsloth/gpt-oss-20b-BF16' + """ + # First patch ParamWrapper for MoE separated LoRA (v5 only) + if _is_transformers_v5(): + patch_param_wrapper_for_moe() + + try: + import transformers.integrations.mxfp4 + + Mxfp4GptOssExpertsClass = getattr( + transformers.integrations.mxfp4, "Mxfp4GptOssExperts", None + ) + if Mxfp4GptOssExpertsClass is None: + if UNSLOTH_ENABLE_LOGGING: + logger.warning( + "Unsloth: Mxfp4GptOssExperts not found in transformers.integrations.mxfp4" + ) + return + except Exception as e: + if UNSLOTH_ENABLE_LOGGING: + logger.warning(f"Unsloth: Could not patch MXFP4 GPT OSS for LoRA: {e}") + return + + # Check if already patched + if hasattr(Mxfp4GptOssExpertsClass, "_unsloth_mxfp4_lora_patched"): + return + + # Only patch if triton_kernels is available + # Without triton_kernels, MXFP4 weights cannot be used directly + # Users must use dequantization or BF16 model + if is_triton_kernels_available(): + # Use native MXFP4 + LoRA (keeps weights quantized) + Mxfp4GptOssExpertsClass._original_forward = Mxfp4GptOssExpertsClass.forward + Mxfp4GptOssExpertsClass.forward = forward_mxfp4_gpt_oss_with_lora + Mxfp4GptOssExpertsClass._unsloth_mxfp4_lora_patched = True + if UNSLOTH_ENABLE_LOGGING: + logger.info("Unsloth: Patched MXFP4 GPT OSS MoE for LoRA training") + else: + # triton_kernels NOT available - do NOT patch + # The model will fail with a helpful error if user tries to use MXFP4 without dequantization + Mxfp4GptOssExpertsClass._unsloth_mxfp4_lora_patched = True + if UNSLOTH_ENABLE_LOGGING: + logger.warning( + "Unsloth: triton_kernels is not installed. MXFP4 GPT OSS will NOT be patched for LoRA.\n" + "To train GPT OSS with LoRA, either:\n" + " 1. Install triton_kernels from OpenAI (for native MXFP4), OR\n" + " 2. Use Mxfp4Config(dequantize=True) when loading (dequantizes to bf16), OR\n" + " 3. Use the BF16 model: 'unsloth/gpt-oss-20b-BF16'" ) - # Register LoRA extractor - self._unsloth_lora_extractor_fn = _gpt_oss_lora_extractor - @property - def hidden_dim(self): - return self.hidden_size +TEMPORARY_PATCHES.append(patch_mxfp4_gpt_oss_for_lora) - @property - def intermediate_dim(self): - return self.intermediate_size - def forward( - self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None - ) -> torch.Tensor: - """ - Forward pass using grouped GEMM from moe_utils. - Uses forward_native_grouped_mm which handles split LoRA automatically. - """ - # Call moe_utils grouped MM forward which handles everything - return forward_native_grouped_mm( - self, - hidden_states, - router_indices, # top_k_index - routing_weights, # top_k_weights +_MXFP4_DEQUANT_WARNED = False + + +def should_dequantize_mxfp4(): + """ + Check if MXFP4 should be dequantized to bf16. + + Returns True if: + - UNSLOTH_MXFP4_NO_DEQUANTIZE is not set or "0", OR + - UNSLOTH_MXFP4_NO_DEQUANTIZE="1" but triton_kernels is not available + + Returns False if: + - UNSLOTH_MXFP4_NO_DEQUANTIZE="1" AND triton_kernels is available + + MEMORY IMPACT: + - MXFP4 quantized: ~10GB for GPT-OSS 20B + - MXFP4 dequantized to bf16: ~40GB for GPT-OSS 20B + + To keep MXFP4 quantized (requires triton_kernels): + export UNSLOTH_MXFP4_NO_DEQUANTIZE=1 + """ + global _MXFP4_DEQUANT_WARNED + + if not UNSLOTH_MXFP4_NO_DEQUANTIZE: + # Default: dequantize to bf16 + if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: + _MXFP4_DEQUANT_WARNED = True + logger.warning( + "Unsloth: MXFP4 will be dequantized to bf16 (~4x memory increase). " + "To keep 4-bit: set UNSLOTH_MXFP4_NO_DEQUANTIZE=1 and install triton_kernels." ) + return True - # MLP wrapper that works with stacked experts - class GptOssMLP_BF16(nn.Module): - def __init__(self, config): - super().__init__() - self.router = GptOssTopKRouter(config) - self.experts = GptOssExpertsBF16Stacked(config) + if not is_triton_kernels_available(): + if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: + _MXFP4_DEQUANT_WARNED = True + logger.warning( + "Unsloth: UNSLOTH_MXFP4_NO_DEQUANTIZE=1 but triton_kernels not available. " + "Will dequantize MXFP4 to bf16 (~4x memory increase). " + "Install triton_kernels to keep 4-bit quantized weights." + ) + return True # triton_kernels required for native MXFP4 + + if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: + _MXFP4_DEQUANT_WARNED = True + logger.info("Unsloth: Keeping MXFP4 quantized (triton_kernels available)") + return False # Keep MXFP4 quantized + + +def patch_gpt_oss_linearized(): + """ + Patch GPT OSS for 4bit loading with grouped_mm support. + Only patches the GptOssExperts forward method - keeps original classes for proper weight loading. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if _should_use_gpt_oss_bnb4bit(): return + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) - def forward(self, hidden_states): - bsz, qlen, hd = hidden_states.shape - hidden_states_flat = hidden_states.view(-1, hd) + # Don't replace classes - just patch the forward method of GptOssExperts + # This keeps the original class structure which properly handles 4-bit weight loading + backend = select_moe_backend() - # Get router scores - router_scores, router_indices = self.router(hidden_states_flat) + if backend == "grouped_mm": - # Run experts with stacked weights - routed_out = self.experts( - hidden_states_flat, - router_indices=router_indices, - routing_weights=router_scores, + def experts_forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + return forward_native_grouped_mm( + self, hidden_states, router_indices, routing_weights ) + else: - routed_out = routed_out.view(bsz, qlen, hd) - return routed_out, router_scores + def experts_forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + return forward_native_moe_loop( + self, hidden_states, router_indices, routing_weights + ) - # Patch transformers module - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts._unsloth_lora_extractor_fn = _gpt_oss_lora_extractor + # Patch the original transformers class forward method + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward - # Check if model has stacked weights (BF16) vs ModuleList (4-bit) - # This is done at load time by checking the model structure if UNSLOTH_ENABLE_LOGGING: - logger.info("Unsloth: Patched GPT-OSS for BF16 split LoRA with grouped GEMM") + logger.info( + f"Unsloth: Patched GPT OSS MoE for 4bit loading (backend: {backend})" + ) + return pass -TEMPORARY_PATCHES.append(patch_gpt_oss_bf16_split_lora) +TEMPORARY_PATCHES.append(patch_gpt_oss_linearized) def patch_GptOssAttention(): @@ -1133,17 +1857,20 @@ def patch_GptOssAttention(): flex_attention_with_sink_decoding, flex_attention_add_sinks, ) + assert flex_attention_with_sink is not None except Exception as e: return raise_error("flex_attention_with_sink", e) try: import transformers.models.gpt_oss.modeling_gpt_oss + transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb except Exception as e: return raise_error("transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention", e) torch._dynamo.config.cache_size_limit = 256 + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -1158,6 +1885,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: F_softmax = torch.nn.functional.softmax F_dropout = nn.functional.dropout matmul = torch.matmul + def inplace_eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -1173,9 +1901,13 @@ def inplace_eager_attention_forward( bsz, n_heads, qlen, _ = query.shape bsz, n_heads, kvlen, _ = key_states.shape - combined_logits = key_states.new_empty((bsz, n_heads, qlen, kvlen+1)) + out_dtype = torch.result_type(query, key_states) + combined_logits = key_states.new_empty( + (bsz, n_heads, qlen, kvlen + 1), + dtype=out_dtype, + ) - attn_weights = matmul(query, key_states.transpose(2, 3), out = combined_logits[:,:,:,:kvlen]) + attn_weights = matmul(query, key_states.transpose(2, 3), out=combined_logits[:, :, :, :kvlen]) attn_weights *= scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -1185,16 +1917,16 @@ def inplace_eager_attention_forward( # combined_logits = torch.cat([attn_weights, sinks], dim=-1) combined_logits[:, :, :, -1] = module.sinks.reshape(1, -1, 1) - # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 - # when training with bsz>1 we clamp max values. - # combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=torch.float32) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) probs = combined_logits scores = probs[..., :-1] # we drop the sink here - attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) - attn_output = matmul(attn_weights, value_states, out = query) + attn_weights = F_dropout(scores, p=dropout, training=module.training) + attn_weights = attn_weights.to(value_states.dtype) + attn_output = matmul(attn_weights, value_states, out=query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None + pass def eager_attention_forward( @@ -1218,21 +1950,21 @@ def eager_attention_forward( sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) combined_logits = torch.cat([attn_weights, sinks], dim=-1) - # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 - # when training with bsz>1 we clamp max values. - # combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=torch.float32) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) probs = combined_logits scores = probs[..., :-1] # we drop the sink here - attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) - attn_output = matmul(attn_weights, value_states, out = query) + attn_weights = F_dropout(scores, p=dropout, training=module.training) + attn_weights = attn_weights.to(value_states.dtype) + attn_output = matmul(attn_weights, value_states, out=query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None + pass apply_rotary_pos_emb = torch_compile(apply_rotary_pos_emb) - if False: # Version(torch.__version__) >= Version("2.10.0"): - eager_attention_forward = torch_compile(eager_attention_forward, dynamic = None, fullgraph = True) + if False: # Version(torch.__version__) >= Version("2.10.0"): + eager_attention_forward = torch_compile(eager_attention_forward, dynamic=None, fullgraph=True) else: # Too many recompilation failures on 2.8.0, 2.9.0 eager_attention_forward = inplace_eager_attention_forward @@ -1246,7 +1978,7 @@ def forward_function( cache_position: Optional[torch.LongTensor] = None, **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, torch.Tensor]: - input_shape = hidden_states.shape[:-1] + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -1254,8 +1986,24 @@ def forward_function( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: + cache_dtype = getattr(past_key_value, "dtype", None) + if cache_dtype is None and hasattr(past_key_value, "layers"): + try: + cache_layer = past_key_value.layers[self.layer_idx] + if hasattr(cache_layer, "keys") and cache_layer.keys is not None: + cache_dtype = cache_layer.keys.dtype + except Exception: + cache_dtype = None + if cache_dtype is not None and key_states.dtype != cache_dtype: + key_states = key_states.to(cache_dtype) + value_states = value_states.to(cache_dtype) cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + if key_states.dtype != query_states.dtype: + key_states = key_states.to(query_states.dtype) + value_states = value_states.to(query_states.dtype) # flex_attention_with_sink only works for training since KV cache is wrong # switch to flex_attention_with_sink which allows all to work @@ -1307,9 +2055,11 @@ def forward_function( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights + pass functions = [] + def forward( self, hidden_states: torch.Tensor, @@ -1320,7 +2070,9 @@ def forward( **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs) + functions.append(forward) + def forward( self, hidden_states: torch.Tensor, @@ -1331,10 +2083,13 @@ def forward( **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs) + functions.append(forward) patch_function_past_key_values(transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention, "forward", functions) # Set env variable for padding purposes os.environ["UNSLOTH_ENABLE_FLEX_ATTENTION"] = "1" + + pass TEMPORARY_PATCHES.append(patch_GptOssAttention) @@ -1344,6 +2099,7 @@ def patch_GptOssModel(): if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return try: import transformers.models.gpt_oss.modeling_gpt_oss + transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel from transformers.models.gpt_oss.modeling_gpt_oss import MoeModelOutputWithPast from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb @@ -1360,6 +2116,7 @@ def patch_GptOssModel(): # Disable mask creations since we don't need them for GPT-OSS import transformers.masking_utils import transformers.generation.utils + def wrap(f): def return_attention_mask(*args, **kwargs): if kwargs["input_embeds"].requires_grad: @@ -1372,7 +2129,9 @@ def return_attention_mask(*args, **kwargs): # Eager return f(*args, **kwargs) pass + return return_attention_mask + pass create_causal_mask = getattr( transformers.masking_utils, @@ -1389,8 +2148,8 @@ def return_attention_mask(*args, **kwargs): if create_sliding_window_causal_mask is None: return raise_error("transformers.masking_utils.create_sliding_window_causal_mask") if not hasattr(transformers.masking_utils, "__patched_causal_mask__"): - transformers.masking_utils._old_create_causal_mask = _torch_compile(transformers.masking_utils.create_causal_mask, fullgraph = False, dynamic = True) - transformers.masking_utils._old_create_sliding_window_causal_mask = _torch_compile(transformers.masking_utils.create_sliding_window_causal_mask, fullgraph = False, dynamic = True) + transformers.masking_utils._old_create_causal_mask = _torch_compile(transformers.masking_utils.create_causal_mask, fullgraph=False, dynamic=True) + transformers.masking_utils._old_create_sliding_window_causal_mask = _torch_compile(transformers.masking_utils.create_sliding_window_causal_mask, fullgraph=False, dynamic=True) transformers.masking_utils.create_causal_mask = wrap(create_causal_mask) transformers.masking_utils.create_sliding_window_causal_mask = wrap(create_sliding_window_causal_mask) transformers.models.gpt_oss.modeling_gpt_oss.create_causal_mask = transformers.masking_utils.create_causal_mask @@ -1405,6 +2164,7 @@ def return_attention_mask(*args, **kwargs): flex_attention_with_sink_decoding, flex_attention_add_sinks, ) + apply_rotary_pos_emb = torch_compile(apply_rotary_pos_emb) try: from transformers.integrations.mxfp4 import mlp_forward @@ -1420,7 +2180,7 @@ def pre_attention_decoding( cache_position: Optional[torch.LongTensor] = None, **kwargs: KWARGS_TYPE, ): - input_shape = hidden_states.shape[:-1] + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -1431,7 +2191,9 @@ def pre_attention_decoding( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) return query_states, key_states, value_states, input_shape + pass + # Do flex_attention_with_sink_decoding with cannot be compiled # attn_output, logsumexp = flex_attention_with_sink_decoding( # self, @@ -1444,6 +2206,7 @@ def post_attention_decoding(self_attn, attn_output, logsumexp, input_shape): attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self_attn.o_proj(attn_output) return attn_output + pass # RMSNorm forward @@ -1455,6 +2218,7 @@ def rms_layernorm_forward(self, hidden_states): hidden_states *= torch.rsqrt_(variance) hidden_states *= self.weight.to(hidden_states.device).to(torch.float32) return hidden_states.to(input_dtype) # main diff with Llama + pass # Re-compiling for each new sequence length which is NOT ideal @@ -1482,19 +2246,21 @@ def pre_forward( position_embeddings=position_embeddings, ) return query_states, key_states, value_states, input_shape + pass fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion = True, - max_autotune = False, # Too slow - shape_padding = True, - cudagraphs = True, - coordinate_descent_tuning = False, - combo_kernels = False, - memory_planning = True, - multi_kernel = False, # Fails on torch 2.10 nightly - use_block_ptr = True, - logging = UNSLOTH_ENABLE_LOGGING, + epilogue_fusion=True, + max_autotune=False, # Too slow + shape_padding=True, + cudagraphs=True, + coordinate_descent_tuning=False, + combo_kernels=False, + memory_planning=True, + multi_kernel=False, # Fails on torch 2.10 nightly + use_block_ptr=True, + logging=UNSLOTH_ENABLE_LOGGING, ) + @_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) def post_forward( self, @@ -1503,13 +2269,16 @@ def post_forward( logsumexp: torch.Tensor, input_shape, ): - hidden_states = post_attention_decoding(self.self_attn, attn_output, logsumexp, input_shape) + hidden_states = post_attention_decoding( + self.self_attn, attn_output, logsumexp, input_shape + ) hidden_states += residual # Fully Connected residual = hidden_states.clone() hidden_states = rms_layernorm_forward(self.post_attention_layernorm, hidden_states) return hidden_states, residual + pass def inference_forward( @@ -1542,6 +2311,7 @@ def inference_forward( residual = hidden_states.clone() hidden_states = rms_layernorm_forward(self.post_attention_layernorm, hidden_states) return hidden_states, residual + pass # if has_static_cache and Version(torch.__version__) >= Version("2.10.0"): # # torch 2.9.0 has excessive compilations @@ -1567,24 +2337,24 @@ def forward( if inputs_embeds is None: # Account for CPU offloaded embed_tokens embed_device = self.embed_tokens.weight.device - inputs_embeds = self.embed_tokens(input_ids.to(embed_device, non_blocking = True)).to(input_ids.device) + inputs_embeds = self.embed_tokens( + input_ids.to(embed_device, non_blocking=True) + ).to(input_ids.device) if not self.training: inputs_embeds.requires_grad_(False) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + past_seen_tokens = (past_key_values.get_seq_length() if past_key_values is not None else 0) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) try: - torch._dynamo.mark_static (hidden_states, 0) + torch._dynamo.mark_static(hidden_states, 0) torch._dynamo.mark_dynamic(hidden_states, 1) - torch._dynamo.mark_static (hidden_states, 2) + torch._dynamo.mark_static(hidden_states, 2) except: pass @@ -1608,6 +2378,8 @@ def forward( # Add hack since residuals need to clone outside of the torch.compile region?? # This forces it to free past residuals torch.compiler.cudagraph_mark_step_begin() + # Initialize for common return path + all_hidden_states = None for decoder_layer in self.layers: hidden_states, residual = inference_forward( decoder_layer, @@ -1620,9 +2392,14 @@ def forward( position_embeddings, **kwargs, ) - if hasattr(decoder_layer.mlp.experts, "gate_up_projs"): - hidden_states = moe_forward_inference(decoder_layer.mlp, hidden_states) - elif decoder_layer.mlp.experts.__class__.__name__ == "Mxfp4GptOssExperts": + _actual_experts = _unwrap_peft_experts(decoder_layer.mlp.experts) + if hasattr(_actual_experts, "gate_up_projs"): + hidden_states = moe_forward_inference( + decoder_layer.mlp, hidden_states + ) + elif ( + _actual_experts.__class__.__name__ == "Mxfp4GptOssExperts" + ): if mlp_forward is None: raise RuntimeError("Unsloth: MXFP4 forward is not found") hidden_states, _ = mlp_forward(decoder_layer.mlp, hidden_states) @@ -1632,8 +2409,36 @@ def forward( pass hidden_states = rms_layernorm_forward(self.norm, hidden_states) else: + # Fix 2D attention_mask being passed to eager_attention_forward which expects 4D + if ( + self.training + and attention_mask is not None + and attention_mask.dim() == 2 + ): + bsz, seq_len = attention_mask.shape + min_dtype = torch.finfo(inputs_embeds.dtype).min + # 1. Expand padding mask to (B, 1, 1, S) + expanded_mask = attention_mask[:, None, None, :].to( + dtype=inputs_embeds.dtype + ) + expanded_mask = (1.0 - expanded_mask) * min_dtype + + # 2. Causal mask (1, 1, S, S) + causal_mask = torch.full((seq_len, seq_len), min_dtype, device=attention_mask.device, dtype=inputs_embeds.dtype) + causal_mask = torch.triu(causal_mask, diagonal=1) + attention_mask = causal_mask[None, None, :, :] + expanded_mask + + # Accumulate hidden states if requested + output_hidden_states = kwargs.get( + "output_hidden_states", self.config.output_hidden_states + ) + all_hidden_states = () if output_hidden_states else None + for decoder_layer in self.layers: - mask = attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask + if output_hidden_states: + all_hidden_states += (hidden_states,) + + mask = (attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask) hidden_states = decoder_layer( hidden_states, attention_mask=mask, @@ -1646,13 +2451,24 @@ def forward( ) pass hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + # Fix float16 / float32 mismatching hidden_states = hidden_states.to(inputs_embeds.dtype) - return process_return(MoeModelOutputWithPast, { - "last_hidden_state" : hidden_states, - "past_key_values" : past_key_values, - }) - patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level = "relaxed") + return process_return( + MoeModelOutputWithPast, + { + "last_hidden_state": hidden_states, + "past_key_values": past_key_values, + "hidden_states": all_hidden_states, + }, + ) + + patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level="relaxed") + + pass TEMPORARY_PATCHES.append(patch_GptOssModel) @@ -1667,11 +2483,14 @@ def forward( SystemContent, ToolDescription, load_harmony_encoding, - ReasoningEffort + ReasoningEffort, ) + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) except: pass + + def encode_conversations_with_harmony( messages, reasoning_effort = "medium", @@ -1688,22 +2507,24 @@ def encode_conversations_with_harmony( assert reasoning_effort in ("low", "medium", "high") match reasoning_effort: - case "low": harmony_reasoning = ReasoningEffort.LOW + case "low": harmony_reasoning = ReasoningEffort.LOW case "medium": harmony_reasoning = ReasoningEffort.MEDIUM - case "high": harmony_reasoning = ReasoningEffort.HIGH + case "high": harmony_reasoning = ReasoningEffort.HIGH convos = [] # Create system message import datetime + today = datetime.datetime.today().strftime("%Y-%m-%d") - system = Message.from_role_and_content(Role.SYSTEM, + system = Message.from_role_and_content( + Role.SYSTEM, SystemContent.new() .with_model_identity(model_identity) .with_reasoning_effort(harmony_reasoning) .with_conversation_start_date(today) .with_knowledge_cutoff("2024-06") - .with_required_channels(["analysis", "commentary", "final"]) + .with_required_channels(["analysis", "commentary", "final"]), ) convos.append(system) @@ -1727,27 +2548,21 @@ def encode_conversations_with_harmony( for message in messages: if message["role"] == "user": - convos.append( - Message.from_role_and_content(Role.USER, message["content"]) - ) + convos.append(Message.from_role_and_content(Role.USER, message["content"])) elif message["role"] == "assistant": if "thinking" in message: x = Message.from_role_and_content(Role.ASSISTANT, message["content"]) x = x.with_channel("analysis") elif "tool_calls" in message: - x = Message.from_role_and_content(Role.ASSISTANT, message['tool_calls'][0]["arguments"]) - x = x.with_channel("commentary")\ - .with_recipient(f"functions.{message['tool_calls'][0]['name']}")\ - .with_content_type("json") + x = Message.from_role_and_content(Role.ASSISTANT, message["tool_calls"][0]["arguments"]) + x = x.with_channel("commentary").with_recipient(f"functions.{message['tool_calls'][0]['name']}").with_content_type("json") else: x = Message.from_role_and_content(Role.ASSISTANT, message["content"]) x = x.with_channel("final") convos.append(x) elif message["role"] == "tool": - x = Message.from_author_and_content( - Author.new(Role.TOOL, f"functions.{message['name']}"), - message["content"], - ).with_recipient("assistant").with_channel("commentary") + x = Message.from_author_and_content(Author.new(Role.TOOL, f"functions.{message['name']}"), message["content"]) + x = x.with_recipient("assistant").with_channel("commentary") convos.append(x) pass @@ -1759,6 +2574,8 @@ def encode_conversations_with_harmony( harmony_input_ids = encoding.render_conversation(convos) harmony_decoded_text = encoding.decode(harmony_input_ids) return harmony_decoded_text, harmony_input_ids + + pass @@ -1767,8 +2584,10 @@ def encode_conversations_with_harmony( # AutoConfig error: 'GptOssConfig' object has no attribute 'max_position_embeddings' try: from transformers.configuration_utils import layer_type_validation + try: from transformers.configuration_utils import PreTrainedConfig + PretrainedConfig = PreTrainedConfig except: from transformers.configuration_utils import PretrainedConfig @@ -1850,9 +2669,7 @@ def __init__( self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads self.layer_types = layer_types if self.layer_types is None: - self.layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) - ] + self.layer_types = ["sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)] layer_type_validation(self.layer_types) # Validate the correctness of rotary position embeddings parameters @@ -1915,7 +2732,13 @@ def __init__( initializer_range: float = 0.02, max_position_embeddings=131072, rms_norm_eps: float = 1e-5, - rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False}, + rope_scaling={ + "rope_type": "yarn", + "factor": 32.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "truncate": False, + }, attention_dropout: float = 0.0, num_experts_per_tok=4, router_aux_loss_coef: float = 0.9, @@ -1943,11 +2766,16 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout - self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.head_dim = ( + head_dim + if head_dim is not None + else self.hidden_size // self.num_attention_heads + ) self.layer_types = layer_types if self.layer_types is None: self.layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + "sliding_attention" if bool((i + 1) % 2) else "full_attention" + for i in range(self.num_hidden_layers) ] layer_type_validation(self.layer_types) self.attention_bias = True @@ -1975,6 +2803,7 @@ def __init__( def patch_gpt_oss_config(): try: import transformers.models.gpt_oss.configuration_gpt_oss + transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig except Exception as e: return raise_error("transformers.models.gpt_oss.configuration_gpt_oss", e) @@ -1988,7 +2817,46 @@ def patch_gpt_oss_config(): patch_function(transformers.models.gpt_oss.configuration_gpt_oss, "GptOssConfig", GptOssConfig) except Exception as e: return raise_error("transformers.models.gpt_oss.configuration_gpt_oss", e) + pass TEMPORARY_PATCHES.append(patch_gpt_oss_config) except Exception as e: raise_error("transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig", e) + + +def patch_gpt_oss_init_weights_modulelist_fix(): + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + + GptOssPreTrainedModel = ( + transformers.models.gpt_oss.modeling_gpt_oss.GptOssPreTrainedModel + ) + GptOssExperts = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts + if getattr(GptOssPreTrainedModel, "_unsloth_init_weights_fixed", False): + return + _original_init_weights = GptOssPreTrainedModel._init_weights + + def _patched_init_weights(self, module): + if isinstance(module, GptOssExperts) and not hasattr(module, "gate_up_proj"): + std = self.config.initializer_range + for up in getattr(module, "gate_up_projs", []): + init.normal_(up.weight, mean=0.0, std=std) + if up.bias is not None: + init.zeros_(up.bias) + for down in getattr(module, "down_projs", []): + init.normal_(down.weight, mean=0.0, std=std) + if down.bias is not None: + init.zeros_(down.bias) + return + _original_init_weights(self, module) + + patch_function(GptOssPreTrainedModel, "_init_weights", _patched_init_weights) + GptOssPreTrainedModel._unsloth_init_weights_fixed = True + + +pass +TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index 5760778a2..89171e387 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -19,6 +19,7 @@ import shutil from typing import Optional, Tuple from torch.autograd import Function +from .utils import logger # Get compile location UNSLOTH_COMPILE_LOCATION = os.environ.get( @@ -711,6 +712,8 @@ def forward_native_grouped_mm( """ # This Unsloth Zoo code section is licensed under AGPL3 + logger.info(f"[DEBUG]Using torch._grouped_mm for MoE forward pass") + # Runtime safety check - defense in depth if not _check_torch_grouped_mm_supported(): major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) From 27af91c64b3613592fada9fafcf95ea28f5b924e Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Fri, 6 Feb 2026 14:35:55 -0500 Subject: [PATCH 03/26] Update gpt_oss.py, make it transformers 4 compatible --- unsloth_zoo/temporary_patches/gpt_oss.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index c60c64f21..b135fbbe2 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -905,8 +905,10 @@ def forward(self, hidden_states): router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_( 1, router_indices, router_top_value ) - return router_logits, router_scores, router_indices - + if _is_transformers_v5(): + return router_logits, router_scores, router_indices + else: + return router_scores, router_indices pass @@ -1073,7 +1075,10 @@ def forward(self, hidden_states): dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_(1, router_indices, router_top_value) - return router_logits, router_scores, router_indices + if _is_transformers_v5(): + return router_logits, router_scores, router_indices + else: + return router_scores, router_indices pass From b032d2248e0da733efdea20f95080f9af23d3c66 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Fri, 6 Feb 2026 15:01:31 -0500 Subject: [PATCH 04/26] Update gpt_oss.py, needed if statement --- unsloth_zoo/temporary_patches/gpt_oss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index b135fbbe2..727455df0 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -905,7 +905,7 @@ def forward(self, hidden_states): router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_( 1, router_indices, router_top_value ) - if _is_transformers_v5(): + if transformers_version >= Version("5.0.0"): return router_logits, router_scores, router_indices else: return router_scores, router_indices @@ -1075,7 +1075,7 @@ def forward(self, hidden_states): dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_(1, router_indices, router_top_value) - if _is_transformers_v5(): + if transformers_version >= Version("5.0.0"): return router_logits, router_scores, router_indices else: return router_scores, router_indices From b1bd22a3c6bdfaca58bd0198a2b56cb317e55d28 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 6 Feb 2026 17:14:48 +0000 Subject: [PATCH 05/26] undo spacing --- unsloth_zoo/temporary_patches/gpt_oss.py | 117 +++++++++-------------- 1 file changed, 47 insertions(+), 70 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 727455df0..6cfddd559 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -183,7 +183,6 @@ def is_kernels_available(): def swizzle_mxfp4(w, w_scale, *args, **kwargs): from triton_kernels import tensor, tensor_details - FP4, convert_layout, wrap_torch_tensor = ( tensor.FP4, tensor.convert_layout, @@ -192,12 +191,8 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): layout = tensor_details.layout StridedLayout = tensor_details.layout.StridedLayout - value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( - mx_axis=1 - ) - w = convert_layout( - wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts - ) + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) # TODO : add that when we are actually sure that it works on B200 # if torch.cuda.get_device_capability()[0] == 10: # constraints = { @@ -231,9 +226,7 @@ def forward( scatter_idx, ): pre_activation = matmul_ogs( - hidden_states.to( - torch.bfloat16 - ), # tl.dot_scaled upcasts to BF16 for old hardware + hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware self_class.gate_up_proj, self_class.gate_up_proj_bias, routing_data, @@ -272,7 +265,6 @@ def forward( ctx.scatter_idx = scatter_idx ctx.routing_data = routing_data return out - pass @staticmethod @@ -305,7 +297,6 @@ def backward(ctx, grad_token): dx_token = torch.zeros_like(grad_token) dx_token.index_add_(0, gather_dst, dx_exp) return (dx_token, None, None, None, None,) - pass pass @@ -893,12 +884,8 @@ def bias(self, value): def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.linear( - hidden_states.to(self.linear.weight.dtype) - ) # (batch_size * seq_len, num_experts) - router_top_value, router_indices = torch.topk( - router_logits, self.top_k, dim=-1 - ) # (seq_len, top_k) + router_logits = self.linear(hidden_states.to(self.linear.weight.dtype)) # (batch_size * seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax( router_top_value, dim=1, dtype=router_top_value.dtype ) @@ -1173,30 +1160,30 @@ def patch_gpt_oss_bnb4bit_auto(): device_memory = torch.xpu.memory.mem_get_info(0)[-1] else: device_memory = torch.cuda.memory.mem_get_info(0)[-1] -use_combo_kernels = False if device_memory / 1024 / 1024 / 1024 <= 40 else True +use_combo_kernels = False if device_memory/1024/1024/1024 <= 40 else True fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion=True, - max_autotune=False, # Too slow - shape_padding=True, - cudagraphs=True, - coordinate_descent_tuning=use_combo_kernels, # Very slow! - combo_kernels=use_combo_kernels, - memory_planning=True, - multi_kernel=False, # Fails on torch 2.10 nightly - use_block_ptr=True, - logging=UNSLOTH_ENABLE_LOGGING, + epilogue_fusion = True, + max_autotune = False, # Too slow + shape_padding = True, + cudagraphs = True, + coordinate_descent_tuning = use_combo_kernels, # Very slow! + combo_kernels = use_combo_kernels, + memory_planning = True, + multi_kernel = False, # Fails on torch 2.10 nightly + use_block_ptr = True, + logging = UNSLOTH_ENABLE_LOGGING, ) no_combo_fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion=True, - max_autotune=False, # Too slow - shape_padding=True, - cudagraphs=True, - coordinate_descent_tuning=use_combo_kernels, # Very slow! - combo_kernels=False, # Breaks on attention - memory_planning=True, - multi_kernel=False, # Fails on torch 2.10 nightly - use_block_ptr=True, - logging=UNSLOTH_ENABLE_LOGGING, + epilogue_fusion = True, + max_autotune = False, # Too slow + shape_padding = True, + cudagraphs = True, + coordinate_descent_tuning = use_combo_kernels, # Very slow! + combo_kernels = False, # Breaks on attention + memory_planning = True, + multi_kernel = False, # Fails on torch 2.10 nightly + use_block_ptr = True, + logging = UNSLOTH_ENABLE_LOGGING, ) @@ -1278,21 +1265,13 @@ def moe_forward_inference(self, hidden_states): @torch_compile(dynamic=True, fullgraph=True) def moe_router_forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear( - hidden_states.to(self.weight.dtype), self.weight, self.bias - ) # (seq_len, num_experts) - router_top_value, router_indices = torch.topk( - router_logits, self.top_k, dim=-1 - ) # (seq_len, top_k) + router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, self.bias) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) dtype = ( torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype ) - router_top_value = torch.nn.functional.softmax( - router_top_value, dim=1, dtype=torch.float32 - ).to(dtype) - router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_( - 1, router_indices, router_top_value - ) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) + router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) return router_scores, router_indices @@ -1912,7 +1891,7 @@ def inplace_eager_attention_forward( dtype=out_dtype, ) - attn_weights = matmul(query, key_states.transpose(2, 3), out=combined_logits[:, :, :, :kvlen]) + attn_weights = matmul(query, key_states.transpose(2, 3), out = combined_logits[:,:,:,:kvlen]) attn_weights *= scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -1928,7 +1907,7 @@ def inplace_eager_attention_forward( scores = probs[..., :-1] # we drop the sink here attn_weights = F_dropout(scores, p=dropout, training=module.training) attn_weights = attn_weights.to(value_states.dtype) - attn_output = matmul(attn_weights, value_states, out=query) + attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None @@ -1961,7 +1940,7 @@ def eager_attention_forward( scores = probs[..., :-1] # we drop the sink here attn_weights = F_dropout(scores, p=dropout, training=module.training) attn_weights = attn_weights.to(value_states.dtype) - attn_output = matmul(attn_weights, value_states, out=query) + attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None @@ -1983,7 +1962,7 @@ def forward_function( cache_position: Optional[torch.LongTensor] = None, **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, torch.Tensor]: - input_shape = hidden_states.shape[:-1] + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -2134,7 +2113,6 @@ def return_attention_mask(*args, **kwargs): # Eager return f(*args, **kwargs) pass - return return_attention_mask pass @@ -2185,7 +2163,7 @@ def pre_attention_decoding( cache_position: Optional[torch.LongTensor] = None, **kwargs: KWARGS_TYPE, ): - input_shape = hidden_states.shape[:-1] + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -2254,16 +2232,16 @@ def pre_forward( pass fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion=True, - max_autotune=False, # Too slow - shape_padding=True, - cudagraphs=True, - coordinate_descent_tuning=False, - combo_kernels=False, - memory_planning=True, - multi_kernel=False, # Fails on torch 2.10 nightly - use_block_ptr=True, - logging=UNSLOTH_ENABLE_LOGGING, + epilogue_fusion = True, + max_autotune = False, # Too slow + shape_padding = True, + cudagraphs = True, + coordinate_descent_tuning = False, + combo_kernels = False, + memory_planning = True, + multi_kernel = False, # Fails on torch 2.10 nightly + use_block_ptr = True, + logging = UNSLOTH_ENABLE_LOGGING, ) @_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) @@ -2274,9 +2252,7 @@ def post_forward( logsumexp: torch.Tensor, input_shape, ): - hidden_states = post_attention_decoding( - self.self_attn, attn_output, logsumexp, input_shape - ) + hidden_states = post_attention_decoding(self.self_attn, attn_output, logsumexp, input_shape) hidden_states += residual # Fully Connected @@ -2865,3 +2841,4 @@ def _patched_init_weights(self, module): pass TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) + From 46376a0b218d945267c49e8d2f2b676e677a2f62 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sat, 7 Feb 2026 11:23:02 +0000 Subject: [PATCH 06/26] remove logger --- unsloth_zoo/temporary_patches/moe_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index 89171e387..5760778a2 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -19,7 +19,6 @@ import shutil from typing import Optional, Tuple from torch.autograd import Function -from .utils import logger # Get compile location UNSLOTH_COMPILE_LOCATION = os.environ.get( @@ -712,8 +711,6 @@ def forward_native_grouped_mm( """ # This Unsloth Zoo code section is licensed under AGPL3 - logger.info(f"[DEBUG]Using torch._grouped_mm for MoE forward pass") - # Runtime safety check - defense in depth if not _check_torch_grouped_mm_supported(): major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) From 4e2db2d978db44fb0ffb7ab6ee3e7f765934838e Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 8 Feb 2026 10:27:07 +0000 Subject: [PATCH 07/26] dtype cast --- unsloth_zoo/temporary_patches/gpt_oss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 6cfddd559..68dd79dd2 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -1985,7 +1985,7 @@ def forward_function( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) - if key_states.dtype != query_states.dtype: + if key_states.dtype != query_states.dtype or value_states.dtype != query_states.dtype: key_states = key_states.to(query_states.dtype) value_states = value_states.to(query_states.dtype) @@ -2841,4 +2841,3 @@ def _patched_init_weights(self, module): pass TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) - From 2004c5871a8f3d8493c7f6f27c95889aee7e99d1 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 8 Feb 2026 10:46:26 +0000 Subject: [PATCH 08/26] dtype cast --- unsloth_zoo/temporary_patches/gpt_oss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 68dd79dd2..d8fae2708 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -1985,9 +1985,9 @@ def forward_function( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) - if key_states.dtype != query_states.dtype or value_states.dtype != query_states.dtype: - key_states = key_states.to(query_states.dtype) - value_states = value_states.to(query_states.dtype) + if key_states.dtype != query_states.dtype or value_states.dtype != query_states.dtype: + key_states = key_states.to(query_states.dtype) + value_states = value_states.to(query_states.dtype) # flex_attention_with_sink only works for training since KV cache is wrong # switch to flex_attention_with_sink which allows all to work From e22008be1e276efdfe85fd7e5b95cb2de883765a Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 8 Feb 2026 11:03:20 +0000 Subject: [PATCH 09/26] fix gpt oss grpo --- unsloth_zoo/temporary_patches/gpt_oss.py | 104 +++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index d8fae2708..8d43dcfc0 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -2841,3 +2841,107 @@ def _patched_init_weights(self, module): pass TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) + + +# ============================================================================ +# Patch GptOssForCausalLM.forward for GRPO training +# When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits +# ============================================================================ +def patch_gpt_oss_for_grpo(): + """ + Patch GptOssForCausalLM.forward for GRPO training. + When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits. + This fixes the matrix multiplication dimension mismatch issue in GRPO training. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + + try: + import transformers.models.gpt_oss.modeling_gpt_oss + from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssForCausalLM, + MoeCausalLMOutputWithPast, + ) + + _original_causal_lm_forward = GptOssForCausalLM.forward + + def _patched_causal_lm_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + logits_to_keep=0, + **kwargs, + ): + # This Unsloth Zoo code section is licensed under AGPL3 + + RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1" + + if not RETURN_HIDDEN_STATES: + # Normal forward pass + return _original_causal_lm_forward( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + # RETURN_HIDDEN_STATES mode - return hidden_states instead of logits + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + # Apply slice_indices to hidden_states (same indexing as for logits) + if logits_to_keep != 0: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + hidden_states = hidden_states[:, slice_indices, :] + + # Return hidden_states as "logits" for GRPO to use + return MoeCausalLMOutputWithPast( + loss=None, + aux_loss=getattr(outputs, 'aux_loss', None), + logits=hidden_states, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=getattr(outputs, 'router_logits', None), + ) + + GptOssForCausalLM.forward = _patched_causal_lm_forward + if UNSLOTH_ENABLE_LOGGING: + logger.info("Unsloth: Patched GptOssForCausalLM.forward for GRPO hidden states.") + + except Exception as e: + if UNSLOTH_ENABLE_LOGGING: + logger.warning(f"Unsloth: Could not patch GptOssForCausalLM.forward: {e}") + + +pass +TEMPORARY_PATCHES.append(patch_gpt_oss_for_grpo) From d432149e312af6d9aeb4558b516d57794a06519e Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 07:19:36 +0000 Subject: [PATCH 10/26] patch for v5 only --- unsloth_zoo/temporary_patches/gpt_oss.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 8d43dcfc0..0a3036fc6 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -1098,17 +1098,18 @@ def patch_gpt_oss_bnb4bit(): # Preserve original symbol names for compiler-generated modules. GptOssExpertsBnb4bit.__name__ = "GptOssExperts" GptOssExpertsBnb4bit.__qualname__ = "GptOssExperts" - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit - # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. - # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. - # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict - # override to remap keys, but transformers v5 bypasses _load_from_state_dict - # (uses accelerate's set_module_tensor_to_device), so the remapping never ran - # and router weights were randomly initialized - causing high loss (~4-5). - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter - - logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") - os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" + if _is_transformers_v5(): + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit + # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. + # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. + # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict + # override to remap keys, but transformers v5 bypasses _load_from_state_dict + # (uses accelerate's set_module_tensor_to_device), so the remapping never ran + # and router weights were randomly initialized - causing high loss (~4-5). + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter + + logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") + os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" return True From fc51c93d8faaba706c9bf2a4df396cb4ad24db90 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 07:33:50 +0000 Subject: [PATCH 11/26] patch for v5 only --- unsloth_zoo/temporary_patches/gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 0a3036fc6..b430dac2f 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -1151,7 +1151,7 @@ def patch_gpt_oss_bnb4bit_auto(): pass -TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto) +# TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto) # Combo kernels uses too much VRAM for low memory GPUs From 3d8a6443ff8d2cd6a72c2f700720c1549c970a97 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 07:40:28 +0000 Subject: [PATCH 12/26] remove unnecessary ops --- unsloth_zoo/temporary_patches/gpt_oss.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index b430dac2f..0df6a641e 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -950,10 +950,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig if self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + # with torch.no_grad(): + # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + # expert_mask = expert_mask.permute(2, 1, 0) + # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hitted[:]: with torch.no_grad(): _, token_idx = torch.where(expert_mask[expert_idx[0]]) @@ -965,7 +965,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig glu = gate * torch.sigmoid(gate * self.alpha) gated_output = (up + 1) * glu out = self.down_projs[expert_idx](gated_output) - weighted_output = out * routing_weights[token_idx, expert_idx, None] + weighted_output = out * routing_weights[token_idx, expert_idx, None].to(torch.float32) next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) return next_states @@ -982,8 +982,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig out_list = [down_l(fused[e]) for e, down_l in enumerate(self.down_projs)] outs = torch.stack(out_list, dim=0) rw = routing_weights.transpose(0, 1).unsqueeze(-1) - mixed = (outs * rw).sum(dim=0) - return mixed.view(batch_size, -1, self.hidden_size) + mixed = (outs.to(torch.float32) * rw.to(torch.float32)).sum(dim=0) + return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) pass @@ -1151,7 +1151,7 @@ def patch_gpt_oss_bnb4bit_auto(): pass -# TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto) +TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto) # Combo kernels uses too much VRAM for low memory GPUs @@ -1939,7 +1939,7 @@ def eager_attention_forward( combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) probs = combined_logits scores = probs[..., :-1] # we drop the sink here - attn_weights = F_dropout(scores, p=dropout, training=module.training) + attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) attn_weights = attn_weights.to(value_states.dtype) attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() From efe50060080c4f3a616cf615a2813f3e0caba34a Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 07:46:46 +0000 Subject: [PATCH 13/26] fix patch --- unsloth_zoo/temporary_patches/gpt_oss.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 0df6a641e..57b061ed4 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -1098,18 +1098,17 @@ def patch_gpt_oss_bnb4bit(): # Preserve original symbol names for compiler-generated modules. GptOssExpertsBnb4bit.__name__ = "GptOssExperts" GptOssExpertsBnb4bit.__qualname__ = "GptOssExperts" - if _is_transformers_v5(): - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit - # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. - # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. - # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict - # override to remap keys, but transformers v5 bypasses _load_from_state_dict - # (uses accelerate's set_module_tensor_to_device), so the remapping never ran - # and router weights were randomly initialized - causing high loss (~4-5). - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter - - logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") - os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit + # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. + # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. + # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict + # override to remap keys, but transformers v5 bypasses _load_from_state_dict + # (uses accelerate's set_module_tensor_to_device), so the remapping never ran + # and router weights were randomly initialized - causing high loss (~4-5). + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter + + logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") + os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" return True From ff5c749312ac0dff70b730ff7aa6317f397001cc Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 07:54:32 +0000 Subject: [PATCH 14/26] undo changes --- unsloth_zoo/temporary_patches/gpt_oss.py | 66 +++++++++++++----------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 57b061ed4..6ed0f466c 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -949,43 +949,46 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig num_experts = routing_weights.shape[1] if self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) # with torch.no_grad(): - # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - # expert_mask = expert_mask.permute(2, 1, 0) - # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hitted[:]: + # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + # expert_mask = expert_mask.permute(2, 1, 0) + # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + # for expert_idx in expert_hitted[:]: + for expert_idx in range(num_experts): with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx[0]]) + # _, token_idx = torch.where(expert_mask[expert_idx[0]]) + token_idx, _ = torch.where(router_indices == expert_idx) current_state = hidden_states[token_idx] gate_up = self.gate_up_projs[expert_idx](current_state) - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu + gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit) + # gate, up = gate_up[..., ::2], gate_up[..., 1::2] + # gate = gate.clamp(min=None, max=self.limit) + # up = up.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # gated_output = (up + 1) * glu out = self.down_projs[expert_idx](gated_output) weighted_output = out * routing_weights[token_idx, expert_idx, None].to(torch.float32) - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states.index_add_(0, token_idx, weighted_output) next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states + return next_states.to(hidden_states.dtype) else: X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] gate_up = torch.stack(gate_up_list, dim=0) - gate = gate_up[..., ::2] - up_h = gate_up[..., 1::2] - gate = gate.clamp(max=self.limit) - up_h = up_h.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - fused = (up_h + 1) * glu + fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = X_rep.dtype) + # gate = gate_up[..., ::2] + # up_h = gate_up[..., 1::2] + # gate = gate.clamp(max=self.limit) + # up_h = up_h.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # fused = (up_h + 1) * glu out_list = [down_l(fused[e]) for e, down_l in enumerate(self.down_projs)] outs = torch.stack(out_list, dim=0) rw = routing_weights.transpose(0, 1).unsqueeze(-1) mixed = (outs.to(torch.float32) * rw.to(torch.float32)).sum(dim=0) return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) - pass @@ -1098,17 +1101,18 @@ def patch_gpt_oss_bnb4bit(): # Preserve original symbol names for compiler-generated modules. GptOssExpertsBnb4bit.__name__ = "GptOssExperts" GptOssExpertsBnb4bit.__qualname__ = "GptOssExperts" - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit - # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. - # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. - # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict - # override to remap keys, but transformers v5 bypasses _load_from_state_dict - # (uses accelerate's set_module_tensor_to_device), so the remapping never ran - # and router weights were randomly initialized - causing high loss (~4-5). - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter - - logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") - os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" + if _is_transformers_v5(): + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit + # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. + # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. + # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict + # override to remap keys, but transformers v5 bypasses _load_from_state_dict + # (uses accelerate's set_module_tensor_to_device), so the remapping never ran + # and router weights were randomly initialized - causing high loss (~4-5). + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter + + logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") + os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" return True From ed563f127370d9018016506f0e4c5a34fac9393d Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 08:08:00 +0000 Subject: [PATCH 15/26] old code --- unsloth_zoo/temporary_patches/gpt_oss.py | 42 +++++++++++++----------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 57b061ed4..a3e2c265c 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -949,36 +949,40 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig num_experts = routing_weights.shape[1] if self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) # with torch.no_grad(): - # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - # expert_mask = expert_mask.permute(2, 1, 0) - # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hitted[:]: + # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + # expert_mask = expert_mask.permute(2, 1, 0) + # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + # for expert_idx in expert_hitted[:]: + for expert_idx in range(num_experts): with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx[0]]) + # _, token_idx = torch.where(expert_mask[expert_idx[0]]) + token_idx, _ = torch.where(router_indices == expert_idx) current_state = hidden_states[token_idx] gate_up = self.gate_up_projs[expert_idx](current_state) - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu + gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit) + # gate, up = gate_up[..., ::2], gate_up[..., 1::2] + # gate = gate.clamp(min=None, max=self.limit) + # up = up.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # gated_output = (up + 1) * glu out = self.down_projs[expert_idx](gated_output) weighted_output = out * routing_weights[token_idx, expert_idx, None].to(torch.float32) - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states.index_add_(0, token_idx, weighted_output) next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states + return next_states.to(hidden_states.dtype) else: X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] gate_up = torch.stack(gate_up_list, dim=0) - gate = gate_up[..., ::2] - up_h = gate_up[..., 1::2] - gate = gate.clamp(max=self.limit) - up_h = up_h.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - fused = (up_h + 1) * glu + fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = X_rep.dtype) + # gate = gate_up[..., ::2] + # up_h = gate_up[..., 1::2] + # gate = gate.clamp(max=self.limit) + # up_h = up_h.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # fused = (up_h + 1) * glu out_list = [down_l(fused[e]) for e, down_l in enumerate(self.down_projs)] outs = torch.stack(out_list, dim=0) rw = routing_weights.transpose(0, 1).unsqueeze(-1) From ff3728845dc2dbfa8f5552ecb99d1ea6216b01b3 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 08:22:00 +0000 Subject: [PATCH 16/26] further fixes for loop --- unsloth_zoo/temporary_patches/gpt_oss.py | 79 ++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 6ed0f466c..66c9f9dfe 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -920,11 +920,11 @@ def __init__(self, config): self.dtype = dtype_from_config(config) self.gate_up_projs = nn.ModuleList([ - nn.Linear(self.hidden_size, 2 * self.expert_dim, bias=True, dtype=self.dtype) + nn.Linear(self.hidden_size, 2 * self.expert_dim, dtype=self.dtype) for _ in range(self.num_experts) ]) self.down_projs = nn.ModuleList([ - nn.Linear(self.expert_dim, self.hidden_size, bias=True, dtype=self.dtype) + nn.Linear(self.expert_dim, self.hidden_size, dtype=self.dtype) for _ in range(self.num_experts) ]) @@ -1787,6 +1787,73 @@ def should_dequantize_mxfp4(): return False # Keep MXFP4 quantized +def torch_native_forward( + self, + hidden_states: torch.Tensor, + router_indices = None, + routing_weights = None +) -> torch.Tensor: + + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + num_experts = routing_weights.shape[1] + if self.training: + next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) + # with torch.no_grad(): + # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + # expert_mask = expert_mask.permute(2, 1, 0) + # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + # for expert_idx in expert_hitted[:]: + for expert_idx in range(num_experts): + with torch.no_grad(): + # _, token_idx = torch.where(expert_mask[expert_idx[0]]) + token_idx, _ = torch.where(router_indices == expert_idx) + current_state = hidden_states[token_idx] + gate_up = self.gate_up_projs[expert_idx](current_state) + down_proj = self.down_projs[expert_idx] + gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = torch.float32) + # gate, up = gate_up[..., ::2], gate_up[..., 1::2] + # gate = gate.clamp(min=None, max=self.limit) + # up = up.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # gated_output = (up + 1) * glu + + # Force float32 matrix multiply on some down projection modules + gated_output = gated_output.to(torch.float32) + device_type = gated_output.device.type if isinstance(gated_output.device.type, str) and gated_output.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + out = down_proj(gated_output) + weighted_output = out.to(torch.float32) * routing_weights[token_idx, expert_idx, None].to(torch.float32) + next_states.index_add_(0, token_idx, weighted_output) + next_states = next_states.view(batch_size, -1, self.hidden_size) + return next_states.to(torch.float32) + else: + X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) + gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] + gate_up = torch.stack(gate_up_list, dim=0) + dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype + fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = dtype) + # gate = gate_up[..., ::2] + # up_h = gate_up[..., 1::2] + # gate = gate.clamp(max=self.limit) + # up_h = up_h.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # fused = (up_h + 1) * glu + + # Force float32 matrix multiply on down projection only + device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + out_list = [ + down_l(fused[e].to(dtype)) + for e, down_l in enumerate(self.down_projs) + ] + outs = torch.stack(out_list, dim=0) + rw = routing_weights.transpose(0, 1).unsqueeze(-1) + mixed = (outs.to(dtype) * rw.to(dtype)).sum(dim=0) + return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) + pass +pass + def patch_gpt_oss_linearized(): """ Patch GPT OSS for 4bit loading with grouped_mm support. @@ -1812,17 +1879,19 @@ def experts_forward( return forward_native_grouped_mm( self, hidden_states, router_indices, routing_weights ) + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward else: def experts_forward( self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None ) -> torch.Tensor: - return forward_native_moe_loop( + return torch_native_forward( self, hidden_states, router_indices, routing_weights ) - # Patch the original transformers class forward method - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward + if UNSLOTH_ENABLE_LOGGING: logger.info( From a8aa29a7e4dace4dd424083104e53d14f978ce67 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 08:27:45 +0000 Subject: [PATCH 17/26] Use gpt oss specific loop --- unsloth_zoo/temporary_patches/gpt_oss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 66c9f9dfe..9859f2232 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -828,7 +828,7 @@ def forward( ) # Fallback to loop-based implementation - return forward_native_moe_loop( + return torch_native_forward( self, hidden_states, router_indices, routing_weights ) @@ -1395,7 +1395,7 @@ def forward(self, hidden_states): select_moe_backend, patch_param_wrapper_for_moe, forward_native_grouped_mm, - forward_native_moe_loop, + torch_native_forward, ) @@ -1460,7 +1460,7 @@ def patch_gpt_oss_moe_for_lora(): if backend == "grouped_mm": forward = forward_native_grouped_mm else: - forward = forward_native_moe_loop + forward = torch_native_forward # Store original forward and patch - but DON'T replace the class! GptOssExpertsClass._original_forward = GptOssExpertsClass.forward From 733d1a7dd658d2656f17989f85eb7c1751af6a57 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 08:29:49 +0000 Subject: [PATCH 18/26] Fix import --- unsloth_zoo/temporary_patches/gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 9859f2232..43f247415 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -1395,7 +1395,7 @@ def forward(self, hidden_states): select_moe_backend, patch_param_wrapper_for_moe, forward_native_grouped_mm, - torch_native_forward, + # torch_native_forward, ) From b01530ea922ebb7c8a694eda51724ba043abcbca Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 08:35:33 +0000 Subject: [PATCH 19/26] Patch bnb --- unsloth_zoo/temporary_patches/gpt_oss.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 43f247415..289288ffc 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -1101,18 +1101,18 @@ def patch_gpt_oss_bnb4bit(): # Preserve original symbol names for compiler-generated modules. GptOssExpertsBnb4bit.__name__ = "GptOssExperts" GptOssExpertsBnb4bit.__qualname__ = "GptOssExperts" - if _is_transformers_v5(): - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit - # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. - # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. - # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict - # override to remap keys, but transformers v5 bypasses _load_from_state_dict - # (uses accelerate's set_module_tensor_to_device), so the remapping never ran - # and router weights were randomly initialized - causing high loss (~4-5). - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter - - logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") - os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" + + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit + # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. + # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. + # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict + # override to remap keys, but transformers v5 bypasses _load_from_state_dict + # (uses accelerate's set_module_tensor_to_device), so the remapping never ran + # and router weights were randomly initialized - causing high loss (~4-5). + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter + + logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") + os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" return True From 16fad112544f82f4715cf41653942c85a236e643 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 08:59:18 +0000 Subject: [PATCH 20/26] Undo scale shape --- unsloth_zoo/temporary_patches/gpt_oss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 289288ffc..048b9612f 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -321,7 +321,7 @@ def __init__(self, config): requires_grad=False, ) self.gate_up_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( @@ -334,7 +334,7 @@ def __init__(self, config): requires_grad=False, ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, 16, dtype=torch.uint8), + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), requires_grad=False, ) self.down_proj_bias = nn.Parameter( From 74873d9110aac8c5963e1d12e9b99aed21be482e Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 09:27:14 +0000 Subject: [PATCH 21/26] cleanup spaces --- unsloth_zoo/temporary_patches/gpt_oss.py | 230 ++++++----------------- 1 file changed, 61 insertions(+), 169 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 048b9612f..0265079c9 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -30,7 +30,6 @@ ) from importlib.metadata import version as importlib_version from ..utils import Version - transformers_version = Version(importlib_version("transformers")) has_static_cache = transformers_version >= Version("4.56.0.dev0") from .utils import ( @@ -44,7 +43,6 @@ process_return, ) from ..hf_utils import dtype_from_config - torch_cuda_device = torch.cuda.device # MXFP4 configuration @@ -86,8 +84,6 @@ def swiglu_torch_forward(a, alpha, limit, dtype = None): out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) out = out_gelu * (a_linear + 1) return out.to(a.dtype if dtype is None else dtype) - - pass @@ -98,15 +94,15 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): if limit is not None: mask_g = g <= limit mask_l = l.abs() <= limit - ḡ = torch.where(mask_g, g, limit) + ḡ = torch.where(mask_g, g, limit) l̄ = torch.where(mask_l, l, l.sign() * limit) else: # no clipping mask_g = mask_l = torch.ones_like(g, dtype=bool) - ḡ, l̄ = g, l + ḡ, l̄ = g, l σ = torch.sigmoid(alpha * ḡ) dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) - dl = ḡ * σ + dl = ḡ * σ dg = torch.where(mask_g, dg, 0.0) # clamp-grad dl = torch.where(mask_l, dl, 0.0) @@ -126,23 +122,16 @@ def patch_gpt_oss(): except Exception as e: HAS_TRITON_KERNELS = False # return raise_error("Please install triton_kernels", e) - try: import transformers.quantizers.quantizer_mxfp4 def is_kernels_available(): return True - transformers.quantizers.quantizer_mxfp4.is_kernels_available = ( - is_kernels_available - ) - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = ( - lambda *args, **kwargs: True - ) + transformers.quantizers.quantizer_mxfp4.is_kernels_available = is_kernels_available + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = (lambda *args, **kwargs: True) except Exception as e: - return raise_error( - "transformers.quantizers.quantizer_mxfp4.is_kernels_available", e - ) + return raise_error("transformers.quantizers.quantizer_mxfp4.is_kernels_available", e) if hasattr( transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer, "_lazy_import_kernels" @@ -152,13 +141,9 @@ def is_kernels_available(): ) try: - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = ( - lambda *args, **kwargs: True - ) + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = (lambda *args, **kwargs: True) except Exception as e: - return raise_error( - "transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e - ) + return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) if HAS_TRITON_KERNELS: try: @@ -208,12 +193,7 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale - patch_function( - transformers.integrations.mxfp4, - "swizzle_mxfp4", - swizzle_mxfp4, - match_level="relaxed", - ) + patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level="relaxed") class Mxfp4GptOssExperts_Training(torch.autograd.Function): @staticmethod @@ -569,7 +549,7 @@ def load_and_swizzle_mxfp4( del blocks pass - patch_function(transformers.integrations.mxfp4, "load_and_swizzle_mxfp4", load_and_swizzle_mxfp4, match_level="relaxed") + patch_function(transformers.integrations.mxfp4, "load_and_swizzle_mxfp4", load_and_swizzle_mxfp4, match_level = "relaxed") try: from transformers.integrations.mxfp4 import _replace_with_mxfp4_linear @@ -600,11 +580,7 @@ def replace_with_mxfp4_linear( return model - patch_function( - transformers.integrations.mxfp4, - "replace_with_mxfp4_linear", - replace_with_mxfp4_linear, - ) + patch_function(transformers.integrations.mxfp4, "replace_with_mxfp4_linear", replace_with_mxfp4_linear) pass @@ -647,18 +623,12 @@ def get_param(self): # (E, 2I, H) is shape_3d permuted by permute_to_2d. unflattened_shape = [self.shape_3d[i] for i in self.permute_to_2d] - return ( - self.weight.view(*unflattened_shape) - .permute(*self.permute_to_3d) - .contiguous() - ) + return self.weight.view(*unflattened_shape).permute(*self.permute_to_3d).contiguous() def set_weight_from_3d(self, weight_3d): """Sets the 2D weight from a 3D tensor.""" # 3D -> Permute -> Flatten - weight_2d = weight_3d.permute(*self.permute_to_2d).reshape( - self.out_features, self.in_features - ) + weight_2d = weight_3d.permute(*self.permute_to_2d).reshape(self.out_features, self.in_features) self.weight.data.copy_(weight_2d) def _load_from_state_dict( @@ -680,9 +650,7 @@ def _load_from_state_dict( # Found the parameter (likely from original model structure where it was a Param) val = state_dict[key] # Convert 3D val to 2D - val_2d = val.permute(*self.permute_to_2d).reshape( - self.out_features, self.in_features - ) + val_2d = val.permute(*self.permute_to_2d).reshape(self.out_features, self.in_features) # Put into 'weight' key state_dict[prefix + "weight"] = val_2d @@ -741,9 +709,7 @@ def __init__(self, config): ) ) - self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=self.dtype) - ) + self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=self.dtype)) # down_proj: 3D (E, I, H). Target 2D (H, E*I). # Permute (2, 0, 1) -> (H, E, I). Reverse (1, 2, 0) @@ -760,9 +726,7 @@ def __init__(self, config): ) ) - self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, dtype=self.dtype) - ) + self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, dtype=self.dtype)) def _load_from_state_dict( self, @@ -788,10 +752,7 @@ def _load_from_state_dict( if gate_up_key in state_dict and gate_up_weight_key not in state_dict: val_3d = state_dict.pop(gate_up_key) # gate_up_proj: 3D (E, H, 2I) -> permute (0, 2, 1) -> (E, 2I, H) -> reshape (E*2I, H) - val_2d = val_3d.permute(0, 2, 1).reshape( - self.num_experts * 2 * self.expert_dim, # out_features - self.hidden_size, # in_features - ) + val_2d = val_3d.permute(0, 2, 1).reshape(self.num_experts * 2 * self.expert_dim, self.hidden_size) state_dict[gate_up_weight_key] = val_2d # Handle down_proj: checkpoint has 3D tensor, we need 2D for ParameterModule.weight @@ -800,10 +761,7 @@ def _load_from_state_dict( if down_key in state_dict and down_weight_key not in state_dict: val_3d = state_dict.pop(down_key) # down_proj: 3D (E, I, H) -> permute (2, 0, 1) -> (H, E, I) -> reshape (H, E*I) - val_2d = val_3d.permute(2, 0, 1).reshape( - self.hidden_size, # out_features - self.num_experts * self.expert_dim, # in_features - ) + val_2d = val_3d.permute(2, 0, 1).reshape(self.hidden_size, self.num_experts * self.expert_dim) state_dict[down_weight_key] = val_2d # Call parent implementation @@ -823,14 +781,10 @@ def forward( """Forward using grouped_mm or loop fallback with LoRA support.""" # Use optimized grouped_mm if available if _check_torch_grouped_mm_supported(): - return forward_native_grouped_mm( - self, hidden_states, router_indices, routing_weights - ) + return forward_native_grouped_mm(self, hidden_states, router_indices, routing_weights) # Fallback to loop-based implementation - return torch_native_forward( - self, hidden_states, router_indices, routing_weights - ) + return torch_native_forward(self, hidden_states, router_indices, routing_weights) pass @@ -861,9 +815,7 @@ def __init__(self, config): # Use _RouterLinearParams (not nn.Linear) to avoid BnB 4-bit quantization. # State dict keys are router.linear.weight / router.linear.bias, matching # the BnB 4-bit checkpoint format where router was stored via nn.Linear. - self.linear = _RouterLinearParams( - self.hidden_dim, self.num_experts, dtype=dtype_from_config(config) - ) + self.linear = _RouterLinearParams(self.hidden_dim, self.num_experts, dtype=dtype_from_config(config)) # Properties for compatibility with transformers' _init_weights which expects .weight and .bias @property @@ -886,12 +838,8 @@ def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.linear(hidden_states.to(self.linear.weight.dtype)) # (batch_size * seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - router_top_value = torch.nn.functional.softmax( - router_top_value, dim=1, dtype=router_top_value.dtype - ) - router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_( - 1, router_indices, router_top_value - ) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_(1, router_indices, router_top_value) if transformers_version >= Version("5.0.0"): return router_logits, router_scores, router_indices else: @@ -1058,9 +1006,7 @@ def _load_from_state_dict( def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = torch.nn.functional.linear( - hidden_states.to(self.weight.dtype), self.weight, self.bias - ) + router_logits = torch.nn.functional.linear(hidden_states.to(self.weight.dtype), self.weight, self.bias) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) @@ -1206,57 +1152,33 @@ def moe_forward_inference(self, hidden_states): hidden_states = hidden_states.reshape(-1, moe.hidden_size) num_experts = routing_weights.shape[1] + X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) # Check if using ModuleList (old style) or 3D parameters (new style) if hasattr(moe, "gate_up_projs"): # ModuleList style - X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(moe.gate_up_projs)] gate_up = torch.stack(gate_up_list, dim=0) - dtype = ( - torch.float32 - if hidden_states.dtype != torch.bfloat16 - else hidden_states.dtype - ) + dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) fused = fused.to(dtype) - device_type = ( - fused.device.type - if isinstance(fused.device.type, str) and fused.device.type != "mps" - else "cpu" - ) + device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - out_list = [ - down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs) - ] + out_list = [down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs)] outs = torch.stack(out_list, dim=0) else: # 3D parameter style (compatible with transformers) - X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) # gate_up_proj: (E, hidden_size, 2*expert_dim) - bmm: (E, N, H) @ (E, H, 2I) -> (E, N, 2I) - gate_up = ( - torch.bmm(X_rep, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] - ) - dtype = ( - torch.float32 - if hidden_states.dtype != torch.bfloat16 - else hidden_states.dtype - ) + gate_up = torch.bmm(X_rep, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] + dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) fused = fused.to(dtype) - device_type = ( - fused.device.type - if isinstance(fused.device.type, str) and fused.device.type != "mps" - else "cpu" - ) + device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # down_proj: (E, expert_dim, hidden_size) - bmm: (E, N, I) @ (E, I, H) -> (E, N, H) - outs = ( - torch.bmm(fused.to(dtype), moe.down_proj) - + moe.down_proj_bias[..., None, :] - ) + outs = torch.bmm(fused.to(dtype), moe.down_proj) + moe.down_proj_bias[..., None, :] rw = routing_weights.to(dtype).transpose(0, 1).unsqueeze(-1) mixed = (outs * rw).sum(dim=0) @@ -1271,9 +1193,7 @@ def moe_router_forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - dtype = ( - torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype - ) + dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) return router_scores, router_indices @@ -1876,27 +1796,19 @@ def patch_gpt_oss_linearized(): def experts_forward( self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None ) -> torch.Tensor: - return forward_native_grouped_mm( - self, hidden_states, router_indices, routing_weights - ) + return forward_native_grouped_mm(self, hidden_states, router_indices, routing_weights) transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward else: def experts_forward( self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None ) -> torch.Tensor: - return torch_native_forward( - self, hidden_states, router_indices, routing_weights - ) + return torch_native_forward(self, hidden_states, router_indices, routing_weights) if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward - - if UNSLOTH_ENABLE_LOGGING: - logger.info( - f"Unsloth: Patched GPT OSS MoE for 4bit loading (backend: {backend})" - ) + if UNSLOTH_ENABLE_LOGGING: logger.info(f"Unsloth: Patched GPT OSS MoE for 4bit loading (backend: {backend})") return @@ -1920,7 +1832,6 @@ def patch_GptOssAttention(): return raise_error("flex_attention_with_sink", e) try: import transformers.models.gpt_oss.modeling_gpt_oss - transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb except Exception as e: @@ -1942,7 +1853,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: F_softmax = torch.nn.functional.softmax F_dropout = nn.functional.dropout matmul = torch.matmul - def inplace_eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -1959,10 +1869,7 @@ def inplace_eager_attention_forward( bsz, n_heads, qlen, _ = query.shape bsz, n_heads, kvlen, _ = key_states.shape out_dtype = torch.result_type(query, key_states) - combined_logits = key_states.new_empty( - (bsz, n_heads, qlen, kvlen + 1), - dtype=out_dtype, - ) + combined_logits = key_states.new_empty((bsz, n_heads, qlen, kvlen + 1), dtype=out_dtype) attn_weights = matmul(query, key_states.transpose(2, 3), out = combined_logits[:,:,:,:kvlen]) attn_weights *= scaling @@ -1974,11 +1881,13 @@ def inplace_eager_attention_forward( # combined_logits = torch.cat([attn_weights, sinks], dim=-1) combined_logits[:, :, :, -1] = module.sinks.reshape(1, -1, 1) - combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + # combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=torch.float32) probs = combined_logits scores = probs[..., :-1] # we drop the sink here - attn_weights = F_dropout(scores, p=dropout, training=module.training) + attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) attn_weights = attn_weights.to(value_states.dtype) attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() @@ -2007,7 +1916,10 @@ def eager_attention_forward( sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) combined_logits = torch.cat([attn_weights, sinks], dim=-1) - combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + # combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=torch.float32) combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) probs = combined_logits scores = probs[..., :-1] # we drop the sink here @@ -2156,7 +2068,6 @@ def patch_GptOssModel(): if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return try: import transformers.models.gpt_oss.modeling_gpt_oss - transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel from transformers.models.gpt_oss.modeling_gpt_oss import MoeModelOutputWithPast from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb @@ -2173,7 +2084,6 @@ def patch_GptOssModel(): # Disable mask creations since we don't need them for GPT-OSS import transformers.masking_utils import transformers.generation.utils - def wrap(f): def return_attention_mask(*args, **kwargs): if kwargs["input_embeds"].requires_grad: @@ -2187,7 +2097,6 @@ def return_attention_mask(*args, **kwargs): return f(*args, **kwargs) pass return return_attention_mask - pass create_causal_mask = getattr( transformers.masking_utils, @@ -2204,8 +2113,8 @@ def return_attention_mask(*args, **kwargs): if create_sliding_window_causal_mask is None: return raise_error("transformers.masking_utils.create_sliding_window_causal_mask") if not hasattr(transformers.masking_utils, "__patched_causal_mask__"): - transformers.masking_utils._old_create_causal_mask = _torch_compile(transformers.masking_utils.create_causal_mask, fullgraph=False, dynamic=True) - transformers.masking_utils._old_create_sliding_window_causal_mask = _torch_compile(transformers.masking_utils.create_sliding_window_causal_mask, fullgraph=False, dynamic=True) + transformers.masking_utils._old_create_causal_mask = _torch_compile(transformers.masking_utils.create_causal_mask, fullgraph = False, dynamic = True) + transformers.masking_utils._old_create_sliding_window_causal_mask = _torch_compile(transformers.masking_utils.create_sliding_window_causal_mask, fullgraph = False, dynamic = True) transformers.masking_utils.create_causal_mask = wrap(create_causal_mask) transformers.masking_utils.create_sliding_window_causal_mask = wrap(create_sliding_window_causal_mask) transformers.models.gpt_oss.modeling_gpt_oss.create_causal_mask = transformers.masking_utils.create_causal_mask @@ -2391,9 +2300,7 @@ def forward( if inputs_embeds is None: # Account for CPU offloaded embed_tokens embed_device = self.embed_tokens.weight.device - inputs_embeds = self.embed_tokens( - input_ids.to(embed_device, non_blocking=True) - ).to(input_ids.device) + inputs_embeds = self.embed_tokens(input_ids.to(embed_device, non_blocking=True)).to(input_ids.device) if not self.training: inputs_embeds.requires_grad_(False) @@ -2406,9 +2313,9 @@ def forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) try: - torch._dynamo.mark_static(hidden_states, 0) + torch._dynamo.mark_static (hidden_states, 0) torch._dynamo.mark_dynamic(hidden_states, 1) - torch._dynamo.mark_static(hidden_states, 2) + torch._dynamo.mark_static (hidden_states, 2) except: pass @@ -2511,14 +2418,11 @@ def forward( # Fix float16 / float32 mismatching hidden_states = hidden_states.to(inputs_embeds.dtype) - return process_return( - MoeModelOutputWithPast, - { + return process_return(MoeModelOutputWithPast, { "last_hidden_state": hidden_states, "past_key_values": past_key_values, "hidden_states": all_hidden_states, - }, - ) + }) patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level="relaxed") @@ -2537,7 +2441,7 @@ def forward( SystemContent, ToolDescription, load_harmony_encoding, - ReasoningEffort, + ReasoningEffort ) encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) @@ -2561,9 +2465,9 @@ def encode_conversations_with_harmony( assert reasoning_effort in ("low", "medium", "high") match reasoning_effort: - case "low": harmony_reasoning = ReasoningEffort.LOW - case "medium": harmony_reasoning = ReasoningEffort.MEDIUM - case "high": harmony_reasoning = ReasoningEffort.HIGH + case "low": harmony_reasoning = ReasoningEffort.LOW + case "medium": harmony_reasoning = ReasoningEffort.MEDIUM + case "high": harmony_reasoning = ReasoningEffort.HIGH convos = [] @@ -2571,8 +2475,7 @@ def encode_conversations_with_harmony( import datetime today = datetime.datetime.today().strftime("%Y-%m-%d") - system = Message.from_role_and_content( - Role.SYSTEM, + system = Message.from_role_and_content(Role.SYSTEM, SystemContent.new() .with_model_identity(model_identity) .with_reasoning_effort(harmony_reasoning) @@ -2786,13 +2689,7 @@ def __init__( initializer_range: float = 0.02, max_position_embeddings=131072, rms_norm_eps: float = 1e-5, - rope_scaling={ - "rope_type": "yarn", - "factor": 32.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "truncate": False, - }, + rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False}, attention_dropout: float = 0.0, num_experts_per_tok=4, router_aux_loss_coef: float = 0.9, @@ -2820,16 +2717,11 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout - self.head_dim = ( - head_dim - if head_dim is not None - else self.hidden_size // self.num_attention_heads - ) + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads self.layer_types = layer_types if self.layer_types is None: self.layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" - for i in range(self.num_hidden_layers) + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) ] layer_type_validation(self.layer_types) self.attention_bias = True From e45b828eba75f42541ec8289ec9656827b5bfb35 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 09:36:30 +0000 Subject: [PATCH 22/26] More space changes --- unsloth_zoo/temporary_patches/gpt_oss.py | 61 +++++------------------- 1 file changed, 12 insertions(+), 49 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 0265079c9..a33f38f8b 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -109,8 +109,6 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): grad = torch.empty_like(pre_act) grad[..., ::2], grad[..., 1::2] = dg, dl return g1 * grad.to(g1.dtype) - - pass @@ -125,23 +123,18 @@ def patch_gpt_oss(): try: import transformers.quantizers.quantizer_mxfp4 - def is_kernels_available(): - return True + def is_kernels_available(): return True transformers.quantizers.quantizer_mxfp4.is_kernels_available = is_kernels_available - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = (lambda *args, **kwargs: True) + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = lambda *args, **kwargs: True except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.is_kernels_available", e) - if hasattr( - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer, "_lazy_import_kernels" - ): - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer._lazy_import_kernels = ( - lambda *args, **kwargs: triton_kernels - ) + if hasattr(transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer, "_lazy_import_kernels"): + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer._lazy_import_kernels = lambda *args, **kwargs: triton_kernels try: - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = (lambda *args, **kwargs: True) + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = lambda *args, **kwargs: True except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) @@ -192,7 +185,6 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale - patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level="relaxed") class Mxfp4GptOssExperts_Training(torch.autograd.Function): @@ -305,21 +297,17 @@ def __init__(self, config): requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), - requires_grad=False, + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32),requires_grad=False, ) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16,), dtype=torch.uint8), - requires_grad=False, + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),requires_grad=False ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), - requires_grad=False, + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),requires_grad=False ) self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), - requires_grad=False, + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32),requires_grad=False ) self.alpha = 1.702 @@ -478,9 +466,7 @@ def mlp_forward(self, hidden_states): except Exception as e: return raise_error("transformers.integrations.tensor_parallel.shard_and_distribute_module", e) - def load_and_swizzle_mxfp4( - module, param_name, param_value, target_device, *args, **kwargs - ): + def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args, **kwargs): model = kwargs.get("model", None) empty_param = kwargs.get("empty_param", None) casting_dtype = kwargs.get("casting_dtype", None) @@ -554,9 +540,7 @@ def load_and_swizzle_mxfp4( try: from transformers.integrations.mxfp4 import _replace_with_mxfp4_linear except Exception as e: - return raise_error( - "transformers.integrations.mxfp4._replace_with_mxfp4_linear", e - ) + return raise_error("transformers.integrations.mxfp4._replace_with_mxfp4_linear", e) def replace_with_mxfp4_linear( model, @@ -581,8 +565,6 @@ def replace_with_mxfp4_linear( return model patch_function(transformers.integrations.mxfp4, "replace_with_mxfp4_linear", replace_with_mxfp4_linear) - - pass TEMPORARY_PATCHES.append(patch_gpt_oss) @@ -1920,7 +1902,6 @@ def eager_attention_forward( # when training with bsz>1 we clamp max values. # combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=torch.float32) - combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) probs = combined_logits scores = probs[..., :-1] # we drop the sink here attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) @@ -1967,9 +1948,7 @@ def forward_function( key_states = key_states.to(cache_dtype) value_states = value_states.to(cache_dtype) cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) if key_states.dtype != query_states.dtype or value_states.dtype != query_states.dtype: key_states = key_states.to(query_states.dtype) value_states = value_states.to(query_states.dtype) @@ -2024,7 +2003,6 @@ def forward_function( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights - pass functions = [] @@ -2057,8 +2035,6 @@ def forward( patch_function_past_key_values(transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention, "forward", functions) # Set env variable for padding purposes os.environ["UNSLOTH_ENABLE_FLEX_ATTENTION"] = "1" - - pass TEMPORARY_PATCHES.append(patch_GptOssAttention) @@ -2156,7 +2132,6 @@ def pre_attention_decoding( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) return query_states, key_states, value_states, input_shape - pass # Do flex_attention_with_sink_decoding with cannot be compiled @@ -2183,7 +2158,6 @@ def rms_layernorm_forward(self, hidden_states): hidden_states *= torch.rsqrt_(variance) hidden_states *= self.weight.to(hidden_states.device).to(torch.float32) return hidden_states.to(input_dtype) # main diff with Llama - pass # Re-compiling for each new sequence length which is NOT ideal @@ -2211,7 +2185,6 @@ def pre_forward( position_embeddings=position_embeddings, ) return query_states, key_states, value_states, input_shape - pass fused_torch_compile_options = get_torch_compile_options( epilogue_fusion = True, @@ -2241,7 +2214,6 @@ def post_forward( residual = hidden_states.clone() hidden_states = rms_layernorm_forward(self.post_attention_layernorm, hidden_states) return hidden_states, residual - pass def inference_forward( @@ -2274,7 +2246,6 @@ def inference_forward( residual = hidden_states.clone() hidden_states = rms_layernorm_forward(self.post_attention_layernorm, hidden_states) return hidden_states, residual - pass # if has_static_cache and Version(torch.__version__) >= Version("2.10.0"): # # torch 2.9.0 has excessive compilations @@ -2425,8 +2396,6 @@ def forward( }) patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level="relaxed") - - pass TEMPORARY_PATCHES.append(patch_GptOssModel) @@ -2531,8 +2500,6 @@ def encode_conversations_with_harmony( harmony_input_ids = encoding.render_conversation(convos) harmony_decoded_text = encoding.decode(harmony_input_ids) return harmony_decoded_text, harmony_input_ids - - pass @@ -2802,8 +2769,6 @@ def _patched_init_weights(self, module): patch_function(GptOssPreTrainedModel, "_init_weights", _patched_init_weights) GptOssPreTrainedModel._unsloth_init_weights_fixed = True - - pass TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) @@ -2906,7 +2871,5 @@ def _patched_causal_lm_forward( except Exception as e: if UNSLOTH_ENABLE_LOGGING: logger.warning(f"Unsloth: Could not patch GptOssForCausalLM.forward: {e}") - - pass TEMPORARY_PATCHES.append(patch_gpt_oss_for_grpo) From 35b70eb872a1938e93cf476ad6616ca2798080f0 Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 12:48:04 +0000 Subject: [PATCH 23/26] Fix GPT-OSS MXFP4 patch gating + swiglu backward --- unsloth_zoo/temporary_patches/gpt_oss.py | 49 +++++++++++------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index a33f38f8b..35e9a81a1 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -100,8 +100,8 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): mask_g = mask_l = torch.ones_like(g, dtype=bool) ḡ, l̄ = g, l - σ = torch.sigmoid(alpha * ḡ) - dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) + σ = torch.sigmoid(alpha * ḡ) + dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) dl = ḡ * σ dg = torch.where(mask_g, dg, 0.0) # clamp-grad dl = torch.where(mask_l, dl, 0.0) @@ -115,11 +115,12 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): def patch_gpt_oss(): try: import triton_kernels - - HAS_TRITON_KERNELS = True except Exception as e: - HAS_TRITON_KERNELS = False - # return raise_error("Please install triton_kernels", e) + if UNSLOTH_ENABLE_LOGGING: + logger.warning_once( + "Unsloth: `triton_kernels` is not installed, skipping GPT-OSS MXFP4 patches." + ) + return try: import transformers.quantizers.quantizer_mxfp4 @@ -138,21 +139,17 @@ def is_kernels_available(): return True except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) - if HAS_TRITON_KERNELS: - try: - from triton_kernels import matmul_ogs, swiglu + try: + from triton_kernels import matmul_ogs, swiglu - FnSpecs, FusedActivation, matmul_ogs = ( - matmul_ogs.FnSpecs, - matmul_ogs.FusedActivation, - matmul_ogs.matmul_ogs, - ) - swiglu_fn = swiglu.swiglu_fn - except Exception as e: - return raise_error("triton_kernels", e) - else: - # Skip MXFP4 patches when triton_kernels not available - return + FnSpecs, FusedActivation, matmul_ogs = ( + matmul_ogs.FnSpecs, + matmul_ogs.FusedActivation, + matmul_ogs.matmul_ogs, + ) + swiglu_fn = swiglu.swiglu_fn + except Exception as e: + return raise_error("triton_kernels", e) try: import transformers.integrations.mxfp4 @@ -185,7 +182,7 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale - patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level="relaxed") + patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level = "relaxed") class Mxfp4GptOssExperts_Training(torch.autograd.Function): @staticmethod @@ -297,17 +294,17 @@ def __init__(self, config): requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32),requires_grad=False, + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False, ) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),requires_grad=False + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), requires_grad=False ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),requires_grad=False + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), requires_grad=False ) self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32),requires_grad=False + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False ) self.alpha = 1.702 @@ -2395,7 +2392,7 @@ def forward( "hidden_states": all_hidden_states, }) - patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level="relaxed") + patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level = "relaxed") pass TEMPORARY_PATCHES.append(patch_GptOssModel) From dae00042ed36092e5dd99012bbb3613069b0fe6b Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 12:58:06 +0000 Subject: [PATCH 24/26] Revert "Fix GPT-OSS MXFP4 patch gating + swiglu backward" This reverts commit 35b70eb872a1938e93cf476ad6616ca2798080f0. --- unsloth_zoo/temporary_patches/gpt_oss.py | 49 +++++++++++++----------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 35e9a81a1..a33f38f8b 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -100,8 +100,8 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): mask_g = mask_l = torch.ones_like(g, dtype=bool) ḡ, l̄ = g, l - σ = torch.sigmoid(alpha * ḡ) - dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) + σ = torch.sigmoid(alpha * ḡ) + dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) dl = ḡ * σ dg = torch.where(mask_g, dg, 0.0) # clamp-grad dl = torch.where(mask_l, dl, 0.0) @@ -115,12 +115,11 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): def patch_gpt_oss(): try: import triton_kernels + + HAS_TRITON_KERNELS = True except Exception as e: - if UNSLOTH_ENABLE_LOGGING: - logger.warning_once( - "Unsloth: `triton_kernels` is not installed, skipping GPT-OSS MXFP4 patches." - ) - return + HAS_TRITON_KERNELS = False + # return raise_error("Please install triton_kernels", e) try: import transformers.quantizers.quantizer_mxfp4 @@ -139,17 +138,21 @@ def is_kernels_available(): return True except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) - try: - from triton_kernels import matmul_ogs, swiglu + if HAS_TRITON_KERNELS: + try: + from triton_kernels import matmul_ogs, swiglu - FnSpecs, FusedActivation, matmul_ogs = ( - matmul_ogs.FnSpecs, - matmul_ogs.FusedActivation, - matmul_ogs.matmul_ogs, - ) - swiglu_fn = swiglu.swiglu_fn - except Exception as e: - return raise_error("triton_kernels", e) + FnSpecs, FusedActivation, matmul_ogs = ( + matmul_ogs.FnSpecs, + matmul_ogs.FusedActivation, + matmul_ogs.matmul_ogs, + ) + swiglu_fn = swiglu.swiglu_fn + except Exception as e: + return raise_error("triton_kernels", e) + else: + # Skip MXFP4 patches when triton_kernels not available + return try: import transformers.integrations.mxfp4 @@ -182,7 +185,7 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale - patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level = "relaxed") + patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level="relaxed") class Mxfp4GptOssExperts_Training(torch.autograd.Function): @staticmethod @@ -294,17 +297,17 @@ def __init__(self, config): requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False, + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32),requires_grad=False, ) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), requires_grad=False + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),requires_grad=False ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), requires_grad=False + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),requires_grad=False ) self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32),requires_grad=False ) self.alpha = 1.702 @@ -2392,7 +2395,7 @@ def forward( "hidden_states": all_hidden_states, }) - patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level = "relaxed") + patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level="relaxed") pass TEMPORARY_PATCHES.append(patch_GptOssModel) From 2e12afebf3ba958fe6b28dcb5435aca85de0a70c Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 12:57:17 +0000 Subject: [PATCH 25/26] Style: fix spacing in GPT-OSS patches --- unsloth_zoo/temporary_patches/gpt_oss.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index a33f38f8b..365464d30 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -185,7 +185,7 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale - patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level="relaxed") + patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level = "relaxed") class Mxfp4GptOssExperts_Training(torch.autograd.Function): @staticmethod @@ -297,17 +297,17 @@ def __init__(self, config): requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32),requires_grad=False, + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False, ) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),requires_grad=False + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), requires_grad=False ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),requires_grad=False + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), requires_grad=False ) self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32),requires_grad=False + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False ) self.alpha = 1.702 @@ -2395,7 +2395,7 @@ def forward( "hidden_states": all_hidden_states, }) - patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level="relaxed") + patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level = "relaxed") pass TEMPORARY_PATCHES.append(patch_GptOssModel) From b8e170883424be8c9a83c3c520eca0aec8d925c8 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 9 Feb 2026 13:10:33 +0000 Subject: [PATCH 26/26] Copy swiglu_torch_backward from main --- unsloth_zoo/temporary_patches/gpt_oss.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 365464d30..cf7c2a157 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -96,22 +96,21 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): mask_l = l.abs() <= limit ḡ = torch.where(mask_g, g, limit) l̄ = torch.where(mask_l, l, l.sign() * limit) - else: # no clipping + else: # no clipping mask_g = mask_l = torch.ones_like(g, dtype=bool) ḡ, l̄ = g, l - σ = torch.sigmoid(alpha * ḡ) - dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) - dl = ḡ * σ - dg = torch.where(mask_g, dg, 0.0) # clamp-grad - dl = torch.where(mask_l, dl, 0.0) + σ = torch.sigmoid(alpha * ḡ) + dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) + dl = ḡ * σ + dg = torch.where(mask_g, dg, 0.) # clamp-grad + dl = torch.where(mask_l, dl, 0.) grad = torch.empty_like(pre_act) grad[..., ::2], grad[..., 1::2] = dg, dl return g1 * grad.to(g1.dtype) pass - def patch_gpt_oss(): try: import triton_kernels