[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #8
Conversation
Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
…e_huggingface_model
- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
_call_create_lora_manager's signature inspection still sees vllm_config; pass model
positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
rotary_pos_emb on vision_config availability; mirror language_model detection from
set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
(defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
so non-tensor state-dict values pass through.
…n, finalize_huggingface_model
- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
_call_create_lora_manager's signature inspection still sees vllm_config; pass model
positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
rotary_pos_emb on vision_config availability; mirror language_model detection from
set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
(defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
so non-tensor state-dict values pass through.
- hf_utils.set_dtype_in_config: store string (JSON-safe, keeps string
comparisons in patch_model_and_tokenizer working); fix fallback
else-branch that had the HAS_TORCH_DTYPE field selection inverted.
- empty_model.extract_gdn_layers: read bnb_quant_state off the raw
Params4bit before unwrapping .data; emit weight.quant_state and FP8
weight_scale(_inv) shards for the in_proj_b / in_proj_a split so
quantized Qwen3.5 GDN layers round-trip correctly.
- vllm_utils.convert_vllm_to_huggingface: rebuild linear_attn.conv1d
as a grouped Conv1d with real channels/kernel_size/groups/padding
instead of treating it as a LayerNorm-style weight swap.
- empty_model.patch_gemma4_vllm_lora_support: soft-import
vllm.v1.worker.lora_model_runner_mixin so older supported vLLM
layouts keep working.
- vllm_utils._get_vllm_state_dict: extract Gemma4 per_layer_input_gate
and per_layer_projection so converted HF models carry the real
checkpoint weights.
- empty_model.finalize_huggingface_model: restrict dtype propagation
to the top-level config and its known text/vision/audio subconfigs;
consolidate the duplicated Gemma4 rotary re-init into one loop while
keeping the post-.to(dtype) float32 buffer / attention_scaling
restoration.
- vllm_utils.assert_same_state_dict: _normalize_state_dict_tensor now
returns None for non-tensor entries (e.g. BnB QuantState dicts) and
callers skip those; align tied-embedding fallback tolerances with
the outer comparison (atol=1e-4, rtol=1e-3).
- vllm_utils._test_is_same_vlm: cast only floating-point tensors to
model.dtype for Gemma3/Gemma4 processors, leaving integer inputs
like pixel_values untouched.
- vllm_utils._get_vllm_state_dict: collapse the unreachable lm_head
elif chain; hoist the constant model_type/attention_k_eq_v check
out of the gemma4_k_eq_v_layers set comprehension.
- empty_model.get_model_layer_config: move model.visual.merger.
linear_fc1 / linear_fc2 from additional_layers (which expected a
{kk} placeholder) into non_layered_components.
# Conflicts: # unsloth_zoo/empty_model.py # unsloth_zoo/hf_utils.py # unsloth_zoo/vllm_utils.py
Apply 16 accepted review fixes across two files: - set_additional_modules now honors non_layered_components explicitly so Qwen3-VL merger.linear_fc1/2 are restored instead of dropped by the generic "linear" substring filter. - _get_vllm_state_dict moves layernorm extraction (and layer_scalar capture) above the no-mlp early-continue so layers without an mlp attribute still get their input/post layernorms exported. - extract_gdn_layers dequantizes per-shard BnB QuantStates before concatenating into the fused in_proj_qkv weight, avoiding K/V being dequantized with Q's scales. The in_proj_ba single-shard merged-layer case now dequantizes and splits instead of silently dropping in_proj_a quant_state. - Gemma4 top-level per-layer-input modules (embed_tokens_per_layer, per_layer_model_projection, per_layer_projection_norm) are added to non_layered_components and extracted from the vLLM text model. - patch_gemma4_vllm_lora_support now also patches Gemma4ForCausalLM (when available) and guards class-level supports_lora / embedding_modules writes behind an idempotency flag. - finalize_huggingface_model reapplies dtype to the live config tree after copy_attributes, switches vision-rotary detection from class equality to identity-based id() membership, and keeps inv_freq buffers at float32 for all archs (matching transformers default). - convert_vllm_to_huggingface preserves buffer registration for layer_scalar-style entries instead of unconditionally wrapping them in nn.Parameter. - assert_same_state_dict only relaxes tolerances on the dtype-mismatch / FP8 upcast branch; same-dtype comparisons keep torch defaults. - Conv1d rebuild branch is qualified with linear_attn substring so it won't silently rebuild future non-GDN conv1d layers as depthwise. - _test_is_same_vlm falls back to a synthetic PIL image when the remote sloth URL load_image fails, so the test runs offline.
Append 9 regression tests to tests/test_vllm_to_hf_conversion.py covering the fixes applied during review: - set_additional_modules now restores visual merger linear_fc1/2. - _get_vllm_state_dict extracts layernorms even when a decoder layer lacks an mlp attribute. - finalize_huggingface_model propagates dtype to live config tree after copy_attributes replaces the config object. - finalize_huggingface_model uses identity-based vision rotary detection so text rotary is not misclassified when text and vision configs share a Python class. - convert_vllm_to_huggingface preserves buffer registration for layer_scalar-style entries instead of converting them to nn.Parameter. - assert_same_state_dict uses tight torch defaults for same-dtype comparisons; loose tolerance only applies on the FP8/dtype-mismatch upcast branch. - Conv1d rebuild branch is qualified with linear_attn substring. - patch_gemma4_vllm_lora_support now covers both Gemma4ForConditionalGeneration and Gemma4ForCausalLM. - get_model_layer_config includes Gemma4 top-level per-layer-input modules in non_layered_components. Also corrects the rotary inv_freq dtype assertion in test_finalize_non_gemma4_rotary_buffers_follow_model_dtype to match the new always-float32 behavior of finalize_huggingface_model.
# Conflicts: # unsloth_zoo/empty_model.py # unsloth_zoo/vllm_utils.py
- finalize_huggingface_model: guard Gemma4 multimodal rotary rebuild with try/except and broaden vision-rotary detection by module path, so a copy-attributes id() drift no longer reroutes a vision rotary through the text_config (which lacks the vision rope_parameters shape and crashed with KeyError 'rope_type' / NoneType ** Tensor). - finalize_huggingface_model: lift float rotary buffers to float32 on all non-quantized models (not just Gemma4) after new_model.to(dtype), fixing an inv_freq / original_inv_freq downcast regression for e.g. Qwen3.5. Drops the redundant is_gemma4 fresh-rotary clone used only to re-copy attention_scaling (a Python float unaffected by .to). - finalize_huggingface_model: hoist deepcopy(text_config) out of the rotary_emb_local loop so multi-layer Gemma3/4 models don't deepcopy the text config once per decoder layer. - extract_gdn_layers: when dequantizing the fused in_proj_ba BnB shard, compute the b/a split midpoint on the dequantized tensor rather than the packed uint8 Params4bit buffer whose shape[0] is numel/2. - _get_vllm_state_dict: match lm_head by exact name or .lm_head suffix instead of substring so unrelated submodule names containing 'lm_head' cannot shadow the real head.
Trim WHAT-restatement comments and collapse a multi-line rationale to one line stating the load-bearing fact. No behavioural change.
…vation, GDN dequantize midpoint, and lm_head exact match
# Conflicts: # unsloth_zoo/empty_model.py # unsloth_zoo/vllm_utils.py
- vllm_utils.py: convert_vllm_to_huggingface parameter-list path regex now handles trailing-digit segments (e.g. `embed_tokens_per_layer.0`) via the same anchor-or-end pattern used below for `exec` assignment. - empty_model.py: finalize_huggingface_model rotary reinit no longer swallows failures silently; the float32 buffer lift is skipped when the reinit raised, and the exception is logged so wrong-shape rotary state does not propagate. - empty_model.py: patch_gemma4_vllm_lora_support guards both the `gemma4_mm` and `gemma4` imports, does not clobber a pre-existing `embedding_modules` registry, and delegates to the original `create_lora_manager` so vLLM shim kwargs (e.g. `vllm_config`) reach the manager constructor correctly. - empty_model.py: patch_gemma4_vllm_k_eq_v_support uses getattr for `_stack_quantization_states` so absent / renamed private attrs do not crash, and adds `model.language_model` to the k/v name-prefix search so HF-style Gemma4 multimodal parameters match. - empty_model.py + vllm_utils.py: both Gemma4 gates share a new `_is_gemma4_config` helper that accepts `gemma4` and the text-only `gemma4_text` model_type. - vllm_utils.py: Gemma4 patch block runs after the BnB autodetect so that `-bnb-4bit` / `quant_method == bitsandbytes` models trigger the k_eq_v patch even when the caller did not pass use_bitsandbytes=True.
…in set_additional_modules
Adds 12 regression tests covering the iter-1 hardening (trailing-digit regex path, rotary reinit success guard, _is_gemma4_config helper, gemma4 gate migration, gemma4_mm import guard, private loader attr guard, HF-style k_eq_v prefix, lora manager delegation, behavioral no-op tests that stub missing vLLM modules). Updates test_gemma4_lora_patch_preserves_signature_for_inspect and test_gemma4_k_eq_v_set_hoists_constant_check to match the new source shape.
# Conflicts: # unsloth_zoo/empty_model.py # unsloth_zoo/vllm_utils.py
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Gemma 4 and Qwen 3.5 (GDN) architectures, adding specialized layer extraction logic, vLLM patches for LoRA and BitsAndBytes, and a centralized model finalization process. The review feedback identifies a logic flaw in GDN quantization state preservation where only the first shard is checked, and points out that Gemma 4 k_eq_v layers require explicit mapping of the K shard to the V projection during extraction to avoid uninitialized weights. Additionally, a debug print statement should be removed or replaced with a logger call.
| if isinstance(qs_attr, dict): | ||
| _store_quant_state(f"{prefix}.in_proj_qkv", qkv_states[0]) |
There was a problem hiding this comment.
The logic here assumes that if only one shard is quantized, it must be at index 0. If qkv_states[1] or qkv_states[2] is the only non-None state, it will be ignored, and _store_quant_state will be called with None, leading to loss of quantization information for the fused layer. Consider using the first non-None state found in qkv_states to ensure the quantization state is preserved.
| if isinstance(qs_attr, dict): | |
| _store_quant_state(f"{prefix}.in_proj_qkv", qkv_states[0]) | |
| if isinstance(qs_attr, dict): | |
| _store_quant_state(f"{prefix}.in_proj_qkv", next((qs for qs in qkv_states if qs is not None), None)) |
| if kk not in gemma4_k_eq_v_layers: | ||
| get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) |
There was a problem hiding this comment.
Skipping the extraction of v_proj for k_eq_v layers will result in an incomplete state dictionary for the Hugging Face model, as the v_proj weights will remain uninitialized (or initialized with tiny dimensions from the empty model creation). Since these layers reuse K as V, you should extract shard 1 (K) and store it as v_proj to ensure the HF model is correctly populated.
| if kk not in gemma4_k_eq_v_layers: | |
| get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) | |
| v_shard_idx = 1 if kk in gemma4_k_eq_v_layers else 2 | |
| get_state_dict(f"{prefix}.v_proj", v_shard_idx, state_dict, qkv_proj) |
| or not any(substr in x for substr in ("layers", "blocks", embed_tokens_key, norm_key, "lm_head", "mlp", "linear", "list")) | ||
| ) | ||
| ) | ||
| print(f'Performing substitution for {additional_keys=}') |
Addresses review iter-1 findings. The four hunks refine (not delete) logic previously added in review-fix commits 4587bc6, e4f530c, ca60088, e6ebab4 — each hunk preserves the original intent while extending it to cover cases those commits missed. - extract_gdn_layers (refines 4587bc6 / ca60088): keep the sharded- QuantState branch intact; add a defensive shape guard on the else-branch for the fused single-QuantState case, where the stored quant_state shape does not match the 3/4 qkv slice. When detected, dequantize the fused buffer once and split into qkv / z, so neither shard is paired with a mismatched quant_state. - _store_quant_state (refines 4587bc6): keep the try/as_dict fallback; swap the bare `except ... pass` for `logger.warning` so a bitsandbytes version mismatch in `quant_state.as_dict(packed=True)` is visible rather than silent. Behavior on success is unchanged. - get_model_layer_config (refines 4587bc6): move `model.visual.merger.linear_fc1/fc2` from `non_layered_components` to `additional_layers`. The non_layered path routes them through set_additional_modules' plain `exec` assignment, which strips BnB/FP8 quant_state for quantized Qwen3-VL mergers; additional_layers keeps the main quantized loop in charge (Linear4bit / FP8Linear build). why safe: linear_fc1/fc2 are literal names (no {kk}), and the main loop's trailing regex `r"\.([\d]{1,})"` matches only digit-after-dot, which does not occur in `linear_fc1` / `linear_fc2`. - finalize_huggingface_model (refines e4f530c / e6ebab4): the existing `.to(dtype)` post-lift loop only restored rotary_emb buffers to fp32; Gemma3 `rotary_emb_local` and Qwen2.5-VL `rotary_pos_emb` buffers were still left at bf16/fp16, breaking RoPE precision. Extend the loop to cover all three rotary variants. The `if rotary is None or not hasattr(rotary, "_buffers"): continue` guard mirrors the prior None-check, just broadened — why safe: the original was already a continue-on-None, and the hasattr check only short-circuits for non-nn.Module attributes, which were never iterated before anyway.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Gemma 4 and Qwen 3.5 (GDN) architectures, including specialized logic for Gated Delta Net (GDN) layers and Gemma 4's shared K/V attention. Key updates include new vLLM patching mechanisms for LoRA and BitsAndBytes support, expanded layer configuration templates, and a centralized finalize_huggingface_model function. A critical issue was identified in the GDN layer extraction logic where merging mixed quantized and non-quantized shards into a single weight tensor could lead to weight corruption; a suggestion was provided to ensure all shards are dequantized if any quantization is present.
| if sum(qs is not None for qs in qkv_states) > 1: | ||
| try: | ||
| from bitsandbytes.functional import dequantize_4bit | ||
| except Exception: | ||
| raise RuntimeError( | ||
| "Unsloth: prequantized BnB Qwen3.5 GDN requires bitsandbytes for fused in_proj_qkv reconstruction." | ||
| ) | ||
| parts = [] | ||
| for i, qs in enumerate(qkv_states): | ||
| shard = weight[offsets[i]:offsets[i + 1]] | ||
| parts.append(dequantize_4bit(shard, quant_state=qs) if qs is not None else shard) | ||
| store(f"{prefix}.in_proj_qkv.weight", torch.cat(parts, dim=0)) | ||
| else: | ||
| store(f"{prefix}.in_proj_qkv.weight", qkv_weight) | ||
| if isinstance(qs_attr, dict): | ||
| _store_quant_state(f"{prefix}.in_proj_qkv", qkv_states[0]) |
There was a problem hiding this comment.
The logic for merging quantized qkv shards in GDN layers is problematic when sum(qs is not None for qs in qkv_states) == 1. Merging a quantized shard with other non-quantized shards into a single in_proj_qkv.weight tensor while only storing the quant state of the first shard will lead to corrupted weights for the other shards during dequantization. Additionally, if the single quant state is at index 1 or 2, it is currently lost. For fused layers where shards are merged into one HF weight, you should dequantize all shards if any quant state is present to ensure correctness.
| if sum(qs is not None for qs in qkv_states) > 1: | |
| try: | |
| from bitsandbytes.functional import dequantize_4bit | |
| except Exception: | |
| raise RuntimeError( | |
| "Unsloth: prequantized BnB Qwen3.5 GDN requires bitsandbytes for fused in_proj_qkv reconstruction." | |
| ) | |
| parts = [] | |
| for i, qs in enumerate(qkv_states): | |
| shard = weight[offsets[i]:offsets[i + 1]] | |
| parts.append(dequantize_4bit(shard, quant_state=qs) if qs is not None else shard) | |
| store(f"{prefix}.in_proj_qkv.weight", torch.cat(parts, dim=0)) | |
| else: | |
| store(f"{prefix}.in_proj_qkv.weight", qkv_weight) | |
| if isinstance(qs_attr, dict): | |
| _store_quant_state(f"{prefix}.in_proj_qkv", qkv_states[0]) | |
| if any(qs is not None for qs in qkv_states): | |
| try: | |
| from bitsandbytes.functional import dequantize_4bit | |
| except Exception: | |
| raise RuntimeError( | |
| "Unsloth: prequantized BnB Qwen3.5 GDN requires bitsandbytes for fused in_proj_qkv reconstruction." | |
| ) | |
| parts = [] | |
| for i, qs in enumerate(qkv_states): | |
| shard = weight[offsets[i]:offsets[i + 1]] | |
| parts.append(dequantize_4bit(shard, quant_state=qs) if qs is not None else shard) | |
| store(f"{prefix}.in_proj_qkv.weight", torch.cat(parts, dim=0)) | |
| else: | |
| store(f"{prefix}.in_proj_qkv.weight", qkv_weight) |
Review iter-2 flagged that my iter-1 refinement in 7ec99b6 regressed in_proj_z handling on the multi-quant path from PR #7's GDN extractor (commits ca60088 / e6ebab4 / 3575a41). This commit refines 7ec99b6 and preserves all three prior intents, rather than deleting them. Preserved from PR #7 (ca60088 / e6ebab4 / 3575a41): - `sum(qs is not None for qs in qkv_states) > 1` per-shard dequant branch is kept as an `elif`, not dropped. - `dequantize_4bit` import and RuntimeError behaviour are kept. - `_store_quant_state` emission for the packed `qs0 covers qkv + qs3 covers z` BnB convention is preserved in the final else branch. Preserved from my iter-1 fix (7ec99b6): - Fused-single-QuantState shape guard that triggers a full dequantize + split. Moved out of the else branch into its own leading branch so the subsequent per-shard dequant path can also emit the z shard. New in this commit: - Multi-quant (`sum > 1`) branch now also stores `in_proj_z.weight` (dequant via `qs_attr.get(3)` if present, else raw `z_weight`). Without this, the HF reconstruction saw no z key and rebuilt z from the random placeholder in `create_empty_causal_lm`. Bug introduced by 7ec99b6. - Fused-full detection now compares `qs_shape[0]` against `offsets[4]` (logical unpacked out_features summed from `proj.output_sizes`) instead of `weight.shape[0]`. The BnB Params4bit buffer is stored packed as uint8 with half the rows, so the prior comparison never matched real BnB layouts and the fused-full guard silently fell through to the per-shard branch. - Final else branch now re-emits `_store_quant_state(in_proj_qkv, qkv_states[0])` alongside the z quant_state, restoring the emission that 7ec99b6 inadvertently dropped when it collapsed the old nested `if/else` block.
Six review-added tests covering the iter-1 / iter-2 fixes to extract_gdn_layers, finalize_huggingface_model, and the merger routing change. All added to the existing test module: - test_extract_gdn_layers_fused_single_quant_state_dequantizes_and_splits -- fused-single-QuantState path dequantizes whole buffer and splits. - test_extract_gdn_layers_multi_quant_branch_stores_in_proj_z_dequantized -- multi-quant path dequantizes z via qs_attr[3]. - test_extract_gdn_layers_multi_quant_branch_stores_in_proj_z_raw_when_no_z_state -- multi-quant path falls back to raw z when qs_attr[3] absent. - test_store_quant_state_logs_warning_when_as_dict_raises -- _store_quant_state surfaces as_dict failures via logger.warning instead of silent pass. - test_finalize_lifts_rotary_emb_local_to_fp32_after_dtype_cast -- Gemma3 rotary_emb_local buffers stay fp32 after .to(dtype). - test_finalize_lifts_rotary_pos_emb_to_fp32_after_dtype_cast -- Qwen2.5-VL rotary_pos_emb buffers stay fp32 after .to(dtype). Plus two rename+flip updates of existing tests to track the merger.linear_fc1/fc2 move from non_layered_components back to additional_layers: - test_merger_linear_fc_moved_to_non_layered -> test_merger_linear_fc_routed_to_additional_layers. - test_set_additional_modules_loads_visual_merger_linear_fc -> test_set_additional_modules_skips_visual_merger_linear_fc.
Addresses review iter-1 findings. The four hunks refine (not delete) logic previously added in review-fix commits 4587bc6, e4f530c, ca60088, e6ebab4 — each hunk preserves the original intent while extending it to cover cases those commits missed. - extract_gdn_layers (refines 4587bc6 / ca60088): keep the sharded- QuantState branch intact; add a defensive shape guard on the else-branch for the fused single-QuantState case, where the stored quant_state shape does not match the 3/4 qkv slice. When detected, dequantize the fused buffer once and split into qkv / z, so neither shard is paired with a mismatched quant_state. - _store_quant_state (refines 4587bc6): keep the try/as_dict fallback; swap the bare `except ... pass` for `logger.warning` so a bitsandbytes version mismatch in `quant_state.as_dict(packed=True)` is visible rather than silent. Behavior on success is unchanged. - get_model_layer_config (refines 4587bc6): move `model.visual.merger.linear_fc1/fc2` from `non_layered_components` to `additional_layers`. The non_layered path routes them through set_additional_modules' plain `exec` assignment, which strips BnB/FP8 quant_state for quantized Qwen3-VL mergers; additional_layers keeps the main quantized loop in charge (Linear4bit / FP8Linear build). why safe: linear_fc1/fc2 are literal names (no {kk}), and the main loop's trailing regex `r"\.([\d]{1,})"` matches only digit-after-dot, which does not occur in `linear_fc1` / `linear_fc2`. - finalize_huggingface_model (refines e4f530c / e6ebab4): the existing `.to(dtype)` post-lift loop only restored rotary_emb buffers to fp32; Gemma3 `rotary_emb_local` and Qwen2.5-VL `rotary_pos_emb` buffers were still left at bf16/fp16, breaking RoPE precision. Extend the loop to cover all three rotary variants. The `if rotary is None or not hasattr(rotary, "_buffers"): continue` guard mirrors the prior None-check, just broadened — why safe: the original was already a continue-on-None, and the hasattr check only short-circuits for non-nn.Module attributes, which were never iterated before anyway.
Review iter-2 flagged that my iter-1 refinement in 7ec99b6 regressed in_proj_z handling on the multi-quant path from PR #7's GDN extractor (commits ca60088 / e6ebab4 / 3575a41). This commit refines 7ec99b6 and preserves all three prior intents, rather than deleting them. Preserved from PR #7 (ca60088 / e6ebab4 / 3575a41): - `sum(qs is not None for qs in qkv_states) > 1` per-shard dequant branch is kept as an `elif`, not dropped. - `dequantize_4bit` import and RuntimeError behaviour are kept. - `_store_quant_state` emission for the packed `qs0 covers qkv + qs3 covers z` BnB convention is preserved in the final else branch. Preserved from my iter-1 fix (7ec99b6): - Fused-single-QuantState shape guard that triggers a full dequantize + split. Moved out of the else branch into its own leading branch so the subsequent per-shard dequant path can also emit the z shard. New in this commit: - Multi-quant (`sum > 1`) branch now also stores `in_proj_z.weight` (dequant via `qs_attr.get(3)` if present, else raw `z_weight`). Without this, the HF reconstruction saw no z key and rebuilt z from the random placeholder in `create_empty_causal_lm`. Bug introduced by 7ec99b6. - Fused-full detection now compares `qs_shape[0]` against `offsets[4]` (logical unpacked out_features summed from `proj.output_sizes`) instead of `weight.shape[0]`. The BnB Params4bit buffer is stored packed as uint8 with half the rows, so the prior comparison never matched real BnB layouts and the fused-full guard silently fell through to the per-shard branch. - Final else branch now re-emits `_store_quant_state(in_proj_qkv, qkv_states[0])` alongside the z quant_state, restoring the emission that 7ec99b6 inadvertently dropped when it collapsed the old nested `if/else` block.
Staging mirror of unslothai#588
Original PR: unslothai#588
Author: Datta0
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Tried to make sure the changes are minimal so when we detect linear_attn, we hand it off to a separate function
This PR contains code changes only (3 files). Test changes are in a separate PR.
Changed files:
unsloth_zoo/empty_model.pyunsloth_zoo/hf_utils.pyunsloth_zoo/vllm_utils.py