Skip to content

[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #8

Open
danielhanchen wants to merge 39 commits into
mainfrom
pr-588-code
Open

[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #8
danielhanchen wants to merge 39 commits into
mainfrom
pr-588-code

Conversation

@danielhanchen
Copy link
Copy Markdown
Owner

Staging mirror of unslothai#588

Original PR: unslothai#588
Author: Datta0

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

  • fast inference support for Qwen 3.5 with vLLM :)
    Tried to make sure the changes are minimal so when we detect linear_attn, we hand it off to a separate function

This PR contains code changes only (3 files). Test changes are in a separate PR.

Changed files:

  • unsloth_zoo/empty_model.py
  • unsloth_zoo/hf_utils.py
  • unsloth_zoo/vllm_utils.py

Datta0 and others added 30 commits March 30, 2026 13:41
Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
…e_huggingface_model

- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
  _call_create_lora_manager's signature inspection still sees vllm_config; pass model
  positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
  upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
  qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
  so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
  ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
  of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
  handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
  export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
  rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
  rotary_pos_emb on vision_config availability; mirror language_model detection from
  set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
  the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
  'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
  current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
  (defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
  so non-tensor state-dict values pass through.
…n, finalize_huggingface_model

- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
  _call_create_lora_manager's signature inspection still sees vllm_config; pass model
  positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
  upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
  qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
  so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
  ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
  of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
  handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
  export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
  rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
  rotary_pos_emb on vision_config availability; mirror language_model detection from
  set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
  the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
  'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
  current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
  (defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
  so non-tensor state-dict values pass through.
- hf_utils.set_dtype_in_config: store string (JSON-safe, keeps string
  comparisons in patch_model_and_tokenizer working); fix fallback
  else-branch that had the HAS_TORCH_DTYPE field selection inverted.
- empty_model.extract_gdn_layers: read bnb_quant_state off the raw
  Params4bit before unwrapping .data; emit weight.quant_state and FP8
  weight_scale(_inv) shards for the in_proj_b / in_proj_a split so
  quantized Qwen3.5 GDN layers round-trip correctly.
- vllm_utils.convert_vllm_to_huggingface: rebuild linear_attn.conv1d
  as a grouped Conv1d with real channels/kernel_size/groups/padding
  instead of treating it as a LayerNorm-style weight swap.
- empty_model.patch_gemma4_vllm_lora_support: soft-import
  vllm.v1.worker.lora_model_runner_mixin so older supported vLLM
  layouts keep working.
- vllm_utils._get_vllm_state_dict: extract Gemma4 per_layer_input_gate
  and per_layer_projection so converted HF models carry the real
  checkpoint weights.
- empty_model.finalize_huggingface_model: restrict dtype propagation
  to the top-level config and its known text/vision/audio subconfigs;
  consolidate the duplicated Gemma4 rotary re-init into one loop while
  keeping the post-.to(dtype) float32 buffer / attention_scaling
  restoration.
- vllm_utils.assert_same_state_dict: _normalize_state_dict_tensor now
  returns None for non-tensor entries (e.g. BnB QuantState dicts) and
  callers skip those; align tied-embedding fallback tolerances with
  the outer comparison (atol=1e-4, rtol=1e-3).
- vllm_utils._test_is_same_vlm: cast only floating-point tensors to
  model.dtype for Gemma3/Gemma4 processors, leaving integer inputs
  like pixel_values untouched.
- vllm_utils._get_vllm_state_dict: collapse the unreachable lm_head
  elif chain; hoist the constant model_type/attention_k_eq_v check
  out of the gemma4_k_eq_v_layers set comprehension.
- empty_model.get_model_layer_config: move model.visual.merger.
  linear_fc1 / linear_fc2 from additional_layers (which expected a
  {kk} placeholder) into non_layered_components.
# Conflicts:
#	unsloth_zoo/empty_model.py
#	unsloth_zoo/hf_utils.py
#	unsloth_zoo/vllm_utils.py
Apply 16 accepted review fixes across two files:

- set_additional_modules now honors non_layered_components explicitly so
  Qwen3-VL merger.linear_fc1/2 are restored instead of dropped by the
  generic "linear" substring filter.
- _get_vllm_state_dict moves layernorm extraction (and layer_scalar
  capture) above the no-mlp early-continue so layers without an mlp
  attribute still get their input/post layernorms exported.
- extract_gdn_layers dequantizes per-shard BnB QuantStates before
  concatenating into the fused in_proj_qkv weight, avoiding K/V being
  dequantized with Q's scales. The in_proj_ba single-shard merged-layer
  case now dequantizes and splits instead of silently dropping
  in_proj_a quant_state.
- Gemma4 top-level per-layer-input modules (embed_tokens_per_layer,
  per_layer_model_projection, per_layer_projection_norm) are added to
  non_layered_components and extracted from the vLLM text model.
- patch_gemma4_vllm_lora_support now also patches Gemma4ForCausalLM
  (when available) and guards class-level supports_lora /
  embedding_modules writes behind an idempotency flag.
- finalize_huggingface_model reapplies dtype to the live config tree
  after copy_attributes, switches vision-rotary detection from class
  equality to identity-based id() membership, and keeps inv_freq
  buffers at float32 for all archs (matching transformers default).
- convert_vllm_to_huggingface preserves buffer registration for
  layer_scalar-style entries instead of unconditionally wrapping them
  in nn.Parameter.
- assert_same_state_dict only relaxes tolerances on the dtype-mismatch
  / FP8 upcast branch; same-dtype comparisons keep torch defaults.
- Conv1d rebuild branch is qualified with linear_attn substring so it
  won't silently rebuild future non-GDN conv1d layers as depthwise.
- _test_is_same_vlm falls back to a synthetic PIL image when the
  remote sloth URL load_image fails, so the test runs offline.
Append 9 regression tests to tests/test_vllm_to_hf_conversion.py covering
the fixes applied during review:

- set_additional_modules now restores visual merger linear_fc1/2.
- _get_vllm_state_dict extracts layernorms even when a decoder layer
  lacks an mlp attribute.
- finalize_huggingface_model propagates dtype to live config tree after
  copy_attributes replaces the config object.
- finalize_huggingface_model uses identity-based vision rotary detection
  so text rotary is not misclassified when text and vision configs
  share a Python class.
- convert_vllm_to_huggingface preserves buffer registration for
  layer_scalar-style entries instead of converting them to nn.Parameter.
- assert_same_state_dict uses tight torch defaults for same-dtype
  comparisons; loose tolerance only applies on the FP8/dtype-mismatch
  upcast branch.
- Conv1d rebuild branch is qualified with linear_attn substring.
- patch_gemma4_vllm_lora_support now covers both
  Gemma4ForConditionalGeneration and Gemma4ForCausalLM.
- get_model_layer_config includes Gemma4 top-level per-layer-input
  modules in non_layered_components.

Also corrects the rotary inv_freq dtype assertion in
test_finalize_non_gemma4_rotary_buffers_follow_model_dtype to match the
new always-float32 behavior of finalize_huggingface_model.
# Conflicts:
#	unsloth_zoo/empty_model.py
#	unsloth_zoo/vllm_utils.py
- finalize_huggingface_model: guard Gemma4 multimodal rotary rebuild with
  try/except and broaden vision-rotary detection by module path, so a
  copy-attributes id() drift no longer reroutes a vision rotary through
  the text_config (which lacks the vision rope_parameters shape and
  crashed with KeyError 'rope_type' / NoneType ** Tensor).
- finalize_huggingface_model: lift float rotary buffers to float32 on
  all non-quantized models (not just Gemma4) after new_model.to(dtype),
  fixing an inv_freq / original_inv_freq downcast regression for e.g.
  Qwen3.5. Drops the redundant is_gemma4 fresh-rotary clone used only
  to re-copy attention_scaling (a Python float unaffected by .to).
- finalize_huggingface_model: hoist deepcopy(text_config) out of the
  rotary_emb_local loop so multi-layer Gemma3/4 models don't deepcopy
  the text config once per decoder layer.
- extract_gdn_layers: when dequantizing the fused in_proj_ba BnB shard,
  compute the b/a split midpoint on the dequantized tensor rather than
  the packed uint8 Params4bit buffer whose shape[0] is numel/2.
- _get_vllm_state_dict: match lm_head by exact name or .lm_head suffix
  instead of substring so unrelated submodule names containing
  'lm_head' cannot shadow the real head.
Trim WHAT-restatement comments and collapse a multi-line rationale
to one line stating the load-bearing fact. No behavioural change.
…vation, GDN dequantize midpoint, and lm_head exact match
# Conflicts:
#	unsloth_zoo/empty_model.py
#	unsloth_zoo/vllm_utils.py
- vllm_utils.py: convert_vllm_to_huggingface parameter-list path regex now
  handles trailing-digit segments (e.g. `embed_tokens_per_layer.0`) via the
  same anchor-or-end pattern used below for `exec` assignment.
- empty_model.py: finalize_huggingface_model rotary reinit no longer
  swallows failures silently; the float32 buffer lift is skipped when the
  reinit raised, and the exception is logged so wrong-shape rotary state
  does not propagate.
- empty_model.py: patch_gemma4_vllm_lora_support guards both the
  `gemma4_mm` and `gemma4` imports, does not clobber a pre-existing
  `embedding_modules` registry, and delegates to the original
  `create_lora_manager` so vLLM shim kwargs (e.g. `vllm_config`) reach the
  manager constructor correctly.
- empty_model.py: patch_gemma4_vllm_k_eq_v_support uses getattr for
  `_stack_quantization_states` so absent / renamed private attrs do not
  crash, and adds `model.language_model` to the k/v name-prefix search so
  HF-style Gemma4 multimodal parameters match.
- empty_model.py + vllm_utils.py: both Gemma4 gates share a new
  `_is_gemma4_config` helper that accepts `gemma4` and the text-only
  `gemma4_text` model_type.
- vllm_utils.py: Gemma4 patch block runs after the BnB autodetect so that
  `-bnb-4bit` / `quant_method == bitsandbytes` models trigger the
  k_eq_v patch even when the caller did not pass use_bitsandbytes=True.
Adds 12 regression tests covering the iter-1 hardening (trailing-digit
regex path, rotary reinit success guard, _is_gemma4_config helper,
gemma4 gate migration, gemma4_mm import guard, private loader attr
guard, HF-style k_eq_v prefix, lora manager delegation, behavioral
no-op tests that stub missing vLLM modules). Updates
test_gemma4_lora_patch_preserves_signature_for_inspect and
test_gemma4_k_eq_v_set_hoists_constant_check to match the new source
shape.
# Conflicts:
#	unsloth_zoo/empty_model.py
#	unsloth_zoo/vllm_utils.py
@danielhanchen
Copy link
Copy Markdown
Owner Author

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Gemma 4 and Qwen 3.5 (GDN) architectures, adding specialized layer extraction logic, vLLM patches for LoRA and BitsAndBytes, and a centralized model finalization process. The review feedback identifies a logic flaw in GDN quantization state preservation where only the first shard is checked, and points out that Gemma 4 k_eq_v layers require explicit mapping of the K shard to the V projection during extraction to avoid uninitialized weights. Additionally, a debug print statement should be removed or replaced with a logger call.

Comment on lines +1167 to +1168
if isinstance(qs_attr, dict):
_store_quant_state(f"{prefix}.in_proj_qkv", qkv_states[0])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic here assumes that if only one shard is quantized, it must be at index 0. If qkv_states[1] or qkv_states[2] is the only non-None state, it will be ignored, and _store_quant_state will be called with None, leading to loss of quantization information for the fused layer. Consider using the first non-None state found in qkv_states to ensure the quantization state is preserved.

Suggested change
if isinstance(qs_attr, dict):
_store_quant_state(f"{prefix}.in_proj_qkv", qkv_states[0])
if isinstance(qs_attr, dict):
_store_quant_state(f"{prefix}.in_proj_qkv", next((qs for qs in qkv_states if qs is not None), None))

Comment thread unsloth_zoo/vllm_utils.py
Comment on lines +1120 to +1121
if kk not in gemma4_k_eq_v_layers:
get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Skipping the extraction of v_proj for k_eq_v layers will result in an incomplete state dictionary for the Hugging Face model, as the v_proj weights will remain uninitialized (or initialized with tiny dimensions from the empty model creation). Since these layers reuse K as V, you should extract shard 1 (K) and store it as v_proj to ensure the HF model is correctly populated.

Suggested change
if kk not in gemma4_k_eq_v_layers:
get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj)
v_shard_idx = 1 if kk in gemma4_k_eq_v_layers else 2
get_state_dict(f"{prefix}.v_proj", v_shard_idx, state_dict, qkv_proj)

or not any(substr in x for substr in ("layers", "blocks", embed_tokens_key, norm_key, "lm_head", "mlp", "linear", "list"))
)
)
print(f'Performing substitution for {additional_keys=}')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

