Skip to content

[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #7

Closed
danielhanchen wants to merge 33 commits into
mainfrom
pr-588-code
Closed

[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #7
danielhanchen wants to merge 33 commits into
mainfrom
pr-588-code

Conversation

@danielhanchen

Copy link
Copy Markdown
Owner

Staging mirror of unslothai#588

Original PR: unslothai#588
Author: Datta0

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

  • fast inference support for Qwen 3.5 with vLLM :)
    Tried to make sure the changes are minimal so when we detect linear_attn, we hand it off to a separate function

This PR contains code changes only (3 files). Test changes are in a separate PR.

Changed files:

  • unsloth_zoo/empty_model.py
  • unsloth_zoo/hf_utils.py
  • unsloth_zoo/vllm_utils.py

Datta0 and others added 22 commits March 30, 2026 13:41
Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
…e_huggingface_model

- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
  _call_create_lora_manager's signature inspection still sees vllm_config; pass model
  positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
  upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
  qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
  so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
  ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
  of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
  handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
  export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
  rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
  rotary_pos_emb on vision_config availability; mirror language_model detection from
  set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
  the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
  'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
  current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
  (defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
  so non-tensor state-dict values pass through.
…n, finalize_huggingface_model

- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
  _call_create_lora_manager's signature inspection still sees vllm_config; pass model
  positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
  upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
  qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
  so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
  ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
  of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
  handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
  export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
  rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
  rotary_pos_emb on vision_config availability; mirror language_model detection from
  set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
  the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
  'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
  current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
  (defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
  so non-tensor state-dict values pass through.
- hf_utils.set_dtype_in_config: store string (JSON-safe, keeps string
  comparisons in patch_model_and_tokenizer working); fix fallback
  else-branch that had the HAS_TORCH_DTYPE field selection inverted.
- empty_model.extract_gdn_layers: read bnb_quant_state off the raw
  Params4bit before unwrapping .data; emit weight.quant_state and FP8
  weight_scale(_inv) shards for the in_proj_b / in_proj_a split so
  quantized Qwen3.5 GDN layers round-trip correctly.
- vllm_utils.convert_vllm_to_huggingface: rebuild linear_attn.conv1d
  as a grouped Conv1d with real channels/kernel_size/groups/padding
  instead of treating it as a LayerNorm-style weight swap.
- empty_model.patch_gemma4_vllm_lora_support: soft-import
  vllm.v1.worker.lora_model_runner_mixin so older supported vLLM
  layouts keep working.
- vllm_utils._get_vllm_state_dict: extract Gemma4 per_layer_input_gate
  and per_layer_projection so converted HF models carry the real
  checkpoint weights.
- empty_model.finalize_huggingface_model: restrict dtype propagation
  to the top-level config and its known text/vision/audio subconfigs;
  consolidate the duplicated Gemma4 rotary re-init into one loop while
  keeping the post-.to(dtype) float32 buffer / attention_scaling
  restoration.
- vllm_utils.assert_same_state_dict: _normalize_state_dict_tensor now
  returns None for non-tensor entries (e.g. BnB QuantState dicts) and
  callers skip those; align tied-embedding fallback tolerances with
  the outer comparison (atol=1e-4, rtol=1e-3).
- vllm_utils._test_is_same_vlm: cast only floating-point tensors to
  model.dtype for Gemma3/Gemma4 processors, leaving integer inputs
  like pixel_values untouched.
- vllm_utils._get_vllm_state_dict: collapse the unreachable lm_head
  elif chain; hoist the constant model_type/attention_k_eq_v check
  out of the gemma4_k_eq_v_layers set comprehension.
- empty_model.get_model_layer_config: move model.visual.merger.
  linear_fc1 / linear_fc2 from additional_layers (which expected a
  {kk} placeholder) into non_layered_components.
# Conflicts:
#	unsloth_zoo/empty_model.py
#	unsloth_zoo/hf_utils.py
#	unsloth_zoo/vllm_utils.py
@danielhanchen

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Gemma4 and Qwen 3.5 (Gated Delta Net) architectures, implementing specialized vLLM patches for LoRA and quantization, and a new finalize_huggingface_model function for robust model initialization. Feedback identifies a potential correctness issue in extract_gdn_layers where improper quantization state splitting for fused layers could cause precision loss, and recommends using the _unwrap helper for consistent tensor data access.

Comment thread unsloth_zoo/empty_model.py Outdated
Comment on lines +1128 to +1131
qs_attr = getattr(raw_weight, "bnb_quant_state", getattr(weight, "bnb_quant_state", None))
if isinstance(qs_attr, dict):
_store_quant_state(f"{prefix}.in_proj_qkv", qs_attr.get(0))
_store_quant_state(f"{prefix}.in_proj_z", qs_attr.get(3))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential correctness issue here for BitsAndBytes (BnB) quantized models. When splitting in_proj_qkvz into in_proj_qkv and in_proj_z, you are assigning only the quantization state of the first shard (qs_attr.get(0), typically the Query shard) to the entire in_proj_qkv layer.

In BnB, each shard (Q, K, V) often has its own QuantState with different absolute maximums. By using only the Q shard's state for the combined QKV layer in the Hugging Face model, the K and V weights will be dequantized using incorrect scales, leading to significant precision loss or garbage outputs. Since in_proj_qkv is defined as a single fused layer in the Qwen 3.5 HF architecture, you may need to either merge the quantization states (which is non-trivial for BnB) or ensure the HF model is patched to use separate layers for Q, K, and V.

Comment on lines +1186 to +1188
store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data)
store(f"{prefix}.dt_bias", gdn.dt_bias.data)
store(f"{prefix}.A_log", gdn.A_log.data)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Consider using the _unwrap(v) helper defined at line 1082 instead of accessing .data directly. This ensures compatibility with vLLM's ModelWeightParameter and maintains consistency with the rest of the extraction logic in this function.

Suggested change
store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data)
store(f"{prefix}.dt_bias", gdn.dt_bias.data)
store(f"{prefix}.A_log", gdn.A_log.data)
store(f"{prefix}.conv1d.weight", _unwrap(gdn.conv1d.weight))
store(f"{prefix}.dt_bias", _unwrap(gdn.dt_bias))
store(f"{prefix}.A_log", _unwrap(gdn.A_log))

Apply 16 accepted review fixes across two files:

- set_additional_modules now honors non_layered_components explicitly so
  Qwen3-VL merger.linear_fc1/2 are restored instead of dropped by the
  generic "linear" substring filter.
- _get_vllm_state_dict moves layernorm extraction (and layer_scalar
  capture) above the no-mlp early-continue so layers without an mlp
  attribute still get their input/post layernorms exported.
- extract_gdn_layers dequantizes per-shard BnB QuantStates before
  concatenating into the fused in_proj_qkv weight, avoiding K/V being
  dequantized with Q's scales. The in_proj_ba single-shard merged-layer
  case now dequantizes and splits instead of silently dropping
  in_proj_a quant_state.
- Gemma4 top-level per-layer-input modules (embed_tokens_per_layer,
  per_layer_model_projection, per_layer_projection_norm) are added to
  non_layered_components and extracted from the vLLM text model.
- patch_gemma4_vllm_lora_support now also patches Gemma4ForCausalLM
  (when available) and guards class-level supports_lora /
  embedding_modules writes behind an idempotency flag.
- finalize_huggingface_model reapplies dtype to the live config tree
  after copy_attributes, switches vision-rotary detection from class
  equality to identity-based id() membership, and keeps inv_freq
  buffers at float32 for all archs (matching transformers default).
- convert_vllm_to_huggingface preserves buffer registration for
  layer_scalar-style entries instead of unconditionally wrapping them
  in nn.Parameter.
- assert_same_state_dict only relaxes tolerances on the dtype-mismatch
  / FP8 upcast branch; same-dtype comparisons keep torch defaults.
- Conv1d rebuild branch is qualified with linear_attn substring so it
  won't silently rebuild future non-GDN conv1d layers as depthwise.
- _test_is_same_vlm falls back to a synthetic PIL image when the
  remote sloth URL load_image fails, so the test runs offline.
Append 9 regression tests to tests/test_vllm_to_hf_conversion.py covering
the fixes applied during review:

- set_additional_modules now restores visual merger linear_fc1/2.
- _get_vllm_state_dict extracts layernorms even when a decoder layer
  lacks an mlp attribute.
- finalize_huggingface_model propagates dtype to live config tree after
  copy_attributes replaces the config object.
- finalize_huggingface_model uses identity-based vision rotary detection
  so text rotary is not misclassified when text and vision configs
  share a Python class.
- convert_vllm_to_huggingface preserves buffer registration for
  layer_scalar-style entries instead of converting them to nn.Parameter.
- assert_same_state_dict uses tight torch defaults for same-dtype
  comparisons; loose tolerance only applies on the FP8/dtype-mismatch
  upcast branch.
- Conv1d rebuild branch is qualified with linear_attn substring.
- patch_gemma4_vllm_lora_support now covers both
  Gemma4ForConditionalGeneration and Gemma4ForCausalLM.
- get_model_layer_config includes Gemma4 top-level per-layer-input
  modules in non_layered_components.

Also corrects the rotary inv_freq dtype assertion in
test_finalize_non_gemma4_rotary_buffers_follow_model_dtype to match the
new always-float32 behavior of finalize_huggingface_model.
# Conflicts:
#	unsloth_zoo/empty_model.py
#	unsloth_zoo/vllm_utils.py
- finalize_huggingface_model: guard Gemma4 multimodal rotary rebuild with
  try/except and broaden vision-rotary detection by module path, so a
  copy-attributes id() drift no longer reroutes a vision rotary through
  the text_config (which lacks the vision rope_parameters shape and
  crashed with KeyError 'rope_type' / NoneType ** Tensor).
- finalize_huggingface_model: lift float rotary buffers to float32 on
  all non-quantized models (not just Gemma4) after new_model.to(dtype),
  fixing an inv_freq / original_inv_freq downcast regression for e.g.
  Qwen3.5. Drops the redundant is_gemma4 fresh-rotary clone used only
  to re-copy attention_scaling (a Python float unaffected by .to).
- finalize_huggingface_model: hoist deepcopy(text_config) out of the
  rotary_emb_local loop so multi-layer Gemma3/4 models don't deepcopy
  the text config once per decoder layer.
- extract_gdn_layers: when dequantizing the fused in_proj_ba BnB shard,
  compute the b/a split midpoint on the dequantized tensor rather than
  the packed uint8 Params4bit buffer whose shape[0] is numel/2.
- _get_vllm_state_dict: match lm_head by exact name or .lm_head suffix
  instead of substring so unrelated submodule names containing
  'lm_head' cannot shadow the real head.
Trim WHAT-restatement comments and collapse a multi-line rationale
to one line stating the load-bearing fact. No behavioural change.
…vation, GDN dequantize midpoint, and lm_head exact match
# Conflicts:
#	unsloth_zoo/empty_model.py
#	unsloth_zoo/vllm_utils.py
@danielhanchen

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Gemma4 and Qwen 3.5 (GDN) models, including vLLM patching for LoRA and BitsAndBytes, and expanded layer extraction templates. It introduces a centralized model finalization process and improves tensor unwrapping and state dict normalization. Feedback highlights a restrictive regex for layer name conversion that may fail for keys ending in digits and suggests avoiding hardcoded padding in GDN Conv1d layers to ensure consistency with original model configurations.

Comment thread unsloth_zoo/vllm_utils.py
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name.replace('model.','',1))
layer = torch.nn.Parameter(weight, requires_grad = False)
exec(f"new_model.{layer_name_br} = layer")
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The regular expression used here \.([\d]{1,})\. is too restrictive because it requires a trailing dot to match a digit component. This will fail to convert dots to brackets for keys that end in a digit (e.g., model.visual.merger.mlp.0), which are present in the new get_model_layer_config templates. This will result in invalid Python syntax in the subsequent exec call (e.g., new_model.model.visual.merger.mlp.0 = layer). It is recommended to use the more robust regex found later in the same file at line 1517.

