From 48aeb069d41d0ed90bb2b60ea7f3f9f6cdec97c4 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 6 Feb 2026 15:10:43 +0000 Subject: [PATCH 01/17] Working GPT OSS --- unsloth_zoo/compiler.py | 4 + unsloth_zoo/temporary_patches/gpt_oss.py | 429 +++++++++++++++++++++-- 2 files changed, 412 insertions(+), 21 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b905a1e74..756135705 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -811,6 +811,10 @@ def create_new_function( imports += "import torch\n" imports += "import torch.nn as nn\n" imports += "from torch.nn import functional as F\n" + if "torch_compile" in new_source: + imports += "from unsloth_zoo.temporary_patches.common import torch_compile\n" + if "KWARGS_TYPE" in new_source: + imports += "from unsloth_zoo.temporary_patches.utils import KWARGS_TYPE\n" imports += ( "from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable\n" ) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 3a3cbcc5f..818e01aa1 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -18,6 +18,7 @@ import os import torch import torch.nn as nn +import torch.nn.init as init import torch.nn.functional as F import inspect from .common import ( @@ -44,6 +45,32 @@ from ..hf_utils import dtype_from_config torch_cuda_device = torch.cuda.device +# MXFP4 configuration +# Set UNSLOTH_MXFP4_NO_DEQUANTIZE=1 to keep MXFP4 weights quantized (requires triton_kernels) +# Otherwise, MXFP4 weights will be dequantized to bf16 for LoRA training +UNSLOTH_MXFP4_NO_DEQUANTIZE = os.environ.get("UNSLOTH_MXFP4_NO_DEQUANTIZE", "0") == "1" + + +def _check_triton_kernels_available(): + """Check if OpenAI's triton_kernels package is available for MXFP4.""" + try: + from triton_kernels import matmul_ogs, swiglu + + return True + except ImportError: + return False + + +_TRITON_KERNELS_AVAILABLE = None + + +def is_triton_kernels_available(): + """Cached check for triton_kernels availability.""" + global _TRITON_KERNELS_AVAILABLE + if _TRITON_KERNELS_AVAILABLE is None: + _TRITON_KERNELS_AVAILABLE = _check_triton_kernels_available() + return _TRITON_KERNELS_AVAILABLE + @torch_compile(dynamic = True, fullgraph = True) def swiglu_torch_forward(a, alpha, limit, dtype = None): @@ -87,14 +114,24 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): def patch_gpt_oss(): try: import triton_kernels + + HAS_TRITON_KERNELS = True except Exception as e: - return raise_error("Please install triton_kernels", e) + HAS_TRITON_KERNELS = False + # return raise_error("Please install triton_kernels", e) try: import transformers.quantizers.quantizer_mxfp4 - def is_kernels_available(): return True - transformers.quantizers.quantizer_mxfp4.is_kernels_available = is_kernels_available - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = lambda *args, **kwargs: True + + def is_kernels_available(): + return True + + transformers.quantizers.quantizer_mxfp4.is_kernels_available = ( + is_kernels_available + ) + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = ( + lambda *args, **kwargs: True + ) except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.is_kernels_available", e) @@ -106,16 +143,21 @@ def is_kernels_available(): return True except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) - try: - from triton_kernels import matmul_ogs, swiglu - FnSpecs, FusedActivation, matmul_ogs = ( - matmul_ogs.FnSpecs, - matmul_ogs.FusedActivation, - matmul_ogs.matmul_ogs, - ) - swiglu_fn = swiglu.swiglu_fn - except Exception as e: - return raise_error("triton_kernels", e) + if HAS_TRITON_KERNELS: + try: + from triton_kernels import matmul_ogs, swiglu + + FnSpecs, FusedActivation, matmul_ogs = ( + matmul_ogs.FnSpecs, + matmul_ogs.FusedActivation, + matmul_ogs.matmul_ogs, + ) + swiglu_fn = swiglu.swiglu_fn + except Exception as e: + return raise_error("triton_kernels", e) + else: + # Skip MXFP4 patches when triton_kernels not available + return try: import transformers.integrations.mxfp4 @@ -205,8 +247,8 @@ def forward( @staticmethod def backward(ctx, grad_token): raise NotImplementedError( - "Backwards pass using MXFP4 is still under construction!\n"\ - "Instead, use `unsloth/gpt-oss-20b-BF16` for bfloat16 training which will work for LoRA.\n"\ + "Backwards pass using MXFP4 is still under construction!\n" + "Instead, use `unsloth/gpt-oss-20b-BF16` for bfloat16 training which will work for LoRA.\n" "Or, use `load_in_4bit = True` which allows finetuning." ) (pre_act, gamma, gather_src, gather_dst, scatter_src, scatter_dst,) = ctx.saved_tensors @@ -243,12 +285,13 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size + # MXFP4 quantized format (blocks + scales) self.gate_up_proj_blocks = nn.Parameter( torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), requires_grad=False, ) self.gate_up_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( @@ -256,11 +299,11 @@ def __init__(self, config): ) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16,), dtype=torch.uint8), requires_grad=False, ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, 16, dtype=torch.uint8), requires_grad=False, ) self.down_proj_bias = nn.Parameter( @@ -271,7 +314,10 @@ def __init__(self, config): self.gate_up_proj_precision_config = None self.down_proj_precision_config = None - def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: + @property + def gate_up_proj(selforward( + self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx + ) -> torch.Tensor: with torch_cuda_device(hidden_states.device): if not hasattr(self, "act"): self.act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2) @@ -520,7 +566,12 @@ def __init__(self, config): self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - self.linear = nn.Linear(self.hidden_dim, self.num_experts, dtype=dtype_from_config(config)) + self.linear = nn.Linear( + self.hidden_dim, self.num_experts, dtype=dtype_from_config(config) + ) + # Expose weight/bias for HF init compatibility + self.weight = self.linear.weight + self.bias = self.linear.bias @torch_compile(dynamic = True, fullgraph = True) def forward(self, hidden_states): @@ -657,6 +708,10 @@ def patch_gpt_oss_linearized(): except Exception as e: return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + _patch_gpt_oss_init_weights_for_modulelist( + transformers.models.gpt_oss.modeling_gpt_oss + ) + # We find down_proj overflows in GPT OSS if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": def forward( @@ -736,6 +791,338 @@ def forward( TEMPORARY_PATCHES.append(patch_gpt_oss_linearized) +def _patch_gpt_oss_init_weights_for_modulelist(transformers_module): + GptOssPreTrainedModel = transformers_module.GptOssPreTrainedModel + GptOssExperts = transformers_module.GptOssExperts + if getattr(GptOssPreTrainedModel, "_unsloth_init_weights_patched", False): + return + _original_init_weights = GptOssPreTrainedModel._init_weights + + def _patched_init_weights(self, module): + _original_init_weights(self, module) + if isinstance(module, GptOssExperts) and not hasattr(module, "gate_up_proj"): + std = self.config.initializer_range + for up in getattr(module, "gate_up_projs", []): + init.normal_(up.weight, mean=0.0, std=std) + if up.bias is not None: + init.zeros_(up.bias) + for down in getattr(module, "down_projs", []): + init.normal_(down.weight, mean=0.0, std=std) + if down.bias is not None: + init.zeros_(down.bias) + + patch_function(GptOssPreTrainedModel, "_init_weights", _patched_init_weights) + GptOssPreTrainedModel._unsloth_init_weights_patched = True + + +def patch_gpt_oss_bf16_split_lora(): + """ + Patch GPT-OSS BF16 model to use split LoRA with grouped GEMM. + + This patch applies to BF16 models loaded via unsloth/gpt-oss-20b-BF16 or similar. + It creates stacked expert weights and uses moe_utils.py's forward_native_grouped_mm + for efficient training with split LoRA. + + Key differences from Qwen3 MoE: + 1. GPT-OSS uses TRANSPOSED weight layout: + - gate_up_proj: (num_experts, hidden_size, 2 * intermediate_size) + - down_proj: (num_experts, intermediate_size, hidden_size) + 2. GPT-OSS gate_up has interleaved layout (::2 for gate, 1::2 for up) + 3. GPT-OSS has biases: gate_up_proj_bias, down_proj_bias + + 4-bit BNB models use gate_up_projs/down_projs ModuleList and are NOT affected. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + # Skip 4-bit models - they use ModuleList and are handled by patch_gpt_oss_linearized + if "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + + from .moe_utils import ( + forward_native_grouped_mm, + select_moe_backend, + patch_param_wrapper_for_moe, + _has_lora_adapters, + _extract_lora_from_wrapper, + _should_use_separated_lora, + _is_moe_experts_module, + ) + + # Patch ParamWrapper.forward for MoE separated LoRA + patch_param_wrapper_for_moe() + + # LoRA Extractor function for GPT-OSS + # GPT-OSS weight layout is TRANSPOSED: (E, hidden, output) instead of (E, output, hidden) + def _gpt_oss_lora_extractor( + self, wrapper, weight_A, weight_B, scaling, num_experts + ): + """ + GPT-OSS LoRA extractor for transposed weight layout. + + GPT-OSS weights: + gate_up_proj: (E, H, 2*I) - transposed layout (in_dim, out_dim) + down_proj: (E, I, H) - transposed layout (in_dim, out_dim) + + For grouped_mm: X @ W where W is (E, in_dim, out_dim) + + PEFT creates: + lora_A: (E*R, in_dim) - projects input to rank space + lora_B: (out_dim, E*R) - projects rank to output + + For transposed format, the LoRA dimensions are already correct: + - We want X @ (E, in, R) @ (E, R, out) + """ + # This Unsloth Zoo code section is licensed under AGPL3 + + total_rank = weight_A.shape[0] + rank_per_expert = total_rank // num_experts + dim_A = weight_A.shape[ + 1 + ] # in_dim (hidden_dim for gate_up, intermediate for down) + dim_B = weight_B.shape[ + 0 + ] # out_dim (2*intermediate for gate_up, hidden_dim for down) + + # Get model dimensions from the experts module + hidden_dim = None + intermediate_dim = None + current = wrapper + while hasattr(current, "base_layer"): + current = current.base_layer + if hasattr(current, "hidden_size"): + hidden_dim = current.hidden_size + if hasattr(current, "intermediate_size"): + intermediate_dim = current.intermediate_size + + # Get parameter name + param_name = getattr(wrapper, "parameter_name", None) + + # GPT-OSS uses TRANSPOSED layout, so LoRA dimensions map directly: + # Input projection: X @ (E, in_dim, R) + # Output projection: result @ (E, R, out_dim) + + if ( + param_name == "down_proj" + and intermediate_dim is not None + and hidden_dim is not None + ): + # down_proj: input=intermediate_dim, output=hidden_dim + # Weight shape: (E, I, H) - transposed + # lora_A: (E*R, H) from PEFT (swapped due to 3D param handling) + # lora_B: (I, E*R) from PEFT (swapped) + # For X @ first @ second: first is (E, I, R), second is (E, R, H) + + # first_weight from B (has intermediate_dim) + first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + first_weight = first_weight.permute(1, 0, 2).contiguous() # (E, I, R) + + # second_weight from A (has hidden_dim) + second_weight = weight_A.view( + num_experts, rank_per_expert, dim_A + ) # (E, R, H) + + return first_weight, second_weight, scaling, num_experts + + elif param_name == "gate_up_proj" and hidden_dim is not None: + # gate_up_proj: input=hidden_dim, output=2*intermediate_dim + # Weight shape: (E, H, 2*I) - transposed + # lora_A: (E*R, 2*I) from PEFT + # lora_B: (H, E*R) from PEFT + # For X @ first @ second: first is (E, H, R), second is (E, R, 2*I) + + # first_weight from B (has hidden_dim) + first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + first_weight = first_weight.permute(1, 0, 2).contiguous() # (E, H, R) + + # second_weight from A (has 2*intermediate_dim) + second_weight = weight_A.view( + num_experts, rank_per_expert, dim_A + ) # (E, R, 2*I) + + return first_weight, second_weight, scaling, num_experts + + # Fallback: dimension-based detection + if hidden_dim is not None: + if dim_B == hidden_dim: + # B connects to hidden_dim (transposed case) + first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + first_weight = first_weight.permute(1, 0, 2).contiguous() + second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) + return first_weight, second_weight, scaling, num_experts + elif dim_A == hidden_dim: + # A connects to hidden_dim + first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) + first_weight = first_weight.permute(0, 2, 1).contiguous() + second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + second_weight = second_weight.permute(1, 2, 0).contiguous() + return first_weight, second_weight, scaling, num_experts + + # Final fallback + first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) + first_weight = first_weight.permute(0, 2, 1).contiguous() + second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) + second_weight = second_weight.permute(1, 2, 0).contiguous() + return first_weight, second_weight, scaling, num_experts + + # Patch GptOssExperts.forward to use grouped GEMM + split LoRA (BF16 only) + GptOssExperts = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts + _original_gpt_oss_experts_forward = GptOssExperts.forward + + def _bf16_split_lora_forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + # Fallback to original for 4-bit ModuleList or missing routing data + if ( + router_indices is None + or routing_weights is None + or hasattr(self, "gate_up_projs") + or not hasattr(self, "gate_up_proj") + or not hasattr(self, "down_proj") + ): + return _original_gpt_oss_experts_forward( + self, hidden_states, router_indices, routing_weights + ) + + gate_up_param = getattr(self, "gate_up_proj", None) + down_param = getattr(self, "down_proj", None) + if ( + not isinstance(gate_up_param, nn.Parameter) + or gate_up_param.ndim != 3 + or not isinstance(down_param, nn.Parameter) + or down_param.ndim != 3 + ): + return _original_gpt_oss_experts_forward( + self, hidden_states, router_indices, routing_weights + ) + + if not hasattr(self, "_unsloth_model_type"): + self._unsloth_model_type = "gpt_oss" + + return forward_native_grouped_mm( + self, + hidden_states, + router_indices, # top_k_index + routing_weights, # top_k_weights + ) + + patch_function(GptOssExperts, "forward", _bf16_split_lora_forward) + + _patch_gpt_oss_init_weights_for_modulelist( + transformers.models.gpt_oss.modeling_gpt_oss + ) + + # BF16 Experts class with stacked weights (for split LoRA via moe_utils) + class GptOssExpertsBF16Stacked(nn.Module): + """ + GPT-OSS BF16 Experts with stacked weights for grouped GEMM. + + Uses moe_utils.forward_native_grouped_mm for efficient MoE computation + with separated LoRA support. + """ + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.alpha = 1.702 + self.limit = getattr(config, "swiglu_limit", 7.0) + self.dtype = dtype_from_config(config) + + # Stacked weights in transposed format (E, in_dim, out_dim) for grouped_mm + self.gate_up_proj = nn.Parameter( + torch.empty( + self.num_experts, + self.hidden_size, + 2 * self.intermediate_size, + dtype=self.dtype, + ) + ) + self.gate_up_proj_bias = nn.Parameter( + torch.empty( + self.num_experts, 2 * self.intermediate_size, dtype=self.dtype + ) + ) + self.down_proj = nn.Parameter( + torch.empty( + self.num_experts, + self.intermediate_size, + self.hidden_size, + dtype=self.dtype, + ) + ) + self.down_proj_bias = nn.Parameter( + torch.empty(self.num_experts, self.hidden_size, dtype=self.dtype) + ) + + # Register LoRA extractor + self._unsloth_lora_extractor_fn = _gpt_oss_lora_extractor + + @property + def hidden_dim(self): + return self.hidden_size + + @property + def intermediate_dim(self): + return self.intermediate_size + + def forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + """ + Forward pass using grouped GEMM from moe_utils. + Uses forward_native_grouped_mm which handles split LoRA automatically. + """ + # Call moe_utils grouped MM forward which handles everything + return forward_native_grouped_mm( + self, + hidden_states, + router_indices, # top_k_index + routing_weights, # top_k_weights + ) + + # MLP wrapper that works with stacked experts + class GptOssMLP_BF16(nn.Module): + def __init__(self, config): + super().__init__() + self.router = GptOssTopKRouter(config) + self.experts = GptOssExpertsBF16Stacked(config) + + def forward(self, hidden_states): + bsz, qlen, hd = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hd) + + # Get router scores + router_scores, router_indices = self.router(hidden_states_flat) + + # Run experts with stacked weights + routed_out = self.experts( + hidden_states_flat, + router_indices=router_indices, + routing_weights=router_scores, + ) + + routed_out = routed_out.view(bsz, qlen, hd) + return routed_out, router_scores + + # Patch transformers module + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts._unsloth_lora_extractor_fn = _gpt_oss_lora_extractor + + # Check if model has stacked weights (BF16) vs ModuleList (4-bit) + # This is done at load time by checking the model structure + if UNSLOTH_ENABLE_LOGGING: + logger.info("Unsloth: Patched GPT-OSS for BF16 split LoRA with grouped GEMM") + + +pass +TEMPORARY_PATCHES.append(patch_gpt_oss_bf16_split_lora) + + def patch_GptOssAttention(): if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return From 2eb8ce3c7a941a8bcc2ff36db11d00a8918c1265 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 6 Feb 2026 16:57:34 +0000 Subject: [PATCH 02/17] Working GPT OSS --- unsloth_zoo/temporary_patches/gpt_oss.py | 2066 ++++++++++++++------ unsloth_zoo/temporary_patches/moe_utils.py | 3 + 2 files changed, 1470 insertions(+), 599 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 818e01aa1..c60c64f21 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -30,6 +30,7 @@ ) from importlib.metadata import version as importlib_version from ..utils import Version + transformers_version = Version(importlib_version("transformers")) has_static_cache = transformers_version >= Version("4.56.0.dev0") from .utils import ( @@ -43,6 +44,7 @@ process_return, ) from ..hf_utils import dtype_from_config + torch_cuda_device = torch.cuda.device # MXFP4 configuration @@ -84,8 +86,11 @@ def swiglu_torch_forward(a, alpha, limit, dtype = None): out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) out = out_gelu * (a_linear + 1) return out.to(a.dtype if dtype is None else dtype) + + pass + @torch_compile(dynamic = True, fullgraph = True) def swiglu_torch_backward(pre_act, alpha, limit, g1): g, l = pre_act[..., ::2].to(torch.float32), pre_act[..., 1::2].to(torch.float32) @@ -93,21 +98,23 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): if limit is not None: mask_g = g <= limit mask_l = l.abs() <= limit - ḡ = torch.where(mask_g, g, limit) + ḡ = torch.where(mask_g, g, limit) l̄ = torch.where(mask_l, l, l.sign() * limit) - else: # no clipping + else: # no clipping mask_g = mask_l = torch.ones_like(g, dtype=bool) - ḡ, l̄ = g, l + ḡ, l̄ = g, l - σ = torch.sigmoid(alpha * ḡ) - dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) - dl = ḡ * σ - dg = torch.where(mask_g, dg, 0.) # clamp-grad - dl = torch.where(mask_l, dl, 0.) + σ = torch.sigmoid(alpha * ḡ) + dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) + dl = ḡ * σ + dg = torch.where(mask_g, dg, 0.0) # clamp-grad + dl = torch.where(mask_l, dl, 0.0) grad = torch.empty_like(pre_act) grad[..., ::2], grad[..., 1::2] = dg, dl return g1 * grad.to(g1.dtype) + + pass @@ -133,15 +140,25 @@ def is_kernels_available(): lambda *args, **kwargs: True ) except Exception as e: - return raise_error("transformers.quantizers.quantizer_mxfp4.is_kernels_available", e) + return raise_error( + "transformers.quantizers.quantizer_mxfp4.is_kernels_available", e + ) - if hasattr(transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer, "_lazy_import_kernels"): - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer._lazy_import_kernels = lambda *args, **kwargs: triton_kernels + if hasattr( + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer, "_lazy_import_kernels" + ): + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer._lazy_import_kernels = ( + lambda *args, **kwargs: triton_kernels + ) try: - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = lambda *args, **kwargs: True + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = ( + lambda *args, **kwargs: True + ) except Exception as e: - return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) + return raise_error( + "transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e + ) if HAS_TRITON_KERNELS: try: @@ -166,6 +183,7 @@ def is_kernels_available(): def swizzle_mxfp4(w, w_scale, *args, **kwargs): from triton_kernels import tensor, tensor_details + FP4, convert_layout, wrap_torch_tensor = ( tensor.FP4, tensor.convert_layout, @@ -174,8 +192,12 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): layout = tensor_details.layout StridedLayout = tensor_details.layout.StridedLayout - value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) - w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1 + ) + w = convert_layout( + wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts + ) # TODO : add that when we are actually sure that it works on B200 # if torch.cuda.get_device_capability()[0] == 10: # constraints = { @@ -190,7 +212,13 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale - patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level = "relaxed") + + patch_function( + transformers.integrations.mxfp4, + "swizzle_mxfp4", + swizzle_mxfp4, + match_level="relaxed", + ) class Mxfp4GptOssExperts_Training(torch.autograd.Function): @staticmethod @@ -203,7 +231,9 @@ def forward( scatter_idx, ): pre_activation = matmul_ogs( - hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware + hidden_states.to( + torch.bfloat16 + ), # tl.dot_scaled upcasts to BF16 for old hardware self_class.gate_up_proj, self_class.gate_up_proj_bias, routing_data, @@ -242,6 +272,7 @@ def forward( ctx.scatter_idx = scatter_idx ctx.routing_data = routing_data return out + pass @staticmethod @@ -274,7 +305,9 @@ def backward(ctx, grad_token): dx_token = torch.zeros_like(grad_token) dx_token.index_add_(0, gather_dst, dx_exp) return (dx_token, None, None, None, None,) + pass + pass class Mxfp4GptOssExperts(nn.Module): @@ -287,7 +320,13 @@ def __init__(self, config): # MXFP4 quantized format (blocks + scales) self.gate_up_proj_blocks = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), + torch.zeros( + self.num_experts, + 2 * self.intermediate_size, + self.hidden_size // 32, + 16, + dtype=torch.uint8, + ), requires_grad=False, ) self.gate_up_proj_scales = nn.Parameter( @@ -295,7 +334,8 @@ def __init__(self, config): requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), + requires_grad=False, ) self.down_proj_blocks = nn.Parameter( @@ -307,15 +347,91 @@ def __init__(self, config): requires_grad=False, ) self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), + requires_grad=False, ) + self.alpha = 1.702 self.limit = getattr(config, "swiglu_limit", 7.0) self.gate_up_proj_precision_config = None self.down_proj_precision_config = None @property - def gate_up_proj(selforward( + def gate_up_proj(self): + """Return gate_up_proj tensor (created from blocks/scales or stored directly).""" + # Check if already set as an attribute (from checkpoint loading or previous dequantization) + if "_gate_up_proj" in self.__dict__: + return self.__dict__["_gate_up_proj"] + + # Check if MXFP4 weights are present (blocks and scales are not all zeros) + blocks_valid = ( + self.gate_up_proj_blocks.device.type != "meta" + and self.gate_up_proj_blocks.numel() > 0 + and self.gate_up_proj_blocks.any() + ) + + if not blocks_valid: + # No MXFP4 weights and no regular weights loaded + raise AttributeError( + f"Mxfp4GptOssExperts.gate_up_proj: No weights loaded. " + f"Try 'openai/gpt-oss-20b' with load_in_4bit=True instead." + ) + + # MXFP4 weights present - dequantize them + try: + from transformers.integrations.mxfp4 import dequantize + # Dequantize: (E, out_dim, in_dim//32, 16) -> (E, out_dim, in_dim) + dequantized = dequantize(self.gate_up_proj_blocks, self.gate_up_proj_scales) + # Cache for future accesses + self.__dict__["_gate_up_proj"] = dequantized + return dequantized + except Exception as e: + raise RuntimeError( + f"Failed to dequantize MXFP4 gate_up_proj: {e}. " + f"Ensure transformers.integrations.mxfp4.dequantize is available." + ) + + @gate_up_proj.setter + def gate_up_proj(self, value): + """Set gate_up_proj tensor (called during checkpoint loading).""" + self.__dict__["_gate_up_proj"] = value + + @property + def down_proj(self): + """Return down_proj tensor (created from blocks/scales or stored directly).""" + if "_down_proj" in self.__dict__: + return self.__dict__["_down_proj"] + + blocks_valid = ( + self.down_proj_blocks.device.type != "meta" + and self.down_proj_blocks.numel() > 0 + and self.down_proj_blocks.any() + ) + + if not blocks_valid: + raise AttributeError( + f"Mxfp4GptOssExperts.down_proj: No weights loaded." + ) + + # MXFP4 weights present - dequantize them + try: + from transformers.integrations.mxfp4 import dequantize + # Dequantize: (E, out_dim, in_dim//32, 16) -> (E, out_dim, in_dim) + dequantized = dequantize(self.down_proj_blocks, self.down_proj_scales) + # Cache for future accesses + self.__dict__["_down_proj"] = dequantized + return dequantized + except Exception as e: + raise RuntimeError( + f"Failed to dequantize MXFP4 down_proj: {e}" + ) + + @down_proj.setter + def down_proj(self, value): + """Set down_proj tensor (called during checkpoint loading).""" + self.__dict__["_down_proj"] = value + + def forward( self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx ) -> torch.Tensor: with torch_cuda_device(hidden_states.device): @@ -323,7 +439,7 @@ def gate_up_proj(selforward( self.act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2) if not hidden_states.requires_grad: intermediate_cache1 = matmul_ogs( - hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware + hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware self.gate_up_proj, self.gate_up_proj_bias, routing_data, @@ -350,43 +466,50 @@ def gate_up_proj(selforward( scatter_idx, ) return intermediate_cache3 + pass + patch_function(transformers.integrations.mxfp4, "Mxfp4GptOssExperts", Mxfp4GptOssExperts) - try: - routing = triton_kernels.routing.routing - routing = torch.compiler.disable(routing) - except Exception as e: - return raise_error("triton_kernels.routing.routing", e) + if HAS_TRITON_KERNELS: + try: + routing = triton_kernels.routing.routing + routing = torch.compiler.disable(routing) + except Exception as e: + return raise_error("triton_kernels.routing.routing", e) - def mlp_forward(self, hidden_states): - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) - router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) + def mlp_forward(self, hidden_states): + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) + router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - with torch_cuda_device(router_logits.device): - routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) + with torch_cuda_device(router_logits.device): + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) - routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) - routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) - return routed_out, router_logits - patch_function(transformers.integrations.mxfp4, "mlp_forward", mlp_forward) + routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) + routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) + return routed_out, router_logits - try: - PrecisionConfig, FlexCtx, InFlexData = ( - triton_kernels.matmul_ogs.PrecisionConfig, - triton_kernels.matmul_ogs.FlexCtx, - triton_kernels.matmul_ogs.InFlexData, - ) - except Exception as e: - return raise_error("triton_kernels.matmul_ogs", e) + patch_function(transformers.integrations.mxfp4, "mlp_forward", mlp_forward) + + if HAS_TRITON_KERNELS: + try: + PrecisionConfig, FlexCtx, InFlexData = ( + triton_kernels.matmul_ogs.PrecisionConfig, + triton_kernels.matmul_ogs.FlexCtx, + triton_kernels.matmul_ogs.InFlexData, + ) + except Exception as e: + return raise_error("triton_kernels.matmul_ogs", e) try: from transformers.integrations.tensor_parallel import shard_and_distribute_module except Exception as e: return raise_error("transformers.integrations.tensor_parallel.shard_and_distribute_module", e) - def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args, **kwargs): + def load_and_swizzle_mxfp4( + module, param_name, param_value, target_device, *args, **kwargs + ): model = kwargs.get("model", None) empty_param = kwargs.get("empty_param", None) casting_dtype = kwargs.get("casting_dtype", None) @@ -397,17 +520,22 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args for proj in ["gate_up_proj", "down_proj"]: if proj in param_name: if device_mesh is not None: - shard_and_distribute_module( - model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh - ) + shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) else: setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False)) blocks_attr = f"{proj}_blocks" scales_attr = f"{proj}_scales" blocks = getattr(module, blocks_attr) scales = getattr(module, scales_attr) - # Check if both blocks and scales both not on on meta device - if blocks.device.type != "meta" and scales.device.type != "meta": + # Check if both blocks and scales both not on meta device AND not all zeros + # (if blocks are all zeros, they're from initialization, not from checkpoint) + blocks_valid = ( + blocks.device.type != "meta" + and scales.device.type != "meta" + and blocks.numel() > 0 + and blocks.any() # At least some non-zero values + ) + if blocks_valid: # need it for ep local_experts = blocks.size(0) if proj == "gate_up_proj": @@ -448,13 +576,16 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args delattr(module, blocks_attr) # setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) del blocks + pass - patch_function(transformers.integrations.mxfp4, "load_and_swizzle_mxfp4", load_and_swizzle_mxfp4, match_level = "relaxed") + patch_function(transformers.integrations.mxfp4, "load_and_swizzle_mxfp4", load_and_swizzle_mxfp4, match_level="relaxed") try: from transformers.integrations.mxfp4 import _replace_with_mxfp4_linear except Exception as e: - return raise_error("transformers.integrations.mxfp4._replace_with_mxfp4_linear", e) + return raise_error( + "transformers.integrations.mxfp4._replace_with_mxfp4_linear", e + ) def replace_with_mxfp4_linear( model, @@ -464,17 +595,11 @@ def replace_with_mxfp4_linear( config=None, ): if quantization_config.dequantize: return model - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + modules_to_not_convert = (["lm_head"] if modules_to_not_convert is None else modules_to_not_convert) if quantization_config.modules_to_not_convert is not None: modules_to_not_convert.extend(quantization_config.modules_to_not_convert) modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _replace_with_mxfp4_linear( - model, - modules_to_not_convert, - current_key_name, - quantization_config, - config=config, - ) + model, has_been_replaced = _replace_with_mxfp4_linear(model, modules_to_not_convert, current_key_name, quantization_config, config=config) if not has_been_replaced: logger.warning_once( "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." @@ -483,206 +608,765 @@ def replace_with_mxfp4_linear( ) return model - patch_function(transformers.integrations.mxfp4, "replace_with_mxfp4_linear", replace_with_mxfp4_linear) + + patch_function( + transformers.integrations.mxfp4, + "replace_with_mxfp4_linear", + replace_with_mxfp4_linear, + ) + + pass TEMPORARY_PATCHES.append(patch_gpt_oss) +class ParameterModule(nn.Linear): + """ + A module that wraps a parameter to look like a Linear layer for PEFT. + It inherits from nn.Linear but manages 3D <-> 2D weight conversion. + Unsloth grouped_mm requires 3D weights: + - gate_up: (E, H, 2I) + - down: (E, I, H) + PEFT Linear requires 2D weights: (Out, In). + We store the weight as 2D for PEFT, and reshape for Unsloth via get_param(). + """ + + def __init__( + self, in_features, out_features, shape_3d, permute_to_2d, permute_to_3d + ): + # Initialize nn.Linear (creates weight of size (out, in)) + super().__init__(in_features, out_features, bias=False) + self.shape_3d = shape_3d + self.permute_to_2d = permute_to_2d + self.permute_to_3d = permute_to_3d + + # We expect the caller to set the weight content correctly. + # nn.Linear initialized it randomly. We will overwrite it. + + def extra_repr(self): + return f"in_features={self.in_features}, out_features={self.out_features}, shape_3d={self.shape_3d}" + + def get_param(self): + """Restores the 3D weight for Unsloth computation.""" + # 2D weight (Out, In) -> View (Unflattened 2D) -> Permute -> 3D(E, ...) + # We need to know the unflattened shape. + # gate_up: 2D (E*2I, H). View (E, 2I, H). Permute(0,2,1) -> (E, H, 2I). + # We store (E*2I, H). + # To restore: View (E, 2I, H)? + # (E, 2I, H) is shape_3d permuted by permute_to_2d. + + unflattened_shape = [self.shape_3d[i] for i in self.permute_to_2d] + return ( + self.weight.view(*unflattened_shape) + .permute(*self.permute_to_3d) + .contiguous() + ) + + def set_weight_from_3d(self, weight_3d): + """Sets the 2D weight from a 3D tensor.""" + # 3D -> Permute -> Flatten + weight_2d = weight_3d.permute(*self.permute_to_2d).reshape( + self.out_features, self.in_features + ) + self.weight.data.copy_(weight_2d) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # 'gate_up_proj' in checkpoint is 3D. + # We need to load it into 'gate_up_proj.weight' (2D). + # key is '...gate_up_proj' + key = prefix[:-1] + + if key in state_dict: + # Found the parameter (likely from original model structure where it was a Param) + val = state_dict[key] + # Convert 3D val to 2D + val_2d = val.permute(*self.permute_to_2d).reshape( + self.out_features, self.in_features + ) + + # Put into 'weight' key + state_dict[prefix + "weight"] = val_2d + del state_dict[key] + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + class GptOssExperts(nn.Module): + """ + GPT OSS MoE Experts layer with 3D stacked parameters. + Compatible with transformers' _init_weights and supports grouped_mm with split LoRA. + + Uses the same structure as the original transformers GptOssExperts: + - gate_up_proj: (num_experts, hidden_size, 2 * expert_dim) + - gate_up_proj_bias: (num_experts, 2 * expert_dim) + - down_proj: (num_experts, expert_dim, hidden_size) + - down_proj_bias: (num_experts, hidden_size) + """ + + def __init__(self, config): + super().__init__() + + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.expert_dim = config.intermediate_size + self.intermediate_size = config.intermediate_size # Alias for compatibility + self.alpha = 1.702 + self.limit = getattr(config, "swiglu_limit", 7.0) + self.dtype = dtype_from_config(config) + + # gate_up_proj: 3D (E, H, 2I). Target 2D (E*2I, H). + # Permute (0, 2, 1) -> (E, 2I, H). Reverse (0, 2, 1). + self.gate_up_proj = ParameterModule( + in_features=self.hidden_size, + out_features=self.num_experts * 2 * self.expert_dim, + shape_3d=(self.num_experts, self.hidden_size, 2 * self.expert_dim), + permute_to_2d=(0, 2, 1), + permute_to_3d=(0, 2, 1), + ) + # Initialize 3D zero tensor and set 2D weight + self.gate_up_proj.set_weight_from_3d( + torch.zeros( + self.num_experts, + self.hidden_size, + 2 * self.expert_dim, + dtype=self.dtype, + ) + ) + + self.gate_up_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=self.dtype) + ) + + # down_proj: 3D (E, I, H). Target 2D (H, E*I). + # Permute (2, 0, 1) -> (H, E, I). Reverse (1, 2, 0) + self.down_proj = ParameterModule( + in_features=self.num_experts * self.expert_dim, + out_features=self.hidden_size, + shape_3d=(self.num_experts, self.expert_dim, self.hidden_size), + permute_to_2d=(2, 0, 1), + permute_to_3d=(1, 2, 0), + ) + self.down_proj.set_weight_from_3d( + torch.empty( + self.num_experts, self.expert_dim, self.hidden_size, dtype=self.dtype + ) + ) + + self.down_proj_bias = nn.Parameter( + torch.zeros(self.num_experts, self.hidden_size, dtype=self.dtype) + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """ + Override to handle loading 3D tensors (gate_up_proj, down_proj) from original checkpoints. + The original checkpoint has these as nn.Parameter (3D tensors), but we now have + them as ParameterModule (nn.Linear subclass) which expects .weight as 2D tensors. + + This method intercepts the 3D tensors and converts them to the 2D format + that ParameterModule expects. + """ + # Handle gate_up_proj: checkpoint has 3D tensor, we need 2D for ParameterModule.weight + gate_up_key = prefix + "gate_up_proj" + gate_up_weight_key = prefix + "gate_up_proj.weight" + if gate_up_key in state_dict and gate_up_weight_key not in state_dict: + val_3d = state_dict.pop(gate_up_key) + # gate_up_proj: 3D (E, H, 2I) -> permute (0, 2, 1) -> (E, 2I, H) -> reshape (E*2I, H) + val_2d = val_3d.permute(0, 2, 1).reshape( + self.num_experts * 2 * self.expert_dim, # out_features + self.hidden_size, # in_features + ) + state_dict[gate_up_weight_key] = val_2d + + # Handle down_proj: checkpoint has 3D tensor, we need 2D for ParameterModule.weight + down_key = prefix + "down_proj" + down_weight_key = prefix + "down_proj.weight" + if down_key in state_dict and down_weight_key not in state_dict: + val_3d = state_dict.pop(down_key) + # down_proj: 3D (E, I, H) -> permute (2, 0, 1) -> (H, E, I) -> reshape (H, E*I) + val_2d = val_3d.permute(2, 0, 1).reshape( + self.hidden_size, # out_features + self.num_experts * self.expert_dim, # in_features + ) + state_dict[down_weight_key] = val_2d + + # Call parent implementation + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + """Forward using grouped_mm or loop fallback with LoRA support.""" + # Use optimized grouped_mm if available + if _check_torch_grouped_mm_supported(): + return forward_native_grouped_mm( + self, hidden_states, router_indices, routing_weights + ) + + # Fallback to loop-based implementation + return forward_native_moe_loop( + self, hidden_states, router_indices, routing_weights + ) + + +pass + + +class _RouterLinearParams(nn.Module): + """ + Simple parameter container that stores weight/bias like nn.Linear + but is NOT nn.Linear itself, so BitsAndBytes will NOT quantize it. + State dict keys: linear.weight, linear.bias (matching BnB 4-bit checkpoints + where the router was saved via an nn.Linear submodule). + """ + def __init__(self, in_features, out_features, dtype): + super().__init__() + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) + self.bias = nn.Parameter(torch.zeros(out_features, dtype=dtype)) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + + +class GptOssTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + # Use _RouterLinearParams (not nn.Linear) to avoid BnB 4-bit quantization. + # State dict keys are router.linear.weight / router.linear.bias, matching + # the BnB 4-bit checkpoint format where router was stored via nn.Linear. + self.linear = _RouterLinearParams( + self.hidden_dim, self.num_experts, dtype=dtype_from_config(config) + ) + + # Properties for compatibility with transformers' _init_weights which expects .weight and .bias + @property + def weight(self): + return self.linear.weight + + @weight.setter + def weight(self, value): + self.linear.weight = value + + @property + def bias(self): + return self.linear.bias + + @bias.setter + def bias(self, value): + self.linear.bias = value + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.linear( + hidden_states.to(self.linear.weight.dtype) + ) # (batch_size * seq_len, num_experts) + router_top_value, router_indices = torch.topk( + router_logits, self.top_k, dim=-1 + ) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax( + router_top_value, dim=1, dtype=router_top_value.dtype + ) + router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_( + 1, router_indices, router_top_value + ) + return router_logits, router_scores, router_indices + + +pass + + +# BitsAndBytes 4bit compatible classes for loading pre-quantized models +class GptOssExpertsBnb4bit(nn.Module): + """ + GPT OSS MoE Experts using nn.Linear layers for BitsAndBytes 4bit compatibility. + This version uses gate_up_projs and down_projs as ModuleLists of Linear layers, + which can be properly quantized with BitsAndBytes. + """ + def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = config.intermediate_size + self.intermediate_size = config.intermediate_size self.alpha = 1.702 self.limit = getattr(config, "swiglu_limit", 7.0) self.dtype = dtype_from_config(config) self.gate_up_projs = nn.ModuleList([ - nn.Linear(self.hidden_size, 2 * self.expert_dim, dtype=self.dtype) + nn.Linear(self.hidden_size, 2 * self.expert_dim, bias=True, dtype=self.dtype) for _ in range(self.num_experts) ]) self.down_projs = nn.ModuleList([ - nn.Linear(self.expert_dim, self.hidden_size, dtype=self.dtype) + nn.Linear(self.expert_dim, self.hidden_size, bias=True, dtype=self.dtype) for _ in range(self.num_experts) ]) - def forward( - self, - hidden_states: torch.Tensor, - router_indices = None, - routing_weights = None - ) -> torch.Tensor: + # Provide minimal tensors for transformers _init_weights compatibility. + # Keep them empty to avoid allocating large bf16 weights in 4-bit mode. + self.register_buffer( + "gate_up_proj", torch.empty(0, dtype=self.dtype), persistent=False + ) + self.register_buffer( + "gate_up_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False + ) + self.register_buffer( + "down_proj", torch.empty(0, dtype=self.dtype), persistent=False + ) + self.register_buffer( + "down_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False + ) + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) num_experts = routing_weights.shape[1] + if self.training: - next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) - # with torch.no_grad(): - # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - # expert_mask = expert_mask.permute(2, 1, 0) - # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - # for expert_idx in expert_hitted[:]: - for expert_idx in range(num_experts): + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hitted[:]: with torch.no_grad(): - # _, token_idx = torch.where(expert_mask[expert_idx[0]]) - token_idx, _ = torch.where(router_indices == expert_idx) + _, token_idx = torch.where(expert_mask[expert_idx[0]]) current_state = hidden_states[token_idx] gate_up = self.gate_up_projs[expert_idx](current_state) - gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit) - # gate, up = gate_up[..., ::2], gate_up[..., 1::2] - # gate = gate.clamp(min=None, max=self.limit) - # up = up.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # gated_output = (up + 1) * glu + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu out = self.down_projs[expert_idx](gated_output) - weighted_output = out * routing_weights[token_idx, expert_idx, None].to(torch.float32) - next_states.index_add_(0, token_idx, weighted_output) + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states.to(hidden_states.dtype) + return next_states else: X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] gate_up = torch.stack(gate_up_list, dim=0) - fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = X_rep.dtype) - # gate = gate_up[..., ::2] - # up_h = gate_up[..., 1::2] - # gate = gate.clamp(max=self.limit) - # up_h = up_h.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # fused = (up_h + 1) * glu + gate = gate_up[..., ::2] + up_h = gate_up[..., 1::2] + gate = gate.clamp(max=self.limit) + up_h = up_h.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + fused = (up_h + 1) * glu out_list = [down_l(fused[e]) for e, down_l in enumerate(self.down_projs)] outs = torch.stack(out_list, dim=0) rw = routing_weights.transpose(0, 1).unsqueeze(-1) - mixed = (outs.to(torch.float32) * rw.to(torch.float32)).sum(dim=0) - return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) + mixed = (outs * rw).sum(dim=0) + return mixed.view(batch_size, -1, self.hidden_size) + + pass -class GptOssTopKRouter(nn.Module): + +class GptOssTopKRouterBnb4bit(nn.Module): + """ + GPT OSS Router using direct parameters for BitsAndBytes 4bit compatibility. + This version uses weight/bias as direct nn.Parameter instead of nested Linear. + """ + def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - self.linear = nn.Linear( - self.hidden_dim, self.num_experts, dtype=dtype_from_config(config) + self.dtype = dtype_from_config(config) + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, dtype=self.dtype)) + self.bias = nn.Parameter(torch.zeros(self.num_experts, dtype=self.dtype)) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # Accept checkpoints that used a Linear router (linear.weight/bias) + linear_weight_key = prefix + "linear.weight" + linear_bias_key = prefix + "linear.bias" + weight_key = prefix + "weight" + bias_key = prefix + "bias" + moved_weight = False + moved_bias = False + if linear_weight_key in state_dict and weight_key not in state_dict: + state_dict[weight_key] = state_dict.pop(linear_weight_key) + moved_weight = True + if linear_bias_key in state_dict and bias_key not in state_dict: + state_dict[bias_key] = state_dict.pop(linear_bias_key) + moved_bias = True + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, ) - # Expose weight/bias for HF init compatibility - self.weight = self.linear.weight - self.bias = self.linear.bias + if moved_weight: + if linear_weight_key in unexpected_keys: + unexpected_keys.remove(linear_weight_key) + if weight_key in missing_keys: + missing_keys.remove(weight_key) + if moved_bias: + if linear_bias_key in unexpected_keys: + unexpected_keys.remove(linear_bias_key) + if bias_key in missing_keys: + missing_keys.remove(bias_key) + if self.weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): + self.weight.data = self.weight.data.to(self.dtype) + if self.bias is not None and self.bias.dtype not in ( + torch.float16, torch.bfloat16, torch.float32 + ): + self.bias.data = self.bias.data.to(self.dtype) - @torch_compile(dynamic = True, fullgraph = True) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.linear(hidden_states.to(self.linear.weight.dtype)) # (batch_size * seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_logits = torch.nn.functional.linear( + hidden_states.to(self.weight.dtype), self.weight, self.bias + ) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) - router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_(1, router_indices, router_top_value) + return router_logits, router_scores, router_indices + + pass +def patch_gpt_oss_bnb4bit(): + """ + Patch transformers to use BnB 4bit compatible classes when loading pre-quantized models. + This should be called before loading models that were saved with BitsAndBytes quantization + using the linear-based expert structure. + + Usage: + from unsloth_zoo.temporary_patches.gpt_oss import patch_gpt_oss_bnb4bit + patch_gpt_oss_bnb4bit() # Call before loading the model + model = FastLanguageModel.from_pretrained(...) + """ + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + + # Store original classes for potential restoration + if not hasattr(transformers.models.gpt_oss.modeling_gpt_oss, '_original_GptOssExperts'): + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssExperts = \ + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssTopKRouter = \ + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter + + # Replace with BnB 4bit compatible versions + # Preserve original symbol names for compiler-generated modules. + GptOssExpertsBnb4bit.__name__ = "GptOssExperts" + GptOssExpertsBnb4bit.__qualname__ = "GptOssExperts" + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit + # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. + # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. + # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict + # override to remap keys, but transformers v5 bypasses _load_from_state_dict + # (uses accelerate's set_module_tensor_to_device), so the remapping never ran + # and router weights were randomly initialized - causing high loss (~4-5). + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter + + logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") + os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" + return True + + +pass + + +def restore_gpt_oss_original(): + """ + Restore original GPT-OSS classes (undo BnB 4bit patch). + """ + try: + import transformers.models.gpt_oss.modeling_gpt_oss + if hasattr(transformers.models.gpt_oss.modeling_gpt_oss, '_original_GptOssExperts'): + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = \ + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssExperts + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = \ + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssTopKRouter + logger.info("Unsloth: Restored original GPT OSS classes") + return True + except Exception: + pass + return False + +def patch_gpt_oss_bnb4bit_auto(): + """ + Auto-patch GPT-OSS for BnB 4-bit when load_in_4bit is active. + Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to opt out. + """ + if not _should_use_gpt_oss_bnb4bit(): + return + # Avoid compiler-generated modules missing BnB helpers + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + patch_gpt_oss_bnb4bit() + # Ensure inference path avoids torch.compile for 4-bit + try: + global moe_forward_inference + moe_forward_inference = torch.compiler.disable(moe_forward_inference) + except Exception: + pass + + +TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto) + + # Combo kernels uses too much VRAM for low memory GPUs from ..device_type import DEVICE_TYPE + if DEVICE_TYPE == "xpu": device_memory = torch.xpu.memory.mem_get_info(0)[-1] else: device_memory = torch.cuda.memory.mem_get_info(0)[-1] -use_combo_kernels = False if device_memory/1024/1024/1024 <= 40 else True +use_combo_kernels = False if device_memory / 1024 / 1024 / 1024 <= 40 else True fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion = True, - max_autotune = False, # Too slow - shape_padding = True, - cudagraphs = True, - coordinate_descent_tuning = use_combo_kernels, # Very slow! - combo_kernels = use_combo_kernels, - memory_planning = True, - multi_kernel = False, # Fails on torch 2.10 nightly - use_block_ptr = True, - logging = UNSLOTH_ENABLE_LOGGING, + epilogue_fusion=True, + max_autotune=False, # Too slow + shape_padding=True, + cudagraphs=True, + coordinate_descent_tuning=use_combo_kernels, # Very slow! + combo_kernels=use_combo_kernels, + memory_planning=True, + multi_kernel=False, # Fails on torch 2.10 nightly + use_block_ptr=True, + logging=UNSLOTH_ENABLE_LOGGING, ) no_combo_fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion = True, - max_autotune = False, # Too slow - shape_padding = True, - cudagraphs = True, - coordinate_descent_tuning = use_combo_kernels, # Very slow! - combo_kernels = False, # Breaks on attention - memory_planning = True, - multi_kernel = False, # Fails on torch 2.10 nightly - use_block_ptr = True, - logging = UNSLOTH_ENABLE_LOGGING, + epilogue_fusion=True, + max_autotune=False, # Too slow + shape_padding=True, + cudagraphs=True, + coordinate_descent_tuning=use_combo_kernels, # Very slow! + combo_kernels=False, # Breaks on attention + memory_planning=True, + multi_kernel=False, # Fails on torch 2.10 nightly + use_block_ptr=True, + logging=UNSLOTH_ENABLE_LOGGING, ) -@_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) + +@_torch_compile(dynamic=None, fullgraph=True, options=fused_torch_compile_options) def moe_forward_inference(self, hidden_states): """Torch compile for forward inference path only with CUDAGraphs""" # Router - router_scores, router_indices = self.router(hidden_states) + router_out = self.router(hidden_states) + if isinstance(router_out, tuple) and len(router_out) == 3: + _, router_scores, router_indices = router_out + else: + router_scores, router_indices = router_out routing_weights = router_scores moe = self.experts batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, moe.hidden_size) num_experts = routing_weights.shape[1] - X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) - # Gate up projection - gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(moe.gate_up_projs)] - gate_up = torch.stack(gate_up_list, dim = 0) - dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype - fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype = dtype) + # Check if using ModuleList (old style) or 3D parameters (new style) + if hasattr(moe, "gate_up_projs"): + # ModuleList style + X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) + gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(moe.gate_up_projs)] + gate_up = torch.stack(gate_up_list, dim=0) + dtype = ( + torch.float32 + if hidden_states.dtype != torch.bfloat16 + else hidden_states.dtype + ) + fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) - # Down projection must be done in float32 if not bfloat16 otherwise infinites - fused = fused.to(dtype) - device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - out_list = [down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs)] - outs = torch.stack(out_list, dim=0) + fused = fused.to(dtype) + device_type = ( + fused.device.type + if isinstance(fused.device.type, str) and fused.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + out_list = [ + down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs) + ] + outs = torch.stack(out_list, dim=0) + else: + # 3D parameter style (compatible with transformers) + X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) + # gate_up_proj: (E, hidden_size, 2*expert_dim) - bmm: (E, N, H) @ (E, H, 2I) -> (E, N, 2I) + gate_up = ( + torch.bmm(X_rep, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] + ) + dtype = ( + torch.float32 + if hidden_states.dtype != torch.bfloat16 + else hidden_states.dtype + ) + fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) + + fused = fused.to(dtype) + device_type = ( + fused.device.type + if isinstance(fused.device.type, str) and fused.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + # down_proj: (E, expert_dim, hidden_size) - bmm: (E, N, I) @ (E, I, H) -> (E, N, H) + outs = ( + torch.bmm(fused.to(dtype), moe.down_proj) + + moe.down_proj_bias[..., None, :] + ) rw = routing_weights.to(dtype).transpose(0, 1).unsqueeze(-1) mixed = (outs * rw).sum(dim=0) return mixed.view(batch_size, -1, moe.hidden_size).to(hidden_states.dtype) + + pass -@torch_compile(dynamic = True, fullgraph = True) + +@torch_compile(dynamic=True, fullgraph=True) def moe_router_forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, self.bias) # (seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) - router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) + router_logits = F.linear( + hidden_states.to(self.weight.dtype), self.weight, self.bias + ) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk( + router_logits, self.top_k, dim=-1 + ) # (seq_len, top_k) + dtype = ( + torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype + ) + router_top_value = torch.nn.functional.softmax( + router_top_value, dim=1, dtype=torch.float32 + ).to(dtype) + router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_( + 1, router_indices, router_top_value + ) return router_scores, router_indices + + pass -# Combo Kernels errors with InductorError: AttributeError: 'NullKernelHandler' object has no attribute 'index_to_str' -@_torch_compile(dynamic = None, fullgraph = True, options = no_combo_fused_torch_compile_options) -def moe_forward_inference_bf16(self, hidden_states): - router_scores, router_indices = moe_router_forward(self.router, hidden_states) - routing_weights = router_scores - moe = self.experts +# Combo Kernels errors with InductorError: AttributeError: 'NullKernelHandler' object has no attribute 'index_to_str' +@_torch_compile( + dynamic=None, fullgraph=True, options=no_combo_fused_torch_compile_options +) +def _moe_forward_inference_bf16_kernel( + hidden_states, routing_weights, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias, limit, alpha, hidden_size +): + """Inner compiled kernel for BF16 MoE inference - works with raw tensors only.""" batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, moe.hidden_size) + hidden_states = hidden_states.reshape(-1, hidden_size) num_experts = routing_weights.shape[1] hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, moe.hidden_size) - gate_up = torch.bmm(hidden_states, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] + hidden_states = hidden_states.view(num_experts, -1, hidden_size) + + gate_up = ( + torch.bmm(hidden_states, gate_up_proj) + gate_up_proj_bias[..., None, :] + ) gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=moe.limit) - up = up.clamp(min=-moe.limit, max=moe.limit) - glu = gate * torch.sigmoid(gate.to(torch.float32) * moe.alpha).to(gate.dtype) - next_states = torch.bmm(((up + 1) * glu), moe.down_proj) - next_states = next_states + moe.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, moe.hidden_size) - next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate.to(torch.float32) * alpha).to(gate.dtype) + next_states = torch.bmm(((up + 1) * glu), down_proj) + next_states = next_states + down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, hidden_size) + next_states = ( + next_states + * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + ) next_states = next_states.sum(dim=0) return next_states -pass + + +def _unwrap_peft_experts(module): + """Unwrap PEFT ParamWrapper chain to get the actual experts module.""" + while hasattr(module, 'base_layer'): + module = module.base_layer + return module + + +def moe_forward_inference_bf16(self, hidden_states): + """Wrapper that extracts weights from ParameterModule before calling the compiled kernel.""" + router_scores, router_indices = moe_router_forward(self.router, hidden_states) + routing_weights = router_scores + + moe = _unwrap_peft_experts(self.experts) + + # Handle ParameterModule (which wraps nn.Linear) vs direct 3D tensor + # Extract weights BEFORE the compiled region + gate_up_proj = moe.gate_up_proj + if hasattr(gate_up_proj, "get_param"): + gate_up_proj = gate_up_proj.get_param() + elif hasattr(gate_up_proj, "weight"): + gate_up_proj = gate_up_proj.weight + + down_proj = moe.down_proj + if hasattr(down_proj, "get_param"): + down_proj = down_proj.get_param() + elif hasattr(down_proj, "weight"): + down_proj = down_proj.weight + + return _moe_forward_inference_bf16_kernel( + hidden_states, + routing_weights, + gate_up_proj, + moe.gate_up_proj_bias, + down_proj, + moe.down_proj_bias, + moe.limit, + moe.alpha, + moe.hidden_size, + ) + + class GptOssMLP(nn.Module): @@ -695,432 +1379,472 @@ def forward(self, hidden_states): bsz, qlen, hd = hidden_states.shape if qlen == 1 and not self.training: return moe_forward_inference(self, hidden_states), None - router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out, router_scores -pass - -def patch_gpt_oss_linearized(): - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return - if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return - try: - import transformers.models.gpt_oss.modeling_gpt_oss - except Exception as e: - return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + router_out = self.router(hidden_states) + if isinstance(router_out, tuple) and len(router_out) == 3: + _, router_scores, router_indices = router_out + else: + router_scores, router_indices = router_out + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out, router_scores - _patch_gpt_oss_init_weights_for_modulelist( - transformers.models.gpt_oss.modeling_gpt_oss - ) - # We find down_proj overflows in GPT OSS - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - def forward( - self, - hidden_states: torch.Tensor, - router_indices = None, - routing_weights = None - ) -> torch.Tensor: +pass - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) - num_experts = routing_weights.shape[1] - if self.training: - next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) - # with torch.no_grad(): - # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - # expert_mask = expert_mask.permute(2, 1, 0) - # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - # for expert_idx in expert_hitted[:]: - for expert_idx in range(num_experts): - with torch.no_grad(): - # _, token_idx = torch.where(expert_mask[expert_idx[0]]) - token_idx, _ = torch.where(router_indices == expert_idx) - current_state = hidden_states[token_idx] - gate_up = self.gate_up_projs[expert_idx](current_state) - down_proj = self.down_projs[expert_idx] - gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = torch.float32) - # gate, up = gate_up[..., ::2], gate_up[..., 1::2] - # gate = gate.clamp(min=None, max=self.limit) - # up = up.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # gated_output = (up + 1) * glu - - # Force float32 matrix multiply on some down projection modules - gated_output = gated_output.to(torch.float32) - device_type = gated_output.device.type if isinstance(gated_output.device.type, str) and gated_output.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - out = down_proj(gated_output) - weighted_output = out.to(torch.float32) * routing_weights[token_idx, expert_idx, None].to(torch.float32) - next_states.index_add_(0, token_idx, weighted_output) - next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states.to(torch.float32) - else: - X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) - gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] - gate_up = torch.stack(gate_up_list, dim=0) - dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype - fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = dtype) - # gate = gate_up[..., ::2] - # up_h = gate_up[..., 1::2] - # gate = gate.clamp(max=self.limit) - # up_h = up_h.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # fused = (up_h + 1) * glu - - # Force float32 matrix multiply on down projection only - device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - out_list = [ - down_l(fused[e].to(dtype)) - for e, down_l in enumerate(self.down_projs) - ] - outs = torch.stack(out_list, dim=0) - rw = routing_weights.transpose(0, 1).unsqueeze(-1) - mixed = (outs.to(dtype) * rw.to(dtype)).sum(dim=0) - return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) - pass - pass - GptOssExperts.forward = forward - pass - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExperts - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter - transformers.models.gpt_oss.modeling_gpt_oss.GptOssMLP = GptOssMLP - return -pass -TEMPORARY_PATCHES.append(patch_gpt_oss_linearized) +# ============================================================================ +# GPT OSS MoE LoRA Support using grouped GEMM kernels +# ============================================================================ + +# IMPORTS FROM MOE UTILS +from .moe_utils import ( + _check_grouped_gemm_available, + _TORCH_GROUPED_MM_AVAILABLE, + _check_torch_grouped_mm_supported, + native_moe_grouped_mm, + _get_moe_lora_weights, + _apply_lora_grouped_mm, + _get_lora_wrapper_for_param, + select_moe_backend, + patch_param_wrapper_for_moe, + forward_native_grouped_mm, + forward_native_moe_loop, +) -def _patch_gpt_oss_init_weights_for_modulelist(transformers_module): - GptOssPreTrainedModel = transformers_module.GptOssPreTrainedModel - GptOssExperts = transformers_module.GptOssExperts - if getattr(GptOssPreTrainedModel, "_unsloth_init_weights_patched", False): - return - _original_init_weights = GptOssPreTrainedModel._init_weights +def _should_use_gpt_oss_bnb4bit() -> bool: + """ + Decide if GPT-OSS should use BnB-compatible 4-bit experts. + Default: True when load_in_4bit is active. + Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return False + if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return False + return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1" - def _patched_init_weights(self, module): - _original_init_weights(self, module) - if isinstance(module, GptOssExperts) and not hasattr(module, "gate_up_proj"): - std = self.config.initializer_range - for up in getattr(module, "gate_up_projs", []): - init.normal_(up.weight, mean=0.0, std=std) - if up.bias is not None: - init.zeros_(up.bias) - for down in getattr(module, "down_projs", []): - init.normal_(down.weight, mean=0.0, std=std) - if down.bias is not None: - init.zeros_(down.bias) - patch_function(GptOssPreTrainedModel, "_init_weights", _patched_init_weights) - GptOssPreTrainedModel._unsloth_init_weights_patched = True +def _is_gpt_oss_4bit_load() -> bool: + return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "") -def patch_gpt_oss_bf16_split_lora(): - """ - Patch GPT-OSS BF16 model to use split LoRA with grouped GEMM. +def _is_transformers_v5() -> bool: + return transformers_version >= Version("5.0.0.dev0") - This patch applies to BF16 models loaded via unsloth/gpt-oss-20b-BF16 or similar. - It creates stacked expert weights and uses moe_utils.py's forward_native_grouped_mm - for efficient training with split LoRA. - Key differences from Qwen3 MoE: - 1. GPT-OSS uses TRANSPOSED weight layout: - - gate_up_proj: (num_experts, hidden_size, 2 * intermediate_size) - - down_proj: (num_experts, intermediate_size, hidden_size) - 2. GPT-OSS gate_up has interleaved layout (::2 for gate, 1::2 for up) - 3. GPT-OSS has biases: gate_up_proj_bias, down_proj_bias +def patch_gpt_oss_moe_for_lora(): + """ + Patch GPT OSS MoE experts for LoRA training with grouped GEMM support. + This patches the original GptOssExperts class (with 3D parameter tensors) + to use optimized grouped GEMM kernels with LoRA support. - 4-bit BNB models use gate_up_projs/down_projs ModuleList and are NOT affected. + IMPORTANT: We only patch the forward method, NOT replace the entire class. + This preserves the original class structure so weights load correctly. """ if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return - # Skip 4-bit models - they use ModuleList and are handled by patch_gpt_oss_linearized - if "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit(): + # 4-bit loads should keep quantized weights and use default PEFT LoRA. + return + if not _is_transformers_v5(): + # Split-LoRA grouped_mm path is only needed for transformers v5+ return + # First patch ParamWrapper for MoE separated LoRA + patch_param_wrapper_for_moe() try: import transformers.models.gpt_oss.modeling_gpt_oss + + # Get the ORIGINAL class - don't replace it! + GptOssExpertsClass = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts except Exception as e: - return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + if UNSLOTH_ENABLE_LOGGING: + logger.warning(f"Unsloth: Could not patch GPT OSS MoE for LoRA: {e}") + return - from .moe_utils import ( - forward_native_grouped_mm, - select_moe_backend, - patch_param_wrapper_for_moe, - _has_lora_adapters, - _extract_lora_from_wrapper, - _should_use_separated_lora, - _is_moe_experts_module, - ) + # Check if already patched + if hasattr(GptOssExpertsClass, "_unsloth_lora_patched"): + return - # Patch ParamWrapper.forward for MoE separated LoRA - patch_param_wrapper_for_moe() + # Select backend + backend = select_moe_backend() - # LoRA Extractor function for GPT-OSS - # GPT-OSS weight layout is TRANSPOSED: (E, hidden, output) instead of (E, output, hidden) - def _gpt_oss_lora_extractor( - self, wrapper, weight_A, weight_B, scaling, num_experts - ): - """ - GPT-OSS LoRA extractor for transposed weight layout. + if backend == "grouped_mm": + forward = forward_native_grouped_mm + else: + forward = forward_native_moe_loop - GPT-OSS weights: - gate_up_proj: (E, H, 2*I) - transposed layout (in_dim, out_dim) - down_proj: (E, I, H) - transposed layout (in_dim, out_dim) + # Store original forward and patch - but DON'T replace the class! + GptOssExpertsClass._original_forward = GptOssExpertsClass.forward + GptOssExpertsClass.forward = forward + GptOssExpertsClass._unsloth_lora_patched = True - For grouped_mm: X @ W where W is (E, in_dim, out_dim) + if UNSLOTH_ENABLE_LOGGING: + backend_desc = { + "grouped_mm": "torch._grouped_mm (batched, fastest)", + "unsloth_triton": "Triton kernels", + "native_torch": "loop fallback (slower)", + }.get(backend, backend) + logger.info( + f"Unsloth: Patched GPT OSS MoE for LoRA training using {backend_desc}" + ) - PEFT creates: - lora_A: (E*R, in_dim) - projects input to rank space - lora_B: (out_dim, E*R) - projects rank to output - For transposed format, the LoRA dimensions are already correct: - - We want X @ (E, in, R) @ (E, R, out) - """ - # This Unsloth Zoo code section is licensed under AGPL3 - - total_rank = weight_A.shape[0] - rank_per_expert = total_rank // num_experts - dim_A = weight_A.shape[ - 1 - ] # in_dim (hidden_dim for gate_up, intermediate for down) - dim_B = weight_B.shape[ - 0 - ] # out_dim (2*intermediate for gate_up, hidden_dim for down) - - # Get model dimensions from the experts module - hidden_dim = None - intermediate_dim = None - current = wrapper - while hasattr(current, "base_layer"): - current = current.base_layer - if hasattr(current, "hidden_size"): - hidden_dim = current.hidden_size - if hasattr(current, "intermediate_size"): - intermediate_dim = current.intermediate_size - - # Get parameter name - param_name = getattr(wrapper, "parameter_name", None) - - # GPT-OSS uses TRANSPOSED layout, so LoRA dimensions map directly: - # Input projection: X @ (E, in_dim, R) - # Output projection: result @ (E, R, out_dim) - - if ( - param_name == "down_proj" - and intermediate_dim is not None - and hidden_dim is not None - ): - # down_proj: input=intermediate_dim, output=hidden_dim - # Weight shape: (E, I, H) - transposed - # lora_A: (E*R, H) from PEFT (swapped due to 3D param handling) - # lora_B: (I, E*R) from PEFT (swapped) - # For X @ first @ second: first is (E, I, R), second is (E, R, H) - - # first_weight from B (has intermediate_dim) - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() # (E, I, R) - - # second_weight from A (has hidden_dim) - second_weight = weight_A.view( - num_experts, rank_per_expert, dim_A - ) # (E, R, H) - - return first_weight, second_weight, scaling, num_experts - - elif param_name == "gate_up_proj" and hidden_dim is not None: - # gate_up_proj: input=hidden_dim, output=2*intermediate_dim - # Weight shape: (E, H, 2*I) - transposed - # lora_A: (E*R, 2*I) from PEFT - # lora_B: (H, E*R) from PEFT - # For X @ first @ second: first is (E, H, R), second is (E, R, 2*I) - - # first_weight from B (has hidden_dim) - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() # (E, H, R) - - # second_weight from A (has 2*intermediate_dim) - second_weight = weight_A.view( - num_experts, rank_per_expert, dim_A - ) # (E, R, 2*I) - - return first_weight, second_weight, scaling, num_experts - - # Fallback: dimension-based detection - if hidden_dim is not None: - if dim_B == hidden_dim: - # B connects to hidden_dim (transposed case) - first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - first_weight = first_weight.permute(1, 0, 2).contiguous() - second_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - return first_weight, second_weight, scaling, num_experts - elif dim_A == hidden_dim: - # A connects to hidden_dim - first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - first_weight = first_weight.permute(0, 2, 1).contiguous() - second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - second_weight = second_weight.permute(1, 2, 0).contiguous() - return first_weight, second_weight, scaling, num_experts - - # Final fallback - first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) - first_weight = first_weight.permute(0, 2, 1).contiguous() - second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) - second_weight = second_weight.permute(1, 2, 0).contiguous() - return first_weight, second_weight, scaling, num_experts - - # Patch GptOssExperts.forward to use grouped GEMM + split LoRA (BF16 only) - GptOssExperts = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts - _original_gpt_oss_experts_forward = GptOssExperts.forward +TEMPORARY_PATCHES.append(patch_gpt_oss_moe_for_lora) - def _bf16_split_lora_forward( - self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None - ) -> torch.Tensor: - # Fallback to original for 4-bit ModuleList or missing routing data - if ( - router_indices is None - or routing_weights is None - or hasattr(self, "gate_up_projs") - or not hasattr(self, "gate_up_proj") - or not hasattr(self, "down_proj") - ): - return _original_gpt_oss_experts_forward( - self, hidden_states, router_indices, routing_weights - ) - gate_up_param = getattr(self, "gate_up_proj", None) - down_param = getattr(self, "down_proj", None) - if ( - not isinstance(gate_up_param, nn.Parameter) - or gate_up_param.ndim != 3 - or not isinstance(down_param, nn.Parameter) - or down_param.ndim != 3 - ): - return _original_gpt_oss_experts_forward( - self, hidden_states, router_indices, routing_weights - ) +# ============================================================================ +# MXFP4 (4-bit) GPT OSS MoE LoRA Support +# ============================================================================ - if not hasattr(self, "_unsloth_model_type"): - self._unsloth_model_type = "gpt_oss" +_MXFP4_LORA_PATH_LOGGED = False - return forward_native_grouped_mm( - self, - hidden_states, - router_indices, # top_k_index - routing_weights, # top_k_weights +@torch.compiler.disable +def forward_mxfp4_gpt_oss_with_lora( + self, + hidden_states: torch.Tensor, + routing_data, + gather_idx, + scatter_idx, +) -> torch.Tensor: + """ + MXFP4 GPT OSS MoE forward pass with LoRA support. + + For LoRA training with MXFP4: + - Base MXFP4 matmul is computed without gradients (frozen weights) + - LoRA delta is computed with gradients + - Output = base_output + lora_delta + + This allows finetuning MXFP4 quantized models with LoRA. + + Requires triton_kernels for native MXFP4 matmul. + If triton_kernels is not available, model should be loaded with + Mxfp4Config(dequantize=True) to convert to bf16. + """ + if not is_triton_kernels_available(): + raise RuntimeError( + "triton_kernels is required for native MXFP4 GPT OSS forward pass. " + "Either:\n" + " 1. Install triton_kernels from OpenAI, OR\n" + " 2. Load model with dequantization: Mxfp4Config(dequantize=True), OR\n" + " 3. Use the BF16 model: 'unsloth/gpt-oss-20b-BF16'\n" + "Set UNSLOTH_MXFP4_NO_DEQUANTIZE=0 (default) to auto-dequantize." ) - patch_function(GptOssExperts, "forward", _bf16_split_lora_forward) + from triton_kernels import matmul_ogs, swiglu - _patch_gpt_oss_init_weights_for_modulelist( - transformers.models.gpt_oss.modeling_gpt_oss - ) + matmul_ogs_fn = matmul_ogs.matmul_ogs + FnSpecs = matmul_ogs.FnSpecs + FusedActivation = matmul_ogs.FusedActivation + swiglu_fn = swiglu.swiglu_fn - # BF16 Experts class with stacked weights (for split LoRA via moe_utils) - class GptOssExpertsBF16Stacked(nn.Module): - """ - GPT-OSS BF16 Experts with stacked weights for grouped GEMM. + # Get LoRA wrappers + gate_up_wrapper = _get_lora_wrapper_for_param(self, "gate_up_proj") + down_wrapper = _get_lora_wrapper_for_param(self, "down_proj") - Uses moe_utils.forward_native_grouped_mm for efficient MoE computation - with separated LoRA support. - """ + has_lora = gate_up_wrapper is not None or down_wrapper is not None - def __init__(self, config): - super().__init__() - self.num_experts = config.num_local_experts - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.alpha = 1.702 - self.limit = getattr(config, "swiglu_limit", 7.0) - self.dtype = dtype_from_config(config) + # Log path info once + global _MXFP4_LORA_PATH_LOGGED + if not _MXFP4_LORA_PATH_LOGGED: + _MXFP4_LORA_PATH_LOGGED = True + logger.warning_once( + f"Unsloth: GPT-OSS MoE training path: MXFP4 + triton_kernels. " + f"LoRA={has_lora}, experts={self.num_experts}. " + f"Tip: Increase batch_size for better GPU utilization." + ) - # Stacked weights in transposed format (E, in_dim, out_dim) for grouped_mm - self.gate_up_proj = nn.Parameter( - torch.empty( - self.num_experts, - self.hidden_size, - 2 * self.intermediate_size, - dtype=self.dtype, + # If no LoRA, use the original MXFP4 forward (inference path) + if not has_lora: + with torch_cuda_device(hidden_states.device): + if not hasattr(self, "act"): + self.act = FusedActivation( + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + (self.alpha, self.limit), + 2, ) + intermediate_cache1 = matmul_ogs_fn( + hidden_states.to(torch.bfloat16), + self.gate_up_proj, + self.gate_up_proj_bias, + routing_data, + gather_indx=gather_idx, + precision_config=self.gate_up_proj_precision_config, + gammas=None, + fused_activation=self.act, ) - self.gate_up_proj_bias = nn.Parameter( - torch.empty( - self.num_experts, 2 * self.intermediate_size, dtype=self.dtype + intermediate_cache3 = matmul_ogs_fn( + intermediate_cache1, + self.down_proj, + self.down_proj_bias, + routing_data, + scatter_indx=scatter_idx, + precision_config=self.down_proj_precision_config, + gammas=routing_data.gate_scal if routing_data else None, + ) + return intermediate_cache3 + + # With LoRA: compute base MXFP4 output (no grad) + LoRA delta (with grad) + with torch_cuda_device(hidden_states.device): + # 1. Compute MXFP4 base output (detached, no gradients for base weights) + with torch.no_grad(): + if not hasattr(self, "act"): + self.act = FusedActivation( + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + (self.alpha, self.limit), + 2, ) + + # Gate-up projection (MXFP4) + pre_activation = matmul_ogs_fn( + hidden_states.to(torch.bfloat16), + self.gate_up_proj, + self.gate_up_proj_bias, + routing_data, + gather_indx=gather_idx, + precision_config=self.gate_up_proj_precision_config, + gammas=None, + fused_activation=None, # Don't fuse activation so we can add LoRA before ) - self.down_proj = nn.Parameter( - torch.empty( - self.num_experts, - self.intermediate_size, - self.hidden_size, - dtype=self.dtype, + + # 2. Add LoRA for gate_up_proj using grouped_mm + if gate_up_wrapper is not None: + lora_data = _get_moe_lora_weights(gate_up_wrapper) + if lora_data is not None: + lora_A, lora_B, scaling, num_experts = lora_data + + # Convert triton_kernels routing format to grouped_mm offsets + # routing_data.exp_indx contains expert index for each token + gather_src = gather_idx.src_indx + permuted_input = hidden_states[gather_src].to(torch.bfloat16) + + # Compute offsets from expert indices (cumsum of token counts) + expert_ids = routing_data.exp_indx + num_tokens_per_expert = torch.bincount( + expert_ids.int(), minlength=num_experts + ).int() + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + + # Use grouped_mm for LoRA computation (much faster than loop!) + lora_delta = _apply_lora_grouped_mm( + permuted_input, + lora_A, + lora_B, + offsets, + scaling, + grouped_mm_func=native_moe_grouped_mm, ) + pre_activation = pre_activation + lora_delta + + # 3. Apply SwiGLU activation + alpha = getattr(self, "alpha", 1.702) + limit = getattr(self, "limit", 7.0) + swiglu_output = swiglu_torch_forward(pre_activation, alpha, limit) + + # 4. Down projection (MXFP4) + with torch.no_grad(): + base_output = matmul_ogs_fn( + swiglu_output.detach(), # Detach to prevent grad through MXFP4 + self.down_proj, + self.down_proj_bias, + routing_data, + scatter_indx=scatter_idx, + precision_config=self.down_proj_precision_config, + gammas=routing_data.gate_scal if routing_data else None, ) - self.down_proj_bias = nn.Parameter( - torch.empty(self.num_experts, self.hidden_size, dtype=self.dtype) + + # 5. Add LoRA for down_proj using grouped_mm + if down_wrapper is not None: + lora_data = _get_moe_lora_weights(down_wrapper) + if lora_data is not None: + lora_A, lora_B, scaling, num_experts = lora_data + + # Compute offsets from expert indices (reuse if available) + expert_ids = routing_data.exp_indx + num_tokens_per_expert = torch.bincount( + expert_ids.int(), minlength=num_experts + ).int() + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + + # Use grouped_mm for LoRA computation + lora_delta = _apply_lora_grouped_mm( + swiglu_output, + lora_A, + lora_B, + offsets, + scaling, + grouped_mm_func=native_moe_grouped_mm, + ) + + # Scatter LoRA delta back to output (apply routing weights) + scatter_dst = scatter_idx.dst_indx + gamma = routing_data.gate_scal.unsqueeze(-1) + base_output.index_add_( + 0, scatter_dst, (lora_delta * gamma).to(base_output.dtype) + ) + + return base_output + + +def patch_mxfp4_gpt_oss_for_lora(): + """ + Patch MXFP4 GPT OSS experts for LoRA training. + + This enables finetuning the MXFP4 quantized model (unsloth/gpt-oss-20b) + with LoRA. + + IMPORTANT: Requires triton_kernels for native MXFP4 matmul. + If triton_kernels is not available, users must either: + - Use Mxfp4Config(dequantize=True) when loading, OR + - Use the BF16 model: 'unsloth/gpt-oss-20b-BF16' + """ + # First patch ParamWrapper for MoE separated LoRA (v5 only) + if _is_transformers_v5(): + patch_param_wrapper_for_moe() + + try: + import transformers.integrations.mxfp4 + + Mxfp4GptOssExpertsClass = getattr( + transformers.integrations.mxfp4, "Mxfp4GptOssExperts", None + ) + if Mxfp4GptOssExpertsClass is None: + if UNSLOTH_ENABLE_LOGGING: + logger.warning( + "Unsloth: Mxfp4GptOssExperts not found in transformers.integrations.mxfp4" + ) + return + except Exception as e: + if UNSLOTH_ENABLE_LOGGING: + logger.warning(f"Unsloth: Could not patch MXFP4 GPT OSS for LoRA: {e}") + return + + # Check if already patched + if hasattr(Mxfp4GptOssExpertsClass, "_unsloth_mxfp4_lora_patched"): + return + + # Only patch if triton_kernels is available + # Without triton_kernels, MXFP4 weights cannot be used directly + # Users must use dequantization or BF16 model + if is_triton_kernels_available(): + # Use native MXFP4 + LoRA (keeps weights quantized) + Mxfp4GptOssExpertsClass._original_forward = Mxfp4GptOssExpertsClass.forward + Mxfp4GptOssExpertsClass.forward = forward_mxfp4_gpt_oss_with_lora + Mxfp4GptOssExpertsClass._unsloth_mxfp4_lora_patched = True + if UNSLOTH_ENABLE_LOGGING: + logger.info("Unsloth: Patched MXFP4 GPT OSS MoE for LoRA training") + else: + # triton_kernels NOT available - do NOT patch + # The model will fail with a helpful error if user tries to use MXFP4 without dequantization + Mxfp4GptOssExpertsClass._unsloth_mxfp4_lora_patched = True + if UNSLOTH_ENABLE_LOGGING: + logger.warning( + "Unsloth: triton_kernels is not installed. MXFP4 GPT OSS will NOT be patched for LoRA.\n" + "To train GPT OSS with LoRA, either:\n" + " 1. Install triton_kernels from OpenAI (for native MXFP4), OR\n" + " 2. Use Mxfp4Config(dequantize=True) when loading (dequantizes to bf16), OR\n" + " 3. Use the BF16 model: 'unsloth/gpt-oss-20b-BF16'" ) - # Register LoRA extractor - self._unsloth_lora_extractor_fn = _gpt_oss_lora_extractor - @property - def hidden_dim(self): - return self.hidden_size +TEMPORARY_PATCHES.append(patch_mxfp4_gpt_oss_for_lora) - @property - def intermediate_dim(self): - return self.intermediate_size - def forward( - self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None - ) -> torch.Tensor: - """ - Forward pass using grouped GEMM from moe_utils. - Uses forward_native_grouped_mm which handles split LoRA automatically. - """ - # Call moe_utils grouped MM forward which handles everything - return forward_native_grouped_mm( - self, - hidden_states, - router_indices, # top_k_index - routing_weights, # top_k_weights +_MXFP4_DEQUANT_WARNED = False + + +def should_dequantize_mxfp4(): + """ + Check if MXFP4 should be dequantized to bf16. + + Returns True if: + - UNSLOTH_MXFP4_NO_DEQUANTIZE is not set or "0", OR + - UNSLOTH_MXFP4_NO_DEQUANTIZE="1" but triton_kernels is not available + + Returns False if: + - UNSLOTH_MXFP4_NO_DEQUANTIZE="1" AND triton_kernels is available + + MEMORY IMPACT: + - MXFP4 quantized: ~10GB for GPT-OSS 20B + - MXFP4 dequantized to bf16: ~40GB for GPT-OSS 20B + + To keep MXFP4 quantized (requires triton_kernels): + export UNSLOTH_MXFP4_NO_DEQUANTIZE=1 + """ + global _MXFP4_DEQUANT_WARNED + + if not UNSLOTH_MXFP4_NO_DEQUANTIZE: + # Default: dequantize to bf16 + if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: + _MXFP4_DEQUANT_WARNED = True + logger.warning( + "Unsloth: MXFP4 will be dequantized to bf16 (~4x memory increase). " + "To keep 4-bit: set UNSLOTH_MXFP4_NO_DEQUANTIZE=1 and install triton_kernels." ) + return True - # MLP wrapper that works with stacked experts - class GptOssMLP_BF16(nn.Module): - def __init__(self, config): - super().__init__() - self.router = GptOssTopKRouter(config) - self.experts = GptOssExpertsBF16Stacked(config) + if not is_triton_kernels_available(): + if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: + _MXFP4_DEQUANT_WARNED = True + logger.warning( + "Unsloth: UNSLOTH_MXFP4_NO_DEQUANTIZE=1 but triton_kernels not available. " + "Will dequantize MXFP4 to bf16 (~4x memory increase). " + "Install triton_kernels to keep 4-bit quantized weights." + ) + return True # triton_kernels required for native MXFP4 + + if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: + _MXFP4_DEQUANT_WARNED = True + logger.info("Unsloth: Keeping MXFP4 quantized (triton_kernels available)") + return False # Keep MXFP4 quantized + + +def patch_gpt_oss_linearized(): + """ + Patch GPT OSS for 4bit loading with grouped_mm support. + Only patches the GptOssExperts forward method - keeps original classes for proper weight loading. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if _should_use_gpt_oss_bnb4bit(): return + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) - def forward(self, hidden_states): - bsz, qlen, hd = hidden_states.shape - hidden_states_flat = hidden_states.view(-1, hd) + # Don't replace classes - just patch the forward method of GptOssExperts + # This keeps the original class structure which properly handles 4-bit weight loading + backend = select_moe_backend() - # Get router scores - router_scores, router_indices = self.router(hidden_states_flat) + if backend == "grouped_mm": - # Run experts with stacked weights - routed_out = self.experts( - hidden_states_flat, - router_indices=router_indices, - routing_weights=router_scores, + def experts_forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + return forward_native_grouped_mm( + self, hidden_states, router_indices, routing_weights ) + else: - routed_out = routed_out.view(bsz, qlen, hd) - return routed_out, router_scores + def experts_forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + return forward_native_moe_loop( + self, hidden_states, router_indices, routing_weights + ) - # Patch transformers module - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts._unsloth_lora_extractor_fn = _gpt_oss_lora_extractor + # Patch the original transformers class forward method + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward - # Check if model has stacked weights (BF16) vs ModuleList (4-bit) - # This is done at load time by checking the model structure if UNSLOTH_ENABLE_LOGGING: - logger.info("Unsloth: Patched GPT-OSS for BF16 split LoRA with grouped GEMM") + logger.info( + f"Unsloth: Patched GPT OSS MoE for 4bit loading (backend: {backend})" + ) + return pass -TEMPORARY_PATCHES.append(patch_gpt_oss_bf16_split_lora) +TEMPORARY_PATCHES.append(patch_gpt_oss_linearized) def patch_GptOssAttention(): @@ -1133,17 +1857,20 @@ def patch_GptOssAttention(): flex_attention_with_sink_decoding, flex_attention_add_sinks, ) + assert flex_attention_with_sink is not None except Exception as e: return raise_error("flex_attention_with_sink", e) try: import transformers.models.gpt_oss.modeling_gpt_oss + transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb except Exception as e: return raise_error("transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention", e) torch._dynamo.config.cache_size_limit = 256 + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -1158,6 +1885,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: F_softmax = torch.nn.functional.softmax F_dropout = nn.functional.dropout matmul = torch.matmul + def inplace_eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -1173,9 +1901,13 @@ def inplace_eager_attention_forward( bsz, n_heads, qlen, _ = query.shape bsz, n_heads, kvlen, _ = key_states.shape - combined_logits = key_states.new_empty((bsz, n_heads, qlen, kvlen+1)) + out_dtype = torch.result_type(query, key_states) + combined_logits = key_states.new_empty( + (bsz, n_heads, qlen, kvlen + 1), + dtype=out_dtype, + ) - attn_weights = matmul(query, key_states.transpose(2, 3), out = combined_logits[:,:,:,:kvlen]) + attn_weights = matmul(query, key_states.transpose(2, 3), out=combined_logits[:, :, :, :kvlen]) attn_weights *= scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -1185,16 +1917,16 @@ def inplace_eager_attention_forward( # combined_logits = torch.cat([attn_weights, sinks], dim=-1) combined_logits[:, :, :, -1] = module.sinks.reshape(1, -1, 1) - # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 - # when training with bsz>1 we clamp max values. - # combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=torch.float32) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) probs = combined_logits scores = probs[..., :-1] # we drop the sink here - attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) - attn_output = matmul(attn_weights, value_states, out = query) + attn_weights = F_dropout(scores, p=dropout, training=module.training) + attn_weights = attn_weights.to(value_states.dtype) + attn_output = matmul(attn_weights, value_states, out=query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None + pass def eager_attention_forward( @@ -1218,21 +1950,21 @@ def eager_attention_forward( sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) combined_logits = torch.cat([attn_weights, sinks], dim=-1) - # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 - # when training with bsz>1 we clamp max values. - # combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=torch.float32) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) probs = combined_logits scores = probs[..., :-1] # we drop the sink here - attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) - attn_output = matmul(attn_weights, value_states, out = query) + attn_weights = F_dropout(scores, p=dropout, training=module.training) + attn_weights = attn_weights.to(value_states.dtype) + attn_output = matmul(attn_weights, value_states, out=query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None + pass apply_rotary_pos_emb = torch_compile(apply_rotary_pos_emb) - if False: # Version(torch.__version__) >= Version("2.10.0"): - eager_attention_forward = torch_compile(eager_attention_forward, dynamic = None, fullgraph = True) + if False: # Version(torch.__version__) >= Version("2.10.0"): + eager_attention_forward = torch_compile(eager_attention_forward, dynamic=None, fullgraph=True) else: # Too many recompilation failures on 2.8.0, 2.9.0 eager_attention_forward = inplace_eager_attention_forward @@ -1246,7 +1978,7 @@ def forward_function( cache_position: Optional[torch.LongTensor] = None, **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, torch.Tensor]: - input_shape = hidden_states.shape[:-1] + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -1254,8 +1986,24 @@ def forward_function( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: + cache_dtype = getattr(past_key_value, "dtype", None) + if cache_dtype is None and hasattr(past_key_value, "layers"): + try: + cache_layer = past_key_value.layers[self.layer_idx] + if hasattr(cache_layer, "keys") and cache_layer.keys is not None: + cache_dtype = cache_layer.keys.dtype + except Exception: + cache_dtype = None + if cache_dtype is not None and key_states.dtype != cache_dtype: + key_states = key_states.to(cache_dtype) + value_states = value_states.to(cache_dtype) cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + if key_states.dtype != query_states.dtype: + key_states = key_states.to(query_states.dtype) + value_states = value_states.to(query_states.dtype) # flex_attention_with_sink only works for training since KV cache is wrong # switch to flex_attention_with_sink which allows all to work @@ -1307,9 +2055,11 @@ def forward_function( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights + pass functions = [] + def forward( self, hidden_states: torch.Tensor, @@ -1320,7 +2070,9 @@ def forward( **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs) + functions.append(forward) + def forward( self, hidden_states: torch.Tensor, @@ -1331,10 +2083,13 @@ def forward( **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs) + functions.append(forward) patch_function_past_key_values(transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention, "forward", functions) # Set env variable for padding purposes os.environ["UNSLOTH_ENABLE_FLEX_ATTENTION"] = "1" + + pass TEMPORARY_PATCHES.append(patch_GptOssAttention) @@ -1344,6 +2099,7 @@ def patch_GptOssModel(): if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return try: import transformers.models.gpt_oss.modeling_gpt_oss + transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel from transformers.models.gpt_oss.modeling_gpt_oss import MoeModelOutputWithPast from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb @@ -1360,6 +2116,7 @@ def patch_GptOssModel(): # Disable mask creations since we don't need them for GPT-OSS import transformers.masking_utils import transformers.generation.utils + def wrap(f): def return_attention_mask(*args, **kwargs): if kwargs["input_embeds"].requires_grad: @@ -1372,7 +2129,9 @@ def return_attention_mask(*args, **kwargs): # Eager return f(*args, **kwargs) pass + return return_attention_mask + pass create_causal_mask = getattr( transformers.masking_utils, @@ -1389,8 +2148,8 @@ def return_attention_mask(*args, **kwargs): if create_sliding_window_causal_mask is None: return raise_error("transformers.masking_utils.create_sliding_window_causal_mask") if not hasattr(transformers.masking_utils, "__patched_causal_mask__"): - transformers.masking_utils._old_create_causal_mask = _torch_compile(transformers.masking_utils.create_causal_mask, fullgraph = False, dynamic = True) - transformers.masking_utils._old_create_sliding_window_causal_mask = _torch_compile(transformers.masking_utils.create_sliding_window_causal_mask, fullgraph = False, dynamic = True) + transformers.masking_utils._old_create_causal_mask = _torch_compile(transformers.masking_utils.create_causal_mask, fullgraph=False, dynamic=True) + transformers.masking_utils._old_create_sliding_window_causal_mask = _torch_compile(transformers.masking_utils.create_sliding_window_causal_mask, fullgraph=False, dynamic=True) transformers.masking_utils.create_causal_mask = wrap(create_causal_mask) transformers.masking_utils.create_sliding_window_causal_mask = wrap(create_sliding_window_causal_mask) transformers.models.gpt_oss.modeling_gpt_oss.create_causal_mask = transformers.masking_utils.create_causal_mask @@ -1405,6 +2164,7 @@ def return_attention_mask(*args, **kwargs): flex_attention_with_sink_decoding, flex_attention_add_sinks, ) + apply_rotary_pos_emb = torch_compile(apply_rotary_pos_emb) try: from transformers.integrations.mxfp4 import mlp_forward @@ -1420,7 +2180,7 @@ def pre_attention_decoding( cache_position: Optional[torch.LongTensor] = None, **kwargs: KWARGS_TYPE, ): - input_shape = hidden_states.shape[:-1] + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -1431,7 +2191,9 @@ def pre_attention_decoding( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) return query_states, key_states, value_states, input_shape + pass + # Do flex_attention_with_sink_decoding with cannot be compiled # attn_output, logsumexp = flex_attention_with_sink_decoding( # self, @@ -1444,6 +2206,7 @@ def post_attention_decoding(self_attn, attn_output, logsumexp, input_shape): attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self_attn.o_proj(attn_output) return attn_output + pass # RMSNorm forward @@ -1455,6 +2218,7 @@ def rms_layernorm_forward(self, hidden_states): hidden_states *= torch.rsqrt_(variance) hidden_states *= self.weight.to(hidden_states.device).to(torch.float32) return hidden_states.to(input_dtype) # main diff with Llama + pass # Re-compiling for each new sequence length which is NOT ideal @@ -1482,19 +2246,21 @@ def pre_forward( position_embeddings=position_embeddings, ) return query_states, key_states, value_states, input_shape + pass fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion = True, - max_autotune = False, # Too slow - shape_padding = True, - cudagraphs = True, - coordinate_descent_tuning = False, - combo_kernels = False, - memory_planning = True, - multi_kernel = False, # Fails on torch 2.10 nightly - use_block_ptr = True, - logging = UNSLOTH_ENABLE_LOGGING, + epilogue_fusion=True, + max_autotune=False, # Too slow + shape_padding=True, + cudagraphs=True, + coordinate_descent_tuning=False, + combo_kernels=False, + memory_planning=True, + multi_kernel=False, # Fails on torch 2.10 nightly + use_block_ptr=True, + logging=UNSLOTH_ENABLE_LOGGING, ) + @_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) def post_forward( self, @@ -1503,13 +2269,16 @@ def post_forward( logsumexp: torch.Tensor, input_shape, ): - hidden_states = post_attention_decoding(self.self_attn, attn_output, logsumexp, input_shape) + hidden_states = post_attention_decoding( + self.self_attn, attn_output, logsumexp, input_shape + ) hidden_states += residual # Fully Connected residual = hidden_states.clone() hidden_states = rms_layernorm_forward(self.post_attention_layernorm, hidden_states) return hidden_states, residual + pass def inference_forward( @@ -1542,6 +2311,7 @@ def inference_forward( residual = hidden_states.clone() hidden_states = rms_layernorm_forward(self.post_attention_layernorm, hidden_states) return hidden_states, residual + pass # if has_static_cache and Version(torch.__version__) >= Version("2.10.0"): # # torch 2.9.0 has excessive compilations @@ -1567,24 +2337,24 @@ def forward( if inputs_embeds is None: # Account for CPU offloaded embed_tokens embed_device = self.embed_tokens.weight.device - inputs_embeds = self.embed_tokens(input_ids.to(embed_device, non_blocking = True)).to(input_ids.device) + inputs_embeds = self.embed_tokens( + input_ids.to(embed_device, non_blocking=True) + ).to(input_ids.device) if not self.training: inputs_embeds.requires_grad_(False) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + past_seen_tokens = (past_key_values.get_seq_length() if past_key_values is not None else 0) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) try: - torch._dynamo.mark_static (hidden_states, 0) + torch._dynamo.mark_static(hidden_states, 0) torch._dynamo.mark_dynamic(hidden_states, 1) - torch._dynamo.mark_static (hidden_states, 2) + torch._dynamo.mark_static(hidden_states, 2) except: pass @@ -1608,6 +2378,8 @@ def forward( # Add hack since residuals need to clone outside of the torch.compile region?? # This forces it to free past residuals torch.compiler.cudagraph_mark_step_begin() + # Initialize for common return path + all_hidden_states = None for decoder_layer in self.layers: hidden_states, residual = inference_forward( decoder_layer, @@ -1620,9 +2392,14 @@ def forward( position_embeddings, **kwargs, ) - if hasattr(decoder_layer.mlp.experts, "gate_up_projs"): - hidden_states = moe_forward_inference(decoder_layer.mlp, hidden_states) - elif decoder_layer.mlp.experts.__class__.__name__ == "Mxfp4GptOssExperts": + _actual_experts = _unwrap_peft_experts(decoder_layer.mlp.experts) + if hasattr(_actual_experts, "gate_up_projs"): + hidden_states = moe_forward_inference( + decoder_layer.mlp, hidden_states + ) + elif ( + _actual_experts.__class__.__name__ == "Mxfp4GptOssExperts" + ): if mlp_forward is None: raise RuntimeError("Unsloth: MXFP4 forward is not found") hidden_states, _ = mlp_forward(decoder_layer.mlp, hidden_states) @@ -1632,8 +2409,36 @@ def forward( pass hidden_states = rms_layernorm_forward(self.norm, hidden_states) else: + # Fix 2D attention_mask being passed to eager_attention_forward which expects 4D + if ( + self.training + and attention_mask is not None + and attention_mask.dim() == 2 + ): + bsz, seq_len = attention_mask.shape + min_dtype = torch.finfo(inputs_embeds.dtype).min + # 1. Expand padding mask to (B, 1, 1, S) + expanded_mask = attention_mask[:, None, None, :].to( + dtype=inputs_embeds.dtype + ) + expanded_mask = (1.0 - expanded_mask) * min_dtype + + # 2. Causal mask (1, 1, S, S) + causal_mask = torch.full((seq_len, seq_len), min_dtype, device=attention_mask.device, dtype=inputs_embeds.dtype) + causal_mask = torch.triu(causal_mask, diagonal=1) + attention_mask = causal_mask[None, None, :, :] + expanded_mask + + # Accumulate hidden states if requested + output_hidden_states = kwargs.get( + "output_hidden_states", self.config.output_hidden_states + ) + all_hidden_states = () if output_hidden_states else None + for decoder_layer in self.layers: - mask = attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask + if output_hidden_states: + all_hidden_states += (hidden_states,) + + mask = (attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask) hidden_states = decoder_layer( hidden_states, attention_mask=mask, @@ -1646,13 +2451,24 @@ def forward( ) pass hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + # Fix float16 / float32 mismatching hidden_states = hidden_states.to(inputs_embeds.dtype) - return process_return(MoeModelOutputWithPast, { - "last_hidden_state" : hidden_states, - "past_key_values" : past_key_values, - }) - patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level = "relaxed") + return process_return( + MoeModelOutputWithPast, + { + "last_hidden_state": hidden_states, + "past_key_values": past_key_values, + "hidden_states": all_hidden_states, + }, + ) + + patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level="relaxed") + + pass TEMPORARY_PATCHES.append(patch_GptOssModel) @@ -1667,11 +2483,14 @@ def forward( SystemContent, ToolDescription, load_harmony_encoding, - ReasoningEffort + ReasoningEffort, ) + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) except: pass + + def encode_conversations_with_harmony( messages, reasoning_effort = "medium", @@ -1688,22 +2507,24 @@ def encode_conversations_with_harmony( assert reasoning_effort in ("low", "medium", "high") match reasoning_effort: - case "low": harmony_reasoning = ReasoningEffort.LOW + case "low": harmony_reasoning = ReasoningEffort.LOW case "medium": harmony_reasoning = ReasoningEffort.MEDIUM - case "high": harmony_reasoning = ReasoningEffort.HIGH + case "high": harmony_reasoning = ReasoningEffort.HIGH convos = [] # Create system message import datetime + today = datetime.datetime.today().strftime("%Y-%m-%d") - system = Message.from_role_and_content(Role.SYSTEM, + system = Message.from_role_and_content( + Role.SYSTEM, SystemContent.new() .with_model_identity(model_identity) .with_reasoning_effort(harmony_reasoning) .with_conversation_start_date(today) .with_knowledge_cutoff("2024-06") - .with_required_channels(["analysis", "commentary", "final"]) + .with_required_channels(["analysis", "commentary", "final"]), ) convos.append(system) @@ -1727,27 +2548,21 @@ def encode_conversations_with_harmony( for message in messages: if message["role"] == "user": - convos.append( - Message.from_role_and_content(Role.USER, message["content"]) - ) + convos.append(Message.from_role_and_content(Role.USER, message["content"])) elif message["role"] == "assistant": if "thinking" in message: x = Message.from_role_and_content(Role.ASSISTANT, message["content"]) x = x.with_channel("analysis") elif "tool_calls" in message: - x = Message.from_role_and_content(Role.ASSISTANT, message['tool_calls'][0]["arguments"]) - x = x.with_channel("commentary")\ - .with_recipient(f"functions.{message['tool_calls'][0]['name']}")\ - .with_content_type("json") + x = Message.from_role_and_content(Role.ASSISTANT, message["tool_calls"][0]["arguments"]) + x = x.with_channel("commentary").with_recipient(f"functions.{message['tool_calls'][0]['name']}").with_content_type("json") else: x = Message.from_role_and_content(Role.ASSISTANT, message["content"]) x = x.with_channel("final") convos.append(x) elif message["role"] == "tool": - x = Message.from_author_and_content( - Author.new(Role.TOOL, f"functions.{message['name']}"), - message["content"], - ).with_recipient("assistant").with_channel("commentary") + x = Message.from_author_and_content(Author.new(Role.TOOL, f"functions.{message['name']}"), message["content"]) + x = x.with_recipient("assistant").with_channel("commentary") convos.append(x) pass @@ -1759,6 +2574,8 @@ def encode_conversations_with_harmony( harmony_input_ids = encoding.render_conversation(convos) harmony_decoded_text = encoding.decode(harmony_input_ids) return harmony_decoded_text, harmony_input_ids + + pass @@ -1767,8 +2584,10 @@ def encode_conversations_with_harmony( # AutoConfig error: 'GptOssConfig' object has no attribute 'max_position_embeddings' try: from transformers.configuration_utils import layer_type_validation + try: from transformers.configuration_utils import PreTrainedConfig + PretrainedConfig = PreTrainedConfig except: from transformers.configuration_utils import PretrainedConfig @@ -1850,9 +2669,7 @@ def __init__( self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads self.layer_types = layer_types if self.layer_types is None: - self.layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) - ] + self.layer_types = ["sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)] layer_type_validation(self.layer_types) # Validate the correctness of rotary position embeddings parameters @@ -1915,7 +2732,13 @@ def __init__( initializer_range: float = 0.02, max_position_embeddings=131072, rms_norm_eps: float = 1e-5, - rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False}, + rope_scaling={ + "rope_type": "yarn", + "factor": 32.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "truncate": False, + }, attention_dropout: float = 0.0, num_experts_per_tok=4, router_aux_loss_coef: float = 0.9, @@ -1943,11 +2766,16 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout - self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.head_dim = ( + head_dim + if head_dim is not None + else self.hidden_size // self.num_attention_heads + ) self.layer_types = layer_types if self.layer_types is None: self.layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + "sliding_attention" if bool((i + 1) % 2) else "full_attention" + for i in range(self.num_hidden_layers) ] layer_type_validation(self.layer_types) self.attention_bias = True @@ -1975,6 +2803,7 @@ def __init__( def patch_gpt_oss_config(): try: import transformers.models.gpt_oss.configuration_gpt_oss + transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig except Exception as e: return raise_error("transformers.models.gpt_oss.configuration_gpt_oss", e) @@ -1988,7 +2817,46 @@ def patch_gpt_oss_config(): patch_function(transformers.models.gpt_oss.configuration_gpt_oss, "GptOssConfig", GptOssConfig) except Exception as e: return raise_error("transformers.models.gpt_oss.configuration_gpt_oss", e) + pass TEMPORARY_PATCHES.append(patch_gpt_oss_config) except Exception as e: raise_error("transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig", e) + + +def patch_gpt_oss_init_weights_modulelist_fix(): + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + + GptOssPreTrainedModel = ( + transformers.models.gpt_oss.modeling_gpt_oss.GptOssPreTrainedModel + ) + GptOssExperts = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts + if getattr(GptOssPreTrainedModel, "_unsloth_init_weights_fixed", False): + return + _original_init_weights = GptOssPreTrainedModel._init_weights + + def _patched_init_weights(self, module): + if isinstance(module, GptOssExperts) and not hasattr(module, "gate_up_proj"): + std = self.config.initializer_range + for up in getattr(module, "gate_up_projs", []): + init.normal_(up.weight, mean=0.0, std=std) + if up.bias is not None: + init.zeros_(up.bias) + for down in getattr(module, "down_projs", []): + init.normal_(down.weight, mean=0.0, std=std) + if down.bias is not None: + init.zeros_(down.bias) + return + _original_init_weights(self, module) + + patch_function(GptOssPreTrainedModel, "_init_weights", _patched_init_weights) + GptOssPreTrainedModel._unsloth_init_weights_fixed = True + + +pass +TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index 5760778a2..89171e387 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -19,6 +19,7 @@ import shutil from typing import Optional, Tuple from torch.autograd import Function +from .utils import logger # Get compile location UNSLOTH_COMPILE_LOCATION = os.environ.get( @@ -711,6 +712,8 @@ def forward_native_grouped_mm( """ # This Unsloth Zoo code section is licensed under AGPL3 + logger.info(f"[DEBUG]Using torch._grouped_mm for MoE forward pass") + # Runtime safety check - defense in depth if not _check_torch_grouped_mm_supported(): major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) From 27af91c64b3613592fada9fafcf95ea28f5b924e Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Fri, 6 Feb 2026 14:35:55 -0500 Subject: [PATCH 03/17] Update gpt_oss.py, make it transformers 4 compatible --- unsloth_zoo/temporary_patches/gpt_oss.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index c60c64f21..b135fbbe2 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -905,8 +905,10 @@ def forward(self, hidden_states): router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_( 1, router_indices, router_top_value ) - return router_logits, router_scores, router_indices - + if _is_transformers_v5(): + return router_logits, router_scores, router_indices + else: + return router_scores, router_indices pass @@ -1073,7 +1075,10 @@ def forward(self, hidden_states): dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_(1, router_indices, router_top_value) - return router_logits, router_scores, router_indices + if _is_transformers_v5(): + return router_logits, router_scores, router_indices + else: + return router_scores, router_indices pass From b032d2248e0da733efdea20f95080f9af23d3c66 Mon Sep 17 00:00:00 2001 From: pluesclues <136766175+pluesclues@users.noreply.github.com> Date: Fri, 6 Feb 2026 15:01:31 -0500 Subject: [PATCH 04/17] Update gpt_oss.py, needed if statement --- unsloth_zoo/temporary_patches/gpt_oss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index b135fbbe2..727455df0 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -905,7 +905,7 @@ def forward(self, hidden_states): router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_( 1, router_indices, router_top_value ) - if _is_transformers_v5(): + if transformers_version >= Version("5.0.0"): return router_logits, router_scores, router_indices else: return router_scores, router_indices @@ -1075,7 +1075,7 @@ def forward(self, hidden_states): dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_(1, router_indices, router_top_value) - if _is_transformers_v5(): + if transformers_version >= Version("5.0.0"): return router_logits, router_scores, router_indices else: return router_scores, router_indices From b1bd22a3c6bdfaca58bd0198a2b56cb317e55d28 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 6 Feb 2026 17:14:48 +0000 Subject: [PATCH 05/17] undo spacing --- unsloth_zoo/temporary_patches/gpt_oss.py | 117 +++++++++-------------- 1 file changed, 47 insertions(+), 70 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 727455df0..6cfddd559 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -183,7 +183,6 @@ def is_kernels_available(): def swizzle_mxfp4(w, w_scale, *args, **kwargs): from triton_kernels import tensor, tensor_details - FP4, convert_layout, wrap_torch_tensor = ( tensor.FP4, tensor.convert_layout, @@ -192,12 +191,8 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): layout = tensor_details.layout StridedLayout = tensor_details.layout.StridedLayout - value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( - mx_axis=1 - ) - w = convert_layout( - wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts - ) + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) # TODO : add that when we are actually sure that it works on B200 # if torch.cuda.get_device_capability()[0] == 10: # constraints = { @@ -231,9 +226,7 @@ def forward( scatter_idx, ): pre_activation = matmul_ogs( - hidden_states.to( - torch.bfloat16 - ), # tl.dot_scaled upcasts to BF16 for old hardware + hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware self_class.gate_up_proj, self_class.gate_up_proj_bias, routing_data, @@ -272,7 +265,6 @@ def forward( ctx.scatter_idx = scatter_idx ctx.routing_data = routing_data return out - pass @staticmethod @@ -305,7 +297,6 @@ def backward(ctx, grad_token): dx_token = torch.zeros_like(grad_token) dx_token.index_add_(0, gather_dst, dx_exp) return (dx_token, None, None, None, None,) - pass pass @@ -893,12 +884,8 @@ def bias(self, value): def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.linear( - hidden_states.to(self.linear.weight.dtype) - ) # (batch_size * seq_len, num_experts) - router_top_value, router_indices = torch.topk( - router_logits, self.top_k, dim=-1 - ) # (seq_len, top_k) + router_logits = self.linear(hidden_states.to(self.linear.weight.dtype)) # (batch_size * seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax( router_top_value, dim=1, dtype=router_top_value.dtype ) @@ -1173,30 +1160,30 @@ def patch_gpt_oss_bnb4bit_auto(): device_memory = torch.xpu.memory.mem_get_info(0)[-1] else: device_memory = torch.cuda.memory.mem_get_info(0)[-1] -use_combo_kernels = False if device_memory / 1024 / 1024 / 1024 <= 40 else True +use_combo_kernels = False if device_memory/1024/1024/1024 <= 40 else True fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion=True, - max_autotune=False, # Too slow - shape_padding=True, - cudagraphs=True, - coordinate_descent_tuning=use_combo_kernels, # Very slow! - combo_kernels=use_combo_kernels, - memory_planning=True, - multi_kernel=False, # Fails on torch 2.10 nightly - use_block_ptr=True, - logging=UNSLOTH_ENABLE_LOGGING, + epilogue_fusion = True, + max_autotune = False, # Too slow + shape_padding = True, + cudagraphs = True, + coordinate_descent_tuning = use_combo_kernels, # Very slow! + combo_kernels = use_combo_kernels, + memory_planning = True, + multi_kernel = False, # Fails on torch 2.10 nightly + use_block_ptr = True, + logging = UNSLOTH_ENABLE_LOGGING, ) no_combo_fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion=True, - max_autotune=False, # Too slow - shape_padding=True, - cudagraphs=True, - coordinate_descent_tuning=use_combo_kernels, # Very slow! - combo_kernels=False, # Breaks on attention - memory_planning=True, - multi_kernel=False, # Fails on torch 2.10 nightly - use_block_ptr=True, - logging=UNSLOTH_ENABLE_LOGGING, + epilogue_fusion = True, + max_autotune = False, # Too slow + shape_padding = True, + cudagraphs = True, + coordinate_descent_tuning = use_combo_kernels, # Very slow! + combo_kernels = False, # Breaks on attention + memory_planning = True, + multi_kernel = False, # Fails on torch 2.10 nightly + use_block_ptr = True, + logging = UNSLOTH_ENABLE_LOGGING, ) @@ -1278,21 +1265,13 @@ def moe_forward_inference(self, hidden_states): @torch_compile(dynamic=True, fullgraph=True) def moe_router_forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear( - hidden_states.to(self.weight.dtype), self.weight, self.bias - ) # (seq_len, num_experts) - router_top_value, router_indices = torch.topk( - router_logits, self.top_k, dim=-1 - ) # (seq_len, top_k) + router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, self.bias) # (seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) dtype = ( torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype ) - router_top_value = torch.nn.functional.softmax( - router_top_value, dim=1, dtype=torch.float32 - ).to(dtype) - router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_( - 1, router_indices, router_top_value - ) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) + router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) return router_scores, router_indices @@ -1912,7 +1891,7 @@ def inplace_eager_attention_forward( dtype=out_dtype, ) - attn_weights = matmul(query, key_states.transpose(2, 3), out=combined_logits[:, :, :, :kvlen]) + attn_weights = matmul(query, key_states.transpose(2, 3), out = combined_logits[:,:,:,:kvlen]) attn_weights *= scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] @@ -1928,7 +1907,7 @@ def inplace_eager_attention_forward( scores = probs[..., :-1] # we drop the sink here attn_weights = F_dropout(scores, p=dropout, training=module.training) attn_weights = attn_weights.to(value_states.dtype) - attn_output = matmul(attn_weights, value_states, out=query) + attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None @@ -1961,7 +1940,7 @@ def eager_attention_forward( scores = probs[..., :-1] # we drop the sink here attn_weights = F_dropout(scores, p=dropout, training=module.training) attn_weights = attn_weights.to(value_states.dtype) - attn_output = matmul(attn_weights, value_states, out=query) + attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None @@ -1983,7 +1962,7 @@ def forward_function( cache_position: Optional[torch.LongTensor] = None, **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, torch.Tensor]: - input_shape = hidden_states.shape[:-1] + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -2134,7 +2113,6 @@ def return_attention_mask(*args, **kwargs): # Eager return f(*args, **kwargs) pass - return return_attention_mask pass @@ -2185,7 +2163,7 @@ def pre_attention_decoding( cache_position: Optional[torch.LongTensor] = None, **kwargs: KWARGS_TYPE, ): - input_shape = hidden_states.shape[:-1] + input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -2254,16 +2232,16 @@ def pre_forward( pass fused_torch_compile_options = get_torch_compile_options( - epilogue_fusion=True, - max_autotune=False, # Too slow - shape_padding=True, - cudagraphs=True, - coordinate_descent_tuning=False, - combo_kernels=False, - memory_planning=True, - multi_kernel=False, # Fails on torch 2.10 nightly - use_block_ptr=True, - logging=UNSLOTH_ENABLE_LOGGING, + epilogue_fusion = True, + max_autotune = False, # Too slow + shape_padding = True, + cudagraphs = True, + coordinate_descent_tuning = False, + combo_kernels = False, + memory_planning = True, + multi_kernel = False, # Fails on torch 2.10 nightly + use_block_ptr = True, + logging = UNSLOTH_ENABLE_LOGGING, ) @_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) @@ -2274,9 +2252,7 @@ def post_forward( logsumexp: torch.Tensor, input_shape, ): - hidden_states = post_attention_decoding( - self.self_attn, attn_output, logsumexp, input_shape - ) + hidden_states = post_attention_decoding(self.self_attn, attn_output, logsumexp, input_shape) hidden_states += residual # Fully Connected @@ -2865,3 +2841,4 @@ def _patched_init_weights(self, module): pass TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) + From 46376a0b218d945267c49e8d2f2b676e677a2f62 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sat, 7 Feb 2026 11:23:02 +0000 Subject: [PATCH 06/17] remove logger --- unsloth_zoo/temporary_patches/moe_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index 89171e387..5760778a2 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -19,7 +19,6 @@ import shutil from typing import Optional, Tuple from torch.autograd import Function -from .utils import logger # Get compile location UNSLOTH_COMPILE_LOCATION = os.environ.get( @@ -712,8 +711,6 @@ def forward_native_grouped_mm( """ # This Unsloth Zoo code section is licensed under AGPL3 - logger.info(f"[DEBUG]Using torch._grouped_mm for MoE forward pass") - # Runtime safety check - defense in depth if not _check_torch_grouped_mm_supported(): major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) From d27d339225ba88b5448272c6c9cc9d2694e3cc6b Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Sat, 7 Feb 2026 21:08:21 +0000 Subject: [PATCH 07/17] Fix tokenizer guard, ModernBERT attention, gpt_oss MoE unwrap Fixes notebook failures for transformers 4.57.6 + TRL 0.22-0.27. Tokenizer None guard (tokenizer_utils.py): - Return early from patch_tokenizer if tokenizer is None (some VLM processors like ERNIE VL may have None tokenizer during loading) - Guard inner tokenizer unwrap when processor.tokenizer is None ModernBERT attention mask fix (temporary_patches/misc.py): - Add patch_modernbert_attention_mask() to fix stride alignment issues in SDPA backward pass with torch.compile - The _update_attention_mask uses .expand() which creates non-contiguous strides not aligned to multiples of 4, causing reinterpret_tensor errors in the inductor backward graph - Fix: make masks contiguous before they enter compiled regions gpt_oss ParamWrapper unwrap (temporary_patches/gpt_oss.py): - Unwrap PEFT ParamWrapper from MoE experts before accessing hidden_size attribute in both GptOssMLP.forward() and the model inference forward path - ParamWrapper (from peft.tuners.lora.layer) wraps nn.Parameter via base_layer attribute; check base_layer, module, _module in order Tested with all 125 notebooks: no regressions on TRL 0.22.2 or 0.27.1. --- unsloth_zoo/temporary_patches/gpt_oss.py | 17 +++++++++++ unsloth_zoo/temporary_patches/misc.py | 37 ++++++++++++++++++++++++ unsloth_zoo/tokenizer_utils.py | 12 +++++++- 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 3a3cbcc5f..41c5f4206 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -641,6 +641,15 @@ def __init__(self, config): self.experts = GptOssExperts(config) def forward(self, hidden_states): + # Unwrap ParamWrapper from experts if needed (PEFT LoRA wraps modules) + if not hasattr(self.experts, "hidden_size"): + _e = self.experts + for _attr in ("base_layer", "module", "_module"): + while not hasattr(_e, "hidden_size") and hasattr(_e, _attr): + _e = getattr(_e, _attr) + if _e is not self.experts: + self.experts = _e + bsz, qlen, hd = hidden_states.shape if qlen == 1 and not self.training: return moe_forward_inference(self, hidden_states), None @@ -1233,6 +1242,14 @@ def forward( position_embeddings, **kwargs, ) + # Unwrap ParamWrapper from experts if needed (PEFT LoRA wraps modules) + _experts = decoder_layer.mlp.experts + for _attr in ("base_layer", "module", "_module"): + while not hasattr(_experts, "hidden_size") and hasattr(_experts, _attr): + _experts = getattr(_experts, _attr) + if _experts is not decoder_layer.mlp.experts: + decoder_layer.mlp.experts = _experts + if hasattr(decoder_layer.mlp.experts, "gate_up_projs"): hidden_states = moe_forward_inference(decoder_layer.mlp, hidden_states) elif decoder_layer.mlp.experts.__class__.__name__ == "Mxfp4GptOssExperts": diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index 806252cd3..a9b0db796 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -489,6 +489,43 @@ def return_attention_mask(*args, **kwargs): TEMPORARY_PATCHES.append(patch_transformers_masks) +def patch_modernbert_attention_mask(): + """Fix ModernBERT attn_bias stride alignment for SDPA backward pass. + + The attention mask created by _prepare_4d_attention_mask uses .expand() + which creates non-contiguous strides. The SDPA compiled backward kernel + requires strides to be multiples of 4. Fix: patch _update_attention_mask + on ModernBertModel to return contiguous masks BEFORE they enter + torch.compile regions, so the inductor backward graph uses aligned strides. + """ + try: + import transformers.models.modernbert.modeling_modernbert as modernbert_module + except Exception: + return # ModernBERT not available, skip + + ModernBertModel = getattr(modernbert_module, "ModernBertModel", None) + if ModernBertModel is None: + return + + original_update = getattr(ModernBertModel, "_update_attention_mask", None) + if original_update is None: + return + + def _update_attention_mask_contiguous(self, attention_mask, output_attentions=False): + global_attention_mask, sliding_window_mask = original_update(self, attention_mask, output_attentions=output_attentions) + # Make masks contiguous so SDPA backward (including compiled graphs) + # gets strides that are multiples of 4 + if global_attention_mask is not None and not global_attention_mask.is_contiguous(): + global_attention_mask = global_attention_mask.contiguous() + if sliding_window_mask is not None and not sliding_window_mask.is_contiguous(): + sliding_window_mask = sliding_window_mask.contiguous() + return global_attention_mask, sliding_window_mask + + ModernBertModel._update_attention_mask = _update_attention_mask_contiguous +pass +TEMPORARY_PATCHES.append(patch_modernbert_attention_mask) + + def patch_CsmForConditionalGeneration_merge(): try: import transformers.models.csm.modeling_csm diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index 9a05cc9ad..f097b7bcb 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -482,6 +482,11 @@ def patch_tokenizer(model, tokenizer): Fixes https://github.com/unslothai/unsloth/issues/5 """ # All Unsloth Zoo code licensed under LGPLv3 + + # Guard against None tokenizer (e.g., some VLM processors without tokenizer) + if tokenizer is None: + return model, tokenizer + joiner = "\1\0=+=\0\1" number_repetitions = 3 - 1 # Number of reserved tokens needed @@ -492,7 +497,12 @@ def patch_tokenizer(model, tokenizer): if hasattr(tokenizer, "image_processor") and hasattr(tokenizer, "apply_chat_template"): patch_processor_call(tokenizer) - if hasattr(tokenizer, "tokenizer"): tokenizer = tokenizer.tokenizer + if hasattr(tokenizer, "tokenizer"): + inner = tokenizer.tokenizer + if inner is None: + # Processor exists but inner tokenizer is None - return as-is + return model, original_tokenizer + tokenizer = inner bad_pad_token = False if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None: From 4e2db2d978db44fb0ffb7ab6ee3e7f765934838e Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 8 Feb 2026 10:27:07 +0000 Subject: [PATCH 08/17] dtype cast --- unsloth_zoo/temporary_patches/gpt_oss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 6cfddd559..68dd79dd2 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -1985,7 +1985,7 @@ def forward_function( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) - if key_states.dtype != query_states.dtype: + if key_states.dtype != query_states.dtype or value_states.dtype != query_states.dtype: key_states = key_states.to(query_states.dtype) value_states = value_states.to(query_states.dtype) @@ -2841,4 +2841,3 @@ def _patched_init_weights(self, module): pass TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) - From 2004c5871a8f3d8493c7f6f27c95889aee7e99d1 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 8 Feb 2026 10:46:26 +0000 Subject: [PATCH 09/17] dtype cast --- unsloth_zoo/temporary_patches/gpt_oss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 68dd79dd2..d8fae2708 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -1985,9 +1985,9 @@ def forward_function( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) - if key_states.dtype != query_states.dtype or value_states.dtype != query_states.dtype: - key_states = key_states.to(query_states.dtype) - value_states = value_states.to(query_states.dtype) + if key_states.dtype != query_states.dtype or value_states.dtype != query_states.dtype: + key_states = key_states.to(query_states.dtype) + value_states = value_states.to(query_states.dtype) # flex_attention_with_sink only works for training since KV cache is wrong # switch to flex_attention_with_sink which allows all to work From e22008be1e276efdfe85fd7e5b95cb2de883765a Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 8 Feb 2026 11:03:20 +0000 Subject: [PATCH 10/17] fix gpt oss grpo --- unsloth_zoo/temporary_patches/gpt_oss.py | 104 +++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index d8fae2708..8d43dcfc0 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -2841,3 +2841,107 @@ def _patched_init_weights(self, module): pass TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) + + +# ============================================================================ +# Patch GptOssForCausalLM.forward for GRPO training +# When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits +# ============================================================================ +def patch_gpt_oss_for_grpo(): + """ + Patch GptOssForCausalLM.forward for GRPO training. + When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits. + This fixes the matrix multiplication dimension mismatch issue in GRPO training. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + + try: + import transformers.models.gpt_oss.modeling_gpt_oss + from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssForCausalLM, + MoeCausalLMOutputWithPast, + ) + + _original_causal_lm_forward = GptOssForCausalLM.forward + + def _patched_causal_lm_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + logits_to_keep=0, + **kwargs, + ): + # This Unsloth Zoo code section is licensed under AGPL3 + + RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1" + + if not RETURN_HIDDEN_STATES: + # Normal forward pass + return _original_causal_lm_forward( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + # RETURN_HIDDEN_STATES mode - return hidden_states instead of logits + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + # Apply slice_indices to hidden_states (same indexing as for logits) + if logits_to_keep != 0: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + hidden_states = hidden_states[:, slice_indices, :] + + # Return hidden_states as "logits" for GRPO to use + return MoeCausalLMOutputWithPast( + loss=None, + aux_loss=getattr(outputs, 'aux_loss', None), + logits=hidden_states, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=getattr(outputs, 'router_logits', None), + ) + + GptOssForCausalLM.forward = _patched_causal_lm_forward + if UNSLOTH_ENABLE_LOGGING: + logger.info("Unsloth: Patched GptOssForCausalLM.forward for GRPO hidden states.") + + except Exception as e: + if UNSLOTH_ENABLE_LOGGING: + logger.warning(f"Unsloth: Could not patch GptOssForCausalLM.forward: {e}") + + +pass +TEMPORARY_PATCHES.append(patch_gpt_oss_for_grpo) From e401797a59ce79eaa60e7ac96928fefde2e3de3d Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Sun, 8 Feb 2026 12:06:14 +0000 Subject: [PATCH 11/17] Add causal_conv1d CUDA probe with dynamic sys.modules scan Probe causal_conv1d CUDA kernels at startup and force the PyTorch slow path when they fail (e.g. sm_100 on B200). Uses identity checks against the original function objects to avoid clobbering vllm's independent Triton-based implementations. Dynamically scans sys.modules instead of hardcoding model module lists, so new models like qwen3_next and mamba_ssm are automatically covered. --- unsloth_zoo/temporary_patches/misc.py | 148 ++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index a9b0db796..4aeaf36d2 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -620,6 +620,106 @@ def _merge_input_ids_with_input_values( TEMPORARY_PATCHES.append(patch_CsmForConditionalGeneration_merge) +def patch_causal_conv1d_cuda_probe(): + """Probe causal_conv1d CUDA kernels and force slow path if broken. + + On GPUs whose compute capability is not supported by pre-built causal_conv1d + CUDA kernels (e.g. sm_100 on B200), `import causal_conv1d` succeeds but calling + `causal_conv1d_fn(...)` fails at runtime with "no kernel image is available". + This probe runs a tiny forward pass at startup to detect the failure, then + nullifies causal_conv1d_fn/causal_conv1d_update everywhere so all Mamba-family + models fall back to their pure-PyTorch slow paths. + """ + try: + import causal_conv1d + from causal_conv1d import causal_conv1d_fn + from causal_conv1d import causal_conv1d_update + except ImportError: + return # Package not installed, transformers already handles this + pass + + if causal_conv1d_fn is None: + return # Already nullified + pass + + if not torch.cuda.is_available(): + return + pass + + # Probe: try a tiny CUDA forward pass + try: + device = torch.device("cuda", torch.cuda.current_device()) + x = torch.randn(1, 4, 8, device=device, dtype=torch.float16) + w = torch.randn(4, 4, device=device, dtype=torch.float16) + b = torch.zeros(4, device=device, dtype=torch.float16) + _ = causal_conv1d_fn(x, w, b, activation="silu") + del x, w, b + return # CUDA kernels work fine + except Exception: + pass # Fall through to disable + pass + + print( + "Unsloth: causal_conv1d CUDA kernels not compatible with this GPU. " + "Using PyTorch slow path for Mamba models." + ) + + import sys + + # 1. Nullify the package exports themselves + for mod_name in ("causal_conv1d", "causal_conv1d.causal_conv1d_interface"): + mod = sys.modules.get(mod_name) + if mod is not None: + if hasattr(mod, "causal_conv1d_fn"): + mod.causal_conv1d_fn = None + if hasattr(mod, "causal_conv1d_update"): + mod.causal_conv1d_update = None + pass + pass + + # 2. Patch is_causal_conv1d_available to return False + try: + import transformers.utils.import_utils + transformers.utils.import_utils.is_causal_conv1d_available = lambda: False + except Exception: + pass + pass + + # 3. Dynamically scan all loaded modules and nullify broken causal_conv1d + # references. Uses identity checks (is) against the original function objects + # to avoid clobbering vllm's independent Triton-based causal_conv1d_fn/update. + _original_fn = causal_conv1d_fn + _original_update = causal_conv1d_update + + def _disabled_lazy_load(): + return (None, None) + pass + + for mod in list(sys.modules.values()): + if mod is None: + continue + # Only nullify references that point to the causal_conv1d package's functions + touched = False + if getattr(mod, "causal_conv1d_fn", None) is _original_fn: + mod.causal_conv1d_fn = None + touched = True + if getattr(mod, "causal_conv1d_update", None) is _original_update: + mod.causal_conv1d_update = None + touched = True + # is_fast_path_available = all((causal_conv1d_fn, ...)) -- must be False + # Only touch it on modules where we just nullified causal_conv1d refs + if touched and getattr(mod, "is_fast_path_available", False): + mod.is_fast_path_available = False + # Replace lazy load stubs (Pattern B: mamba, falcon_mamba) + if hasattr(mod, "_lazy_load_causal_conv1d"): + mod._lazy_load_causal_conv1d = _disabled_lazy_load + if hasattr(mod, "_causal_conv1d_cache"): + mod._causal_conv1d_cache = (None, None) + pass +pass +TEMPORARY_PATCHES.append(patch_causal_conv1d_cuda_probe) + + def patch_GraniteMoeHybridMambaLayer_cuda_kernels_forward(): try: import transformers.models.granitemoehybrid.modeling_granitemoehybrid @@ -976,3 +1076,51 @@ def forward( patch_function(transformers.models.siglip.modeling_siglip.SiglipEncoderLayer, "forward", forward) pass TEMPORARY_PATCHES.append(patch_SiglipEncoderLayer) + + +def patch_Lfm2VlMultiModalProjector(): + """Fix Lfm2VlMultiModalProjector unconditionally creating LayerNorm. + + transformers 4.57.6 ignores config.projector_use_layernorm and always + creates nn.LayerNorm + applies it in forward. The model checkpoint for + LFM2.5-VL-1.6B has projector_use_layernorm=False and ships no layer_norm + weights, so the LayerNorm gets randomly initialized and corrupts features. + Fixed in transformers 5.0.0. This patch backports the fix. + """ + try: + import transformers.models.lfm2_vl.modeling_lfm2_vl as lfm2_vl_module + except Exception: + return + + Projector = getattr(lfm2_vl_module, "Lfm2VlMultiModalProjector", None) + if Projector is None: + return + + # Already patched or already has conditional logic (transformers >= 5.0.0) + if hasattr(Projector, "_unsloth_patched") or "use_layer_norm" in (getattr(Projector.__init__, "__code__", None) and Projector.__init__.__code__.co_varnames or ()): + return + + import torch.nn as nn + original_init = Projector.__init__ + original_forward = Projector.forward + + def patched_init(self, config, *args, **kwargs): + original_init(self, config, *args, **kwargs) + self.use_layer_norm = getattr(config, "projector_use_layernorm", True) + if not self.use_layer_norm: + self.layer_norm = None + + def patched_forward(self, image_features): + image_features = self.pixel_unshuffle(image_features) + if getattr(self, "use_layer_norm", True) and self.layer_norm is not None: + image_features = self.layer_norm(image_features) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + Projector.__init__ = patched_init + Projector.forward = patched_forward + Projector._unsloth_patched = True +pass +TEMPORARY_PATCHES.append(patch_Lfm2VlMultiModalProjector) From 36014eb0a5fa0f4e0e542981592e24bd2a863048 Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 00:26:04 +0000 Subject: [PATCH 12/17] Add transformers 5.0 compat: pad_token_id guard, BNB weight check, PEFT dispatch fix, FP8Linear device, VLM dataset tokenization, push_to_hub_token and DPO vision mapping patches - tokenizer_utils.py: Use getattr for pad_token_id to handle missing attr - bitsandbytes.py: Guard fix_4bit_weight on packed weight shape - misc.py: Add patch_peft_dispatch_bnb_4bit for compress_statistics AttributeError - misc.py: Add patch_trl_push_to_hub_token to ensure to_dict() includes it - misc.py: Add patch_trl_vision_model_mapping for DPO on TRL 0.22.x - vllm_utils.py: Version-gate FP8Linear device kwarg - dataset_utils.py: Add _maybe_tokenize_dataset for VLM skip_prepare_dataset --- unsloth_zoo/dataset_utils.py | 32 ++++++- unsloth_zoo/temporary_patches/bitsandbytes.py | 4 +- unsloth_zoo/temporary_patches/misc.py | 93 +++++++++++++++++++ unsloth_zoo/tokenizer_utils.py | 2 +- unsloth_zoo/vllm_utils.py | 7 +- 5 files changed, 134 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 57c7f3969..6bc6721c1 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -340,21 +340,50 @@ def _train_on_responses_only(examples): # Set it to int(memory_gb_left) so 16Gb = 16 num_proc = min(num_proc, int(memory_gb_left)) + # In transformers 5.0+, VLM models skip dataset preparation in SFTTrainer.__init__ + # (skip_prepare_dataset=True when _is_vlm=True). This means the dataset may not be + # tokenized yet. We need to tokenize it before applying _train_on_responses_only. + def _maybe_tokenize_dataset(dataset): + if dataset is None: + return dataset + sample = next(iter(dataset)) + if "input_ids" in sample: + return dataset # Already tokenized + # Need to tokenize - get the processing class from trainer + _tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer + # Get the actual tokenizer (not processor) for tokenization + if hasattr(_tokenizer, "tokenizer"): + _tok = _tokenizer.tokenizer + else: + _tok = _tokenizer + max_length = getattr(trainer.args, "max_length", None) or getattr(trainer.args, "max_seq_length", 2048) + text_field = getattr(trainer.args, "dataset_text_field", "text") + def _tokenize_fn(examples): + texts = examples.get(text_field) or examples.get("text", []) + return _tok(texts, truncation=True, max_length=max_length, padding=False) + _map_kwargs = {"batched": True, "num_proc": num_proc} + if isinstance(dataset, IterableDataset): + _map_kwargs = {"batched": True} + return dataset.map(_tokenize_fn, **_map_kwargs) + pass + if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None: if not hasattr(trainer.train_dataset, "map"): raise TypeError("Unsloth: train_on_responses_only does not work on lists!") + trainer.train_dataset = _maybe_tokenize_dataset(trainer.train_dataset) if isinstance(trainer.train_dataset, IterableDataset): trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batch_size = trainer.train_dataset._ex_iterable.batch_size, batched = True) else: trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) pass - + if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None: # Eval datasets could be a dict! if type(trainer.eval_dataset) is dict: for key, value in trainer.eval_dataset.items(): if not hasattr(value, "map"): raise TypeError("Unsloth: train_on_responses_only does not work on lists!") + value = _maybe_tokenize_dataset(value) if isinstance(value, IterableDataset): trainer.eval_dataset[key] = value.map(_train_on_responses_only, batch_size = value._ex_iterable.batch_size, batched = True) else: @@ -362,6 +391,7 @@ def _train_on_responses_only(examples): else: if not hasattr(trainer.eval_dataset, "map"): raise TypeError("Unsloth: train_on_responses_only does not work on lists!") + trainer.eval_dataset = _maybe_tokenize_dataset(trainer.eval_dataset) if isinstance(trainer.eval_dataset, IterableDataset): trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batch_size = trainer.eval_dataset._ex_iterable.batch_size, batched = True) else: diff --git a/unsloth_zoo/temporary_patches/bitsandbytes.py b/unsloth_zoo/temporary_patches/bitsandbytes.py index 38037b707..c83b64dfb 100644 --- a/unsloth_zoo/temporary_patches/bitsandbytes.py +++ b/unsloth_zoo/temporary_patches/bitsandbytes.py @@ -49,7 +49,9 @@ def patch_bitsandbytes_linear4bit_forward(): return raise_error("bitsandbytes.Linear4bit", e) def forward(self, x: torch.Tensor): - fix_4bit_weight_quant_state_from_module(self) + # In transformers 5.0+, weights may not be in packed format yet during init + if self.weight.shape[-1] == 1: + fix_4bit_weight_quant_state_from_module(self) # weights are cast automatically as Int8Params, but the bias has to be cast manually diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index 4aeaf36d2..dfb101dd1 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -1124,3 +1124,96 @@ def patched_forward(self, image_features): Projector._unsloth_patched = True pass TEMPORARY_PATCHES.append(patch_Lfm2VlMultiModalProjector) + + +def patch_peft_dispatch_bnb_4bit(): + """Fix PEFT dispatch_bnb_4bit accessing compress_statistics on non-Params4bit weights. + + In transformers 5.0+, BNB quantization loading order changed so weights may still be + nn.Parameter (not Params4bit) when PEFT tries to access .compress_statistics and .quant_type. + This wraps the original dispatch to catch AttributeError and provide defaults. + """ + try: + import peft.tuners.lora.bnb as peft_bnb + original_dispatch = peft_bnb.dispatch_bnb_4bit + except (ImportError, AttributeError): + return + + if hasattr(original_dispatch, "_unsloth_patched"): + return + + def safe_dispatch_bnb_4bit(target, adapter_name, **kwargs): + try: + return original_dispatch(target, adapter_name, **kwargs) + except AttributeError as e: + if "compress_statistics" in str(e) or "quant_type" in str(e): + # Transformers 5.0+: weight not yet quantized as Params4bit + # Retry after ensuring weight has needed attributes + w = target.weight + if not hasattr(w, "compress_statistics"): + w.compress_statistics = getattr( + target, "_bnb_compress_statistics", True + ) + if not hasattr(w, "quant_type"): + w.quant_type = getattr(target, "_bnb_quant_type", "nf4") + return original_dispatch(target, adapter_name, **kwargs) + raise + + safe_dispatch_bnb_4bit._unsloth_patched = True + peft_bnb.dispatch_bnb_4bit = safe_dispatch_bnb_4bit +pass +TEMPORARY_PATCHES.append(patch_peft_dispatch_bnb_4bit) + + +def patch_trl_push_to_hub_token(): + """Ensure to_dict() always includes push_to_hub_token for TRL compat. + + TRL 0.22.x through 0.27.1 do bare dict_args.pop("push_to_hub_token") in + SFTTrainer.__init__ and IterativeSFTTrainer.__init__. On transformers 5.0+, + TrainingArguments.to_dict() no longer includes push_to_hub_token, so the + bare pop raises KeyError. Fix: monkey-patch to_dict() to always include it. + """ + try: + from unsloth_zoo.utils import Version + import transformers + if Version(transformers.__version__) < Version("5.0.0"): + return # Not needed pre-5.0, to_dict() already includes it + from transformers import TrainingArguments + _original_to_dict = TrainingArguments.to_dict + if getattr(_original_to_dict, "_unsloth_patched", False): + return + def _patched_to_dict(self): + d = _original_to_dict(self) + if "push_to_hub_token" not in d: + d["push_to_hub_token"] = None + return d + _patched_to_dict._unsloth_patched = True + TrainingArguments.to_dict = _patched_to_dict + except Exception: + pass +pass +TEMPORARY_PATCHES.append(patch_trl_push_to_hub_token) + + +def patch_trl_vision_model_mapping(): + """Fix DPO vision model detection for TRL 0.22.x + transformers 5.0+. + + TRL 0.22.x uses MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES which was removed in + transformers 5.0.0, replaced by MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES. + When the import fails, the fallback {} makes is_vision_model always False, + silently breaking DPO with vision models. This patch injects the new mapping. + """ + try: + import trl.trainer.dpo_trainer as dpo_mod + except ImportError: + return + current = getattr(dpo_mod, "MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES", None) + if current is not None and len(current) > 0: + return # Already has valid mapping (transformers < 5.0 or TRL >= 0.23) + try: + from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + dpo_mod.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + except ImportError: + pass # Neither mapping available +pass +TEMPORARY_PATCHES.append(patch_trl_vision_model_mapping) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index f097b7bcb..c3a0a99a9 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -602,7 +602,7 @@ def patch_tokenizer(model, tokenizer): model.generation_config.update(pad_token_id = tokenizer.pad_token_id) else: if model is not None: - if model.config.pad_token_id is None: + if getattr(model.config, "pad_token_id", None) is None: model.config.update({"pad_token_id" : tokenizer.pad_token_id}) if getattr(model, "generation_config", None) is not None: model.generation_config.update(pad_token_id = tokenizer.pad_token_id) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index f89df7c19..cc82c81f8 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1329,7 +1329,12 @@ def _override_to(self, *args, **kwargs): layer.quant_method = "fbgemm_fp8" elif fp8_weight_scale.ndim == 2: # This denotes that the model if FP8 dynamic quantized. - layer = FP8Linear(in_features = 0, out_features = 0, bias = has_bias, dtype = dtype, block_size = kwargs['block_size'], device = get_target_device(), activation_scheme = kwargs['activation_scheme']) + fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme']) + # transformers 5.0+ removed device param from FP8Linear.__init__ + import transformers as _tfm + if Version(_tfm.__version__) < Version("5.0.0"): + fp8_kwargs["device"] = get_target_device() + layer = FP8Linear(**fp8_kwargs) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] layer.weight = torch.nn.Parameter(weight, requires_grad = False) From b1f31e9801e89f5d7de54b7f145f17db6c797651 Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 00:29:20 +0000 Subject: [PATCH 13/17] Fix patch_trl_vision_model_mapping to pre-inject old name into transformers auto module Inject MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES as an alias of MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES into transformers.models.auto.modeling_auto before TRL imports it. This allows TRL 0.22.x's bare import to succeed on transformers 5.0+ without needing to modify installed TRL files. --- unsloth_zoo/temporary_patches/misc.py | 34 ++++++++++++++++++--------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index dfb101dd1..948f52eb4 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -1198,22 +1198,34 @@ def _patched_to_dict(self): def patch_trl_vision_model_mapping(): """Fix DPO vision model detection for TRL 0.22.x + transformers 5.0+. - TRL 0.22.x uses MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES which was removed in - transformers 5.0.0, replaced by MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES. - When the import fails, the fallback {} makes is_vision_model always False, - silently breaking DPO with vision models. This patch injects the new mapping. + TRL 0.22.x does a bare import of MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from + transformers.models.auto.modeling_auto. This name was removed in transformers + 5.0.0, replaced by MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES. The import + failure prevents DPO trainer from loading at all. + + Fix: inject the old name as an alias of the new name into the transformers + auto modeling module BEFORE TRL imports it, so the bare import succeeds. + Also patch already-loaded DPO module if it fell back to empty dict. """ try: - import trl.trainer.dpo_trainer as dpo_mod + import transformers.models.auto.modeling_auto as auto_mod except ImportError: return - current = getattr(dpo_mod, "MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES", None) - if current is not None and len(current) > 0: - return # Already has valid mapping (transformers < 5.0 or TRL >= 0.23) + # If the old name already exists and is populated, nothing to do + existing = getattr(auto_mod, "MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES", None) + if existing is not None and len(existing) > 0: + return + # Inject the old name as alias of the new name + new_mapping = getattr(auto_mod, "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", None) + if new_mapping is not None: + auto_mod.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new_mapping + # Also patch already-loaded DPO module if present try: - from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES - dpo_mod.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + import trl.trainer.dpo_trainer as dpo_mod + dpo_current = getattr(dpo_mod, "MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES", None) + if (dpo_current is None or len(dpo_current) == 0) and new_mapping is not None: + dpo_mod.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new_mapping except ImportError: - pass # Neither mapping available + pass pass TEMPORARY_PATCHES.append(patch_trl_vision_model_mapping) From f9a8d3fcbc8a98d84c28f24dc5d35c99be62eb41 Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 10:51:15 +0000 Subject: [PATCH 14/17] Fix vLLM TypeError with transformers 5.0 apply_chat_template return_dict change transformers 5.0.0 changed apply_chat_template(tokenize=True) to default return_dict=True, returning BatchEncoding instead of list[int]. vLLM's safe_apply_chat_template doesn't pass return_dict=False, causing TypeError in _validate_model_input when max(BatchEncoding) yields a string key. Patch wraps the original function to inject return_dict=False when tokenize=True. Version-gated to transformers >= 5.0.0, no-op if vLLM is not installed. --- unsloth_zoo/temporary_patches/misc.py | 40 +++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index 948f52eb4..c4ff97629 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -1229,3 +1229,43 @@ def patch_trl_vision_model_mapping(): pass pass TEMPORARY_PATCHES.append(patch_trl_vision_model_mapping) + + +def patch_vllm_safe_apply_chat_template(): + """Fix vLLM safe_apply_chat_template for transformers 5.0+. + + transformers 5.0.0 changed apply_chat_template(tokenize=True) to default + return_dict=True, returning BatchEncoding instead of list[int]. vLLM's + safe_apply_chat_template doesn't pass return_dict=False, causing TypeError + in _validate_model_input when max(BatchEncoding) returns a string key. + + Fix: wrap the original function to inject return_dict=False when tokenize=True. + """ + try: + from unsloth_zoo.utils import Version + import transformers + if Version(transformers.__version__) < Version("5.0.0"): + return + + import vllm.renderers.hf as hf_mod + _original_safe_apply = getattr(hf_mod, "safe_apply_chat_template", None) + if _original_safe_apply is None: + return + if getattr(_original_safe_apply, "_unsloth_patched", False): + return + + def _patched_safe_apply(model_config, tokenizer, conversation, *, + tools=None, chat_template=None, tokenize=True, **kwargs): + if tokenize: + kwargs["return_dict"] = False + return _original_safe_apply( + model_config, tokenizer, conversation, + tools=tools, chat_template=chat_template, tokenize=tokenize, + **kwargs, + ) + _patched_safe_apply._unsloth_patched = True + hf_mod.safe_apply_chat_template = _patched_safe_apply + except Exception: + pass +pass +TEMPORARY_PATCHES.append(patch_vllm_safe_apply_chat_template) From 1b2e47e4f824662b022f9a7a0e6a7aa952505e9a Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 10:54:45 +0000 Subject: [PATCH 15/17] Force recompile on transformers version mismatch, fix BNB Linear4bit without quant_state compiler.py: When UNSLOTH_COMPILE_OVERWRITE=0 is set, check if the cached file's transformers version differs from the current one. If so, force a recompile instead of silently using stale compiled cache. bitsandbytes.py: Guard Linear4bit.forward against layers with no quant_state (not quantized) by falling back to regular F.linear. Use local quant_state variable in the matmul_4bit call. --- unsloth_zoo/compiler.py | 17 ++++++++++++++++- unsloth_zoo/temporary_patches/bitsandbytes.py | 10 ++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 756135705..d98217ce1 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -882,7 +882,22 @@ def create_new_function( overwrite = True pass if os.environ.get("UNSLOTH_COMPILE_OVERWRITE", "1") == "0": - overwrite = False + # Even with OVERWRITE disabled, force recompile on transformers version mismatch + if file_source is not None and "__UNSLOTH_VERSIONING__" in file_source: + cached_versions = file_source[:file_source.find("__UNSLOTH_VERSIONING__")] + cached_lines = [l.strip() for l in cached_versions.strip().strip('"').split("\n") if l.strip()] + # Format: [unsloth_zoo_version, unsloth_version, transformers_version, trl_version] + cached_tf_version = cached_lines[2] if len(cached_lines) > 2 else "0" + if cached_tf_version != transformers_version: + logger.warning_once( + f"Unsloth: UNSLOTH_COMPILE_OVERWRITE=0 is set, but transformers version changed " + f"({cached_tf_version} -> {transformers_version}). Forcing recompile of {name}." + ) + # Don't set overwrite = False; keep overwrite = True from version mismatch detection + else: + overwrite = False + else: + overwrite = False # Check location def write_file(function_location, write_new_source): diff --git a/unsloth_zoo/temporary_patches/bitsandbytes.py b/unsloth_zoo/temporary_patches/bitsandbytes.py index c83b64dfb..f88f0a426 100644 --- a/unsloth_zoo/temporary_patches/bitsandbytes.py +++ b/unsloth_zoo/temporary_patches/bitsandbytes.py @@ -53,8 +53,14 @@ def forward(self, x: torch.Tensor): if self.weight.shape[-1] == 1: fix_4bit_weight_quant_state_from_module(self) + # Some layers may not be quantized (no quant_state) - fall back to regular matmul + quant_state = getattr(self.weight, "quant_state", None) + if quant_state is None: + bias = None if self.bias is None else self.bias + return torch.nn.functional.linear(x, self.weight, bias) + # weights are cast automatically as Int8Params, but the bias has to be cast manually - + # ** Errors out in torch.compile so remove it # if self.bias is not None and self.bias.dtype != x.dtype: # self.bias.data = self.bias.data.to(x.dtype) @@ -74,7 +80,7 @@ def forward(self, x: torch.Tensor): # Cannot do .t() on Params4bit, instead do it on torch.Tensor weight = self.weight.data.t() - return bitsandbytes.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bitsandbytes.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype) patch_function(bitsandbytes.nn.modules.Linear4bit, "forward", forward) try: From 6d15abfad64334c8f312fa94903aa0f0589a46b2 Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 11:25:20 +0000 Subject: [PATCH 16/17] Minor cleanups: use print for recompile warning, simplify FP8Linear version check compiler.py: Switch logger.warning_once to print for the OVERWRITE=0 version mismatch message. vllm_utils.py: Use Version("transformers") instead of importing the module and reading __version__ manually. --- unsloth_zoo/compiler.py | 2 +- unsloth_zoo/vllm_utils.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d98217ce1..630af4416 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -889,7 +889,7 @@ def create_new_function( # Format: [unsloth_zoo_version, unsloth_version, transformers_version, trl_version] cached_tf_version = cached_lines[2] if len(cached_lines) > 2 else "0" if cached_tf_version != transformers_version: - logger.warning_once( + print( f"Unsloth: UNSLOTH_COMPILE_OVERWRITE=0 is set, but transformers version changed " f"({cached_tf_version} -> {transformers_version}). Forcing recompile of {name}." ) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index cc82c81f8..ea025ac7b 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1331,8 +1331,7 @@ def _override_to(self, *args, **kwargs): # This denotes that the model if FP8 dynamic quantized. fp8_kwargs = dict(in_features=0, out_features=0, bias=has_bias, dtype=dtype, block_size=kwargs['block_size'], activation_scheme=kwargs['activation_scheme']) # transformers 5.0+ removed device param from FP8Linear.__init__ - import transformers as _tfm - if Version(_tfm.__version__) < Version("5.0.0"): + if Version("transformers") < Version("5.0.0"): fp8_kwargs["device"] = get_target_device() layer = FP8Linear(**fp8_kwargs) layer.in_features = weight.shape[1] From cb044f707bb9999ce4c21de2d8f979df54ca21f7 Mon Sep 17 00:00:00 2001 From: Daniel Hanchen Date: Mon, 9 Feb 2026 12:03:28 +0000 Subject: [PATCH 17/17] Remove gpt_oss and compiler import changes that overlap with PR #471 gpt_oss.py: Reset to main since PR #471 (fix_gpt_oss2) handles all MoE fixes. compiler.py: Remove torch_compile/KWARGS_TYPE import hunk (added by #471), keep the OVERWRITE version-mismatch recompile logic which is unique to this PR. Also keep main's data-dependent compile check (.nonzero/.tolist/.item). --- unsloth_zoo/compiler.py | 29 +- unsloth_zoo/temporary_patches/gpt_oss.py | 1896 ++++------------------ 2 files changed, 296 insertions(+), 1629 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 630af4416..b434daf0b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -811,10 +811,6 @@ def create_new_function( imports += "import torch\n" imports += "import torch.nn as nn\n" imports += "from torch.nn import functional as F\n" - if "torch_compile" in new_source: - imports += "from unsloth_zoo.temporary_patches.common import torch_compile\n" - if "KWARGS_TYPE" in new_source: - imports += "from unsloth_zoo.temporary_patches.utils import KWARGS_TYPE\n" imports += ( "from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable\n" ) @@ -889,7 +885,7 @@ def create_new_function( # Format: [unsloth_zoo_version, unsloth_version, transformers_version, trl_version] cached_tf_version = cached_lines[2] if len(cached_lines) > 2 else "0" if cached_tf_version != transformers_version: - print( + logger.warning_once( f"Unsloth: UNSLOTH_COMPILE_OVERWRITE=0 is set, but transformers version changed " f"({cached_tf_version} -> {transformers_version}). Forcing recompile of {name}." ) @@ -3184,15 +3180,26 @@ def replaced_tqdm(*args, **kwargs): bad_torch_modules.add(module) pass - # Check if creating arrays in inside the function - # Error: DataDependentOutputException: aten._local_scalar_dense.default + # Check for data-dependent control flow that breaks torch.compile(fullgraph=True) + # Tier 1: Direct data escapes from tensor to Python + # .nonzero() -> data-dependent output shape (variable-length) + # .tolist() -> materializes tensor values into Python list + # .item() -> materializes tensor scalar into Python + # Tier 2: MoE expert dispatch via torch.where + index_add + # 1-arg torch.where returns data-dependent indices; combined with + # index_add this is the standard MoE routing loop pattern if ( - "torch.arange(" in source - or "torch.zeros(" in source - or "torch.ones(" in source + ".nonzero()" in source + or ".tolist()" in source + or ".item()" in source ): print( - f"Unsloth: Failed compiling function {module} since array creations are done." + f"Unsloth: Will not compile {module} since data-dependent operations are done." + ) + bad_torch_modules.add(module) + elif "torch.where(" in source and ".index_add" in source: + print( + f"Unsloth: Will not compile {module} since data-dependent routing is done." ) bad_torch_modules.add(module) pass diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 8d43dcfc0..3a3cbcc5f 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -18,7 +18,6 @@ import os import torch import torch.nn as nn -import torch.nn.init as init import torch.nn.functional as F import inspect from .common import ( @@ -30,7 +29,6 @@ ) from importlib.metadata import version as importlib_version from ..utils import Version - transformers_version = Version(importlib_version("transformers")) has_static_cache = transformers_version >= Version("4.56.0.dev0") from .utils import ( @@ -44,35 +42,8 @@ process_return, ) from ..hf_utils import dtype_from_config - torch_cuda_device = torch.cuda.device -# MXFP4 configuration -# Set UNSLOTH_MXFP4_NO_DEQUANTIZE=1 to keep MXFP4 weights quantized (requires triton_kernels) -# Otherwise, MXFP4 weights will be dequantized to bf16 for LoRA training -UNSLOTH_MXFP4_NO_DEQUANTIZE = os.environ.get("UNSLOTH_MXFP4_NO_DEQUANTIZE", "0") == "1" - - -def _check_triton_kernels_available(): - """Check if OpenAI's triton_kernels package is available for MXFP4.""" - try: - from triton_kernels import matmul_ogs, swiglu - - return True - except ImportError: - return False - - -_TRITON_KERNELS_AVAILABLE = None - - -def is_triton_kernels_available(): - """Cached check for triton_kernels availability.""" - global _TRITON_KERNELS_AVAILABLE - if _TRITON_KERNELS_AVAILABLE is None: - _TRITON_KERNELS_AVAILABLE = _check_triton_kernels_available() - return _TRITON_KERNELS_AVAILABLE - @torch_compile(dynamic = True, fullgraph = True) def swiglu_torch_forward(a, alpha, limit, dtype = None): @@ -86,11 +57,8 @@ def swiglu_torch_forward(a, alpha, limit, dtype = None): out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) out = out_gelu * (a_linear + 1) return out.to(a.dtype if dtype is None else dtype) - - pass - @torch_compile(dynamic = True, fullgraph = True) def swiglu_torch_backward(pre_act, alpha, limit, g1): g, l = pre_act[..., ::2].to(torch.float32), pre_act[..., 1::2].to(torch.float32) @@ -98,83 +66,56 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): if limit is not None: mask_g = g <= limit mask_l = l.abs() <= limit - ḡ = torch.where(mask_g, g, limit) + ḡ = torch.where(mask_g, g, limit) l̄ = torch.where(mask_l, l, l.sign() * limit) - else: # no clipping + else: # no clipping mask_g = mask_l = torch.ones_like(g, dtype=bool) - ḡ, l̄ = g, l + ḡ, l̄ = g, l - σ = torch.sigmoid(alpha * ḡ) - dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) - dl = ḡ * σ - dg = torch.where(mask_g, dg, 0.0) # clamp-grad - dl = torch.where(mask_l, dl, 0.0) + σ = torch.sigmoid(alpha * ḡ) + dg = (σ + alpha * ḡ * σ * (1 - σ)) * (l̄ + 1) + dl = ḡ * σ + dg = torch.where(mask_g, dg, 0.) # clamp-grad + dl = torch.where(mask_l, dl, 0.) grad = torch.empty_like(pre_act) grad[..., ::2], grad[..., 1::2] = dg, dl return g1 * grad.to(g1.dtype) - - pass def patch_gpt_oss(): try: import triton_kernels - - HAS_TRITON_KERNELS = True except Exception as e: - HAS_TRITON_KERNELS = False - # return raise_error("Please install triton_kernels", e) + return raise_error("Please install triton_kernels", e) try: import transformers.quantizers.quantizer_mxfp4 + def is_kernels_available(): return True + transformers.quantizers.quantizer_mxfp4.is_kernels_available = is_kernels_available + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = lambda *args, **kwargs: True + except Exception as e: + return raise_error("transformers.quantizers.quantizer_mxfp4.is_kernels_available", e) - def is_kernels_available(): - return True + if hasattr(transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer, "_lazy_import_kernels"): + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer._lazy_import_kernels = lambda *args, **kwargs: triton_kernels - transformers.quantizers.quantizer_mxfp4.is_kernels_available = ( - is_kernels_available - ) - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = ( - lambda *args, **kwargs: True - ) + try: + transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = lambda *args, **kwargs: True except Exception as e: - return raise_error( - "transformers.quantizers.quantizer_mxfp4.is_kernels_available", e - ) - - if hasattr( - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer, "_lazy_import_kernels" - ): - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer._lazy_import_kernels = ( - lambda *args, **kwargs: triton_kernels - ) + return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) try: - transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = ( - lambda *args, **kwargs: True + from triton_kernels import matmul_ogs, swiglu + FnSpecs, FusedActivation, matmul_ogs = ( + matmul_ogs.FnSpecs, + matmul_ogs.FusedActivation, + matmul_ogs.matmul_ogs, ) + swiglu_fn = swiglu.swiglu_fn except Exception as e: - return raise_error( - "transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e - ) - - if HAS_TRITON_KERNELS: - try: - from triton_kernels import matmul_ogs, swiglu - - FnSpecs, FusedActivation, matmul_ogs = ( - matmul_ogs.FnSpecs, - matmul_ogs.FusedActivation, - matmul_ogs.matmul_ogs, - ) - swiglu_fn = swiglu.swiglu_fn - except Exception as e: - return raise_error("triton_kernels", e) - else: - # Skip MXFP4 patches when triton_kernels not available - return + return raise_error("triton_kernels", e) try: import transformers.integrations.mxfp4 @@ -207,13 +148,7 @@ def swizzle_mxfp4(w, w_scale, *args, **kwargs): # w_scale = convert_layout(wrap_torch_tensor(w_scale), scale_layout, **scale_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale - - patch_function( - transformers.integrations.mxfp4, - "swizzle_mxfp4", - swizzle_mxfp4, - match_level="relaxed", - ) + patch_function(transformers.integrations.mxfp4, "swizzle_mxfp4", swizzle_mxfp4, match_level = "relaxed") class Mxfp4GptOssExperts_Training(torch.autograd.Function): @staticmethod @@ -270,8 +205,8 @@ def forward( @staticmethod def backward(ctx, grad_token): raise NotImplementedError( - "Backwards pass using MXFP4 is still under construction!\n" - "Instead, use `unsloth/gpt-oss-20b-BF16` for bfloat16 training which will work for LoRA.\n" + "Backwards pass using MXFP4 is still under construction!\n"\ + "Instead, use `unsloth/gpt-oss-20b-BF16` for bfloat16 training which will work for LoRA.\n"\ "Or, use `load_in_4bit = True` which allows finetuning." ) (pre_act, gamma, gather_src, gather_dst, scatter_src, scatter_dst,) = ctx.saved_tensors @@ -298,7 +233,6 @@ def backward(ctx, grad_token): dx_token.index_add_(0, gather_dst, dx_exp) return (dx_token, None, None, None, None,) pass - pass class Mxfp4GptOssExperts(nn.Module): @@ -309,128 +243,41 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size - # MXFP4 quantized format (blocks + scales) self.gate_up_proj_blocks = nn.Parameter( - torch.zeros( - self.num_experts, - 2 * self.intermediate_size, - self.hidden_size // 32, - 16, - dtype=torch.uint8, - ), + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), requires_grad=False, ) self.gate_up_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), + torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), - requires_grad=False, + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False ) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16,), dtype=torch.uint8), + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), requires_grad=False, ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, 16, dtype=torch.uint8), + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), requires_grad=False, ) self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), - requires_grad=False, + torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False ) - self.alpha = 1.702 self.limit = getattr(config, "swiglu_limit", 7.0) self.gate_up_proj_precision_config = None self.down_proj_precision_config = None - @property - def gate_up_proj(self): - """Return gate_up_proj tensor (created from blocks/scales or stored directly).""" - # Check if already set as an attribute (from checkpoint loading or previous dequantization) - if "_gate_up_proj" in self.__dict__: - return self.__dict__["_gate_up_proj"] - - # Check if MXFP4 weights are present (blocks and scales are not all zeros) - blocks_valid = ( - self.gate_up_proj_blocks.device.type != "meta" - and self.gate_up_proj_blocks.numel() > 0 - and self.gate_up_proj_blocks.any() - ) - - if not blocks_valid: - # No MXFP4 weights and no regular weights loaded - raise AttributeError( - f"Mxfp4GptOssExperts.gate_up_proj: No weights loaded. " - f"Try 'openai/gpt-oss-20b' with load_in_4bit=True instead." - ) - - # MXFP4 weights present - dequantize them - try: - from transformers.integrations.mxfp4 import dequantize - # Dequantize: (E, out_dim, in_dim//32, 16) -> (E, out_dim, in_dim) - dequantized = dequantize(self.gate_up_proj_blocks, self.gate_up_proj_scales) - # Cache for future accesses - self.__dict__["_gate_up_proj"] = dequantized - return dequantized - except Exception as e: - raise RuntimeError( - f"Failed to dequantize MXFP4 gate_up_proj: {e}. " - f"Ensure transformers.integrations.mxfp4.dequantize is available." - ) - - @gate_up_proj.setter - def gate_up_proj(self, value): - """Set gate_up_proj tensor (called during checkpoint loading).""" - self.__dict__["_gate_up_proj"] = value - - @property - def down_proj(self): - """Return down_proj tensor (created from blocks/scales or stored directly).""" - if "_down_proj" in self.__dict__: - return self.__dict__["_down_proj"] - - blocks_valid = ( - self.down_proj_blocks.device.type != "meta" - and self.down_proj_blocks.numel() > 0 - and self.down_proj_blocks.any() - ) - - if not blocks_valid: - raise AttributeError( - f"Mxfp4GptOssExperts.down_proj: No weights loaded." - ) - - # MXFP4 weights present - dequantize them - try: - from transformers.integrations.mxfp4 import dequantize - # Dequantize: (E, out_dim, in_dim//32, 16) -> (E, out_dim, in_dim) - dequantized = dequantize(self.down_proj_blocks, self.down_proj_scales) - # Cache for future accesses - self.__dict__["_down_proj"] = dequantized - return dequantized - except Exception as e: - raise RuntimeError( - f"Failed to dequantize MXFP4 down_proj: {e}" - ) - - @down_proj.setter - def down_proj(self, value): - """Set down_proj tensor (called during checkpoint loading).""" - self.__dict__["_down_proj"] = value - - def forward( - self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: with torch_cuda_device(hidden_states.device): if not hasattr(self, "act"): self.act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2) if not hidden_states.requires_grad: intermediate_cache1 = matmul_ogs( - hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware + hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware self.gate_up_proj, self.gate_up_proj_bias, routing_data, @@ -457,50 +304,43 @@ def forward( scatter_idx, ) return intermediate_cache3 - pass - patch_function(transformers.integrations.mxfp4, "Mxfp4GptOssExperts", Mxfp4GptOssExperts) - if HAS_TRITON_KERNELS: - try: - routing = triton_kernels.routing.routing - routing = torch.compiler.disable(routing) - except Exception as e: - return raise_error("triton_kernels.routing.routing", e) - - def mlp_forward(self, hidden_states): - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) - router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) + try: + routing = triton_kernels.routing.routing + routing = torch.compiler.disable(routing) + except Exception as e: + return raise_error("triton_kernels.routing.routing", e) - with torch_cuda_device(router_logits.device): - routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) + def mlp_forward(self, hidden_states): + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) + router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) - routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) - return routed_out, router_logits + with torch_cuda_device(router_logits.device): + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) - patch_function(transformers.integrations.mxfp4, "mlp_forward", mlp_forward) + routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) + routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) + return routed_out, router_logits + patch_function(transformers.integrations.mxfp4, "mlp_forward", mlp_forward) - if HAS_TRITON_KERNELS: - try: - PrecisionConfig, FlexCtx, InFlexData = ( - triton_kernels.matmul_ogs.PrecisionConfig, - triton_kernels.matmul_ogs.FlexCtx, - triton_kernels.matmul_ogs.InFlexData, - ) - except Exception as e: - return raise_error("triton_kernels.matmul_ogs", e) + try: + PrecisionConfig, FlexCtx, InFlexData = ( + triton_kernels.matmul_ogs.PrecisionConfig, + triton_kernels.matmul_ogs.FlexCtx, + triton_kernels.matmul_ogs.InFlexData, + ) + except Exception as e: + return raise_error("triton_kernels.matmul_ogs", e) try: from transformers.integrations.tensor_parallel import shard_and_distribute_module except Exception as e: return raise_error("transformers.integrations.tensor_parallel.shard_and_distribute_module", e) - def load_and_swizzle_mxfp4( - module, param_name, param_value, target_device, *args, **kwargs - ): + def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args, **kwargs): model = kwargs.get("model", None) empty_param = kwargs.get("empty_param", None) casting_dtype = kwargs.get("casting_dtype", None) @@ -511,22 +351,17 @@ def load_and_swizzle_mxfp4( for proj in ["gate_up_proj", "down_proj"]: if proj in param_name: if device_mesh is not None: - shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) + shard_and_distribute_module( + model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh + ) else: setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False)) blocks_attr = f"{proj}_blocks" scales_attr = f"{proj}_scales" blocks = getattr(module, blocks_attr) scales = getattr(module, scales_attr) - # Check if both blocks and scales both not on meta device AND not all zeros - # (if blocks are all zeros, they're from initialization, not from checkpoint) - blocks_valid = ( - blocks.device.type != "meta" - and scales.device.type != "meta" - and blocks.numel() > 0 - and blocks.any() # At least some non-zero values - ) - if blocks_valid: + # Check if both blocks and scales both not on on meta device + if blocks.device.type != "meta" and scales.device.type != "meta": # need it for ep local_experts = blocks.size(0) if proj == "gate_up_proj": @@ -567,16 +402,13 @@ def load_and_swizzle_mxfp4( delattr(module, blocks_attr) # setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) del blocks - pass - patch_function(transformers.integrations.mxfp4, "load_and_swizzle_mxfp4", load_and_swizzle_mxfp4, match_level="relaxed") + patch_function(transformers.integrations.mxfp4, "load_and_swizzle_mxfp4", load_and_swizzle_mxfp4, match_level = "relaxed") try: from transformers.integrations.mxfp4 import _replace_with_mxfp4_linear except Exception as e: - return raise_error( - "transformers.integrations.mxfp4._replace_with_mxfp4_linear", e - ) + return raise_error("transformers.integrations.mxfp4._replace_with_mxfp4_linear", e) def replace_with_mxfp4_linear( model, @@ -586,11 +418,17 @@ def replace_with_mxfp4_linear( config=None, ): if quantization_config.dequantize: return model - modules_to_not_convert = (["lm_head"] if modules_to_not_convert is None else modules_to_not_convert) + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert if quantization_config.modules_to_not_convert is not None: modules_to_not_convert.extend(quantization_config.modules_to_not_convert) modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _replace_with_mxfp4_linear(model, modules_to_not_convert, current_key_name, quantization_config, config=config) + model, has_been_replaced = _replace_with_mxfp4_linear( + model, + modules_to_not_convert, + current_key_name, + quantization_config, + config=config, + ) if not has_been_replaced: logger.warning_once( "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." @@ -599,563 +437,105 @@ def replace_with_mxfp4_linear( ) return model - - patch_function( - transformers.integrations.mxfp4, - "replace_with_mxfp4_linear", - replace_with_mxfp4_linear, - ) - - + patch_function(transformers.integrations.mxfp4, "replace_with_mxfp4_linear", replace_with_mxfp4_linear) pass TEMPORARY_PATCHES.append(patch_gpt_oss) -class ParameterModule(nn.Linear): - """ - A module that wraps a parameter to look like a Linear layer for PEFT. - It inherits from nn.Linear but manages 3D <-> 2D weight conversion. - Unsloth grouped_mm requires 3D weights: - - gate_up: (E, H, 2I) - - down: (E, I, H) - PEFT Linear requires 2D weights: (Out, In). - We store the weight as 2D for PEFT, and reshape for Unsloth via get_param(). - """ - - def __init__( - self, in_features, out_features, shape_3d, permute_to_2d, permute_to_3d - ): - # Initialize nn.Linear (creates weight of size (out, in)) - super().__init__(in_features, out_features, bias=False) - self.shape_3d = shape_3d - self.permute_to_2d = permute_to_2d - self.permute_to_3d = permute_to_3d - - # We expect the caller to set the weight content correctly. - # nn.Linear initialized it randomly. We will overwrite it. - - def extra_repr(self): - return f"in_features={self.in_features}, out_features={self.out_features}, shape_3d={self.shape_3d}" - - def get_param(self): - """Restores the 3D weight for Unsloth computation.""" - # 2D weight (Out, In) -> View (Unflattened 2D) -> Permute -> 3D(E, ...) - # We need to know the unflattened shape. - # gate_up: 2D (E*2I, H). View (E, 2I, H). Permute(0,2,1) -> (E, H, 2I). - # We store (E*2I, H). - # To restore: View (E, 2I, H)? - # (E, 2I, H) is shape_3d permuted by permute_to_2d. - - unflattened_shape = [self.shape_3d[i] for i in self.permute_to_2d] - return ( - self.weight.view(*unflattened_shape) - .permute(*self.permute_to_3d) - .contiguous() - ) - - def set_weight_from_3d(self, weight_3d): - """Sets the 2D weight from a 3D tensor.""" - # 3D -> Permute -> Flatten - weight_2d = weight_3d.permute(*self.permute_to_2d).reshape( - self.out_features, self.in_features - ) - self.weight.data.copy_(weight_2d) - - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - # 'gate_up_proj' in checkpoint is 3D. - # We need to load it into 'gate_up_proj.weight' (2D). - # key is '...gate_up_proj' - key = prefix[:-1] - - if key in state_dict: - # Found the parameter (likely from original model structure where it was a Param) - val = state_dict[key] - # Convert 3D val to 2D - val_2d = val.permute(*self.permute_to_2d).reshape( - self.out_features, self.in_features - ) - - # Put into 'weight' key - state_dict[prefix + "weight"] = val_2d - del state_dict[key] - - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - - class GptOssExperts(nn.Module): - """ - GPT OSS MoE Experts layer with 3D stacked parameters. - Compatible with transformers' _init_weights and supports grouped_mm with split LoRA. - - Uses the same structure as the original transformers GptOssExperts: - - gate_up_proj: (num_experts, hidden_size, 2 * expert_dim) - - gate_up_proj_bias: (num_experts, 2 * expert_dim) - - down_proj: (num_experts, expert_dim, hidden_size) - - down_proj_bias: (num_experts, hidden_size) - """ - def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = config.intermediate_size - self.intermediate_size = config.intermediate_size # Alias for compatibility - self.alpha = 1.702 - self.limit = getattr(config, "swiglu_limit", 7.0) - self.dtype = dtype_from_config(config) - - # gate_up_proj: 3D (E, H, 2I). Target 2D (E*2I, H). - # Permute (0, 2, 1) -> (E, 2I, H). Reverse (0, 2, 1). - self.gate_up_proj = ParameterModule( - in_features=self.hidden_size, - out_features=self.num_experts * 2 * self.expert_dim, - shape_3d=(self.num_experts, self.hidden_size, 2 * self.expert_dim), - permute_to_2d=(0, 2, 1), - permute_to_3d=(0, 2, 1), - ) - # Initialize 3D zero tensor and set 2D weight - self.gate_up_proj.set_weight_from_3d( - torch.zeros( - self.num_experts, - self.hidden_size, - 2 * self.expert_dim, - dtype=self.dtype, - ) - ) - - self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=self.dtype) - ) - - # down_proj: 3D (E, I, H). Target 2D (H, E*I). - # Permute (2, 0, 1) -> (H, E, I). Reverse (1, 2, 0) - self.down_proj = ParameterModule( - in_features=self.num_experts * self.expert_dim, - out_features=self.hidden_size, - shape_3d=(self.num_experts, self.expert_dim, self.hidden_size), - permute_to_2d=(2, 0, 1), - permute_to_3d=(1, 2, 0), - ) - self.down_proj.set_weight_from_3d( - torch.empty( - self.num_experts, self.expert_dim, self.hidden_size, dtype=self.dtype - ) - ) - - self.down_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, dtype=self.dtype) - ) - - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - """ - Override to handle loading 3D tensors (gate_up_proj, down_proj) from original checkpoints. - The original checkpoint has these as nn.Parameter (3D tensors), but we now have - them as ParameterModule (nn.Linear subclass) which expects .weight as 2D tensors. - - This method intercepts the 3D tensors and converts them to the 2D format - that ParameterModule expects. - """ - # Handle gate_up_proj: checkpoint has 3D tensor, we need 2D for ParameterModule.weight - gate_up_key = prefix + "gate_up_proj" - gate_up_weight_key = prefix + "gate_up_proj.weight" - if gate_up_key in state_dict and gate_up_weight_key not in state_dict: - val_3d = state_dict.pop(gate_up_key) - # gate_up_proj: 3D (E, H, 2I) -> permute (0, 2, 1) -> (E, 2I, H) -> reshape (E*2I, H) - val_2d = val_3d.permute(0, 2, 1).reshape( - self.num_experts * 2 * self.expert_dim, # out_features - self.hidden_size, # in_features - ) - state_dict[gate_up_weight_key] = val_2d - - # Handle down_proj: checkpoint has 3D tensor, we need 2D for ParameterModule.weight - down_key = prefix + "down_proj" - down_weight_key = prefix + "down_proj.weight" - if down_key in state_dict and down_weight_key not in state_dict: - val_3d = state_dict.pop(down_key) - # down_proj: 3D (E, I, H) -> permute (2, 0, 1) -> (H, E, I) -> reshape (H, E*I) - val_2d = val_3d.permute(2, 0, 1).reshape( - self.hidden_size, # out_features - self.num_experts * self.expert_dim, # in_features - ) - state_dict[down_weight_key] = val_2d - - # Call parent implementation - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - - def forward( - self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None - ) -> torch.Tensor: - """Forward using grouped_mm or loop fallback with LoRA support.""" - # Use optimized grouped_mm if available - if _check_torch_grouped_mm_supported(): - return forward_native_grouped_mm( - self, hidden_states, router_indices, routing_weights - ) - - # Fallback to loop-based implementation - return forward_native_moe_loop( - self, hidden_states, router_indices, routing_weights - ) - - -pass - - -class _RouterLinearParams(nn.Module): - """ - Simple parameter container that stores weight/bias like nn.Linear - but is NOT nn.Linear itself, so BitsAndBytes will NOT quantize it. - State dict keys: linear.weight, linear.bias (matching BnB 4-bit checkpoints - where the router was saved via an nn.Linear submodule). - """ - def __init__(self, in_features, out_features, dtype): - super().__init__() - self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) - self.bias = nn.Parameter(torch.zeros(out_features, dtype=dtype)) - - def forward(self, input): - return F.linear(input, self.weight, self.bias) - - -class GptOssTopKRouter(nn.Module): - def __init__(self, config): - super().__init__() - self.top_k = config.num_experts_per_tok - self.num_experts = config.num_local_experts - self.hidden_dim = config.hidden_size - # Use _RouterLinearParams (not nn.Linear) to avoid BnB 4-bit quantization. - # State dict keys are router.linear.weight / router.linear.bias, matching - # the BnB 4-bit checkpoint format where router was stored via nn.Linear. - self.linear = _RouterLinearParams( - self.hidden_dim, self.num_experts, dtype=dtype_from_config(config) - ) - - # Properties for compatibility with transformers' _init_weights which expects .weight and .bias - @property - def weight(self): - return self.linear.weight - - @weight.setter - def weight(self, value): - self.linear.weight = value - - @property - def bias(self): - return self.linear.bias - - @bias.setter - def bias(self, value): - self.linear.bias = value - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.linear(hidden_states.to(self.linear.weight.dtype)) # (batch_size * seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - router_top_value = torch.nn.functional.softmax( - router_top_value, dim=1, dtype=router_top_value.dtype - ) - router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_( - 1, router_indices, router_top_value - ) - if transformers_version >= Version("5.0.0"): - return router_logits, router_scores, router_indices - else: - return router_scores, router_indices - -pass - - -# BitsAndBytes 4bit compatible classes for loading pre-quantized models -class GptOssExpertsBnb4bit(nn.Module): - """ - GPT OSS MoE Experts using nn.Linear layers for BitsAndBytes 4bit compatibility. - This version uses gate_up_projs and down_projs as ModuleLists of Linear layers, - which can be properly quantized with BitsAndBytes. - """ - - def __init__(self, config): - super().__init__() - - self.num_experts = config.num_local_experts - self.hidden_size = config.hidden_size - self.expert_dim = config.intermediate_size - self.intermediate_size = config.intermediate_size self.alpha = 1.702 self.limit = getattr(config, "swiglu_limit", 7.0) self.dtype = dtype_from_config(config) self.gate_up_projs = nn.ModuleList([ - nn.Linear(self.hidden_size, 2 * self.expert_dim, bias=True, dtype=self.dtype) + nn.Linear(self.hidden_size, 2 * self.expert_dim, dtype=self.dtype) for _ in range(self.num_experts) ]) self.down_projs = nn.ModuleList([ - nn.Linear(self.expert_dim, self.hidden_size, bias=True, dtype=self.dtype) + nn.Linear(self.expert_dim, self.hidden_size, dtype=self.dtype) for _ in range(self.num_experts) ]) - # Provide minimal tensors for transformers _init_weights compatibility. - # Keep them empty to avoid allocating large bf16 weights in 4-bit mode. - self.register_buffer( - "gate_up_proj", torch.empty(0, dtype=self.dtype), persistent=False - ) - self.register_buffer( - "gate_up_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False - ) - self.register_buffer( - "down_proj", torch.empty(0, dtype=self.dtype), persistent=False - ) - self.register_buffer( - "down_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False - ) + def forward( + self, + hidden_states: torch.Tensor, + router_indices = None, + routing_weights = None + ) -> torch.Tensor: - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) num_experts = routing_weights.shape[1] - if self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hitted[:]: + next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) + # with torch.no_grad(): + # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + # expert_mask = expert_mask.permute(2, 1, 0) + # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + # for expert_idx in expert_hitted[:]: + for expert_idx in range(num_experts): with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx[0]]) + # _, token_idx = torch.where(expert_mask[expert_idx[0]]) + token_idx, _ = torch.where(router_indices == expert_idx) current_state = hidden_states[token_idx] gate_up = self.gate_up_projs[expert_idx](current_state) - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu + gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit) + # gate, up = gate_up[..., ::2], gate_up[..., 1::2] + # gate = gate.clamp(min=None, max=self.limit) + # up = up.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # gated_output = (up + 1) * glu out = self.down_projs[expert_idx](gated_output) - weighted_output = out * routing_weights[token_idx, expert_idx, None] - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + weighted_output = out * routing_weights[token_idx, expert_idx, None].to(torch.float32) + next_states.index_add_(0, token_idx, weighted_output) next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states + return next_states.to(hidden_states.dtype) else: X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] gate_up = torch.stack(gate_up_list, dim=0) - gate = gate_up[..., ::2] - up_h = gate_up[..., 1::2] - gate = gate.clamp(max=self.limit) - up_h = up_h.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - fused = (up_h + 1) * glu + fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = X_rep.dtype) + # gate = gate_up[..., ::2] + # up_h = gate_up[..., 1::2] + # gate = gate.clamp(max=self.limit) + # up_h = up_h.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # fused = (up_h + 1) * glu out_list = [down_l(fused[e]) for e, down_l in enumerate(self.down_projs)] outs = torch.stack(out_list, dim=0) rw = routing_weights.transpose(0, 1).unsqueeze(-1) - mixed = (outs * rw).sum(dim=0) - return mixed.view(batch_size, -1, self.hidden_size) - - + mixed = (outs.to(torch.float32) * rw.to(torch.float32)).sum(dim=0) + return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) pass - -class GptOssTopKRouterBnb4bit(nn.Module): - """ - GPT OSS Router using direct parameters for BitsAndBytes 4bit compatibility. - This version uses weight/bias as direct nn.Parameter instead of nested Linear. - """ - +class GptOssTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - self.dtype = dtype_from_config(config) - self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, dtype=self.dtype)) - self.bias = nn.Parameter(torch.zeros(self.num_experts, dtype=self.dtype)) - - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - # Accept checkpoints that used a Linear router (linear.weight/bias) - linear_weight_key = prefix + "linear.weight" - linear_bias_key = prefix + "linear.bias" - weight_key = prefix + "weight" - bias_key = prefix + "bias" - moved_weight = False - moved_bias = False - if linear_weight_key in state_dict and weight_key not in state_dict: - state_dict[weight_key] = state_dict.pop(linear_weight_key) - moved_weight = True - if linear_bias_key in state_dict and bias_key not in state_dict: - state_dict[bias_key] = state_dict.pop(linear_bias_key) - moved_bias = True - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - if moved_weight: - if linear_weight_key in unexpected_keys: - unexpected_keys.remove(linear_weight_key) - if weight_key in missing_keys: - missing_keys.remove(weight_key) - if moved_bias: - if linear_bias_key in unexpected_keys: - unexpected_keys.remove(linear_bias_key) - if bias_key in missing_keys: - missing_keys.remove(bias_key) - if self.weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): - self.weight.data = self.weight.data.to(self.dtype) - if self.bias is not None and self.bias.dtype not in ( - torch.float16, torch.bfloat16, torch.float32 - ): - self.bias.data = self.bias.data.to(self.dtype) + self.linear = nn.Linear(self.hidden_dim, self.num_experts, dtype=dtype_from_config(config)) + @torch_compile(dynamic = True, fullgraph = True) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = torch.nn.functional.linear( - hidden_states.to(self.weight.dtype), self.weight, self.bias - ) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) + router_logits = self.linear(hidden_states.to(self.linear.weight.dtype)) # (batch_size * seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) - router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_(1, router_indices, router_top_value) - if transformers_version >= Version("5.0.0"): - return router_logits, router_scores, router_indices - else: - return router_scores, router_indices - - -pass - - -def patch_gpt_oss_bnb4bit(): - """ - Patch transformers to use BnB 4bit compatible classes when loading pre-quantized models. - This should be called before loading models that were saved with BitsAndBytes quantization - using the linear-based expert structure. - - Usage: - from unsloth_zoo.temporary_patches.gpt_oss import patch_gpt_oss_bnb4bit - patch_gpt_oss_bnb4bit() # Call before loading the model - model = FastLanguageModel.from_pretrained(...) - """ - try: - import transformers.models.gpt_oss.modeling_gpt_oss - except Exception as e: - return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) - - # Store original classes for potential restoration - if not hasattr(transformers.models.gpt_oss.modeling_gpt_oss, '_original_GptOssExperts'): - transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssExperts = \ - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts - transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssTopKRouter = \ - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter - - # Replace with BnB 4bit compatible versions - # Preserve original symbol names for compiler-generated modules. - GptOssExpertsBnb4bit.__name__ = "GptOssExperts" - GptOssExpertsBnb4bit.__qualname__ = "GptOssExperts" - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit - # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. - # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. - # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict - # override to remap keys, but transformers v5 bypasses _load_from_state_dict - # (uses accelerate's set_module_tensor_to_device), so the remapping never ran - # and router weights were randomly initialized - causing high loss (~4-5). - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter - - logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") - os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" - return True - - + router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices pass -def restore_gpt_oss_original(): - """ - Restore original GPT-OSS classes (undo BnB 4bit patch). - """ - try: - import transformers.models.gpt_oss.modeling_gpt_oss - if hasattr(transformers.models.gpt_oss.modeling_gpt_oss, '_original_GptOssExperts'): - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = \ - transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssExperts - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = \ - transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssTopKRouter - logger.info("Unsloth: Restored original GPT OSS classes") - return True - except Exception: - pass - return False - -def patch_gpt_oss_bnb4bit_auto(): - """ - Auto-patch GPT-OSS for BnB 4-bit when load_in_4bit is active. - Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to opt out. - """ - if not _should_use_gpt_oss_bnb4bit(): - return - # Avoid compiler-generated modules missing BnB helpers - os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" - patch_gpt_oss_bnb4bit() - # Ensure inference path avoids torch.compile for 4-bit - try: - global moe_forward_inference - moe_forward_inference = torch.compiler.disable(moe_forward_inference) - except Exception: - pass - - -TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto) - - # Combo kernels uses too much VRAM for low memory GPUs from ..device_type import DEVICE_TYPE - if DEVICE_TYPE == "xpu": device_memory = torch.xpu.memory.mem_get_info(0)[-1] else: @@ -1186,171 +566,72 @@ def patch_gpt_oss_bnb4bit_auto(): logging = UNSLOTH_ENABLE_LOGGING, ) - -@_torch_compile(dynamic=None, fullgraph=True, options=fused_torch_compile_options) +@_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) def moe_forward_inference(self, hidden_states): """Torch compile for forward inference path only with CUDAGraphs""" # Router - router_out = self.router(hidden_states) - if isinstance(router_out, tuple) and len(router_out) == 3: - _, router_scores, router_indices = router_out - else: - router_scores, router_indices = router_out + router_scores, router_indices = self.router(hidden_states) routing_weights = router_scores moe = self.experts batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, moe.hidden_size) num_experts = routing_weights.shape[1] + X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) - # Check if using ModuleList (old style) or 3D parameters (new style) - if hasattr(moe, "gate_up_projs"): - # ModuleList style - X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) - gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(moe.gate_up_projs)] - gate_up = torch.stack(gate_up_list, dim=0) - dtype = ( - torch.float32 - if hidden_states.dtype != torch.bfloat16 - else hidden_states.dtype - ) - fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) - - fused = fused.to(dtype) - device_type = ( - fused.device.type - if isinstance(fused.device.type, str) and fused.device.type != "mps" - else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - out_list = [ - down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs) - ] - outs = torch.stack(out_list, dim=0) - else: - # 3D parameter style (compatible with transformers) - X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) - # gate_up_proj: (E, hidden_size, 2*expert_dim) - bmm: (E, N, H) @ (E, H, 2I) -> (E, N, 2I) - gate_up = ( - torch.bmm(X_rep, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] - ) - dtype = ( - torch.float32 - if hidden_states.dtype != torch.bfloat16 - else hidden_states.dtype - ) - fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) + # Gate up projection + gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(moe.gate_up_projs)] + gate_up = torch.stack(gate_up_list, dim = 0) + dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype + fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype = dtype) - fused = fused.to(dtype) - device_type = ( - fused.device.type - if isinstance(fused.device.type, str) and fused.device.type != "mps" - else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - # down_proj: (E, expert_dim, hidden_size) - bmm: (E, N, I) @ (E, I, H) -> (E, N, H) - outs = ( - torch.bmm(fused.to(dtype), moe.down_proj) - + moe.down_proj_bias[..., None, :] - ) + # Down projection must be done in float32 if not bfloat16 otherwise infinites + fused = fused.to(dtype) + device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + out_list = [down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs)] + outs = torch.stack(out_list, dim=0) rw = routing_weights.to(dtype).transpose(0, 1).unsqueeze(-1) mixed = (outs * rw).sum(dim=0) return mixed.view(batch_size, -1, moe.hidden_size).to(hidden_states.dtype) - - pass - -@torch_compile(dynamic=True, fullgraph=True) +@torch_compile(dynamic = True, fullgraph = True) def moe_router_forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - dtype = ( - torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype - ) + dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) return router_scores, router_indices - - pass - # Combo Kernels errors with InductorError: AttributeError: 'NullKernelHandler' object has no attribute 'index_to_str' -@_torch_compile( - dynamic=None, fullgraph=True, options=no_combo_fused_torch_compile_options -) -def _moe_forward_inference_bf16_kernel( - hidden_states, routing_weights, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias, limit, alpha, hidden_size -): - """Inner compiled kernel for BF16 MoE inference - works with raw tensors only.""" +@_torch_compile(dynamic = None, fullgraph = True, options = no_combo_fused_torch_compile_options) +def moe_forward_inference_bf16(self, hidden_states): + router_scores, router_indices = moe_router_forward(self.router, hidden_states) + routing_weights = router_scores + + moe = self.experts batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, hidden_size) + hidden_states = hidden_states.reshape(-1, moe.hidden_size) num_experts = routing_weights.shape[1] hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, hidden_size) - - gate_up = ( - torch.bmm(hidden_states, gate_up_proj) + gate_up_proj_bias[..., None, :] - ) + hidden_states = hidden_states.view(num_experts, -1, moe.hidden_size) + gate_up = torch.bmm(hidden_states, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(gate.to(torch.float32) * alpha).to(gate.dtype) - next_states = torch.bmm(((up + 1) * glu), down_proj) - next_states = next_states + down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, hidden_size) - next_states = ( - next_states - * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] - ) + gate = gate.clamp(min=None, max=moe.limit) + up = up.clamp(min=-moe.limit, max=moe.limit) + glu = gate * torch.sigmoid(gate.to(torch.float32) * moe.alpha).to(gate.dtype) + next_states = torch.bmm(((up + 1) * glu), moe.down_proj) + next_states = next_states + moe.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, moe.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] next_states = next_states.sum(dim=0) return next_states - - -def _unwrap_peft_experts(module): - """Unwrap PEFT ParamWrapper chain to get the actual experts module.""" - while hasattr(module, 'base_layer'): - module = module.base_layer - return module - - -def moe_forward_inference_bf16(self, hidden_states): - """Wrapper that extracts weights from ParameterModule before calling the compiled kernel.""" - router_scores, router_indices = moe_router_forward(self.router, hidden_states) - routing_weights = router_scores - - moe = _unwrap_peft_experts(self.experts) - - # Handle ParameterModule (which wraps nn.Linear) vs direct 3D tensor - # Extract weights BEFORE the compiled region - gate_up_proj = moe.gate_up_proj - if hasattr(gate_up_proj, "get_param"): - gate_up_proj = gate_up_proj.get_param() - elif hasattr(gate_up_proj, "weight"): - gate_up_proj = gate_up_proj.weight - - down_proj = moe.down_proj - if hasattr(down_proj, "get_param"): - down_proj = down_proj.get_param() - elif hasattr(down_proj, "weight"): - down_proj = down_proj.weight - - return _moe_forward_inference_bf16_kernel( - hidden_states, - routing_weights, - gate_up_proj, - moe.gate_up_proj_bias, - down_proj, - moe.down_proj_bias, - moe.limit, - moe.alpha, - moe.hidden_size, - ) - - +pass class GptOssMLP(nn.Module): @@ -1363,470 +644,94 @@ def forward(self, hidden_states): bsz, qlen, hd = hidden_states.shape if qlen == 1 and not self.training: return moe_forward_inference(self, hidden_states), None - router_out = self.router(hidden_states) - if isinstance(router_out, tuple) and len(router_out) == 3: - _, router_scores, router_indices = router_out - else: - router_scores, router_indices = router_out + router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) return routed_out, router_scores - - pass - -# ============================================================================ -# GPT OSS MoE LoRA Support using grouped GEMM kernels -# ============================================================================ - -# IMPORTS FROM MOE UTILS -from .moe_utils import ( - _check_grouped_gemm_available, - _TORCH_GROUPED_MM_AVAILABLE, - _check_torch_grouped_mm_supported, - native_moe_grouped_mm, - _get_moe_lora_weights, - _apply_lora_grouped_mm, - _get_lora_wrapper_for_param, - select_moe_backend, - patch_param_wrapper_for_moe, - forward_native_grouped_mm, - forward_native_moe_loop, -) - - -def _should_use_gpt_oss_bnb4bit() -> bool: - """ - Decide if GPT-OSS should use BnB-compatible 4-bit experts. - Default: True when load_in_4bit is active. - Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path. - """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): - return False - if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): - return False - return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1" - - -def _is_gpt_oss_4bit_load() -> bool: - return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "") - - -def _is_transformers_v5() -> bool: - return transformers_version >= Version("5.0.0.dev0") - - -def patch_gpt_oss_moe_for_lora(): - """ - Patch GPT OSS MoE experts for LoRA training with grouped GEMM support. - This patches the original GptOssExperts class (with 3D parameter tensors) - to use optimized grouped GEMM kernels with LoRA support. - - IMPORTANT: We only patch the forward method, NOT replace the entire class. - This preserves the original class structure so weights load correctly. - """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): - return - if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit(): - # 4-bit loads should keep quantized weights and use default PEFT LoRA. - return - if not _is_transformers_v5(): - # Split-LoRA grouped_mm path is only needed for transformers v5+ - return - # First patch ParamWrapper for MoE separated LoRA - patch_param_wrapper_for_moe() - - try: - import transformers.models.gpt_oss.modeling_gpt_oss - - # Get the ORIGINAL class - don't replace it! - GptOssExpertsClass = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts - except Exception as e: - if UNSLOTH_ENABLE_LOGGING: - logger.warning(f"Unsloth: Could not patch GPT OSS MoE for LoRA: {e}") - return - - # Check if already patched - if hasattr(GptOssExpertsClass, "_unsloth_lora_patched"): - return - - # Select backend - backend = select_moe_backend() - - if backend == "grouped_mm": - forward = forward_native_grouped_mm - else: - forward = forward_native_moe_loop - - # Store original forward and patch - but DON'T replace the class! - GptOssExpertsClass._original_forward = GptOssExpertsClass.forward - GptOssExpertsClass.forward = forward - GptOssExpertsClass._unsloth_lora_patched = True - - if UNSLOTH_ENABLE_LOGGING: - backend_desc = { - "grouped_mm": "torch._grouped_mm (batched, fastest)", - "unsloth_triton": "Triton kernels", - "native_torch": "loop fallback (slower)", - }.get(backend, backend) - logger.info( - f"Unsloth: Patched GPT OSS MoE for LoRA training using {backend_desc}" - ) - - -TEMPORARY_PATCHES.append(patch_gpt_oss_moe_for_lora) - - -# ============================================================================ -# MXFP4 (4-bit) GPT OSS MoE LoRA Support -# ============================================================================ - -_MXFP4_LORA_PATH_LOGGED = False - -@torch.compiler.disable -def forward_mxfp4_gpt_oss_with_lora( - self, - hidden_states: torch.Tensor, - routing_data, - gather_idx, - scatter_idx, -) -> torch.Tensor: - """ - MXFP4 GPT OSS MoE forward pass with LoRA support. - - For LoRA training with MXFP4: - - Base MXFP4 matmul is computed without gradients (frozen weights) - - LoRA delta is computed with gradients - - Output = base_output + lora_delta - - This allows finetuning MXFP4 quantized models with LoRA. - - Requires triton_kernels for native MXFP4 matmul. - If triton_kernels is not available, model should be loaded with - Mxfp4Config(dequantize=True) to convert to bf16. - """ - if not is_triton_kernels_available(): - raise RuntimeError( - "triton_kernels is required for native MXFP4 GPT OSS forward pass. " - "Either:\n" - " 1. Install triton_kernels from OpenAI, OR\n" - " 2. Load model with dequantization: Mxfp4Config(dequantize=True), OR\n" - " 3. Use the BF16 model: 'unsloth/gpt-oss-20b-BF16'\n" - "Set UNSLOTH_MXFP4_NO_DEQUANTIZE=0 (default) to auto-dequantize." - ) - - from triton_kernels import matmul_ogs, swiglu - - matmul_ogs_fn = matmul_ogs.matmul_ogs - FnSpecs = matmul_ogs.FnSpecs - FusedActivation = matmul_ogs.FusedActivation - swiglu_fn = swiglu.swiglu_fn - - # Get LoRA wrappers - gate_up_wrapper = _get_lora_wrapper_for_param(self, "gate_up_proj") - down_wrapper = _get_lora_wrapper_for_param(self, "down_proj") - - has_lora = gate_up_wrapper is not None or down_wrapper is not None - - # Log path info once - global _MXFP4_LORA_PATH_LOGGED - if not _MXFP4_LORA_PATH_LOGGED: - _MXFP4_LORA_PATH_LOGGED = True - logger.warning_once( - f"Unsloth: GPT-OSS MoE training path: MXFP4 + triton_kernels. " - f"LoRA={has_lora}, experts={self.num_experts}. " - f"Tip: Increase batch_size for better GPU utilization." - ) - - # If no LoRA, use the original MXFP4 forward (inference path) - if not has_lora: - with torch_cuda_device(hidden_states.device): - if not hasattr(self, "act"): - self.act = FusedActivation( - FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), - (self.alpha, self.limit), - 2, - ) - intermediate_cache1 = matmul_ogs_fn( - hidden_states.to(torch.bfloat16), - self.gate_up_proj, - self.gate_up_proj_bias, - routing_data, - gather_indx=gather_idx, - precision_config=self.gate_up_proj_precision_config, - gammas=None, - fused_activation=self.act, - ) - intermediate_cache3 = matmul_ogs_fn( - intermediate_cache1, - self.down_proj, - self.down_proj_bias, - routing_data, - scatter_indx=scatter_idx, - precision_config=self.down_proj_precision_config, - gammas=routing_data.gate_scal if routing_data else None, - ) - return intermediate_cache3 - - # With LoRA: compute base MXFP4 output (no grad) + LoRA delta (with grad) - with torch_cuda_device(hidden_states.device): - # 1. Compute MXFP4 base output (detached, no gradients for base weights) - with torch.no_grad(): - if not hasattr(self, "act"): - self.act = FusedActivation( - FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), - (self.alpha, self.limit), - 2, - ) - - # Gate-up projection (MXFP4) - pre_activation = matmul_ogs_fn( - hidden_states.to(torch.bfloat16), - self.gate_up_proj, - self.gate_up_proj_bias, - routing_data, - gather_indx=gather_idx, - precision_config=self.gate_up_proj_precision_config, - gammas=None, - fused_activation=None, # Don't fuse activation so we can add LoRA before - ) - - # 2. Add LoRA for gate_up_proj using grouped_mm - if gate_up_wrapper is not None: - lora_data = _get_moe_lora_weights(gate_up_wrapper) - if lora_data is not None: - lora_A, lora_B, scaling, num_experts = lora_data - - # Convert triton_kernels routing format to grouped_mm offsets - # routing_data.exp_indx contains expert index for each token - gather_src = gather_idx.src_indx - permuted_input = hidden_states[gather_src].to(torch.bfloat16) - - # Compute offsets from expert indices (cumsum of token counts) - expert_ids = routing_data.exp_indx - num_tokens_per_expert = torch.bincount( - expert_ids.int(), minlength=num_experts - ).int() - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - - # Use grouped_mm for LoRA computation (much faster than loop!) - lora_delta = _apply_lora_grouped_mm( - permuted_input, - lora_A, - lora_B, - offsets, - scaling, - grouped_mm_func=native_moe_grouped_mm, - ) - pre_activation = pre_activation + lora_delta - - # 3. Apply SwiGLU activation - alpha = getattr(self, "alpha", 1.702) - limit = getattr(self, "limit", 7.0) - swiglu_output = swiglu_torch_forward(pre_activation, alpha, limit) - - # 4. Down projection (MXFP4) - with torch.no_grad(): - base_output = matmul_ogs_fn( - swiglu_output.detach(), # Detach to prevent grad through MXFP4 - self.down_proj, - self.down_proj_bias, - routing_data, - scatter_indx=scatter_idx, - precision_config=self.down_proj_precision_config, - gammas=routing_data.gate_scal if routing_data else None, - ) - - # 5. Add LoRA for down_proj using grouped_mm - if down_wrapper is not None: - lora_data = _get_moe_lora_weights(down_wrapper) - if lora_data is not None: - lora_A, lora_B, scaling, num_experts = lora_data - - # Compute offsets from expert indices (reuse if available) - expert_ids = routing_data.exp_indx - num_tokens_per_expert = torch.bincount( - expert_ids.int(), minlength=num_experts - ).int() - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - - # Use grouped_mm for LoRA computation - lora_delta = _apply_lora_grouped_mm( - swiglu_output, - lora_A, - lora_B, - offsets, - scaling, - grouped_mm_func=native_moe_grouped_mm, - ) - - # Scatter LoRA delta back to output (apply routing weights) - scatter_dst = scatter_idx.dst_indx - gamma = routing_data.gate_scal.unsqueeze(-1) - base_output.index_add_( - 0, scatter_dst, (lora_delta * gamma).to(base_output.dtype) - ) - - return base_output - - -def patch_mxfp4_gpt_oss_for_lora(): - """ - Patch MXFP4 GPT OSS experts for LoRA training. - - This enables finetuning the MXFP4 quantized model (unsloth/gpt-oss-20b) - with LoRA. - - IMPORTANT: Requires triton_kernels for native MXFP4 matmul. - If triton_kernels is not available, users must either: - - Use Mxfp4Config(dequantize=True) when loading, OR - - Use the BF16 model: 'unsloth/gpt-oss-20b-BF16' - """ - # First patch ParamWrapper for MoE separated LoRA (v5 only) - if _is_transformers_v5(): - patch_param_wrapper_for_moe() - - try: - import transformers.integrations.mxfp4 - - Mxfp4GptOssExpertsClass = getattr( - transformers.integrations.mxfp4, "Mxfp4GptOssExperts", None - ) - if Mxfp4GptOssExpertsClass is None: - if UNSLOTH_ENABLE_LOGGING: - logger.warning( - "Unsloth: Mxfp4GptOssExperts not found in transformers.integrations.mxfp4" - ) - return - except Exception as e: - if UNSLOTH_ENABLE_LOGGING: - logger.warning(f"Unsloth: Could not patch MXFP4 GPT OSS for LoRA: {e}") - return - - # Check if already patched - if hasattr(Mxfp4GptOssExpertsClass, "_unsloth_mxfp4_lora_patched"): - return - - # Only patch if triton_kernels is available - # Without triton_kernels, MXFP4 weights cannot be used directly - # Users must use dequantization or BF16 model - if is_triton_kernels_available(): - # Use native MXFP4 + LoRA (keeps weights quantized) - Mxfp4GptOssExpertsClass._original_forward = Mxfp4GptOssExpertsClass.forward - Mxfp4GptOssExpertsClass.forward = forward_mxfp4_gpt_oss_with_lora - Mxfp4GptOssExpertsClass._unsloth_mxfp4_lora_patched = True - if UNSLOTH_ENABLE_LOGGING: - logger.info("Unsloth: Patched MXFP4 GPT OSS MoE for LoRA training") - else: - # triton_kernels NOT available - do NOT patch - # The model will fail with a helpful error if user tries to use MXFP4 without dequantization - Mxfp4GptOssExpertsClass._unsloth_mxfp4_lora_patched = True - if UNSLOTH_ENABLE_LOGGING: - logger.warning( - "Unsloth: triton_kernels is not installed. MXFP4 GPT OSS will NOT be patched for LoRA.\n" - "To train GPT OSS with LoRA, either:\n" - " 1. Install triton_kernels from OpenAI (for native MXFP4), OR\n" - " 2. Use Mxfp4Config(dequantize=True) when loading (dequantizes to bf16), OR\n" - " 3. Use the BF16 model: 'unsloth/gpt-oss-20b-BF16'" - ) - - -TEMPORARY_PATCHES.append(patch_mxfp4_gpt_oss_for_lora) - - -_MXFP4_DEQUANT_WARNED = False - - -def should_dequantize_mxfp4(): - """ - Check if MXFP4 should be dequantized to bf16. - - Returns True if: - - UNSLOTH_MXFP4_NO_DEQUANTIZE is not set or "0", OR - - UNSLOTH_MXFP4_NO_DEQUANTIZE="1" but triton_kernels is not available - - Returns False if: - - UNSLOTH_MXFP4_NO_DEQUANTIZE="1" AND triton_kernels is available - - MEMORY IMPACT: - - MXFP4 quantized: ~10GB for GPT-OSS 20B - - MXFP4 dequantized to bf16: ~40GB for GPT-OSS 20B - - To keep MXFP4 quantized (requires triton_kernels): - export UNSLOTH_MXFP4_NO_DEQUANTIZE=1 - """ - global _MXFP4_DEQUANT_WARNED - - if not UNSLOTH_MXFP4_NO_DEQUANTIZE: - # Default: dequantize to bf16 - if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: - _MXFP4_DEQUANT_WARNED = True - logger.warning( - "Unsloth: MXFP4 will be dequantized to bf16 (~4x memory increase). " - "To keep 4-bit: set UNSLOTH_MXFP4_NO_DEQUANTIZE=1 and install triton_kernels." - ) - return True - - if not is_triton_kernels_available(): - if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: - _MXFP4_DEQUANT_WARNED = True - logger.warning( - "Unsloth: UNSLOTH_MXFP4_NO_DEQUANTIZE=1 but triton_kernels not available. " - "Will dequantize MXFP4 to bf16 (~4x memory increase). " - "Install triton_kernels to keep 4-bit quantized weights." - ) - return True # triton_kernels required for native MXFP4 - - if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: - _MXFP4_DEQUANT_WARNED = True - logger.info("Unsloth: Keeping MXFP4 quantized (triton_kernels available)") - return False # Keep MXFP4 quantized - - def patch_gpt_oss_linearized(): - """ - Patch GPT OSS for 4bit loading with grouped_mm support. - Only patches the GptOssExperts forward method - keeps original classes for proper weight loading. - """ if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return - if _should_use_gpt_oss_bnb4bit(): return try: import transformers.models.gpt_oss.modeling_gpt_oss except Exception as e: return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) - # Don't replace classes - just patch the forward method of GptOssExperts - # This keeps the original class structure which properly handles 4-bit weight loading - backend = select_moe_backend() - - if backend == "grouped_mm": - - def experts_forward( - self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None - ) -> torch.Tensor: - return forward_native_grouped_mm( - self, hidden_states, router_indices, routing_weights - ) - else: - - def experts_forward( - self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + # We find down_proj overflows in GPT OSS + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + def forward( + self, + hidden_states: torch.Tensor, + router_indices = None, + routing_weights = None ) -> torch.Tensor: - return forward_native_moe_loop( - self, hidden_states, router_indices, routing_weights - ) - # Patch the original transformers class forward method - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + num_experts = routing_weights.shape[1] + if self.training: + next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) + # with torch.no_grad(): + # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + # expert_mask = expert_mask.permute(2, 1, 0) + # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + # for expert_idx in expert_hitted[:]: + for expert_idx in range(num_experts): + with torch.no_grad(): + # _, token_idx = torch.where(expert_mask[expert_idx[0]]) + token_idx, _ = torch.where(router_indices == expert_idx) + current_state = hidden_states[token_idx] + gate_up = self.gate_up_projs[expert_idx](current_state) + down_proj = self.down_projs[expert_idx] + gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = torch.float32) + # gate, up = gate_up[..., ::2], gate_up[..., 1::2] + # gate = gate.clamp(min=None, max=self.limit) + # up = up.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # gated_output = (up + 1) * glu + + # Force float32 matrix multiply on some down projection modules + gated_output = gated_output.to(torch.float32) + device_type = gated_output.device.type if isinstance(gated_output.device.type, str) and gated_output.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + out = down_proj(gated_output) + weighted_output = out.to(torch.float32) * routing_weights[token_idx, expert_idx, None].to(torch.float32) + next_states.index_add_(0, token_idx, weighted_output) + next_states = next_states.view(batch_size, -1, self.hidden_size) + return next_states.to(torch.float32) + else: + X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) + gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] + gate_up = torch.stack(gate_up_list, dim=0) + dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype + fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = dtype) + # gate = gate_up[..., ::2] + # up_h = gate_up[..., 1::2] + # gate = gate.clamp(max=self.limit) + # up_h = up_h.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # fused = (up_h + 1) * glu + + # Force float32 matrix multiply on down projection only + device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + out_list = [ + down_l(fused[e].to(dtype)) + for e, down_l in enumerate(self.down_projs) + ] + outs = torch.stack(out_list, dim=0) + rw = routing_weights.transpose(0, 1).unsqueeze(-1) + mixed = (outs.to(dtype) * rw.to(dtype)).sum(dim=0) + return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) + pass + pass + GptOssExperts.forward = forward + pass - if UNSLOTH_ENABLE_LOGGING: - logger.info( - f"Unsloth: Patched GPT OSS MoE for 4bit loading (backend: {backend})" - ) + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExperts + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter + transformers.models.gpt_oss.modeling_gpt_oss.GptOssMLP = GptOssMLP return - - pass TEMPORARY_PATCHES.append(patch_gpt_oss_linearized) @@ -1841,20 +746,17 @@ def patch_GptOssAttention(): flex_attention_with_sink_decoding, flex_attention_add_sinks, ) - assert flex_attention_with_sink is not None except Exception as e: return raise_error("flex_attention_with_sink", e) try: import transformers.models.gpt_oss.modeling_gpt_oss - transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb except Exception as e: return raise_error("transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention", e) torch._dynamo.config.cache_size_limit = 256 - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -1869,7 +771,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: F_softmax = torch.nn.functional.softmax F_dropout = nn.functional.dropout matmul = torch.matmul - def inplace_eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -1885,11 +786,7 @@ def inplace_eager_attention_forward( bsz, n_heads, qlen, _ = query.shape bsz, n_heads, kvlen, _ = key_states.shape - out_dtype = torch.result_type(query, key_states) - combined_logits = key_states.new_empty( - (bsz, n_heads, qlen, kvlen + 1), - dtype=out_dtype, - ) + combined_logits = key_states.new_empty((bsz, n_heads, qlen, kvlen+1)) attn_weights = matmul(query, key_states.transpose(2, 3), out = combined_logits[:,:,:,:kvlen]) attn_weights *= scaling @@ -1901,16 +798,16 @@ def inplace_eager_attention_forward( # combined_logits = torch.cat([attn_weights, sinks], dim=-1) combined_logits[:, :, :, -1] = module.sinks.reshape(1, -1, 1) - combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + # combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=torch.float32) probs = combined_logits scores = probs[..., :-1] # we drop the sink here - attn_weights = F_dropout(scores, p=dropout, training=module.training) - attn_weights = attn_weights.to(value_states.dtype) + attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None - pass def eager_attention_forward( @@ -1934,21 +831,21 @@ def eager_attention_forward( sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) combined_logits = torch.cat([attn_weights, sinks], dim=-1) - combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + # combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + combined_logits[:] = F_softmax(combined_logits, dim=-1, dtype=torch.float32) probs = combined_logits scores = probs[..., :-1] # we drop the sink here - attn_weights = F_dropout(scores, p=dropout, training=module.training) - attn_weights = attn_weights.to(value_states.dtype) + attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None - pass apply_rotary_pos_emb = torch_compile(apply_rotary_pos_emb) - if False: # Version(torch.__version__) >= Version("2.10.0"): - eager_attention_forward = torch_compile(eager_attention_forward, dynamic=None, fullgraph=True) + if False: # Version(torch.__version__) >= Version("2.10.0"): + eager_attention_forward = torch_compile(eager_attention_forward, dynamic = None, fullgraph = True) else: # Too many recompilation failures on 2.8.0, 2.9.0 eager_attention_forward = inplace_eager_attention_forward @@ -1970,24 +867,8 @@ def forward_function( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_dtype = getattr(past_key_value, "dtype", None) - if cache_dtype is None and hasattr(past_key_value, "layers"): - try: - cache_layer = past_key_value.layers[self.layer_idx] - if hasattr(cache_layer, "keys") and cache_layer.keys is not None: - cache_dtype = cache_layer.keys.dtype - except Exception: - cache_dtype = None - if cache_dtype is not None and key_states.dtype != cache_dtype: - key_states = key_states.to(cache_dtype) - value_states = value_states.to(cache_dtype) cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - if key_states.dtype != query_states.dtype or value_states.dtype != query_states.dtype: - key_states = key_states.to(query_states.dtype) - value_states = value_states.to(query_states.dtype) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # flex_attention_with_sink only works for training since KV cache is wrong # switch to flex_attention_with_sink which allows all to work @@ -2039,11 +920,9 @@ def forward_function( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights - pass functions = [] - def forward( self, hidden_states: torch.Tensor, @@ -2054,9 +933,7 @@ def forward( **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs) - functions.append(forward) - def forward( self, hidden_states: torch.Tensor, @@ -2067,13 +944,10 @@ def forward( **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs) - functions.append(forward) patch_function_past_key_values(transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention, "forward", functions) # Set env variable for padding purposes os.environ["UNSLOTH_ENABLE_FLEX_ATTENTION"] = "1" - - pass TEMPORARY_PATCHES.append(patch_GptOssAttention) @@ -2083,7 +957,6 @@ def patch_GptOssModel(): if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return try: import transformers.models.gpt_oss.modeling_gpt_oss - transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel from transformers.models.gpt_oss.modeling_gpt_oss import MoeModelOutputWithPast from transformers.models.gpt_oss.modeling_gpt_oss import apply_rotary_pos_emb @@ -2100,7 +973,6 @@ def patch_GptOssModel(): # Disable mask creations since we don't need them for GPT-OSS import transformers.masking_utils import transformers.generation.utils - def wrap(f): def return_attention_mask(*args, **kwargs): if kwargs["input_embeds"].requires_grad: @@ -2114,7 +986,6 @@ def return_attention_mask(*args, **kwargs): return f(*args, **kwargs) pass return return_attention_mask - pass create_causal_mask = getattr( transformers.masking_utils, @@ -2131,8 +1002,8 @@ def return_attention_mask(*args, **kwargs): if create_sliding_window_causal_mask is None: return raise_error("transformers.masking_utils.create_sliding_window_causal_mask") if not hasattr(transformers.masking_utils, "__patched_causal_mask__"): - transformers.masking_utils._old_create_causal_mask = _torch_compile(transformers.masking_utils.create_causal_mask, fullgraph=False, dynamic=True) - transformers.masking_utils._old_create_sliding_window_causal_mask = _torch_compile(transformers.masking_utils.create_sliding_window_causal_mask, fullgraph=False, dynamic=True) + transformers.masking_utils._old_create_causal_mask = _torch_compile(transformers.masking_utils.create_causal_mask, fullgraph = False, dynamic = True) + transformers.masking_utils._old_create_sliding_window_causal_mask = _torch_compile(transformers.masking_utils.create_sliding_window_causal_mask, fullgraph = False, dynamic = True) transformers.masking_utils.create_causal_mask = wrap(create_causal_mask) transformers.masking_utils.create_sliding_window_causal_mask = wrap(create_sliding_window_causal_mask) transformers.models.gpt_oss.modeling_gpt_oss.create_causal_mask = transformers.masking_utils.create_causal_mask @@ -2147,7 +1018,6 @@ def return_attention_mask(*args, **kwargs): flex_attention_with_sink_decoding, flex_attention_add_sinks, ) - apply_rotary_pos_emb = torch_compile(apply_rotary_pos_emb) try: from transformers.integrations.mxfp4 import mlp_forward @@ -2174,9 +1044,7 @@ def pre_attention_decoding( cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) return query_states, key_states, value_states, input_shape - pass - # Do flex_attention_with_sink_decoding with cannot be compiled # attn_output, logsumexp = flex_attention_with_sink_decoding( # self, @@ -2189,7 +1057,6 @@ def post_attention_decoding(self_attn, attn_output, logsumexp, input_shape): attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self_attn.o_proj(attn_output) return attn_output - pass # RMSNorm forward @@ -2201,7 +1068,6 @@ def rms_layernorm_forward(self, hidden_states): hidden_states *= torch.rsqrt_(variance) hidden_states *= self.weight.to(hidden_states.device).to(torch.float32) return hidden_states.to(input_dtype) # main diff with Llama - pass # Re-compiling for each new sequence length which is NOT ideal @@ -2229,7 +1095,6 @@ def pre_forward( position_embeddings=position_embeddings, ) return query_states, key_states, value_states, input_shape - pass fused_torch_compile_options = get_torch_compile_options( epilogue_fusion = True, @@ -2243,7 +1108,6 @@ def pre_forward( use_block_ptr = True, logging = UNSLOTH_ENABLE_LOGGING, ) - @_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) def post_forward( self, @@ -2259,7 +1123,6 @@ def post_forward( residual = hidden_states.clone() hidden_states = rms_layernorm_forward(self.post_attention_layernorm, hidden_states) return hidden_states, residual - pass def inference_forward( @@ -2292,7 +1155,6 @@ def inference_forward( residual = hidden_states.clone() hidden_states = rms_layernorm_forward(self.post_attention_layernorm, hidden_states) return hidden_states, residual - pass # if has_static_cache and Version(torch.__version__) >= Version("2.10.0"): # # torch 2.9.0 has excessive compilations @@ -2318,24 +1180,24 @@ def forward( if inputs_embeds is None: # Account for CPU offloaded embed_tokens embed_device = self.embed_tokens.weight.device - inputs_embeds = self.embed_tokens( - input_ids.to(embed_device, non_blocking=True) - ).to(input_ids.device) + inputs_embeds = self.embed_tokens(input_ids.to(embed_device, non_blocking = True)).to(input_ids.device) if not self.training: inputs_embeds.requires_grad_(False) if cache_position is None: - past_seen_tokens = (past_key_values.get_seq_length() if past_key_values is not None else 0) - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) try: - torch._dynamo.mark_static(hidden_states, 0) + torch._dynamo.mark_static (hidden_states, 0) torch._dynamo.mark_dynamic(hidden_states, 1) - torch._dynamo.mark_static(hidden_states, 2) + torch._dynamo.mark_static (hidden_states, 2) except: pass @@ -2359,8 +1221,6 @@ def forward( # Add hack since residuals need to clone outside of the torch.compile region?? # This forces it to free past residuals torch.compiler.cudagraph_mark_step_begin() - # Initialize for common return path - all_hidden_states = None for decoder_layer in self.layers: hidden_states, residual = inference_forward( decoder_layer, @@ -2373,14 +1233,9 @@ def forward( position_embeddings, **kwargs, ) - _actual_experts = _unwrap_peft_experts(decoder_layer.mlp.experts) - if hasattr(_actual_experts, "gate_up_projs"): - hidden_states = moe_forward_inference( - decoder_layer.mlp, hidden_states - ) - elif ( - _actual_experts.__class__.__name__ == "Mxfp4GptOssExperts" - ): + if hasattr(decoder_layer.mlp.experts, "gate_up_projs"): + hidden_states = moe_forward_inference(decoder_layer.mlp, hidden_states) + elif decoder_layer.mlp.experts.__class__.__name__ == "Mxfp4GptOssExperts": if mlp_forward is None: raise RuntimeError("Unsloth: MXFP4 forward is not found") hidden_states, _ = mlp_forward(decoder_layer.mlp, hidden_states) @@ -2390,36 +1245,8 @@ def forward( pass hidden_states = rms_layernorm_forward(self.norm, hidden_states) else: - # Fix 2D attention_mask being passed to eager_attention_forward which expects 4D - if ( - self.training - and attention_mask is not None - and attention_mask.dim() == 2 - ): - bsz, seq_len = attention_mask.shape - min_dtype = torch.finfo(inputs_embeds.dtype).min - # 1. Expand padding mask to (B, 1, 1, S) - expanded_mask = attention_mask[:, None, None, :].to( - dtype=inputs_embeds.dtype - ) - expanded_mask = (1.0 - expanded_mask) * min_dtype - - # 2. Causal mask (1, 1, S, S) - causal_mask = torch.full((seq_len, seq_len), min_dtype, device=attention_mask.device, dtype=inputs_embeds.dtype) - causal_mask = torch.triu(causal_mask, diagonal=1) - attention_mask = causal_mask[None, None, :, :] + expanded_mask - - # Accumulate hidden states if requested - output_hidden_states = kwargs.get( - "output_hidden_states", self.config.output_hidden_states - ) - all_hidden_states = () if output_hidden_states else None - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - mask = (attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask) + mask = attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask hidden_states = decoder_layer( hidden_states, attention_mask=mask, @@ -2432,24 +1259,13 @@ def forward( ) pass hidden_states = self.norm(hidden_states) - - if output_hidden_states: - all_hidden_states += (hidden_states,) - # Fix float16 / float32 mismatching hidden_states = hidden_states.to(inputs_embeds.dtype) - return process_return( - MoeModelOutputWithPast, - { - "last_hidden_state": hidden_states, - "past_key_values": past_key_values, - "hidden_states": all_hidden_states, - }, - ) - - patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level="relaxed") - - + return process_return(MoeModelOutputWithPast, { + "last_hidden_state" : hidden_states, + "past_key_values" : past_key_values, + }) + patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level = "relaxed") pass TEMPORARY_PATCHES.append(patch_GptOssModel) @@ -2464,14 +1280,11 @@ def forward( SystemContent, ToolDescription, load_harmony_encoding, - ReasoningEffort, + ReasoningEffort ) - encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) except: pass - - def encode_conversations_with_harmony( messages, reasoning_effort = "medium", @@ -2488,24 +1301,22 @@ def encode_conversations_with_harmony( assert reasoning_effort in ("low", "medium", "high") match reasoning_effort: - case "low": harmony_reasoning = ReasoningEffort.LOW + case "low": harmony_reasoning = ReasoningEffort.LOW case "medium": harmony_reasoning = ReasoningEffort.MEDIUM - case "high": harmony_reasoning = ReasoningEffort.HIGH + case "high": harmony_reasoning = ReasoningEffort.HIGH convos = [] # Create system message import datetime - today = datetime.datetime.today().strftime("%Y-%m-%d") - system = Message.from_role_and_content( - Role.SYSTEM, + system = Message.from_role_and_content(Role.SYSTEM, SystemContent.new() .with_model_identity(model_identity) .with_reasoning_effort(harmony_reasoning) .with_conversation_start_date(today) .with_knowledge_cutoff("2024-06") - .with_required_channels(["analysis", "commentary", "final"]), + .with_required_channels(["analysis", "commentary", "final"]) ) convos.append(system) @@ -2529,21 +1340,27 @@ def encode_conversations_with_harmony( for message in messages: if message["role"] == "user": - convos.append(Message.from_role_and_content(Role.USER, message["content"])) + convos.append( + Message.from_role_and_content(Role.USER, message["content"]) + ) elif message["role"] == "assistant": if "thinking" in message: x = Message.from_role_and_content(Role.ASSISTANT, message["content"]) x = x.with_channel("analysis") elif "tool_calls" in message: - x = Message.from_role_and_content(Role.ASSISTANT, message["tool_calls"][0]["arguments"]) - x = x.with_channel("commentary").with_recipient(f"functions.{message['tool_calls'][0]['name']}").with_content_type("json") + x = Message.from_role_and_content(Role.ASSISTANT, message['tool_calls'][0]["arguments"]) + x = x.with_channel("commentary")\ + .with_recipient(f"functions.{message['tool_calls'][0]['name']}")\ + .with_content_type("json") else: x = Message.from_role_and_content(Role.ASSISTANT, message["content"]) x = x.with_channel("final") convos.append(x) elif message["role"] == "tool": - x = Message.from_author_and_content(Author.new(Role.TOOL, f"functions.{message['name']}"), message["content"]) - x = x.with_recipient("assistant").with_channel("commentary") + x = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{message['name']}"), + message["content"], + ).with_recipient("assistant").with_channel("commentary") convos.append(x) pass @@ -2555,8 +1372,6 @@ def encode_conversations_with_harmony( harmony_input_ids = encoding.render_conversation(convos) harmony_decoded_text = encoding.decode(harmony_input_ids) return harmony_decoded_text, harmony_input_ids - - pass @@ -2565,10 +1380,8 @@ def encode_conversations_with_harmony( # AutoConfig error: 'GptOssConfig' object has no attribute 'max_position_embeddings' try: from transformers.configuration_utils import layer_type_validation - try: from transformers.configuration_utils import PreTrainedConfig - PretrainedConfig = PreTrainedConfig except: from transformers.configuration_utils import PretrainedConfig @@ -2650,7 +1463,9 @@ def __init__( self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads self.layer_types = layer_types if self.layer_types is None: - self.layer_types = ["sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)] + self.layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + ] layer_type_validation(self.layer_types) # Validate the correctness of rotary position embeddings parameters @@ -2713,13 +1528,7 @@ def __init__( initializer_range: float = 0.02, max_position_embeddings=131072, rms_norm_eps: float = 1e-5, - rope_scaling={ - "rope_type": "yarn", - "factor": 32.0, - "beta_fast": 32.0, - "beta_slow": 1.0, - "truncate": False, - }, + rope_scaling={"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False}, attention_dropout: float = 0.0, num_experts_per_tok=4, router_aux_loss_coef: float = 0.9, @@ -2747,16 +1556,11 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout - self.head_dim = ( - head_dim - if head_dim is not None - else self.hidden_size // self.num_attention_heads - ) + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads self.layer_types = layer_types if self.layer_types is None: self.layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" - for i in range(self.num_hidden_layers) + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) ] layer_type_validation(self.layer_types) self.attention_bias = True @@ -2784,7 +1588,6 @@ def __init__( def patch_gpt_oss_config(): try: import transformers.models.gpt_oss.configuration_gpt_oss - transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig except Exception as e: return raise_error("transformers.models.gpt_oss.configuration_gpt_oss", e) @@ -2798,150 +1601,7 @@ def patch_gpt_oss_config(): patch_function(transformers.models.gpt_oss.configuration_gpt_oss, "GptOssConfig", GptOssConfig) except Exception as e: return raise_error("transformers.models.gpt_oss.configuration_gpt_oss", e) - pass TEMPORARY_PATCHES.append(patch_gpt_oss_config) except Exception as e: raise_error("transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig", e) - - -def patch_gpt_oss_init_weights_modulelist_fix(): - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): - return - try: - import transformers.models.gpt_oss.modeling_gpt_oss - except Exception as e: - return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) - - GptOssPreTrainedModel = ( - transformers.models.gpt_oss.modeling_gpt_oss.GptOssPreTrainedModel - ) - GptOssExperts = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts - if getattr(GptOssPreTrainedModel, "_unsloth_init_weights_fixed", False): - return - _original_init_weights = GptOssPreTrainedModel._init_weights - - def _patched_init_weights(self, module): - if isinstance(module, GptOssExperts) and not hasattr(module, "gate_up_proj"): - std = self.config.initializer_range - for up in getattr(module, "gate_up_projs", []): - init.normal_(up.weight, mean=0.0, std=std) - if up.bias is not None: - init.zeros_(up.bias) - for down in getattr(module, "down_projs", []): - init.normal_(down.weight, mean=0.0, std=std) - if down.bias is not None: - init.zeros_(down.bias) - return - _original_init_weights(self, module) - - patch_function(GptOssPreTrainedModel, "_init_weights", _patched_init_weights) - GptOssPreTrainedModel._unsloth_init_weights_fixed = True - - -pass -TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) - - -# ============================================================================ -# Patch GptOssForCausalLM.forward for GRPO training -# When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits -# ============================================================================ -def patch_gpt_oss_for_grpo(): - """ - Patch GptOssForCausalLM.forward for GRPO training. - When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits. - This fixes the matrix multiplication dimension mismatch issue in GRPO training. - """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): - return - - try: - import transformers.models.gpt_oss.modeling_gpt_oss - from transformers.models.gpt_oss.modeling_gpt_oss import ( - GptOssForCausalLM, - MoeCausalLMOutputWithPast, - ) - - _original_causal_lm_forward = GptOssForCausalLM.forward - - def _patched_causal_lm_forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - past_key_values=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - cache_position=None, - logits_to_keep=0, - **kwargs, - ): - # This Unsloth Zoo code section is licensed under AGPL3 - - RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1" - - if not RETURN_HIDDEN_STATES: - # Normal forward pass - return _original_causal_lm_forward( - self, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - **kwargs, - ) - - # RETURN_HIDDEN_STATES mode - return hidden_states instead of logits - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - - # Apply slice_indices to hidden_states (same indexing as for logits) - if logits_to_keep != 0: - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - hidden_states = hidden_states[:, slice_indices, :] - - # Return hidden_states as "logits" for GRPO to use - return MoeCausalLMOutputWithPast( - loss=None, - aux_loss=getattr(outputs, 'aux_loss', None), - logits=hidden_states, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=getattr(outputs, 'router_logits', None), - ) - - GptOssForCausalLM.forward = _patched_causal_lm_forward - if UNSLOTH_ENABLE_LOGGING: - logger.info("Unsloth: Patched GptOssForCausalLM.forward for GRPO hidden states.") - - except Exception as e: - if UNSLOTH_ENABLE_LOGGING: - logger.warning(f"Unsloth: Could not patch GptOssForCausalLM.forward: {e}") - - -pass -TEMPORARY_PATCHES.append(patch_gpt_oss_for_grpo)