This debug print statement should be removed or replaced with a proper logger call (e.g., logger.info) to maintain clean output during model conversion.

danielhanchen added a commit that referenced this pull request Apr 20, 2026
Addresses review iter-1 findings. The four hunks refine (not delete)
logic previously added in review-fix commits 4587bc6, e4f530c,
ca60088, e6ebab4 — each hunk preserves the original intent while
extending it to cover cases those commits missed.

- extract_gdn_layers (refines 4587bc6 / ca60088): keep the sharded-
  QuantState branch intact; add a defensive shape guard on the
  else-branch for the fused single-QuantState case, where the stored
  quant_state shape does not match the 3/4 qkv slice. When detected,
  dequantize the fused buffer once and split into qkv / z, so neither
  shard is paired with a mismatched quant_state.
- _store_quant_state (refines 4587bc6): keep the try/as_dict fallback;
  swap the bare `except ... pass` for `logger.warning` so a bitsandbytes
  version mismatch in `quant_state.as_dict(packed=True)` is visible
  rather than silent. Behavior on success is unchanged.
- get_model_layer_config (refines 4587bc6): move
  `model.visual.merger.linear_fc1/fc2` from `non_layered_components` to
  `additional_layers`. The non_layered path routes them through
  set_additional_modules' plain `exec` assignment, which strips BnB/FP8
  quant_state for quantized Qwen3-VL mergers; additional_layers keeps
  the main quantized loop in charge (Linear4bit / FP8Linear build).
  why safe: linear_fc1/fc2 are literal names (no {kk}), and the main
  loop's trailing regex `r"\.([\d]{1,})"` matches only digit-after-dot,
  which does not occur in `linear_fc1` / `linear_fc2`.
- finalize_huggingface_model (refines e4f530c / e6ebab4): the existing
  `.to(dtype)` post-lift loop only restored rotary_emb buffers to fp32;
  Gemma3 `rotary_emb_local` and Qwen2.5-VL `rotary_pos_emb` buffers
  were still left at bf16/fp16, breaking RoPE precision. Extend the
  loop to cover all three rotary variants. The `if rotary is None or
  not hasattr(rotary, "_buffers"): continue` guard mirrors the prior
  None-check, just broadened — why safe: the original was already a
  continue-on-None, and the hasattr check only short-circuits for
  non-nn.Module attributes, which were never iterated before anyway.
@danielhanchen
Copy link
Copy Markdown
Owner Author

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Gemma 4 and Qwen 3.5 (GDN) architectures, including specialized logic for Gated Delta Net (GDN) layers and Gemma 4's shared K/V attention. Key updates include new vLLM patching mechanisms for LoRA and BitsAndBytes support, expanded layer configuration templates, and a centralized finalize_huggingface_model function. A critical issue was identified in the GDN layer extraction logic where merging mixed quantized and non-quantized shards into a single weight tensor could lead to weight corruption; a suggestion was provided to ensure all shards are dequantized if any quantization is present.

Comment on lines +1153 to +1168
if sum(qs is not None for qs in qkv_states) > 1:
try:
from bitsandbytes.functional import dequantize_4bit
except Exception:
raise RuntimeError(
"Unsloth: prequantized BnB Qwen3.5 GDN requires bitsandbytes for fused in_proj_qkv reconstruction."
)
parts = []
for i, qs in enumerate(qkv_states):
shard = weight[offsets[i]:offsets[i + 1]]
parts.append(dequantize_4bit(shard, quant_state=qs) if qs is not None else shard)
store(f"{prefix}.in_proj_qkv.weight", torch.cat(parts, dim=0))
else:
store(f"{prefix}.in_proj_qkv.weight", qkv_weight)
if isinstance(qs_attr, dict):
_store_quant_state(f"{prefix}.in_proj_qkv", qkv_states[0])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for merging quantized qkv shards in GDN layers is problematic when sum(qs is not None for qs in qkv_states) == 1. Merging a quantized shard with other non-quantized shards into a single in_proj_qkv.weight tensor while only storing the quant state of the first shard will lead to corrupted weights for the other shards during dequantization. Additionally, if the single quant state is at index 1 or 2, it is currently lost. For fused layers where shards are merged into one HF weight, you should dequantize all shards if any quant state is present to ensure correctness.

