diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index f9ff7cba0..70ece548e 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -17,6 +17,10 @@ __all__ = [ "create_empty_model", "set_additional_modules", + "finalize_huggingface_model", + "patch_gemma4_vllm_lora_support", + "patch_gemma4_vllm_k_eq_v_support", + "extract_gdn_layers", "extract_vision_layers", "get_model_layer_config", "compare_attributes", @@ -29,7 +33,7 @@ from copy import deepcopy from .utils import get_quant_type from .log import logger -from .hf_utils import HAS_TORCH_DTYPE, dtype_from_config +from .hf_utils import HAS_TORCH_DTYPE, dtype_from_config, set_dtype_in_config def is_comparable(val): # Don't treat tensors as comparable, only basic types @@ -280,6 +284,14 @@ def create_empty_causal_lm(config, dtype = torch.float16): new_config.vocab_size = 1 new_config.pad_token_id = 0 + _set_config_attrs(new_config, { + "linear_num_key_heads": 1, + "linear_num_value_heads": 1, + "linear_key_head_dim": 1, + "linear_value_head_dim": 1, + "linear_conv_kernel_dim": 1, + }) + # Set attention module head_dim head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) new_config.update({"head_dim" : head_dim}) @@ -298,6 +310,149 @@ def _set_config_attrs(config_obj, attrs_to_set): setattr(config_obj, attr, value) pass +def _get_model_device(model): + for tensor in model.parameters(): + return tensor.device + for tensor in model.buffers(): + return tensor.device + return torch.device("cpu") +pass + +def patch_gemma4_vllm_lora_support(): + from functools import wraps + from vllm.model_executor.models.gemma4_mm import Gemma4ForConditionalGeneration + from vllm.model_executor.models import interfaces as vllm_model_interfaces + from vllm.lora import model_manager as vllm_lora_model_manager + try: + from vllm.v1.worker import lora_model_runner_mixin + except ImportError: + lora_model_runner_mixin = None + from unsloth_zoo import vllm_lora_worker_manager + + gemma4_lora_classes = ["Gemma4ForConditionalGeneration"] + classes_to_patch = [Gemma4ForConditionalGeneration] + try: + from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM + gemma4_lora_classes.append("Gemma4ForCausalLM") + classes_to_patch.append(Gemma4ForCausalLM) + except Exception: + pass + gemma4_lora_classes = set(gemma4_lora_classes) + + for cls in classes_to_patch: + if not getattr(cls, "_unsloth_gemma4_class_patched", False): + cls.supports_lora = True + cls.embedding_modules = {} + cls._unsloth_gemma4_class_patched = True + + original_supports_lora = getattr( + lora_model_runner_mixin, "supports_lora", vllm_model_interfaces.supports_lora + ) + if not hasattr(original_supports_lora, "_unsloth_gemma4_patch"): + def patched_supports_lora(model): + if model.__class__.__name__ in gemma4_lora_classes: + return True + return original_supports_lora(model) + + patched_supports_lora._unsloth_gemma4_patch = True + if lora_model_runner_mixin is not None: + lora_model_runner_mixin.supports_lora = patched_supports_lora + vllm_model_interfaces.supports_lora = patched_supports_lora + + if not hasattr(vllm_lora_model_manager.create_lora_manager, "_unsloth_gemma4_patch"): + original_create_lora_manager = vllm_lora_model_manager.create_lora_manager + + @wraps(original_create_lora_manager) + def patched_create_lora_manager(model, *args, **kwargs): + if model.__class__.__name__ in gemma4_lora_classes: + lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager) + return lora_manager_cls(model, *args, **kwargs) + return original_create_lora_manager(model, *args, **kwargs) + + patched_create_lora_manager._unsloth_gemma4_patch = True + vllm_lora_model_manager.create_lora_manager = patched_create_lora_manager + vllm_lora_worker_manager.create_lora_manager = patched_create_lora_manager +pass + +# Prequantized BnB Gemma4 k_eq_v layers lack a synthetic v quant-state shard; +# we duplicate K -> V at loader-side quant-state stacking time. +def patch_gemma4_vllm_k_eq_v_support(): + from vllm.model_executor.model_loader.bitsandbytes_loader import ( + BitsAndBytesModelLoader, + ) + + if hasattr( + BitsAndBytesModelLoader._stack_quantization_states, + "_unsloth_gemma4_k_eq_v_patch", + ): + return + + original_stack_quantization_states = ( + BitsAndBytesModelLoader._stack_quantization_states + ) + + def _get_gemma4_text_config(model): + config = getattr(model, "config", None) + if config is None: + return None + + text_config = getattr(config, "text_config", config) + model_type = getattr(config, "model_type", None) + text_model_type = getattr(text_config, "model_type", None) + if model_type != "gemma4" and text_model_type not in ("gemma4", "gemma4_text"): + return None + return text_config + + def _get_gemma4_k_eq_v_pairs(model): + text_config = _get_gemma4_text_config(model) + if text_config is None or not getattr(text_config, "attention_k_eq_v", False): + return () + + param_names = set(name for name, _ in model.named_parameters()) + pairs = [] + for layer_idx, layer_type in enumerate(getattr(text_config, "layer_types", ())): + if layer_type != "full_attention": + continue + + for prefix in ("language_model.model", "model"): + k_name = f"{prefix}.layers.{layer_idx}.self_attn.k_proj.weight" + v_name = f"{prefix}.layers.{layer_idx}.self_attn.v_proj.weight" + qkv_name = f"{prefix}.layers.{layer_idx}.self_attn.qkv_proj.weight" + if k_name in param_names: + pairs.append(("split", k_name, v_name)) + break + if qkv_name in param_names: + pairs.append(("packed", qkv_name, None)) + break + return tuple(pairs) + + def patched_stack_quantization_states(self, model, quant_state_dict): + stacked_quant_state_dict = original_stack_quantization_states( + self, model, quant_state_dict + ) + + for kind, source, target in _get_gemma4_k_eq_v_pairs(model): + quant_states = stacked_quant_state_dict.get(source) + if quant_states is None: + continue + + # k_eq_v reuses K as V: the raw-weight loader already duplicates + # k_proj -> v_proj, so prequant BnB needs the matching QuantState. + if kind == "packed": + if isinstance(quant_states, dict) and 2 not in quant_states and 1 in quant_states: + quant_states[2] = deepcopy(quant_states[1]) + elif kind == "split": + if target not in stacked_quant_state_dict: + stacked_quant_state_dict[target] = deepcopy(quant_states) + + return stacked_quant_state_dict + + patched_stack_quantization_states._unsloth_gemma4_k_eq_v_patch = True + BitsAndBytesModelLoader._stack_quantization_states = ( + patched_stack_quantization_states + ) +pass + @torch.inference_mode def create_empty_vision_model(config, dtype = torch.float16): @@ -352,6 +507,14 @@ def _init_weights(self, module): "head_dim": 1, "pad_token_id": 1, }) + # Qwen 3.5 or GDN related attrs + _set_config_attrs(new_config.text_config, { + "linear_num_key_heads": 1, + "linear_num_value_heads": 1, + "linear_key_head_dim": 1, + "linear_value_head_dim": 1, + "linear_conv_kernel_dim": 1, + }) # Common vision attributes _set_config_attrs(new_config.vision_config, { @@ -369,13 +532,9 @@ def _init_weights(self, module): text_layers = config.text_config.num_hidden_layers vision_layers = getattr(config.vision_config, "num_hidden_layers", None) or getattr(config.vision_config, "depth", 0) - # Set minimal sizes for different model types - if model_type == "qwen2_5_vl": - new_config.vision_config.out_hidden_size = 1 - elif model_type == "qwen3_vl": + if model_type in ("qwen2_5_vl", "qwen3_vl", "qwen3_5"): new_config.vision_config.out_hidden_size = 1 - num_layers = max(text_layers, vision_layers) new_model = model_cls(new_config) @@ -400,9 +559,15 @@ def create_empty_model(config, dtype = torch.float16, is_vision_model = False): @torch.inference_mode def set_additional_modules(new_model, quant_state_dict, config): + def _unwrap_tensor(val): + return getattr(val, "data", val) + if hasattr(new_model, "language_model"): language_model = new_model.language_model language_model_prefix = "model.language_model" + elif hasattr(new_model, "model") and hasattr(new_model.model, "language_model"): + language_model = new_model.model.language_model + language_model_prefix = "model.language_model" else: language_model_prefix = "model" language_model = new_model.model @@ -422,13 +587,13 @@ def set_additional_modules(new_model, quant_state_dict, config): # freeze = True, # padding_idx = pad_token_id, # ) - # we cannot use the normal embedding init because gemma3 uses Gemma3TextScaledWordEmbedding which wraps around nn.Embedding and has a scaling factor. This new init ensures that we respect the forward from original model. + # gemma3 uses Gemma3TextScaledWordEmbedding (nn.Embedding subclass with + # an embed_scale); in-place weight assignment preserves its forward. def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): num_embeddings, embedding_dim = quant_state_dict[embed_tokens_key].shape - embeddings = quant_state_dict[embed_tokens_key] + embeddings = _unwrap_tensor(quant_state_dict[embed_tokens_key]) if isinstance(embeddings, torch.Tensor): - # in the newer vLLM versions, this seems to return a tensor which can't be assigned to embedding weight - # we need to convert that to nn.Paramter and then pass it on + # Newer vLLM returns a plain tensor; wrap it so it can be assigned. embeddings = torch.nn.Parameter(embeddings, requires_grad = requires_grad) module.weight = embeddings module.padding_idx = pad_token_id @@ -444,6 +609,7 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): # Norm norm_key = f"{language_model_prefix}.norm.weight" norm = quant_state_dict[norm_key] + norm = _unwrap_tensor(norm) norm = torch.nn.Parameter(norm, requires_grad = False) language_model.norm.weight = norm @@ -456,42 +622,38 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): else: lmhead_key = "lm_head.weight" - # Check if lm_head exists in the state dict if lmhead_key in quant_state_dict: - weight = quant_state_dict[lmhead_key] + weight = _unwrap_tensor(quant_state_dict[lmhead_key]) from torch.nn import Linear - # Create Linear layer with zero dimensions to avoid any weight allocation + # Zero-dim Linear skips default weight allocation before we assign the real one. layer = Linear(0, 0, device=weight.device, bias=False) - # Set correct dimensions layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] - # Assign the weight directly (no deletion needed since no weight was allocated) layer.weight = torch.nn.Parameter(weight, requires_grad=False) - # Set lm_head at the correct level if hasattr(new_model, "lm_head"): new_model.lm_head = layer + elif hasattr(language_model, "lm_head"): + language_model.lm_head = layer else: - # For multimodal models, check if language_model has lm_head - if hasattr(language_model, "lm_head"): - language_model.lm_head = layer - else: - new_model.lm_head = layer + new_model.lm_head = layer if getattr(config, "tie_word_embeddings", False): - # For tied embeddings, tie the weights properly if hasattr(new_model, "tie_weights"): new_model.tie_weights() elif hasattr(language_model, "tie_weights"): language_model.tie_weights() - # Process additional keys - # For any layers that are potentially in non layered components. - # Preferably norms, embeddings and convolution type layers. + # Non-layered components (norms, embeddings, conv-style layers). + non_layered_components = get_model_layer_config()["non_layered_components"] + exact_non_layered = {n for n in non_layered_components if "{kk}" not in n} additional_keys = set( x for x in quant_state_dict.keys() - if not any(substr in x for substr in ("layers", "blocks", embed_tokens_key, norm_key, "lm_head", "mlp", "linear", "list")) + if ( + any(x == n or x.startswith(n + ".") for n in exact_non_layered) + or not any(substr in x for substr in ("layers", "blocks", embed_tokens_key, norm_key, "lm_head", "mlp", "linear", "list")) + ) ) print(f'Performing substitution for {additional_keys=}') @@ -500,6 +662,7 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): for prefix in ['new_', 'new_model.']: try: val = quant_state_dict[key] + val = _unwrap_tensor(val) if isinstance(val, torch.Tensor): val = torch.nn.Parameter(val,requires_grad=False) exec(f"{prefix}{key} = val") @@ -510,6 +673,117 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): pass pass +@torch.inference_mode +def finalize_huggingface_model( + new_model, + original_meta_model, + config, + dtype, + quantization_config = None, + bnb_config = None, +): + if original_meta_model is not None: + copy_attributes(original_meta_model, new_model) + + if hasattr(new_model, "language_model"): + lm_root = new_model.language_model + elif hasattr(new_model, "model") and hasattr(new_model.model, "language_model"): + lm_root = new_model.model.language_model + else: + lm_root = getattr(new_model, "model", None) + + if lm_root is not None and hasattr(lm_root, "layers"): + for layer_idx, layer in enumerate(lm_root.layers): + if hasattr(layer, "layer_idx"): + layer.layer_idx = layer_idx + for attr_name in ("self_attn", "cross_attn", "mlp", "linear_attn"): + submodule = getattr(layer, attr_name, None) + if submodule is not None and hasattr(submodule, "layer_idx"): + submodule.layer_idx = layer_idx + + known_configs = {id(config)} + for sub_name in ("text_config", "vision_config", "audio_config"): + sub_cfg = getattr(config, sub_name, None) + if sub_cfg is not None: + known_configs.add(id(sub_cfg)) + + live_root = getattr(new_model, "config", None) + if live_root is not None and id(live_root) not in known_configs: + set_dtype_in_config(live_root, dtype) + known_configs.add(id(live_root)) + for sub_name in ("text_config", "vision_config", "audio_config"): + sub_cfg = getattr(live_root, sub_name, None) + if sub_cfg is not None and id(sub_cfg) not in known_configs: + set_dtype_in_config(sub_cfg, dtype) + known_configs.add(id(sub_cfg)) + + for module in new_model.modules(): + module_config = getattr(module, "config", None) + if module_config is not None and id(module_config) in known_configs: + set_dtype_in_config(module_config, dtype) + + target_device = _get_model_device(new_model) + text_config = getattr(config, "text_config", config) + vision_config = getattr(config, "vision_config", None) + + vision_config_ids = set() + if vision_config is not None: + vision_config_ids.add(id(vision_config)) + live_vision_config = getattr(live_root, "vision_config", None) if live_root is not None else None + if live_vision_config is not None: + vision_config_ids.add(id(live_vision_config)) + + local_rope_config = None + for module_name, module in new_model.named_modules(): + if hasattr(module, "rotary_emb"): + current_rotary_config = getattr(module.rotary_emb, "config", None) + is_vision_rotary = vision_config is not None and ( + "vision_tower" in module_name + or "vision_model" in module_name + or (current_rotary_config is not None and id(current_rotary_config) in vision_config_ids) + ) + rotary_config = vision_config if is_vision_rotary else text_config + try: + module.rotary_emb = module.rotary_emb.__class__( + config = rotary_config, + device = target_device, + ) + except Exception: + pass + for buffer_name, buffer in list(module.rotary_emb._buffers.items()): + if torch.is_tensor(buffer) and buffer.is_floating_point(): + module.rotary_emb._buffers[buffer_name] = buffer.to( + device = target_device, + dtype = torch.float32, + ) + if hasattr(module, "rotary_pos_emb") and vision_config is not None: + head_dim = vision_config.hidden_size // vision_config.num_heads + module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(target_device) + if hasattr(module, "rotary_emb_local"): + if local_rope_config is None: + local_rope_config = deepcopy(text_config) + local_rope_config.rope_theta = text_config.rope_local_base_freq + local_rope_config.rope_scaling = {"rope_type": "default"} + module.rotary_emb_local = module.rotary_emb_local.__class__( + config = local_rope_config, + device = target_device, + ) + + if (quantization_config or {}) == {} and bnb_config is None: + new_model = new_model.to(device = target_device, dtype = dtype) + for module in new_model.modules(): + rotary_emb = getattr(module, "rotary_emb", None) + if rotary_emb is None: + continue + for buffer_name, buffer in list(rotary_emb._buffers.items()): + if torch.is_tensor(buffer) and buffer.is_floating_point(): + rotary_emb._buffers[buffer_name] = buffer.to( + device = target_device, + dtype = torch.float32, + ) + return new_model +pass + def get_model_layer_config(return_non_layered=True): """ Returns a unified layer configuration containing the union of layer names @@ -520,6 +794,7 @@ def get_model_layer_config(return_non_layered=True): """ layer_templates = { 'standard_layers': { + "model.language_model.layers.{kk}.layer_scalar", "model.language_model.layers.{kk}.self_attn.q_proj", "model.language_model.layers.{kk}.self_attn.k_proj", "model.language_model.layers.{kk}.self_attn.v_proj", @@ -530,6 +805,7 @@ def get_model_layer_config(return_non_layered=True): "model.language_model.layers.{kk}.mlp.gate_up_proj", # for extracting from vLLM (phi3 architecture) "model.language_model.layers.{kk}.mlp.down_proj", + "model.layers.{kk}.layer_scalar", "model.layers.{kk}.self_attn.q_proj", "model.layers.{kk}.self_attn.k_proj", "model.layers.{kk}.self_attn.v_proj", @@ -539,6 +815,29 @@ def get_model_layer_config(return_non_layered=True): "model.layers.{kk}.mlp.up_proj", "model.layers.{kk}.mlp.gate_up_proj", # for extracting from vLLM (phi3 architecture) "model.layers.{kk}.mlp.down_proj", + "model.language_model.layers.{kk}.linear_attn.in_proj_qkv", + "model.language_model.layers.{kk}.linear_attn.in_proj_z", + "model.language_model.layers.{kk}.linear_attn.in_proj_b", + "model.language_model.layers.{kk}.linear_attn.in_proj_a", + "model.language_model.layers.{kk}.linear_attn.conv1d", + "model.language_model.layers.{kk}.linear_attn.out_proj", + "model.language_model.layers.{kk}.linear_attn.dt_bias", + "model.language_model.layers.{kk}.linear_attn.A_log", + + "model.layers.{kk}.linear_attn.in_proj_qkv", + "model.layers.{kk}.linear_attn.in_proj_z", + "model.layers.{kk}.linear_attn.in_proj_b", + "model.layers.{kk}.linear_attn.in_proj_a", + "model.layers.{kk}.linear_attn.conv1d", + "model.layers.{kk}.linear_attn.out_proj", + "model.layers.{kk}.linear_attn.dt_bias", + "model.layers.{kk}.linear_attn.A_log", + + # Gemma4 per-layer input modules + "model.language_model.layers.{kk}.per_layer_input_gate", + "model.language_model.layers.{kk}.per_layer_projection", + "model.layers.{kk}.per_layer_input_gate", + "model.layers.{kk}.per_layer_projection", }, 'layernorms': { "model.language_model.layers.{kk}.input_layernorm", @@ -560,6 +859,12 @@ def get_model_layer_config(return_non_layered=True): "model.vision_tower.vision_model.encoder.layers.{kk}.post_layernorm", "model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm1", "model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm2", + "model.vision_tower.encoder.layers.{kk}.input_layernorm", + "model.vision_tower.encoder.layers.{kk}.post_attention_layernorm", + "model.vision_tower.encoder.layers.{kk}.pre_feedforward_layernorm", + "model.vision_tower.encoder.layers.{kk}.post_feedforward_layernorm", + "model.vision_tower.encoder.layers.{kk}.self_attn.q_norm", + "model.vision_tower.encoder.layers.{kk}.self_attn.k_norm", # Mistral3 vision norms "model.vision_tower.transformer.layers.{kk}.attention_norm", @@ -567,6 +872,12 @@ def get_model_layer_config(return_non_layered=True): # qwen3 vl "model.visual.deepstack_merger_list.{kk}.norm", + "model.language_model.layers.{kk}.linear_attn.norm", + "model.layers.{kk}.linear_attn.norm", + + # Gemma4 per-layer input norm + "model.language_model.layers.{kk}.post_per_layer_input_norm", + "model.layers.{kk}.post_per_layer_input_norm", }, 'vision_layers': { @@ -610,6 +921,13 @@ def get_model_layer_config(return_non_layered=True): "model.vision_tower.vision_model.encoder.layers.{kk}.mlp.fc1", "model.vision_tower.vision_model.encoder.layers.{kk}.mlp.fc2", + "model.vision_tower.encoder.layers.{kk}.self_attn.q_proj.linear", + "model.vision_tower.encoder.layers.{kk}.self_attn.k_proj.linear", + "model.vision_tower.encoder.layers.{kk}.self_attn.v_proj.linear", + "model.vision_tower.encoder.layers.{kk}.self_attn.o_proj.linear", + "model.vision_tower.encoder.layers.{kk}.mlp.gate_proj.linear", + "model.vision_tower.encoder.layers.{kk}.mlp.up_proj.linear", + "model.vision_tower.encoder.layers.{kk}.mlp.down_proj.linear", # qwen2.5_vl style "model.visual.blocks.{kk}.attn.qkv", @@ -654,12 +972,13 @@ def get_model_layer_config(return_non_layered=True): # qwen 3 vl "model.visual.deepstack_merger_list.{kk}.linear_fc1", "model.visual.deepstack_merger_list.{kk}.linear_fc2", - "model.visual.merger.linear_fc{kk}", }, "non_layered_components":{ # we do not handle quantization for these layers yet # the set_additional_modules would process these layers + "model.visual.merger.linear_fc1", + "model.visual.merger.linear_fc2", "model.multi_modal_projector", "model.language_model.norm", 'model.vision_model.layernorm_pre', @@ -685,10 +1004,20 @@ def get_model_layer_config(return_non_layered=True): "model.vision_tower.patch_positional_embedding", "model.vision_tower.patch_conv", "model.vision_tower.ln_pre", + "model.vision_tower.std_bias", + "model.vision_tower.std_scale", + "model.vision_tower.patch_embedder.position_embedding_table", + "model.vision_tower.patch_embedder.input_proj", + "model.embed_vision.embedding_projection", # qwen 3 vl "model.visual.pos_embed", "model.visual.merger.norm", + + # Gemma4 top-level per-layer-input modules + "model.language_model.embed_tokens_per_layer", + "model.language_model.per_layer_model_projection", + "model.language_model.per_layer_projection_norm", } } @@ -732,6 +1061,11 @@ def get_model_layer_counts(config): "vision_layers": getattr(config.vision_config, "depth", 27), "deepstack_layers": getattr(config.vision_config, "deepstack_depth", 3), } + elif model_type == "gemma4": + return { + "text_layers": getattr(config.text_config, "num_hidden_layers", 32), + "vision_layers": getattr(config.vision_config, "num_hidden_layers", 32), + } elif model_type == "gemma3": return { "text_layers": getattr(config.text_config, "num_hidden_layers", 32), @@ -759,6 +1093,152 @@ def _get_nested_attr(obj, attr_path: str): return None +def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_state_dict): + gdn = gdn_module + + def _unwrap(v): + return getattr(v, "data", v) + + def store(name, value): + state_dict[name] = value + quant_state_dict[name] = value + + def _store_quant_state(name, quant_state): + if quant_state is None: + return + quant_state_dict[f"{name}.weight.quant_state"] = quant_state + try: + for k, v in quant_state.as_dict(packed=True).items(): + state_dict[f"{name}.weight.{k}"] = v + except Exception: + pass + + if hasattr(gdn, "in_proj_qkvz"): + proj = getattr(gdn.in_proj_qkvz, "base_layer", gdn.in_proj_qkvz) + raw_weight = proj.weight + weight = _unwrap(raw_weight) + + output_sizes = getattr(proj, "output_sizes", None) + if output_sizes is None: + key_dim = getattr(gdn, "key_dim", None) + value_dim = getattr(gdn, "value_dim", None) + if key_dim is None or value_dim is None: + raise RuntimeError( + "Unsloth: cannot infer GDN in_proj_qkvz shards without " + "proj.output_sizes or gdn.key_dim / gdn.value_dim" + ) + output_sizes = [key_dim, key_dim, value_dim, value_dim] + output_sizes = list(output_sizes) + offsets = [0] + for s in output_sizes: + offsets.append(offsets[-1] + s) + if len(offsets) < 5: + raise RuntimeError( + f"Unsloth: GDN in_proj_qkvz expected 4 shards (q,k,v,z); got sizes={output_sizes}" + ) + + qkv_weight = weight[offsets[0]:offsets[3]] + z_weight = weight[offsets[3]:offsets[4]] + + qs_attr = getattr(raw_weight, "bnb_quant_state", getattr(weight, "bnb_quant_state", None)) + qkv_states = [qs_attr.get(i) for i in (0, 1, 2)] if isinstance(qs_attr, dict) else [None, None, None] + if sum(qs is not None for qs in qkv_states) > 1: + try: + from bitsandbytes.functional import dequantize_4bit + except Exception: + raise RuntimeError( + "Unsloth: prequantized BnB Qwen3.5 GDN requires bitsandbytes for fused in_proj_qkv reconstruction." + ) + parts = [] + for i, qs in enumerate(qkv_states): + shard = weight[offsets[i]:offsets[i + 1]] + parts.append(dequantize_4bit(shard, quant_state=qs) if qs is not None else shard) + store(f"{prefix}.in_proj_qkv.weight", torch.cat(parts, dim=0)) + else: + store(f"{prefix}.in_proj_qkv.weight", qkv_weight) + if isinstance(qs_attr, dict): + _store_quant_state(f"{prefix}.in_proj_qkv", qkv_states[0]) + store(f"{prefix}.in_proj_z.weight", z_weight) + if isinstance(qs_attr, dict): + _store_quant_state(f"{prefix}.in_proj_z", qs_attr.get(3)) + + if weight.dtype == torch.float8_e4m3fn: + scale_attr = None + if hasattr(proj, "weight_scale"): + scale_attr = "weight_scale" + elif hasattr(proj, "weight_scale_inv"): + scale_attr = "weight_scale_inv" + ws = _unwrap(getattr(proj, scale_attr)) if scale_attr is not None else None + if ws is not None: + if ws.ndim == 2 and ws.shape[1] > 1: + block_size = proj.weight_block_size[0] + scale_offsets = [x // block_size for x in offsets] + qkv_scale = ws[scale_offsets[0]:scale_offsets[3]] + z_scale = ws[scale_offsets[3]:scale_offsets[4]] + else: + qkv_scale = ws[offsets[0]:offsets[3]] + z_scale = ws[offsets[3]:offsets[4]] + store(f"{prefix}.in_proj_qkv.{scale_attr}", qkv_scale) + store(f"{prefix}.in_proj_z.{scale_attr}", z_scale) + else: + get_state_dict(f"{prefix}.in_proj_qkv", 0, state_dict, gdn.in_proj_qkv, slice_weights=False) + get_state_dict(f"{prefix}.in_proj_z", 0, state_dict, gdn.in_proj_z, slice_weights=False) + + ba_layer = getattr(gdn.in_proj_ba, "base_layer", gdn.in_proj_ba) + raw_ba_weight = ba_layer.weight + ba_weight = _unwrap(raw_ba_weight) + mid = ba_weight.shape[0] // 2 + + ba_qs = getattr(raw_ba_weight, "bnb_quant_state", getattr(ba_weight, "bnb_quant_state", None)) + ba_states = [ba_qs.get(i) for i in (0, 1)] if isinstance(ba_qs, dict) else [None, None] + if isinstance(ba_qs, dict) and ba_states[0] is not None and ba_states[1] is None: + try: + from bitsandbytes.functional import dequantize_4bit + except Exception: + raise RuntimeError( + "Unsloth: prequantized BnB Qwen3.5 GDN requires bitsandbytes for in_proj_ba split." + ) + full = dequantize_4bit(ba_weight, quant_state=ba_states[0]) + full_mid = full.shape[0] // 2 + store(f"{prefix}.in_proj_b.weight", full[:full_mid]) + store(f"{prefix}.in_proj_a.weight", full[full_mid:]) + else: + store(f"{prefix}.in_proj_b.weight", ba_weight[:mid]) + store(f"{prefix}.in_proj_a.weight", ba_weight[mid:]) + if isinstance(ba_qs, dict): + _store_quant_state(f"{prefix}.in_proj_b", ba_states[0]) + _store_quant_state(f"{prefix}.in_proj_a", ba_states[1]) + + if ba_weight.dtype == torch.float8_e4m3fn: + scale_attr = None + if hasattr(ba_layer, "weight_scale"): + scale_attr = "weight_scale" + elif hasattr(ba_layer, "weight_scale_inv"): + scale_attr = "weight_scale_inv" + ws = _unwrap(getattr(ba_layer, scale_attr)) if scale_attr is not None else None + if ws is not None: + if ws.ndim == 2 and ws.shape[1] > 1: + block_size = ba_layer.weight_block_size[0] + scale_mid = mid // block_size + b_scale = ws[:scale_mid] + a_scale = ws[scale_mid:] + else: + b_scale = ws[:mid] + a_scale = ws[mid:] + store(f"{prefix}.in_proj_b.{scale_attr}", b_scale) + store(f"{prefix}.in_proj_a.{scale_attr}", a_scale) + + store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data) + store(f"{prefix}.dt_bias", gdn.dt_bias.data) + store(f"{prefix}.A_log", gdn.A_log.data) + + if hasattr(gdn, "norm") and hasattr(gdn.norm, "weight"): + store(f"{prefix}.norm.weight", gdn.norm.weight.data) + + get_state_dict(f"{prefix}.out_proj", 0, state_dict, gdn.out_proj) +pass + + def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_state_dict): """ Extracts vision layers for any supported vision model by dynamically using @@ -790,7 +1270,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if layer_module is not None: if "qkv" in layer_path: - if model_type in ("qwen2_5_vl", "qwen3_vl"): + if model_type in ("qwen2_5_vl", "qwen3_vl", "qwen3_5"): # If the HF model too prefers having merged qkv, we do this # This is evident in qwen-2.5-vl and qwen-3-vl so far. get_state_dict(layer_path, 0, state_dict, layer_module, slice_weights=False) @@ -809,7 +1289,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if isinstance(layer_module, torch.nn.Module): if hasattr(layer_module, 'weight'): get_state_dict(layer_path, 0, state_dict, layer_module) - elif isinstance(layer_module, torch.nn.Parameter): + elif isinstance(layer_module, torch.Tensor): state_dict[f"{layer_path}"] = layer_module.data quant_state_dict[f"{layer_path}"] = state_dict[f"{layer_path}"] else: @@ -824,7 +1304,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if hasattr(component, 'weight'): # Prefer using get_state_dict when possible get_state_dict(component_path, 0, state_dict, component) - elif isinstance(component, torch.nn.Parameter): + elif isinstance(component, torch.Tensor): state_dict[component_path] = component.data quant_state_dict[component_path] = component.data elif isinstance(component, torch.nn.Module): diff --git a/unsloth_zoo/hf_utils.py b/unsloth_zoo/hf_utils.py index cb96a9c51..4bd337dda 100644 --- a/unsloth_zoo/hf_utils.py +++ b/unsloth_zoo/hf_utils.py @@ -295,7 +295,7 @@ def get_auto_processor(name, **kwargs): with open(processor_config, "r", encoding="utf-8") as f: config = json.load(f) processor_class = config["processor_class"] - # Strip _Unsloth_Patched_ prefix from old saves (issue #4085) + # Strip _Unsloth_Patched_ prefix from old saves (unsloth issue 4085) if processor_class.startswith("_Unsloth_Patched_"): processor_class = processor_class[len("_Unsloth_Patched_"):] model_type = reversal_map[processor_class] @@ -345,7 +345,7 @@ def get_auto_processor(name, **kwargs): pass pass - # Fix _Unsloth_Patched_ prefix in copied config files (issue #4085) + # Fix _Unsloth_Patched_ prefix in copied config files (unsloth issue 4085) for cfg_name in ["processor_config.json", "preprocessor_config.json", "tokenizer_config.json"]: cfg_path = os.path.join(temp_name, cfg_name) if os.path.exists(cfg_path): diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 4d77c88a5..8f46546b4 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1063,6 +1063,15 @@ def get_state_dict(prefix, kk, state_dict, proj, slice_weights=True, slice_index quant_state_dict[prefix + ".bias"] = bias_tensor pass + if model_type == "gemma4" and getattr(text_config, "attention_k_eq_v", False): + gemma4_k_eq_v_layers = { + kk + for kk, layer_type in enumerate(getattr(text_config, "layer_types", ())) + if layer_type == "full_attention" + } + else: + gemma4_k_eq_v_layers = set() + # Embedding if hasattr(vllm_internals, "model"): # Standard Language models vllm_text_model = vllm_internals.model @@ -1107,7 +1116,9 @@ def _is_fused_module(name: str) -> bool: else: get_state_dict(f"{prefix}.q_proj", 0, state_dict, qkv_proj) get_state_dict(f"{prefix}.k_proj", 1, state_dict, qkv_proj) - get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) + if kk not in gemma4_k_eq_v_layers: + get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) + get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) elif hasattr(layer, "cross_attn"): prefix = f"{vllm_text_model_prefix}.layers.{kk}.cross_attn" qkv_proj = layer.cross_attn.qkv_proj @@ -1119,22 +1130,26 @@ def _is_fused_module(name: str) -> bool: get_state_dict(f"{prefix}.q_proj", 0, state_dict, q_proj) get_state_dict(f"{prefix}.k_proj", 1, state_dict, kv_proj) get_state_dict(f"{prefix}.v_proj", 2, state_dict, kv_proj) + get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) + elif hasattr(layer, "linear_attn"): + # Qwen3.5 Gated Delta Net (GDN) linear attention layers + extract_gdn_layers( + layer.linear_attn, + f"{vllm_text_model_prefix}.layers.{kk}.linear_attn", + state_dict, quant_state_dict, get_state_dict, + ) + pass - get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) - - proj = layer.mlp.gate_up_proj - use_fused_gate_up = _is_fused_module("gate_up_proj") - if use_fused_gate_up: - # For some model types like phi3 vllm will expect fused gate_up_proj (e.g. Phi3, Phi3.5-mini-instruct, Phi4-mini-instruct) - # so we should not split them here otherwise there will be a size mismatch when activating the adapter - # see https://github.com/vllm-project/vllm/blob/9b693d023cf595e60b5346fdeeb41cf2a6eda838/vllm/model_executor/models/phi3.py - get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.gate_up_proj", 0, state_dict, proj, slice_weights=False) - else: - get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.gate_proj", 0, state_dict, proj) - get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.up_proj", 1, state_dict, proj) - - proj = layer.mlp.down_proj - get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.down_proj", 0, state_dict, proj) + if hasattr(layer, "per_layer_input_gate"): + get_state_dict( + f"{vllm_text_model_prefix}.layers.{kk}.per_layer_input_gate", + 0, state_dict, layer.per_layer_input_gate, + ) + if hasattr(layer, "per_layer_projection"): + get_state_dict( + f"{vllm_text_model_prefix}.layers.{kk}.per_layer_projection", + 0, state_dict, layer.per_layer_projection, + ) # Use layernorms from the layer configuration layernorm_names = [name.format(kk=kk) for name in layer_config['layernorms']] @@ -1149,6 +1164,27 @@ def _is_fused_module(name: str) -> bool: except Exception as e: skipped_layernorms.append(layernorm_name.split(".")[-1]) pass + + if hasattr(layer, "layer_scalar"): + state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data + quant_state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data + + if not hasattr(layer, "mlp"): + continue + + proj = layer.mlp.gate_up_proj + use_fused_gate_up = _is_fused_module("gate_up_proj") + if use_fused_gate_up: + # For some model types like phi3 vllm will expect fused gate_up_proj (e.g. Phi3, Phi3.5-mini-instruct, Phi4-mini-instruct) + # so we should not split them here otherwise there will be a size mismatch when activating the adapter + # see https://github.com/vllm-project/vllm/blob/9b693d023cf595e60b5346fdeeb41cf2a6eda838/vllm/model_executor/models/phi3.py + get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.gate_up_proj", 0, state_dict, proj, slice_weights=False) + else: + get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.gate_proj", 0, state_dict, proj) + get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.up_proj", 1, state_dict, proj) + + proj = layer.mlp.down_proj + get_state_dict(f"{vllm_text_model_prefix}.layers.{kk}.mlp.down_proj", 0, state_dict, proj) pass if len(skipped_layernorms) != 0: @@ -1165,11 +1201,26 @@ def _is_fused_module(name: str) -> bool: state_dict[norm_prefix] = vllm_text_model.norm.weight.data quant_state_dict[norm_prefix] = state_dict[norm_prefix] + # Gemma4 top-level per-layer-input modules + for extra_name in ("embed_tokens_per_layer", "per_layer_model_projection", "per_layer_projection_norm"): + component = getattr(vllm_text_model, extra_name, None) + if component is None: + continue + prefix = f"{vllm_text_model_prefix}.{extra_name}" + if hasattr(component, "weight"): + get_state_dict(prefix, 0, state_dict, component, slice_weights=False) + else: + for param_name, param in component.named_parameters(): + key = f"{prefix}.{param_name}" + state_dict[key] = param.data + quant_state_dict[key] = param.data + # LM Head - Use get_state_dict for consistency if not getattr(text_config, "tie_word_embeddings", False): - lm_layer = [mod for name,mod in vllm_internals.named_modules() if "lm_head" in name] - # Use get_state_dict for consistent extraction and automatic truncation - get_state_dict("lm_head", 0, state_dict, lm_layer[0], slice_weights=False) + lm_layer = next((mod for name, mod in vllm_internals.named_modules() if name == "lm_head" or name.endswith(".lm_head")), None) + if lm_layer is None: + raise RuntimeError("Unsloth: could not find lm_head in vLLM internals") + get_state_dict("lm_head", 0, state_dict, lm_layer, slice_weights=False) else: # Fallback to embed_tokens for tied embeddings embed_key = f"{vllm_text_model_prefix}.embed_tokens.weight" @@ -1186,8 +1237,16 @@ def _is_fused_module(name: str) -> bool: @torch.inference_mode def assert_same_state_dict(old_state_dict, new_state_dict): # All Unsloth Zoo code licensed under LGPLv3 - # Check if state_dict are equivalent - # hf, vllm + # args: hf, vllm + + def _normalize_state_dict_tensor(value): + if isinstance(value, torch.nn.Parameter): + value = value.detach() + if not isinstance(value, torch.Tensor): + return None + if value.is_sparse: + value = value.to_dense() + return value.contiguous() difference = new_state_dict.keys() ^ old_state_dict.keys() difference -= set(("model.lm_head.weight","model.language_model.lm_head.weight", "lm_head.weight")) @@ -1202,13 +1261,18 @@ def assert_same_state_dict(old_state_dict, new_state_dict): for key in old_state_dict: try: - old_val = old_state_dict[key] - new_val = new_state_dict[key] - if old_val.dtype != new_val.dtype or (new_val.element_size() < 2): + old_val = _normalize_state_dict_tensor(old_state_dict[key]) + new_val = _normalize_state_dict_tensor(new_state_dict[key]) + if old_val is None or new_val is None: + continue + loose_tol = old_val.dtype != new_val.dtype or (new_val.element_size() < 2) + if loose_tol: # upcast both to float32 just for comparison. For FP8, vLLM stores weight scale in FP32 while HF preferes 16bit old_val = old_val.to(torch.float32) new_val = new_val.to(torch.float32) - torch.testing.assert_close(old_val, new_val, check_stride = False, atol = 1e-4, rtol = 1e-3) + torch.testing.assert_close(old_val, new_val, check_stride = False, atol = 1e-4, rtol = 1e-3) + else: + torch.testing.assert_close(old_val, new_val, check_stride = False) except Exception as error: if key == "lm_head.weight": # Try tied embeddings fallback @@ -1217,7 +1281,13 @@ def assert_same_state_dict(old_state_dict, new_state_dict): if key1 is not None and key2 is not None: try: - torch.testing.assert_close(old_state_dict[key1].contiguous(), new_state_dict[key2].contiguous(), check_stride = True) + torch.testing.assert_close( + _normalize_state_dict_tensor(old_state_dict[key1]), + _normalize_state_dict_tensor(new_state_dict[key2]), + check_stride = False, + atol = 1e-4, + rtol = 1e-3, + ) except Exception: failures[key] = error else: @@ -1235,7 +1305,14 @@ def assert_same_state_dict(old_state_dict, new_state_dict): def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16, bnb_config = None, is_vision_model = False): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules to create HF compatible model + def _unwrap_tensor(value): + return getattr(value, "data", value) + set_dtype_in_config(config, dtype) + for subconfig_name in ("text_config", "vision_config", "audio_config"): + subconfig = getattr(config, subconfig_name, None) + if subconfig is not None: + set_dtype_in_config(subconfig, dtype) new_model, original_meta_model, layer_count, layer_names = create_empty_model(config, dtype, is_vision_model) new_model = new_model.to(device = get_target_device(), dtype = dtype) quantization_config = getattr(config, "quantization_config", {}) @@ -1331,16 +1408,14 @@ def _override_to(self, *args, **kwargs): weight = quant_state_dict[f"{layer_name}.weight"] if f"{layer_name}.bias" in quant_state_dict: - # Has bias! has_bias = True - bias = quant_state_dict[f"{layer_name}.bias"] + bias = _unwrap_tensor(quant_state_dict[f"{layer_name}.bias"]) bias = torch.nn.Parameter(bias, requires_grad = False) else: has_bias = False bias = None pass - # check if either of layer_name.weight_scale or layer_name.weight_scale_inv exists and set that attribute to fp8_weight_scale fp8_weight_scale = None if f"{layer_name}.weight_scale" in quant_state_dict: fp8_weight_scale = quant_state_dict[f"{layer_name}.weight_scale"] @@ -1352,9 +1427,15 @@ def _override_to(self, *args, **kwargs): if layer_name in quant_state_dict: # for attributes of type nn.Parameter, there's no .weight - layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name.replace('model.','',1)) - layer = torch.nn.Parameter(weight, requires_grad = False) - exec(f"new_model.{layer_name_br} = layer") + layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) + raw_value = _unwrap_tensor(weight) + parent_path, _, attr_name = layer_name_br.rpartition(".") + parent = eval(f"new_model.{parent_path}") if parent_path else new_model + if attr_name in getattr(parent, "_buffers", {}): + parent._buffers[attr_name] = raw_value + else: + layer = torch.nn.Parameter(raw_value, requires_grad = False) + exec(f"new_model.{layer_name_br} = layer") continue elif fp8_weight_scale is not None: if fp8_weight_scale.ndim == 1: @@ -1362,7 +1443,7 @@ def _override_to(self, *args, **kwargs): layer = FbgemmFp8Linear(in_features = 0, out_features = 0, bias = has_bias, weight_dtype = dtype).to(get_target_device()) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] - layer.weight = torch.nn.Parameter(weight, requires_grad = False) + layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias layer.input_scale_ub = kwargs['input_scale_ub'] layer.weight_scale = torch.nn.Parameter(fp8_weight_scale, requires_grad = False) @@ -1378,12 +1459,11 @@ def _override_to(self, *args, **kwargs): layer = FP8Linear(**fp8_kwargs) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] - layer.weight = torch.nn.Parameter(weight, requires_grad = False) + layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias layer.weight_scale_inv = torch.nn.Parameter(fp8_weight_scale, requires_grad = False) layer.quant_method = "fp8" elif f"{layer_name}.weight.quant_state" in quant_state_dict: - # Layer is quantized! quant_state = quant_state_dict[f"{layer_name}.weight.quant_state"] layer = Linear4bit(0, 0, device = get_target_device(), bias = has_bias, compute_dtype = compute_dtype, **kwargs) layer.in_features = quant_state.shape[1] @@ -1396,22 +1476,37 @@ def _override_to(self, *args, **kwargs): layer.to = partial(_override_to, layer) layer.weight.to = partial(_override_to, layer.weight) + elif layer_name.endswith(".conv1d") and "linear_attn" in layer_name: + # Qwen3.5 GDN depthwise Conv1d: rebuild with real channels/kernel/groups. + from torch.nn import Conv1d + conv_weight = _unwrap_tensor(weight) + channels = conv_weight.shape[0] + kernel_size = conv_weight.shape[-1] + layer = Conv1d( + in_channels = channels, + out_channels = channels, + kernel_size = kernel_size, + groups = channels, + padding = kernel_size - 1, + bias = has_bias, + device = get_target_device(), + ) + layer.weight = torch.nn.Parameter(conv_weight, requires_grad = False) + layer.bias = bias elif not any(x in layer_name for x in layernorm_names): layer = Linear(0, 0, device = get_target_device(), bias = has_bias) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] # from vllm 0.11.1, the .weight is of dtype ModelWeightParameter, so try to extract the 'data' part # https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170#diff-7d6145ac4ba084231a441c2056c7fca23c3bae33e6542f4f602a6c9d4d2da64dL199-R208 - layer.weight = torch.nn.Parameter(getattr(weight, 'data', weight), requires_grad = False) + layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias else: # LayerNorms (including vision norms) - weight_param = torch.nn.Parameter(weight, requires_grad=False) + weight_param = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad=False) layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) - # Set weight exec(f"new_model.{layer_name_br}.weight = None") exec(f"new_model.{layer_name_br}.weight = weight_param") - # Set bias if it exists if bias is not None: exec(f"new_model.{layer_name_br}.bias = None") exec(f"new_model.{layer_name_br}.bias = bias") @@ -1425,49 +1520,14 @@ def _override_to(self, *args, **kwargs): pass set_additional_modules(new_model, quant_state_dict, config) - - if original_meta_model is not None: - copy_attributes(original_meta_model, new_model) - - # # Set config on model and modules using clean approach - # new_model.config = config - # for module in new_model.modules(): - # if hasattr(module, "config"): - # module.config = config - # for param in new_model.parameters(): - # if hasattr(param, "config"): - # param.config = config - - text_config = getattr(config, "text_config", config) #try using text config for VLMs - vision_config = getattr(config, "vision_config", None) - # Fix up rotary_emb by re-initing them - for module in new_model.modules(): - if hasattr(module, "rotary_emb"): - module.rotary_emb = module.rotary_emb.__class__( - config = text_config, - device = get_target_device(), - ) - if hasattr(module, "rotary_pos_emb"): - # Qwen 2.5 VL has a rotary_pos_emb in vision submodel - # https://github.com/huggingface/transformers/blob/a871f6f58d49f3a05ae9dae519caa8aa9d919a07/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L337 - assert vision_config is not None, "Unsloth: vision_config is required for models with vision rotary_pos_emb" - head_dim = vision_config.hidden_size // vision_config.num_heads - module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(get_target_device()) - if hasattr(module, "rotary_emb_local"): - # gemma3 has a rotary_emb_local - # https://github.com/huggingface/transformers/blob/008c0ba8e2a1226a6ef5a61c4915a0a8a340c157/src/transformers/models/gemma3/modeling_gemma3.py#L469-L471 - # Gemma3 uses different defaults for local and global RoPE. Copy the config for modification. - local_rope_config = deepcopy(text_config) - local_rope_config.rope_theta = text_config.rope_local_base_freq - local_rope_config.rope_scaling = {"rope_type": "default"} - # gemma3 has a rotary_emb_local - module.rotary_emb_local = module.rotary_emb_local.__class__( - config = local_rope_config, - device = get_target_device(), - ) - del local_rope_config - pass - pass + new_model = finalize_huggingface_model( + new_model, + original_meta_model, + config, + dtype, + quantization_config = quantization_config, + bnb_config = bnb_config, + ) # Must override or else Bitsandbytes will error new_model.to = partial(_override_to, new_model) @@ -1771,6 +1831,12 @@ def load_vllm( assert(type(use_bitsandbytes) is bool) assert(conservativeness >= 0.0 and conservativeness <= 1.0) + if getattr(config, "model_type", None) == "gemma4": + if enable_lora: + patch_gemma4_vllm_lora_support() + if use_bitsandbytes: + patch_gemma4_vllm_k_eq_v_support() + unsloth_vllm_standby = unsloth_vllm_standby or (os.getenv("UNSLOTH_VLLM_STANDBY", "0") != "0") # This would give the flexibility to override the util we set for standby mode. In some extreme cases, this can be helpful. standby_util_override = os.getenv("UNSLOTH_VLLM_STANDBY_UTIL_OVERRIDE", "0") != "0" @@ -2866,10 +2932,27 @@ def _test_is_same_vlm(model, new_model, processor, test_backward=False): messages, tokenize=False, add_generation_prompt=True ) - inputs = processor.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, - return_dict=True, return_tensors="pt" - ).to(model.device, dtype=model.dtype) + if processor.__class__.__name__ in ("Gemma3Processor", "Gemma4Processor"): + try: + from transformers.image_utils import load_image + image = load_image(messages[0]["content"][0]["image"]) + except Exception: + from PIL import Image + image = Image.new("RGB", (224, 224), color = (128, 128, 128)) + inputs = processor( + text = [text], + images = [image], + return_tensors = "pt", + ) + else: + inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, + return_dict=True, return_tensors="pt", + ) + inputs = inputs.to(model.device) + for _k, _v in list(inputs.items()): + if torch.is_tensor(_v) and torch.is_floating_point(_v): + inputs[_k] = _v.to(dtype = model.dtype) with torch.no_grad(): original_outputs = model(**inputs) @@ -2989,6 +3072,7 @@ def _test_get_vllm_state_dict( load_in_4bit = False, skip_generation = False, is_vision_model = False, + compilation_config = None, ): # All Unsloth Zoo code licensed under LGPLv3 # Check if model is allowed to be used in vLLM @@ -3028,6 +3112,8 @@ def _test_get_vllm_state_dict( model_type = getattr(config, "model_type", "causal_lm") enable_lora = model_type != "mllama" + if compilation_config is None and model_type == "gemma4": + compilation_config = 0 if not is_vision_model: model_class = AutoModelForCausalLM @@ -3069,6 +3155,7 @@ def _test_get_vllm_state_dict( use_bitsandbytes = load_in_4bit, is_vision_model = is_vision_model, enable_lora = enable_lora, + compilation_config = compilation_config, ) state_dict, quant_state_dict = get_vllm_state_dict( @@ -3082,6 +3169,8 @@ def _test_get_vllm_state_dict( new_model = convert_vllm_to_huggingface(quant_state_dict, config, dtype, is_vision_model = is_vision_model) test_model_conversion(model, new_model) + new_model, _ = patch_model_and_tokenizer(new_model, None) + new_model.eval() # Run the model as well if not is_vision_model: