Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions unsloth_zoo/empty_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
import re
import os
import functools
from copy import deepcopy
from .utils import get_quant_type
from .log import logger
Expand Down Expand Up @@ -343,10 +344,11 @@ def patched_supports_lora(model):
if not hasattr(vllm_lora_model_manager.create_lora_manager, "_unsloth_gemma4_patch"):
original_create_lora_manager = vllm_lora_model_manager.create_lora_manager

@functools.wraps(original_create_lora_manager)
def patched_create_lora_manager(model, *args, **kwargs):
if model.__class__.__name__ == "Gemma4ForConditionalGeneration":
lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager)
return lora_manager_cls(model = model, *args, **kwargs)
return lora_manager_cls(model, *args, **kwargs)
return original_create_lora_manager(model, *args, **kwargs)
Comment on lines +348 to +352

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the patched_supports_lora function, using isinstance here is more robust than comparing class names as strings.

Suggested change
def patched_create_lora_manager(model, *args, **kwargs):
if model.__class__.__name__ == "Gemma4ForConditionalGeneration":
lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager)
return lora_manager_cls(model, *args, **kwargs)
return original_create_lora_manager(model, *args, **kwargs)
def patched_create_lora_manager(model, *args, **kwargs):
if isinstance(model, Gemma4ForConditionalGeneration):
lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager)
return lora_manager_cls(model, *args, **kwargs)
return original_create_lora_manager(model, *args, **kwargs)


patched_create_lora_manager._unsloth_gemma4_patch = True
Expand Down Expand Up @@ -758,6 +760,8 @@ def get_model_layer_config(return_non_layered=True):
layer_templates = {
'standard_layers': {
"model.language_model.layers.{kk}.layer_scalar",
"model.language_model.layers.{kk}.per_layer_input_gate",
"model.language_model.layers.{kk}.per_layer_projection",
"model.language_model.layers.{kk}.self_attn.q_proj",
"model.language_model.layers.{kk}.self_attn.k_proj",
"model.language_model.layers.{kk}.self_attn.v_proj",
Expand All @@ -769,6 +773,8 @@ def get_model_layer_config(return_non_layered=True):
"model.language_model.layers.{kk}.mlp.down_proj",

"model.layers.{kk}.layer_scalar",
"model.layers.{kk}.per_layer_input_gate",
"model.layers.{kk}.per_layer_projection",
"model.layers.{kk}.self_attn.q_proj",
"model.layers.{kk}.self_attn.k_proj",
"model.layers.{kk}.self_attn.v_proj",
Expand Down Expand Up @@ -801,6 +807,7 @@ def get_model_layer_config(return_non_layered=True):
"model.language_model.layers.{kk}.post_attention_layernorm",
"model.language_model.layers.{kk}.pre_feedforward_layernorm",
"model.language_model.layers.{kk}.post_feedforward_layernorm",
"model.language_model.layers.{kk}.post_per_layer_input_norm",
"model.language_model.layers.{kk}.self_attn.q_norm",
"model.language_model.layers.{kk}.self_attn.k_norm",
"model.language_model.layers.{kk}.cross_attn.q_norm",
Expand All @@ -809,6 +816,7 @@ def get_model_layer_config(return_non_layered=True):
"model.layers.{kk}.post_attention_layernorm",
"model.layers.{kk}.pre_feedforward_layernorm",
"model.layers.{kk}.post_feedforward_layernorm",
"model.layers.{kk}.post_per_layer_input_norm",
"model.layers.{kk}.self_attn.q_norm",
"model.layers.{kk}.self_attn.k_norm",
"model.visual.blocks.{kk}.norm1",
Expand Down Expand Up @@ -925,7 +933,6 @@ def get_model_layer_config(return_non_layered=True):
# qwen 3 vl
"model.visual.deepstack_merger_list.{kk}.linear_fc1",
"model.visual.deepstack_merger_list.{kk}.linear_fc2",
"model.visual.merger.linear_fc{kk}",

},
"non_layered_components":{
Expand Down Expand Up @@ -965,6 +972,8 @@ def get_model_layer_config(return_non_layered=True):
# qwen 3 vl
"model.visual.pos_embed",
"model.visual.merger.norm",
"model.visual.merger.linear_fc1",
"model.visual.merger.linear_fc2",
}
}

