[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #5
Conversation
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Gemma 4 and Qwen 3.5 (GDN) architectures, including vLLM monkey-patches for LoRA and k_eq_v logic, enhanced model conversion utilities for vision and linear attention layers, and more robust dtype configuration. Review feedback identified the need to ensure complete state dictionary extraction for Gemma 4 k_eq_v layers by duplicating the K shard for V, refining the lm_head module search criteria, and adopting more robust regex patterns for converting layer names to bracket notation.
| if kk not in gemma4_k_eq_v_layers: | ||
| get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) |
There was a problem hiding this comment.
For Gemma 4 layers where k_eq_v is active, skipping the extraction of v_proj will result in an incomplete state dictionary for the Hugging Face model. Even if vLLM optimizes storage by sharing weights, the Hugging Face architecture still expects a v_proj module. You should extract the K shard again and assign it to the v_proj key to ensure the converted model is functional.
| if kk not in gemma4_k_eq_v_layers: | |
| get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) | |
| if kk not in gemma4_k_eq_v_layers: | |
| get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) | |
| else: | |
| # For k_eq_v layers, V is identical to K. Extract the K shard again as V. | |
| get_state_dict(f"{prefix}.v_proj", 1, state_dict, qkv_proj) |
There was a problem hiding this comment.
| lm_layer = [mod for name,mod in vllm_internals.named_modules() if "lm_head" in name] | ||
| # Use get_state_dict for consistent extraction and automatic truncation | ||
| get_state_dict("lm_head", 0, state_dict, lm_layer[0], slice_weights=False) | ||
| lm_layer = next((mod for name, mod in vllm_internals.named_modules() if "lm_head" in name), None) |
There was a problem hiding this comment.
The search for lm_head using if "lm_head" in name is somewhat loose and might match unintended modules (e.g., wrappers or custom components containing that string). A more precise check like name.split('.')[-1] == "lm_head" or name.endswith(".lm_head") is recommended.
| lm_layer = next((mod for name, mod in vllm_internals.named_modules() if "lm_head" in name), None) | |
| lm_layer = next((mod for name, mod in vllm_internals.named_modules() if name.split('.')[-1] == "lm_head"), None) |
There was a problem hiding this comment.
There's no reason why it would match something unintentionally. There aren't multiple lm_heads
| # for attributes of type nn.Parameter, there's no .weight | ||
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name.replace('model.','',1)) | ||
| layer = torch.nn.Parameter(weight, requires_grad = False) | ||
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) |
There was a problem hiding this comment.
The regex r"\.([\d]{1,})\." only matches digits surrounded by dots, which will fail if a digit is at the end of the path (e.g., model.layers.0). Using a more robust regex like r"\.([\d]+)(?=\.||$)" or the one used later in this function (r"\.([\d]{1,})") ensures all indexed components are correctly converted to [] syntax.
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) | |
| layer_name_br = re.sub(r"\.([\d]{1,})", lambda x: f"[{x.group(1)}]", layer_name) |
There was a problem hiding this comment.
This is false positive. we always have digits wrapped by .
| # LayerNorms (including vision norms) | ||
| weight_param = torch.nn.Parameter(weight, requires_grad=False) | ||
| weight_param = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad=False) | ||
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) |
There was a problem hiding this comment.
The regex used here is inconsistent with the more robust version used at the end of the loop. It should be updated to handle digits at the end of the path string.
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) | |
| layer_name_br = re.sub(r"\.([\d]{1,})", lambda x: f"[{x.group(1)}]", layer_name) |
- patched_create_lora_manager: pass model positionally and preserve __signature__ so unsloth_zoo.vllm_lora_worker_manager._call_create_lora_manager still dispatches vllm_config correctly on vLLM versions that require it. - patch_gemma4_vllm_lora_support: add embedding_padding_modules class attribute to avoid AttributeError in the vLLM LoRA runner; wrap vLLM imports in try/except so text-only or non-vLLM environments do not crash on import. - patch_gemma4_vllm_k_eq_v_support: tolerate older vLLM releases without BitsAndBytesModelLoader._stack_quantization_states. - load_vllm: gate Gemma4 LoRA patch on is_vision_model and enable_lora, and gate the k_eq_v patch on use_bitsandbytes so non-LoRA / non-BnB Gemma4 loads do not force optional vLLM internals. - extract_gdn_layers: dequantize fused in_proj_qkvz when BnB quant state is attached to the weight so 4-bit Qwen3.5 GDN does not fall through to dense Linear; pick FP8 scale suffix from whichever of weight_scale / weight_scale_inv the source exposes so downstream FP8 detection sees the right key. - finalize_huggingface_model: apply layer_idx fix to both model.model.language_model.layers (VLM path) and model.model.layers (text-only Qwen3.5) so GDN submodules have correct per-layer index. - finalize_huggingface_model: keep RoPE inv_freq / original_inv_freq in float32 for all models after fresh rotary_emb re-init, and run the Gemma4 rotary finalization block for quantized Gemma4 too. - get_model_layer_config: add Gemma4 per_layer_input_gate, per_layer_projection, and post_per_layer_input_norm entries so the reconstructed HF model does not retain 1-wide dummies. - _get_vllm_state_dict: extract per_layer_input_gate and per_layer_projection directly from vLLM layers. - convert_vllm_to_huggingface: preserve buffer semantics (not nn.Parameter) for attributes restored through the layer_name-in-quant-state-dict path, so Gemma4 layer_scalar stays a buffer.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Gemma 4 and Qwen 3.5 (GDN) models, including vLLM integration patches and expanded layer configuration patterns. Key changes include the addition of finalize_huggingface_model for post-processing model configurations, robust dtype setting in hf_utils.py, and specialized layer extraction for Gated Delta Net (GDN) architectures. Feedback focuses on optimizing module traversal in finalize_huggingface_model to reduce redundant passes, ensuring correct torch.dtype resolution when string prefixes are present, and refining regex patterns for more robust layer name mapping.
| if getattr(config, "model_type", None) == "gemma4": | ||
| for module in new_model.modules(): | ||
| rotary_emb = getattr(module, "rotary_emb", None) | ||
| if rotary_emb is None: | ||
| continue | ||
| fresh_rotary_emb = rotary_emb.__class__( | ||
| config = rotary_emb.config, | ||
| device = target_device, | ||
| ) | ||
| for attr_name in ("max_seq_len_cached", "original_max_seq_len"): | ||
| if hasattr(fresh_rotary_emb, attr_name): | ||
| setattr(rotary_emb, attr_name, getattr(fresh_rotary_emb, attr_name)) | ||
| for attr_name, attr_value in fresh_rotary_emb.__dict__.items(): | ||
| if attr_name == "attention_scaling" or attr_name.endswith("_attention_scaling"): | ||
| setattr(rotary_emb, attr_name, attr_value) | ||
| for buffer_name, buffer in fresh_rotary_emb._buffers.items(): | ||
| if torch.is_tensor(buffer) and buffer.is_floating_point(): | ||
| rotary_emb._buffers[buffer_name] = buffer.to( | ||
| device = target_device, | ||
| dtype = torch.float32, | ||
| ) | ||
| else: | ||
| for module in new_model.modules(): | ||
| rotary_emb = getattr(module, "rotary_emb", None) | ||
| if rotary_emb is None: | ||
| continue | ||
| for buffer_name in ("inv_freq", "original_inv_freq"): | ||
| buffer = getattr(rotary_emb, buffer_name, None) | ||
| if torch.is_tensor(buffer) and buffer.is_floating_point() and buffer.dtype != torch.float32: | ||
| rotary_emb._buffers[buffer_name] = buffer.to( | ||
| device = target_device, | ||
| dtype = torch.float32, | ||
| ) | ||
| return new_model |
There was a problem hiding this comment.
The finalize_huggingface_model function performs multiple redundant passes over all model modules. The logic for Gemma 4 rotary patching (lines 753-773) and the precision fix for other models (lines 775-785) can be integrated into the initial module loop (starting at line 713). This would improve performance by reducing the number of full model traversals and would also resolve redundant operations for Gemma 4 text modules, which are currently handled in both loops.
There was a problem hiding this comment.
This is intentional and was always the case to find rope modules. They might be at varying depths
| setattr(config, "dtype", string_dtype) | ||
| except: | ||
| set_dtype_in_config_fallback(config, string_dtype) | ||
| runtime_dtype = getattr(torch, dtype, dtype) if isinstance(dtype, str) else dtype |
There was a problem hiding this comment.
The runtime_dtype resolution does not account for strings with a "torch." prefix. If dtype is "torch.float16", getattr(torch, dtype, dtype) will return the string itself, which is then set in the config. This can break downstream code expecting a torch.dtype object. Stripping the prefix ensures correct resolution.
| runtime_dtype = getattr(torch, dtype, dtype) if isinstance(dtype, str) else dtype | |
| runtime_dtype = getattr(torch, dtype.replace("torch.", ""), dtype) if isinstance(dtype, str) else dtype |
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name.replace('model.','',1)) | ||
| layer = torch.nn.Parameter(weight, requires_grad = False) | ||
| exec(f"new_model.{layer_name_br} = layer") | ||
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) |
There was a problem hiding this comment.
The regex r"\.([\d]{1,})\." is less robust than the one used at line 1480 because it requires a trailing dot. Using a more flexible regex that handles indices at the end of a path ensures consistent transformation of dot-notation indices into bracket notation across all keys.
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) | |
| layer_name_br = re.sub(r"\.([\d]{1,})", lambda x: f"[{x.group(1)}]", layer_name) |
- hf_utils.set_dtype_in_config: store string (JSON-safe, keeps string
comparisons in patch_model_and_tokenizer working); fix fallback
else-branch that had the HAS_TORCH_DTYPE field selection inverted.
- empty_model.extract_gdn_layers: read bnb_quant_state off the raw
Params4bit before unwrapping .data; emit weight.quant_state and FP8
weight_scale(_inv) shards for the in_proj_b / in_proj_a split so
quantized Qwen3.5 GDN layers round-trip correctly.
- vllm_utils.convert_vllm_to_huggingface: rebuild linear_attn.conv1d
as a grouped Conv1d with real channels/kernel_size/groups/padding
instead of treating it as a LayerNorm-style weight swap.
- empty_model.patch_gemma4_vllm_lora_support: soft-import
vllm.v1.worker.lora_model_runner_mixin so older supported vLLM
layouts keep working.
- vllm_utils._get_vllm_state_dict: extract Gemma4 per_layer_input_gate
and per_layer_projection so converted HF models carry the real
checkpoint weights.
- empty_model.finalize_huggingface_model: restrict dtype propagation
to the top-level config and its known text/vision/audio subconfigs;
consolidate the duplicated Gemma4 rotary re-init into one loop while
keeping the post-.to(dtype) float32 buffer / attention_scaling
restoration.
- vllm_utils.assert_same_state_dict: _normalize_state_dict_tensor now
returns None for non-tensor entries (e.g. BnB QuantState dicts) and
callers skip those; align tied-embedding fallback tolerances with
the outer comparison (atol=1e-4, rtol=1e-3).
- vllm_utils._test_is_same_vlm: cast only floating-point tensors to
model.dtype for Gemma3/Gemma4 processors, leaving integer inputs
like pixel_values untouched.
- vllm_utils._get_vllm_state_dict: collapse the unreachable lm_head
elif chain; hoist the constant model_type/attention_k_eq_v check
out of the gemma4_k_eq_v_layers set comprehension.
- empty_model.get_model_layer_config: move model.visual.merger.
linear_fc1 / linear_fc2 from additional_layers (which expected a
{kk} placeholder) into non_layered_components.
4b7935e to
a6f73ac
Compare
| inputs = inputs.to(model.device) | ||
| for _k, _v in list(inputs.items()): | ||
| if torch.is_tensor(_v) and torch.is_floating_point(_v): | ||
| inputs[_k] = _v.to(dtype = model.dtype) |
There was a problem hiding this comment.
Interesting. I thought hf's tokenized dict supports .to for a while now
| layer.to = partial(_override_to, layer) | ||
| layer.weight.to = partial(_override_to, layer.weight) | ||
|
|
||
| elif layer_name.endswith(".conv1d") and "linear_attn" in layer_name: |
There was a problem hiding this comment.
I think we do these things in empty_model.py? This is not the right place for this
| _normalize_state_dict_tensor(old_state_dict[key1]), | ||
| _normalize_state_dict_tensor(new_state_dict[key2]), | ||
| check_stride = True, | ||
| check_stride = False, |
There was a problem hiding this comment.
Which parameter is making it require this? Conv?
Codex/Claude can do this just to get the test working. we need to be careful
| state_dict[norm_prefix] = vllm_text_model.norm.weight.data | ||
| quant_state_dict[norm_prefix] = state_dict[norm_prefix] | ||
|
|
||
| # Gemma4 top-level per-layer-input modules |
There was a problem hiding this comment.
All these extra VLM or bridge layers are dealt with in empty_model.py
| if not success: | ||
| set_dtype_in_config_fallback(config, dtype) | ||
| try: | ||
| # if dtype is not a string, convert it to a string |
There was a problem hiding this comment.
I think this was failing for qwen 3.5 MRoPE or some module. That is why I had to check dtype vs torch_dtype and set accordingly
| # k_proj -> v_proj, so prequant BnB needs the matching QuantState. | ||
| if kind == "packed": | ||
| if isinstance(quant_states, dict) and 2 not in quant_states and 1 in quant_states: | ||
| quant_states[2] = deepcopy(quant_states[1]) |
There was a problem hiding this comment.
Do we really want to deep copy? wouldn't that duplicate memory usage?
| assert vision_config is not None, "Unsloth: vision_config is required for models with vision rotary_pos_emb" | ||
| except Exception as rotary_reinit_error: | ||
| reinit_ok = False | ||
| logger.warning( |
There was a problem hiding this comment.
Is this to avoid double init? Otherwise skipping would lead to damaged model
| z_scale = ws[scale_offsets[3]:scale_offsets[4]] | ||
| store(f"{prefix}.in_proj_qkv{scale_suffix}", qkv_scale) | ||
| store(f"{prefix}.in_proj_z{scale_suffix}", z_scale) | ||
| scale_attr = None |
There was a problem hiding this comment.
A lot of this logic already exists elsewhere in generic FP8 checks. Maybe we should make it a function and use it here?
| quant_state_dict[f"{name}.weight.quant_state"] = quant_state | ||
| try: | ||
| for k, v in quant_state.as_dict(packed=True).items(): | ||
| state_dict[f"{name}.weight.{k}"] = v |
There was a problem hiding this comment.
I don't fully understand what this is trying to do... Iterate over all items and set reset the name?
Addresses review iter-1 findings. The four hunks refine (not delete) logic previously added in review-fix commits 4587bc6, e4f530c, ca60088, e6ebab4 — each hunk preserves the original intent while extending it to cover cases those commits missed. - extract_gdn_layers (refines 4587bc6 / ca60088): keep the sharded- QuantState branch intact; add a defensive shape guard on the else-branch for the fused single-QuantState case, where the stored quant_state shape does not match the 3/4 qkv slice. When detected, dequantize the fused buffer once and split into qkv / z, so neither shard is paired with a mismatched quant_state. - _store_quant_state (refines 4587bc6): keep the try/as_dict fallback; swap the bare `except ... pass` for `logger.warning` so a bitsandbytes version mismatch in `quant_state.as_dict(packed=True)` is visible rather than silent. Behavior on success is unchanged. - get_model_layer_config (refines 4587bc6): move `model.visual.merger.linear_fc1/fc2` from `non_layered_components` to `additional_layers`. The non_layered path routes them through set_additional_modules' plain `exec` assignment, which strips BnB/FP8 quant_state for quantized Qwen3-VL mergers; additional_layers keeps the main quantized loop in charge (Linear4bit / FP8Linear build). why safe: linear_fc1/fc2 are literal names (no {kk}), and the main loop's trailing regex `r"\.([\d]{1,})"` matches only digit-after-dot, which does not occur in `linear_fc1` / `linear_fc2`. - finalize_huggingface_model (refines e4f530c / e6ebab4): the existing `.to(dtype)` post-lift loop only restored rotary_emb buffers to fp32; Gemma3 `rotary_emb_local` and Qwen2.5-VL `rotary_pos_emb` buffers were still left at bf16/fp16, breaking RoPE precision. Extend the loop to cover all three rotary variants. The `if rotary is None or not hasattr(rotary, "_buffers"): continue` guard mirrors the prior None-check, just broadened — why safe: the original was already a continue-on-None, and the hasattr check only short-circuits for non-nn.Module attributes, which were never iterated before anyway.
Review iter-2 flagged that my iter-1 refinement in 7ec99b6 regressed in_proj_z handling on the multi-quant path from PR #7's GDN extractor (commits ca60088 / e6ebab4 / 3575a41). This commit refines 7ec99b6 and preserves all three prior intents, rather than deleting them. Preserved from PR #7 (ca60088 / e6ebab4 / 3575a41): - `sum(qs is not None for qs in qkv_states) > 1` per-shard dequant branch is kept as an `elif`, not dropped. - `dequantize_4bit` import and RuntimeError behaviour are kept. - `_store_quant_state` emission for the packed `qs0 covers qkv + qs3 covers z` BnB convention is preserved in the final else branch. Preserved from my iter-1 fix (7ec99b6): - Fused-single-QuantState shape guard that triggers a full dequantize + split. Moved out of the else branch into its own leading branch so the subsequent per-shard dequant path can also emit the z shard. New in this commit: - Multi-quant (`sum > 1`) branch now also stores `in_proj_z.weight` (dequant via `qs_attr.get(3)` if present, else raw `z_weight`). Without this, the HF reconstruction saw no z key and rebuilt z from the random placeholder in `create_empty_causal_lm`. Bug introduced by 7ec99b6. - Fused-full detection now compares `qs_shape[0]` against `offsets[4]` (logical unpacked out_features summed from `proj.output_sizes`) instead of `weight.shape[0]`. The BnB Params4bit buffer is stored packed as uint8 with half the rows, so the prior comparison never matched real BnB layouts and the fused-full guard silently fell through to the per-shard branch. - Final else branch now re-emits `_store_quant_state(in_proj_qkv, qkv_states[0])` alongside the z quant_state, restoring the emission that 7ec99b6 inadvertently dropped when it collapsed the old nested `if/else` block.
1a32ddb to
70d7381
Compare
Staging mirror of unslothai#588
Original PR: unslothai#588
Author: Datta0
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Tried to make sure the changes are minimal so when we detect linear_attn, we hand it off to a separate function