Skip to content

[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #5

Open
danielhanchen wants to merge 4 commits into
mainfrom
pr-588-head
Open

[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #5
danielhanchen wants to merge 4 commits into
mainfrom
pr-588-head

Conversation

@danielhanchen

Copy link
Copy Markdown
Owner

Staging mirror of unslothai#588

Original PR: unslothai#588
Author: Datta0

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

  • fast inference support for Qwen 3.5 with vLLM :)
    Tried to make sure the changes are minimal so when we detect linear_attn, we hand it off to a separate function

@danielhanchen

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Gemma 4 and Qwen 3.5 (GDN) architectures, including vLLM monkey-patches for LoRA and k_eq_v logic, enhanced model conversion utilities for vision and linear attention layers, and more robust dtype configuration. Review feedback identified the need to ensure complete state dictionary extraction for Gemma 4 k_eq_v layers by duplicating the K shard for V, refining the lm_head module search criteria, and adopting more robust regex patterns for converting layer names to bracket notation.

Comment thread unsloth_zoo/vllm_utils.py
Comment on lines +1116 to +1117
if kk not in gemma4_k_eq_v_layers:
get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For Gemma 4 layers where k_eq_v is active, skipping the extraction of v_proj will result in an incomplete state dictionary for the Hugging Face model. Even if vLLM optimizes storage by sharing weights, the Hugging Face architecture still expects a v_proj module. You should extract the K shard again and assign it to the v_proj key to ensure the converted model is functional.

Suggested change
if kk not in gemma4_k_eq_v_layers:
get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj)
if kk not in gemma4_k_eq_v_layers:
get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj)
else:
# For k_eq_v layers, V is identical to K. Extract the K shard again as V.
get_state_dict(f"{prefix}.v_proj", 1, state_dict, qkv_proj)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread unsloth_zoo/vllm_utils.py Outdated
lm_layer = [mod for name,mod in vllm_internals.named_modules() if "lm_head" in name]
# Use get_state_dict for consistent extraction and automatic truncation
get_state_dict("lm_head", 0, state_dict, lm_layer[0], slice_weights=False)
lm_layer = next((mod for name, mod in vllm_internals.named_modules() if "lm_head" in name), None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The search for lm_head using if "lm_head" in name is somewhat loose and might match unintended modules (e.g., wrappers or custom components containing that string). A more precise check like name.split('.')[-1] == "lm_head" or name.endswith(".lm_head") is recommended.

Suggested change
lm_layer = next((mod for name, mod in vllm_internals.named_modules() if "lm_head" in name), None)
lm_layer = next((mod for name, mod in vllm_internals.named_modules() if name.split('.')[-1] == "lm_head"), None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no reason why it would match something unintentionally. There aren't multiple lm_heads

Comment thread unsloth_zoo/vllm_utils.py Outdated
# for attributes of type nn.Parameter, there's no .weight
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name.replace('model.','',1))
layer = torch.nn.Parameter(weight, requires_grad = False)
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The regex r"\.([\d]{1,})\." only matches digits surrounded by dots, which will fail if a digit is at the end of the path (e.g., model.layers.0). Using a more robust regex like r"\.([\d]+)(?=\.||$)" or the one used later in this function (r"\.([\d]{1,})") ensures all indexed components are correctly converted to [] syntax.

Suggested change
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)
layer_name_br = re.sub(r"\.([\d]{1,})", lambda x: f"[{x.group(1)}]", layer_name)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is false positive. we always have digits wrapped by .

Comment thread unsloth_zoo/vllm_utils.py
# LayerNorms (including vision norms)
weight_param = torch.nn.Parameter(weight, requires_grad=False)
weight_param = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad=False)
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The regex used here is inconsistent with the more robust version used at the end of the loop. It should be updated to handle digits at the end of the path string.

Suggested change
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)
layer_name_br = re.sub(r"\.([\d]{1,})", lambda x: f"[{x.group(1)}]", layer_name)