Expand Down Expand Up @@ -1077,13 +1086,22 @@ def store(name, value):
get_state_dict(f"{prefix}.in_proj_qkv", 0, state_dict, gdn.in_proj_qkv, slice_weights=False)
get_state_dict(f"{prefix}.in_proj_z", 0, state_dict, gdn.in_proj_z, slice_weights=False)

get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba)
get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba)
if hasattr(gdn, "in_proj_ba"):
get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba)
get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba)
else:
get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_b, slice_weights=False)
get_state_dict(f"{prefix}.in_proj_a", 0, state_dict, gdn.in_proj_a, slice_weights=False)

store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data)
store(f"{prefix}.dt_bias", gdn.dt_bias.data)
store(f"{prefix}.A_log", gdn.A_log.data)
Comment on lines +1096 to +1098

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Using .data is discouraged in modern PyTorch as it can lead to silent bugs by bypassing autograd tracking and safety checks. It is recommended to use .detach() instead when a view of the tensor is needed without gradient tracking.

Suggested change
store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data)
store(f"{prefix}.dt_bias", gdn.dt_bias.data)
store(f"{prefix}.A_log", gdn.A_log.data)
store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.detach())
store(f"{prefix}.dt_bias", gdn.dt_bias.detach())
store(f"{prefix}.A_log", gdn.A_log.detach())


norm = getattr(gdn, "norm", None)
norm_weight = getattr(norm, "weight", None) if norm is not None else None
if norm_weight is not None:
store(f"{prefix}.norm.weight", norm_weight.data)

get_state_dict(f"{prefix}.out_proj", 0, state_dict, gdn.out_proj)
pass

Expand Down
4 changes: 2 additions & 2 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,8 @@ def grpo_accumulated_loss(
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False)

for module in unwrapped_model.modules():
if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "io_same_decice"):
module._hf_hook.io_same_decice = False
if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "io_same_device"):
module._hf_hook.io_same_device = False
if hasattr(module, "rope_deltas"):
module.rope_deltas = None
pass
Expand Down
55 changes: 54 additions & 1 deletion unsloth_zoo/temporary_patches/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import os
from .common import TEMPORARY_PATCHES
from .utils import raise_error
from .utils import raise_error, patch_function


