From 7fa143f115b3c7d91bda04540073a9f8b10d09e5 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 30 Mar 2026 13:41:05 +0000 Subject: [PATCH 01/11] [WIP] initial fast_inference support for qwen3.5 --- unsloth_zoo/empty_model.py | 113 ++++++++++++++++++++++++++++++++++++- unsloth_zoo/vllm_utils.py | 15 ++++- 2 files changed, 123 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index f9ff7cba0..925c215ef 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -17,6 +17,7 @@ __all__ = [ "create_empty_model", "set_additional_modules", + "extract_gdn_layers", "extract_vision_layers", "get_model_layer_config", "compare_attributes", @@ -280,6 +281,12 @@ def create_empty_causal_lm(config, dtype = torch.float16): new_config.vocab_size = 1 new_config.pad_token_id = 0 + # Minimize GDN (Gated Delta Net) layer sizes for hybrid models like Qwen3.5 + for attr in ("linear_num_key_heads", "linear_num_value_heads"): + if hasattr(new_config, attr): setattr(new_config, attr, 1) + for attr in ("linear_key_head_dim", "linear_value_head_dim", "linear_conv_kernel_dim"): + if hasattr(new_config, attr): setattr(new_config, attr, 1) + # Set attention module head_dim head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) new_config.update({"head_dim" : head_dim}) @@ -353,6 +360,12 @@ def _init_weights(self, module): "pad_token_id": 1, }) + # Minimize GDN (Gated Delta Net) layer sizes for hybrid models like Qwen3.5 + for attr in ("linear_num_key_heads", "linear_num_value_heads"): + if hasattr(new_config.text_config, attr): setattr(new_config.text_config, attr, 1) + for attr in ("linear_key_head_dim", "linear_value_head_dim", "linear_conv_kernel_dim"): + if hasattr(new_config.text_config, attr): setattr(new_config.text_config, attr, 1) + # Common vision attributes _set_config_attrs(new_config.vision_config, { "hidden_size": 1, @@ -374,7 +387,8 @@ def _init_weights(self, module): new_config.vision_config.out_hidden_size = 1 elif model_type == "qwen3_vl": new_config.vision_config.out_hidden_size = 1 - + elif model_type in ("qwen3_5", "qwen3_5_moe"): + new_config.vision_config.out_hidden_size = 1 num_layers = max(text_layers, vision_layers) new_model = model_cls(new_config) @@ -403,6 +417,9 @@ def set_additional_modules(new_model, quant_state_dict, config): if hasattr(new_model, "language_model"): language_model = new_model.language_model language_model_prefix = "model.language_model" + elif hasattr(new_model, "model") and hasattr(new_model.model, "language_model"): + language_model = new_model.model.language_model + language_model_prefix = "model.language_model" else: language_model_prefix = "model" language_model = new_model.model @@ -539,6 +556,25 @@ def get_model_layer_config(return_non_layered=True): "model.layers.{kk}.mlp.up_proj", "model.layers.{kk}.mlp.gate_up_proj", # for extracting from vLLM (phi3 architecture) "model.layers.{kk}.mlp.down_proj", + + # Qwen3.5 Gated Delta Net (GDN) linear attention layers + "model.language_model.layers.{kk}.linear_attn.in_proj_qkv", + "model.language_model.layers.{kk}.linear_attn.in_proj_z", + "model.language_model.layers.{kk}.linear_attn.in_proj_b", + "model.language_model.layers.{kk}.linear_attn.in_proj_a", + "model.language_model.layers.{kk}.linear_attn.conv1d", + "model.language_model.layers.{kk}.linear_attn.out_proj", + "model.language_model.layers.{kk}.linear_attn.dt_bias", + "model.language_model.layers.{kk}.linear_attn.A_log", + + "model.layers.{kk}.linear_attn.in_proj_qkv", + "model.layers.{kk}.linear_attn.in_proj_z", + "model.layers.{kk}.linear_attn.in_proj_b", + "model.layers.{kk}.linear_attn.in_proj_a", + "model.layers.{kk}.linear_attn.conv1d", + "model.layers.{kk}.linear_attn.out_proj", + "model.layers.{kk}.linear_attn.dt_bias", + "model.layers.{kk}.linear_attn.A_log", }, 'layernorms': { "model.language_model.layers.{kk}.input_layernorm", @@ -567,6 +603,10 @@ def get_model_layer_config(return_non_layered=True): # qwen3 vl "model.visual.deepstack_merger_list.{kk}.norm", + + # Qwen3.5 GDN linear attention norms + "model.language_model.layers.{kk}.linear_attn.norm", + "model.layers.{kk}.linear_attn.norm", }, 'vision_layers': { @@ -759,6 +799,75 @@ def _get_nested_attr(obj, attr_path: str): return None +def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_state_dict): + """ + Extracts Gated Delta Net (GDN) linear attention weights from a vLLM + linear_attn module into HF-compatible state_dict entries. + Used by Qwen3.5 hybrid models which mix GDN and standard attention layers. + """ + gdn = gdn_module + + # in_proj_qkvz (non-LoRA: fused Q,K,V,Z) or in_proj_qkv + in_proj_z (LoRA) + if hasattr(gdn, "in_proj_qkvz"): + proj = getattr(gdn.in_proj_qkvz, "base_layer", gdn.in_proj_qkvz) + weight = proj.weight + output_sizes = list(proj.output_sizes) + # cumsum offsets: [0, key_dim, key_dim*2, key_dim*2+value_dim, key_dim*2+value_dim*2] + offsets = [0] + for s in output_sizes: + offsets.append(offsets[-1] + s) + # Slots 0-2 (Q,K,V) -> HF's in_proj_qkv; Slot 3 (Z) -> HF's in_proj_z + qkv_weight = weight[offsets[0]:offsets[3]] + z_weight = weight[offsets[3]:offsets[4]] + state_dict[f"{prefix}.in_proj_qkv.weight"] = qkv_weight + quant_state_dict[f"{prefix}.in_proj_qkv.weight"] = qkv_weight + state_dict[f"{prefix}.in_proj_z.weight"] = z_weight + quant_state_dict[f"{prefix}.in_proj_z.weight"] = z_weight + # Handle FP8 weight scales if present + if weight.dtype == torch.float8_e4m3fn: + if hasattr(proj, 'weight_scale'): + ws = proj.weight_scale + elif hasattr(proj, 'weight_scale_inv'): + ws = proj.weight_scale_inv + else: + ws = None + if ws is not None and ws.ndim == 2 and ws.shape[1] > 1: + block_size = proj.weight_block_size[0] + scale_offsets = [x // block_size for x in offsets] + scale_suffix = '.weight_scale_inv' + qkv_scale = ws[scale_offsets[0]:scale_offsets[3]] + z_scale = ws[scale_offsets[3]:scale_offsets[4]] + state_dict[f"{prefix}.in_proj_qkv{scale_suffix}"] = qkv_scale + quant_state_dict[f"{prefix}.in_proj_qkv{scale_suffix}"] = qkv_scale + state_dict[f"{prefix}.in_proj_z{scale_suffix}"] = z_scale + quant_state_dict[f"{prefix}.in_proj_z{scale_suffix}"] = z_scale + else: + # LoRA mode: separate in_proj_qkv and in_proj_z + get_state_dict(f"{prefix}.in_proj_qkv", 0, state_dict, gdn.in_proj_qkv, slice_weights=False) + get_state_dict(f"{prefix}.in_proj_z", 0, state_dict, gdn.in_proj_z, slice_weights=False) + + # in_proj_ba -> split into in_proj_b (slot 0) + in_proj_a (slot 1) + get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba) + get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba) + + # conv1d — vLLM stores as ColumnParallelLinear with unsqueeze(1), + # already in HF Conv1d shape (conv_dim, 1, kernel_size) + conv_w = gdn.conv1d.weight.data + state_dict[f"{prefix}.conv1d.weight"] = conv_w + quant_state_dict[f"{prefix}.conv1d.weight"] = conv_w + + # dt_bias, A_log — bare nn.Parameters (no .weight suffix in HF) + state_dict[f"{prefix}.dt_bias"] = gdn.dt_bias.data + quant_state_dict[f"{prefix}.dt_bias"] = gdn.dt_bias.data + state_dict[f"{prefix}.A_log"] = gdn.A_log.data + quant_state_dict[f"{prefix}.A_log"] = gdn.A_log.data + + # norm (RMSNormGated) — handled by the layernorm loop in the caller + # out_proj (RowParallelLinear) + get_state_dict(f"{prefix}.out_proj", 0, state_dict, gdn.out_proj) +pass + + def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_state_dict): """ Extracts vision layers for any supported vision model by dynamically using @@ -790,7 +899,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if layer_module is not None: if "qkv" in layer_path: - if model_type in ("qwen2_5_vl", "qwen3_vl"): + if model_type in ("qwen2_5_vl", "qwen3_vl", "qwen3_5", "qwen3_5_moe"): # If the HF model too prefers having merged qkv, we do this # This is evident in qwen-2.5-vl and qwen-3-vl so far. get_state_dict(layer_path, 0, state_dict, layer_module, slice_weights=False) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 166ff403b..76ac551b5 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1108,6 +1108,7 @@ def _is_fused_module(name: str) -> bool: get_state_dict(f"{prefix}.q_proj", 0, state_dict, qkv_proj) get_state_dict(f"{prefix}.k_proj", 1, state_dict, qkv_proj) get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) + get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) elif hasattr(layer, "cross_attn"): prefix = f"{vllm_text_model_prefix}.layers.{kk}.cross_attn" qkv_proj = layer.cross_attn.qkv_proj @@ -1119,8 +1120,15 @@ def _is_fused_module(name: str) -> bool: get_state_dict(f"{prefix}.q_proj", 0, state_dict, q_proj) get_state_dict(f"{prefix}.k_proj", 1, state_dict, kv_proj) get_state_dict(f"{prefix}.v_proj", 2, state_dict, kv_proj) - - get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) + get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) + elif hasattr(layer, "linear_attn"): + # Qwen3.5 Gated Delta Net (GDN) linear attention layers + extract_gdn_layers( + layer.linear_attn, + f"{vllm_text_model_prefix}.layers.{kk}.linear_attn", + state_dict, quant_state_dict, get_state_dict, + ) + pass proj = layer.mlp.gate_up_proj use_fused_gate_up = _is_fused_module("gate_up_proj") @@ -1300,6 +1308,7 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16, "norm1", # Qwen2.5-VL vision encoder "norm2", # Qwen2.5-VL vision encoder "norm", + "conv1d", # Qwen3.5 GDN conv1d — assign weight directly to preserve nn.Conv1d ] # Override .to("cuda") to disable it otherwise we'll get # ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8 @@ -1352,7 +1361,7 @@ def _override_to(self, *args, **kwargs): if layer_name in quant_state_dict: # for attributes of type nn.Parameter, there's no .weight - layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name.replace('model.','',1)) + layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) layer = torch.nn.Parameter(weight, requires_grad = False) exec(f"new_model.{layer_name_br} = layer") continue From 31c1bc3cad82282290c9fa7eeebad406df944e3b Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 30 Mar 2026 09:35:49 +0000 Subject: [PATCH 02/11] [WIP] fixes for rope deltas --- unsloth_zoo/rl_replacements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 5c76cbff9..dace666b4 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -758,6 +758,8 @@ def grpo_accumulated_loss( for module in unwrapped_model.modules(): if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "io_same_decice"): module._hf_hook.io_same_decice = False + if hasattr(module, "rope_deltas"): + module.rope_deltas = None pass all_logprobs_list = [] From 112f1b5e3e0a4bb3a3e88f13fe3ae2726aefa933 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 10 Apr 2026 06:59:13 +0000 Subject: [PATCH 03/11] cleanup --- unsloth_zoo/empty_model.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index 925c215ef..603509866 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -281,7 +281,6 @@ def create_empty_causal_lm(config, dtype = torch.float16): new_config.vocab_size = 1 new_config.pad_token_id = 0 - # Minimize GDN (Gated Delta Net) layer sizes for hybrid models like Qwen3.5 for attr in ("linear_num_key_heads", "linear_num_value_heads"): if hasattr(new_config, attr): setattr(new_config, attr, 1) for attr in ("linear_key_head_dim", "linear_value_head_dim", "linear_conv_kernel_dim"): @@ -360,7 +359,6 @@ def _init_weights(self, module): "pad_token_id": 1, }) - # Minimize GDN (Gated Delta Net) layer sizes for hybrid models like Qwen3.5 for attr in ("linear_num_key_heads", "linear_num_value_heads"): if hasattr(new_config.text_config, attr): setattr(new_config.text_config, attr, 1) for attr in ("linear_key_head_dim", "linear_value_head_dim", "linear_conv_kernel_dim"): @@ -556,8 +554,6 @@ def get_model_layer_config(return_non_layered=True): "model.layers.{kk}.mlp.up_proj", "model.layers.{kk}.mlp.gate_up_proj", # for extracting from vLLM (phi3 architecture) "model.layers.{kk}.mlp.down_proj", - - # Qwen3.5 Gated Delta Net (GDN) linear attention layers "model.language_model.layers.{kk}.linear_attn.in_proj_qkv", "model.language_model.layers.{kk}.linear_attn.in_proj_z", "model.language_model.layers.{kk}.linear_attn.in_proj_b", @@ -603,8 +599,6 @@ def get_model_layer_config(return_non_layered=True): # qwen3 vl "model.visual.deepstack_merger_list.{kk}.norm", - - # Qwen3.5 GDN linear attention norms "model.language_model.layers.{kk}.linear_attn.norm", "model.layers.{kk}.linear_attn.norm", }, @@ -800,30 +794,21 @@ def _get_nested_attr(obj, attr_path: str): def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_state_dict): - """ - Extracts Gated Delta Net (GDN) linear attention weights from a vLLM - linear_attn module into HF-compatible state_dict entries. - Used by Qwen3.5 hybrid models which mix GDN and standard attention layers. - """ gdn = gdn_module - # in_proj_qkvz (non-LoRA: fused Q,K,V,Z) or in_proj_qkv + in_proj_z (LoRA) if hasattr(gdn, "in_proj_qkvz"): proj = getattr(gdn.in_proj_qkvz, "base_layer", gdn.in_proj_qkvz) weight = proj.weight output_sizes = list(proj.output_sizes) - # cumsum offsets: [0, key_dim, key_dim*2, key_dim*2+value_dim, key_dim*2+value_dim*2] offsets = [0] for s in output_sizes: offsets.append(offsets[-1] + s) - # Slots 0-2 (Q,K,V) -> HF's in_proj_qkv; Slot 3 (Z) -> HF's in_proj_z qkv_weight = weight[offsets[0]:offsets[3]] z_weight = weight[offsets[3]:offsets[4]] state_dict[f"{prefix}.in_proj_qkv.weight"] = qkv_weight quant_state_dict[f"{prefix}.in_proj_qkv.weight"] = qkv_weight state_dict[f"{prefix}.in_proj_z.weight"] = z_weight quant_state_dict[f"{prefix}.in_proj_z.weight"] = z_weight - # Handle FP8 weight scales if present if weight.dtype == torch.float8_e4m3fn: if hasattr(proj, 'weight_scale'): ws = proj.weight_scale @@ -842,28 +827,21 @@ def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_sta state_dict[f"{prefix}.in_proj_z{scale_suffix}"] = z_scale quant_state_dict[f"{prefix}.in_proj_z{scale_suffix}"] = z_scale else: - # LoRA mode: separate in_proj_qkv and in_proj_z get_state_dict(f"{prefix}.in_proj_qkv", 0, state_dict, gdn.in_proj_qkv, slice_weights=False) get_state_dict(f"{prefix}.in_proj_z", 0, state_dict, gdn.in_proj_z, slice_weights=False) - # in_proj_ba -> split into in_proj_b (slot 0) + in_proj_a (slot 1) get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba) get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba) - # conv1d — vLLM stores as ColumnParallelLinear with unsqueeze(1), - # already in HF Conv1d shape (conv_dim, 1, kernel_size) conv_w = gdn.conv1d.weight.data state_dict[f"{prefix}.conv1d.weight"] = conv_w quant_state_dict[f"{prefix}.conv1d.weight"] = conv_w - # dt_bias, A_log — bare nn.Parameters (no .weight suffix in HF) state_dict[f"{prefix}.dt_bias"] = gdn.dt_bias.data quant_state_dict[f"{prefix}.dt_bias"] = gdn.dt_bias.data state_dict[f"{prefix}.A_log"] = gdn.A_log.data quant_state_dict[f"{prefix}.A_log"] = gdn.A_log.data - # norm (RMSNormGated) — handled by the layernorm loop in the caller - # out_proj (RowParallelLinear) get_state_dict(f"{prefix}.out_proj", 0, state_dict, gdn.out_proj) pass From a31dee8e77d5a68820fede053e372ec51172fada Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 10 Apr 2026 08:10:03 +0000 Subject: [PATCH 04/11] fix lm_head detection and remove moe Signed-off-by: Datta Nimmaturi --- unsloth_zoo/empty_model.py | 17 ++++++++++------- unsloth_zoo/vllm_utils.py | 12 +++++++++--- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index 603509866..3c59996eb 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -358,11 +358,14 @@ def _init_weights(self, module): "head_dim": 1, "pad_token_id": 1, }) - - for attr in ("linear_num_key_heads", "linear_num_value_heads"): - if hasattr(new_config.text_config, attr): setattr(new_config.text_config, attr, 1) - for attr in ("linear_key_head_dim", "linear_value_head_dim", "linear_conv_kernel_dim"): - if hasattr(new_config.text_config, attr): setattr(new_config.text_config, attr, 1) + # Qwen 3.5 or GDN related attrs + _set_config_attrs(new_config.text_config, { + "linear_num_key_heads": 1, + "linear_num_value_heads": 1, + "linear_key_head_dim": 1, + "linear_value_head_dim": 1, + "linear_conv_kernel_dim": 1, + }) # Common vision attributes _set_config_attrs(new_config.vision_config, { @@ -385,7 +388,7 @@ def _init_weights(self, module): new_config.vision_config.out_hidden_size = 1 elif model_type == "qwen3_vl": new_config.vision_config.out_hidden_size = 1 - elif model_type in ("qwen3_5", "qwen3_5_moe"): + elif model_type == "qwen3_5": new_config.vision_config.out_hidden_size = 1 num_layers = max(text_layers, vision_layers) @@ -877,7 +880,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if layer_module is not None: if "qkv" in layer_path: - if model_type in ("qwen2_5_vl", "qwen3_vl", "qwen3_5", "qwen3_5_moe"): + if model_type in ("qwen2_5_vl", "qwen3_vl", "qwen3_5"): # If the HF model too prefers having merged qkv, we do this # This is evident in qwen-2.5-vl and qwen-3-vl so far. get_state_dict(layer_path, 0, state_dict, layer_module, slice_weights=False) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 76ac551b5..4477bb12e 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1176,8 +1176,14 @@ def _is_fused_module(name: str) -> bool: # LM Head - Use get_state_dict for consistency if not getattr(text_config, "tie_word_embeddings", False): lm_layer = [mod for name,mod in vllm_internals.named_modules() if "lm_head" in name] - # Use get_state_dict for consistent extraction and automatic truncation - get_state_dict("lm_head", 0, state_dict, lm_layer[0], slice_weights=False) + if len(lm_layer) != 0: + get_state_dict("lm_head", 0, state_dict, lm_layer[0], slice_weights=False) + elif hasattr(vllm_internals, "language_model") and hasattr(vllm_internals.language_model, "lm_head"): + get_state_dict("lm_head", 0, state_dict, vllm_internals.language_model.lm_head, slice_weights=False) + elif hasattr(vllm_internals, "lm_head"): + get_state_dict("lm_head", 0, state_dict, vllm_internals.lm_head, slice_weights=False) + else: + raise RuntimeError("Could not find lm_head in vLLM internals") else: # Fallback to embed_tokens for tied embeddings embed_key = f"{vllm_text_model_prefix}.embed_tokens.weight" @@ -1308,7 +1314,7 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16, "norm1", # Qwen2.5-VL vision encoder "norm2", # Qwen2.5-VL vision encoder "norm", - "conv1d", # Qwen3.5 GDN conv1d — assign weight directly to preserve nn.Conv1d + "conv1d", ] # Override .to("cuda") to disable it otherwise we'll get # ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8 From 5d075043f10f0052ea6f854ce4582297474e044d Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 10 Apr 2026 14:41:33 +0000 Subject: [PATCH 05/11] [WIP] gemma 4 dense fast inference --- unsloth_zoo/empty_model.py | 223 ++++++++++++++++++++++++++++++++----- unsloth_zoo/vllm_utils.py | 127 +++++++++++---------- 2 files changed, 260 insertions(+), 90 deletions(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index 3c59996eb..36d8e8f19 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -17,6 +17,8 @@ __all__ = [ "create_empty_model", "set_additional_modules", + "finalize_huggingface_model", + "patch_gemma4_vllm_lora_support", "extract_gdn_layers", "extract_vision_layers", "get_model_layer_config", @@ -281,10 +283,13 @@ def create_empty_causal_lm(config, dtype = torch.float16): new_config.vocab_size = 1 new_config.pad_token_id = 0 - for attr in ("linear_num_key_heads", "linear_num_value_heads"): - if hasattr(new_config, attr): setattr(new_config, attr, 1) - for attr in ("linear_key_head_dim", "linear_value_head_dim", "linear_conv_kernel_dim"): - if hasattr(new_config, attr): setattr(new_config, attr, 1) + _set_config_attrs(new_config, { + "linear_num_key_heads": 1, + "linear_num_value_heads": 1, + "linear_key_head_dim": 1, + "linear_value_head_dim": 1, + "linear_conv_kernel_dim": 1, + }) # Set attention module head_dim head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) @@ -304,6 +309,50 @@ def _set_config_attrs(config_obj, attrs_to_set): setattr(config_obj, attr, value) pass +def _get_model_device(model): + for tensor in model.parameters(): + return tensor.device + for tensor in model.buffers(): + return tensor.device + return torch.device("cpu") +pass + +def patch_gemma4_vllm_lora_support(): + from vllm.model_executor.models.gemma4_mm import Gemma4ForConditionalGeneration + from vllm.model_executor.models import interfaces as vllm_model_interfaces + from vllm.lora import model_manager as vllm_lora_model_manager + from vllm.v1.worker import lora_model_runner_mixin + from unsloth_zoo import vllm_lora_worker_manager + + Gemma4ForConditionalGeneration.supports_lora = True + Gemma4ForConditionalGeneration.embedding_modules = {} + + if not hasattr(lora_model_runner_mixin.supports_lora, "_unsloth_gemma4_patch"): + original_supports_lora = lora_model_runner_mixin.supports_lora + + def patched_supports_lora(model): + if model.__class__.__name__ == "Gemma4ForConditionalGeneration": + return True + return original_supports_lora(model) + + patched_supports_lora._unsloth_gemma4_patch = True + lora_model_runner_mixin.supports_lora = patched_supports_lora + vllm_model_interfaces.supports_lora = patched_supports_lora + + if not hasattr(vllm_lora_model_manager.create_lora_manager, "_unsloth_gemma4_patch"): + original_create_lora_manager = vllm_lora_model_manager.create_lora_manager + + def patched_create_lora_manager(model, *args, **kwargs): + if model.__class__.__name__ == "Gemma4ForConditionalGeneration": + lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager) + return lora_manager_cls(model = model, *args, **kwargs) + return original_create_lora_manager(model, *args, **kwargs) + + patched_create_lora_manager._unsloth_gemma4_patch = True + vllm_lora_model_manager.create_lora_manager = patched_create_lora_manager + vllm_lora_worker_manager.create_lora_manager = patched_create_lora_manager +pass + @torch.inference_mode def create_empty_vision_model(config, dtype = torch.float16): @@ -383,12 +432,7 @@ def _init_weights(self, module): text_layers = config.text_config.num_hidden_layers vision_layers = getattr(config.vision_config, "num_hidden_layers", None) or getattr(config.vision_config, "depth", 0) - # Set minimal sizes for different model types - if model_type == "qwen2_5_vl": - new_config.vision_config.out_hidden_size = 1 - elif model_type == "qwen3_vl": - new_config.vision_config.out_hidden_size = 1 - elif model_type == "qwen3_5": + if model_type in ("qwen2_5_vl", "qwen3_vl", "qwen3_5"): new_config.vision_config.out_hidden_size = 1 num_layers = max(text_layers, vision_layers) @@ -415,6 +459,9 @@ def create_empty_model(config, dtype = torch.float16, is_vision_model = False): @torch.inference_mode def set_additional_modules(new_model, quant_state_dict, config): + def _unwrap_tensor(val): + return getattr(val, "data", val) + if hasattr(new_model, "language_model"): language_model = new_model.language_model language_model_prefix = "model.language_model" @@ -443,7 +490,7 @@ def set_additional_modules(new_model, quant_state_dict, config): # we cannot use the normal embedding init because gemma3 uses Gemma3TextScaledWordEmbedding which wraps around nn.Embedding and has a scaling factor. This new init ensures that we respect the forward from original model. def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): num_embeddings, embedding_dim = quant_state_dict[embed_tokens_key].shape - embeddings = quant_state_dict[embed_tokens_key] + embeddings = _unwrap_tensor(quant_state_dict[embed_tokens_key]) if isinstance(embeddings, torch.Tensor): # in the newer vLLM versions, this seems to return a tensor which can't be assigned to embedding weight # we need to convert that to nn.Paramter and then pass it on @@ -462,6 +509,7 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): # Norm norm_key = f"{language_model_prefix}.norm.weight" norm = quant_state_dict[norm_key] + norm = _unwrap_tensor(norm) norm = torch.nn.Parameter(norm, requires_grad = False) language_model.norm.weight = norm @@ -476,7 +524,7 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): # Check if lm_head exists in the state dict if lmhead_key in quant_state_dict: - weight = quant_state_dict[lmhead_key] + weight = _unwrap_tensor(quant_state_dict[lmhead_key]) from torch.nn import Linear # Create Linear layer with zero dimensions to avoid any weight allocation @@ -518,6 +566,7 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): for prefix in ['new_', 'new_model.']: try: val = quant_state_dict[key] + val = _unwrap_tensor(val) if isinstance(val, torch.Tensor): val = torch.nn.Parameter(val,requires_grad=False) exec(f"{prefix}{key} = val") @@ -528,6 +577,100 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): pass pass +@torch.inference_mode +def finalize_huggingface_model( + new_model, + original_meta_model, + config, + dtype, + quantization_config = None, + bnb_config = None, +): + if original_meta_model is not None: + copy_attributes(original_meta_model, new_model) + + language_model = getattr(getattr(new_model, "model", None), "language_model", None) + if language_model is not None and hasattr(language_model, "layers"): + for layer_idx, layer in enumerate(language_model.layers): + if hasattr(layer, "layer_idx"): + layer.layer_idx = layer_idx + for attr_name in ("self_attn", "cross_attn", "mlp", "linear_attn"): + submodule = getattr(layer, attr_name, None) + if submodule is not None and hasattr(submodule, "layer_idx"): + submodule.layer_idx = layer_idx + + for module in new_model.modules(): + module_config = getattr(module, "config", None) + if module_config is not None: + set_dtype_in_config(module_config, dtype) + + target_device = _get_model_device(new_model) + text_config = getattr(config, "text_config", config) + vision_config = getattr(config, "vision_config", None) + + for module in new_model.modules(): + if hasattr(module, "rotary_emb"): + rotary_config = text_config + current_rotary_config = getattr(module.rotary_emb, "config", None) + is_vision_rotary = ( + vision_config is not None and + current_rotary_config is not None and + current_rotary_config.__class__ == vision_config.__class__ + ) + if is_vision_rotary: + rotary_config = vision_config + if not (getattr(config, "model_type", None) == "gemma4" and is_vision_rotary): + module.rotary_emb = module.rotary_emb.__class__( + config = rotary_config, + device = target_device, + ) + for buffer_name in ("inv_freq", "original_inv_freq"): + buffer = getattr(module.rotary_emb, buffer_name, None) + if torch.is_tensor(buffer) and buffer.is_floating_point(): + module.rotary_emb._buffers[buffer_name] = buffer.to( + device = target_device, + dtype = dtype, + ) + if hasattr(module, "rotary_pos_emb"): + assert vision_config is not None, "Unsloth: vision_config is required for models with vision rotary_pos_emb" + head_dim = vision_config.hidden_size // vision_config.num_heads + module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(target_device) + if hasattr(module, "rotary_emb_local"): + local_rope_config = deepcopy(text_config) + local_rope_config.rope_theta = text_config.rope_local_base_freq + local_rope_config.rope_scaling = {"rope_type": "default"} + module.rotary_emb_local = module.rotary_emb_local.__class__( + config = local_rope_config, + device = target_device, + ) + del local_rope_config + + if (quantization_config or {}) == {} and bnb_config is None: + new_model = new_model.to(device = target_device, dtype = dtype) + if getattr(config, "model_type", None) == "gemma4": + for module in new_model.modules(): + rotary_emb = getattr(module, "rotary_emb", None) + if rotary_emb is None: + continue + fresh_rotary_emb = rotary_emb.__class__( + config = rotary_emb.config, + device = target_device, + ) + for attr_name in ("max_seq_len_cached", "original_max_seq_len"): + if hasattr(fresh_rotary_emb, attr_name): + setattr(rotary_emb, attr_name, getattr(fresh_rotary_emb, attr_name)) + for attr_name, attr_value in fresh_rotary_emb.__dict__.items(): + if attr_name == "attention_scaling" or attr_name.endswith("_attention_scaling"): + setattr(rotary_emb, attr_name, attr_value) + for buffer_name, buffer in fresh_rotary_emb._buffers.items(): + if torch.is_tensor(buffer) and buffer.is_floating_point(): + rotary_emb._buffers[buffer_name] = buffer.to( + device = target_device, + dtype = torch.float32, + ) + return new_model +pass + def get_model_layer_config(return_non_layered=True): """ Returns a unified layer configuration containing the union of layer names @@ -538,6 +681,7 @@ def get_model_layer_config(return_non_layered=True): """ layer_templates = { 'standard_layers': { + "model.language_model.layers.{kk}.layer_scalar", "model.language_model.layers.{kk}.self_attn.q_proj", "model.language_model.layers.{kk}.self_attn.k_proj", "model.language_model.layers.{kk}.self_attn.v_proj", @@ -548,6 +692,7 @@ def get_model_layer_config(return_non_layered=True): "model.language_model.layers.{kk}.mlp.gate_up_proj", # for extracting from vLLM (phi3 architecture) "model.language_model.layers.{kk}.mlp.down_proj", + "model.layers.{kk}.layer_scalar", "model.layers.{kk}.self_attn.q_proj", "model.layers.{kk}.self_attn.k_proj", "model.layers.{kk}.self_attn.v_proj", @@ -595,6 +740,12 @@ def get_model_layer_config(return_non_layered=True): "model.vision_tower.vision_model.encoder.layers.{kk}.post_layernorm", "model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm1", "model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm2", + "model.vision_tower.encoder.layers.{kk}.input_layernorm", + "model.vision_tower.encoder.layers.{kk}.post_attention_layernorm", + "model.vision_tower.encoder.layers.{kk}.pre_feedforward_layernorm", + "model.vision_tower.encoder.layers.{kk}.post_feedforward_layernorm", + "model.vision_tower.encoder.layers.{kk}.self_attn.q_norm", + "model.vision_tower.encoder.layers.{kk}.self_attn.k_norm", # Mistral3 vision norms "model.vision_tower.transformer.layers.{kk}.attention_norm", @@ -647,6 +798,13 @@ def get_model_layer_config(return_non_layered=True): "model.vision_tower.vision_model.encoder.layers.{kk}.mlp.fc1", "model.vision_tower.vision_model.encoder.layers.{kk}.mlp.fc2", + "model.vision_tower.encoder.layers.{kk}.self_attn.q_proj.linear", + "model.vision_tower.encoder.layers.{kk}.self_attn.k_proj.linear", + "model.vision_tower.encoder.layers.{kk}.self_attn.v_proj.linear", + "model.vision_tower.encoder.layers.{kk}.self_attn.o_proj.linear", + "model.vision_tower.encoder.layers.{kk}.mlp.gate_proj.linear", + "model.vision_tower.encoder.layers.{kk}.mlp.up_proj.linear", + "model.vision_tower.encoder.layers.{kk}.mlp.down_proj.linear", # qwen2.5_vl style "model.visual.blocks.{kk}.attn.qkv", @@ -722,6 +880,11 @@ def get_model_layer_config(return_non_layered=True): "model.vision_tower.patch_positional_embedding", "model.vision_tower.patch_conv", "model.vision_tower.ln_pre", + "model.vision_tower.std_bias", + "model.vision_tower.std_scale", + "model.vision_tower.patch_embedder.position_embedding_table", + "model.vision_tower.patch_embedder.input_proj", + "model.embed_vision.embedding_projection", # qwen 3 vl "model.visual.pos_embed", @@ -769,6 +932,11 @@ def get_model_layer_counts(config): "vision_layers": getattr(config.vision_config, "depth", 27), "deepstack_layers": getattr(config.vision_config, "deepstack_depth", 3), } + elif model_type == "gemma4": + return { + "text_layers": getattr(config.text_config, "num_hidden_layers", 32), + "vision_layers": getattr(config.vision_config, "num_hidden_layers", 32), + } elif model_type == "gemma3": return { "text_layers": getattr(config.text_config, "num_hidden_layers", 32), @@ -799,6 +967,10 @@ def _get_nested_attr(obj, attr_path: str): def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_state_dict): gdn = gdn_module + def store(name, value): + state_dict[name] = value + quant_state_dict[name] = value + if hasattr(gdn, "in_proj_qkvz"): proj = getattr(gdn.in_proj_qkvz, "base_layer", gdn.in_proj_qkvz) weight = proj.weight @@ -808,10 +980,8 @@ def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_sta offsets.append(offsets[-1] + s) qkv_weight = weight[offsets[0]:offsets[3]] z_weight = weight[offsets[3]:offsets[4]] - state_dict[f"{prefix}.in_proj_qkv.weight"] = qkv_weight - quant_state_dict[f"{prefix}.in_proj_qkv.weight"] = qkv_weight - state_dict[f"{prefix}.in_proj_z.weight"] = z_weight - quant_state_dict[f"{prefix}.in_proj_z.weight"] = z_weight + store(f"{prefix}.in_proj_qkv.weight", qkv_weight) + store(f"{prefix}.in_proj_z.weight", z_weight) if weight.dtype == torch.float8_e4m3fn: if hasattr(proj, 'weight_scale'): ws = proj.weight_scale @@ -825,10 +995,8 @@ def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_sta scale_suffix = '.weight_scale_inv' qkv_scale = ws[scale_offsets[0]:scale_offsets[3]] z_scale = ws[scale_offsets[3]:scale_offsets[4]] - state_dict[f"{prefix}.in_proj_qkv{scale_suffix}"] = qkv_scale - quant_state_dict[f"{prefix}.in_proj_qkv{scale_suffix}"] = qkv_scale - state_dict[f"{prefix}.in_proj_z{scale_suffix}"] = z_scale - quant_state_dict[f"{prefix}.in_proj_z{scale_suffix}"] = z_scale + store(f"{prefix}.in_proj_qkv{scale_suffix}", qkv_scale) + store(f"{prefix}.in_proj_z{scale_suffix}", z_scale) else: get_state_dict(f"{prefix}.in_proj_qkv", 0, state_dict, gdn.in_proj_qkv, slice_weights=False) get_state_dict(f"{prefix}.in_proj_z", 0, state_dict, gdn.in_proj_z, slice_weights=False) @@ -836,14 +1004,9 @@ def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_sta get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba) get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba) - conv_w = gdn.conv1d.weight.data - state_dict[f"{prefix}.conv1d.weight"] = conv_w - quant_state_dict[f"{prefix}.conv1d.weight"] = conv_w - - state_dict[f"{prefix}.dt_bias"] = gdn.dt_bias.data - quant_state_dict[f"{prefix}.dt_bias"] = gdn.dt_bias.data - state_dict[f"{prefix}.A_log"] = gdn.A_log.data - quant_state_dict[f"{prefix}.A_log"] = gdn.A_log.data + store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data) + store(f"{prefix}.dt_bias", gdn.dt_bias.data) + store(f"{prefix}.A_log", gdn.A_log.data) get_state_dict(f"{prefix}.out_proj", 0, state_dict, gdn.out_proj) pass @@ -899,7 +1062,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if isinstance(layer_module, torch.nn.Module): if hasattr(layer_module, 'weight'): get_state_dict(layer_path, 0, state_dict, layer_module) - elif isinstance(layer_module, torch.nn.Parameter): + elif isinstance(layer_module, torch.Tensor): state_dict[f"{layer_path}"] = layer_module.data quant_state_dict[f"{layer_path}"] = state_dict[f"{layer_path}"] else: @@ -914,7 +1077,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if hasattr(component, 'weight'): # Prefer using get_state_dict when possible get_state_dict(component_path, 0, state_dict, component) - elif isinstance(component, torch.nn.Parameter): + elif isinstance(component, torch.Tensor): state_dict[component_path] = component.data quant_state_dict[component_path] = component.data elif isinstance(component, torch.nn.Module): diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 4477bb12e..56fea08f4 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1063,6 +1063,12 @@ def get_state_dict(prefix, kk, state_dict, proj, slice_weights=True, slice_index quant_state_dict[prefix + ".bias"] = bias_tensor pass + gemma4_k_eq_v_layers = { + kk + for kk, layer_type in enumerate(getattr(text_config, "layer_types", ())) + if model_type == "gemma4" and getattr(text_config, "attention_k_eq_v", False) and layer_type == "full_attention" + } + # Embedding if hasattr(vllm_internals, "model"): # Standard Language models vllm_text_model = vllm_internals.model @@ -1107,7 +1113,8 @@ def _is_fused_module(name: str) -> bool: else: get_state_dict(f"{prefix}.q_proj", 0, state_dict, qkv_proj) get_state_dict(f"{prefix}.k_proj", 1, state_dict, qkv_proj) - get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) + if kk not in gemma4_k_eq_v_layers: + get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) elif hasattr(layer, "cross_attn"): prefix = f"{vllm_text_model_prefix}.layers.{kk}.cross_attn" @@ -1157,6 +1164,9 @@ def _is_fused_module(name: str) -> bool: except Exception as e: skipped_layernorms.append(layernorm_name.split(".")[-1]) pass + if hasattr(layer, "layer_scalar"): + state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data + quant_state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data pass if len(skipped_layernorms) != 0: @@ -1175,9 +1185,9 @@ def _is_fused_module(name: str) -> bool: # LM Head - Use get_state_dict for consistency if not getattr(text_config, "tie_word_embeddings", False): - lm_layer = [mod for name,mod in vllm_internals.named_modules() if "lm_head" in name] - if len(lm_layer) != 0: - get_state_dict("lm_head", 0, state_dict, lm_layer[0], slice_weights=False) + lm_layer = next((mod for name, mod in vllm_internals.named_modules() if "lm_head" in name), None) + if lm_layer is not None: + get_state_dict("lm_head", 0, state_dict, lm_layer, slice_weights=False) elif hasattr(vllm_internals, "language_model") and hasattr(vllm_internals.language_model, "lm_head"): get_state_dict("lm_head", 0, state_dict, vllm_internals.language_model.lm_head, slice_weights=False) elif hasattr(vllm_internals, "lm_head"): @@ -1203,6 +1213,13 @@ def assert_same_state_dict(old_state_dict, new_state_dict): # Check if state_dict are equivalent # hf, vllm + def _normalize_state_dict_tensor(value): + if isinstance(value, torch.nn.Parameter): + value = value.detach() + if value.is_sparse: + value = value.to_dense() + return value.contiguous() + difference = new_state_dict.keys() ^ old_state_dict.keys() difference -= set(("model.lm_head.weight","model.language_model.lm_head.weight", "lm_head.weight")) if len(difference) != 0: @@ -1216,8 +1233,8 @@ def assert_same_state_dict(old_state_dict, new_state_dict): for key in old_state_dict: try: - old_val = old_state_dict[key] - new_val = new_state_dict[key] + old_val = _normalize_state_dict_tensor(old_state_dict[key]) + new_val = _normalize_state_dict_tensor(new_state_dict[key]) if old_val.dtype != new_val.dtype or (new_val.element_size() < 2): # upcast both to float32 just for comparison. For FP8, vLLM stores weight scale in FP32 while HF preferes 16bit old_val = old_val.to(torch.float32) @@ -1231,7 +1248,11 @@ def assert_same_state_dict(old_state_dict, new_state_dict): if key1 is not None and key2 is not None: try: - torch.testing.assert_close(old_state_dict[key1].contiguous(), new_state_dict[key2].contiguous(), check_stride = True) + torch.testing.assert_close( + _normalize_state_dict_tensor(old_state_dict[key1]), + _normalize_state_dict_tensor(new_state_dict[key2]), + check_stride = True, + ) except Exception: failures[key] = error else: @@ -1249,7 +1270,14 @@ def assert_same_state_dict(old_state_dict, new_state_dict): def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16, bnb_config = None, is_vision_model = False): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules to create HF compatible model + def _unwrap_tensor(value): + return getattr(value, "data", value) + set_dtype_in_config(config, dtype) + for subconfig_name in ("text_config", "vision_config", "audio_config"): + subconfig = getattr(config, subconfig_name, None) + if subconfig is not None: + set_dtype_in_config(subconfig, dtype) new_model, original_meta_model, layer_count, layer_names = create_empty_model(config, dtype, is_vision_model) new_model = new_model.to(device = get_target_device(), dtype = dtype) quantization_config = getattr(config, "quantization_config", {}) @@ -1348,7 +1376,7 @@ def _override_to(self, *args, **kwargs): if f"{layer_name}.bias" in quant_state_dict: # Has bias! has_bias = True - bias = quant_state_dict[f"{layer_name}.bias"] + bias = _unwrap_tensor(quant_state_dict[f"{layer_name}.bias"]) bias = torch.nn.Parameter(bias, requires_grad = False) else: has_bias = False @@ -1368,7 +1396,7 @@ def _override_to(self, *args, **kwargs): if layer_name in quant_state_dict: # for attributes of type nn.Parameter, there's no .weight layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) - layer = torch.nn.Parameter(weight, requires_grad = False) + layer = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) exec(f"new_model.{layer_name_br} = layer") continue elif fp8_weight_scale is not None: @@ -1377,7 +1405,7 @@ def _override_to(self, *args, **kwargs): layer = FbgemmFp8Linear(in_features = 0, out_features = 0, bias = has_bias, weight_dtype = dtype).to(get_target_device()) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] - layer.weight = torch.nn.Parameter(weight, requires_grad = False) + layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias layer.input_scale_ub = kwargs['input_scale_ub'] layer.weight_scale = torch.nn.Parameter(fp8_weight_scale, requires_grad = False) @@ -1392,7 +1420,7 @@ def _override_to(self, *args, **kwargs): layer = FP8Linear(**fp8_kwargs) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] - layer.weight = torch.nn.Parameter(weight, requires_grad = False) + layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias layer.weight_scale_inv = torch.nn.Parameter(fp8_weight_scale, requires_grad = False) layer.quant_method = "fp8" @@ -1416,11 +1444,11 @@ def _override_to(self, *args, **kwargs): layer.out_features = weight.shape[0] # from vllm 0.11.1, the .weight is of dtype ModelWeightParameter, so try to extract the 'data' part # https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170#diff-7d6145ac4ba084231a441c2056c7fca23c3bae33e6542f4f602a6c9d4d2da64dL199-R208 - layer.weight = torch.nn.Parameter(getattr(weight, 'data', weight), requires_grad = False) + layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias else: # LayerNorms (including vision norms) - weight_param = torch.nn.Parameter(weight, requires_grad=False) + weight_param = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad=False) layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) # Set weight exec(f"new_model.{layer_name_br}.weight = None") @@ -1439,49 +1467,14 @@ def _override_to(self, *args, **kwargs): pass set_additional_modules(new_model, quant_state_dict, config) - - if original_meta_model is not None: - copy_attributes(original_meta_model, new_model) - - # # Set config on model and modules using clean approach - # new_model.config = config - # for module in new_model.modules(): - # if hasattr(module, "config"): - # module.config = config - # for param in new_model.parameters(): - # if hasattr(param, "config"): - # param.config = config - - text_config = getattr(config, "text_config", config) #try using text config for VLMs - vision_config = getattr(config, "vision_config", None) - # Fix up rotary_emb by re-initing them - for module in new_model.modules(): - if hasattr(module, "rotary_emb"): - module.rotary_emb = module.rotary_emb.__class__( - config = text_config, - device = get_target_device(), - ) - if hasattr(module, "rotary_pos_emb"): - # Qwen 2.5 VL has a rotary_pos_emb in vision submodel - # https://github.com/huggingface/transformers/blob/a871f6f58d49f3a05ae9dae519caa8aa9d919a07/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L337 - assert vision_config is not None, "Unsloth: vision_config is required for models with vision rotary_pos_emb" - head_dim = vision_config.hidden_size // vision_config.num_heads - module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(get_target_device()) - if hasattr(module, "rotary_emb_local"): - # gemma3 has a rotary_emb_local - # https://github.com/huggingface/transformers/blob/008c0ba8e2a1226a6ef5a61c4915a0a8a340c157/src/transformers/models/gemma3/modeling_gemma3.py#L469-L471 - # Gemma3 uses different defaults for local and global RoPE. Copy the config for modification. - local_rope_config = deepcopy(text_config) - local_rope_config.rope_theta = text_config.rope_local_base_freq - local_rope_config.rope_scaling = {"rope_type": "default"} - # gemma3 has a rotary_emb_local - module.rotary_emb_local = module.rotary_emb_local.__class__( - config = local_rope_config, - device = get_target_device(), - ) - del local_rope_config - pass - pass + new_model = finalize_huggingface_model( + new_model, + original_meta_model, + config, + dtype, + quantization_config = quantization_config, + bnb_config = bnb_config, + ) # Must override or else Bitsandbytes will error new_model.to = partial(_override_to, new_model) @@ -1785,6 +1778,9 @@ def load_vllm( assert(type(use_bitsandbytes) is bool) assert(conservativeness >= 0.0 and conservativeness <= 1.0) + if is_vision_model and getattr(config, "model_type", None) == "gemma4": + patch_gemma4_vllm_lora_support() + unsloth_vllm_standby = unsloth_vllm_standby or (os.getenv("UNSLOTH_VLLM_STANDBY", "0") != "0") # This would give the flexibility to override the util we set for standby mode. In some extreme cases, this can be helpful. standby_util_override = os.getenv("UNSLOTH_VLLM_STANDBY_UTIL_OVERRIDE", "0") != "0" @@ -2880,10 +2876,19 @@ def _test_is_same_vlm(model, new_model, processor, test_backward=False): messages, tokenize=False, add_generation_prompt=True ) - inputs = processor.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, - return_dict=True, return_tensors="pt" - ).to(model.device, dtype=model.dtype) + if processor.__class__.__name__ in ("Gemma3Processor", "Gemma4Processor"): + from transformers.image_utils import load_image + image = load_image(messages[0]["content"][0]["image"]) + inputs = processor( + text = [text], + images = [image], + return_tensors = "pt", + ).to(model.device, dtype=model.dtype) + else: + inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, + return_dict=True, return_tensors="pt" + ).to(model.device, dtype=model.dtype) with torch.no_grad(): original_outputs = model(**inputs) @@ -3096,6 +3101,8 @@ def _test_get_vllm_state_dict( new_model = convert_vllm_to_huggingface(quant_state_dict, config, dtype, is_vision_model = is_vision_model) test_model_conversion(model, new_model) + new_model, _ = patch_model_and_tokenizer(new_model, None) + new_model.eval() # Run the model as well if not is_vision_model: From a85a4f45f74d6fde99c831e73d5a1f853af1175f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 17 Apr 2026 06:48:11 +0000 Subject: [PATCH 06/11] Fix dtype setting --- unsloth_zoo/empty_model.py | 2 +- unsloth_zoo/hf_utils.py | 37 ++++++++++++++++++++++++++++--------- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index 36d8e8f19..632a77e7a 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -32,7 +32,7 @@ from copy import deepcopy from .utils import get_quant_type from .log import logger -from .hf_utils import HAS_TORCH_DTYPE, dtype_from_config +from .hf_utils import HAS_TORCH_DTYPE, dtype_from_config, set_dtype_in_config def is_comparable(val): # Don't treat tensors as comparable, only basic types diff --git a/unsloth_zoo/hf_utils.py b/unsloth_zoo/hf_utils.py index cb96a9c51..27ff771e9 100644 --- a/unsloth_zoo/hf_utils.py +++ b/unsloth_zoo/hf_utils.py @@ -50,15 +50,34 @@ def dtype_from_config(config): return dtype def set_dtype_in_config(config, dtype): - try: - # if dtype is not a string, convert it to a string - string_dtype = str(dtype).split(".")[-1] if isinstance(dtype, torch.dtype) else dtype - if HAS_TORCH_DTYPE: - setattr(config, "torch_dtype", string_dtype) - else: - setattr(config, "dtype", string_dtype) - except: - set_dtype_in_config_fallback(config, string_dtype) + runtime_dtype = getattr(torch, dtype, dtype) if isinstance(dtype, str) else dtype + target_fields = [] + + if hasattr(config, "dtype"): + target_fields.append("dtype") + if hasattr(config, "torch_dtype"): + target_fields.append("torch_dtype") + + if len(target_fields) == 0: + target_fields.append("torch_dtype" if HAS_TORCH_DTYPE else "dtype") + + success = False + for field in target_fields: + try: + setattr(config, field, runtime_dtype) + success = True + continue + except Exception: + pass + + try: + config.__dict__[field] = runtime_dtype + success = True + except Exception: + pass + + if not success: + set_dtype_in_config_fallback(config, dtype) def set_dtype_in_config_fallback(config, dtype): try: From 43ed24f25402e55dcce8cad96ec7de1387cda2f1 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 17 Apr 2026 08:21:24 +0000 Subject: [PATCH 07/11] Fix gemma4 load on vllm 0.19.0 --- unsloth_zoo/empty_model.py | 46 +++++++++++++++++++++++++ unsloth_zoo/temporary_patches/gemma4.py | 12 +++++++ unsloth_zoo/vllm_utils.py | 1 + 3 files changed, 59 insertions(+) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index 632a77e7a..8d7dcc682 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -19,6 +19,7 @@ "set_additional_modules", "finalize_huggingface_model", "patch_gemma4_vllm_lora_support", + "patch_gemma4_vllm_k_eq_v_support", "extract_gdn_layers", "extract_vision_layers", "get_model_layer_config", @@ -351,6 +352,51 @@ def patched_create_lora_manager(model, *args, **kwargs): patched_create_lora_manager._unsloth_gemma4_patch = True vllm_lora_model_manager.create_lora_manager = patched_create_lora_manager vllm_lora_worker_manager.create_lora_manager = patched_create_lora_manager + +# vLLM load seems to be failing at least with 0.19.0 +# due to the diff b/w sepearate kv and shared kv like gemma4 +# This is to address that issue :) +def patch_gemma4_vllm_k_eq_v_support(): + from vllm.model_executor.models.gemma4 import Gemma4Attention + + if hasattr(Gemma4Attention.forward, "_unsloth_gemma4_k_eq_v_patch"): + return + + original_forward = Gemma4Attention.forward + + def patched_forward(self, positions, hidden_states, **kwargs): + qkv, _ = self.qkv_proj(hidden_states) + + if self.use_k_eq_v and qkv.shape[-1] == (self.q_size + self.kv_size): + q, k = qkv.split([self.q_size, self.kv_size], dim = -1) + # Gemma4 full-attention k_eq_v layers reuse K as the pre-norm V input. + v = k + else: + return original_forward(self, positions, hidden_states, **kwargs) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + + if not self.is_kv_shared_layer: + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + q, k = self.rotary_emb(positions, q, k) + + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + v = v.flatten(-2, -1) + else: + q = self.rotary_emb(positions, q, k)[0] + + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + patched_forward._unsloth_gemma4_k_eq_v_patch = True + Gemma4Attention.forward = patched_forward +pass pass diff --git a/unsloth_zoo/temporary_patches/gemma4.py b/unsloth_zoo/temporary_patches/gemma4.py index 3a91bba0c..f356e5455 100644 --- a/unsloth_zoo/temporary_patches/gemma4.py +++ b/unsloth_zoo/temporary_patches/gemma4.py @@ -118,6 +118,18 @@ def __getattr__(self, name): ) return getattr(object.__getattribute__(self, "_real"), name) + def __setattr__(self, name, value): + if name == "_real": + object.__setattr__(self, name, value) + return + setattr(object.__getattribute__(self, "_real"), name, value) + + def __delattr__(self, name): + if name == "_real": + object.__delattr__(self, name) + return + delattr(object.__getattribute__(self, "_real"), name) + def get_text_config(self, decoder=None, encoder=None): # If upstream recursively calls get_text_config on the proxy, return # self so the proxy is not unwrapped back into a raw config. diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 2167564c8..8c2b0e0dd 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1781,6 +1781,7 @@ def load_vllm( if is_vision_model and getattr(config, "model_type", None) == "gemma4": patch_gemma4_vllm_lora_support() + patch_gemma4_vllm_k_eq_v_support() unsloth_vllm_standby = unsloth_vllm_standby or (os.getenv("UNSLOTH_VLLM_STANDBY", "0") != "0") # This would give the flexibility to override the util we set for standby mode. In some extreme cases, this can be helpful. From b7052b55a5132556c7928fc5ab2c75e63b5e3eae Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 17 Apr 2026 10:19:37 +0000 Subject: [PATCH 08/11] fix bnb loader for gemam4 --- unsloth_zoo/empty_model.py | 100 ++++++++++++++++++++++++------------- unsloth_zoo/vllm_utils.py | 4 ++ 2 files changed, 69 insertions(+), 35 deletions(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index 8d7dcc682..c4df7893c 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -352,51 +352,81 @@ def patched_create_lora_manager(model, *args, **kwargs): patched_create_lora_manager._unsloth_gemma4_patch = True vllm_lora_model_manager.create_lora_manager = patched_create_lora_manager vllm_lora_worker_manager.create_lora_manager = patched_create_lora_manager +pass -# vLLM load seems to be failing at least with 0.19.0 -# due to the diff b/w sepearate kv and shared kv like gemma4 -# This is to address that issue :) +# vLLM's Gemma4 k_eq_v path now expects qkv_proj to always expose q+k+v. +# For prequantized bitsandbytes checkpoints, the synthetic v shard is still +# missing from the quant-state dict on full-attention k_eq_v layers, so we +# materialize it during loader-side quant-state stacking instead of patching +# the runtime attention forward. def patch_gemma4_vllm_k_eq_v_support(): - from vllm.model_executor.models.gemma4 import Gemma4Attention + from vllm.model_executor.model_loader.bitsandbytes_loader import ( + BitsAndBytesModelLoader, + ) - if hasattr(Gemma4Attention.forward, "_unsloth_gemma4_k_eq_v_patch"): + if hasattr( + BitsAndBytesModelLoader._stack_quantization_states, + "_unsloth_gemma4_k_eq_v_patch", + ): return - original_forward = Gemma4Attention.forward - - def patched_forward(self, positions, hidden_states, **kwargs): - qkv, _ = self.qkv_proj(hidden_states) - - if self.use_k_eq_v and qkv.shape[-1] == (self.q_size + self.kv_size): - q, k = qkv.split([self.q_size, self.kv_size], dim = -1) - # Gemma4 full-attention k_eq_v layers reuse K as the pre-norm V input. - v = k - else: - return original_forward(self, positions, hidden_states, **kwargs) + original_stack_quantization_states = ( + BitsAndBytesModelLoader._stack_quantization_states + ) - q = q.unflatten(-1, (self.num_heads, self.head_dim)) - q = self.q_norm(q) - q = q.flatten(-2, -1) + def _get_gemma4_text_config(model): + config = getattr(model, "config", None) + if config is None: + return None + + text_config = getattr(config, "text_config", config) + model_type = getattr(config, "model_type", None) + text_model_type = getattr(text_config, "model_type", None) + if model_type != "gemma4" and text_model_type not in ("gemma4", "gemma4_text"): + return None + return text_config + + def _get_gemma4_k_eq_v_qkv_param_names(model): + text_config = _get_gemma4_text_config(model) + if text_config is None or not getattr(text_config, "attention_k_eq_v", False): + return () + + param_names = set(name for name, _ in model.named_parameters()) + qkv_param_names = [] + for layer_idx, layer_type in enumerate(getattr(text_config, "layer_types", ())): + if layer_type != "full_attention": + continue - if not self.is_kv_shared_layer: - k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) - k = self.k_norm(k) - k = k.flatten(-2, -1) - q, k = self.rotary_emb(positions, q, k) + for prefix in ("language_model.model", "model"): + qkv_param_name = ( + f"{prefix}.layers.{layer_idx}.self_attn.qkv_proj.weight" + ) + if qkv_param_name in param_names: + qkv_param_names.append(qkv_param_name) + break + return tuple(qkv_param_names) + + def patched_stack_quantization_states(self, model, quant_state_dict): + stacked_quant_state_dict = original_stack_quantization_states( + self, model, quant_state_dict + ) + + for qkv_param_name in _get_gemma4_k_eq_v_qkv_param_names(model): + quant_states = stacked_quant_state_dict.get(qkv_param_name) + if not isinstance(quant_states, dict) or 2 in quant_states or 1 not in quant_states: + continue - v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) - v = self.v_norm(v) - v = v.flatten(-2, -1) - else: - q = self.rotary_emb(positions, q, k)[0] + # Gemma4 full-attention k_eq_v layers reuse K as V. The raw weight + # loader already duplicates k_proj -> v_proj; prequant BnB needs the + # same duplication for shard-local QuantState metadata. + quant_states[2] = deepcopy(quant_states[1]) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output + return stacked_quant_state_dict - patched_forward._unsloth_gemma4_k_eq_v_patch = True - Gemma4Attention.forward = patched_forward -pass + patched_stack_quantization_states._unsloth_gemma4_k_eq_v_patch = True + BitsAndBytesModelLoader._stack_quantization_states = ( + patched_stack_quantization_states + ) pass diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 8c2b0e0dd..b319e7720 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -3010,6 +3010,7 @@ def _test_get_vllm_state_dict( load_in_4bit = False, skip_generation = False, is_vision_model = False, + compilation_config = None, ): # All Unsloth Zoo code licensed under LGPLv3 # Check if model is allowed to be used in vLLM @@ -3049,6 +3050,8 @@ def _test_get_vllm_state_dict( model_type = getattr(config, "model_type", "causal_lm") enable_lora = model_type != "mllama" + if compilation_config is None and model_type == "gemma4": + compilation_config = 0 if not is_vision_model: model_class = AutoModelForCausalLM @@ -3090,6 +3093,7 @@ def _test_get_vllm_state_dict( use_bitsandbytes = load_in_4bit, is_vision_model = is_vision_model, enable_lora = enable_lora, + compilation_config = compilation_config, ) state_dict, quant_state_dict = get_vllm_state_dict( From 9a4bbd9b41fd5c447dce7e355e4829ed9cbee479 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Apr 2026 11:07:36 +0000 Subject: [PATCH 09/11] Fix review findings: Gemma4 LoRA/BnB patches, GDN extraction, finalize_huggingface_model - patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so _call_create_lora_manager's signature inspection still sees vllm_config; pass model positionally to lora_manager_cls to avoid "multiple values for 'model'". - patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed qkv_proj path as fallback. - load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model), so text-only Gemma4 + LoRA / BnB also works. - extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars; handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8; export linear_attn.norm.weight. - finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path); rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard rotary_pos_emb on vision_config availability; mirror language_model detection from set_additional_modules. - get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection / post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop the broken linear_fc{kk} template. - set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to 'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on current transformers. - vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers (defensive) while still capturing layer_scalar. - _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor) so non-tensor state-dict values pass through. --- unsloth_zoo/empty_model.py | 169 +++++++++++++++++++++++++++---------- unsloth_zoo/hf_utils.py | 13 ++- unsloth_zoo/vllm_utils.py | 16 +++- 3 files changed, 142 insertions(+), 56 deletions(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index c4df7893c..229c09c67 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -319,6 +319,7 @@ def _get_model_device(model): pass def patch_gemma4_vllm_lora_support(): + from functools import wraps from vllm.model_executor.models.gemma4_mm import Gemma4ForConditionalGeneration from vllm.model_executor.models import interfaces as vllm_model_interfaces from vllm.lora import model_manager as vllm_lora_model_manager @@ -343,10 +344,11 @@ def patched_supports_lora(model): if not hasattr(vllm_lora_model_manager.create_lora_manager, "_unsloth_gemma4_patch"): original_create_lora_manager = vllm_lora_model_manager.create_lora_manager + @wraps(original_create_lora_manager) def patched_create_lora_manager(model, *args, **kwargs): if model.__class__.__name__ == "Gemma4ForConditionalGeneration": lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager) - return lora_manager_cls(model = model, *args, **kwargs) + return lora_manager_cls(model, *args, **kwargs) return original_create_lora_manager(model, *args, **kwargs) patched_create_lora_manager._unsloth_gemma4_patch = True @@ -386,40 +388,48 @@ def _get_gemma4_text_config(model): return None return text_config - def _get_gemma4_k_eq_v_qkv_param_names(model): + def _get_gemma4_k_eq_v_pairs(model): text_config = _get_gemma4_text_config(model) if text_config is None or not getattr(text_config, "attention_k_eq_v", False): return () param_names = set(name for name, _ in model.named_parameters()) - qkv_param_names = [] + pairs = [] for layer_idx, layer_type in enumerate(getattr(text_config, "layer_types", ())): if layer_type != "full_attention": continue for prefix in ("language_model.model", "model"): - qkv_param_name = ( - f"{prefix}.layers.{layer_idx}.self_attn.qkv_proj.weight" - ) - if qkv_param_name in param_names: - qkv_param_names.append(qkv_param_name) + k_name = f"{prefix}.layers.{layer_idx}.self_attn.k_proj.weight" + v_name = f"{prefix}.layers.{layer_idx}.self_attn.v_proj.weight" + qkv_name = f"{prefix}.layers.{layer_idx}.self_attn.qkv_proj.weight" + if k_name in param_names: + pairs.append(("split", k_name, v_name)) + break + if qkv_name in param_names: + pairs.append(("packed", qkv_name, None)) break - return tuple(qkv_param_names) + return tuple(pairs) def patched_stack_quantization_states(self, model, quant_state_dict): stacked_quant_state_dict = original_stack_quantization_states( self, model, quant_state_dict ) - for qkv_param_name in _get_gemma4_k_eq_v_qkv_param_names(model): - quant_states = stacked_quant_state_dict.get(qkv_param_name) - if not isinstance(quant_states, dict) or 2 in quant_states or 1 not in quant_states: + for kind, source, target in _get_gemma4_k_eq_v_pairs(model): + quant_states = stacked_quant_state_dict.get(source) + if quant_states is None: continue # Gemma4 full-attention k_eq_v layers reuse K as V. The raw weight # loader already duplicates k_proj -> v_proj; prequant BnB needs the # same duplication for shard-local QuantState metadata. - quant_states[2] = deepcopy(quant_states[1]) + if kind == "packed": + if isinstance(quant_states, dict) and 2 not in quant_states and 1 in quant_states: + quant_states[2] = deepcopy(quant_states[1]) + elif kind == "split": + if target not in stacked_quant_state_dict: + stacked_quant_state_dict[target] = deepcopy(quant_states) return stacked_quant_state_dict @@ -665,9 +675,15 @@ def finalize_huggingface_model( if original_meta_model is not None: copy_attributes(original_meta_model, new_model) - language_model = getattr(getattr(new_model, "model", None), "language_model", None) - if language_model is not None and hasattr(language_model, "layers"): - for layer_idx, layer in enumerate(language_model.layers): + if hasattr(new_model, "language_model"): + lm_root = new_model.language_model + elif hasattr(new_model, "model") and hasattr(new_model.model, "language_model"): + lm_root = new_model.model.language_model + else: + lm_root = getattr(new_model, "model", None) + + if lm_root is not None and hasattr(lm_root, "layers"): + for layer_idx, layer in enumerate(lm_root.layers): if hasattr(layer, "layer_idx"): layer.layer_idx = layer_idx for attr_name in ("self_attn", "cross_attn", "mlp", "linear_attn"): @@ -683,6 +699,7 @@ def finalize_huggingface_model( target_device = _get_model_device(new_model) text_config = getattr(config, "text_config", config) vision_config = getattr(config, "vision_config", None) + is_gemma4 = getattr(config, "model_type", None) == "gemma4" for module in new_model.modules(): if hasattr(module, "rotary_emb"): @@ -691,24 +708,24 @@ def finalize_huggingface_model( is_vision_rotary = ( vision_config is not None and current_rotary_config is not None and + current_rotary_config is not text_config and current_rotary_config.__class__ == vision_config.__class__ ) if is_vision_rotary: rotary_config = vision_config - if not (getattr(config, "model_type", None) == "gemma4" and is_vision_rotary): - module.rotary_emb = module.rotary_emb.__class__( - config = rotary_config, - device = target_device, - ) + module.rotary_emb = module.rotary_emb.__class__( + config = rotary_config, + device = target_device, + ) + buffer_dtype = torch.float32 if (is_gemma4 and is_vision_rotary) else dtype for buffer_name in ("inv_freq", "original_inv_freq"): buffer = getattr(module.rotary_emb, buffer_name, None) if torch.is_tensor(buffer) and buffer.is_floating_point(): module.rotary_emb._buffers[buffer_name] = buffer.to( device = target_device, - dtype = dtype, + dtype = buffer_dtype, ) - if hasattr(module, "rotary_pos_emb"): - assert vision_config is not None, "Unsloth: vision_config is required for models with vision rotary_pos_emb" + if hasattr(module, "rotary_pos_emb") and vision_config is not None: head_dim = vision_config.hidden_size // vision_config.num_heads module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(target_device) if hasattr(module, "rotary_emb_local"): @@ -723,13 +740,16 @@ def finalize_huggingface_model( if (quantization_config or {}) == {} and bnb_config is None: new_model = new_model.to(device = target_device, dtype = dtype) - if getattr(config, "model_type", None) == "gemma4": + if is_gemma4: for module in new_model.modules(): rotary_emb = getattr(module, "rotary_emb", None) if rotary_emb is None: continue + rotary_cfg = getattr(rotary_emb, "config", None) + if rotary_cfg is None: + continue fresh_rotary_emb = rotary_emb.__class__( - config = rotary_emb.config, + config = rotary_cfg, device = target_device, ) for attr_name in ("max_seq_len_cached", "original_max_seq_len"): @@ -795,6 +815,12 @@ def get_model_layer_config(return_non_layered=True): "model.layers.{kk}.linear_attn.out_proj", "model.layers.{kk}.linear_attn.dt_bias", "model.layers.{kk}.linear_attn.A_log", + + # Gemma4 per-layer input modules + "model.language_model.layers.{kk}.per_layer_input_gate", + "model.language_model.layers.{kk}.per_layer_projection", + "model.layers.{kk}.per_layer_input_gate", + "model.layers.{kk}.per_layer_projection", }, 'layernorms': { "model.language_model.layers.{kk}.input_layernorm", @@ -831,6 +857,10 @@ def get_model_layer_config(return_non_layered=True): "model.visual.deepstack_merger_list.{kk}.norm", "model.language_model.layers.{kk}.linear_attn.norm", "model.layers.{kk}.linear_attn.norm", + + # Gemma4 per-layer input norm + "model.language_model.layers.{kk}.post_per_layer_input_norm", + "model.layers.{kk}.post_per_layer_input_norm", }, 'vision_layers': { @@ -925,7 +955,8 @@ def get_model_layer_config(return_non_layered=True): # qwen 3 vl "model.visual.deepstack_merger_list.{kk}.linear_fc1", "model.visual.deepstack_merger_list.{kk}.linear_fc2", - "model.visual.merger.linear_fc{kk}", + "model.visual.merger.linear_fc1", + "model.visual.merger.linear_fc2", }, "non_layered_components":{ @@ -1043,47 +1074,95 @@ def _get_nested_attr(obj, attr_path: str): def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_state_dict): gdn = gdn_module + def _unwrap(v): + return getattr(v, "data", v) + def store(name, value): state_dict[name] = value quant_state_dict[name] = value if hasattr(gdn, "in_proj_qkvz"): proj = getattr(gdn.in_proj_qkvz, "base_layer", gdn.in_proj_qkvz) - weight = proj.weight - output_sizes = list(proj.output_sizes) + weight = _unwrap(proj.weight) + + output_sizes = getattr(proj, "output_sizes", None) + if output_sizes is None: + key_dim = getattr(gdn, "key_dim", None) + value_dim = getattr(gdn, "value_dim", None) + if key_dim is None or value_dim is None: + raise RuntimeError( + "Unsloth: cannot infer GDN in_proj_qkvz shards without " + "proj.output_sizes or gdn.key_dim / gdn.value_dim" + ) + output_sizes = [key_dim, key_dim, value_dim, value_dim] + output_sizes = list(output_sizes) offsets = [0] for s in output_sizes: offsets.append(offsets[-1] + s) + if len(offsets) < 5: + raise RuntimeError( + f"Unsloth: GDN in_proj_qkvz expected 4 shards (q,k,v,z); got sizes={output_sizes}" + ) + qkv_weight = weight[offsets[0]:offsets[3]] z_weight = weight[offsets[3]:offsets[4]] store(f"{prefix}.in_proj_qkv.weight", qkv_weight) store(f"{prefix}.in_proj_z.weight", z_weight) + + qs_attr = getattr(weight, "bnb_quant_state", None) + if isinstance(qs_attr, dict): + qkv_qs = qs_attr.get(0) + z_qs = qs_attr.get(3) + if qkv_qs is not None: + quant_state_dict[f"{prefix}.in_proj_qkv.weight.quant_state"] = qkv_qs + try: + for k, v in qkv_qs.as_dict(packed=True).items(): + state_dict[f"{prefix}.in_proj_qkv.weight.{k}"] = v + except Exception: + pass + if z_qs is not None: + quant_state_dict[f"{prefix}.in_proj_z.weight.quant_state"] = z_qs + try: + for k, v in z_qs.as_dict(packed=True).items(): + state_dict[f"{prefix}.in_proj_z.weight.{k}"] = v + except Exception: + pass + if weight.dtype == torch.float8_e4m3fn: - if hasattr(proj, 'weight_scale'): - ws = proj.weight_scale - elif hasattr(proj, 'weight_scale_inv'): - ws = proj.weight_scale_inv - else: - ws = None - if ws is not None and ws.ndim == 2 and ws.shape[1] > 1: - block_size = proj.weight_block_size[0] - scale_offsets = [x // block_size for x in offsets] - scale_suffix = '.weight_scale_inv' - qkv_scale = ws[scale_offsets[0]:scale_offsets[3]] - z_scale = ws[scale_offsets[3]:scale_offsets[4]] - store(f"{prefix}.in_proj_qkv{scale_suffix}", qkv_scale) - store(f"{prefix}.in_proj_z{scale_suffix}", z_scale) + scale_attr = None + if hasattr(proj, "weight_scale"): + scale_attr = "weight_scale" + elif hasattr(proj, "weight_scale_inv"): + scale_attr = "weight_scale_inv" + ws = _unwrap(getattr(proj, scale_attr)) if scale_attr is not None else None + if ws is not None: + if ws.ndim == 2 and ws.shape[1] > 1: + block_size = proj.weight_block_size[0] + scale_offsets = [x // block_size for x in offsets] + qkv_scale = ws[scale_offsets[0]:scale_offsets[3]] + z_scale = ws[scale_offsets[3]:scale_offsets[4]] + else: + qkv_scale = ws[offsets[0]:offsets[3]] + z_scale = ws[offsets[3]:offsets[4]] + store(f"{prefix}.in_proj_qkv.{scale_attr}", qkv_scale) + store(f"{prefix}.in_proj_z.{scale_attr}", z_scale) else: get_state_dict(f"{prefix}.in_proj_qkv", 0, state_dict, gdn.in_proj_qkv, slice_weights=False) get_state_dict(f"{prefix}.in_proj_z", 0, state_dict, gdn.in_proj_z, slice_weights=False) - get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba) - get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba) + ba_layer = getattr(gdn.in_proj_ba, "base_layer", gdn.in_proj_ba) + ba_weight = _unwrap(ba_layer.weight) + mid = ba_weight.shape[0] // 2 + store(f"{prefix}.in_proj_b.weight", ba_weight[:mid]) + store(f"{prefix}.in_proj_a.weight", ba_weight[mid:]) store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data) store(f"{prefix}.dt_bias", gdn.dt_bias.data) store(f"{prefix}.A_log", gdn.A_log.data) + if hasattr(gdn, "norm") and hasattr(gdn.norm, "weight"): + store(f"{prefix}.norm.weight", gdn.norm.weight.data) + get_state_dict(f"{prefix}.out_proj", 0, state_dict, gdn.out_proj) pass diff --git a/unsloth_zoo/hf_utils.py b/unsloth_zoo/hf_utils.py index 27ff771e9..8b99603e6 100644 --- a/unsloth_zoo/hf_utils.py +++ b/unsloth_zoo/hf_utils.py @@ -51,15 +51,12 @@ def dtype_from_config(config): def set_dtype_in_config(config, dtype): runtime_dtype = getattr(torch, dtype, dtype) if isinstance(dtype, str) else dtype - target_fields = [] - if hasattr(config, "dtype"): - target_fields.append("dtype") - if hasattr(config, "torch_dtype"): - target_fields.append("torch_dtype") - - if len(target_fields) == 0: - target_fields.append("torch_dtype" if HAS_TORCH_DTYPE else "dtype") + target_fields = ["dtype"] + elif hasattr(config, "torch_dtype"): + target_fields = ["torch_dtype"] + else: + target_fields = ["dtype" if HAS_TORCH_DTYPE else "torch_dtype"] success = False for field in target_fields: diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index b319e7720..5d5999db1 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1137,6 +1137,12 @@ def _is_fused_module(name: str) -> bool: ) pass + if not hasattr(layer, "mlp"): + if hasattr(layer, "layer_scalar"): + state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data + quant_state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data + continue + proj = layer.mlp.gate_up_proj use_fused_gate_up = _is_fused_module("gate_up_proj") if use_fused_gate_up: @@ -1216,6 +1222,8 @@ def assert_same_state_dict(old_state_dict, new_state_dict): def _normalize_state_dict_tensor(value): if isinstance(value, torch.nn.Parameter): value = value.detach() + if not isinstance(value, torch.Tensor): + return value if value.is_sparse: value = value.to_dense() return value.contiguous() @@ -1779,9 +1787,11 @@ def load_vllm( assert(type(use_bitsandbytes) is bool) assert(conservativeness >= 0.0 and conservativeness <= 1.0) - if is_vision_model and getattr(config, "model_type", None) == "gemma4": - patch_gemma4_vllm_lora_support() - patch_gemma4_vllm_k_eq_v_support() + if getattr(config, "model_type", None) == "gemma4": + if enable_lora: + patch_gemma4_vllm_lora_support() + if use_bitsandbytes: + patch_gemma4_vllm_k_eq_v_support() unsloth_vllm_standby = unsloth_vllm_standby or (os.getenv("UNSLOTH_VLLM_STANDBY", "0") != "0") # This would give the flexibility to override the util we set for standby mode. In some extreme cases, this can be helpful. From 68a02c3da19b6066ca170bb445189c486430e24b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Apr 2026 11:12:24 +0000 Subject: [PATCH 10/11] Add review tests --- tests/test_finalize_huggingface_model.py | 114 +++++++++++++++++++ tests/test_gdn_extraction.py | 134 +++++++++++++++++++++++ tests/test_vllm_conversion_helpers.py | 102 +++++++++++++++++ 3 files changed, 350 insertions(+) create mode 100644 tests/test_finalize_huggingface_model.py create mode 100644 tests/test_gdn_extraction.py create mode 100644 tests/test_vllm_conversion_helpers.py diff --git a/tests/test_finalize_huggingface_model.py b/tests/test_finalize_huggingface_model.py new file mode 100644 index 000000000..fe077cd10 --- /dev/null +++ b/tests/test_finalize_huggingface_model.py @@ -0,0 +1,114 @@ +import sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import types +import torch +from unsloth_zoo.empty_model import finalize_huggingface_model + + +class _LinearAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_idx = -1 + + +class _StandardLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_idx = -1 + self.linear_attn = _LinearAttn() + + +class _StandardLM(torch.nn.Module): + def __init__(self, n_layers=3): + super().__init__() + + class _Inner(torch.nn.Module): + def __init__(self, n): + super().__init__() + self.layers = torch.nn.ModuleList([_StandardLayer() for _ in range(n)]) + + self.model = _Inner(n_layers) + + +def _make_config(model_type="qwen3_5", has_vision=False): + cfg = types.SimpleNamespace() + cfg.model_type = model_type + cfg.text_config = cfg + if has_vision: + vc = types.SimpleNamespace() + vc.hidden_size = 1 + vc.num_heads = 1 + cfg.vision_config = vc + return cfg + + +def test_finalize_fixes_layer_idx_on_standard_causal_lm(): + # Pre-fix: finalize_huggingface_model only touched new_model.model.language_model.layers, + # so standard-LM paths kept layer_idx at the empty-model stub value. + model = _StandardLM(n_layers=4) + cfg = _make_config(model_type="qwen3_5") + finalize_huggingface_model( + model, None, cfg, torch.float16, + quantization_config={"x": 1}, bnb_config=None, # skip .to() to avoid meta tensors + ) + for i, layer in enumerate(model.model.layers): + assert layer.layer_idx == i + assert layer.linear_attn.layer_idx == i + + +def test_finalize_also_handles_vlm_language_model_path(): + # Original VLM path should still work. + class _VLM(torch.nn.Module): + def __init__(self): + super().__init__() + + class _Inner(torch.nn.Module): + def __init__(self): + super().__init__() + + class _LM(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([_StandardLayer() for _ in range(3)]) + + self.language_model = _LM() + + self.model = _Inner() + + model = _VLM() + cfg = _make_config() + finalize_huggingface_model( + model, None, cfg, torch.float16, + quantization_config={"x": 1}, bnb_config=None, + ) + for i, layer in enumerate(model.model.language_model.layers): + assert layer.layer_idx == i + assert layer.linear_attn.layer_idx == i + + +def test_finalize_does_not_assert_when_rotary_pos_emb_without_vision_config(): + # Pre-fix: hard `assert vision_config is not None` crashed text-only models that + # happened to expose a rotary_pos_emb attr. Post-fix: skip silently. + class _Rotary(torch.nn.Module): + def __init__(self): + super().__init__() + + class _Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.rotary_pos_emb = _Rotary() + + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = torch.nn.Module() + self.model.layers = torch.nn.ModuleList([_Layer()]) + + model = _Model() + cfg = _make_config(has_vision=False) + # Should not raise AssertionError + finalize_huggingface_model( + model, None, cfg, torch.float16, + quantization_config={"x": 1}, bnb_config=None, + ) diff --git a/tests/test_gdn_extraction.py b/tests/test_gdn_extraction.py new file mode 100644 index 000000000..c4ba55965 --- /dev/null +++ b/tests/test_gdn_extraction.py @@ -0,0 +1,134 @@ +import sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import pytest +import torch +from unsloth_zoo.empty_model import extract_gdn_layers + + +class _FakePlainProj(torch.nn.Module): + # Simulates vLLM ColumnParallelLinear: plain Linear-like with .weight but no output_sizes. + def __init__(self, out_features, in_features, dtype=torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(out_features, in_features, dtype=dtype), requires_grad=False) + + +class _FakeRowProj(torch.nn.Module): + def __init__(self, out_features, in_features, dtype=torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(out_features, in_features, dtype=dtype), requires_grad=False) + + +class _FakeGDN(torch.nn.Module): + def __init__(self, hidden_size=8, num_k_heads=2, num_v_heads=2, head_k_dim=2, head_v_dim=4): + super().__init__() + self.hidden_size = hidden_size + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.key_dim = num_k_heads * head_k_dim + self.value_dim = num_v_heads * head_v_dim + qkvz_dim = self.key_dim * 2 + self.value_dim * 2 + self.in_proj_qkvz = _FakePlainProj(qkvz_dim, hidden_size) + self.in_proj_ba = _FakePlainProj(num_v_heads * 2, hidden_size) + self.conv1d = _FakePlainProj(self.key_dim * 2 + self.value_dim, 4) + self.dt_bias = torch.nn.Parameter(torch.randn(num_v_heads), requires_grad=False) + self.A_log = torch.nn.Parameter(torch.randn(num_v_heads), requires_grad=False) + self.norm = torch.nn.Module() + self.norm.weight = torch.nn.Parameter(torch.randn(head_v_dim), requires_grad=False) + self.out_proj = _FakeRowProj(hidden_size, self.value_dim) + + +def _fake_get_state_dict(prefix, kk, state_dict, module, slice_weights=True): + state_dict[f"{prefix}.weight"] = module.weight.data + + +def test_extract_gdn_layers_no_output_sizes_does_not_crash(): + gdn = _FakeGDN() + state_dict = {} + quant_state_dict = {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + assert "prefix.in_proj_qkv.weight" in state_dict + assert "prefix.in_proj_z.weight" in state_dict + + +def test_extract_gdn_layers_splits_ba_without_indexerror(): + gdn = _FakeGDN() + state_dict = {} + quant_state_dict = {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + assert "prefix.in_proj_b.weight" in state_dict + assert "prefix.in_proj_a.weight" in state_dict + ba_weight = gdn.in_proj_ba.weight.data + mid = ba_weight.shape[0] // 2 + torch.testing.assert_close(state_dict["prefix.in_proj_b.weight"], ba_weight[:mid]) + torch.testing.assert_close(state_dict["prefix.in_proj_a.weight"], ba_weight[mid:]) + + +def test_extract_gdn_layers_exports_norm_weight(): + gdn = _FakeGDN() + state_dict = {} + quant_state_dict = {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + assert "prefix.norm.weight" in state_dict + torch.testing.assert_close(state_dict["prefix.norm.weight"], gdn.norm.weight.data) + + +def test_extract_gdn_layers_exports_conv1d_dtbias_alog(): + gdn = _FakeGDN() + state_dict = {} + quant_state_dict = {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + assert "prefix.conv1d.weight" in state_dict + assert "prefix.dt_bias" in state_dict + assert "prefix.A_log" in state_dict + + +def test_extract_gdn_layers_qkvz_offsets_match_gdn_dims(): + gdn = _FakeGDN(num_k_heads=3, num_v_heads=2, head_k_dim=4, head_v_dim=5) + state_dict = {} + quant_state_dict = {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + # offsets = [0, key_dim, 2*key_dim, 2*key_dim+value_dim, 2*key_dim+2*value_dim] + # qkv = rows [0 : 2*key_dim+value_dim], z = rows [that : end] + key_dim = gdn.key_dim + value_dim = gdn.value_dim + expected_qkv_rows = 2 * key_dim + value_dim + expected_z_rows = value_dim + assert state_dict["prefix.in_proj_qkv.weight"].shape[0] == expected_qkv_rows + assert state_dict["prefix.in_proj_z.weight"].shape[0] == expected_z_rows + + +def test_extract_gdn_layers_raises_when_dims_missing(): + gdn = _FakeGDN() + # Strip dim attrs and output_sizes so offset derivation fails. + del gdn.key_dim + del gdn.value_dim + with pytest.raises(RuntimeError, match="in_proj_qkvz"): + extract_gdn_layers(gdn, "prefix", {}, {}, _fake_get_state_dict) + + +def test_extract_gdn_layers_preserves_bnb_quant_state_sidecars(): + gdn = _FakeGDN() + + class _FakeQS: + def __init__(self, label): + self.label = label + def as_dict(self, packed=True): + return {"absmax": torch.tensor([float(hash(self.label) % 100)])} + + qkvz_weight = gdn.in_proj_qkvz.weight.data.clone() + qkvz_weight.bnb_quant_state = {0: _FakeQS("q"), 1: _FakeQS("k"), 2: _FakeQS("v"), 3: _FakeQS("z")} + gdn.in_proj_qkvz.weight = torch.nn.Parameter(qkvz_weight, requires_grad=False) + # Re-attach bnb_quant_state after re-wrap (nn.Parameter copies base tensor; simulate via separate path) + # Since attach-onto-Parameter may not propagate, set it directly via data wrapper attribute + gdn.in_proj_qkvz.weight.data.bnb_quant_state = qkvz_weight.bnb_quant_state + # Our extract_gdn_layers reads getattr(weight, "bnb_quant_state", None) on unwrapped weight + # which via _unwrap becomes weight.data; attach on data to satisfy that path + state_dict = {} + quant_state_dict = {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + # At minimum, the call must not crash and qkv/z weights must be exported. + assert "prefix.in_proj_qkv.weight" in state_dict + assert "prefix.in_proj_z.weight" in state_dict diff --git a/tests/test_vllm_conversion_helpers.py b/tests/test_vllm_conversion_helpers.py new file mode 100644 index 000000000..d82cb2df2 --- /dev/null +++ b/tests/test_vllm_conversion_helpers.py @@ -0,0 +1,102 @@ +import sys, os, warnings, inspect +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import types +import torch + + +def test_set_dtype_in_config_no_torch_dtype_deprecation(): + # Pre-fix: wrote both dtype and torch_dtype -> triggered transformers deprecation warning. + # Post-fix: writes only dtype when available. + from transformers import PretrainedConfig + from unsloth_zoo.hf_utils import set_dtype_in_config + + cfg = PretrainedConfig() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + set_dtype_in_config(cfg, torch.bfloat16) + dep_warnings = [ + w for w in caught + if "torch_dtype" in str(w.message) and "deprecated" in str(w.message).lower() + ] + assert not dep_warnings, f"unexpected deprecation warning: {[str(w.message) for w in dep_warnings]}" + + +def test_set_dtype_in_config_writes_runtime_dtype(): + from transformers import PretrainedConfig + from unsloth_zoo.hf_utils import set_dtype_in_config + + cfg = PretrainedConfig() + set_dtype_in_config(cfg, torch.float16) + # Either dtype or torch_dtype (aliased via property in modern transformers) should reflect it. + got = getattr(cfg, "dtype", None) or getattr(cfg, "torch_dtype", None) + assert got == torch.float16 + + +def test_set_dtype_in_config_accepts_string(): + from transformers import PretrainedConfig + from unsloth_zoo.hf_utils import set_dtype_in_config + + cfg = PretrainedConfig() + set_dtype_in_config(cfg, "bfloat16") + got = getattr(cfg, "dtype", None) or getattr(cfg, "torch_dtype", None) + assert got == torch.bfloat16 + + +def test_normalize_state_dict_tensor_guards_non_tensor(): + # Pre-fix: _normalize_state_dict_tensor called value.is_sparse unconditionally. + # Post-fix: the is_sparse branch is guarded by isinstance(value, torch.Tensor). + from unsloth_zoo import vllm_utils + + src = inspect.getsource(vllm_utils.assert_same_state_dict) + assert "isinstance(value, torch.Tensor)" in src + assert src.index("isinstance(value, torch.Tensor)") < src.index("value.is_sparse") + + +def test_gemma4_lora_patch_preserves_callable_signature(): + # Pre-fix: patched_create_lora_manager was `(model, *args, **kwargs)`, which hid vllm_config + # from `inspect.signature` and broke `_call_create_lora_manager`'s signature check. + # Post-fix: @functools.wraps preserves the original signature. + from functools import wraps + + def original_create_lora_manager( + model, max_num_seqs=None, vllm_config=None, lora_manager_cls=None, **kwargs, + ): + return (model, vllm_config, lora_manager_cls) + + @wraps(original_create_lora_manager) + def patched_create_lora_manager(model, *args, **kwargs): + return original_create_lora_manager(model, *args, **kwargs) + + sig = inspect.signature(patched_create_lora_manager) + assert "vllm_config" in sig.parameters + + +def test_gemma4_lora_patch_positional_model_no_double_bind(): + # Pre-fix: `lora_manager_cls(model=model, *args, **kwargs)` raised + # "TypeError: multiple values for argument 'model'" if *args was non-empty. + # Post-fix: pass model positionally. + class _LoRAManagerCls: + def __init__(self, model, extra=None, **kwargs): + self.model = model + self.extra = extra + self.kwargs = kwargs + + # Post-fix semantics: lora_manager_cls(model, *args, **kwargs) + inst = _LoRAManagerCls("fake_model", "extra_positional", vllm_config="cfg") + assert inst.model == "fake_model" + assert inst.extra == "extra_positional" + assert inst.kwargs == {"vllm_config": "cfg"} + + +def test_gemma4_k_eq_v_pairs_handles_split_layout(): + # Pre-fix: _get_gemma4_k_eq_v_qkv_param_names only searched for packed `qkv_proj.weight`. + # Post-fix: detects split `k_proj.weight` / `v_proj.weight` layout too. + import inspect as _inspect + from unsloth_zoo import empty_model + + src = _inspect.getsource(empty_model.patch_gemma4_vllm_k_eq_v_support) + # Sanity: the split-layout branch exists. + assert "k_proj.weight" in src + assert "v_proj.weight" in src + assert '"split"' in src or "'split'" in src From cebdaf363ceef632010b1f4651f5bd04a753e441 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Apr 2026 11:15:38 +0000 Subject: [PATCH 11/11] Consolidate review tests into test_vllm_to_hf_conversion.py --- tests/test_finalize_huggingface_model.py | 114 ---------- tests/test_gdn_extraction.py | 134 ------------ tests/test_vllm_conversion_helpers.py | 102 --------- tests/test_vllm_to_hf_conversion.py | 259 +++++++++++++++++++++++ 4 files changed, 259 insertions(+), 350 deletions(-) delete mode 100644 tests/test_finalize_huggingface_model.py delete mode 100644 tests/test_gdn_extraction.py delete mode 100644 tests/test_vllm_conversion_helpers.py create mode 100644 tests/test_vllm_to_hf_conversion.py diff --git a/tests/test_finalize_huggingface_model.py b/tests/test_finalize_huggingface_model.py deleted file mode 100644 index fe077cd10..000000000 --- a/tests/test_finalize_huggingface_model.py +++ /dev/null @@ -1,114 +0,0 @@ -import sys, os -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - -import types -import torch -from unsloth_zoo.empty_model import finalize_huggingface_model - - -class _LinearAttn(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer_idx = -1 - - -class _StandardLayer(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer_idx = -1 - self.linear_attn = _LinearAttn() - - -class _StandardLM(torch.nn.Module): - def __init__(self, n_layers=3): - super().__init__() - - class _Inner(torch.nn.Module): - def __init__(self, n): - super().__init__() - self.layers = torch.nn.ModuleList([_StandardLayer() for _ in range(n)]) - - self.model = _Inner(n_layers) - - -def _make_config(model_type="qwen3_5", has_vision=False): - cfg = types.SimpleNamespace() - cfg.model_type = model_type - cfg.text_config = cfg - if has_vision: - vc = types.SimpleNamespace() - vc.hidden_size = 1 - vc.num_heads = 1 - cfg.vision_config = vc - return cfg - - -def test_finalize_fixes_layer_idx_on_standard_causal_lm(): - # Pre-fix: finalize_huggingface_model only touched new_model.model.language_model.layers, - # so standard-LM paths kept layer_idx at the empty-model stub value. - model = _StandardLM(n_layers=4) - cfg = _make_config(model_type="qwen3_5") - finalize_huggingface_model( - model, None, cfg, torch.float16, - quantization_config={"x": 1}, bnb_config=None, # skip .to() to avoid meta tensors - ) - for i, layer in enumerate(model.model.layers): - assert layer.layer_idx == i - assert layer.linear_attn.layer_idx == i - - -def test_finalize_also_handles_vlm_language_model_path(): - # Original VLM path should still work. - class _VLM(torch.nn.Module): - def __init__(self): - super().__init__() - - class _Inner(torch.nn.Module): - def __init__(self): - super().__init__() - - class _LM(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = torch.nn.ModuleList([_StandardLayer() for _ in range(3)]) - - self.language_model = _LM() - - self.model = _Inner() - - model = _VLM() - cfg = _make_config() - finalize_huggingface_model( - model, None, cfg, torch.float16, - quantization_config={"x": 1}, bnb_config=None, - ) - for i, layer in enumerate(model.model.language_model.layers): - assert layer.layer_idx == i - assert layer.linear_attn.layer_idx == i - - -def test_finalize_does_not_assert_when_rotary_pos_emb_without_vision_config(): - # Pre-fix: hard `assert vision_config is not None` crashed text-only models that - # happened to expose a rotary_pos_emb attr. Post-fix: skip silently. - class _Rotary(torch.nn.Module): - def __init__(self): - super().__init__() - - class _Layer(torch.nn.Module): - def __init__(self): - super().__init__() - self.rotary_pos_emb = _Rotary() - - class _Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.model = torch.nn.Module() - self.model.layers = torch.nn.ModuleList([_Layer()]) - - model = _Model() - cfg = _make_config(has_vision=False) - # Should not raise AssertionError - finalize_huggingface_model( - model, None, cfg, torch.float16, - quantization_config={"x": 1}, bnb_config=None, - ) diff --git a/tests/test_gdn_extraction.py b/tests/test_gdn_extraction.py deleted file mode 100644 index c4ba55965..000000000 --- a/tests/test_gdn_extraction.py +++ /dev/null @@ -1,134 +0,0 @@ -import sys, os -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - -import pytest -import torch -from unsloth_zoo.empty_model import extract_gdn_layers - - -class _FakePlainProj(torch.nn.Module): - # Simulates vLLM ColumnParallelLinear: plain Linear-like with .weight but no output_sizes. - def __init__(self, out_features, in_features, dtype=torch.float32): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(out_features, in_features, dtype=dtype), requires_grad=False) - - -class _FakeRowProj(torch.nn.Module): - def __init__(self, out_features, in_features, dtype=torch.float32): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(out_features, in_features, dtype=dtype), requires_grad=False) - - -class _FakeGDN(torch.nn.Module): - def __init__(self, hidden_size=8, num_k_heads=2, num_v_heads=2, head_k_dim=2, head_v_dim=4): - super().__init__() - self.hidden_size = hidden_size - self.num_k_heads = num_k_heads - self.num_v_heads = num_v_heads - self.head_k_dim = head_k_dim - self.head_v_dim = head_v_dim - self.key_dim = num_k_heads * head_k_dim - self.value_dim = num_v_heads * head_v_dim - qkvz_dim = self.key_dim * 2 + self.value_dim * 2 - self.in_proj_qkvz = _FakePlainProj(qkvz_dim, hidden_size) - self.in_proj_ba = _FakePlainProj(num_v_heads * 2, hidden_size) - self.conv1d = _FakePlainProj(self.key_dim * 2 + self.value_dim, 4) - self.dt_bias = torch.nn.Parameter(torch.randn(num_v_heads), requires_grad=False) - self.A_log = torch.nn.Parameter(torch.randn(num_v_heads), requires_grad=False) - self.norm = torch.nn.Module() - self.norm.weight = torch.nn.Parameter(torch.randn(head_v_dim), requires_grad=False) - self.out_proj = _FakeRowProj(hidden_size, self.value_dim) - - -def _fake_get_state_dict(prefix, kk, state_dict, module, slice_weights=True): - state_dict[f"{prefix}.weight"] = module.weight.data - - -def test_extract_gdn_layers_no_output_sizes_does_not_crash(): - gdn = _FakeGDN() - state_dict = {} - quant_state_dict = {} - extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) - assert "prefix.in_proj_qkv.weight" in state_dict - assert "prefix.in_proj_z.weight" in state_dict - - -def test_extract_gdn_layers_splits_ba_without_indexerror(): - gdn = _FakeGDN() - state_dict = {} - quant_state_dict = {} - extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) - assert "prefix.in_proj_b.weight" in state_dict - assert "prefix.in_proj_a.weight" in state_dict - ba_weight = gdn.in_proj_ba.weight.data - mid = ba_weight.shape[0] // 2 - torch.testing.assert_close(state_dict["prefix.in_proj_b.weight"], ba_weight[:mid]) - torch.testing.assert_close(state_dict["prefix.in_proj_a.weight"], ba_weight[mid:]) - - -def test_extract_gdn_layers_exports_norm_weight(): - gdn = _FakeGDN() - state_dict = {} - quant_state_dict = {} - extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) - assert "prefix.norm.weight" in state_dict - torch.testing.assert_close(state_dict["prefix.norm.weight"], gdn.norm.weight.data) - - -def test_extract_gdn_layers_exports_conv1d_dtbias_alog(): - gdn = _FakeGDN() - state_dict = {} - quant_state_dict = {} - extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) - assert "prefix.conv1d.weight" in state_dict - assert "prefix.dt_bias" in state_dict - assert "prefix.A_log" in state_dict - - -def test_extract_gdn_layers_qkvz_offsets_match_gdn_dims(): - gdn = _FakeGDN(num_k_heads=3, num_v_heads=2, head_k_dim=4, head_v_dim=5) - state_dict = {} - quant_state_dict = {} - extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) - # offsets = [0, key_dim, 2*key_dim, 2*key_dim+value_dim, 2*key_dim+2*value_dim] - # qkv = rows [0 : 2*key_dim+value_dim], z = rows [that : end] - key_dim = gdn.key_dim - value_dim = gdn.value_dim - expected_qkv_rows = 2 * key_dim + value_dim - expected_z_rows = value_dim - assert state_dict["prefix.in_proj_qkv.weight"].shape[0] == expected_qkv_rows - assert state_dict["prefix.in_proj_z.weight"].shape[0] == expected_z_rows - - -def test_extract_gdn_layers_raises_when_dims_missing(): - gdn = _FakeGDN() - # Strip dim attrs and output_sizes so offset derivation fails. - del gdn.key_dim - del gdn.value_dim - with pytest.raises(RuntimeError, match="in_proj_qkvz"): - extract_gdn_layers(gdn, "prefix", {}, {}, _fake_get_state_dict) - - -def test_extract_gdn_layers_preserves_bnb_quant_state_sidecars(): - gdn = _FakeGDN() - - class _FakeQS: - def __init__(self, label): - self.label = label - def as_dict(self, packed=True): - return {"absmax": torch.tensor([float(hash(self.label) % 100)])} - - qkvz_weight = gdn.in_proj_qkvz.weight.data.clone() - qkvz_weight.bnb_quant_state = {0: _FakeQS("q"), 1: _FakeQS("k"), 2: _FakeQS("v"), 3: _FakeQS("z")} - gdn.in_proj_qkvz.weight = torch.nn.Parameter(qkvz_weight, requires_grad=False) - # Re-attach bnb_quant_state after re-wrap (nn.Parameter copies base tensor; simulate via separate path) - # Since attach-onto-Parameter may not propagate, set it directly via data wrapper attribute - gdn.in_proj_qkvz.weight.data.bnb_quant_state = qkvz_weight.bnb_quant_state - # Our extract_gdn_layers reads getattr(weight, "bnb_quant_state", None) on unwrapped weight - # which via _unwrap becomes weight.data; attach on data to satisfy that path - state_dict = {} - quant_state_dict = {} - extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) - # At minimum, the call must not crash and qkv/z weights must be exported. - assert "prefix.in_proj_qkv.weight" in state_dict - assert "prefix.in_proj_z.weight" in state_dict diff --git a/tests/test_vllm_conversion_helpers.py b/tests/test_vllm_conversion_helpers.py deleted file mode 100644 index d82cb2df2..000000000 --- a/tests/test_vllm_conversion_helpers.py +++ /dev/null @@ -1,102 +0,0 @@ -import sys, os, warnings, inspect -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - -import types -import torch - - -def test_set_dtype_in_config_no_torch_dtype_deprecation(): - # Pre-fix: wrote both dtype and torch_dtype -> triggered transformers deprecation warning. - # Post-fix: writes only dtype when available. - from transformers import PretrainedConfig - from unsloth_zoo.hf_utils import set_dtype_in_config - - cfg = PretrainedConfig() - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - set_dtype_in_config(cfg, torch.bfloat16) - dep_warnings = [ - w for w in caught - if "torch_dtype" in str(w.message) and "deprecated" in str(w.message).lower() - ] - assert not dep_warnings, f"unexpected deprecation warning: {[str(w.message) for w in dep_warnings]}" - - -def test_set_dtype_in_config_writes_runtime_dtype(): - from transformers import PretrainedConfig - from unsloth_zoo.hf_utils import set_dtype_in_config - - cfg = PretrainedConfig() - set_dtype_in_config(cfg, torch.float16) - # Either dtype or torch_dtype (aliased via property in modern transformers) should reflect it. - got = getattr(cfg, "dtype", None) or getattr(cfg, "torch_dtype", None) - assert got == torch.float16 - - -def test_set_dtype_in_config_accepts_string(): - from transformers import PretrainedConfig - from unsloth_zoo.hf_utils import set_dtype_in_config - - cfg = PretrainedConfig() - set_dtype_in_config(cfg, "bfloat16") - got = getattr(cfg, "dtype", None) or getattr(cfg, "torch_dtype", None) - assert got == torch.bfloat16 - - -def test_normalize_state_dict_tensor_guards_non_tensor(): - # Pre-fix: _normalize_state_dict_tensor called value.is_sparse unconditionally. - # Post-fix: the is_sparse branch is guarded by isinstance(value, torch.Tensor). - from unsloth_zoo import vllm_utils - - src = inspect.getsource(vllm_utils.assert_same_state_dict) - assert "isinstance(value, torch.Tensor)" in src - assert src.index("isinstance(value, torch.Tensor)") < src.index("value.is_sparse") - - -def test_gemma4_lora_patch_preserves_callable_signature(): - # Pre-fix: patched_create_lora_manager was `(model, *args, **kwargs)`, which hid vllm_config - # from `inspect.signature` and broke `_call_create_lora_manager`'s signature check. - # Post-fix: @functools.wraps preserves the original signature. - from functools import wraps - - def original_create_lora_manager( - model, max_num_seqs=None, vllm_config=None, lora_manager_cls=None, **kwargs, - ): - return (model, vllm_config, lora_manager_cls) - - @wraps(original_create_lora_manager) - def patched_create_lora_manager(model, *args, **kwargs): - return original_create_lora_manager(model, *args, **kwargs) - - sig = inspect.signature(patched_create_lora_manager) - assert "vllm_config" in sig.parameters - - -def test_gemma4_lora_patch_positional_model_no_double_bind(): - # Pre-fix: `lora_manager_cls(model=model, *args, **kwargs)` raised - # "TypeError: multiple values for argument 'model'" if *args was non-empty. - # Post-fix: pass model positionally. - class _LoRAManagerCls: - def __init__(self, model, extra=None, **kwargs): - self.model = model - self.extra = extra - self.kwargs = kwargs - - # Post-fix semantics: lora_manager_cls(model, *args, **kwargs) - inst = _LoRAManagerCls("fake_model", "extra_positional", vllm_config="cfg") - assert inst.model == "fake_model" - assert inst.extra == "extra_positional" - assert inst.kwargs == {"vllm_config": "cfg"} - - -def test_gemma4_k_eq_v_pairs_handles_split_layout(): - # Pre-fix: _get_gemma4_k_eq_v_qkv_param_names only searched for packed `qkv_proj.weight`. - # Post-fix: detects split `k_proj.weight` / `v_proj.weight` layout too. - import inspect as _inspect - from unsloth_zoo import empty_model - - src = _inspect.getsource(empty_model.patch_gemma4_vllm_k_eq_v_support) - # Sanity: the split-layout branch exists. - assert "k_proj.weight" in src - assert "v_proj.weight" in src - assert '"split"' in src or "'split'" in src diff --git a/tests/test_vllm_to_hf_conversion.py b/tests/test_vllm_to_hf_conversion.py new file mode 100644 index 000000000..ae6906301 --- /dev/null +++ b/tests/test_vllm_to_hf_conversion.py @@ -0,0 +1,259 @@ +import sys, os, warnings, inspect +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import types +import pytest +import torch + + +class _FakePlainProj(torch.nn.Module): + def __init__(self, out_features, in_features, dtype=torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(out_features, in_features, dtype=dtype), requires_grad=False) + + +class _FakeGDN(torch.nn.Module): + def __init__(self, hidden_size=8, num_k_heads=2, num_v_heads=2, head_k_dim=2, head_v_dim=4): + super().__init__() + self.hidden_size = hidden_size + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.key_dim = num_k_heads * head_k_dim + self.value_dim = num_v_heads * head_v_dim + qkvz_dim = self.key_dim * 2 + self.value_dim * 2 + self.in_proj_qkvz = _FakePlainProj(qkvz_dim, hidden_size) + self.in_proj_ba = _FakePlainProj(num_v_heads * 2, hidden_size) + self.conv1d = _FakePlainProj(self.key_dim * 2 + self.value_dim, 4) + self.dt_bias = torch.nn.Parameter(torch.randn(num_v_heads), requires_grad=False) + self.A_log = torch.nn.Parameter(torch.randn(num_v_heads), requires_grad=False) + self.norm = torch.nn.Module() + self.norm.weight = torch.nn.Parameter(torch.randn(head_v_dim), requires_grad=False) + self.out_proj = _FakePlainProj(hidden_size, self.value_dim) + + +def _fake_get_state_dict(prefix, kk, state_dict, module, slice_weights=True): + state_dict[f"{prefix}.weight"] = module.weight.data + + +def test_extract_gdn_layers_handles_plain_column_parallel_linear(): + # Pre-fix: vllm ColumnParallelLinear has no `output_sizes` -> AttributeError. + from unsloth_zoo.empty_model import extract_gdn_layers + gdn = _FakeGDN() + state_dict, quant_state_dict = {}, {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + expected = { + "prefix.in_proj_qkv.weight", + "prefix.in_proj_z.weight", + "prefix.in_proj_b.weight", + "prefix.in_proj_a.weight", + "prefix.conv1d.weight", + "prefix.dt_bias", + "prefix.A_log", + "prefix.norm.weight", + "prefix.out_proj.weight", + } + assert expected <= set(state_dict.keys()) + + +def test_extract_gdn_layers_splits_in_proj_ba_without_indexerror(): + # Pre-fix: get_state_dict(kk=1, in_proj_ba) -> IndexError (no output_sizes). + from unsloth_zoo.empty_model import extract_gdn_layers + gdn = _FakeGDN() + state_dict, quant_state_dict = {}, {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + ba_weight = gdn.in_proj_ba.weight.data + mid = ba_weight.shape[0] // 2 + torch.testing.assert_close(state_dict["prefix.in_proj_b.weight"], ba_weight[:mid]) + torch.testing.assert_close(state_dict["prefix.in_proj_a.weight"], ba_weight[mid:]) + + +def test_extract_gdn_layers_qkvz_offsets_match_gdn_dims(): + from unsloth_zoo.empty_model import extract_gdn_layers + gdn = _FakeGDN(num_k_heads=3, num_v_heads=2, head_k_dim=4, head_v_dim=5) + state_dict, quant_state_dict = {}, {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + assert state_dict["prefix.in_proj_qkv.weight"].shape[0] == 2 * gdn.key_dim + gdn.value_dim + assert state_dict["prefix.in_proj_z.weight"].shape[0] == gdn.value_dim + + +def test_extract_gdn_layers_raises_when_offsets_underivable(): + from unsloth_zoo.empty_model import extract_gdn_layers + gdn = _FakeGDN() + del gdn.key_dim + del gdn.value_dim + with pytest.raises(RuntimeError, match="in_proj_qkvz"): + extract_gdn_layers(gdn, "prefix", {}, {}, _fake_get_state_dict) + + +def test_extract_gdn_layers_has_bnb_quant_state_preservation(): + # Pre-fix: merged in_proj_qkvz path only stored raw weight slices; BnB prequantized + # checkpoints lost quant_state metadata and were rebuilt as plain nn.Linear. + # Behavioral test requires real BnB; source-level check confirms the branch exists. + from unsloth_zoo import empty_model + src = inspect.getsource(empty_model.extract_gdn_layers) + assert "bnb_quant_state" in src + assert "in_proj_qkv.weight.quant_state" in src + assert "in_proj_z.weight.quant_state" in src + + +class _LinearAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_idx = -1 + + +class _StandardLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_idx = -1 + self.linear_attn = _LinearAttn() + + +class _StandardLM(torch.nn.Module): + def __init__(self, n_layers=3): + super().__init__() + + class _Inner(torch.nn.Module): + def __init__(self, n): + super().__init__() + self.layers = torch.nn.ModuleList([_StandardLayer() for _ in range(n)]) + + self.model = _Inner(n_layers) + + +def _config(model_type="qwen3_5", has_vision=False): + cfg = types.SimpleNamespace() + cfg.model_type = model_type + cfg.text_config = cfg + if has_vision: + vc = types.SimpleNamespace() + vc.hidden_size = 1 + vc.num_heads = 1 + cfg.vision_config = vc + return cfg + + +def test_finalize_fixes_layer_idx_on_standard_causal_lm(): + # Pre-fix: only new_model.model.language_model.layers was traversed, so + # standard-LM paths kept layer_idx at the empty-model stub value. + from unsloth_zoo.empty_model import finalize_huggingface_model + model = _StandardLM(n_layers=4) + finalize_huggingface_model( + model, None, _config("qwen3_5"), torch.float16, + quantization_config={"x": 1}, bnb_config=None, + ) + for i, layer in enumerate(model.model.layers): + assert layer.layer_idx == i + assert layer.linear_attn.layer_idx == i + + +def test_finalize_fixes_layer_idx_on_vlm_language_model_path(): + from unsloth_zoo.empty_model import finalize_huggingface_model + + class _VLM(torch.nn.Module): + def __init__(self): + super().__init__() + + class _Inner(torch.nn.Module): + def __init__(self): + super().__init__() + + class _LM(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([_StandardLayer() for _ in range(3)]) + + self.language_model = _LM() + + self.model = _Inner() + + model = _VLM() + finalize_huggingface_model( + model, None, _config(), torch.float16, + quantization_config={"x": 1}, bnb_config=None, + ) + for i, layer in enumerate(model.model.language_model.layers): + assert layer.layer_idx == i + assert layer.linear_attn.layer_idx == i + + +def test_finalize_does_not_assert_on_text_only_with_rotary_pos_emb(): + # Pre-fix: hard `assert vision_config is not None` crashed text-only models. + from unsloth_zoo.empty_model import finalize_huggingface_model + + class _Rotary(torch.nn.Module): + pass + + class _Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.rotary_pos_emb = _Rotary() + + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = torch.nn.Module() + self.model.layers = torch.nn.ModuleList([_Layer()]) + + finalize_huggingface_model( + _Model(), None, _config(has_vision=False), torch.float16, + quantization_config={"x": 1}, bnb_config=None, + ) + + +def test_set_dtype_in_config_no_torch_dtype_deprecation(): + # Pre-fix: wrote both dtype and torch_dtype -> transformers deprecation warning. + from transformers import PretrainedConfig + from unsloth_zoo.hf_utils import set_dtype_in_config + cfg = PretrainedConfig() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + set_dtype_in_config(cfg, torch.bfloat16) + dep = [w for w in caught if "torch_dtype" in str(w.message) and "deprecated" in str(w.message).lower()] + assert not dep, f"unexpected deprecation warning: {[str(w.message) for w in dep]}" + + +def test_set_dtype_in_config_writes_torch_dtype_value(): + from transformers import PretrainedConfig + from unsloth_zoo.hf_utils import set_dtype_in_config + cfg = PretrainedConfig() + set_dtype_in_config(cfg, torch.float16) + got = getattr(cfg, "dtype", None) or getattr(cfg, "torch_dtype", None) + assert got == torch.float16 + + +def test_set_dtype_in_config_accepts_string_input(): + from transformers import PretrainedConfig + from unsloth_zoo.hf_utils import set_dtype_in_config + cfg = PretrainedConfig() + set_dtype_in_config(cfg, "bfloat16") + got = getattr(cfg, "dtype", None) or getattr(cfg, "torch_dtype", None) + assert got == torch.bfloat16 + + +def test_normalize_state_dict_tensor_guards_non_tensor(): + # Pre-fix: value.is_sparse was called unconditionally on any state-dict value. + from unsloth_zoo import vllm_utils + src = inspect.getsource(vllm_utils.assert_same_state_dict) + assert "isinstance(value, torch.Tensor)" in src + assert src.index("isinstance(value, torch.Tensor)") < src.index("value.is_sparse") + + +def test_gemma4_lora_patch_preserves_signature_for_inspect(): + # Pre-fix: patched_create_lora_manager(model, *args, **kwargs) hid vllm_config, + # breaking _call_create_lora_manager's signature-based forwarding. + from unsloth_zoo import empty_model + src = inspect.getsource(empty_model.patch_gemma4_vllm_lora_support) + assert "@wraps(original_create_lora_manager)" in src + assert "lora_manager_cls(model, *args, **kwargs)" in src + + +def test_gemma4_k_eq_v_patch_handles_split_kv_layout(): + # Pre-fix: only packed self_attn.qkv_proj.weight was searched, so current upstream + # Gemma4 split q_proj/k_proj/v_proj layout never got synthetic V quant-state. + from unsloth_zoo import empty_model + src = inspect.getsource(empty_model.patch_gemma4_vllm_k_eq_v_support) + assert "k_proj.weight" in src and "v_proj.weight" in src + assert '"split"' in src or "'split'" in src