Suggested change
if sum(qs is not None for qs in qkv_states) > 1:
try:
from bitsandbytes.functional import dequantize_4bit
except Exception:
raise RuntimeError(
"Unsloth: prequantized BnB Qwen3.5 GDN requires bitsandbytes for fused in_proj_qkv reconstruction."
)
parts = []
for i, qs in enumerate(qkv_states):
shard = weight[offsets[i]:offsets[i + 1]]
parts.append(dequantize_4bit(shard, quant_state=qs) if qs is not None else shard)
store(f"{prefix}.in_proj_qkv.weight", torch.cat(parts, dim=0))
else:
store(f"{prefix}.in_proj_qkv.weight", qkv_weight)
if isinstance(qs_attr, dict):
_store_quant_state(f"{prefix}.in_proj_qkv", qkv_states[0])
if any(qs is not None for qs in qkv_states):
try:
from bitsandbytes.functional import dequantize_4bit
except Exception:
raise RuntimeError(
"Unsloth: prequantized BnB Qwen3.5 GDN requires bitsandbytes for fused in_proj_qkv reconstruction."
)
parts = []
for i, qs in enumerate(qkv_states):
shard = weight[offsets[i]:offsets[i + 1]]
parts.append(dequantize_4bit(shard, quant_state=qs) if qs is not None else shard)
store(f"{prefix}.in_proj_qkv.weight", torch.cat(parts, dim=0))
else:
store(f"{prefix}.in_proj_qkv.weight", qkv_weight)

danielhanchen added a commit that referenced this pull request Apr 20, 2026
Review iter-2 flagged that my iter-1 refinement in 7ec99b6 regressed
in_proj_z handling on the multi-quant path from PR #7's GDN extractor
(commits ca60088 / e6ebab4 / 3575a41). This commit refines 7ec99b6 and
preserves all three prior intents, rather than deleting them.

Preserved from PR #7 (ca60088 / e6ebab4 / 3575a41):
- `sum(qs is not None for qs in qkv_states) > 1` per-shard dequant
  branch is kept as an `elif`, not dropped.
- `dequantize_4bit` import and RuntimeError behaviour are kept.
- `_store_quant_state` emission for the packed `qs0 covers qkv + qs3
  covers z` BnB convention is preserved in the final else branch.

Preserved from my iter-1 fix (7ec99b6):
- Fused-single-QuantState shape guard that triggers a full dequantize +
  split. Moved out of the else branch into its own leading branch so
  the subsequent per-shard dequant path can also emit the z shard.

New in this commit:
- Multi-quant (`sum > 1`) branch now also stores `in_proj_z.weight`
  (dequant via `qs_attr.get(3)` if present, else raw `z_weight`).
  Without this, the HF reconstruction saw no z key and rebuilt z from
  the random placeholder in `create_empty_causal_lm`. Bug introduced
  by 7ec99b6.
- Fused-full detection now compares `qs_shape[0]` against `offsets[4]`
  (logical unpacked out_features summed from `proj.output_sizes`)
  instead of `weight.shape[0]`. The BnB Params4bit buffer is stored
  packed as uint8 with half the rows, so the prior comparison never
  matched real BnB layouts and the fused-full guard silently fell
  through to the per-shard branch.
- Final else branch now re-emits `_store_quant_state(in_proj_qkv,
  qkv_states[0])` alongside the z quant_state, restoring the
  emission that 7ec99b6 inadvertently dropped when it collapsed the
  old nested `if/else` block.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Six review-added tests covering the iter-1 / iter-2 fixes to
extract_gdn_layers, finalize_huggingface_model, and the merger
routing change. All added to the existing test module:

- test_extract_gdn_layers_fused_single_quant_state_dequantizes_and_splits
  -- fused-single-QuantState path dequantizes whole buffer and splits.
- test_extract_gdn_layers_multi_quant_branch_stores_in_proj_z_dequantized
  -- multi-quant path dequantizes z via qs_attr[3].
- test_extract_gdn_layers_multi_quant_branch_stores_in_proj_z_raw_when_no_z_state
  -- multi-quant path falls back to raw z when qs_attr[3] absent.
- test_store_quant_state_logs_warning_when_as_dict_raises
  -- _store_quant_state surfaces as_dict failures via logger.warning
  instead of silent pass.
- test_finalize_lifts_rotary_emb_local_to_fp32_after_dtype_cast
  -- Gemma3 rotary_emb_local buffers stay fp32 after .to(dtype).
- test_finalize_lifts_rotary_pos_emb_to_fp32_after_dtype_cast
  -- Qwen2.5-VL rotary_pos_emb buffers stay fp32 after .to(dtype).