danielhanchen added a commit that referenced this pull request Apr 19, 2026
- patched_create_lora_manager: pass model positionally and preserve __signature__ so unsloth_zoo.vllm_lora_worker_manager._call_create_lora_manager still dispatches vllm_config correctly on vLLM versions that require it.
- patch_gemma4_vllm_lora_support: add embedding_padding_modules class attribute to avoid AttributeError in the vLLM LoRA runner; wrap vLLM imports in try/except so text-only or non-vLLM environments do not crash on import.
- patch_gemma4_vllm_k_eq_v_support: tolerate older vLLM releases without BitsAndBytesModelLoader._stack_quantization_states.
- load_vllm: gate Gemma4 LoRA patch on is_vision_model and enable_lora, and gate the k_eq_v patch on use_bitsandbytes so non-LoRA / non-BnB Gemma4 loads do not force optional vLLM internals.
- extract_gdn_layers: dequantize fused in_proj_qkvz when BnB quant state is attached to the weight so 4-bit Qwen3.5 GDN does not fall through to dense Linear; pick FP8 scale suffix from whichever of weight_scale / weight_scale_inv the source exposes so downstream FP8 detection sees the right key.
- finalize_huggingface_model: apply layer_idx fix to both model.model.language_model.layers (VLM path) and model.model.layers (text-only Qwen3.5) so GDN submodules have correct per-layer index.
- finalize_huggingface_model: keep RoPE inv_freq / original_inv_freq in float32 for all models after fresh rotary_emb re-init, and run the Gemma4 rotary finalization block for quantized Gemma4 too.
- get_model_layer_config: add Gemma4 per_layer_input_gate, per_layer_projection, and post_per_layer_input_norm entries so the reconstructed HF model does not retain 1-wide dummies.
- _get_vllm_state_dict: extract per_layer_input_gate and per_layer_projection directly from vLLM layers.
- convert_vllm_to_huggingface: preserve buffer semantics (not nn.Parameter) for attributes restored through the layer_name-in-quant-state-dict path, so Gemma4 layer_scalar stays a buffer.
@danielhanchen

Copy link
Copy Markdown
Owner Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Gemma 4 and Qwen 3.5 (GDN) models, including vLLM integration patches and expanded layer configuration patterns. Key changes include the addition of finalize_huggingface_model for post-processing model configurations, robust dtype setting in hf_utils.py, and specialized layer extraction for Gated Delta Net (GDN) architectures. Feedback focuses on optimizing module traversal in finalize_huggingface_model to reduce redundant passes, ensuring correct torch.dtype resolution when string prefixes are present, and refining regex patterns for more robust layer name mapping.

Comment thread unsloth_zoo/empty_model.py Outdated
Comment on lines +753 to +786
if getattr(config, "model_type", None) == "gemma4":
for module in new_model.modules():
rotary_emb = getattr(module, "rotary_emb", None)
if rotary_emb is None:
continue
fresh_rotary_emb = rotary_emb.__class__(
config = rotary_emb.config,
device = target_device,
)
for attr_name in ("max_seq_len_cached", "original_max_seq_len"):
if hasattr(fresh_rotary_emb, attr_name):
setattr(rotary_emb, attr_name, getattr(fresh_rotary_emb, attr_name))
for attr_name, attr_value in fresh_rotary_emb.__dict__.items():
if attr_name == "attention_scaling" or attr_name.endswith("_attention_scaling"):
setattr(rotary_emb, attr_name, attr_value)
for buffer_name, buffer in fresh_rotary_emb._buffers.items():
if torch.is_tensor(buffer) and buffer.is_floating_point():
rotary_emb._buffers[buffer_name] = buffer.to(
device = target_device,
dtype = torch.float32,
)
else:
for module in new_model.modules():
rotary_emb = getattr(module, "rotary_emb", None)
if rotary_emb is None:
continue
for buffer_name in ("inv_freq", "original_inv_freq"):
buffer = getattr(rotary_emb, buffer_name, None)
if torch.is_tensor(buffer) and buffer.is_floating_point() and buffer.dtype != torch.float32:
rotary_emb._buffers[buffer_name] = buffer.to(
device = target_device,
dtype = torch.float32,
)
return new_model

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The finalize_huggingface_model function performs multiple redundant passes over all model modules. The logic for Gemma 4 rotary patching (lines 753-773) and the precision fix for other models (lines 775-785) can be integrated into the initial module loop (starting at line 713). This would improve performance by reducing the number of full model traversals and would also resolve redundant operations for Gemma 4 text modules, which are currently handled in both loops.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional and was always the case to find rope modules. They might be at varying depths

Comment thread unsloth_zoo/hf_utils.py Outdated
setattr(config, "dtype", string_dtype)
except:
set_dtype_in_config_fallback(config, string_dtype)
runtime_dtype = getattr(torch, dtype, dtype) if isinstance(dtype, str) else dtype

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The runtime_dtype resolution does not account for strings with a "torch." prefix. If dtype is "torch.float16", getattr(torch, dtype, dtype) will return the string itself, which is then set in the config. This can break downstream code expecting a torch.dtype object. Stripping the prefix ensures correct resolution.

Suggested change
runtime_dtype = getattr(torch, dtype, dtype) if isinstance(dtype, str) else dtype
runtime_dtype = getattr(torch, dtype.replace("torch.", ""), dtype) if isinstance(dtype, str) else dtype

Comment thread unsloth_zoo/vllm_utils.py Outdated
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name.replace('model.','',1))
layer = torch.nn.Parameter(weight, requires_grad = False)
exec(f"new_model.{layer_name_br} = layer")
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The regex r"\.([\d]{1,})\." is less robust than the one used at line 1480 because it requires a trailing dot. Using a more flexible regex that handles indices at the end of a path ensures consistent transformation of dot-notation indices into bracket notation across all keys.

Suggested change
layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name)
layer_name_br = re.sub(r"\.([\d]{1,})", lambda x: f"[{x.group(1)}]", layer_name)

danielhanchen added a commit that referenced this pull request Apr 19, 2026
- hf_utils.set_dtype_in_config: store string (JSON-safe, keeps string
  comparisons in patch_model_and_tokenizer working); fix fallback
  else-branch that had the HAS_TORCH_DTYPE field selection inverted.
- empty_model.extract_gdn_layers: read bnb_quant_state off the raw
  Params4bit before unwrapping .data; emit weight.quant_state and FP8
  weight_scale(_inv) shards for the in_proj_b / in_proj_a split so
  quantized Qwen3.5 GDN layers round-trip correctly.
- vllm_utils.convert_vllm_to_huggingface: rebuild linear_attn.conv1d
  as a grouped Conv1d with real channels/kernel_size/groups/padding
  instead of treating it as a LayerNorm-style weight swap.
- empty_model.patch_gemma4_vllm_lora_support: soft-import
  vllm.v1.worker.lora_model_runner_mixin so older supported vLLM
  layouts keep working.
- vllm_utils._get_vllm_state_dict: extract Gemma4 per_layer_input_gate
  and per_layer_projection so converted HF models carry the real
  checkpoint weights.
- empty_model.finalize_huggingface_model: restrict dtype propagation
  to the top-level config and its known text/vision/audio subconfigs;
  consolidate the duplicated Gemma4 rotary re-init into one loop while
  keeping the post-.to(dtype) float32 buffer / attention_scaling
  restoration.
- vllm_utils.assert_same_state_dict: _normalize_state_dict_tensor now
  returns None for non-tensor entries (e.g. BnB QuantState dicts) and
  callers skip those; align tied-embedding fallback tolerances with
  the outer comparison (atol=1e-4, rtol=1e-3).
- vllm_utils._test_is_same_vlm: cast only floating-point tensors to
  model.dtype for Gemma3/Gemma4 processors, leaving integer inputs
  like pixel_values untouched.
- vllm_utils._get_vllm_state_dict: collapse the unreachable lm_head
  elif chain; hoist the constant model_type/attention_k_eq_v check
  out of the gemma4_k_eq_v_layers set comprehension.
- empty_model.get_model_layer_config: move model.visual.merger.
  linear_fc1 / linear_fc2 from additional_layers (which expected a
  {kk} placeholder) into non_layered_components.
Comment thread unsloth_zoo/vllm_utils.py
inputs = inputs.to(model.device)
for _k, _v in list(inputs.items()):
if torch.is_tensor(_v) and torch.is_floating_point(_v):
inputs[_k] = _v.to(dtype = model.dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I thought hf's tokenized dict supports .to for a while now

Comment thread unsloth_zoo/vllm_utils.py
layer.to = partial(_override_to, layer)
layer.weight.to = partial(_override_to, layer.weight)

elif layer_name.endswith(".conv1d") and "linear_attn" in layer_name:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do these things in empty_model.py? This is not the right place for this

Comment thread unsloth_zoo/vllm_utils.py
_normalize_state_dict_tensor(old_state_dict[key1]),
_normalize_state_dict_tensor(new_state_dict[key2]),
check_stride = True,
check_stride = False,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which parameter is making it require this? Conv?
Codex/Claude can do this just to get the test working. we need to be careful

Comment thread unsloth_zoo/vllm_utils.py
state_dict[norm_prefix] = vllm_text_model.norm.weight.data
quant_state_dict[norm_prefix] = state_dict[norm_prefix]

# Gemma4 top-level per-layer-input modules

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these extra VLM or bridge layers are dealt with in empty_model.py

Comment thread unsloth_zoo/hf_utils.py
if not success:
set_dtype_in_config_fallback(config, dtype)
try:
# if dtype is not a string, convert it to a string

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was failing for qwen 3.5 MRoPE or some module. That is why I had to check dtype vs torch_dtype and set accordingly

# k_proj -> v_proj, so prequant BnB needs the matching QuantState.
if kind == "packed":
if isinstance(quant_states, dict) and 2 not in quant_states and 1 in quant_states:
quant_states[2] = deepcopy(quant_states[1])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to deep copy? wouldn't that duplicate memory usage?

assert vision_config is not None, "Unsloth: vision_config is required for models with vision rotary_pos_emb"
except Exception as rotary_reinit_error:
reinit_ok = False
logger.warning(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to avoid double init? Otherwise skipping would lead to damaged model

z_scale = ws[scale_offsets[3]:scale_offsets[4]]
store(f"{prefix}.in_proj_qkv{scale_suffix}", qkv_scale)
store(f"{prefix}.in_proj_z{scale_suffix}", z_scale)
scale_attr = None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this logic already exists elsewhere in generic FP8 checks. Maybe we should make it a function and use it here?

quant_state_dict[f"{name}.weight.quant_state"] = quant_state
try:
for k, v in quant_state.as_dict(packed=True).items():
state_dict[f"{name}.weight.{k}"] = v

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand what this is trying to do... Iterate over all items and set reset the name?

Addresses review iter-1 findings. The four hunks refine (not delete)
logic previously added in review-fix commits 4587bc6, e4f530c,
ca60088, e6ebab4 — each hunk preserves the original intent while
extending it to cover cases those commits missed.

- extract_gdn_layers (refines 4587bc6 / ca60088): keep the sharded-
  QuantState branch intact; add a defensive shape guard on the
  else-branch for the fused single-QuantState case, where the stored
  quant_state shape does not match the 3/4 qkv slice. When detected,
  dequantize the fused buffer once and split into qkv / z, so neither
  shard is paired with a mismatched quant_state.
- _store_quant_state (refines 4587bc6): keep the try/as_dict fallback;
  swap the bare `except ... pass` for `logger.warning` so a bitsandbytes
  version mismatch in `quant_state.as_dict(packed=True)` is visible
  rather than silent. Behavior on success is unchanged.
- get_model_layer_config (refines 4587bc6): move
  `model.visual.merger.linear_fc1/fc2` from `non_layered_components` to
  `additional_layers`. The non_layered path routes them through
  set_additional_modules' plain `exec` assignment, which strips BnB/FP8
  quant_state for quantized Qwen3-VL mergers; additional_layers keeps
  the main quantized loop in charge (Linear4bit / FP8Linear build).
  why safe: linear_fc1/fc2 are literal names (no {kk}), and the main
  loop's trailing regex `r"\.([\d]{1,})"` matches only digit-after-dot,
  which does not occur in `linear_fc1` / `linear_fc2`.
- finalize_huggingface_model (refines e4f530c / e6ebab4): the existing
  `.to(dtype)` post-lift loop only restored rotary_emb buffers to fp32;
  Gemma3 `rotary_emb_local` and Qwen2.5-VL `rotary_pos_emb` buffers
  were still left at bf16/fp16, breaking RoPE precision. Extend the
  loop to cover all three rotary variants. The `if rotary is None or
  not hasattr(rotary, "_buffers"): continue` guard mirrors the prior
  None-check, just broadened — why safe: the original was already a
  continue-on-None, and the hasattr check only short-circuits for
  non-nn.Module attributes, which were never iterated before anyway.
Review iter-2 flagged that my iter-1 refinement in 7ec99b6 regressed
in_proj_z handling on the multi-quant path from PR #7's GDN extractor
(commits ca60088 / e6ebab4 / 3575a41). This commit refines 7ec99b6 and
preserves all three prior intents, rather than deleting them.

Preserved from PR #7 (ca60088 / e6ebab4 / 3575a41):
- `sum(qs is not None for qs in qkv_states) > 1` per-shard dequant
  branch is kept as an `elif`, not dropped.
- `dequantize_4bit` import and RuntimeError behaviour are kept.
- `_store_quant_state` emission for the packed `qs0 covers qkv + qs3
  covers z` BnB convention is preserved in the final else branch.

Preserved from my iter-1 fix (7ec99b6):
- Fused-single-QuantState shape guard that triggers a full dequantize +
  split. Moved out of the else branch into its own leading branch so
  the subsequent per-shard dequant path can also emit the z shard.

New in this commit:
- Multi-quant (`sum > 1`) branch now also stores `in_proj_z.weight`
  (dequant via `qs_attr.get(3)` if present, else raw `z_weight`).
  Without this, the HF reconstruction saw no z key and rebuilt z from
  the random placeholder in `create_empty_causal_lm`. Bug introduced
  by 7ec99b6.
- Fused-full detection now compares `qs_shape[0]` against `offsets[4]`
  (logical unpacked out_features summed from `proj.output_sizes`)
  instead of `weight.shape[0]`. The BnB Params4bit buffer is stored
  packed as uint8 with half the rows, so the prior comparison never
  matched real BnB layouts and the fused-full guard silently fell
  through to the per-shard branch.
- Final else branch now re-emits `_store_quant_state(in_proj_qkv,
  qkv_states[0])` alongside the z quant_state, restoring the
  emission that 7ec99b6 inadvertently dropped when it collapsed the
  old nested `if/else` block.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants