diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b905a1e74..756135705 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -811,6 +811,10 @@ def create_new_function( imports += "import torch\n" imports += "import torch.nn as nn\n" imports += "from torch.nn import functional as F\n" + if "torch_compile" in new_source: + imports += "from unsloth_zoo.temporary_patches.common import torch_compile\n" + if "KWARGS_TYPE" in new_source: + imports += "from unsloth_zoo.temporary_patches.utils import KWARGS_TYPE\n" imports += ( "from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable\n" ) diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 3a3cbcc5f..cf7c2a157 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -18,6 +18,7 @@ import os import torch import torch.nn as nn +import torch.nn.init as init import torch.nn.functional as F import inspect from .common import ( @@ -44,6 +45,32 @@ from ..hf_utils import dtype_from_config torch_cuda_device = torch.cuda.device +# MXFP4 configuration +# Set UNSLOTH_MXFP4_NO_DEQUANTIZE=1 to keep MXFP4 weights quantized (requires triton_kernels) +# Otherwise, MXFP4 weights will be dequantized to bf16 for LoRA training +UNSLOTH_MXFP4_NO_DEQUANTIZE = os.environ.get("UNSLOTH_MXFP4_NO_DEQUANTIZE", "0") == "1" + + +def _check_triton_kernels_available(): + """Check if OpenAI's triton_kernels package is available for MXFP4.""" + try: + from triton_kernels import matmul_ogs, swiglu + + return True + except ImportError: + return False + + +_TRITON_KERNELS_AVAILABLE = None + + +def is_triton_kernels_available(): + """Cached check for triton_kernels availability.""" + global _TRITON_KERNELS_AVAILABLE + if _TRITON_KERNELS_AVAILABLE is None: + _TRITON_KERNELS_AVAILABLE = _check_triton_kernels_available() + return _TRITON_KERNELS_AVAILABLE + @torch_compile(dynamic = True, fullgraph = True) def swiglu_torch_forward(a, alpha, limit, dtype = None): @@ -59,6 +86,7 @@ def swiglu_torch_forward(a, alpha, limit, dtype = None): return out.to(a.dtype if dtype is None else dtype) pass + @torch_compile(dynamic = True, fullgraph = True) def swiglu_torch_backward(pre_act, alpha, limit, g1): g, l = pre_act[..., ::2].to(torch.float32), pre_act[..., 1::2].to(torch.float32) @@ -83,16 +111,19 @@ def swiglu_torch_backward(pre_act, alpha, limit, g1): return g1 * grad.to(g1.dtype) pass - def patch_gpt_oss(): try: import triton_kernels - except Exception as e: - return raise_error("Please install triton_kernels", e) + HAS_TRITON_KERNELS = True + except Exception as e: + HAS_TRITON_KERNELS = False + # return raise_error("Please install triton_kernels", e) try: import transformers.quantizers.quantizer_mxfp4 + def is_kernels_available(): return True + transformers.quantizers.quantizer_mxfp4.is_kernels_available = is_kernels_available transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer.is_trainable = lambda *args, **kwargs: True except Exception as e: @@ -106,16 +137,21 @@ def is_kernels_available(): return True except Exception as e: return raise_error("transformers.quantizers.quantizer_mxfp4.Mxfp4HfQuantizer", e) - try: - from triton_kernels import matmul_ogs, swiglu - FnSpecs, FusedActivation, matmul_ogs = ( - matmul_ogs.FnSpecs, - matmul_ogs.FusedActivation, - matmul_ogs.matmul_ogs, - ) - swiglu_fn = swiglu.swiglu_fn - except Exception as e: - return raise_error("triton_kernels", e) + if HAS_TRITON_KERNELS: + try: + from triton_kernels import matmul_ogs, swiglu + + FnSpecs, FusedActivation, matmul_ogs = ( + matmul_ogs.FnSpecs, + matmul_ogs.FusedActivation, + matmul_ogs.matmul_ogs, + ) + swiglu_fn = swiglu.swiglu_fn + except Exception as e: + return raise_error("triton_kernels", e) + else: + # Skip MXFP4 patches when triton_kernels not available + return try: import transformers.integrations.mxfp4 @@ -205,8 +241,8 @@ def forward( @staticmethod def backward(ctx, grad_token): raise NotImplementedError( - "Backwards pass using MXFP4 is still under construction!\n"\ - "Instead, use `unsloth/gpt-oss-20b-BF16` for bfloat16 training which will work for LoRA.\n"\ + "Backwards pass using MXFP4 is still under construction!\n" + "Instead, use `unsloth/gpt-oss-20b-BF16` for bfloat16 training which will work for LoRA.\n" "Or, use `load_in_4bit = True` which allows finetuning." ) (pre_act, gamma, gather_src, gather_dst, scatter_src, scatter_dst,) = ctx.saved_tensors @@ -233,6 +269,7 @@ def backward(ctx, grad_token): dx_token.index_add_(0, gather_dst, dx_exp) return (dx_token, None, None, None, None,) pass + pass class Mxfp4GptOssExperts(nn.Module): @@ -243,8 +280,15 @@ def __init__(self, config): self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size + # MXFP4 quantized format (blocks + scales) self.gate_up_proj_blocks = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), + torch.zeros( + self.num_experts, + 2 * self.intermediate_size, + self.hidden_size // 32, + 16, + dtype=torch.uint8, + ), requires_grad=False, ) self.gate_up_proj_scales = nn.Parameter( @@ -252,32 +296,108 @@ def __init__(self, config): requires_grad=False, ) self.gate_up_proj_bias = nn.Parameter( - torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False + torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False, ) self.down_proj_blocks = nn.Parameter( - torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), - requires_grad=False, + torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), requires_grad=False ) self.down_proj_scales = nn.Parameter( - torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), - requires_grad=False, + torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), requires_grad=False ) self.down_proj_bias = nn.Parameter( torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False ) + self.alpha = 1.702 self.limit = getattr(config, "swiglu_limit", 7.0) self.gate_up_proj_precision_config = None self.down_proj_precision_config = None - def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: + @property + def gate_up_proj(self): + """Return gate_up_proj tensor (created from blocks/scales or stored directly).""" + # Check if already set as an attribute (from checkpoint loading or previous dequantization) + if "_gate_up_proj" in self.__dict__: + return self.__dict__["_gate_up_proj"] + + # Check if MXFP4 weights are present (blocks and scales are not all zeros) + blocks_valid = ( + self.gate_up_proj_blocks.device.type != "meta" + and self.gate_up_proj_blocks.numel() > 0 + and self.gate_up_proj_blocks.any() + ) + + if not blocks_valid: + # No MXFP4 weights and no regular weights loaded + raise AttributeError( + f"Mxfp4GptOssExperts.gate_up_proj: No weights loaded. " + f"Try 'openai/gpt-oss-20b' with load_in_4bit=True instead." + ) + + # MXFP4 weights present - dequantize them + try: + from transformers.integrations.mxfp4 import dequantize + # Dequantize: (E, out_dim, in_dim//32, 16) -> (E, out_dim, in_dim) + dequantized = dequantize(self.gate_up_proj_blocks, self.gate_up_proj_scales) + # Cache for future accesses + self.__dict__["_gate_up_proj"] = dequantized + return dequantized + except Exception as e: + raise RuntimeError( + f"Failed to dequantize MXFP4 gate_up_proj: {e}. " + f"Ensure transformers.integrations.mxfp4.dequantize is available." + ) + + @gate_up_proj.setter + def gate_up_proj(self, value): + """Set gate_up_proj tensor (called during checkpoint loading).""" + self.__dict__["_gate_up_proj"] = value + + @property + def down_proj(self): + """Return down_proj tensor (created from blocks/scales or stored directly).""" + if "_down_proj" in self.__dict__: + return self.__dict__["_down_proj"] + + blocks_valid = ( + self.down_proj_blocks.device.type != "meta" + and self.down_proj_blocks.numel() > 0 + and self.down_proj_blocks.any() + ) + + if not blocks_valid: + raise AttributeError( + f"Mxfp4GptOssExperts.down_proj: No weights loaded." + ) + + # MXFP4 weights present - dequantize them + try: + from transformers.integrations.mxfp4 import dequantize + # Dequantize: (E, out_dim, in_dim//32, 16) -> (E, out_dim, in_dim) + dequantized = dequantize(self.down_proj_blocks, self.down_proj_scales) + # Cache for future accesses + self.__dict__["_down_proj"] = dequantized + return dequantized + except Exception as e: + raise RuntimeError( + f"Failed to dequantize MXFP4 down_proj: {e}" + ) + + @down_proj.setter + def down_proj(self, value): + """Set down_proj tensor (called during checkpoint loading).""" + self.__dict__["_down_proj"] = value + + def forward( + self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx + ) -> torch.Tensor: with torch_cuda_device(hidden_states.device): if not hasattr(self, "act"): self.act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2) if not hidden_states.requires_grad: intermediate_cache1 = matmul_ogs( - hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware + hidden_states.to(torch.bfloat16), # tl.dot_scaled upcasts to BF16 for old hardware self.gate_up_proj, self.gate_up_proj_bias, routing_data, @@ -304,36 +424,41 @@ def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter scatter_idx, ) return intermediate_cache3 + pass + patch_function(transformers.integrations.mxfp4, "Mxfp4GptOssExperts", Mxfp4GptOssExperts) - try: - routing = triton_kernels.routing.routing - routing = torch.compiler.disable(routing) - except Exception as e: - return raise_error("triton_kernels.routing.routing", e) + if HAS_TRITON_KERNELS: + try: + routing = triton_kernels.routing.routing + routing = torch.compiler.disable(routing) + except Exception as e: + return raise_error("triton_kernels.routing.routing", e) - def mlp_forward(self, hidden_states): - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) - router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) + def mlp_forward(self, hidden_states): + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) + router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) - with torch_cuda_device(router_logits.device): - routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) + with torch_cuda_device(router_logits.device): + routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) - routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) - routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) - return routed_out, router_logits - patch_function(transformers.integrations.mxfp4, "mlp_forward", mlp_forward) + routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) + routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) + return routed_out, router_logits - try: - PrecisionConfig, FlexCtx, InFlexData = ( - triton_kernels.matmul_ogs.PrecisionConfig, - triton_kernels.matmul_ogs.FlexCtx, - triton_kernels.matmul_ogs.InFlexData, - ) - except Exception as e: - return raise_error("triton_kernels.matmul_ogs", e) + patch_function(transformers.integrations.mxfp4, "mlp_forward", mlp_forward) + + if HAS_TRITON_KERNELS: + try: + PrecisionConfig, FlexCtx, InFlexData = ( + triton_kernels.matmul_ogs.PrecisionConfig, + triton_kernels.matmul_ogs.FlexCtx, + triton_kernels.matmul_ogs.InFlexData, + ) + except Exception as e: + return raise_error("triton_kernels.matmul_ogs", e) try: from transformers.integrations.tensor_parallel import shard_and_distribute_module @@ -351,17 +476,22 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args for proj in ["gate_up_proj", "down_proj"]: if proj in param_name: if device_mesh is not None: - shard_and_distribute_module( - model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh - ) + shard_and_distribute_module(model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh) else: setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False)) blocks_attr = f"{proj}_blocks" scales_attr = f"{proj}_scales" blocks = getattr(module, blocks_attr) scales = getattr(module, scales_attr) - # Check if both blocks and scales both not on on meta device - if blocks.device.type != "meta" and scales.device.type != "meta": + # Check if both blocks and scales both not on meta device AND not all zeros + # (if blocks are all zeros, they're from initialization, not from checkpoint) + blocks_valid = ( + blocks.device.type != "meta" + and scales.device.type != "meta" + and blocks.numel() > 0 + and blocks.any() # At least some non-zero values + ) + if blocks_valid: # need it for ep local_experts = blocks.size(0) if proj == "gate_up_proj": @@ -402,6 +532,7 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, *args delattr(module, blocks_attr) # setattr(module, blocks_attr, torch.nn.Parameter(triton_weight_tensor.storage.data, requires_grad=False)) del blocks + pass patch_function(transformers.integrations.mxfp4, "load_and_swizzle_mxfp4", load_and_swizzle_mxfp4, match_level = "relaxed") @@ -418,17 +549,11 @@ def replace_with_mxfp4_linear( config=None, ): if quantization_config.dequantize: return model - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + modules_to_not_convert = (["lm_head"] if modules_to_not_convert is None else modules_to_not_convert) if quantization_config.modules_to_not_convert is not None: modules_to_not_convert.extend(quantization_config.modules_to_not_convert) modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _replace_with_mxfp4_linear( - model, - modules_to_not_convert, - current_key_name, - quantization_config, - config=config, - ) + model, has_been_replaced = _replace_with_mxfp4_linear(model, modules_to_not_convert, current_key_name, quantization_config, config=config) if not has_been_replaced: logger.warning_once( "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." @@ -437,18 +562,288 @@ def replace_with_mxfp4_linear( ) return model + patch_function(transformers.integrations.mxfp4, "replace_with_mxfp4_linear", replace_with_mxfp4_linear) pass TEMPORARY_PATCHES.append(patch_gpt_oss) +class ParameterModule(nn.Linear): + """ + A module that wraps a parameter to look like a Linear layer for PEFT. + It inherits from nn.Linear but manages 3D <-> 2D weight conversion. + Unsloth grouped_mm requires 3D weights: + - gate_up: (E, H, 2I) + - down: (E, I, H) + PEFT Linear requires 2D weights: (Out, In). + We store the weight as 2D for PEFT, and reshape for Unsloth via get_param(). + """ + + def __init__( + self, in_features, out_features, shape_3d, permute_to_2d, permute_to_3d + ): + # Initialize nn.Linear (creates weight of size (out, in)) + super().__init__(in_features, out_features, bias=False) + self.shape_3d = shape_3d + self.permute_to_2d = permute_to_2d + self.permute_to_3d = permute_to_3d + + # We expect the caller to set the weight content correctly. + # nn.Linear initialized it randomly. We will overwrite it. + + def extra_repr(self): + return f"in_features={self.in_features}, out_features={self.out_features}, shape_3d={self.shape_3d}" + + def get_param(self): + """Restores the 3D weight for Unsloth computation.""" + # 2D weight (Out, In) -> View (Unflattened 2D) -> Permute -> 3D(E, ...) + # We need to know the unflattened shape. + # gate_up: 2D (E*2I, H). View (E, 2I, H). Permute(0,2,1) -> (E, H, 2I). + # We store (E*2I, H). + # To restore: View (E, 2I, H)? + # (E, 2I, H) is shape_3d permuted by permute_to_2d. + + unflattened_shape = [self.shape_3d[i] for i in self.permute_to_2d] + return self.weight.view(*unflattened_shape).permute(*self.permute_to_3d).contiguous() + + def set_weight_from_3d(self, weight_3d): + """Sets the 2D weight from a 3D tensor.""" + # 3D -> Permute -> Flatten + weight_2d = weight_3d.permute(*self.permute_to_2d).reshape(self.out_features, self.in_features) + self.weight.data.copy_(weight_2d) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # 'gate_up_proj' in checkpoint is 3D. + # We need to load it into 'gate_up_proj.weight' (2D). + # key is '...gate_up_proj' + key = prefix[:-1] + + if key in state_dict: + # Found the parameter (likely from original model structure where it was a Param) + val = state_dict[key] + # Convert 3D val to 2D + val_2d = val.permute(*self.permute_to_2d).reshape(self.out_features, self.in_features) + + # Put into 'weight' key + state_dict[prefix + "weight"] = val_2d + del state_dict[key] + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + class GptOssExperts(nn.Module): + """ + GPT OSS MoE Experts layer with 3D stacked parameters. + Compatible with transformers' _init_weights and supports grouped_mm with split LoRA. + + Uses the same structure as the original transformers GptOssExperts: + - gate_up_proj: (num_experts, hidden_size, 2 * expert_dim) + - gate_up_proj_bias: (num_experts, 2 * expert_dim) + - down_proj: (num_experts, expert_dim, hidden_size) + - down_proj_bias: (num_experts, hidden_size) + """ + + def __init__(self, config): + super().__init__() + + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.expert_dim = config.intermediate_size + self.intermediate_size = config.intermediate_size # Alias for compatibility + self.alpha = 1.702 + self.limit = getattr(config, "swiglu_limit", 7.0) + self.dtype = dtype_from_config(config) + + # gate_up_proj: 3D (E, H, 2I). Target 2D (E*2I, H). + # Permute (0, 2, 1) -> (E, 2I, H). Reverse (0, 2, 1). + self.gate_up_proj = ParameterModule( + in_features=self.hidden_size, + out_features=self.num_experts * 2 * self.expert_dim, + shape_3d=(self.num_experts, self.hidden_size, 2 * self.expert_dim), + permute_to_2d=(0, 2, 1), + permute_to_3d=(0, 2, 1), + ) + # Initialize 3D zero tensor and set 2D weight + self.gate_up_proj.set_weight_from_3d( + torch.zeros( + self.num_experts, + self.hidden_size, + 2 * self.expert_dim, + dtype=self.dtype, + ) + ) + + self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim, dtype=self.dtype)) + + # down_proj: 3D (E, I, H). Target 2D (H, E*I). + # Permute (2, 0, 1) -> (H, E, I). Reverse (1, 2, 0) + self.down_proj = ParameterModule( + in_features=self.num_experts * self.expert_dim, + out_features=self.hidden_size, + shape_3d=(self.num_experts, self.expert_dim, self.hidden_size), + permute_to_2d=(2, 0, 1), + permute_to_3d=(1, 2, 0), + ) + self.down_proj.set_weight_from_3d( + torch.empty( + self.num_experts, self.expert_dim, self.hidden_size, dtype=self.dtype + ) + ) + + self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, dtype=self.dtype)) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """ + Override to handle loading 3D tensors (gate_up_proj, down_proj) from original checkpoints. + The original checkpoint has these as nn.Parameter (3D tensors), but we now have + them as ParameterModule (nn.Linear subclass) which expects .weight as 2D tensors. + + This method intercepts the 3D tensors and converts them to the 2D format + that ParameterModule expects. + """ + # Handle gate_up_proj: checkpoint has 3D tensor, we need 2D for ParameterModule.weight + gate_up_key = prefix + "gate_up_proj" + gate_up_weight_key = prefix + "gate_up_proj.weight" + if gate_up_key in state_dict and gate_up_weight_key not in state_dict: + val_3d = state_dict.pop(gate_up_key) + # gate_up_proj: 3D (E, H, 2I) -> permute (0, 2, 1) -> (E, 2I, H) -> reshape (E*2I, H) + val_2d = val_3d.permute(0, 2, 1).reshape(self.num_experts * 2 * self.expert_dim, self.hidden_size) + state_dict[gate_up_weight_key] = val_2d + + # Handle down_proj: checkpoint has 3D tensor, we need 2D for ParameterModule.weight + down_key = prefix + "down_proj" + down_weight_key = prefix + "down_proj.weight" + if down_key in state_dict and down_weight_key not in state_dict: + val_3d = state_dict.pop(down_key) + # down_proj: 3D (E, I, H) -> permute (2, 0, 1) -> (H, E, I) -> reshape (H, E*I) + val_2d = val_3d.permute(2, 0, 1).reshape(self.hidden_size, self.num_experts * self.expert_dim) + state_dict[down_weight_key] = val_2d + + # Call parent implementation + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + """Forward using grouped_mm or loop fallback with LoRA support.""" + # Use optimized grouped_mm if available + if _check_torch_grouped_mm_supported(): + return forward_native_grouped_mm(self, hidden_states, router_indices, routing_weights) + + # Fallback to loop-based implementation + return torch_native_forward(self, hidden_states, router_indices, routing_weights) + + +pass + + +class _RouterLinearParams(nn.Module): + """ + Simple parameter container that stores weight/bias like nn.Linear + but is NOT nn.Linear itself, so BitsAndBytes will NOT quantize it. + State dict keys: linear.weight, linear.bias (matching BnB 4-bit checkpoints + where the router was saved via an nn.Linear submodule). + """ + def __init__(self, in_features, out_features, dtype): + super().__init__() + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) + self.bias = nn.Parameter(torch.zeros(out_features, dtype=dtype)) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + + +class GptOssTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + # Use _RouterLinearParams (not nn.Linear) to avoid BnB 4-bit quantization. + # State dict keys are router.linear.weight / router.linear.bias, matching + # the BnB 4-bit checkpoint format where router was stored via nn.Linear. + self.linear = _RouterLinearParams(self.hidden_dim, self.num_experts, dtype=dtype_from_config(config)) + + # Properties for compatibility with transformers' _init_weights which expects .weight and .bias + @property + def weight(self): + return self.linear.weight + + @weight.setter + def weight(self, value): + self.linear.weight = value + + @property + def bias(self): + return self.linear.bias + + @bias.setter + def bias(self, value): + self.linear.bias = value + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.linear(hidden_states.to(self.linear.weight.dtype)) # (batch_size * seq_len, num_experts) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + router_scores = torch.zeros_like(router_logits, dtype=router_logits.dtype).scatter_(1, router_indices, router_top_value) + if transformers_version >= Version("5.0.0"): + return router_logits, router_scores, router_indices + else: + return router_scores, router_indices + +pass + + +# BitsAndBytes 4bit compatible classes for loading pre-quantized models +class GptOssExpertsBnb4bit(nn.Module): + """ + GPT OSS MoE Experts using nn.Linear layers for BitsAndBytes 4bit compatibility. + This version uses gate_up_projs and down_projs as ModuleLists of Linear layers, + which can be properly quantized with BitsAndBytes. + """ + def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = config.intermediate_size + self.intermediate_size = config.intermediate_size self.alpha = 1.702 self.limit = getattr(config, "swiglu_limit", 7.0) self.dtype = dtype_from_config(config) @@ -462,16 +857,26 @@ def __init__(self, config): for _ in range(self.num_experts) ]) - def forward( - self, - hidden_states: torch.Tensor, - router_indices = None, - routing_weights = None - ) -> torch.Tensor: + # Provide minimal tensors for transformers _init_weights compatibility. + # Keep them empty to avoid allocating large bf16 weights in 4-bit mode. + self.register_buffer( + "gate_up_proj", torch.empty(0, dtype=self.dtype), persistent=False + ) + self.register_buffer( + "gate_up_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False + ) + self.register_buffer( + "down_proj", torch.empty(0, dtype=self.dtype), persistent=False + ) + self.register_buffer( + "down_proj_bias", torch.empty(0, dtype=self.dtype), persistent=False + ) + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) num_experts = routing_weights.shape[1] + if self.training: next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) # with torch.no_grad(): @@ -512,30 +917,176 @@ def forward( rw = routing_weights.transpose(0, 1).unsqueeze(-1) mixed = (outs.to(torch.float32) * rw.to(torch.float32)).sum(dim=0) return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) + pass -class GptOssTopKRouter(nn.Module): + +class GptOssTopKRouterBnb4bit(nn.Module): + """ + GPT OSS Router using direct parameters for BitsAndBytes 4bit compatibility. + This version uses weight/bias as direct nn.Parameter instead of nested Linear. + """ + def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - self.linear = nn.Linear(self.hidden_dim, self.num_experts, dtype=dtype_from_config(config)) + self.dtype = dtype_from_config(config) + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, dtype=self.dtype)) + self.bias = nn.Parameter(torch.zeros(self.num_experts, dtype=self.dtype)) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # Accept checkpoints that used a Linear router (linear.weight/bias) + linear_weight_key = prefix + "linear.weight" + linear_bias_key = prefix + "linear.bias" + weight_key = prefix + "weight" + bias_key = prefix + "bias" + moved_weight = False + moved_bias = False + if linear_weight_key in state_dict and weight_key not in state_dict: + state_dict[weight_key] = state_dict.pop(linear_weight_key) + moved_weight = True + if linear_bias_key in state_dict and bias_key not in state_dict: + state_dict[bias_key] = state_dict.pop(linear_bias_key) + moved_bias = True + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + if moved_weight: + if linear_weight_key in unexpected_keys: + unexpected_keys.remove(linear_weight_key) + if weight_key in missing_keys: + missing_keys.remove(weight_key) + if moved_bias: + if linear_bias_key in unexpected_keys: + unexpected_keys.remove(linear_bias_key) + if bias_key in missing_keys: + missing_keys.remove(bias_key) + if self.weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): + self.weight.data = self.weight.data.to(self.dtype) + if self.bias is not None and self.bias.dtype not in ( + torch.float16, torch.bfloat16, torch.float32 + ): + self.bias.data = self.bias.data.to(self.dtype) - @torch_compile(dynamic = True, fullgraph = True) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.linear(hidden_states.to(self.linear.weight.dtype)) # (batch_size * seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_logits = torch.nn.functional.linear(hidden_states.to(self.weight.dtype), self.weight, self.bias) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) dtype = torch.float32 if router_logits.dtype == torch.float16 else router_logits.dtype router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) - router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = torch.zeros_like(router_logits, dtype=dtype).scatter_(1, router_indices, router_top_value) + if transformers_version >= Version("5.0.0"): + return router_logits, router_scores, router_indices + else: + return router_scores, router_indices + + pass +def patch_gpt_oss_bnb4bit(): + """ + Patch transformers to use BnB 4bit compatible classes when loading pre-quantized models. + This should be called before loading models that were saved with BitsAndBytes quantization + using the linear-based expert structure. + + Usage: + from unsloth_zoo.temporary_patches.gpt_oss import patch_gpt_oss_bnb4bit + patch_gpt_oss_bnb4bit() # Call before loading the model + model = FastLanguageModel.from_pretrained(...) + """ + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + + # Store original classes for potential restoration + if not hasattr(transformers.models.gpt_oss.modeling_gpt_oss, '_original_GptOssExperts'): + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssExperts = \ + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssTopKRouter = \ + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter + + # Replace with BnB 4bit compatible versions + # Preserve original symbol names for compiler-generated modules. + GptOssExpertsBnb4bit.__name__ = "GptOssExperts" + GptOssExpertsBnb4bit.__qualname__ = "GptOssExperts" + + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExpertsBnb4bit + # Use the unsloth GptOssTopKRouter (with self.linear = nn.Linear) for the router. + # The BnB 4-bit checkpoint stores router weights as router.linear.weight/bias. + # GptOssTopKRouterBnb4bit had self.weight/bias directly with a _load_from_state_dict + # override to remap keys, but transformers v5 bypasses _load_from_state_dict + # (uses accelerate's set_module_tensor_to_device), so the remapping never ran + # and router weights were randomly initialized - causing high loss (~4-5). + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter + + logger.info("Unsloth: Patched GPT OSS with BitsAndBytes 4bit compatible classes") + os.environ["UNSLOTH_GPT_OSS_BNB4BIT_PATCHED"] = "1" + return True + + +pass + + +def restore_gpt_oss_original(): + """ + Restore original GPT-OSS classes (undo BnB 4bit patch). + """ + try: + import transformers.models.gpt_oss.modeling_gpt_oss + if hasattr(transformers.models.gpt_oss.modeling_gpt_oss, '_original_GptOssExperts'): + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = \ + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssExperts + transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = \ + transformers.models.gpt_oss.modeling_gpt_oss._original_GptOssTopKRouter + logger.info("Unsloth: Restored original GPT OSS classes") + return True + except Exception: + pass + return False + +def patch_gpt_oss_bnb4bit_auto(): + """ + Auto-patch GPT-OSS for BnB 4-bit when load_in_4bit is active. + Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to opt out. + """ + if not _should_use_gpt_oss_bnb4bit(): + return + # Avoid compiler-generated modules missing BnB helpers + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + patch_gpt_oss_bnb4bit() + # Ensure inference path avoids torch.compile for 4-bit + try: + global moe_forward_inference + moe_forward_inference = torch.compiler.disable(moe_forward_inference) + except Exception: + pass + + +TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto) + + # Combo kernels uses too much VRAM for low memory GPUs from ..device_type import DEVICE_TYPE + if DEVICE_TYPE == "xpu": device_memory = torch.xpu.memory.mem_get_info(0)[-1] else: @@ -566,11 +1117,16 @@ def forward(self, hidden_states): logging = UNSLOTH_ENABLE_LOGGING, ) -@_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) + +@_torch_compile(dynamic=None, fullgraph=True, options=fused_torch_compile_options) def moe_forward_inference(self, hidden_states): """Torch compile for forward inference path only with CUDAGraphs""" # Router - router_scores, router_indices = self.router(hidden_states) + router_out = self.router(hidden_states) + if isinstance(router_out, tuple) and len(router_out) == 3: + _, router_scores, router_indices = router_out + else: + router_scores, router_indices = router_out routing_weights = router_scores moe = self.experts batch_size = hidden_states.shape[0] @@ -579,25 +1135,41 @@ def moe_forward_inference(self, hidden_states): num_experts = routing_weights.shape[1] X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) - # Gate up projection - gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(moe.gate_up_projs)] - gate_up = torch.stack(gate_up_list, dim = 0) - dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype - fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype = dtype) - - # Down projection must be done in float32 if not bfloat16 otherwise infinites - fused = fused.to(dtype) - device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - out_list = [down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs)] - outs = torch.stack(out_list, dim=0) + # Check if using ModuleList (old style) or 3D parameters (new style) + if hasattr(moe, "gate_up_projs"): + # ModuleList style + gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(moe.gate_up_projs)] + gate_up = torch.stack(gate_up_list, dim=0) + dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype + fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) + + fused = fused.to(dtype) + device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + out_list = [down_l(fused[e].to(dtype)) for e, down_l in enumerate(moe.down_projs)] + outs = torch.stack(out_list, dim=0) + else: + # 3D parameter style (compatible with transformers) + # gate_up_proj: (E, hidden_size, 2*expert_dim) - bmm: (E, N, H) @ (E, H, 2I) -> (E, N, 2I) + gate_up = torch.bmm(X_rep, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] + dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype + fused = swiglu_torch_forward(gate_up, moe.alpha, moe.limit, dtype=dtype) + + fused = fused.to(dtype) + device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + # down_proj: (E, expert_dim, hidden_size) - bmm: (E, N, I) @ (E, I, H) -> (E, N, H) + outs = torch.bmm(fused.to(dtype), moe.down_proj) + moe.down_proj_bias[..., None, :] rw = routing_weights.to(dtype).transpose(0, 1).unsqueeze(-1) mixed = (outs * rw).sum(dim=0) return mixed.view(batch_size, -1, moe.hidden_size).to(hidden_states.dtype) + + pass -@torch_compile(dynamic = True, fullgraph = True) + +@torch_compile(dynamic=True, fullgraph=True) def moe_router_forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, self.bias) # (seq_len, num_experts) @@ -606,32 +1178,84 @@ def moe_router_forward(self, hidden_states): router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=torch.float32).to(dtype) router_scores = torch.zeros_like(router_logits, dtype = dtype).scatter_(1, router_indices, router_top_value) return router_scores, router_indices + + pass -# Combo Kernels errors with InductorError: AttributeError: 'NullKernelHandler' object has no attribute 'index_to_str' -@_torch_compile(dynamic = None, fullgraph = True, options = no_combo_fused_torch_compile_options) -def moe_forward_inference_bf16(self, hidden_states): - router_scores, router_indices = moe_router_forward(self.router, hidden_states) - routing_weights = router_scores - moe = self.experts +# Combo Kernels errors with InductorError: AttributeError: 'NullKernelHandler' object has no attribute 'index_to_str' +@_torch_compile( + dynamic=None, fullgraph=True, options=no_combo_fused_torch_compile_options +) +def _moe_forward_inference_bf16_kernel( + hidden_states, routing_weights, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias, limit, alpha, hidden_size +): + """Inner compiled kernel for BF16 MoE inference - works with raw tensors only.""" batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, moe.hidden_size) + hidden_states = hidden_states.reshape(-1, hidden_size) num_experts = routing_weights.shape[1] hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, moe.hidden_size) - gate_up = torch.bmm(hidden_states, moe.gate_up_proj) + moe.gate_up_proj_bias[..., None, :] + hidden_states = hidden_states.view(num_experts, -1, hidden_size) + + gate_up = ( + torch.bmm(hidden_states, gate_up_proj) + gate_up_proj_bias[..., None, :] + ) gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=moe.limit) - up = up.clamp(min=-moe.limit, max=moe.limit) - glu = gate * torch.sigmoid(gate.to(torch.float32) * moe.alpha).to(gate.dtype) - next_states = torch.bmm(((up + 1) * glu), moe.down_proj) - next_states = next_states + moe.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, moe.hidden_size) - next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate.to(torch.float32) * alpha).to(gate.dtype) + next_states = torch.bmm(((up + 1) * glu), down_proj) + next_states = next_states + down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, hidden_size) + next_states = ( + next_states + * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + ) next_states = next_states.sum(dim=0) return next_states -pass + + +def _unwrap_peft_experts(module): + """Unwrap PEFT ParamWrapper chain to get the actual experts module.""" + while hasattr(module, 'base_layer'): + module = module.base_layer + return module + + +def moe_forward_inference_bf16(self, hidden_states): + """Wrapper that extracts weights from ParameterModule before calling the compiled kernel.""" + router_scores, router_indices = moe_router_forward(self.router, hidden_states) + routing_weights = router_scores + + moe = _unwrap_peft_experts(self.experts) + + # Handle ParameterModule (which wraps nn.Linear) vs direct 3D tensor + # Extract weights BEFORE the compiled region + gate_up_proj = moe.gate_up_proj + if hasattr(gate_up_proj, "get_param"): + gate_up_proj = gate_up_proj.get_param() + elif hasattr(gate_up_proj, "weight"): + gate_up_proj = gate_up_proj.weight + + down_proj = moe.down_proj + if hasattr(down_proj, "get_param"): + down_proj = down_proj.get_param() + elif hasattr(down_proj, "weight"): + down_proj = down_proj.weight + + return _moe_forward_inference_bf16_kernel( + hidden_states, + routing_weights, + gate_up_proj, + moe.gate_up_proj_bias, + down_proj, + moe.down_proj_bias, + moe.limit, + moe.alpha, + moe.hidden_size, + ) + + class GptOssMLP(nn.Module): @@ -644,94 +1268,531 @@ def forward(self, hidden_states): bsz, qlen, hd = hidden_states.shape if qlen == 1 and not self.training: return moe_forward_inference(self, hidden_states), None - router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) + router_out = self.router(hidden_states) + if isinstance(router_out, tuple) and len(router_out) == 3: + _, router_scores, router_indices = router_out + else: + router_scores, router_indices = router_out routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) return routed_out, router_scores + + +pass + + +# ============================================================================ +# GPT OSS MoE LoRA Support using grouped GEMM kernels +# ============================================================================ + +# IMPORTS FROM MOE UTILS +from .moe_utils import ( + _check_grouped_gemm_available, + _TORCH_GROUPED_MM_AVAILABLE, + _check_torch_grouped_mm_supported, + native_moe_grouped_mm, + _get_moe_lora_weights, + _apply_lora_grouped_mm, + _get_lora_wrapper_for_param, + select_moe_backend, + patch_param_wrapper_for_moe, + forward_native_grouped_mm, + # torch_native_forward, +) + + +def _should_use_gpt_oss_bnb4bit() -> bool: + """ + Decide if GPT-OSS should use BnB-compatible 4-bit experts. + Default: True when load_in_4bit is active. + Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return False + if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return False + return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1" + + +def _is_gpt_oss_4bit_load() -> bool: + return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "") + + +def _is_transformers_v5() -> bool: + return transformers_version >= Version("5.0.0.dev0") + + +def patch_gpt_oss_moe_for_lora(): + """ + Patch GPT OSS MoE experts for LoRA training with grouped GEMM support. + This patches the original GptOssExperts class (with 3D parameter tensors) + to use optimized grouped GEMM kernels with LoRA support. + + IMPORTANT: We only patch the forward method, NOT replace the entire class. + This preserves the original class structure so weights load correctly. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit(): + # 4-bit loads should keep quantized weights and use default PEFT LoRA. + return + if not _is_transformers_v5(): + # Split-LoRA grouped_mm path is only needed for transformers v5+ + return + # First patch ParamWrapper for MoE separated LoRA + patch_param_wrapper_for_moe() + + try: + import transformers.models.gpt_oss.modeling_gpt_oss + + # Get the ORIGINAL class - don't replace it! + GptOssExpertsClass = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts + except Exception as e: + if UNSLOTH_ENABLE_LOGGING: + logger.warning(f"Unsloth: Could not patch GPT OSS MoE for LoRA: {e}") + return + + # Check if already patched + if hasattr(GptOssExpertsClass, "_unsloth_lora_patched"): + return + + # Select backend + backend = select_moe_backend() + + if backend == "grouped_mm": + forward = forward_native_grouped_mm + else: + forward = torch_native_forward + + # Store original forward and patch - but DON'T replace the class! + GptOssExpertsClass._original_forward = GptOssExpertsClass.forward + GptOssExpertsClass.forward = forward + GptOssExpertsClass._unsloth_lora_patched = True + + if UNSLOTH_ENABLE_LOGGING: + backend_desc = { + "grouped_mm": "torch._grouped_mm (batched, fastest)", + "unsloth_triton": "Triton kernels", + "native_torch": "loop fallback (slower)", + }.get(backend, backend) + logger.info( + f"Unsloth: Patched GPT OSS MoE for LoRA training using {backend_desc}" + ) + + +TEMPORARY_PATCHES.append(patch_gpt_oss_moe_for_lora) + + +# ============================================================================ +# MXFP4 (4-bit) GPT OSS MoE LoRA Support +# ============================================================================ + +_MXFP4_LORA_PATH_LOGGED = False + +@torch.compiler.disable +def forward_mxfp4_gpt_oss_with_lora( + self, + hidden_states: torch.Tensor, + routing_data, + gather_idx, + scatter_idx, +) -> torch.Tensor: + """ + MXFP4 GPT OSS MoE forward pass with LoRA support. + + For LoRA training with MXFP4: + - Base MXFP4 matmul is computed without gradients (frozen weights) + - LoRA delta is computed with gradients + - Output = base_output + lora_delta + + This allows finetuning MXFP4 quantized models with LoRA. + + Requires triton_kernels for native MXFP4 matmul. + If triton_kernels is not available, model should be loaded with + Mxfp4Config(dequantize=True) to convert to bf16. + """ + if not is_triton_kernels_available(): + raise RuntimeError( + "triton_kernels is required for native MXFP4 GPT OSS forward pass. " + "Either:\n" + " 1. Install triton_kernels from OpenAI, OR\n" + " 2. Load model with dequantization: Mxfp4Config(dequantize=True), OR\n" + " 3. Use the BF16 model: 'unsloth/gpt-oss-20b-BF16'\n" + "Set UNSLOTH_MXFP4_NO_DEQUANTIZE=0 (default) to auto-dequantize." + ) + + from triton_kernels import matmul_ogs, swiglu + + matmul_ogs_fn = matmul_ogs.matmul_ogs + FnSpecs = matmul_ogs.FnSpecs + FusedActivation = matmul_ogs.FusedActivation + swiglu_fn = swiglu.swiglu_fn + + # Get LoRA wrappers + gate_up_wrapper = _get_lora_wrapper_for_param(self, "gate_up_proj") + down_wrapper = _get_lora_wrapper_for_param(self, "down_proj") + + has_lora = gate_up_wrapper is not None or down_wrapper is not None + + # Log path info once + global _MXFP4_LORA_PATH_LOGGED + if not _MXFP4_LORA_PATH_LOGGED: + _MXFP4_LORA_PATH_LOGGED = True + logger.warning_once( + f"Unsloth: GPT-OSS MoE training path: MXFP4 + triton_kernels. " + f"LoRA={has_lora}, experts={self.num_experts}. " + f"Tip: Increase batch_size for better GPU utilization." + ) + + # If no LoRA, use the original MXFP4 forward (inference path) + if not has_lora: + with torch_cuda_device(hidden_states.device): + if not hasattr(self, "act"): + self.act = FusedActivation( + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + (self.alpha, self.limit), + 2, + ) + intermediate_cache1 = matmul_ogs_fn( + hidden_states.to(torch.bfloat16), + self.gate_up_proj, + self.gate_up_proj_bias, + routing_data, + gather_indx=gather_idx, + precision_config=self.gate_up_proj_precision_config, + gammas=None, + fused_activation=self.act, + ) + intermediate_cache3 = matmul_ogs_fn( + intermediate_cache1, + self.down_proj, + self.down_proj_bias, + routing_data, + scatter_indx=scatter_idx, + precision_config=self.down_proj_precision_config, + gammas=routing_data.gate_scal if routing_data else None, + ) + return intermediate_cache3 + + # With LoRA: compute base MXFP4 output (no grad) + LoRA delta (with grad) + with torch_cuda_device(hidden_states.device): + # 1. Compute MXFP4 base output (detached, no gradients for base weights) + with torch.no_grad(): + if not hasattr(self, "act"): + self.act = FusedActivation( + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + (self.alpha, self.limit), + 2, + ) + + # Gate-up projection (MXFP4) + pre_activation = matmul_ogs_fn( + hidden_states.to(torch.bfloat16), + self.gate_up_proj, + self.gate_up_proj_bias, + routing_data, + gather_indx=gather_idx, + precision_config=self.gate_up_proj_precision_config, + gammas=None, + fused_activation=None, # Don't fuse activation so we can add LoRA before + ) + + # 2. Add LoRA for gate_up_proj using grouped_mm + if gate_up_wrapper is not None: + lora_data = _get_moe_lora_weights(gate_up_wrapper) + if lora_data is not None: + lora_A, lora_B, scaling, num_experts = lora_data + + # Convert triton_kernels routing format to grouped_mm offsets + # routing_data.exp_indx contains expert index for each token + gather_src = gather_idx.src_indx + permuted_input = hidden_states[gather_src].to(torch.bfloat16) + + # Compute offsets from expert indices (cumsum of token counts) + expert_ids = routing_data.exp_indx + num_tokens_per_expert = torch.bincount( + expert_ids.int(), minlength=num_experts + ).int() + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + + # Use grouped_mm for LoRA computation (much faster than loop!) + lora_delta = _apply_lora_grouped_mm( + permuted_input, + lora_A, + lora_B, + offsets, + scaling, + grouped_mm_func=native_moe_grouped_mm, + ) + pre_activation = pre_activation + lora_delta + + # 3. Apply SwiGLU activation + alpha = getattr(self, "alpha", 1.702) + limit = getattr(self, "limit", 7.0) + swiglu_output = swiglu_torch_forward(pre_activation, alpha, limit) + + # 4. Down projection (MXFP4) + with torch.no_grad(): + base_output = matmul_ogs_fn( + swiglu_output.detach(), # Detach to prevent grad through MXFP4 + self.down_proj, + self.down_proj_bias, + routing_data, + scatter_indx=scatter_idx, + precision_config=self.down_proj_precision_config, + gammas=routing_data.gate_scal if routing_data else None, + ) + + # 5. Add LoRA for down_proj using grouped_mm + if down_wrapper is not None: + lora_data = _get_moe_lora_weights(down_wrapper) + if lora_data is not None: + lora_A, lora_B, scaling, num_experts = lora_data + + # Compute offsets from expert indices (reuse if available) + expert_ids = routing_data.exp_indx + num_tokens_per_expert = torch.bincount( + expert_ids.int(), minlength=num_experts + ).int() + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + + # Use grouped_mm for LoRA computation + lora_delta = _apply_lora_grouped_mm( + swiglu_output, + lora_A, + lora_B, + offsets, + scaling, + grouped_mm_func=native_moe_grouped_mm, + ) + + # Scatter LoRA delta back to output (apply routing weights) + scatter_dst = scatter_idx.dst_indx + gamma = routing_data.gate_scal.unsqueeze(-1) + base_output.index_add_( + 0, scatter_dst, (lora_delta * gamma).to(base_output.dtype) + ) + + return base_output + + +def patch_mxfp4_gpt_oss_for_lora(): + """ + Patch MXFP4 GPT OSS experts for LoRA training. + + This enables finetuning the MXFP4 quantized model (unsloth/gpt-oss-20b) + with LoRA. + + IMPORTANT: Requires triton_kernels for native MXFP4 matmul. + If triton_kernels is not available, users must either: + - Use Mxfp4Config(dequantize=True) when loading, OR + - Use the BF16 model: 'unsloth/gpt-oss-20b-BF16' + """ + # First patch ParamWrapper for MoE separated LoRA (v5 only) + if _is_transformers_v5(): + patch_param_wrapper_for_moe() + + try: + import transformers.integrations.mxfp4 + + Mxfp4GptOssExpertsClass = getattr( + transformers.integrations.mxfp4, "Mxfp4GptOssExperts", None + ) + if Mxfp4GptOssExpertsClass is None: + if UNSLOTH_ENABLE_LOGGING: + logger.warning( + "Unsloth: Mxfp4GptOssExperts not found in transformers.integrations.mxfp4" + ) + return + except Exception as e: + if UNSLOTH_ENABLE_LOGGING: + logger.warning(f"Unsloth: Could not patch MXFP4 GPT OSS for LoRA: {e}") + return + + # Check if already patched + if hasattr(Mxfp4GptOssExpertsClass, "_unsloth_mxfp4_lora_patched"): + return + + # Only patch if triton_kernels is available + # Without triton_kernels, MXFP4 weights cannot be used directly + # Users must use dequantization or BF16 model + if is_triton_kernels_available(): + # Use native MXFP4 + LoRA (keeps weights quantized) + Mxfp4GptOssExpertsClass._original_forward = Mxfp4GptOssExpertsClass.forward + Mxfp4GptOssExpertsClass.forward = forward_mxfp4_gpt_oss_with_lora + Mxfp4GptOssExpertsClass._unsloth_mxfp4_lora_patched = True + if UNSLOTH_ENABLE_LOGGING: + logger.info("Unsloth: Patched MXFP4 GPT OSS MoE for LoRA training") + else: + # triton_kernels NOT available - do NOT patch + # The model will fail with a helpful error if user tries to use MXFP4 without dequantization + Mxfp4GptOssExpertsClass._unsloth_mxfp4_lora_patched = True + if UNSLOTH_ENABLE_LOGGING: + logger.warning( + "Unsloth: triton_kernels is not installed. MXFP4 GPT OSS will NOT be patched for LoRA.\n" + "To train GPT OSS with LoRA, either:\n" + " 1. Install triton_kernels from OpenAI (for native MXFP4), OR\n" + " 2. Use Mxfp4Config(dequantize=True) when loading (dequantizes to bf16), OR\n" + " 3. Use the BF16 model: 'unsloth/gpt-oss-20b-BF16'" + ) + + +TEMPORARY_PATCHES.append(patch_mxfp4_gpt_oss_for_lora) + + +_MXFP4_DEQUANT_WARNED = False + + +def should_dequantize_mxfp4(): + """ + Check if MXFP4 should be dequantized to bf16. + + Returns True if: + - UNSLOTH_MXFP4_NO_DEQUANTIZE is not set or "0", OR + - UNSLOTH_MXFP4_NO_DEQUANTIZE="1" but triton_kernels is not available + + Returns False if: + - UNSLOTH_MXFP4_NO_DEQUANTIZE="1" AND triton_kernels is available + + MEMORY IMPACT: + - MXFP4 quantized: ~10GB for GPT-OSS 20B + - MXFP4 dequantized to bf16: ~40GB for GPT-OSS 20B + + To keep MXFP4 quantized (requires triton_kernels): + export UNSLOTH_MXFP4_NO_DEQUANTIZE=1 + """ + global _MXFP4_DEQUANT_WARNED + + if not UNSLOTH_MXFP4_NO_DEQUANTIZE: + # Default: dequantize to bf16 + if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: + _MXFP4_DEQUANT_WARNED = True + logger.warning( + "Unsloth: MXFP4 will be dequantized to bf16 (~4x memory increase). " + "To keep 4-bit: set UNSLOTH_MXFP4_NO_DEQUANTIZE=1 and install triton_kernels." + ) + return True + + if not is_triton_kernels_available(): + if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: + _MXFP4_DEQUANT_WARNED = True + logger.warning( + "Unsloth: UNSLOTH_MXFP4_NO_DEQUANTIZE=1 but triton_kernels not available. " + "Will dequantize MXFP4 to bf16 (~4x memory increase). " + "Install triton_kernels to keep 4-bit quantized weights." + ) + return True # triton_kernels required for native MXFP4 + + if UNSLOTH_ENABLE_LOGGING and not _MXFP4_DEQUANT_WARNED: + _MXFP4_DEQUANT_WARNED = True + logger.info("Unsloth: Keeping MXFP4 quantized (triton_kernels available)") + return False # Keep MXFP4 quantized + + +def torch_native_forward( + self, + hidden_states: torch.Tensor, + router_indices = None, + routing_weights = None +) -> torch.Tensor: + + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + num_experts = routing_weights.shape[1] + if self.training: + next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) + # with torch.no_grad(): + # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + # expert_mask = expert_mask.permute(2, 1, 0) + # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + # for expert_idx in expert_hitted[:]: + for expert_idx in range(num_experts): + with torch.no_grad(): + # _, token_idx = torch.where(expert_mask[expert_idx[0]]) + token_idx, _ = torch.where(router_indices == expert_idx) + current_state = hidden_states[token_idx] + gate_up = self.gate_up_projs[expert_idx](current_state) + down_proj = self.down_projs[expert_idx] + gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = torch.float32) + # gate, up = gate_up[..., ::2], gate_up[..., 1::2] + # gate = gate.clamp(min=None, max=self.limit) + # up = up.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # gated_output = (up + 1) * glu + + # Force float32 matrix multiply on some down projection modules + gated_output = gated_output.to(torch.float32) + device_type = gated_output.device.type if isinstance(gated_output.device.type, str) and gated_output.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + out = down_proj(gated_output) + weighted_output = out.to(torch.float32) * routing_weights[token_idx, expert_idx, None].to(torch.float32) + next_states.index_add_(0, token_idx, weighted_output) + next_states = next_states.view(batch_size, -1, self.hidden_size) + return next_states.to(torch.float32) + else: + X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) + gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] + gate_up = torch.stack(gate_up_list, dim=0) + dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype + fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = dtype) + # gate = gate_up[..., ::2] + # up_h = gate_up[..., 1::2] + # gate = gate.clamp(max=self.limit) + # up_h = up_h.clamp(min=-self.limit, max=self.limit) + # glu = gate * torch.sigmoid(gate * self.alpha) + # fused = (up_h + 1) * glu + + # Force float32 matrix multiply on down projection only + device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + out_list = [ + down_l(fused[e].to(dtype)) + for e, down_l in enumerate(self.down_projs) + ] + outs = torch.stack(out_list, dim=0) + rw = routing_weights.transpose(0, 1).unsqueeze(-1) + mixed = (outs.to(dtype) * rw.to(dtype)).sum(dim=0) + return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) + pass pass def patch_gpt_oss_linearized(): + """ + Patch GPT OSS for 4bit loading with grouped_mm support. + Only patches the GptOssExperts forward method - keeps original classes for proper weight loading. + """ if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if _should_use_gpt_oss_bnb4bit(): return try: import transformers.models.gpt_oss.modeling_gpt_oss except Exception as e: return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) - # We find down_proj overflows in GPT OSS - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - def forward( - self, - hidden_states: torch.Tensor, - router_indices = None, - routing_weights = None + # Don't replace classes - just patch the forward method of GptOssExperts + # This keeps the original class structure which properly handles 4-bit weight loading + backend = select_moe_backend() + + if backend == "grouped_mm": + + def experts_forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None ) -> torch.Tensor: + return forward_native_grouped_mm(self, hidden_states, router_indices, routing_weights) + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward + else: - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) - num_experts = routing_weights.shape[1] - if self.training: - next_states = torch.zeros_like(hidden_states, dtype=torch.float32, device=hidden_states.device) - # with torch.no_grad(): - # expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - # expert_mask = expert_mask.permute(2, 1, 0) - # expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - # for expert_idx in expert_hitted[:]: - for expert_idx in range(num_experts): - with torch.no_grad(): - # _, token_idx = torch.where(expert_mask[expert_idx[0]]) - token_idx, _ = torch.where(router_indices == expert_idx) - current_state = hidden_states[token_idx] - gate_up = self.gate_up_projs[expert_idx](current_state) - down_proj = self.down_projs[expert_idx] - gated_output = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = torch.float32) - # gate, up = gate_up[..., ::2], gate_up[..., 1::2] - # gate = gate.clamp(min=None, max=self.limit) - # up = up.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # gated_output = (up + 1) * glu - - # Force float32 matrix multiply on some down projection modules - gated_output = gated_output.to(torch.float32) - device_type = gated_output.device.type if isinstance(gated_output.device.type, str) and gated_output.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - out = down_proj(gated_output) - weighted_output = out.to(torch.float32) * routing_weights[token_idx, expert_idx, None].to(torch.float32) - next_states.index_add_(0, token_idx, weighted_output) - next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states.to(torch.float32) - else: - X_rep = hidden_states.unsqueeze(0).expand(num_experts, -1, -1) - gate_up_list = [up_l(X_rep[e]) for e, up_l in enumerate(self.gate_up_projs)] - gate_up = torch.stack(gate_up_list, dim=0) - dtype = torch.float32 if hidden_states.dtype != torch.bfloat16 else hidden_states.dtype - fused = swiglu_torch_forward(gate_up, self.alpha, self.limit, dtype = dtype) - # gate = gate_up[..., ::2] - # up_h = gate_up[..., 1::2] - # gate = gate.clamp(max=self.limit) - # up_h = up_h.clamp(min=-self.limit, max=self.limit) - # glu = gate * torch.sigmoid(gate * self.alpha) - # fused = (up_h + 1) * glu - - # Force float32 matrix multiply on down projection only - device_type = fused.device.type if isinstance(fused.device.type, str) and fused.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - out_list = [ - down_l(fused[e].to(dtype)) - for e, down_l in enumerate(self.down_projs) - ] - outs = torch.stack(out_list, dim=0) - rw = routing_weights.transpose(0, 1).unsqueeze(-1) - mixed = (outs.to(dtype) * rw.to(dtype)).sum(dim=0) - return mixed.view(batch_size, -1, self.hidden_size).to(hidden_states.dtype) - pass - pass - GptOssExperts.forward = forward - pass + def experts_forward( + self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None + ) -> torch.Tensor: + return torch_native_forward(self, hidden_states, router_indices, routing_weights) - transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GptOssExperts - transformers.models.gpt_oss.modeling_gpt_oss.GptOssTopKRouter = GptOssTopKRouter - transformers.models.gpt_oss.modeling_gpt_oss.GptOssMLP = GptOssMLP + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts.forward = experts_forward + + if UNSLOTH_ENABLE_LOGGING: logger.info(f"Unsloth: Patched GPT OSS MoE for 4bit loading (backend: {backend})") return + + pass TEMPORARY_PATCHES.append(patch_gpt_oss_linearized) @@ -746,6 +1807,7 @@ def patch_GptOssAttention(): flex_attention_with_sink_decoding, flex_attention_add_sinks, ) + assert flex_attention_with_sink is not None except Exception as e: return raise_error("flex_attention_with_sink", e) @@ -757,6 +1819,7 @@ def patch_GptOssAttention(): return raise_error("transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention", e) torch._dynamo.config.cache_size_limit = 256 + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -786,7 +1849,8 @@ def inplace_eager_attention_forward( bsz, n_heads, qlen, _ = query.shape bsz, n_heads, kvlen, _ = key_states.shape - combined_logits = key_states.new_empty((bsz, n_heads, qlen, kvlen+1)) + out_dtype = torch.result_type(query, key_states) + combined_logits = key_states.new_empty((bsz, n_heads, qlen, kvlen + 1), dtype=out_dtype) attn_weights = matmul(query, key_states.transpose(2, 3), out = combined_logits[:,:,:,:kvlen]) attn_weights *= scaling @@ -805,9 +1869,11 @@ def inplace_eager_attention_forward( probs = combined_logits scores = probs[..., :-1] # we drop the sink here attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) + attn_weights = attn_weights.to(value_states.dtype) attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None + pass def eager_attention_forward( @@ -838,14 +1904,16 @@ def eager_attention_forward( probs = combined_logits scores = probs[..., :-1] # we drop the sink here attn_weights = F_dropout(scores, p=dropout, training=module.training, inplace=True) + attn_weights = attn_weights.to(value_states.dtype) attn_output = matmul(attn_weights, value_states, out = query) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None + pass apply_rotary_pos_emb = torch_compile(apply_rotary_pos_emb) - if False: # Version(torch.__version__) >= Version("2.10.0"): - eager_attention_forward = torch_compile(eager_attention_forward, dynamic = None, fullgraph = True) + if False: # Version(torch.__version__) >= Version("2.10.0"): + eager_attention_forward = torch_compile(eager_attention_forward, dynamic=None, fullgraph=True) else: # Too many recompilation failures on 2.8.0, 2.9.0 eager_attention_forward = inplace_eager_attention_forward @@ -867,8 +1935,22 @@ def forward_function( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: + cache_dtype = getattr(past_key_value, "dtype", None) + if cache_dtype is None and hasattr(past_key_value, "layers"): + try: + cache_layer = past_key_value.layers[self.layer_idx] + if hasattr(cache_layer, "keys") and cache_layer.keys is not None: + cache_dtype = cache_layer.keys.dtype + except Exception: + cache_dtype = None + if cache_dtype is not None and key_states.dtype != cache_dtype: + key_states = key_states.to(cache_dtype) + value_states = value_states.to(cache_dtype) cache_kwargs = {"cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if key_states.dtype != query_states.dtype or value_states.dtype != query_states.dtype: + key_states = key_states.to(query_states.dtype) + value_states = value_states.to(query_states.dtype) # flex_attention_with_sink only works for training since KV cache is wrong # switch to flex_attention_with_sink which allows all to work @@ -923,6 +2005,7 @@ def forward_function( pass functions = [] + def forward( self, hidden_states: torch.Tensor, @@ -933,7 +2016,9 @@ def forward( **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs) + functions.append(forward) + def forward( self, hidden_states: torch.Tensor, @@ -944,6 +2029,7 @@ def forward( **kwargs: KWARGS_TYPE, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: return forward_function(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs) + functions.append(forward) patch_function_past_key_values(transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention, "forward", functions) # Set env variable for padding purposes @@ -1018,6 +2104,7 @@ def return_attention_mask(*args, **kwargs): flex_attention_with_sink_decoding, flex_attention_add_sinks, ) + apply_rotary_pos_emb = torch_compile(apply_rotary_pos_emb) try: from transformers.integrations.mxfp4 import mlp_forward @@ -1045,6 +2132,7 @@ def pre_attention_decoding( key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) return query_states, key_states, value_states, input_shape pass + # Do flex_attention_with_sink_decoding with cannot be compiled # attn_output, logsumexp = flex_attention_with_sink_decoding( # self, @@ -1057,6 +2145,7 @@ def post_attention_decoding(self_attn, attn_output, logsumexp, input_shape): attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self_attn.o_proj(attn_output) return attn_output + pass # RMSNorm forward @@ -1108,6 +2197,7 @@ def pre_forward( use_block_ptr = True, logging = UNSLOTH_ENABLE_LOGGING, ) + @_torch_compile(dynamic = None, fullgraph = True, options = fused_torch_compile_options) def post_forward( self, @@ -1180,15 +2270,13 @@ def forward( if inputs_embeds is None: # Account for CPU offloaded embed_tokens embed_device = self.embed_tokens.weight.device - inputs_embeds = self.embed_tokens(input_ids.to(embed_device, non_blocking = True)).to(input_ids.device) + inputs_embeds = self.embed_tokens(input_ids.to(embed_device, non_blocking=True)).to(input_ids.device) if not self.training: inputs_embeds.requires_grad_(False) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + past_seen_tokens = (past_key_values.get_seq_length() if past_key_values is not None else 0) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) hidden_states = inputs_embeds @@ -1221,6 +2309,8 @@ def forward( # Add hack since residuals need to clone outside of the torch.compile region?? # This forces it to free past residuals torch.compiler.cudagraph_mark_step_begin() + # Initialize for common return path + all_hidden_states = None for decoder_layer in self.layers: hidden_states, residual = inference_forward( decoder_layer, @@ -1233,9 +2323,14 @@ def forward( position_embeddings, **kwargs, ) - if hasattr(decoder_layer.mlp.experts, "gate_up_projs"): - hidden_states = moe_forward_inference(decoder_layer.mlp, hidden_states) - elif decoder_layer.mlp.experts.__class__.__name__ == "Mxfp4GptOssExperts": + _actual_experts = _unwrap_peft_experts(decoder_layer.mlp.experts) + if hasattr(_actual_experts, "gate_up_projs"): + hidden_states = moe_forward_inference( + decoder_layer.mlp, hidden_states + ) + elif ( + _actual_experts.__class__.__name__ == "Mxfp4GptOssExperts" + ): if mlp_forward is None: raise RuntimeError("Unsloth: MXFP4 forward is not found") hidden_states, _ = mlp_forward(decoder_layer.mlp, hidden_states) @@ -1245,8 +2340,36 @@ def forward( pass hidden_states = rms_layernorm_forward(self.norm, hidden_states) else: + # Fix 2D attention_mask being passed to eager_attention_forward which expects 4D + if ( + self.training + and attention_mask is not None + and attention_mask.dim() == 2 + ): + bsz, seq_len = attention_mask.shape + min_dtype = torch.finfo(inputs_embeds.dtype).min + # 1. Expand padding mask to (B, 1, 1, S) + expanded_mask = attention_mask[:, None, None, :].to( + dtype=inputs_embeds.dtype + ) + expanded_mask = (1.0 - expanded_mask) * min_dtype + + # 2. Causal mask (1, 1, S, S) + causal_mask = torch.full((seq_len, seq_len), min_dtype, device=attention_mask.device, dtype=inputs_embeds.dtype) + causal_mask = torch.triu(causal_mask, diagonal=1) + attention_mask = causal_mask[None, None, :, :] + expanded_mask + + # Accumulate hidden states if requested + output_hidden_states = kwargs.get( + "output_hidden_states", self.config.output_hidden_states + ) + all_hidden_states = () if output_hidden_states else None + for decoder_layer in self.layers: - mask = attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask + if output_hidden_states: + all_hidden_states += (hidden_states,) + + mask = (attention_mask[decoder_layer.attention_type] if isinstance(attention_mask, dict) else attention_mask) hidden_states = decoder_layer( hidden_states, attention_mask=mask, @@ -1259,12 +2382,18 @@ def forward( ) pass hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + # Fix float16 / float32 mismatching hidden_states = hidden_states.to(inputs_embeds.dtype) return process_return(MoeModelOutputWithPast, { - "last_hidden_state" : hidden_states, - "past_key_values" : past_key_values, - }) + "last_hidden_state": hidden_states, + "past_key_values": past_key_values, + "hidden_states": all_hidden_states, + }) + patch_function(transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel, "forward", forward, match_level = "relaxed") pass TEMPORARY_PATCHES.append(patch_GptOssModel) @@ -1282,9 +2411,12 @@ def forward( load_harmony_encoding, ReasoningEffort ) + encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) except: pass + + def encode_conversations_with_harmony( messages, reasoning_effort = "medium", @@ -1301,14 +2433,15 @@ def encode_conversations_with_harmony( assert reasoning_effort in ("low", "medium", "high") match reasoning_effort: - case "low": harmony_reasoning = ReasoningEffort.LOW - case "medium": harmony_reasoning = ReasoningEffort.MEDIUM - case "high": harmony_reasoning = ReasoningEffort.HIGH + case "low": harmony_reasoning = ReasoningEffort.LOW + case "medium": harmony_reasoning = ReasoningEffort.MEDIUM + case "high": harmony_reasoning = ReasoningEffort.HIGH convos = [] # Create system message import datetime + today = datetime.datetime.today().strftime("%Y-%m-%d") system = Message.from_role_and_content(Role.SYSTEM, SystemContent.new() @@ -1316,7 +2449,7 @@ def encode_conversations_with_harmony( .with_reasoning_effort(harmony_reasoning) .with_conversation_start_date(today) .with_knowledge_cutoff("2024-06") - .with_required_channels(["analysis", "commentary", "final"]) + .with_required_channels(["analysis", "commentary", "final"]), ) convos.append(system) @@ -1340,27 +2473,21 @@ def encode_conversations_with_harmony( for message in messages: if message["role"] == "user": - convos.append( - Message.from_role_and_content(Role.USER, message["content"]) - ) + convos.append(Message.from_role_and_content(Role.USER, message["content"])) elif message["role"] == "assistant": if "thinking" in message: x = Message.from_role_and_content(Role.ASSISTANT, message["content"]) x = x.with_channel("analysis") elif "tool_calls" in message: - x = Message.from_role_and_content(Role.ASSISTANT, message['tool_calls'][0]["arguments"]) - x = x.with_channel("commentary")\ - .with_recipient(f"functions.{message['tool_calls'][0]['name']}")\ - .with_content_type("json") + x = Message.from_role_and_content(Role.ASSISTANT, message["tool_calls"][0]["arguments"]) + x = x.with_channel("commentary").with_recipient(f"functions.{message['tool_calls'][0]['name']}").with_content_type("json") else: x = Message.from_role_and_content(Role.ASSISTANT, message["content"]) x = x.with_channel("final") convos.append(x) elif message["role"] == "tool": - x = Message.from_author_and_content( - Author.new(Role.TOOL, f"functions.{message['name']}"), - message["content"], - ).with_recipient("assistant").with_channel("commentary") + x = Message.from_author_and_content(Author.new(Role.TOOL, f"functions.{message['name']}"), message["content"]) + x = x.with_recipient("assistant").with_channel("commentary") convos.append(x) pass @@ -1380,8 +2507,10 @@ def encode_conversations_with_harmony( # AutoConfig error: 'GptOssConfig' object has no attribute 'max_position_embeddings' try: from transformers.configuration_utils import layer_type_validation + try: from transformers.configuration_utils import PreTrainedConfig + PretrainedConfig = PreTrainedConfig except: from transformers.configuration_utils import PretrainedConfig @@ -1463,9 +2592,7 @@ def __init__( self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads self.layer_types = layer_types if self.layer_types is None: - self.layer_types = [ - "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) - ] + self.layer_types = ["sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)] layer_type_validation(self.layer_types) # Validate the correctness of rotary position embeddings parameters @@ -1588,6 +2715,7 @@ def __init__( def patch_gpt_oss_config(): try: import transformers.models.gpt_oss.configuration_gpt_oss + transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig except Exception as e: return raise_error("transformers.models.gpt_oss.configuration_gpt_oss", e) @@ -1601,7 +2729,146 @@ def patch_gpt_oss_config(): patch_function(transformers.models.gpt_oss.configuration_gpt_oss, "GptOssConfig", GptOssConfig) except Exception as e: return raise_error("transformers.models.gpt_oss.configuration_gpt_oss", e) + pass TEMPORARY_PATCHES.append(patch_gpt_oss_config) except Exception as e: raise_error("transformers.models.gpt_oss.configuration_gpt_oss.GptOssConfig", e) + + +def patch_gpt_oss_init_weights_modulelist_fix(): + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + return raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + + GptOssPreTrainedModel = ( + transformers.models.gpt_oss.modeling_gpt_oss.GptOssPreTrainedModel + ) + GptOssExperts = transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts + if getattr(GptOssPreTrainedModel, "_unsloth_init_weights_fixed", False): + return + _original_init_weights = GptOssPreTrainedModel._init_weights + + def _patched_init_weights(self, module): + if isinstance(module, GptOssExperts) and not hasattr(module, "gate_up_proj"): + std = self.config.initializer_range + for up in getattr(module, "gate_up_projs", []): + init.normal_(up.weight, mean=0.0, std=std) + if up.bias is not None: + init.zeros_(up.bias) + for down in getattr(module, "down_projs", []): + init.normal_(down.weight, mean=0.0, std=std) + if down.bias is not None: + init.zeros_(down.bias) + return + _original_init_weights(self, module) + + patch_function(GptOssPreTrainedModel, "_init_weights", _patched_init_weights) + GptOssPreTrainedModel._unsloth_init_weights_fixed = True +pass +TEMPORARY_PATCHES.append(patch_gpt_oss_init_weights_modulelist_fix) + + +# ============================================================================ +# Patch GptOssForCausalLM.forward for GRPO training +# When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits +# ============================================================================ +def patch_gpt_oss_for_grpo(): + """ + Patch GptOssForCausalLM.forward for GRPO training. + When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits. + This fixes the matrix multiplication dimension mismatch issue in GRPO training. + """ + if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + return + + try: + import transformers.models.gpt_oss.modeling_gpt_oss + from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssForCausalLM, + MoeCausalLMOutputWithPast, + ) + + _original_causal_lm_forward = GptOssForCausalLM.forward + + def _patched_causal_lm_forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + logits_to_keep=0, + **kwargs, + ): + # This Unsloth Zoo code section is licensed under AGPL3 + + RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1" + + if not RETURN_HIDDEN_STATES: + # Normal forward pass + return _original_causal_lm_forward( + self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + # RETURN_HIDDEN_STATES mode - return hidden_states instead of logits + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + # Apply slice_indices to hidden_states (same indexing as for logits) + if logits_to_keep != 0: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + hidden_states = hidden_states[:, slice_indices, :] + + # Return hidden_states as "logits" for GRPO to use + return MoeCausalLMOutputWithPast( + loss=None, + aux_loss=getattr(outputs, 'aux_loss', None), + logits=hidden_states, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=getattr(outputs, 'router_logits', None), + ) + + GptOssForCausalLM.forward = _patched_causal_lm_forward + if UNSLOTH_ENABLE_LOGGING: + logger.info("Unsloth: Patched GptOssForCausalLM.forward for GRPO hidden states.") + + except Exception as e: + if UNSLOTH_ENABLE_LOGGING: + logger.warning(f"Unsloth: Could not patch GptOssForCausalLM.forward: {e}") +pass +TEMPORARY_PATCHES.append(patch_gpt_oss_for_grpo)