-
Notifications
You must be signed in to change notification settings - Fork 0
[Qwen 3.5][gemma4] Qwen35 and Gemma 4 fast inference (clean-base mirror of #588) #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
20e42fe
a6597a7
ca22808
3fcf1ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |||||||||||||
| import torch | ||||||||||||||
| import re | ||||||||||||||
| import os | ||||||||||||||
| import functools | ||||||||||||||
| from copy import deepcopy | ||||||||||||||
| from .utils import get_quant_type | ||||||||||||||
| from .log import logger | ||||||||||||||
|
|
@@ -343,10 +344,11 @@ def patched_supports_lora(model): | |||||||||||||
| if not hasattr(vllm_lora_model_manager.create_lora_manager, "_unsloth_gemma4_patch"): | ||||||||||||||
| original_create_lora_manager = vllm_lora_model_manager.create_lora_manager | ||||||||||||||
|
|
||||||||||||||
| @functools.wraps(original_create_lora_manager) | ||||||||||||||
| def patched_create_lora_manager(model, *args, **kwargs): | ||||||||||||||
| if model.__class__.__name__ == "Gemma4ForConditionalGeneration": | ||||||||||||||
| lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager) | ||||||||||||||
| return lora_manager_cls(model = model, *args, **kwargs) | ||||||||||||||
| return lora_manager_cls(model, *args, **kwargs) | ||||||||||||||
| return original_create_lora_manager(model, *args, **kwargs) | ||||||||||||||
|
|
||||||||||||||
| patched_create_lora_manager._unsloth_gemma4_patch = True | ||||||||||||||
|
|
@@ -758,6 +760,8 @@ def get_model_layer_config(return_non_layered=True): | |||||||||||||
| layer_templates = { | ||||||||||||||
| 'standard_layers': { | ||||||||||||||
| "model.language_model.layers.{kk}.layer_scalar", | ||||||||||||||
| "model.language_model.layers.{kk}.per_layer_input_gate", | ||||||||||||||
| "model.language_model.layers.{kk}.per_layer_projection", | ||||||||||||||
| "model.language_model.layers.{kk}.self_attn.q_proj", | ||||||||||||||
| "model.language_model.layers.{kk}.self_attn.k_proj", | ||||||||||||||
| "model.language_model.layers.{kk}.self_attn.v_proj", | ||||||||||||||
|
|
@@ -769,6 +773,8 @@ def get_model_layer_config(return_non_layered=True): | |||||||||||||
| "model.language_model.layers.{kk}.mlp.down_proj", | ||||||||||||||
|
|
||||||||||||||
| "model.layers.{kk}.layer_scalar", | ||||||||||||||
| "model.layers.{kk}.per_layer_input_gate", | ||||||||||||||
| "model.layers.{kk}.per_layer_projection", | ||||||||||||||
| "model.layers.{kk}.self_attn.q_proj", | ||||||||||||||
| "model.layers.{kk}.self_attn.k_proj", | ||||||||||||||
| "model.layers.{kk}.self_attn.v_proj", | ||||||||||||||
|
|
@@ -801,6 +807,7 @@ def get_model_layer_config(return_non_layered=True): | |||||||||||||
| "model.language_model.layers.{kk}.post_attention_layernorm", | ||||||||||||||
| "model.language_model.layers.{kk}.pre_feedforward_layernorm", | ||||||||||||||
| "model.language_model.layers.{kk}.post_feedforward_layernorm", | ||||||||||||||
| "model.language_model.layers.{kk}.post_per_layer_input_norm", | ||||||||||||||
| "model.language_model.layers.{kk}.self_attn.q_norm", | ||||||||||||||
| "model.language_model.layers.{kk}.self_attn.k_norm", | ||||||||||||||
| "model.language_model.layers.{kk}.cross_attn.q_norm", | ||||||||||||||
|
|
@@ -809,6 +816,7 @@ def get_model_layer_config(return_non_layered=True): | |||||||||||||
| "model.layers.{kk}.post_attention_layernorm", | ||||||||||||||
| "model.layers.{kk}.pre_feedforward_layernorm", | ||||||||||||||
| "model.layers.{kk}.post_feedforward_layernorm", | ||||||||||||||
| "model.layers.{kk}.post_per_layer_input_norm", | ||||||||||||||
| "model.layers.{kk}.self_attn.q_norm", | ||||||||||||||
| "model.layers.{kk}.self_attn.k_norm", | ||||||||||||||
| "model.visual.blocks.{kk}.norm1", | ||||||||||||||
|
|
@@ -925,7 +933,6 @@ def get_model_layer_config(return_non_layered=True): | |||||||||||||
| # qwen 3 vl | ||||||||||||||
| "model.visual.deepstack_merger_list.{kk}.linear_fc1", | ||||||||||||||
| "model.visual.deepstack_merger_list.{kk}.linear_fc2", | ||||||||||||||
| "model.visual.merger.linear_fc{kk}", | ||||||||||||||
|
|
||||||||||||||
| }, | ||||||||||||||
| "non_layered_components":{ | ||||||||||||||
|
|
@@ -965,6 +972,8 @@ def get_model_layer_config(return_non_layered=True): | |||||||||||||
| # qwen 3 vl | ||||||||||||||
| "model.visual.pos_embed", | ||||||||||||||
| "model.visual.merger.norm", | ||||||||||||||
| "model.visual.merger.linear_fc1", | ||||||||||||||
| "model.visual.merger.linear_fc2", | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -1077,13 +1086,22 @@ def store(name, value): | |||||||||||||
| get_state_dict(f"{prefix}.in_proj_qkv", 0, state_dict, gdn.in_proj_qkv, slice_weights=False) | ||||||||||||||
| get_state_dict(f"{prefix}.in_proj_z", 0, state_dict, gdn.in_proj_z, slice_weights=False) | ||||||||||||||
|
|
||||||||||||||
| get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba) | ||||||||||||||
| get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba) | ||||||||||||||
| if hasattr(gdn, "in_proj_ba"): | ||||||||||||||
| get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_ba) | ||||||||||||||
| get_state_dict(f"{prefix}.in_proj_a", 1, state_dict, gdn.in_proj_ba) | ||||||||||||||
| else: | ||||||||||||||
| get_state_dict(f"{prefix}.in_proj_b", 0, state_dict, gdn.in_proj_b, slice_weights=False) | ||||||||||||||
| get_state_dict(f"{prefix}.in_proj_a", 0, state_dict, gdn.in_proj_a, slice_weights=False) | ||||||||||||||
|
|
||||||||||||||
| store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data) | ||||||||||||||
| store(f"{prefix}.dt_bias", gdn.dt_bias.data) | ||||||||||||||
| store(f"{prefix}.A_log", gdn.A_log.data) | ||||||||||||||
|
Comment on lines
+1096
to
+1098
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| norm = getattr(gdn, "norm", None) | ||||||||||||||
| norm_weight = getattr(norm, "weight", None) if norm is not None else None | ||||||||||||||
| if norm_weight is not None: | ||||||||||||||
| store(f"{prefix}.norm.weight", norm_weight.data) | ||||||||||||||
|
|
||||||||||||||
| get_state_dict(f"{prefix}.out_proj", 0, state_dict, gdn.out_proj) | ||||||||||||||
| pass | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1167,6 +1167,13 @@ def _is_fused_module(name: str) -> bool: | |
| if hasattr(layer, "layer_scalar"): | ||
| state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data | ||
| quant_state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data | ||
| for per_layer_name in ("per_layer_input_gate", "per_layer_projection"): | ||
| per_layer_module = getattr(layer, per_layer_name, None) | ||
| if per_layer_module is not None and hasattr(per_layer_module, "weight"): | ||
| get_state_dict( | ||
| f"{vllm_text_model_prefix}.layers.{kk}.{per_layer_name}", | ||
| 0, state_dict, per_layer_module, | ||
| ) | ||
| pass | ||
|
|
||
| if len(skipped_layernorms) != 0: | ||
|
|
@@ -1216,6 +1223,8 @@ def assert_same_state_dict(old_state_dict, new_state_dict): | |
| def _normalize_state_dict_tensor(value): | ||
| if isinstance(value, torch.nn.Parameter): | ||
| value = value.detach() | ||
| if not isinstance(value, torch.Tensor): | ||
| return value | ||
| if value.is_sparse: | ||
| value = value.to_dense() | ||
| return value.contiguous() | ||
|
Comment on lines
+1223
to
+1230
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
@@ -1396,8 +1405,14 @@ def _override_to(self, *args, **kwargs): | |
| if layer_name in quant_state_dict: | ||
| # for attributes of type nn.Parameter, there's no .weight | ||
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) | ||
| layer = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) | ||
| exec(f"new_model.{layer_name_br} = layer") | ||
| value = _unwrap_tensor(weight) | ||
| parent_expr, attr_name = layer_name_br.rsplit(".", 1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| parent_module = eval(f"new_model.{parent_expr}") | ||
| if attr_name in getattr(parent_module, "_buffers", {}): | ||
| parent_module._buffers[attr_name] = value | ||
| else: | ||
| layer = torch.nn.Parameter(value, requires_grad = False) | ||
| exec(f"new_model.{layer_name_br} = layer") | ||
| continue | ||
| elif fp8_weight_scale is not None: | ||
| if fp8_weight_scale.ndim == 1: | ||
|
|
@@ -1448,9 +1463,23 @@ def _override_to(self, *args, **kwargs): | |
| layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) | ||
| layer.bias = bias | ||
| else: | ||
| # LayerNorms (including vision norms) | ||
| # LayerNorms (including vision norms) and depthwise Conv1d | ||
| weight_param = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad=False) | ||
| layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) | ||
| if layer_name.endswith(".conv1d"): | ||
| target = eval(f"new_model.{layer_name_br}") | ||
| w = _unwrap_tensor(weight) | ||
| out_channels = w.shape[0] | ||
| kernel_size = w.shape[-1] | ||
| target.out_channels = out_channels | ||
| target.in_channels = out_channels | ||
| target.groups = out_channels | ||
| target.kernel_size = (kernel_size,) | ||
| target.padding = (kernel_size - 1,) | ||
| target.weight = weight_param | ||
| if bias is not None: | ||
| target.bias = bias | ||
| continue | ||
|
Comment on lines
+1469
to
+1482
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of |
||
| # Set weight | ||
| exec(f"new_model.{layer_name_br}.weight = None") | ||
| exec(f"new_model.{layer_name_br}.weight = weight_param") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the
patched_supports_lorafunction, usingisinstancehere is more robust than comparing class names as strings.