[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #3
Conversation
Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for Gemma 4 and Qwen 3.5 (GDN) models, implementing vLLM patches for LoRA and k_eq_v support, and adding logic to extract Gated Delta Net layers. It also refactors model finalization, dtype setting, and state dict conversion to improve compatibility across various architectures. Review feedback suggests enhancing the robustness of language model detection, safeguarding against zero-dimension rotary embeddings, and replacing legacy .data usage with .detach() for better PyTorch compliance.
| if original_meta_model is not None: | ||
| copy_attributes(original_meta_model, new_model) | ||
|
|
||
| language_model = getattr(getattr(new_model, "model", None), "language_model", None) |
There was a problem hiding this comment.
The current logic for detecting the language_model is not robust enough. For models like Gemma4ForConditionalGeneration, the language_model attribute is at the top level, and new_model.model might be None or something else. It's better to use a more comprehensive check similar to the one used in set_additional_modules to ensure the language model and its layers are correctly identified.
if hasattr(new_model, "language_model"):
language_model = new_model.language_model
elif hasattr(new_model, "model") and hasattr(new_model.model, "language_model"):
language_model = new_model.model.language_model
else:
language_model = getattr(new_model, "model", None)| if hasattr(module, "rotary_pos_emb"): | ||
| assert vision_config is not None, "Unsloth: vision_config is required for models with vision rotary_pos_emb" | ||
| head_dim = vision_config.hidden_size // vision_config.num_heads | ||
| module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(target_device) |
There was a problem hiding this comment.
If head_dim is 1 (which is the case for the empty vision model created in create_empty_vision_model), head_dim // 2 will be 0. This might cause issues when re-initializing rotary_pos_emb modules, as they typically expect a non-zero dimension for frequency calculation. Consider ensuring head_dim is at least 2 or adding a safety clamp.
| module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(target_device) | |
| module.rotary_pos_emb = module.rotary_pos_emb.__class__(max(1, head_dim//2)).to(target_device) |
| def _unwrap_tensor(val): | ||
| return getattr(val, "data", val) |
There was a problem hiding this comment.
Using .data is generally discouraged in modern PyTorch as it can bypass autograd checks and lead to silent bugs. It is safer to use .detach() to obtain a tensor that shares storage but is not tracked by autograd.
| def _unwrap_tensor(val): | |
| return getattr(val, "data", val) | |
| def _unwrap_tensor(val): | |
| return val.detach() if hasattr(val, "detach") else val |
| get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba) | ||
| get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba) | ||
|
|
||
| store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data) |
| def _unwrap_tensor(value): | ||
| return getattr(value, "data", value) |
There was a problem hiding this comment.
Similar to the suggestion in empty_model.py, consider using .detach() instead of .data for unwrapping tensors to follow modern PyTorch best practices.
| def _unwrap_tensor(value): | |
| return getattr(value, "data", value) | |
| def _unwrap_tensor(value): | |
| return value.detach() if hasattr(value, "detach") else value |
…n, finalize_huggingface_model
- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
_call_create_lora_manager's signature inspection still sees vllm_config; pass model
positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
rotary_pos_emb on vision_config availability; mirror language_model detection from
set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
(defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
so non-tensor state-dict values pass through.
|
Fixes pushed to unslothai#588. |
Staging mirror of unslothai#588
Original PR: unslothai#588
Author: Datta0
This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.
Original description
Tried to make sure the changes are minimal so when we detect linear_attn, we hand it off to a separate function