diff --git a/tests/test_mlx_save_export_regressions.py b/tests/test_mlx_save_export_regressions.py new file mode 100644 index 000000000..3051962e7 --- /dev/null +++ b/tests/test_mlx_save_export_regressions.py @@ -0,0 +1,888 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Fast regressions for MLX save/export parity fixes. + +These tests cover the contracts behind the real save / GGUF export bugs without +downloading or converting large models. +""" + +from __future__ import annotations + +import dataclasses +import json +import os +import sys +import types +from pathlib import Path + +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def _install_mlx_torch_shim(): + from mlx_simulation import simulate_mlx_on_torch + + simulate_mlx_on_torch() + + +def test_vlm_config_save_uses_vlm_helper_and_preserves_quantization_config( + monkeypatch, + tmp_path, +): + import unsloth_zoo.mlx.utils as mutils + + calls = {} + fake_vlm_utils = types.ModuleType("mlx_vlm.utils") + + def fake_save_config(config, path): + calls["config"] = config + calls["path"] = Path(path) + Path(path).write_text(json.dumps(config), encoding="utf-8") + + fake_vlm_utils.save_config = fake_save_config + monkeypatch.setitem(sys.modules, "mlx_vlm.utils", fake_vlm_utils) + + config = { + "model_type": "gemma3", + "vision_config": {"hidden_size": 8}, + "quantization": {"group_size": 64, "bits": 4}, + } + mutils._save_mlx_config(config, tmp_path / "config.json", is_vlm=True) + + assert calls["path"] == tmp_path / "config.json" + assert calls["config"]["quantization"] == config["quantization"] + assert calls["config"]["quantization_config"] == config["quantization"] + assert "quantization_config" not in config + + +def test_merged_16bit_save_fully_dequantizes_model(monkeypatch, tmp_path): + import unsloth_zoo.mlx.utils as mutils + + calls = {"fuse": [], "dequantize": 0} + + class LoRALinear: + def fuse(self, dequantize=False): + calls["fuse"].append(dequantize) + return "fused-linear" + + class Model: + _config = { + "model_type": "llama", + "tie_word_embeddings": False, + "quantization": {"bits": 4}, + "nested": {"quantization_config": {"bits": 4}}, + } + + def eval(self): + calls["eval"] = True + + def named_modules(self): + return [("layers.0.self_attn.q_proj", LoRALinear())] + + def update_modules(self, modules): + calls["updated"] = modules + + class Tokenizer: + def save_pretrained(self, path): + Path(path).mkdir(parents=True, exist_ok=True) + calls["tokenizer_path"] = Path(path) + + fake_mlx_lm_utils = types.ModuleType("mlx_lm.utils") + + def fake_dequantize_model(model): + calls["dequantize"] += 1 + return model + + def fake_save_model(path, model, donate_model=False): + Path(path).mkdir(parents=True, exist_ok=True) + calls["donate_model"] = donate_model + + def fake_save_config(config, path): + calls["saved_config"] = config + Path(path).write_text(json.dumps(config), encoding="utf-8") + + fake_mlx_lm_utils.dequantize_model = fake_dequantize_model + fake_mlx_lm_utils.save_model = fake_save_model + fake_mlx_lm_utils.save_config = fake_save_config + fake_mlx_lm_utils.create_model_card = lambda path, hf_repo: None + monkeypatch.setitem(sys.modules, "mlx_lm.utils", fake_mlx_lm_utils) + + fake_mlx_utils = types.ModuleType("mlx.utils") + fake_mlx_utils.tree_unflatten = dict + monkeypatch.setitem(sys.modules, "mlx.utils", fake_mlx_utils) + + mutils.save_merged_model(Model(), Tokenizer(), tmp_path, dequantize=True) + + assert calls["eval"] is True + assert calls["fuse"] == [True] + assert calls["dequantize"] == 1 + assert calls["donate_model"] is False + assert "quantization" not in calls["saved_config"] + assert "quantization_config" not in calls["saved_config"]["nested"] + + +def test_bound_gguf_save_filters_cuda_only_kwargs(monkeypatch, tmp_path): + import unsloth_zoo.mlx.loader as loader + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + def fake_save_pretrained_gguf( + model, + tokenizer, + save_directory, + quantization_method="fast_quantized", + **kwargs, + ): + calls["tokenizer"] = tokenizer + calls["save_directory"] = Path(save_directory) + calls["quantization_method"] = quantization_method + calls["kwargs"] = kwargs + + monkeypatch.setattr(mutils, "save_pretrained_gguf", fake_save_pretrained_gguf) + tokenizer = object() + model = types.SimpleNamespace(_tokenizer=tokenizer) + + loader._mlx_save_pretrained_gguf( + model, + tmp_path, + quantization_method="not_quantized", + first_conversion="f16", + maximum_memory_usage=0.5, + temporary_location="/tmp/ignored", + ) + + assert calls == { + "tokenizer": tokenizer, + "save_directory": tmp_path, + "quantization_method": "not_quantized", + "kwargs": {"first_conversion": "f16"}, + } + + +def test_bound_gguf_push_filters_kwargs(monkeypatch): + import unsloth_zoo.mlx.loader as loader + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + def fake_push_to_hub_gguf( + model, + tokenizer, + save_directory, + repo_id, + quantization_method="fast_quantized", + **kwargs, + ): + calls["tokenizer"] = tokenizer + calls["save_directory"] = save_directory + calls["repo_id"] = repo_id + calls["quantization_method"] = quantization_method + calls["kwargs"] = kwargs + + monkeypatch.setattr(mutils, "push_to_hub_gguf", fake_push_to_hub_gguf) + tokenizer = object() + model = types.SimpleNamespace(_tokenizer=tokenizer) + + loader._mlx_push_to_hub_gguf( + model, + "org/model", + quantization_method="q8_0", + first_conversion="bf16", + token="hf_token", + private=True, + maximum_memory_usage=0.5, + temporary_location="/tmp/ignored", + ) + + assert calls == { + "tokenizer": tokenizer, + "save_directory": "org/model", + "repo_id": "org/model", + "quantization_method": "q8_0", + "kwargs": { + "first_conversion": "bf16", + "token": "hf_token", + "private": True, + }, + } + + +def test_lora_push_uses_lora_adapter_hub_path(monkeypatch, tmp_path): + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + class Model: + def named_modules(self): + return [("layers.0.q_proj", types.SimpleNamespace(fuse=lambda: None))] + + def trainable_parameters(self): + return {} + + class Tokenizer: + def save_pretrained(self, path): + calls["tokenizer_path"] = Path(path) + + def fake_save_lora_adapters(model, save_directory): + calls["adapter_dir"] = Path(save_directory) + + def fake_push_lora_adapters_to_hub( + save_directory, + **kwargs, + ): + calls["hub_dir"] = Path(save_directory) + calls["hub_kwargs"] = kwargs + + monkeypatch.setattr( + mutils, + "collect_mlx_lora_adapter_tensors", + lambda model: {"layers.0.q_proj.lora_a": object()}, + ) + monkeypatch.setattr(mutils, "iter_mlx_lora_modules", lambda model: []) + monkeypatch.setattr(mutils, "save_lora_adapters", fake_save_lora_adapters) + monkeypatch.setattr( + mutils, + "_push_lora_adapters_to_hub", + fake_push_lora_adapters_to_hub, + ) + monkeypatch.setattr( + mutils, + "push_to_hub_merged", + lambda *args, **kwargs: pytest.fail("push_to_hub_merged should not run"), + ) + + mutils.save_pretrained_merged( + Model(), + Tokenizer(), + tmp_path, + save_method="lora", + push_to_hub=True, + token="hf_token", + private=True, + ) + + assert calls["adapter_dir"] == tmp_path + assert calls["hub_dir"] == tmp_path + assert calls["hub_kwargs"]["repo_id"] is None + assert calls["hub_kwargs"]["token"] == "hf_token" + assert calls["hub_kwargs"]["private"] is True + + +def _patch_mlx_tensor_helpers_for_torch(monkeypatch, mutils): + import torch + + monkeypatch.setattr( + mutils.mx, + "transpose", + lambda tensor, axes=None, **kwargs: tensor.permute(*axes) + if axes is not None + else tensor.permute(*reversed(range(tensor.ndim))), + ) + monkeypatch.setattr(mutils.mx, "all", torch.all) + + +def test_vlm_rewrite_prefers_hf_alias_before_current_name(monkeypatch): + import torch + import unsloth_zoo.mlx.utils as mutils + + _patch_mlx_tensor_helpers_for_torch(monkeypatch, mutils) + + class QwenSanitizer: + @staticmethod + def sanitize(weights): + renamed = {} + for name, tensor in weights.items(): + if name.startswith("visual."): + name = f"vision_tower.{name[len('visual.'):]}" + renamed[name] = tensor + return renamed + + tensor = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5) + new_name, new_tensor, changed = mutils._rewrite_mlx_vlm_tensor_for_gguf( + "vision_tower.patch_embed.proj.weight", + tensor, + [(QwenSanitizer, None)], + ) + + assert changed is True + assert new_name == "visual.patch_embed.proj.weight" + assert mutils._mlx_arrays_match(new_tensor, tensor) + + +def test_vlm_rewrite_handles_same_name_layout_transforms(monkeypatch): + import torch + import unsloth_zoo.mlx.utils as mutils + + _patch_mlx_tensor_helpers_for_torch(monkeypatch, mutils) + + class SameNameConvSanitizer: + @staticmethod + def sanitize(weights): + return { + name: mutils.mx.transpose(tensor, (0, 2, 3, 1)) + for name, tensor in weights.items() + } + + mlx_layout = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5) + new_name, hf_layout, changed = mutils._rewrite_mlx_vlm_tensor_for_gguf( + "vision_tower.patch_embed.proj.weight", + mlx_layout, + [(SameNameConvSanitizer, None)], + ) + + assert changed is True + assert new_name == "vision_tower.patch_embed.proj.weight" + assert tuple(hf_layout.shape) == (2, 5, 3, 4) + assert mutils._mlx_arrays_match( + mutils.mx.transpose(hf_layout, (0, 2, 3, 1)), + mlx_layout, + ) + + +def test_vlm_rewrite_skips_untransformable_text_tensors(): + import torch + import unsloth_zoo.mlx.utils as mutils + + calls = 0 + + class CountingSanitizer: + @staticmethod + def sanitize(weights): + nonlocal calls + calls += 1 + return weights + + tensor = torch.zeros(2, 3) + new_name, new_tensor, changed = mutils._rewrite_mlx_vlm_tensor_for_gguf( + "language_model.model.layers.0.mlp.gate_proj.weight", + tensor, + [(CountingSanitizer, None)], + ) + + assert calls == 0 + assert changed is False + assert new_name == "language_model.model.layers.0.mlp.gate_proj.weight" + assert new_tensor is tensor + + +def test_mlx_arrays_match_checks_2d_tensor_values(monkeypatch): + import torch + import unsloth_zoo.mlx.utils as mutils + + monkeypatch.setattr(mutils.mx, "all", torch.all) + + assert mutils._mlx_arrays_match( + torch.zeros(2, 3), + torch.zeros(2, 3), + ) + assert not mutils._mlx_arrays_match( + torch.zeros(2, 3), + torch.ones(2, 3), + ) + + +def test_vlm_sanitizer_replay_uses_real_model_instances(): + import unsloth_zoo.mlx.utils as mutils + + class VisionTower: + def sanitize(self, weights): + assert "vision_tower.proj.weight" in weights + return {"visual.proj.weight": weights["vision_tower.proj.weight"]} + + class Model: + def __init__(self): + self.vision_tower = VisionTower() + + def sanitize(self, weights): + return self.vision_tower.sanitize(weights) + + model = Model() + pipelines = mutils._get_mlx_vlm_model_sanitize_pipelines(model) + + assert pipelines[0][0][0] is model + assert mutils._apply_mlx_vlm_sanitizers( + pipelines[0], + {"vision_tower.proj.weight": "tensor"}, + ) == {"visual.proj.weight": "tensor"} + + +def test_repair_degraded_vlm_processor_rebuilds_from_sidecar_configs( + monkeypatch, + tmp_path, +): + import unsloth_zoo.mlx.loader as loader + + class FakeProcessor: + def __init__(self, image_processor, tokenizer, chat_template=None): + self.image_processor = image_processor + self.tokenizer = tokenizer + self.chat_template = chat_template + + fake_processing = types.ModuleType("mlx_vlm.models.glm_ocr.processing") + fake_processing.FakeProcessor = FakeProcessor + monkeypatch.setitem( + sys.modules, + "mlx_vlm.models.glm_ocr.processing", + fake_processing, + ) + + image_processor = object() + monkeypatch.setattr( + loader, + "_build_vlm_image_processor_from_config", + lambda model_path, processor_config, preprocessor_config: image_processor, + ) + + (tmp_path / "processor_config.json").write_text( + json.dumps({"processor_class": "FakeProcessor"}), + encoding="utf-8", + ) + (tmp_path / "preprocessor_config.json").write_text( + json.dumps({"image_processor_type": "FakeImageProcessor"}), + encoding="utf-8", + ) + + tokenizer = types.SimpleNamespace( + chat_template=None, + save_pretrained=lambda path: None, + ) + degraded = types.SimpleNamespace( + tokenizer=tokenizer, + chat_template="{{ messages }}", + ) + + repaired = loader._repair_degraded_vlm_processor( + degraded, + tmp_path, + "glm_ocr", + ) + + assert isinstance(repaired, FakeProcessor) + assert repaired.image_processor is image_processor + assert repaired.tokenizer is tokenizer + assert repaired.chat_template == "{{ messages }}" + assert tokenizer.chat_template == "{{ messages }}" + + +def test_read_json_file_returns_empty_for_missing_or_malformed_files(tmp_path): + import unsloth_zoo.mlx.loader as loader + + assert loader._read_json_file(tmp_path / "missing.json") == {} + + malformed = tmp_path / "malformed.json" + malformed.write_text("{not-json", encoding="utf-8") + + assert loader._read_json_file(malformed) == {} + + +def test_read_json_file_does_not_swallow_unexpected_errors(monkeypatch, tmp_path): + import builtins + import unsloth_zoo.mlx.loader as loader + + def fail_open(*args, **kwargs): + raise RuntimeError("unexpected") + + monkeypatch.setattr(builtins, "open", fail_open) + + with pytest.raises(RuntimeError, match="unexpected"): + loader._read_json_file(tmp_path / "config.json") + + +def test_get_model_config_extracts_dataclass_configs(): + import unsloth_zoo.mlx.utils as mutils + + @dataclasses.dataclass + class VisionConfig: + hidden_size: int + + @dataclasses.dataclass + class ModelConfig: + model_type: str + vision_config: VisionConfig + scales: tuple[int, int] + + model = types.SimpleNamespace( + config=ModelConfig( + model_type="glm_ocr", + vision_config=VisionConfig(hidden_size=16), + scales=(1, 2), + ) + ) + + assert mutils._get_model_config(model) == { + "model_type": "glm_ocr", + "vision_config": {"hidden_size": 16}, + "scales": [1, 2], + } + + +def test_get_model_config_prefers_copied_raw_config(): + import unsloth_zoo.mlx.utils as mutils + + raw_config = {"model_type": "qwen3", "nested": {"values": [1]}} + model = types.SimpleNamespace( + _config=raw_config, + config=types.SimpleNamespace(to_dict=lambda: {"model_type": "wrong"}), + ) + + extracted = mutils._get_model_config(model) + extracted["nested"]["values"].append(2) + + assert extracted["model_type"] == "qwen3" + assert raw_config["nested"]["values"] == [1] + + +def test_has_vision_config_handles_nested_and_malformed_configs(): + import unsloth_zoo.mlx.utils as mutils + + assert not mutils._has_vision_config(None) + assert not mutils._has_vision_config({"thinker_config": "bad"}) + assert mutils._has_vision_config({"vision_config": {}}) + assert mutils._has_vision_config({"thinker_config": {"vision_config": {}}}) + + +def test_save_merged_model_detects_nested_vlm_config(monkeypatch, tmp_path): + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + class Model: + _config = { + "model_type": "glm_ocr", + "thinker_config": {"vision_config": {"hidden_size": 8}}, + } + + def eval(self): + pass + + def named_modules(self): + return [] + + class Tokenizer: + def save_pretrained(self, path): + Path(path).mkdir(parents=True, exist_ok=True) + + fake_mlx_lm_utils = types.ModuleType("mlx_lm.utils") + fake_mlx_lm_utils.dequantize_model = lambda model: model + fake_mlx_lm_utils.save_model = lambda path, model, donate_model=False: Path( + path + ).mkdir(parents=True, exist_ok=True) + fake_mlx_lm_utils.create_model_card = lambda path, hf_repo: None + fake_mlx_lm_utils.save_config = lambda config, path: pytest.fail( + "text save_config should not run" + ) + monkeypatch.setitem(sys.modules, "mlx_lm.utils", fake_mlx_lm_utils) + + fake_mlx_utils = types.ModuleType("mlx.utils") + fake_mlx_utils.tree_unflatten = dict + monkeypatch.setitem(sys.modules, "mlx.utils", fake_mlx_utils) + + monkeypatch.setattr(mutils, "_is_vlm_model", lambda model: False) + + def fake_save_mlx_config(config, config_path, *, is_vlm=False): + calls["is_vlm"] = is_vlm + calls["config"] = config + + monkeypatch.setattr(mutils, "_save_mlx_config", fake_save_mlx_config) + + mutils.save_merged_model(Model(), Tokenizer(), tmp_path) + + assert calls["is_vlm"] is True + assert calls["config"]["thinker_config"]["vision_config"]["hidden_size"] == 8 + + +def test_prepare_vlm_gguf_export_directory_writes_nextn_config_without_tensors( + monkeypatch, + tmp_path, +): + import unsloth_zoo.mlx.utils as mutils + + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "model_type": "glm_ocr", + "vision_config": {}, + "text_config": { + "num_hidden_layers": 16, + "num_nextn_predict_layers": 1, + "mtp_num_hidden_layers": 1, + "nextn_predict_layers": 1, + }, + } + ), + encoding="utf-8", + ) + monkeypatch.setattr(mutils, "_get_transformer_layers", lambda model: [object()] * 16) + monkeypatch.setattr( + mutils, + "_build_mlx_vlm_sanitize_pipelines", + lambda config, model=None: [], + ) + + rewritten = mutils._prepare_vlm_gguf_export_directory(tmp_path, model=object()) + + assert rewritten == 0 + updated = json.loads(config_path.read_text(encoding="utf-8")) + assert "num_nextn_predict_layers" not in updated["text_config"] + assert "mtp_num_hidden_layers" not in updated["text_config"] + assert "nextn_predict_layers" not in updated["text_config"] + + +def test_prepare_vlm_gguf_export_directory_ignores_malformed_thinker_config( + monkeypatch, + tmp_path, +): + import unsloth_zoo.mlx.utils as mutils + + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "model_type": "glm_ocr", + "thinker_config": "bad", + "text_config": {"num_hidden_layers": 16}, + } + ), + encoding="utf-8", + ) + monkeypatch.setattr( + mutils, + "_get_transformer_layers", + lambda model: [object()] * 16, + ) + monkeypatch.setattr( + mutils, + "_build_mlx_vlm_sanitize_pipelines", + lambda config, model=None: [], + ) + + assert mutils._prepare_vlm_gguf_export_directory(tmp_path, model=object()) == 0 + + +def test_copy_source_sidecars_preserves_image_processor_metadata(tmp_path): + import unsloth_zoo.mlx.utils as mutils + + src = tmp_path / "src" + dst = tmp_path / "dst" + src.mkdir() + dst.mkdir() + + for name in ( + "preprocessor_config.json", + "processor_config.json", + "video_preprocessor_config.json", + "chat_template.jinja", + "tokenizer.model", + "vocab.txt", + "custom_processing.py", + "config.json", + "README.md", + ".gitattributes", + "model.safetensors", + "model-00001-of-00002.safetensors", + "pytorch_model.bin", + ): + (src / name).write_text(name, encoding="utf-8") + (dst / "preprocessor_config.json").write_text("existing", encoding="utf-8") + + copied = mutils._copy_source_sidecars(src, dst) + + assert copied == 6 + assert (dst / "preprocessor_config.json").read_text(encoding="utf-8") == "existing" + for name in ( + "processor_config.json", + "video_preprocessor_config.json", + "chat_template.jinja", + "tokenizer.model", + "vocab.txt", + "custom_processing.py", + ): + assert (dst / name).read_text(encoding="utf-8") == name + for skipped in ( + "config.json", + "README.md", + ".gitattributes", + "model.safetensors", + "model-00001-of-00002.safetensors", + "pytorch_model.bin", + ): + assert not (dst / skipped).exists() + + +def test_copy_source_sidecars_ignores_non_directory_source(tmp_path): + import unsloth_zoo.mlx.utils as mutils + + src = tmp_path / "model.safetensors" + dst = tmp_path / "dst" + src.write_text("weights", encoding="utf-8") + dst.mkdir() + + assert mutils._copy_source_sidecars(src, dst) == 0 + assert list(dst.iterdir()) == [] + + +def test_save_pretrained_gguf_anchors_patcher_to_checked_llama_cpp_root( + monkeypatch, + tmp_path, +): + import unsloth_zoo.llama_cpp as llama_cpp + import unsloth_zoo.mlx.utils as mutils + + monkeypatch.setitem(sys.modules, "gguf", types.ModuleType("gguf")) + + llama_root = tmp_path / "llama.cpp" + llama_root.mkdir() + converter = llama_root / "convert_hf_to_gguf.py" + converter.write_text("# converter", encoding="utf-8") + quantizer = llama_root / "llama-quantize" + quantizer.write_text("# quantizer", encoding="utf-8") + + calls = {} + + def fake_save_merged_model(model, tokenizer, path, dequantize=False): + calls["dequantize"] = dequantize + Path(path).mkdir(parents=True, exist_ok=True) + + def fake_download_convert_hf_to_gguf(): + calls["scripts_dir"] = os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") + patched = llama_root / "unsloth_convert_hf_to_gguf.py" + patched.write_text("# patched converter", encoding="utf-8") + return str(patched), {"Qwen3ForCausalLM"}, {"Gemma3ForConditionalGeneration"} + + def fake_convert_to_gguf(**kwargs): + calls["convert_kwargs"] = kwargs + output = Path( + f"{kwargs['model_name']}.{kwargs['quantization_type'].upper()}.gguf" + ) + output.write_bytes(b"GGUF") + + monkeypatch.setattr(mutils, "save_merged_model", fake_save_merged_model) + monkeypatch.setattr(mutils, "_is_vlm_model", lambda model: False) + monkeypatch.setattr(llama_cpp, "LLAMA_CPP_DEFAULT_DIR", str(tmp_path / "unused")) + monkeypatch.setattr( + llama_cpp, + "check_llama_cpp", + lambda llama_cpp_folder: (str(quantizer), str(converter)), + ) + monkeypatch.setattr( + llama_cpp, + "install_llama_cpp", + lambda llama_cpp_folder: pytest.fail("install_llama_cpp should not run"), + ) + monkeypatch.setattr( + llama_cpp, + "_download_convert_hf_to_gguf", + fake_download_convert_hf_to_gguf, + ) + monkeypatch.setattr(llama_cpp, "convert_to_gguf", fake_convert_to_gguf) + monkeypatch.setattr( + llama_cpp, + "quantize_gguf", + lambda **kwargs: pytest.fail("quantize_gguf should not run"), + ) + + old_scripts_dir = os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") + model = types.SimpleNamespace(_hf_repo="org/TestModel") + out = tmp_path / "out" + mutils.save_pretrained_gguf( + model, + tokenizer=object(), + save_directory=out, + quantization_method="not_quantized", + first_conversion="f16", + ) + + assert calls["dequantize"] is True + assert calls["scripts_dir"] == str(llama_root) + assert calls["convert_kwargs"]["converter_location"] == str( + llama_root / "unsloth_convert_hf_to_gguf.py" + ) + assert calls["convert_kwargs"]["supported_text_archs"] == {"Qwen3ForCausalLM"} + assert (out / "TestModel.F16.gguf").read_bytes() == b"GGUF" + assert os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") == old_scripts_dir + + +def test_push_to_hub_gguf_forwards_first_conversion(monkeypatch, tmp_path): + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + class FakeHfApi: + def __init__(self, token=None): + calls["token"] = token + + def create_repo(self, repo_id, exist_ok=True, private=None): + calls["repo"] = { + "repo_id": repo_id, + "exist_ok": exist_ok, + "private": private, + } + + def update_repo_settings(self, **kwargs): + calls["update_repo_settings"] = kwargs + + def upload_file(self, path_or_fileobj, path_in_repo, repo_id): + calls["upload"] = { + "path_or_fileobj": Path(path_or_fileobj), + "path_in_repo": path_in_repo, + "repo_id": repo_id, + } + + fake_hub = types.ModuleType("huggingface_hub") + fake_hub.HfApi = FakeHfApi + monkeypatch.setitem(sys.modules, "huggingface_hub", fake_hub) + + def fake_save_pretrained_gguf( + model, + tokenizer, + save_directory, + quantization_method="fast_quantized", + first_conversion=None, + ): + calls["save"] = { + "quantization_method": quantization_method, + "first_conversion": first_conversion, + } + Path(save_directory).mkdir(parents=True, exist_ok=True) + (Path(save_directory) / "model.F16.gguf").write_bytes(b"GGUF") + + monkeypatch.setattr(mutils, "save_pretrained_gguf", fake_save_pretrained_gguf) + + mutils.push_to_hub_gguf( + model=object(), + tokenizer=object(), + save_directory=tmp_path, + repo_id="org/model", + quantization_method="not_quantized", + first_conversion="f16", + token="hf_token", + private=True, + ) + + assert calls["save"] == { + "quantization_method": "not_quantized", + "first_conversion": "f16", + } + assert calls["token"] == "hf_token" + assert calls["repo"] == { + "repo_id": "org/model", + "exist_ok": True, + "private": True, + } + assert calls["upload"]["path_in_repo"] == "model.F16.gguf" diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 98c265d63..2a21e966f 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -1338,7 +1338,12 @@ def _download_convert_hf_to_gguf_cached(name, _local_script_info, _conversion_in # 4. Write Patched File - patched_filename = os.path.join(LLAMA_CPP_DEFAULT_DIR, f"{name}.py") + # Package-layout converters import sibling modules from conversion/. + # Keep the patched entrypoint beside that package so subprocess + # execution resolves `from conversion import ...`. + patched_dir = _llama_cpp_dir if _layout == "package" else LLAMA_CPP_DEFAULT_DIR + os.makedirs(patched_dir, exist_ok=True) + patched_filename = os.path.join(patched_dir, f"{name}.py") logger.info(f"Unsloth: Saving patched script to {patched_filename}") with open(patched_filename, "wb") as file: file.write(patched_content) diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index efd1403f5..823c0aa3a 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -360,6 +360,154 @@ def _load_weights_without_projection_quant_state(self, file_or_weights, strict=T nn.Module.load_weights = original_load_weights +def _read_json_file(path): + """Read a JSON object, returning an empty dict for missing/bad sidecars.""" + try: + with open(path, "r", encoding="utf-8") as file: + return json.load(file) + except (FileNotFoundError, json.JSONDecodeError, OSError, UnicodeDecodeError): + return {} + + +def _resolve_mlx_vlm_processor_class(model_type, processor_class_name): + """Resolve a custom mlx-vlm or Transformers processor class by name.""" + if not processor_class_name: + return None + + module_model_type = (model_type or "").replace("-", "_") + module_candidates = ( + f"mlx_vlm.models.{module_model_type}.processing", + f"mlx_vlm.models.{module_model_type}.processing_{module_model_type}", + ) + for module_name in module_candidates: + try: + module = importlib.import_module(module_name) + except Exception: + continue + processor_class = getattr(module, processor_class_name, None) + if processor_class is not None: + return processor_class + + try: + import transformers + return getattr(transformers, processor_class_name, None) + except Exception: + return None + + +def _build_vlm_image_processor_from_config(model_path, processor_config, preprocessor_config): + """Recreate the image processor from saved processor sidecar configs.""" + image_config = processor_config.get("image_processor") + if not isinstance(image_config, dict): + image_config = preprocessor_config + if not isinstance(image_config, dict): + image_config = {} + + image_processor_type = ( + image_config.get("image_processor_type") + or preprocessor_config.get("image_processor_type") + ) + image_kwargs = dict(image_config) + image_kwargs.pop("image_processor_type", None) + image_kwargs.pop("processor_class", None) + + if image_processor_type: + try: + import transformers + image_processor_class = getattr(transformers, image_processor_type, None) + if image_processor_class is not None: + return image_processor_class(**image_kwargs) + except Exception: + pass + + try: + from transformers import AutoImageProcessor + return AutoImageProcessor.from_pretrained(model_path) + except Exception: + return None + + +def _repair_degraded_vlm_processor( + processor, + model_path, + model_type, + *, + token=None, + trust_remote_code=False, +): + """Rebuild VLM processors when mlx-vlm falls back to tokenizer-only. + + mlx-vlm registers several custom processors through an AutoProcessor patch. + If the custom processor's image processor cannot be constructed through + AutoImageProcessor, the patch falls back to the prior tokenizer-only loader. + Rebuild from the source processor configs so downstream saves preserve real + multimodal processor metadata. + """ + if processor is None or getattr(processor, "image_processor", None) is not None: + return processor + + if not model_path or not os.path.isdir(str(model_path)): + return processor + + processor_config = _read_json_file( + os.path.join(str(model_path), "processor_config.json") + ) + preprocessor_config = _read_json_file( + os.path.join(str(model_path), "preprocessor_config.json") + ) + processor_class_name = ( + processor_config.get("processor_class") + or preprocessor_config.get("processor_class") + ) + processor_class = _resolve_mlx_vlm_processor_class( + model_type, processor_class_name, + ) + if processor_class is None: + return processor + + image_processor = _build_vlm_image_processor_from_config( + model_path, processor_config, preprocessor_config, + ) + if image_processor is None: + return processor + + tokenizer = getattr(processor, "tokenizer", None) or processor + if tokenizer is None or not hasattr(tokenizer, "save_pretrained"): + try: + from transformers import AutoTokenizer + tokenizer_kwargs = {"trust_remote_code": trust_remote_code} + if token: + tokenizer_kwargs["token"] = token + tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs) + except Exception: + return processor + + chat_template = getattr(processor, "chat_template", None) + if chat_template is not None and getattr(tokenizer, "chat_template", None) is None: + tokenizer.chat_template = chat_template + + try: + repaired = processor_class( + image_processor=image_processor, + tokenizer=tokenizer, + chat_template=chat_template, + ) + except TypeError: + try: + repaired = processor_class( + image_processor=image_processor, + tokenizer=tokenizer, + ) + except Exception: + return processor + except Exception: + return processor + + if chat_template is not None and getattr(repaired, "chat_template", None) is None: + repaired.chat_template = chat_template + return repaired + + def _build_vlm_model_types(): """Build the set of model_type strings that mlx_vlm supports. @@ -2413,12 +2561,18 @@ def _mlx_save_pretrained_merged(self, save_directory, tokenizer=None, **kwargs): save_pretrained_merged(self, tokenizer, save_directory, **kwargs) +def _mlx_supported_kwargs(kwargs, supported): + """Keep CUDA-compatible kwargs out of MLX-only save/export APIs.""" + return {key: kwargs[key] for key in supported if key in kwargs} + + def _mlx_save_pretrained_gguf(self, save_directory, tokenizer=None, quantization_method="fast_quantized", **kwargs): from .utils import save_pretrained_gguf tokenizer = tokenizer or self._tokenizer + kwargs = _mlx_supported_kwargs(kwargs, ("first_conversion",)) save_pretrained_gguf(self, tokenizer, save_directory, - quantization_method=quantization_method) + quantization_method=quantization_method, **kwargs) def _mlx_push_to_hub_merged(self, repo_id, tokenizer=None, save_directory=None, **kwargs): @@ -2435,6 +2589,7 @@ def _mlx_push_to_hub_gguf(self, repo_id, tokenizer=None, quantization_method="fast_quantized", **kwargs): from .utils import push_to_hub_gguf tokenizer = tokenizer or self._tokenizer + kwargs = _mlx_supported_kwargs(kwargs, ("first_conversion", "token", "private")) push_to_hub_gguf(self, tokenizer, repo_id, repo_id=repo_id, quantization_method=quantization_method, **kwargs) @@ -3297,6 +3452,14 @@ def from_pretrained( hf_token=token, ) + processor = _repair_degraded_vlm_processor( + processor, + local_path or model_name, + model_type, + token=token, + trust_remote_code=trust_remote_code, + ) + if target_dtype is not None: _convert_mlx_dtype(model, target_dtype, model_type=model_type) elif want_runtime_quant: diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 88e90e5c8..922e2b918 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -27,18 +27,23 @@ import mlx.utils import copy import inspect +import importlib import json import numpy as np import os import sys import shutil import tempfile +import threading from pathlib import Path from .cce import _get_runtime_cce +_LLAMA_CPP_PATCHER_ENV_LOCK = threading.Lock() + + def _safe_token_denominator(ntoks): return mx.maximum(ntoks.astype(mx.float32), mx.array(1.0, dtype=mx.float32)) @@ -3417,21 +3422,55 @@ def _enrich_mlx_adapter_config(model, adapter_config): return adapter_config +def _config_to_plain_python(value): + """Recursively convert config dataclasses and containers to plain Python.""" + import dataclasses + + if dataclasses.is_dataclass(value) and not isinstance(value, type): + value = dataclasses.asdict(value) + elif isinstance(value, dict): + value = copy.deepcopy(value) + elif isinstance(value, (list, tuple)): + return [_config_to_plain_python(item) for item in value] + else: + return value + + if isinstance(value, dict): + return { + key: _config_to_plain_python(item) + for key, item in value.items() + } + return value + + def _get_model_config(model): """Extract config dict from an MLX model. mlx-lm stores the raw config dict at model._config when loaded. + mlx-vlm exposes config dataclasses at model.config. Falls back to reconstructing from model.args dataclass. """ + import dataclasses + # Prefer the raw config dict stashed by our loader if hasattr(model, "_config") and isinstance(model._config, dict): - return dict(model._config) + return _config_to_plain_python(model._config) + + if hasattr(model, "config"): + config = model.config + if isinstance(config, dict) or ( + dataclasses.is_dataclass(config) and not isinstance(config, type) + ): + return _config_to_plain_python(config) + if hasattr(config, "to_dict"): + config = config.to_dict() + if isinstance(config, dict): + return _config_to_plain_python(config) # Reconstruct from the ModelArgs dataclass if hasattr(model, "args"): - import dataclasses - if dataclasses.is_dataclass(model.args): - return dataclasses.asdict(model.args) + if dataclasses.is_dataclass(model.args) and not isinstance(model.args, type): + return _config_to_plain_python(model.args) return {} @@ -3441,6 +3480,466 @@ def _get_src_path(model): return getattr(model, "_src_path", None) +def _save_mlx_config(config, config_path, *, is_vlm=False): + """Save MLX config using the backend-aware upstream helper.""" + config = copy.deepcopy(config) + if is_vlm: + if "quantization" in config: + config["quantization_config"] = config["quantization"] + from mlx_vlm.utils import save_config as save_vlm_config + save_vlm_config(config, config_path) + else: + from mlx_lm.utils import save_config as save_lm_config + save_lm_config(config, config_path) + + +def _has_vision_config(config): + """Return whether a raw or thinker-wrapped VLM config has vision settings.""" + if not isinstance(config, dict): + return False + thinker_config = config.get("thinker_config") + return ( + "vision_config" in config + or ( + isinstance(thinker_config, dict) + and "vision_config" in thinker_config + ) + ) + + +class _MlxVlmSanitizeProxy: + """Minimal instance shim for mlx-vlm class sanitize methods.""" + def __init__(self, config): + self.config = config + self.args = config + + +def _copy_mlx_vlm_sanitize_weights(weights): + """Copy MLX arrays before replaying sanitizer transforms.""" + return { + key: mx.array(value) if isinstance(value, mx.array) else value + for key, value in weights.items() + } + + +def _call_mlx_vlm_sanitize(cls, config, weights): + """Call an mlx-vlm sanitize method with its expected signature.""" + sanitize = getattr(cls, "sanitize", None) + if sanitize is None: + return weights + + weights = _copy_mlx_vlm_sanitize_weights(weights) + params = inspect.signature(sanitize).parameters + if len(params) == 1: + return sanitize(weights) + return sanitize(_MlxVlmSanitizeProxy(config), weights) + + +def _add_mlx_vlm_sanitize_step(steps, module): + """Append a real mlx-vlm module sanitizer once, preserving order.""" + if module is None or getattr(module, "sanitize", None) is None: + return + if all(existing is not module for existing, _ in steps): + steps.append((module, None)) + + +def _get_mlx_vlm_model_sanitize_pipelines(model): + """Build sanitizer pipelines from a loaded mlx-vlm model and submodules.""" + if model is None or getattr(model, "sanitize", None) is None: + return [] + + model_step = [(model, None)] + pipelines = [model_step] + + extra_steps = [] + for attr in ("thinker", "vision_tower", "vision_model", "vision_encoder", "visual"): + _add_mlx_vlm_sanitize_step(extra_steps, getattr(model, attr, None)) + + thinker = getattr(model, "thinker", None) + for attr in ("vision_tower", "vision_model", "vision_encoder", "visual"): + _add_mlx_vlm_sanitize_step(extra_steps, getattr(thinker, attr, None)) + + for idx in range(len(extra_steps)): + pipelines.append(model_step + extra_steps[: idx + 1]) + + return pipelines + + +def _get_nested_config(config, *names): + """Walk nested config attributes, returning None for missing segments.""" + cur = config + for name in names: + cur = getattr(cur, name, None) + if cur is None: + return None + return cur + + +def _build_mlx_vlm_sanitize_steps(config): + """Build class-based mlx-vlm sanitizer steps from a saved VLM config.""" + if not _has_vision_config(config): + return [] + + try: + from mlx_vlm.utils import get_model_and_args, update_module_configs + + config_copy = copy.deepcopy(config) + model_module, model_type = get_model_and_args(config_copy) + config_copy.setdefault("text_config", config_copy.pop("llm_config", {})) + config_copy.setdefault("vision_config", {}) + config_copy.setdefault("audio_config", {}) + + model_config = model_module.ModelConfig.from_dict(config_copy) + try: + model_config = update_module_configs( + model_config, + model_module, + config_copy, + ["text", "vision", "perceiver", "projector", "audio"], + ) + except Exception: + pass + except Exception: + return [] + + steps = [] + if hasattr(model_module, "Model"): + steps.append((model_module.Model, model_config)) + + thinker_config = _get_nested_config(model_config, "thinker_config") + if thinker_config is not None: + thinker_cls = getattr(model_module, "Thinker", None) + if thinker_cls is None: + try: + thinker_mod = importlib.import_module( + f"mlx_vlm.models.{model_type}.thinker" + ) + thinker_cls = getattr(thinker_mod, "Thinker", None) + except Exception: + thinker_cls = None + if thinker_cls is not None: + steps.append((thinker_cls, thinker_config)) + + vision_config = ( + _get_nested_config(model_config, "vision_config") + or _get_nested_config(model_config, "thinker_config", "vision_config") + ) + if vision_config is not None and hasattr(model_module, "VisionModel"): + steps.append((model_module.VisionModel, vision_config)) + + return [ + (cls, step_config) + for cls, step_config in steps + if getattr(cls, "sanitize", None) is not None + ] + + +def _build_mlx_vlm_sanitize_pipelines(config, model=None): + """Combine real-model and config-derived sanitizer replay pipelines.""" + pipelines = _get_mlx_vlm_model_sanitize_pipelines(model) + class_steps = _build_mlx_vlm_sanitize_steps(config) + if class_steps: + pipelines.append(class_steps) + return pipelines + + +def _apply_mlx_vlm_sanitizers(steps, weights): + """Replay a sanitizer pipeline and return None if any step rejects it.""" + sanitized = dict(weights) + for cls, config in steps: + try: + sanitized = _call_mlx_vlm_sanitize(cls, config, sanitized) + except Exception: + return None + return sanitized + + +def _vlm_gguf_name_candidates(name): + """Yield HF/llama.cpp tensor-name candidates for an MLX VLM tensor.""" + candidates = [] + + def add(value): + if value not in candidates: + candidates.append(value) + + if name.startswith("thinker.vision_tower."): + suffix = name[len("thinker.vision_tower."):] + add(f"thinker.visual.{suffix}") + if name.startswith("model.vision_tower."): + suffix = name[len("model.vision_tower."):] + add(f"model.visual.{suffix}") + if name.startswith("vision_tower."): + suffix = name[len("vision_tower."):] + add(f"visual.{suffix}") + add(f"model.visual.{suffix}") + add(f"model.language_model.visual.{suffix}") + add(f"vit.{suffix}") + + add(name) + return candidates + + +def _vlm_gguf_tensor_candidates(tensor): + """Yield HF-layout tensor candidates for an MLX VLM tensor.""" + candidates = [] + shape = getattr(tensor, "shape", ()) + + if len(shape) == 5: + candidates.append(mx.transpose(tensor, (0, 4, 1, 2, 3))) + elif len(shape) == 4: + candidates.append(mx.transpose(tensor, (0, 3, 1, 2))) + + if len(shape) == 1 and mx.issubdtype(tensor.dtype, mx.floating): + candidates.append(tensor - 1) + + candidates.append(tensor) + return candidates + + +def _has_vlm_gguf_tensor_candidate(tensor): + """Return whether a tensor shape can require HF-layout recovery.""" + shape = getattr(tensor, "shape", ()) + if len(shape) in (4, 5): + return True + if len(shape) == 1: + dtype = getattr(tensor, "dtype", None) + return dtype is not None and mx.issubdtype(dtype, mx.floating) + return False + + +def _has_vlm_gguf_rewrite_candidate(name, tensor): + """Return whether a tensor can differ between mlx-vlm and GGUF layouts.""" + if any(candidate_name != name for candidate_name in _vlm_gguf_name_candidates(name)): + return True + return _has_vlm_gguf_tensor_candidate(tensor) + + +def _mlx_arrays_match(actual, expected): + """Compare MLX-like arrays without assuming a concrete backend type.""" + shape = getattr(actual, "shape", None) + if shape != getattr(expected, "shape", None): + return False + if actual is expected: + return True + if shape is None: + return actual == expected + try: + result = mx.all(actual == expected) + item = getattr(result, "item", None) + return bool(item() if callable(item) else result) + except Exception: + return False + + +def _is_mlx_vlm_sanitize_step(value): + """Return whether a value is one sanitizer step tuple.""" + return isinstance(value, tuple) and len(value) == 2 + + +def _normalize_mlx_vlm_sanitize_pipelines(sanitize_steps): + """Normalize legacy step lists and multi-pipeline sanitizer inputs.""" + if not sanitize_steps: + return [] + if all(_is_mlx_vlm_sanitize_step(step) for step in sanitize_steps): + return [sanitize_steps] + return sanitize_steps + + +def _rewrite_mlx_vlm_tensor_for_gguf(name, tensor, sanitize_steps): + """Invert mlx-vlm sanitizers to recover HF tensor names/layouts for GGUF.""" + if not _has_vlm_gguf_rewrite_candidate(name, tensor): + return name, tensor, False + + for candidate_name in _vlm_gguf_name_candidates(name): + for candidate_tensor in _vlm_gguf_tensor_candidates(tensor): + for pipeline in _normalize_mlx_vlm_sanitize_pipelines(sanitize_steps): + sanitized = _apply_mlx_vlm_sanitizers( + pipeline, + {candidate_name: candidate_tensor}, + ) + if not sanitized or len(sanitized) != 1: + continue + sanitized_name, sanitized_tensor = next(iter(sanitized.items())) + if sanitized_name != name: + continue + if not _mlx_arrays_match(sanitized_tensor, tensor): + continue + changed = ( + candidate_name != name + or not _mlx_arrays_match(candidate_tensor, tensor) + ) + if not changed: + continue + return candidate_name, candidate_tensor, True + + return name, tensor, False + + +def _sync_gguf_nextn_layer_config(config, model): + """Align speculative-layer config metadata with exported MLX layers.""" + if model is None or not isinstance(config, dict): + return False + + layers = _get_transformer_layers(model) + if layers is None: + return False + try: + actual_layers = len(layers) + except Exception: + return False + + thinker_config = config.get("thinker_config") + text_configs = [ + config.get("text_config"), + config.get("language_config"), + ( + thinker_config.get("text_config") + if isinstance(thinker_config, dict) + else None + ), + ] + changed = False + for text_config in text_configs: + if not isinstance(text_config, dict): + continue + num_hidden_layers = text_config.get("num_hidden_layers") + if not isinstance(num_hidden_layers, int): + continue + + actual_nextn = actual_layers - num_hidden_layers + for key in ( + "num_nextn_predict_layers", + "mtp_num_hidden_layers", + "nextn_predict_layers", + ): + num_nextn = text_config.get(key) + if not isinstance(num_nextn, int) or num_nextn <= 0: + continue + if actual_layers < num_hidden_layers: + continue + if actual_nextn >= num_nextn: + continue + if actual_nextn > 0: + text_config[key] = actual_nextn + else: + text_config.pop(key, None) + changed = True + + return changed + + +def _prepare_vlm_gguf_export_directory(path, model=None): + """Rewrite MLX-native VLM tensor names in the temporary GGUF export dir.""" + path = Path(path) + config_path = path / "config.json" + if not config_path.exists(): + return 0 + with open(config_path, "r") as f: + config = json.load(f) + config_changed = _sync_gguf_nextn_layer_config(config, model) + sanitize_steps = _build_mlx_vlm_sanitize_pipelines(config, model=model) + if not sanitize_steps: + if config_changed: + with open(config_path, "w") as f: + json.dump(config, f, indent=4) + return 0 + + rewritten = 0 + name_map = {} + for file in sorted(path.glob("*.safetensors")): + tensors = mx.load(str(file)) + updated = {} + file_rewritten = 0 + for name, tensor in tensors.items(): + new_name, tensor, changed = _rewrite_mlx_vlm_tensor_for_gguf( + name, tensor, sanitize_steps + ) + if new_name in updated: + raise RuntimeError( + f"Unsloth: duplicate tensor name after GGUF VLM rewrite: {new_name}" + ) + updated[new_name] = tensor + name_map[name] = new_name + file_rewritten += int(changed) + if file_rewritten: + # mx.load() may return arrays backed by the source safetensors file. + # Saving back to the same path can truncate those backing bytes before + # unchanged tensors are materialized, so write beside it and replace. + mx.eval(*updated.values()) + tmp_file = file.with_name(f"{file.stem}.tmp{file.suffix}") + mx.save_safetensors(str(tmp_file), updated, metadata={"format": "mlx"}) + os.replace(tmp_file, file) + rewritten += file_rewritten + + index_path = path / "model.safetensors.index.json" + if rewritten and index_path.exists(): + with open(index_path, "r") as f: + index_data = json.load(f) + weight_map = {} + for name, shard in index_data.get("weight_map", {}).items(): + new_name = name_map.get(name, name) + if new_name in weight_map: + raise RuntimeError( + f"Unsloth: duplicate index tensor name after GGUF VLM rewrite: {new_name}" + ) + weight_map[new_name] = shard + index_data["weight_map"] = dict(sorted(weight_map.items())) + with open(index_path, "w") as f: + json.dump(index_data, f, indent=4) + + if config_changed: + with open(config_path, "w") as f: + json.dump(config, f, indent=4) + + return rewritten + + +_CORE_SAVE_FILENAMES = { + "config.json", + "model.safetensors.index.json", + "README.md", + ".gitattributes", +} +_MODEL_WEIGHT_SUFFIXES = ( + ".safetensors", + ".bin", + ".gguf", + ".h5", + ".msgpack", + ".onnx", + ".pt", + ".pth", +) +_MODEL_SIDECAR_SUFFIXES = (".json", ".jinja", ".model", ".txt", ".py") + + +def _copy_source_sidecars(src_path, path): + """Copy non-weight source sidecars that tokenizer/model saves may omit.""" + copied = 0 + src_path = Path(src_path) + path = Path(path) + if not src_path.is_dir(): + return copied + for source in src_path.iterdir(): + if not source.is_file(): + continue + name = source.name + if name in _CORE_SAVE_FILENAMES: + continue + if name.startswith("model-") or name.startswith("pytorch_model"): + continue + suffix = source.suffix + if suffix in _MODEL_WEIGHT_SUFFIXES: + continue + if suffix not in _MODEL_SIDECAR_SUFFIXES: + continue + target = path / name + if target.exists(): + continue + shutil.copy2(source, target) + copied += 1 + return copied + def save_merged_model(model, tokenizer, path, dequantize=False): """Fuse LoRA weights and save the full merged model. @@ -3457,7 +3956,7 @@ def save_merged_model(model, tokenizer, path, dequantize=False): base quantization (smaller checkpoint, only meaningful when the base was quantized). """ - from mlx_lm.utils import save_model, save_config, create_model_card + from mlx_lm.utils import save_model, create_model_card, dequantize_model from mlx.utils import tree_unflatten path = Path(path) @@ -3474,6 +3973,7 @@ def save_merged_model(model, tokenizer, path, dequantize=False): model.update_modules(tree_unflatten(fused_linears)) if dequantize: + model = dequantize_model(model) cfg = getattr(model, "_config", None) if isinstance(cfg, dict): model._config = _strip_mlx_quantization_metadata(cfg) @@ -3486,20 +3986,20 @@ def save_merged_model(model, tokenizer, path, dequantize=False): # Save config.json config = _get_model_config(model) if config: - save_config(config, config_path=path / "config.json") + is_vlm = _is_vlm_model(model) or _has_vision_config(config) + _save_mlx_config( + config, + path / "config.json", + is_vlm=is_vlm, + ) # Save tokenizer tokenizer.save_pretrained(str(path)) - # Copy auxiliary files (generation_config.json, *.py) from source + # Copy auxiliary source files that tokenizer/model saves may omit. src_path = _get_src_path(model) if src_path is not None: - src_path = Path(src_path) - if src_path.exists(): - import glob as globmod - for pattern in ["generation_config.json", "*.py"]: - for f in globmod.glob(str(src_path / pattern)): - shutil.copy(f, path) + _copy_source_sidecars(src_path, path) # Model card hf_repo = getattr(model, "_hf_repo", None) @@ -3917,6 +4417,7 @@ def save_pretrained_gguf( quantize_gguf, install_llama_cpp, check_llama_cpp, + LLAMA_CPP_DEFAULT_DIR, _download_convert_hf_to_gguf, ) @@ -3956,16 +4457,25 @@ def save_pretrained_gguf( # Step 1: Save merged model to a temp HF-format directory with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) / "merged" + is_vlm_model = _is_vlm_model(model) print("Unsloth: Merging LoRA weights and saving to 16-bit...") save_merged_model(model, tokenizer, tmp_path, dequantize=True) + if is_vlm_model: + rewritten = _prepare_vlm_gguf_export_directory(tmp_path, model=model) + if rewritten: + print( + "Unsloth: Rewrote " + f"{rewritten} MLX VLM tensors for llama.cpp GGUF export." + ) # Step 2: Ensure llama.cpp is installed and gguf package is available - llama_cpp_folder = "llama.cpp" + llama_cpp_folder = LLAMA_CPP_DEFAULT_DIR try: - check_llama_cpp(llama_cpp_folder) + quantizer_location, converter_location = check_llama_cpp(llama_cpp_folder) except Exception: print("Unsloth: Installing llama.cpp (this only happens once)...") - _install_llama_cpp_macos(llama_cpp_folder) + quantizer_location, converter_location = install_llama_cpp(llama_cpp_folder) + llama_cpp_folder = os.path.dirname(converter_location) # Ensure gguf Python package is installed (may be missing if # llama.cpp was built in a different venv) @@ -3992,7 +4502,17 @@ def save_pretrained_gguf( converter = os.path.join(llama_cpp_folder, "unsloth_convert_hf_to_gguf.py") supported_text_archs = None supported_vision_archs = None - result = _download_convert_hf_to_gguf() # no args — uses defaults + with _LLAMA_CPP_PATCHER_ENV_LOCK: + old_scripts_dir = os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") + if old_scripts_dir is None: + os.environ["UNSLOTH_LLAMA_CPP_SCRIPTS_DIR"] = llama_cpp_folder + try: + result = _download_convert_hf_to_gguf() + finally: + if old_scripts_dir is None: + os.environ.pop("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR", None) + else: + os.environ["UNSLOTH_LLAMA_CPP_SCRIPTS_DIR"] = old_scripts_dir if isinstance(result, tuple) and len(result) >= 3: converter, supported_text_archs, supported_vision_archs = result[:3] elif isinstance(result, str): @@ -4009,7 +4529,6 @@ def save_pretrained_gguf( # Step 5: Convert HF -> GGUF print(f"Unsloth: Converting to GGUF format...") - is_vlm_model = bool(getattr(model, "_is_vlm_model", False)) kwargs = dict( model_name=output_base, input_folder=str(tmp_path), @@ -4027,7 +4546,7 @@ def save_pretrained_gguf( # Step 6: Quantize if the target quant differs from first_conversion if quant_type not in ("bf16", "f16", "f32") and first_conversion != quant_type: - quantizer = os.path.join(llama_cpp_folder, "llama-quantize") + quantizer = quantizer_location base_gguf = f"{output_base}.{first_conversion.upper()}.gguf" final_gguf = f"{output_base}.{quant_type.upper()}.gguf" @@ -4203,6 +4722,7 @@ def push_to_hub_gguf( save_directory, repo_id, quantization_method="fast_quantized", + first_conversion=None, token=None, private=None, ): @@ -4214,6 +4734,8 @@ def push_to_hub_gguf( save_directory: Local path for GGUF output. repo_id: HuggingFace repo ID. quantization_method: GGUF quantization type. + first_conversion: Optional intermediate GGUF dtype passed through to + save_pretrained_gguf. token: HuggingFace token. private: Whether repo should be private. """ @@ -4222,7 +4744,13 @@ def push_to_hub_gguf( save_directory = Path(save_directory) # Export to GGUF - save_pretrained_gguf(model, tokenizer, save_directory, quantization_method) + save_pretrained_gguf( + model, + tokenizer, + save_directory, + quantization_method=quantization_method, + first_conversion=first_conversion, + ) # Upload GGUF files api = HfApi(token=token)