Suggested change
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)
layer_name_br = re.sub(r"\.([\d]{1,})", lambda x: f"[{x.group(1)}]", layer_name)

Comment thread unsloth_zoo/vllm_utils.py
Comment on lines +1480 to +1495
# Qwen3.5 GDN depthwise Conv1d: rebuild with real channels/kernel/groups.
from torch.nn import Conv1d
conv_weight = _unwrap_tensor(weight)
channels = conv_weight.shape[0]
kernel_size = conv_weight.shape[-1]
layer = Conv1d(
in_channels = channels,
out_channels = channels,
kernel_size = kernel_size,
groups = channels,
padding = kernel_size - 1,
bias = has_bias,
device = get_target_device(),
)
layer.weight = torch.nn.Parameter(conv_weight, requires_grad = False)
layer.bias = bias

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The reconstruction of the Conv1d layer for Qwen 3.5 GDN hardcodes padding = kernel_size - 1. While this is a common pattern for implementing causal convolutions in transformers (where the output is subsequently sliced), it assumes that the model's forward method handles the extra padding correctly. If the original model used a different padding scheme or manual padding in its forward pass, this could lead to shape mismatches or incorrect causal behavior. It would be safer to extract the padding value from the original model's configuration if available.

danielhanchen added a commit that referenced this pull request Apr 20, 2026
- vllm_utils.py: convert_vllm_to_huggingface parameter-list path regex now
  handles trailing-digit segments (e.g. `embed_tokens_per_layer.0`) via the
  same anchor-or-end pattern used below for `exec` assignment.
- empty_model.py: finalize_huggingface_model rotary reinit no longer
  swallows failures silently; the float32 buffer lift is skipped when the
  reinit raised, and the exception is logged so wrong-shape rotary state
  does not propagate.
- empty_model.py: patch_gemma4_vllm_lora_support guards both the
  `gemma4_mm` and `gemma4` imports, does not clobber a pre-existing
  `embedding_modules` registry, and delegates to the original
  `create_lora_manager` so vLLM shim kwargs (e.g. `vllm_config`) reach the
  manager constructor correctly.
- empty_model.py: patch_gemma4_vllm_k_eq_v_support uses getattr for
  `_stack_quantization_states` so absent / renamed private attrs do not
  crash, and adds `model.language_model` to the k/v name-prefix search so
  HF-style Gemma4 multimodal parameters match.
- empty_model.py + vllm_utils.py: both Gemma4 gates share a new
  `_is_gemma4_config` helper that accepts `gemma4` and the text-only
  `gemma4_text` model_type.
- vllm_utils.py: Gemma4 patch block runs after the BnB autodetect so that
  `-bnb-4bit` / `quant_method == bitsandbytes` models trigger the
  k_eq_v patch even when the caller did not pass use_bitsandbytes=True.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Adds 12 regression tests covering the iter-1 hardening (trailing-digit
regex path, rotary reinit success guard, _is_gemma4_config helper,
gemma4 gate migration, gemma4_mm import guard, private loader attr
guard, HF-style k_eq_v prefix, lora manager delegation, behavioral
no-op tests that stub missing vLLM modules). Updates
test_gemma4_lora_patch_preserves_signature_for_inspect and
test_gemma4_k_eq_v_set_hoists_constant_check to match the new source
shape.
@danielhanchen

Copy link
Copy Markdown
Owner Author

Fixes pushed to unslothai#588.

danielhanchen added a commit that referenced this pull request Apr 20, 2026
Review iter-2 flagged that my iter-1 refinement in 7ec99b6 regressed
in_proj_z handling on the multi-quant path from PR #7's GDN extractor
(commits ca60088 / e6ebab4 / 3575a41). This commit refines 7ec99b6 and
preserves all three prior intents, rather than deleting them.

Preserved from PR #7 (ca60088 / e6ebab4 / 3575a41):
- `sum(qs is not None for qs in qkv_states) > 1` per-shard dequant
  branch is kept as an `elif`, not dropped.
- `dequantize_4bit` import and RuntimeError behaviour are kept.
- `_store_quant_state` emission for the packed `qs0 covers qkv + qs3
  covers z` BnB convention is preserved in the final else branch.

Preserved from my iter-1 fix (7ec99b6):
- Fused-single-QuantState shape guard that triggers a full dequantize +
  split. Moved out of the else branch into its own leading branch so
  the subsequent per-shard dequant path can also emit the z shard.

New in this commit:
- Multi-quant (`sum > 1`) branch now also stores `in_proj_z.weight`
  (dequant via `qs_attr.get(3)` if present, else raw `z_weight`).
  Without this, the HF reconstruction saw no z key and rebuilt z from
  the random placeholder in `create_empty_causal_lm`. Bug introduced
  by 7ec99b6.
- Fused-full detection now compares `qs_shape[0]` against `offsets[4]`
  (logical unpacked out_features summed from `proj.output_sizes`)
  instead of `weight.shape[0]`. The BnB Params4bit buffer is stored
  packed as uint8 with half the rows, so the prior comparison never
  matched real BnB layouts and the fused-full guard silently fell
  through to the per-shard branch.
- Final else branch now re-emits `_store_quant_state(in_proj_qkv,
  qkv_states[0])` alongside the z quant_state, restoring the
  emission that 7ec99b6 inadvertently dropped when it collapsed the
  old nested `if/else` block.
danielhanchen added a commit that referenced this pull request Apr 20, 2026
Review iter-2 flagged that my iter-1 refinement in 7ec99b6 regressed
in_proj_z handling on the multi-quant path from PR #7's GDN extractor
(commits ca60088 / e6ebab4 / 3575a41). This commit refines 7ec99b6 and
preserves all three prior intents, rather than deleting them.

Preserved from PR #7 (ca60088 / e6ebab4 / 3575a41):
- `sum(qs is not None for qs in qkv_states) > 1` per-shard dequant
  branch is kept as an `elif`, not dropped.
- `dequantize_4bit` import and RuntimeError behaviour are kept.
- `_store_quant_state` emission for the packed `qs0 covers qkv + qs3
  covers z` BnB convention is preserved in the final else branch.

Preserved from my iter-1 fix (7ec99b6):
- Fused-single-QuantState shape guard that triggers a full dequantize +
  split. Moved out of the else branch into its own leading branch so
  the subsequent per-shard dequant path can also emit the z shard.

New in this commit:
- Multi-quant (`sum > 1`) branch now also stores `in_proj_z.weight`
  (dequant via `qs_attr.get(3)` if present, else raw `z_weight`).
  Without this, the HF reconstruction saw no z key and rebuilt z from
  the random placeholder in `create_empty_causal_lm`. Bug introduced
  by 7ec99b6.
- Fused-full detection now compares `qs_shape[0]` against `offsets[4]`
  (logical unpacked out_features summed from `proj.output_sizes`)
  instead of `weight.shape[0]`. The BnB Params4bit buffer is stored
  packed as uint8 with half the rows, so the prior comparison never
  matched real BnB layouts and the fused-full guard silently fell
  through to the per-shard branch.
- Final else branch now re-emits `_store_quant_state(in_proj_qkv,
  qkv_states[0])` alongside the z quant_state, restoring the
  emission that 7ec99b6 inadvertently dropped when it collapsed the
  old nested `if/else` block.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants