diff --git a/tests/test_vllm_to_hf_conversion.py b/tests/test_vllm_to_hf_conversion.py new file mode 100644 index 000000000..ae6906301 --- /dev/null +++ b/tests/test_vllm_to_hf_conversion.py @@ -0,0 +1,259 @@ +import sys, os, warnings, inspect +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import types +import pytest +import torch + + +class _FakePlainProj(torch.nn.Module): + def __init__(self, out_features, in_features, dtype=torch.float32): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(out_features, in_features, dtype=dtype), requires_grad=False) + + +class _FakeGDN(torch.nn.Module): + def __init__(self, hidden_size=8, num_k_heads=2, num_v_heads=2, head_k_dim=2, head_v_dim=4): + super().__init__() + self.hidden_size = hidden_size + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.key_dim = num_k_heads * head_k_dim + self.value_dim = num_v_heads * head_v_dim + qkvz_dim = self.key_dim * 2 + self.value_dim * 2 + self.in_proj_qkvz = _FakePlainProj(qkvz_dim, hidden_size) + self.in_proj_ba = _FakePlainProj(num_v_heads * 2, hidden_size) + self.conv1d = _FakePlainProj(self.key_dim * 2 + self.value_dim, 4) + self.dt_bias = torch.nn.Parameter(torch.randn(num_v_heads), requires_grad=False) + self.A_log = torch.nn.Parameter(torch.randn(num_v_heads), requires_grad=False) + self.norm = torch.nn.Module() + self.norm.weight = torch.nn.Parameter(torch.randn(head_v_dim), requires_grad=False) + self.out_proj = _FakePlainProj(hidden_size, self.value_dim) + + +def _fake_get_state_dict(prefix, kk, state_dict, module, slice_weights=True): + state_dict[f"{prefix}.weight"] = module.weight.data + + +def test_extract_gdn_layers_handles_plain_column_parallel_linear(): + # Pre-fix: vllm ColumnParallelLinear has no `output_sizes` -> AttributeError. + from unsloth_zoo.empty_model import extract_gdn_layers + gdn = _FakeGDN() + state_dict, quant_state_dict = {}, {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + expected = { + "prefix.in_proj_qkv.weight", + "prefix.in_proj_z.weight", + "prefix.in_proj_b.weight", + "prefix.in_proj_a.weight", + "prefix.conv1d.weight", + "prefix.dt_bias", + "prefix.A_log", + "prefix.norm.weight", + "prefix.out_proj.weight", + } + assert expected <= set(state_dict.keys()) + + +def test_extract_gdn_layers_splits_in_proj_ba_without_indexerror(): + # Pre-fix: get_state_dict(kk=1, in_proj_ba) -> IndexError (no output_sizes). + from unsloth_zoo.empty_model import extract_gdn_layers + gdn = _FakeGDN() + state_dict, quant_state_dict = {}, {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + ba_weight = gdn.in_proj_ba.weight.data + mid = ba_weight.shape[0] // 2 + torch.testing.assert_close(state_dict["prefix.in_proj_b.weight"], ba_weight[:mid]) + torch.testing.assert_close(state_dict["prefix.in_proj_a.weight"], ba_weight[mid:]) + + +def test_extract_gdn_layers_qkvz_offsets_match_gdn_dims(): + from unsloth_zoo.empty_model import extract_gdn_layers + gdn = _FakeGDN(num_k_heads=3, num_v_heads=2, head_k_dim=4, head_v_dim=5) + state_dict, quant_state_dict = {}, {} + extract_gdn_layers(gdn, "prefix", state_dict, quant_state_dict, _fake_get_state_dict) + assert state_dict["prefix.in_proj_qkv.weight"].shape[0] == 2 * gdn.key_dim + gdn.value_dim + assert state_dict["prefix.in_proj_z.weight"].shape[0] == gdn.value_dim + + +def test_extract_gdn_layers_raises_when_offsets_underivable(): + from unsloth_zoo.empty_model import extract_gdn_layers + gdn = _FakeGDN() + del gdn.key_dim + del gdn.value_dim + with pytest.raises(RuntimeError, match="in_proj_qkvz"): + extract_gdn_layers(gdn, "prefix", {}, {}, _fake_get_state_dict) + + +def test_extract_gdn_layers_has_bnb_quant_state_preservation(): + # Pre-fix: merged in_proj_qkvz path only stored raw weight slices; BnB prequantized + # checkpoints lost quant_state metadata and were rebuilt as plain nn.Linear. + # Behavioral test requires real BnB; source-level check confirms the branch exists. + from unsloth_zoo import empty_model + src = inspect.getsource(empty_model.extract_gdn_layers) + assert "bnb_quant_state" in src + assert "in_proj_qkv.weight.quant_state" in src + assert "in_proj_z.weight.quant_state" in src + + +class _LinearAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_idx = -1 + + +class _StandardLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_idx = -1 + self.linear_attn = _LinearAttn() + + +class _StandardLM(torch.nn.Module): + def __init__(self, n_layers=3): + super().__init__() + + class _Inner(torch.nn.Module): + def __init__(self, n): + super().__init__() + self.layers = torch.nn.ModuleList([_StandardLayer() for _ in range(n)]) + + self.model = _Inner(n_layers) + + +def _config(model_type="qwen3_5", has_vision=False): + cfg = types.SimpleNamespace() + cfg.model_type = model_type + cfg.text_config = cfg + if has_vision: + vc = types.SimpleNamespace() + vc.hidden_size = 1 + vc.num_heads = 1 + cfg.vision_config = vc + return cfg + + +def test_finalize_fixes_layer_idx_on_standard_causal_lm(): + # Pre-fix: only new_model.model.language_model.layers was traversed, so + # standard-LM paths kept layer_idx at the empty-model stub value. + from unsloth_zoo.empty_model import finalize_huggingface_model + model = _StandardLM(n_layers=4) + finalize_huggingface_model( + model, None, _config("qwen3_5"), torch.float16, + quantization_config={"x": 1}, bnb_config=None, + ) + for i, layer in enumerate(model.model.layers): + assert layer.layer_idx == i + assert layer.linear_attn.layer_idx == i + + +def test_finalize_fixes_layer_idx_on_vlm_language_model_path(): + from unsloth_zoo.empty_model import finalize_huggingface_model + + class _VLM(torch.nn.Module): + def __init__(self): + super().__init__() + + class _Inner(torch.nn.Module): + def __init__(self): + super().__init__() + + class _LM(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([_StandardLayer() for _ in range(3)]) + + self.language_model = _LM() + + self.model = _Inner() + + model = _VLM() + finalize_huggingface_model( + model, None, _config(), torch.float16, + quantization_config={"x": 1}, bnb_config=None, + ) + for i, layer in enumerate(model.model.language_model.layers): + assert layer.layer_idx == i + assert layer.linear_attn.layer_idx == i + + +def test_finalize_does_not_assert_on_text_only_with_rotary_pos_emb(): + # Pre-fix: hard `assert vision_config is not None` crashed text-only models. + from unsloth_zoo.empty_model import finalize_huggingface_model + + class _Rotary(torch.nn.Module): + pass + + class _Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.rotary_pos_emb = _Rotary() + + class _Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = torch.nn.Module() + self.model.layers = torch.nn.ModuleList([_Layer()]) + + finalize_huggingface_model( + _Model(), None, _config(has_vision=False), torch.float16, + quantization_config={"x": 1}, bnb_config=None, + ) + + +def test_set_dtype_in_config_no_torch_dtype_deprecation(): + # Pre-fix: wrote both dtype and torch_dtype -> transformers deprecation warning. + from transformers import PretrainedConfig + from unsloth_zoo.hf_utils import set_dtype_in_config + cfg = PretrainedConfig() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + set_dtype_in_config(cfg, torch.bfloat16) + dep = [w for w in caught if "torch_dtype" in str(w.message) and "deprecated" in str(w.message).lower()] + assert not dep, f"unexpected deprecation warning: {[str(w.message) for w in dep]}" + + +def test_set_dtype_in_config_writes_torch_dtype_value(): + from transformers import PretrainedConfig + from unsloth_zoo.hf_utils import set_dtype_in_config + cfg = PretrainedConfig() + set_dtype_in_config(cfg, torch.float16) + got = getattr(cfg, "dtype", None) or getattr(cfg, "torch_dtype", None) + assert got == torch.float16 + + +def test_set_dtype_in_config_accepts_string_input(): + from transformers import PretrainedConfig + from unsloth_zoo.hf_utils import set_dtype_in_config + cfg = PretrainedConfig() + set_dtype_in_config(cfg, "bfloat16") + got = getattr(cfg, "dtype", None) or getattr(cfg, "torch_dtype", None) + assert got == torch.bfloat16 + + +def test_normalize_state_dict_tensor_guards_non_tensor(): + # Pre-fix: value.is_sparse was called unconditionally on any state-dict value. + from unsloth_zoo import vllm_utils + src = inspect.getsource(vllm_utils.assert_same_state_dict) + assert "isinstance(value, torch.Tensor)" in src + assert src.index("isinstance(value, torch.Tensor)") < src.index("value.is_sparse") + + +def test_gemma4_lora_patch_preserves_signature_for_inspect(): + # Pre-fix: patched_create_lora_manager(model, *args, **kwargs) hid vllm_config, + # breaking _call_create_lora_manager's signature-based forwarding. + from unsloth_zoo import empty_model + src = inspect.getsource(empty_model.patch_gemma4_vllm_lora_support) + assert "@wraps(original_create_lora_manager)" in src + assert "lora_manager_cls(model, *args, **kwargs)" in src + + +def test_gemma4_k_eq_v_patch_handles_split_kv_layout(): + # Pre-fix: only packed self_attn.qkv_proj.weight was searched, so current upstream + # Gemma4 split q_proj/k_proj/v_proj layout never got synthetic V quant-state. + from unsloth_zoo import empty_model + src = inspect.getsource(empty_model.patch_gemma4_vllm_k_eq_v_support) + assert "k_proj.weight" in src and "v_proj.weight" in src + assert '"split"' in src or "'split'" in src diff --git a/unsloth_zoo/empty_model.py b/unsloth_zoo/empty_model.py index f9ff7cba0..229c09c67 100644 --- a/unsloth_zoo/empty_model.py +++ b/unsloth_zoo/empty_model.py @@ -17,6 +17,10 @@ __all__ = [ "create_empty_model", "set_additional_modules", + "finalize_huggingface_model", + "patch_gemma4_vllm_lora_support", + "patch_gemma4_vllm_k_eq_v_support", + "extract_gdn_layers", "extract_vision_layers", "get_model_layer_config", "compare_attributes", @@ -29,7 +33,7 @@ from copy import deepcopy from .utils import get_quant_type from .log import logger -from .hf_utils import HAS_TORCH_DTYPE, dtype_from_config +from .hf_utils import HAS_TORCH_DTYPE, dtype_from_config, set_dtype_in_config def is_comparable(val): # Don't treat tensors as comparable, only basic types @@ -280,6 +284,14 @@ def create_empty_causal_lm(config, dtype = torch.float16): new_config.vocab_size = 1 new_config.pad_token_id = 0 + _set_config_attrs(new_config, { + "linear_num_key_heads": 1, + "linear_num_value_heads": 1, + "linear_key_head_dim": 1, + "linear_value_head_dim": 1, + "linear_conv_kernel_dim": 1, + }) + # Set attention module head_dim head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) new_config.update({"head_dim" : head_dim}) @@ -298,6 +310,135 @@ def _set_config_attrs(config_obj, attrs_to_set): setattr(config_obj, attr, value) pass +def _get_model_device(model): + for tensor in model.parameters(): + return tensor.device + for tensor in model.buffers(): + return tensor.device + return torch.device("cpu") +pass + +def patch_gemma4_vllm_lora_support(): + from functools import wraps + from vllm.model_executor.models.gemma4_mm import Gemma4ForConditionalGeneration + from vllm.model_executor.models import interfaces as vllm_model_interfaces + from vllm.lora import model_manager as vllm_lora_model_manager + from vllm.v1.worker import lora_model_runner_mixin + from unsloth_zoo import vllm_lora_worker_manager + + Gemma4ForConditionalGeneration.supports_lora = True + Gemma4ForConditionalGeneration.embedding_modules = {} + + if not hasattr(lora_model_runner_mixin.supports_lora, "_unsloth_gemma4_patch"): + original_supports_lora = lora_model_runner_mixin.supports_lora + + def patched_supports_lora(model): + if model.__class__.__name__ == "Gemma4ForConditionalGeneration": + return True + return original_supports_lora(model) + + patched_supports_lora._unsloth_gemma4_patch = True + lora_model_runner_mixin.supports_lora = patched_supports_lora + vllm_model_interfaces.supports_lora = patched_supports_lora + + if not hasattr(vllm_lora_model_manager.create_lora_manager, "_unsloth_gemma4_patch"): + original_create_lora_manager = vllm_lora_model_manager.create_lora_manager + + @wraps(original_create_lora_manager) + def patched_create_lora_manager(model, *args, **kwargs): + if model.__class__.__name__ == "Gemma4ForConditionalGeneration": + lora_manager_cls = kwargs.pop("lora_manager_cls", vllm_lora_model_manager.LoRAModelManager) + return lora_manager_cls(model, *args, **kwargs) + return original_create_lora_manager(model, *args, **kwargs) + + patched_create_lora_manager._unsloth_gemma4_patch = True + vllm_lora_model_manager.create_lora_manager = patched_create_lora_manager + vllm_lora_worker_manager.create_lora_manager = patched_create_lora_manager +pass + +# vLLM's Gemma4 k_eq_v path now expects qkv_proj to always expose q+k+v. +# For prequantized bitsandbytes checkpoints, the synthetic v shard is still +# missing from the quant-state dict on full-attention k_eq_v layers, so we +# materialize it during loader-side quant-state stacking instead of patching +# the runtime attention forward. +def patch_gemma4_vllm_k_eq_v_support(): + from vllm.model_executor.model_loader.bitsandbytes_loader import ( + BitsAndBytesModelLoader, + ) + + if hasattr( + BitsAndBytesModelLoader._stack_quantization_states, + "_unsloth_gemma4_k_eq_v_patch", + ): + return + + original_stack_quantization_states = ( + BitsAndBytesModelLoader._stack_quantization_states + ) + + def _get_gemma4_text_config(model): + config = getattr(model, "config", None) + if config is None: + return None + + text_config = getattr(config, "text_config", config) + model_type = getattr(config, "model_type", None) + text_model_type = getattr(text_config, "model_type", None) + if model_type != "gemma4" and text_model_type not in ("gemma4", "gemma4_text"): + return None + return text_config + + def _get_gemma4_k_eq_v_pairs(model): + text_config = _get_gemma4_text_config(model) + if text_config is None or not getattr(text_config, "attention_k_eq_v", False): + return () + + param_names = set(name for name, _ in model.named_parameters()) + pairs = [] + for layer_idx, layer_type in enumerate(getattr(text_config, "layer_types", ())): + if layer_type != "full_attention": + continue + + for prefix in ("language_model.model", "model"): + k_name = f"{prefix}.layers.{layer_idx}.self_attn.k_proj.weight" + v_name = f"{prefix}.layers.{layer_idx}.self_attn.v_proj.weight" + qkv_name = f"{prefix}.layers.{layer_idx}.self_attn.qkv_proj.weight" + if k_name in param_names: + pairs.append(("split", k_name, v_name)) + break + if qkv_name in param_names: + pairs.append(("packed", qkv_name, None)) + break + return tuple(pairs) + + def patched_stack_quantization_states(self, model, quant_state_dict): + stacked_quant_state_dict = original_stack_quantization_states( + self, model, quant_state_dict + ) + + for kind, source, target in _get_gemma4_k_eq_v_pairs(model): + quant_states = stacked_quant_state_dict.get(source) + if quant_states is None: + continue + + # Gemma4 full-attention k_eq_v layers reuse K as V. The raw weight + # loader already duplicates k_proj -> v_proj; prequant BnB needs the + # same duplication for shard-local QuantState metadata. + if kind == "packed": + if isinstance(quant_states, dict) and 2 not in quant_states and 1 in quant_states: + quant_states[2] = deepcopy(quant_states[1]) + elif kind == "split": + if target not in stacked_quant_state_dict: + stacked_quant_state_dict[target] = deepcopy(quant_states) + + return stacked_quant_state_dict + + patched_stack_quantization_states._unsloth_gemma4_k_eq_v_patch = True + BitsAndBytesModelLoader._stack_quantization_states = ( + patched_stack_quantization_states + ) +pass + @torch.inference_mode def create_empty_vision_model(config, dtype = torch.float16): @@ -352,6 +493,14 @@ def _init_weights(self, module): "head_dim": 1, "pad_token_id": 1, }) + # Qwen 3.5 or GDN related attrs + _set_config_attrs(new_config.text_config, { + "linear_num_key_heads": 1, + "linear_num_value_heads": 1, + "linear_key_head_dim": 1, + "linear_value_head_dim": 1, + "linear_conv_kernel_dim": 1, + }) # Common vision attributes _set_config_attrs(new_config.vision_config, { @@ -369,13 +518,9 @@ def _init_weights(self, module): text_layers = config.text_config.num_hidden_layers vision_layers = getattr(config.vision_config, "num_hidden_layers", None) or getattr(config.vision_config, "depth", 0) - # Set minimal sizes for different model types - if model_type == "qwen2_5_vl": - new_config.vision_config.out_hidden_size = 1 - elif model_type == "qwen3_vl": + if model_type in ("qwen2_5_vl", "qwen3_vl", "qwen3_5"): new_config.vision_config.out_hidden_size = 1 - num_layers = max(text_layers, vision_layers) new_model = model_cls(new_config) @@ -400,9 +545,15 @@ def create_empty_model(config, dtype = torch.float16, is_vision_model = False): @torch.inference_mode def set_additional_modules(new_model, quant_state_dict, config): + def _unwrap_tensor(val): + return getattr(val, "data", val) + if hasattr(new_model, "language_model"): language_model = new_model.language_model language_model_prefix = "model.language_model" + elif hasattr(new_model, "model") and hasattr(new_model.model, "language_model"): + language_model = new_model.model.language_model + language_model_prefix = "model.language_model" else: language_model_prefix = "model" language_model = new_model.model @@ -425,7 +576,7 @@ def set_additional_modules(new_model, quant_state_dict, config): # we cannot use the normal embedding init because gemma3 uses Gemma3TextScaledWordEmbedding which wraps around nn.Embedding and has a scaling factor. This new init ensures that we respect the forward from original model. def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): num_embeddings, embedding_dim = quant_state_dict[embed_tokens_key].shape - embeddings = quant_state_dict[embed_tokens_key] + embeddings = _unwrap_tensor(quant_state_dict[embed_tokens_key]) if isinstance(embeddings, torch.Tensor): # in the newer vLLM versions, this seems to return a tensor which can't be assigned to embedding weight # we need to convert that to nn.Paramter and then pass it on @@ -444,6 +595,7 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): # Norm norm_key = f"{language_model_prefix}.norm.weight" norm = quant_state_dict[norm_key] + norm = _unwrap_tensor(norm) norm = torch.nn.Parameter(norm, requires_grad = False) language_model.norm.weight = norm @@ -458,7 +610,7 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): # Check if lm_head exists in the state dict if lmhead_key in quant_state_dict: - weight = quant_state_dict[lmhead_key] + weight = _unwrap_tensor(quant_state_dict[lmhead_key]) from torch.nn import Linear # Create Linear layer with zero dimensions to avoid any weight allocation @@ -500,6 +652,7 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): for prefix in ['new_', 'new_model.']: try: val = quant_state_dict[key] + val = _unwrap_tensor(val) if isinstance(val, torch.Tensor): val = torch.nn.Parameter(val,requires_grad=False) exec(f"{prefix}{key} = val") @@ -510,6 +663,110 @@ def set_embedding(module, embed_tokens_key, pad_token_id, requires_grad=False): pass pass +@torch.inference_mode +def finalize_huggingface_model( + new_model, + original_meta_model, + config, + dtype, + quantization_config = None, + bnb_config = None, +): + if original_meta_model is not None: + copy_attributes(original_meta_model, new_model) + + if hasattr(new_model, "language_model"): + lm_root = new_model.language_model + elif hasattr(new_model, "model") and hasattr(new_model.model, "language_model"): + lm_root = new_model.model.language_model + else: + lm_root = getattr(new_model, "model", None) + + if lm_root is not None and hasattr(lm_root, "layers"): + for layer_idx, layer in enumerate(lm_root.layers): + if hasattr(layer, "layer_idx"): + layer.layer_idx = layer_idx + for attr_name in ("self_attn", "cross_attn", "mlp", "linear_attn"): + submodule = getattr(layer, attr_name, None) + if submodule is not None and hasattr(submodule, "layer_idx"): + submodule.layer_idx = layer_idx + + for module in new_model.modules(): + module_config = getattr(module, "config", None) + if module_config is not None: + set_dtype_in_config(module_config, dtype) + + target_device = _get_model_device(new_model) + text_config = getattr(config, "text_config", config) + vision_config = getattr(config, "vision_config", None) + is_gemma4 = getattr(config, "model_type", None) == "gemma4" + + for module in new_model.modules(): + if hasattr(module, "rotary_emb"): + rotary_config = text_config + current_rotary_config = getattr(module.rotary_emb, "config", None) + is_vision_rotary = ( + vision_config is not None and + current_rotary_config is not None and + current_rotary_config is not text_config and + current_rotary_config.__class__ == vision_config.__class__ + ) + if is_vision_rotary: + rotary_config = vision_config + module.rotary_emb = module.rotary_emb.__class__( + config = rotary_config, + device = target_device, + ) + buffer_dtype = torch.float32 if (is_gemma4 and is_vision_rotary) else dtype + for buffer_name in ("inv_freq", "original_inv_freq"): + buffer = getattr(module.rotary_emb, buffer_name, None) + if torch.is_tensor(buffer) and buffer.is_floating_point(): + module.rotary_emb._buffers[buffer_name] = buffer.to( + device = target_device, + dtype = buffer_dtype, + ) + if hasattr(module, "rotary_pos_emb") and vision_config is not None: + head_dim = vision_config.hidden_size // vision_config.num_heads + module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(target_device) + if hasattr(module, "rotary_emb_local"): + local_rope_config = deepcopy(text_config) + local_rope_config.rope_theta = text_config.rope_local_base_freq + local_rope_config.rope_scaling = {"rope_type": "default"} + module.rotary_emb_local = module.rotary_emb_local.__class__( + config = local_rope_config, + device = target_device, + ) + del local_rope_config + + if (quantization_config or {}) == {} and bnb_config is None: + new_model = new_model.to(device = target_device, dtype = dtype) + if is_gemma4: + for module in new_model.modules(): + rotary_emb = getattr(module, "rotary_emb", None) + if rotary_emb is None: + continue + rotary_cfg = getattr(rotary_emb, "config", None) + if rotary_cfg is None: + continue + fresh_rotary_emb = rotary_emb.__class__( + config = rotary_cfg, + device = target_device, + ) + for attr_name in ("max_seq_len_cached", "original_max_seq_len"): + if hasattr(fresh_rotary_emb, attr_name): + setattr(rotary_emb, attr_name, getattr(fresh_rotary_emb, attr_name)) + for attr_name, attr_value in fresh_rotary_emb.__dict__.items(): + if attr_name == "attention_scaling" or attr_name.endswith("_attention_scaling"): + setattr(rotary_emb, attr_name, attr_value) + for buffer_name, buffer in fresh_rotary_emb._buffers.items(): + if torch.is_tensor(buffer) and buffer.is_floating_point(): + rotary_emb._buffers[buffer_name] = buffer.to( + device = target_device, + dtype = torch.float32, + ) + return new_model +pass + def get_model_layer_config(return_non_layered=True): """ Returns a unified layer configuration containing the union of layer names @@ -520,6 +777,7 @@ def get_model_layer_config(return_non_layered=True): """ layer_templates = { 'standard_layers': { + "model.language_model.layers.{kk}.layer_scalar", "model.language_model.layers.{kk}.self_attn.q_proj", "model.language_model.layers.{kk}.self_attn.k_proj", "model.language_model.layers.{kk}.self_attn.v_proj", @@ -530,6 +788,7 @@ def get_model_layer_config(return_non_layered=True): "model.language_model.layers.{kk}.mlp.gate_up_proj", # for extracting from vLLM (phi3 architecture) "model.language_model.layers.{kk}.mlp.down_proj", + "model.layers.{kk}.layer_scalar", "model.layers.{kk}.self_attn.q_proj", "model.layers.{kk}.self_attn.k_proj", "model.layers.{kk}.self_attn.v_proj", @@ -539,6 +798,29 @@ def get_model_layer_config(return_non_layered=True): "model.layers.{kk}.mlp.up_proj", "model.layers.{kk}.mlp.gate_up_proj", # for extracting from vLLM (phi3 architecture) "model.layers.{kk}.mlp.down_proj", + "model.language_model.layers.{kk}.linear_attn.in_proj_qkv", + "model.language_model.layers.{kk}.linear_attn.in_proj_z", + "model.language_model.layers.{kk}.linear_attn.in_proj_b", + "model.language_model.layers.{kk}.linear_attn.in_proj_a", + "model.language_model.layers.{kk}.linear_attn.conv1d", + "model.language_model.layers.{kk}.linear_attn.out_proj", + "model.language_model.layers.{kk}.linear_attn.dt_bias", + "model.language_model.layers.{kk}.linear_attn.A_log", + + "model.layers.{kk}.linear_attn.in_proj_qkv", + "model.layers.{kk}.linear_attn.in_proj_z", + "model.layers.{kk}.linear_attn.in_proj_b", + "model.layers.{kk}.linear_attn.in_proj_a", + "model.layers.{kk}.linear_attn.conv1d", + "model.layers.{kk}.linear_attn.out_proj", + "model.layers.{kk}.linear_attn.dt_bias", + "model.layers.{kk}.linear_attn.A_log", + + # Gemma4 per-layer input modules + "model.language_model.layers.{kk}.per_layer_input_gate", + "model.language_model.layers.{kk}.per_layer_projection", + "model.layers.{kk}.per_layer_input_gate", + "model.layers.{kk}.per_layer_projection", }, 'layernorms': { "model.language_model.layers.{kk}.input_layernorm", @@ -560,6 +842,12 @@ def get_model_layer_config(return_non_layered=True): "model.vision_tower.vision_model.encoder.layers.{kk}.post_layernorm", "model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm1", "model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm2", + "model.vision_tower.encoder.layers.{kk}.input_layernorm", + "model.vision_tower.encoder.layers.{kk}.post_attention_layernorm", + "model.vision_tower.encoder.layers.{kk}.pre_feedforward_layernorm", + "model.vision_tower.encoder.layers.{kk}.post_feedforward_layernorm", + "model.vision_tower.encoder.layers.{kk}.self_attn.q_norm", + "model.vision_tower.encoder.layers.{kk}.self_attn.k_norm", # Mistral3 vision norms "model.vision_tower.transformer.layers.{kk}.attention_norm", @@ -567,6 +855,12 @@ def get_model_layer_config(return_non_layered=True): # qwen3 vl "model.visual.deepstack_merger_list.{kk}.norm", + "model.language_model.layers.{kk}.linear_attn.norm", + "model.layers.{kk}.linear_attn.norm", + + # Gemma4 per-layer input norm + "model.language_model.layers.{kk}.post_per_layer_input_norm", + "model.layers.{kk}.post_per_layer_input_norm", }, 'vision_layers': { @@ -610,6 +904,13 @@ def get_model_layer_config(return_non_layered=True): "model.vision_tower.vision_model.encoder.layers.{kk}.mlp.fc1", "model.vision_tower.vision_model.encoder.layers.{kk}.mlp.fc2", + "model.vision_tower.encoder.layers.{kk}.self_attn.q_proj.linear", + "model.vision_tower.encoder.layers.{kk}.self_attn.k_proj.linear", + "model.vision_tower.encoder.layers.{kk}.self_attn.v_proj.linear", + "model.vision_tower.encoder.layers.{kk}.self_attn.o_proj.linear", + "model.vision_tower.encoder.layers.{kk}.mlp.gate_proj.linear", + "model.vision_tower.encoder.layers.{kk}.mlp.up_proj.linear", + "model.vision_tower.encoder.layers.{kk}.mlp.down_proj.linear", # qwen2.5_vl style "model.visual.blocks.{kk}.attn.qkv", @@ -654,7 +955,8 @@ def get_model_layer_config(return_non_layered=True): # qwen 3 vl "model.visual.deepstack_merger_list.{kk}.linear_fc1", "model.visual.deepstack_merger_list.{kk}.linear_fc2", - "model.visual.merger.linear_fc{kk}", + "model.visual.merger.linear_fc1", + "model.visual.merger.linear_fc2", }, "non_layered_components":{ @@ -685,6 +987,11 @@ def get_model_layer_config(return_non_layered=True): "model.vision_tower.patch_positional_embedding", "model.vision_tower.patch_conv", "model.vision_tower.ln_pre", + "model.vision_tower.std_bias", + "model.vision_tower.std_scale", + "model.vision_tower.patch_embedder.position_embedding_table", + "model.vision_tower.patch_embedder.input_proj", + "model.embed_vision.embedding_projection", # qwen 3 vl "model.visual.pos_embed", @@ -732,6 +1039,11 @@ def get_model_layer_counts(config): "vision_layers": getattr(config.vision_config, "depth", 27), "deepstack_layers": getattr(config.vision_config, "deepstack_depth", 3), } + elif model_type == "gemma4": + return { + "text_layers": getattr(config.text_config, "num_hidden_layers", 32), + "vision_layers": getattr(config.vision_config, "num_hidden_layers", 32), + } elif model_type == "gemma3": return { "text_layers": getattr(config.text_config, "num_hidden_layers", 32), @@ -759,6 +1071,102 @@ def _get_nested_attr(obj, attr_path: str): return None +def extract_gdn_layers(gdn_module, prefix, state_dict, quant_state_dict, get_state_dict): + gdn = gdn_module + + def _unwrap(v): + return getattr(v, "data", v) + + def store(name, value): + state_dict[name] = value + quant_state_dict[name] = value + + if hasattr(gdn, "in_proj_qkvz"): + proj = getattr(gdn.in_proj_qkvz, "base_layer", gdn.in_proj_qkvz) + weight = _unwrap(proj.weight) + + output_sizes = getattr(proj, "output_sizes", None) + if output_sizes is None: + key_dim = getattr(gdn, "key_dim", None) + value_dim = getattr(gdn, "value_dim", None) + if key_dim is None or value_dim is None: + raise RuntimeError( + "Unsloth: cannot infer GDN in_proj_qkvz shards without " + "proj.output_sizes or gdn.key_dim / gdn.value_dim" + ) + output_sizes = [key_dim, key_dim, value_dim, value_dim] + output_sizes = list(output_sizes) + offsets = [0] + for s in output_sizes: + offsets.append(offsets[-1] + s) + if len(offsets) < 5: + raise RuntimeError( + f"Unsloth: GDN in_proj_qkvz expected 4 shards (q,k,v,z); got sizes={output_sizes}" + ) + + qkv_weight = weight[offsets[0]:offsets[3]] + z_weight = weight[offsets[3]:offsets[4]] + store(f"{prefix}.in_proj_qkv.weight", qkv_weight) + store(f"{prefix}.in_proj_z.weight", z_weight) + + qs_attr = getattr(weight, "bnb_quant_state", None) + if isinstance(qs_attr, dict): + qkv_qs = qs_attr.get(0) + z_qs = qs_attr.get(3) + if qkv_qs is not None: + quant_state_dict[f"{prefix}.in_proj_qkv.weight.quant_state"] = qkv_qs + try: + for k, v in qkv_qs.as_dict(packed=True).items(): + state_dict[f"{prefix}.in_proj_qkv.weight.{k}"] = v + except Exception: + pass + if z_qs is not None: + quant_state_dict[f"{prefix}.in_proj_z.weight.quant_state"] = z_qs + try: + for k, v in z_qs.as_dict(packed=True).items(): + state_dict[f"{prefix}.in_proj_z.weight.{k}"] = v + except Exception: + pass + + if weight.dtype == torch.float8_e4m3fn: + scale_attr = None + if hasattr(proj, "weight_scale"): + scale_attr = "weight_scale" + elif hasattr(proj, "weight_scale_inv"): + scale_attr = "weight_scale_inv" + ws = _unwrap(getattr(proj, scale_attr)) if scale_attr is not None else None + if ws is not None: + if ws.ndim == 2 and ws.shape[1] > 1: + block_size = proj.weight_block_size[0] + scale_offsets = [x // block_size for x in offsets] + qkv_scale = ws[scale_offsets[0]:scale_offsets[3]] + z_scale = ws[scale_offsets[3]:scale_offsets[4]] + else: + qkv_scale = ws[offsets[0]:offsets[3]] + z_scale = ws[offsets[3]:offsets[4]] + store(f"{prefix}.in_proj_qkv.{scale_attr}", qkv_scale) + store(f"{prefix}.in_proj_z.{scale_attr}", z_scale) + else: + get_state_dict(f"{prefix}.in_proj_qkv", 0, state_dict, gdn.in_proj_qkv, slice_weights=False) + get_state_dict(f"{prefix}.in_proj_z", 0, state_dict, gdn.in_proj_z, slice_weights=False) + + ba_layer = getattr(gdn.in_proj_ba, "base_layer", gdn.in_proj_ba) + ba_weight = _unwrap(ba_layer.weight) + mid = ba_weight.shape[0] // 2 + store(f"{prefix}.in_proj_b.weight", ba_weight[:mid]) + store(f"{prefix}.in_proj_a.weight", ba_weight[mid:]) + + store(f"{prefix}.conv1d.weight", gdn.conv1d.weight.data) + store(f"{prefix}.dt_bias", gdn.dt_bias.data) + store(f"{prefix}.A_log", gdn.A_log.data) + + if hasattr(gdn, "norm") and hasattr(gdn.norm, "weight"): + store(f"{prefix}.norm.weight", gdn.norm.weight.data) + + get_state_dict(f"{prefix}.out_proj", 0, state_dict, gdn.out_proj) +pass + + def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_state_dict): """ Extracts vision layers for any supported vision model by dynamically using @@ -790,7 +1198,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if layer_module is not None: if "qkv" in layer_path: - if model_type in ("qwen2_5_vl", "qwen3_vl"): + if model_type in ("qwen2_5_vl", "qwen3_vl", "qwen3_5"): # If the HF model too prefers having merged qkv, we do this # This is evident in qwen-2.5-vl and qwen-3-vl so far. get_state_dict(layer_path, 0, state_dict, layer_module, slice_weights=False) @@ -809,7 +1217,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if isinstance(layer_module, torch.nn.Module): if hasattr(layer_module, 'weight'): get_state_dict(layer_path, 0, state_dict, layer_module) - elif isinstance(layer_module, torch.nn.Parameter): + elif isinstance(layer_module, torch.Tensor): state_dict[f"{layer_path}"] = layer_module.data quant_state_dict[f"{layer_path}"] = state_dict[f"{layer_path}"] else: @@ -824,7 +1232,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat if hasattr(component, 'weight'): # Prefer using get_state_dict when possible get_state_dict(component_path, 0, state_dict, component) - elif isinstance(component, torch.nn.Parameter): + elif isinstance(component, torch.Tensor): state_dict[component_path] = component.data quant_state_dict[component_path] = component.data elif isinstance(component, torch.nn.Module): diff --git a/unsloth_zoo/hf_utils.py b/unsloth_zoo/hf_utils.py index cb96a9c51..8b99603e6 100644 --- a/unsloth_zoo/hf_utils.py +++ b/unsloth_zoo/hf_utils.py @@ -50,15 +50,31 @@ def dtype_from_config(config): return dtype def set_dtype_in_config(config, dtype): - try: - # if dtype is not a string, convert it to a string - string_dtype = str(dtype).split(".")[-1] if isinstance(dtype, torch.dtype) else dtype - if HAS_TORCH_DTYPE: - setattr(config, "torch_dtype", string_dtype) - else: - setattr(config, "dtype", string_dtype) - except: - set_dtype_in_config_fallback(config, string_dtype) + runtime_dtype = getattr(torch, dtype, dtype) if isinstance(dtype, str) else dtype + if hasattr(config, "dtype"): + target_fields = ["dtype"] + elif hasattr(config, "torch_dtype"): + target_fields = ["torch_dtype"] + else: + target_fields = ["dtype" if HAS_TORCH_DTYPE else "torch_dtype"] + + success = False + for field in target_fields: + try: + setattr(config, field, runtime_dtype) + success = True + continue + except Exception: + pass + + try: + config.__dict__[field] = runtime_dtype + success = True + except Exception: + pass + + if not success: + set_dtype_in_config_fallback(config, dtype) def set_dtype_in_config_fallback(config, dtype): try: diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 791d4d9c1..9c90b195a 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -761,6 +761,8 @@ def grpo_accumulated_loss( for module in unwrapped_model.modules(): if hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "io_same_decice"): module._hf_hook.io_same_decice = False + if hasattr(module, "rope_deltas"): + module.rope_deltas = None pass all_logprobs_list = [] diff --git a/unsloth_zoo/temporary_patches/gemma4.py b/unsloth_zoo/temporary_patches/gemma4.py index 3a91bba0c..f356e5455 100644 --- a/unsloth_zoo/temporary_patches/gemma4.py +++ b/unsloth_zoo/temporary_patches/gemma4.py @@ -118,6 +118,18 @@ def __getattr__(self, name): ) return getattr(object.__getattribute__(self, "_real"), name) + def __setattr__(self, name, value): + if name == "_real": + object.__setattr__(self, name, value) + return + setattr(object.__getattribute__(self, "_real"), name, value) + + def __delattr__(self, name): + if name == "_real": + object.__delattr__(self, name) + return + delattr(object.__getattribute__(self, "_real"), name) + def get_text_config(self, decoder=None, encoder=None): # If upstream recursively calls get_text_config on the proxy, return # self so the proxy is not unwrapped back into a raw config. diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 4d77c88a5..5d5999db1 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1063,6 +1063,12 @@ def get_state_dict(prefix, kk, state_dict, proj, slice_weights=True, slice_index quant_state_dict[prefix + ".bias"] = bias_tensor pass + gemma4_k_eq_v_layers = { + kk + for kk, layer_type in enumerate(getattr(text_config, "layer_types", ())) + if model_type == "gemma4" and getattr(text_config, "attention_k_eq_v", False) and layer_type == "full_attention" + } + # Embedding if hasattr(vllm_internals, "model"): # Standard Language models vllm_text_model = vllm_internals.model @@ -1107,7 +1113,9 @@ def _is_fused_module(name: str) -> bool: else: get_state_dict(f"{prefix}.q_proj", 0, state_dict, qkv_proj) get_state_dict(f"{prefix}.k_proj", 1, state_dict, qkv_proj) - get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) + if kk not in gemma4_k_eq_v_layers: + get_state_dict(f"{prefix}.v_proj", 2, state_dict, qkv_proj) + get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) elif hasattr(layer, "cross_attn"): prefix = f"{vllm_text_model_prefix}.layers.{kk}.cross_attn" qkv_proj = layer.cross_attn.qkv_proj @@ -1119,8 +1127,21 @@ def _is_fused_module(name: str) -> bool: get_state_dict(f"{prefix}.q_proj", 0, state_dict, q_proj) get_state_dict(f"{prefix}.k_proj", 1, state_dict, kv_proj) get_state_dict(f"{prefix}.v_proj", 2, state_dict, kv_proj) + get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) + elif hasattr(layer, "linear_attn"): + # Qwen3.5 Gated Delta Net (GDN) linear attention layers + extract_gdn_layers( + layer.linear_attn, + f"{vllm_text_model_prefix}.layers.{kk}.linear_attn", + state_dict, quant_state_dict, get_state_dict, + ) + pass - get_state_dict(f"{prefix}.o_proj", 0, state_dict, o_proj) + if not hasattr(layer, "mlp"): + if hasattr(layer, "layer_scalar"): + state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data + quant_state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data + continue proj = layer.mlp.gate_up_proj use_fused_gate_up = _is_fused_module("gate_up_proj") @@ -1149,6 +1170,9 @@ def _is_fused_module(name: str) -> bool: except Exception as e: skipped_layernorms.append(layernorm_name.split(".")[-1]) pass + if hasattr(layer, "layer_scalar"): + state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data + quant_state_dict[f"{vllm_text_model_prefix}.layers.{kk}.layer_scalar"] = layer.layer_scalar.data pass if len(skipped_layernorms) != 0: @@ -1167,9 +1191,15 @@ def _is_fused_module(name: str) -> bool: # LM Head - Use get_state_dict for consistency if not getattr(text_config, "tie_word_embeddings", False): - lm_layer = [mod for name,mod in vllm_internals.named_modules() if "lm_head" in name] - # Use get_state_dict for consistent extraction and automatic truncation - get_state_dict("lm_head", 0, state_dict, lm_layer[0], slice_weights=False) + lm_layer = next((mod for name, mod in vllm_internals.named_modules() if "lm_head" in name), None) + if lm_layer is not None: + get_state_dict("lm_head", 0, state_dict, lm_layer, slice_weights=False) + elif hasattr(vllm_internals, "language_model") and hasattr(vllm_internals.language_model, "lm_head"): + get_state_dict("lm_head", 0, state_dict, vllm_internals.language_model.lm_head, slice_weights=False) + elif hasattr(vllm_internals, "lm_head"): + get_state_dict("lm_head", 0, state_dict, vllm_internals.lm_head, slice_weights=False) + else: + raise RuntimeError("Could not find lm_head in vLLM internals") else: # Fallback to embed_tokens for tied embeddings embed_key = f"{vllm_text_model_prefix}.embed_tokens.weight" @@ -1189,6 +1219,15 @@ def assert_same_state_dict(old_state_dict, new_state_dict): # Check if state_dict are equivalent # hf, vllm + def _normalize_state_dict_tensor(value): + if isinstance(value, torch.nn.Parameter): + value = value.detach() + if not isinstance(value, torch.Tensor): + return value + if value.is_sparse: + value = value.to_dense() + return value.contiguous() + difference = new_state_dict.keys() ^ old_state_dict.keys() difference -= set(("model.lm_head.weight","model.language_model.lm_head.weight", "lm_head.weight")) if len(difference) != 0: @@ -1202,8 +1241,8 @@ def assert_same_state_dict(old_state_dict, new_state_dict): for key in old_state_dict: try: - old_val = old_state_dict[key] - new_val = new_state_dict[key] + old_val = _normalize_state_dict_tensor(old_state_dict[key]) + new_val = _normalize_state_dict_tensor(new_state_dict[key]) if old_val.dtype != new_val.dtype or (new_val.element_size() < 2): # upcast both to float32 just for comparison. For FP8, vLLM stores weight scale in FP32 while HF preferes 16bit old_val = old_val.to(torch.float32) @@ -1217,7 +1256,11 @@ def assert_same_state_dict(old_state_dict, new_state_dict): if key1 is not None and key2 is not None: try: - torch.testing.assert_close(old_state_dict[key1].contiguous(), new_state_dict[key2].contiguous(), check_stride = True) + torch.testing.assert_close( + _normalize_state_dict_tensor(old_state_dict[key1]), + _normalize_state_dict_tensor(new_state_dict[key2]), + check_stride = True, + ) except Exception: failures[key] = error else: @@ -1235,7 +1278,14 @@ def assert_same_state_dict(old_state_dict, new_state_dict): def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16, bnb_config = None, is_vision_model = False): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules to create HF compatible model + def _unwrap_tensor(value): + return getattr(value, "data", value) + set_dtype_in_config(config, dtype) + for subconfig_name in ("text_config", "vision_config", "audio_config"): + subconfig = getattr(config, subconfig_name, None) + if subconfig is not None: + set_dtype_in_config(subconfig, dtype) new_model, original_meta_model, layer_count, layer_names = create_empty_model(config, dtype, is_vision_model) new_model = new_model.to(device = get_target_device(), dtype = dtype) quantization_config = getattr(config, "quantization_config", {}) @@ -1300,6 +1350,7 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16, "norm1", # Qwen2.5-VL vision encoder "norm2", # Qwen2.5-VL vision encoder "norm", + "conv1d", ] # Override .to("cuda") to disable it otherwise we'll get # ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8 @@ -1333,7 +1384,7 @@ def _override_to(self, *args, **kwargs): if f"{layer_name}.bias" in quant_state_dict: # Has bias! has_bias = True - bias = quant_state_dict[f"{layer_name}.bias"] + bias = _unwrap_tensor(quant_state_dict[f"{layer_name}.bias"]) bias = torch.nn.Parameter(bias, requires_grad = False) else: has_bias = False @@ -1352,8 +1403,8 @@ def _override_to(self, *args, **kwargs): if layer_name in quant_state_dict: # for attributes of type nn.Parameter, there's no .weight - layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name.replace('model.','',1)) - layer = torch.nn.Parameter(weight, requires_grad = False) + layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) + layer = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) exec(f"new_model.{layer_name_br} = layer") continue elif fp8_weight_scale is not None: @@ -1362,7 +1413,7 @@ def _override_to(self, *args, **kwargs): layer = FbgemmFp8Linear(in_features = 0, out_features = 0, bias = has_bias, weight_dtype = dtype).to(get_target_device()) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] - layer.weight = torch.nn.Parameter(weight, requires_grad = False) + layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias layer.input_scale_ub = kwargs['input_scale_ub'] layer.weight_scale = torch.nn.Parameter(fp8_weight_scale, requires_grad = False) @@ -1378,7 +1429,7 @@ def _override_to(self, *args, **kwargs): layer = FP8Linear(**fp8_kwargs) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] - layer.weight = torch.nn.Parameter(weight, requires_grad = False) + layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias layer.weight_scale_inv = torch.nn.Parameter(fp8_weight_scale, requires_grad = False) layer.quant_method = "fp8" @@ -1402,11 +1453,11 @@ def _override_to(self, *args, **kwargs): layer.out_features = weight.shape[0] # from vllm 0.11.1, the .weight is of dtype ModelWeightParameter, so try to extract the 'data' part # https://github.com/vllm-project/vllm/commit/de94289a98d7ec52a5ef02719e01a1db8b505170#diff-7d6145ac4ba084231a441c2056c7fca23c3bae33e6542f4f602a6c9d4d2da64dL199-R208 - layer.weight = torch.nn.Parameter(getattr(weight, 'data', weight), requires_grad = False) + layer.weight = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad = False) layer.bias = bias else: # LayerNorms (including vision norms) - weight_param = torch.nn.Parameter(weight, requires_grad=False) + weight_param = torch.nn.Parameter(_unwrap_tensor(weight), requires_grad=False) layer_name_br = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) # Set weight exec(f"new_model.{layer_name_br}.weight = None") @@ -1425,49 +1476,14 @@ def _override_to(self, *args, **kwargs): pass set_additional_modules(new_model, quant_state_dict, config) - - if original_meta_model is not None: - copy_attributes(original_meta_model, new_model) - - # # Set config on model and modules using clean approach - # new_model.config = config - # for module in new_model.modules(): - # if hasattr(module, "config"): - # module.config = config - # for param in new_model.parameters(): - # if hasattr(param, "config"): - # param.config = config - - text_config = getattr(config, "text_config", config) #try using text config for VLMs - vision_config = getattr(config, "vision_config", None) - # Fix up rotary_emb by re-initing them - for module in new_model.modules(): - if hasattr(module, "rotary_emb"): - module.rotary_emb = module.rotary_emb.__class__( - config = text_config, - device = get_target_device(), - ) - if hasattr(module, "rotary_pos_emb"): - # Qwen 2.5 VL has a rotary_pos_emb in vision submodel - # https://github.com/huggingface/transformers/blob/a871f6f58d49f3a05ae9dae519caa8aa9d919a07/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L337 - assert vision_config is not None, "Unsloth: vision_config is required for models with vision rotary_pos_emb" - head_dim = vision_config.hidden_size // vision_config.num_heads - module.rotary_pos_emb = module.rotary_pos_emb.__class__(head_dim//2).to(get_target_device()) - if hasattr(module, "rotary_emb_local"): - # gemma3 has a rotary_emb_local - # https://github.com/huggingface/transformers/blob/008c0ba8e2a1226a6ef5a61c4915a0a8a340c157/src/transformers/models/gemma3/modeling_gemma3.py#L469-L471 - # Gemma3 uses different defaults for local and global RoPE. Copy the config for modification. - local_rope_config = deepcopy(text_config) - local_rope_config.rope_theta = text_config.rope_local_base_freq - local_rope_config.rope_scaling = {"rope_type": "default"} - # gemma3 has a rotary_emb_local - module.rotary_emb_local = module.rotary_emb_local.__class__( - config = local_rope_config, - device = get_target_device(), - ) - del local_rope_config - pass - pass + new_model = finalize_huggingface_model( + new_model, + original_meta_model, + config, + dtype, + quantization_config = quantization_config, + bnb_config = bnb_config, + ) # Must override or else Bitsandbytes will error new_model.to = partial(_override_to, new_model) @@ -1771,6 +1787,12 @@ def load_vllm( assert(type(use_bitsandbytes) is bool) assert(conservativeness >= 0.0 and conservativeness <= 1.0) + if getattr(config, "model_type", None) == "gemma4": + if enable_lora: + patch_gemma4_vllm_lora_support() + if use_bitsandbytes: + patch_gemma4_vllm_k_eq_v_support() + unsloth_vllm_standby = unsloth_vllm_standby or (os.getenv("UNSLOTH_VLLM_STANDBY", "0") != "0") # This would give the flexibility to override the util we set for standby mode. In some extreme cases, this can be helpful. standby_util_override = os.getenv("UNSLOTH_VLLM_STANDBY_UTIL_OVERRIDE", "0") != "0" @@ -2866,10 +2888,19 @@ def _test_is_same_vlm(model, new_model, processor, test_backward=False): messages, tokenize=False, add_generation_prompt=True ) - inputs = processor.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, - return_dict=True, return_tensors="pt" - ).to(model.device, dtype=model.dtype) + if processor.__class__.__name__ in ("Gemma3Processor", "Gemma4Processor"): + from transformers.image_utils import load_image + image = load_image(messages[0]["content"][0]["image"]) + inputs = processor( + text = [text], + images = [image], + return_tensors = "pt", + ).to(model.device, dtype=model.dtype) + else: + inputs = processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, + return_dict=True, return_tensors="pt" + ).to(model.device, dtype=model.dtype) with torch.no_grad(): original_outputs = model(**inputs) @@ -2989,6 +3020,7 @@ def _test_get_vllm_state_dict( load_in_4bit = False, skip_generation = False, is_vision_model = False, + compilation_config = None, ): # All Unsloth Zoo code licensed under LGPLv3 # Check if model is allowed to be used in vLLM @@ -3028,6 +3060,8 @@ def _test_get_vllm_state_dict( model_type = getattr(config, "model_type", "causal_lm") enable_lora = model_type != "mllama" + if compilation_config is None and model_type == "gemma4": + compilation_config = 0 if not is_vision_model: model_class = AutoModelForCausalLM @@ -3069,6 +3103,7 @@ def _test_get_vllm_state_dict( use_bitsandbytes = load_in_4bit, is_vision_model = is_vision_model, enable_lora = enable_lora, + compilation_config = compilation_config, ) state_dict, quant_state_dict = get_vllm_state_dict( @@ -3082,6 +3117,8 @@ def _test_get_vllm_state_dict( new_model = convert_vllm_to_huggingface(quant_state_dict, config, dtype, is_vision_model = is_vision_model) test_model_conversion(model, new_model) + new_model, _ = patch_model_and_tokenizer(new_model, None) + new_model.eval() # Run the model as well if not is_vision_model: