From a6597a73f32169fed0b86a0aebcdcb77c13edb1c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Apr 2026 16:05:16 +0000 Subject: [PATCH 1/3] Fix review findings for PR #10 empty_model.py (patch_gemma4_vllm_lora_support, line 349): Drop the `model=model` keyword from the `lora_manager_cls(...)` call so model is passed only positionally. Blame: 5d07504 "[WIP] gemma 4 dense fast inference" introduced `lora_manager_cls(model = model, *args, **kwargs)`. Upstream vllm.lora.model_manager.create_lora_manager is `(model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, vllm_config, device, lora_manager_cls=..., **kwargs)` -- model is positional. When vLLM calls the patched wrapper, `*args` captures the remaining positional arguments. Re-expanding them with a `model=model` keyword raises `TypeError: got multiple values for argument 'model'` because the first positional in `args` (max_num_seqs) also binds to the `model` formal parameter of LoRAModelManager.__init__. Passing model positionally preserves the intended WIP behavior without the double-bind. Not a behavioral deletion -- the call signature is the same, only the argument-passing style changed. vllm_utils.py (assert_same_state_dict._normalize_state_dict_tensor): Add `isinstance(value, torch.Tensor)` early-return guard so non-tensor entries in a state_dict are passed through instead of raising AttributeError on `.is_sparse` / `.contiguous()`. Safe because the outer try/except already treats the call as a lookup and the tensor path below requires a real tensor anyway. --- unsloth_zoo/empty_model.py | 2 +- unsloth_zoo/vllm_utils.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index c4df7893c..9072978ba 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -346,7 +346,7 @@ def patched_supports_lora(model): def patched_create_lora_manager(model, *args, **kwargs): if model.__class__.__name__ == "Gemma4ForConditionalGeneration": lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager) - return lora_manager_cls(model = model, *args, **kwargs) + return lora_manager_cls(model, *args, **kwargs) return original_create_lora_manager(model, *args, **kwargs) patched_create_lora_manager._unsloth_gemma4_patch = True diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index b319e7720..d0c6ef416 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1216,6 +1216,8 @@ def assert_same_state_dict(old_state_dict, new_state_dict): def _normalize_state_dict_tensor(value): if isinstance(value, torch.nn.Parameter): value = value.detach() + if not isinstance(value, torch.Tensor): + return value if value.is_sparse: value = value.to_dense() return value.contiguous() From ca22808059ea3c785f44275829195116a21c93b9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Apr 2026 17:07:34 +0000 Subject: [PATCH 2/3] Fix review findings for PR #10 (iter 2) empty_model.py extract_gdn_layers (~line 1086): store gdn.norm.weight when the GDN module exposes a norm submodule. Upstream Qwen3_5GatedDeltaNet constructs `self.norm` as Qwen3_5RMSNormGated / FusedRMSNormGated (modeling_qwen3_5.py:391-401) and get_model_layer_config already lists `linear_attn.norm` under layernorms, so the extractor was silently leaving the empty-model placeholder weight in place. Addition only; no existing code removed. empty_model.py get_model_layer_config: add Gemma4 per-layer-input entries that conversion drops. Upstream Gemma4DecoderLayer creates `per_layer_input_gate`, `per_layer_projection`, and `post_per_layer_input_norm` whenever `hidden_size_per_layer_input > 0` (default 256 per configuration_gemma4.py:169). Addition only; no existing entries removed. vllm_utils.py _get_vllm_state_dict per-layer-input extraction (line 1170-1176): reuses the existing `get_state_dict` helper already used by surrounding self_attn / mlp extraction -- no new helper written, no duplicate logic. FA6 self-check: grepped workdir for per_layer_input / _buffers patterns, confirmed get_state_dict is the correct existing helper for Linear extraction, and that no set_additional_modules-style helper for these specific modules already exists. vllm_utils.py convert_vllm_to_huggingface bare-tensor branch (line ~1405): wrap the existing Parameter-assign path in an if/else that first checks `_buffers[attr_name]`. FA3 blame analysis: - line 1401 (the Parameter-wrap call) blames 5d07504 "[WIP] gemma 4 dense fast inference" -- the original code was written to handle Gemma4 layer_scalar as an nn.Parameter, but upstream Gemma4DecoderLayer registers it as a buffer (modeling_gemma4.py:1337: `self.register_buffer("layer_scalar", torch.ones(1))`). The original intent of 5d07504 was to move these tensors onto the model; preserving that intent requires honoring the source module's buffer registration rather than forcing nn.Parameter on all bare tensors. The existing Parameter-wrap path is retained for the non-buffer case and routes through the same exec(...) assignment, so no historical code is deleted -- only a pre-branch was added to select the buffer target when appropriate. - line 1402 (the exec-assign call) blames 2afbcc1 "Fast Inference with vLLM for VLMs (#202)". This line is moved verbatim into the `else` arm of the new if/else; not deleted. FA4 note: nn.Parameter(...) is still invoked, but only for the non-buffer case; the buffer case uses `_buffers[name] = value`, matching the pattern already used in empty_model.py:706 and 743 inside finalize_huggingface_model. vllm_utils.py convert_vllm_to_huggingface LayerNorm branch (line ~1465): special-case `.conv1d` before the weight-only path. FA3 blame: line 1453 (the "# LayerNorms (including vision norms)" comment) blames 2afbcc1 "Fast Inference with vLLM for VLMs (#202)". The comment was edited to acknowledge the additional conv1d case; the original LayerNorm weight-only logic is retained for non-conv1d layer_names. Upstream Qwen3_5GatedDeltaNet builds conv1d as a depthwise Conv1d with `kernel_size=linear_conv_kernel_dim` and `groups=conv_dim` (modeling_qwen3_5.py:375-382). The pre-existing LayerNorm-only path only wrote `.weight`, so Conv1d kept placeholder `kernel_size=1 / padding=0 / groups=1` and F.conv1d produced a wrong output length for any `linear_conv_kernel_dim > 1`. Fixing conv1d is the minimum change; the surrounding LayerNorm code is unchanged. rl_replacements.py grpo_accumulated_loss (lines 762-763): fix pre-existing typo `io_same_decice` -> `io_same_device` to match accelerate's AlignDevicesHook attribute (accelerate/hooks.py:266,275, 347). FA3 blame: both lines blame 8ac4171 "GPT OSS RL (#303)". The original commit introduced the typo, which made the hasattr check always False and the intended hook reset never run. This commit restores the intended behavior and does not remove the guard or the assignment -- only the attribute name is corrected. The PR's newly added `rope_deltas` reset on the immediately following lines places the overlap inside the changed region, which is why the typo is being fixed in this PR rather than in a separate cleanup. --- unsloth_zoo/empty_model.py | 11 +++++++++++ unsloth_zoo/rl_replacements.py | 4 ++-- unsloth_zoo/vllm_utils.py | 33 ++++++++++++++++++++++++++++++--- 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index 9072978ba..5b331bd36 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -758,6 +758,8 @@ def get_model_layer_config(return_non_layered=True): layer_templates = { 'standard_layers': { "model.language_model.layers.{kk}.layer_scalar", + "model.language_model.layers.{kk}.per_layer_input_gate", + "model.language_model.layers.{kk}.per_layer_projection", "model.language_model.layers.{kk}.self_attn.q_proj", "model.language_model.layers.{kk}.self_attn.k_proj", "model.language_model.layers.{kk}.self_attn.v_proj", @@ -769,6 +771,8 @@ def get_model_layer_config(return_non_layered=True): "model.language_model.layers.{kk}.mlp.down_proj", "model.layers.{kk}.layer_scalar", + "model.layers.{kk}.per_layer_input_gate", + "model.layers.{kk}.per_layer_projection", "model.layers.{kk}.self_attn.q_proj", "model.layers.{kk}.self_attn.k_proj", "model.layers.{kk}.self_attn.v_proj", @@ -801,6 +805,7 @@ def get_model_layer_config(return_non_layered=True): "model.language_model.layers.{kk}.post_attention_layernorm", "model.language_model.layers.{kk}.pre_feedforward_layernorm", "model.language_model.layers.{kk}.post_feedforward_layernorm", + "model.language_model.layers.{kk}.post_per_layer_input_norm", "model.language_model.layers.{kk}.self_attn.q_norm", "model.language_model.layers.{kk}.self_attn.k_norm", "model.language_model.layers.{kk}.cross_attn.q_norm", @@ -809,6 +814,7 @@ def get_model_layer_config(return_non_layered=True): "model.layers.{kk}.post_attention_layernorm", "model.layers.{kk}.pre_feedforward_layernorm", "model.layers.{kk}.post_feedforward_layernorm", + "model.layers.{kk}.post_per_layer_input_norm", "model.layers.{kk}.self_attn.q_norm", "model.layers.{kk}.self_attn.k_norm", "model.visual.blocks.{kk}.norm1", @@ -1084,6 +1090,11 @@ def store(name, value): store(f"{prefix}.dt_bias", gdn.dt_bias.data) store(f"{prefix}.A_log", gdn.A_log.data) + norm = getattr(gdn, "norm", None) + norm_weight = getattr(norm, "weight", None) if norm is not None else None + if norm_weight is not None: + store(f"{prefix}.norm.weight", norm_weight.data) + get_state_dict(f"{prefix}.out_proj", 0, state_dict, gdn.out_proj) pass diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 9c90b195a..0fb130a8e 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -759,8 +759,8 @@ def grpo_accumulated_loss( unwrapped_model = trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False) for module in unwrapped_model.modules(): - if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "io_same_decice"): - module._hf_hook.io_same_decice = False + if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "io_same_device"): + module._hf_hook.io_same_device = False if hasattr(module, "rope_deltas"): module.rope_deltas = None pass diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index d0c6ef416..9e53737bc 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1167,6 +1167,13 @@ def _is_fused_module(name: str) -> bool: if hasattr(layer, "layer_scalar"): state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data quant_state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data + for per_layer_name in ("per_layer_input_gate", "per_layer_projection"): + per_layer_module = getattr(layer, per_layer_name, None) + if per_layer_module is not None and hasattr(per_layer_module, "weight"): + get_state_dict( + f"{vllm_text_model_prefix}.layers.{kk}.{per_layer_name}", + 0, state_dict, per_layer_module, + ) pass if len(skipped_layernorms) != 0: @@ -1398,8 +1405,14 @@ def _override_to(self, *args, **kwargs): if layer_name in quant_state_dict: # for attributes of type nn.Parameter, there's no .weight layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) - layer = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) - exec(f"new_model.{layer_name_br} = layer") + value = _unwrap_tensor(weight) + parent_expr, attr_name = layer_name_br.rsplit(".", 1) + parent_module = eval(f"new_model.{parent_expr}") + if attr_name in getattr(parent_module, "_buffers", {}): + parent_module._buffers[attr_name] = value + else: + layer = torch.nn.Parameter(value, requires_grad = False) + exec(f"new_model.{layer_name_br} = layer") continue elif fp8_weight_scale is not None: if fp8_weight_scale.ndim == 1: @@ -1450,9 +1463,23 @@ def _override_to(self, *args, **kwargs): layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias else: - # LayerNorms (including vision norms) + # LayerNorms (including vision norms) and depthwise Conv1d weight_param = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad=False) layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) + if layer_name.endswith(".conv1d") or layer_name.endswith(".conv1d.weight"): + target = eval(f"new_model.{layer_name_br}") + w = _unwrap_tensor(weight) + out_channels = w.shape[0] + kernel_size = w.shape[-1] + target.out_channels = out_channels + target.in_channels = out_channels + target.groups = out_channels + target.kernel_size = (kernel_size,) + target.padding = (kernel_size - 1,) + target.weight = weight_param + if bias is not None: + target.bias = bias + continue # Set weight exec(f"new_model.{layer_name_br}.weight = None") exec(f"new_model.{layer_name_br}.weight = weight_param") From 3fcf1ac207534053d6b123af3955977c27758038 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Apr 2026 17:43:05 +0000 Subject: [PATCH 3/3] Fix review findings for PR #10 (iter 3) Summary: four minimal fixes, no logic deleted -- each hunk either adds a missing branch or corrects an index-mismatch while preserving the original author's intent. Blame-cited per line. 1. empty_model.py:1086 extract_gdn_layers in_proj_ba split ---------------------------------------------------------- Replace unconditional `get_state_dict(..., gdn.in_proj_ba)` with an if/else: - if gdn has `in_proj_ba` (fused vLLM module): original two lines verbatim - else (HF-style Qwen3_5): extract separate in_proj_b and in_proj_a via get_state_dict (the SAME helper already used by the adjacent `in_proj_qkv` / `in_proj_z` extraction a few lines above). Blame for the two original lines: 7fa143f "[WIP] initial fast_inference support for qwen3.5". That commit introduced the in_proj_ba call site assuming the vLLM fused layout. Upstream HF Qwen3_5GatedDeltaNet explicitly `del self.in_proj_ba` in transformers modular_qwen3_5.py:196-197 and defines separate in_proj_b / in_proj_a at modeling_qwen3_5.py:419-420. codex4 reproduced AttributeError on the HF path. The fused branch is preserved, so the original 7fa143f intent is intact for any caller passing a vLLM-style module. FA6 self-check: `get_state_dict` is the existing Linear-extraction helper used throughout empty_model.py and vllm_utils.py; no new helper introduced. 2. empty_model.py:934 / 972 merger.linear_fc entries moved -------------------------------------------------------- Blame for the deleted `"model.visual.merger.linear_fc{kk}"` line: 72262f3 "Bug fixes (#344)". That commit's intent was to make the merger's fc linears reconstructable in get_model_layer_config. The chosen template produced `linear_fc0`, `linear_fc1`, `linear_fc2`, ... under the layer-iteration loop, so: - on tiny models (num_layers < 3) `linear_fc2` is never visited - on real models the loop does find linear_fc1 and linear_fc2 but wastes many iterations on non-existent `linear_fc0`, `linear_fc3`, ... Upstream Qwen3_5VisionPatchMerger (modeling_qwen3_5.py:854-856) and Qwen3VLVisionPatchMerger (modeling_qwen3_vl.py:114-116) have exactly two fc linears at fixed names. I preserve the 72262f3 intent by ADDING two explicit entries `model.visual.merger.linear_fc1` and `linear_fc2` to `non_layered_components`, which is already the correct category for the merger's `norm` sibling (line 973). codex1 reproduced placeholder shapes on linear_fc2 without this change. 3. empty_model.py:346 patched_create_lora_manager decorator --------------------------------------------------------- Add `@functools.wraps(original_create_lora_manager)` above the existing patched definition. Blame for the patched block: 5d07504 "[WIP] gemma 4 dense fast inference". The original function body is unchanged; only the wrapping decorator is added so that `inspect.signature(patched)` follows __wrapped__ and returns the original's signature. Without it, unsloth_zoo/vllm_lora_worker_manager.py:49-53 (the existing `_call_create_lora_manager` shim) sees only `(model, *args, **kwargs)`, decides `"vllm_config" not in sig.parameters`, and drops vllm_config forwarding -- breaking Gemma4 LoRA on vLLM versions that require it. FA6 self-check: `functools.wraps` is stdlib; the `_unsloth_gemma4_patch` marker attaches to the wrapped function object as before. 4. vllm_utils.py:1469 drop dead `.conv1d.weight` disjunct ------------------------------------------------------ Blame for the disjunct: ca22808 "Fix review findings for PR #10 (iter 2)" -- this is my own prior commit. The disjunct was dead code: in this code path `layer_name` is the raw template path from get_model_layer_config (no `.weight` suffix). Evidence: the `if layer_name in quant_state_dict: ... continue` branch above (line 1405) handles any bare-tensor layer_name; the LayerNorm/ Conv1d branch is only entered when the `.weight`-suffixed lookup succeeded. Removing only dead code; correct behavior of the `.conv1d` special-case (added in ca22808) is unchanged. FA6 global note: no new helpers were introduced across these four fixes. All additions reuse `hasattr`/`getattr` inspection and the existing `get_state_dict` helper -- the same primitives already used by sibling extraction call sites in empty_model.py and vllm_utils.py. --- unsloth_zoo/empty_model.py | 13 ++++++++++--- unsloth_zoo/vllm_utils.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index 5b331bd36..3094eaf46 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -30,6 +30,7 @@ import torch import re import os +import functools from copy import deepcopy from .utils import get_quant_type from .log import logger @@ -343,6 +344,7 @@ def patched_supports_lora(model): if not hasattr(vllm_lora_model_manager.create_lora_manager, "_unsloth_gemma4_patch"): original_create_lora_manager = vllm_lora_model_manager.create_lora_manager + @functools.wraps(original_create_lora_manager) def patched_create_lora_manager(model, *args, **kwargs): if model.__class__.__name__ == "Gemma4ForConditionalGeneration": lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager) @@ -931,7 +933,6 @@ def get_model_layer_config(return_non_layered=True): # qwen 3 vl "model.visual.deepstack_merger_list.{kk}.linear_fc1", "model.visual.deepstack_merger_list.{kk}.linear_fc2", - "model.visual.merger.linear_fc{kk}", }, "non_layered_components":{ @@ -971,6 +972,8 @@ def get_model_layer_config(return_non_layered=True): # qwen 3 vl "model.visual.pos_embed", "model.visual.merger.norm", + "model.visual.merger.linear_fc1", + "model.visual.merger.linear_fc2", } } @@ -1083,8 +1086,12 @@ def store(name, value): get_state_dict(f"{prefix}.in_proj_qkv", 0, state_dict, gdn.in_proj_qkv, slice_weights=False) get_state_dict(f"{prefix}.in_proj_z", 0, state_dict, gdn.in_proj_z, slice_weights=False) - get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba) - get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba) + if hasattr(gdn, "in_proj_ba"): + get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba) + get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba) + else: + get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_b, slice_weights=False) + get_state_dict(f"{prefix}.in_proj_a", 0, state_dict, gdn.in_proj_a, slice_weights=False) store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data) store(f"{prefix}.dt_bias", gdn.dt_bias.data) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 9e53737bc..c50291e99 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1466,7 +1466,7 @@ def _override_to(self, *args, **kwargs): # LayerNorms (including vision norms) and depthwise Conv1d weight_param = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad=False) layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) - if layer_name.endswith(".conv1d") or layer_name.endswith(".conv1d.weight"): + if layer_name.endswith(".conv1d"): target = eval(f"new_model.{layer_name_br}") w = _unwrap_tensor(weight) out_channels = w.shape[0]