# ============================================================================
Expand Down Expand Up @@ -618,3 +618,56 @@ def forward(self, hidden_states, position_embeddings, attention_mask=None, **kwa
Gemma4AudioAttention.forward = forward
pass
TEMPORARY_PATCHES.append(patch_Gemma4AudioAttention)


# Gemma-4 float16 MLP overflow fix.
#
# `down_proj(act_fn(gate_proj(x)) * up_proj(x))` overflows fp16 at layers.0
# (E2B) / layers.1 (E4B): the product + fp16 matmul accumulator saturate to
# +-inf, poison the residual stream, and generation samples NaN logits that
# trip the CUDA categorical sampler on GRPO step ~2.
#
# Fix: fp32 gate*up, clamp to a safe bound, fp16 cast, nan_to_num on the
# down_proj output. Gated on gate output dtype so bf16/fp32 users see no
# change and no env flag is required. RMSNorm / Attention / Embedding
# patches are unnecessary (verified by bisection - identical loss/kl/grad
# trajectories).


def patch_Gemma4TextMLP():
"""fp16 overflow clamp for Gemma4TextMLP.

Does gate*up in fp32, clamps to a safe fp16 bound, then nan_to_nums
the down_proj output. Self-gated on gate dtype - no-op on bf16/fp32.
"""
try:
import transformers.models.gemma4.modeling_gemma4 as mod
except ImportError:
return
try:
Gemma4TextMLP = mod.Gemma4TextMLP
except AttributeError as e:
return raise_error("Gemma4TextMLP.forward", e)

# Largest value representable in both fp16 and bf16 (65536 rounds to
# fp16 inf).
_SAFE_FP16 = 65280.0

def forward(self, x):
gate = self.gate_proj(x)
# Check matmul output dtype so autocast / PEFT fp16 casts are caught.
if gate.dtype != torch.float16:
return self.down_proj(self.act_fn(gate) * self.up_proj(x))
product = self.act_fn(gate.float()) * self.up_proj(x).float()
product = torch.clamp(product, min=-_SAFE_FP16, max=_SAFE_FP16)
out = self.down_proj(product.to(gate.dtype))
# Zero overflows so the residual identity path survives.
return torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
try:
patch_function(
Gemma4TextMLP, "forward", forward, fullgraph=False,
)
except Exception as e:
return raise_error("Gemma4TextMLP.forward", e)
pass
TEMPORARY_PATCHES.append(patch_Gemma4TextMLP)
35 changes: 32 additions & 3 deletions unsloth_zoo/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,13 @@ def _is_fused_module(name: str) -> bool:
if hasattr(layer, "layer_scalar"):
state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data
quant_state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data
for per_layer_name in ("per_layer_input_gate", "per_layer_projection"):
per_layer_module = getattr(layer, per_layer_name, None)
if per_layer_module is not None and hasattr(per_layer_module, "weight"):
get_state_dict(
f"{vllm_text_model_prefix}.layers.{kk}.{per_layer_name}",
0, state_dict, per_layer_module,
)
pass

if len(skipped_layernorms) != 0:
Expand Down Expand Up @@ -1216,6 +1223,8 @@ def assert_same_state_dict(old_state_dict, new_state_dict):
def _normalize_state_dict_tensor(value):
if isinstance(value, torch.nn.Parameter):
value = value.detach()
if not isinstance(value, torch.Tensor):
return value
if value.is_sparse:
value = value.to_dense()
return value.contiguous()
Comment on lines +1223 to +1230

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The _normalize_state_dict_tensor function can be simplified by chaining the operations directly.

    def _normalize_state_dict_tensor(value):
        if isinstance(value, torch.nn.Parameter):
            value = value.detach()
        return value.to_dense().contiguous() if value.is_sparse else value.contiguous()

Expand Down Expand Up @@ -1396,8 +1405,14 @@ def _override_to(self, *args, **kwargs):
if layer_name in quant_state_dict:
# for attributes of type nn.Parameter, there's no .weight
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)
layer = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False)
exec(f"new_model.{layer_name_br} = layer")
value = _unwrap_tensor(weight)
parent_expr, attr_name = layer_name_br.rsplit(".", 1)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

If layer_name_br does not contain a dot (e.g., if a top-level attribute is added to layer_names), rsplit(".", 1) will raise a ValueError. While current templates include dots, adding a guard or using a safer way to split would be more robust for future changes.

parent_module = eval(f"new_model.{parent_expr}")
if attr_name in getattr(parent_module, "_buffers", {}):
parent_module._buffers[attr_name] = value
else:
layer = torch.nn.Parameter(value, requires_grad = False)
exec(f"new_model.{layer_name_br} = layer")
continue
elif fp8_weight_scale is not None:
if fp8_weight_scale.ndim == 1:
Expand Down Expand Up @@ -1448,9 +1463,23 @@ def _override_to(self, *args, **kwargs):
layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False)
layer.bias = bias
else:
# LayerNorms (including vision norms)
# LayerNorms (including vision norms) and depthwise Conv1d
weight_param = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad=False)
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)
if layer_name.endswith(".conv1d"):
target = eval(f"new_model.{layer_name_br}")
w = _unwrap_tensor(weight)
out_channels = w.shape[0]
kernel_size = w.shape[-1]
target.out_channels = out_channels
target.in_channels = out_channels
target.groups = out_channels
target.kernel_size = (kernel_size,)
target.padding = (kernel_size - 1,)
target.weight = weight_param
if bias is not None:
target.bias = bias
continue
Comment on lines +1469 to +1482

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of eval() can be a security risk if the string being evaluated can be influenced by external input. While layer_name_br is constructed from internal data here, it's a good practice to avoid eval() for maintainability and security. Consider refactoring to use a safer method for accessing nested attributes, even though it might be more complex with attribute names containing [].

# Set weight
exec(f"new_model.{layer_name_br}.weight = None")
exec(f"new_model.{layer_name_br}.weight = weight_param")
Expand Down