Plus two rename+flip updates of existing tests to track the
merger.linear_fc1/fc2 move from non_layered_components back to
additional_layers:
- test_merger_linear_fc_moved_to_non_layered ->
  test_merger_linear_fc_routed_to_additional_layers.
- test_set_additional_modules_loads_visual_merger_linear_fc ->
  test_set_additional_modules_skips_visual_merger_linear_fc.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Addresses review iter-1 findings. The four hunks refine (not delete)
logic previously added in review-fix commits 4587bc6, e4f530c,
ca60088, e6ebab4 — each hunk preserves the original intent while
extending it to cover cases those commits missed.

- extract_gdn_layers (refines 4587bc6 / ca60088): keep the sharded-
  QuantState branch intact; add a defensive shape guard on the
  else-branch for the fused single-QuantState case, where the stored
  quant_state shape does not match the 3/4 qkv slice. When detected,
  dequantize the fused buffer once and split into qkv / z, so neither
  shard is paired with a mismatched quant_state.
- _store_quant_state (refines 4587bc6): keep the try/as_dict fallback;
  swap the bare `except ... pass` for `logger.warning` so a bitsandbytes
  version mismatch in `quant_state.as_dict(packed=True)` is visible
  rather than silent. Behavior on success is unchanged.
- get_model_layer_config (refines 4587bc6): move
  `model.visual.merger.linear_fc1/fc2` from `non_layered_components` to
  `additional_layers`. The non_layered path routes them through
  set_additional_modules' plain `exec` assignment, which strips BnB/FP8
  quant_state for quantized Qwen3-VL mergers; additional_layers keeps
  the main quantized loop in charge (Linear4bit / FP8Linear build).
  why safe: linear_fc1/fc2 are literal names (no {kk}), and the main
  loop's trailing regex `r"\.([\d]{1,})"` matches only digit-after-dot,
  which does not occur in `linear_fc1` / `linear_fc2`.
- finalize_huggingface_model (refines e4f530c / e6ebab4): the existing
  `.to(dtype)` post-lift loop only restored rotary_emb buffers to fp32;
  Gemma3 `rotary_emb_local` and Qwen2.5-VL `rotary_pos_emb` buffers
  were still left at bf16/fp16, breaking RoPE precision. Extend the
  loop to cover all three rotary variants. The `if rotary is None or
  not hasattr(rotary, "_buffers"): continue` guard mirrors the prior
  None-check, just broadened — why safe: the original was already a
  continue-on-None, and the hasattr check only short-circuits for
  non-nn.Module attributes, which were never iterated before anyway.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Review iter-2 flagged that my iter-1 refinement in 7ec99b6 regressed
in_proj_z handling on the multi-quant path from PR #7's GDN extractor
(commits ca60088 / e6ebab4 / 3575a41). This commit refines 7ec99b6 and
preserves all three prior intents, rather than deleting them.

Preserved from PR #7 (ca60088 / e6ebab4 / 3575a41):
- `sum(qs is not None for qs in qkv_states) > 1` per-shard dequant
  branch is kept as an `elif`, not dropped.
- `dequantize_4bit` import and RuntimeError behaviour are kept.
- `_store_quant_state` emission for the packed `qs0 covers qkv + qs3
  covers z` BnB convention is preserved in the final else branch.

Preserved from my iter-1 fix (7ec99b6):
- Fused-single-QuantState shape guard that triggers a full dequantize +
  split. Moved out of the else branch into its own leading branch so
  the subsequent per-shard dequant path can also emit the z shard.

New in this commit:
- Multi-quant (`sum > 1`) branch now also stores `in_proj_z.weight`
  (dequant via `qs_attr.get(3)` if present, else raw `z_weight`).
  Without this, the HF reconstruction saw no z key and rebuilt z from
  the random placeholder in `create_empty_causal_lm`. Bug introduced
  by 7ec99b6.
- Fused-full detection now compares `qs_shape[0]` against `offsets[4]`
  (logical unpacked out_features summed from `proj.output_sizes`)
  instead of `weight.shape[0]`. The BnB Params4bit buffer is stored
  packed as uint8 with half the rows, so the prior comparison never
  matched real BnB layouts and the fused-full guard silently fell
  through to the per-shard branch.
- Final else branch now re-emits `_store_quant_state(in_proj_qkv,
  qkv_states[0])` alongside the z quant_state, restoring the
  emission that 7ec99b6 inadvertently dropped when it collapsed the
  old nested `if/else` block.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants