Skip to content

[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #3

Closed
danielhanchen wants to merge 10 commits into
mainfrom
pr-588-head
Closed

[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference #3
danielhanchen wants to merge 10 commits into
mainfrom
pr-588-head

Conversation

@danielhanchen

Copy link
Copy Markdown
Collaborator

Staging mirror of unslothai#588

Original PR: unslothai#588
Author: Datta0

This is a staging copy for review and editing. Once finalized, changes will be pushed back to the original PR.


Original description

  • fast inference support for Qwen 3.5 with vLLM :)
    Tried to make sure the changes are minimal so when we detect linear_attn, we hand it off to a separate function

@danielhanchen

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Gemma 4 and Qwen 3.5 (GDN) models, implementing vLLM patches for LoRA and k_eq_v support, and adding logic to extract Gated Delta Net layers. It also refactors model finalization, dtype setting, and state dict conversion to improve compatibility across various architectures. Review feedback suggests enhancing the robustness of language model detection, safeguarding against zero-dimension rotary embeddings, and replacing legacy .data usage with .detach() for better PyTorch compliance.

Comment thread unsloth_zoo/empty_model.py Outdated
if original_meta_model is not None:
copy_attributes(original_meta_model, new_model)

language_model = getattr(getattr(new_model, "model", None), "language_model", None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic for detecting the language_model is not robust enough. For models like Gemma4ForConditionalGeneration, the language_model attribute is at the top level, and new_model.model might be None or something else. It's better to use a more comprehensive check similar to the one used in set_additional_modules to ensure the language model and its layers are correctly identified.

    if hasattr(new_model, "language_model"):
        language_model = new_model.language_model
    elif hasattr(new_model, "model") and hasattr(new_model.model, "language_model"):
        language_model = new_model.model.language_model
    else:
        language_model = getattr(new_model, "model", None)

if hasattr(module, "rotary_pos_emb"):
assert vision_config is not None, "Unsloth: vision_config is required for models with vision rotary_pos_emb"
head_dim = vision_config.hidden_size // vision_config.num_heads
module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(target_device)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If head_dim is 1 (which is the case for the empty vision model created in create_empty_vision_model), head_dim // 2 will be 0. This might cause issues when re-initializing rotary_pos_emb modules, as they typically expect a non-zero dimension for frequency calculation. Consider ensuring head_dim is at least 2 or adding a safety clamp.

Suggested change
module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(target_device)
module.rotary_pos_emb = module.rotary_pos_emb.__class__(max(1, head_dim//2)).to(target_device)

Comment on lines +538 to +539
def _unwrap_tensor(val):
return getattr(val, "data", val)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Using .data is generally discouraged in modern PyTorch as it can bypass autograd checks and lead to silent bugs. It is safer to use .detach() to obtain a tensor that shares storage but is not tracked by autograd.

Suggested change
def _unwrap_tensor(val):
return getattr(val, "data", val)
def _unwrap_tensor(val):
return val.detach() if hasattr(val, "detach") else val

get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba)
get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba)

store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Accessing .data is legacy PyTorch practice. Using .detach() is preferred for extracting tensors without autograd tracking.

Suggested change
store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data)
store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.detach())

Comment thread unsloth_zoo/vllm_utils.py
Comment on lines +1273 to +1274
def _unwrap_tensor(value):
return getattr(value, "data", value)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

Similar to the suggestion in empty_model.py, consider using .detach() instead of .data for unwrapping tensors to follow modern PyTorch best practices.

Suggested change
def _unwrap_tensor(value):
return getattr(value, "data", value)
def _unwrap_tensor(value):
return value.detach() if hasattr(value, "detach") else value

…n, finalize_huggingface_model

- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
  _call_create_lora_manager's signature inspection still sees vllm_config; pass model
  positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
  upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
  qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
  so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
  ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
  of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
  handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
  export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
  rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
  rotary_pos_emb on vision_config availability; mirror language_model detection from
  set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
  the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
  'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
  current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
  (defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
  so non-tensor state-dict values pass through.
@danielhanchen

Copy link
Copy Markdown
Collaborator Author

Fixes pushed to unslothai#588.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants