From 615aff8ef868b3f6222b8f8748631b45bc4aaf3b Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 27 May 2026 15:26:33 +0800 Subject: [PATCH 01/16] fix(mlx): preserve vlm merged config saves --- unsloth_zoo/mlx/utils.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 88e90e5c8..15b09718e 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -3441,6 +3441,19 @@ def _get_src_path(model): return getattr(model, "_src_path", None) +def _save_mlx_config(config, config_path, *, is_vlm=False): + """Save MLX config using the backend-aware upstream helper.""" + config = copy.deepcopy(config) + if is_vlm: + if "quantization" in config: + config["quantization_config"] = config["quantization"] + from mlx_vlm.utils import save_config as save_vlm_config + save_vlm_config(config, config_path) + else: + from mlx_lm.utils import save_config as save_lm_config + save_lm_config(config, config_path) + + def save_merged_model(model, tokenizer, path, dequantize=False): """Fuse LoRA weights and save the full merged model. @@ -3457,7 +3470,7 @@ def save_merged_model(model, tokenizer, path, dequantize=False): base quantization (smaller checkpoint, only meaningful when the base was quantized). """ - from mlx_lm.utils import save_model, save_config, create_model_card + from mlx_lm.utils import save_model, create_model_card from mlx.utils import tree_unflatten path = Path(path) @@ -3486,7 +3499,11 @@ def save_merged_model(model, tokenizer, path, dequantize=False): # Save config.json config = _get_model_config(model) if config: - save_config(config, config_path=path / "config.json") + _save_mlx_config( + config, + path / "config.json", + is_vlm=_is_vlm_model(model) or "vision_config" in config, + ) # Save tokenizer tokenizer.save_pretrained(str(path)) From 55fa1737784013ca5bab456a0dedffb67318b976 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 27 May 2026 15:26:42 +0800 Subject: [PATCH 02/16] fix(mlx): dequantize qlora merged saves --- unsloth_zoo/mlx/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 15b09718e..869af143d 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -3470,7 +3470,7 @@ def save_merged_model(model, tokenizer, path, dequantize=False): base quantization (smaller checkpoint, only meaningful when the base was quantized). """ - from mlx_lm.utils import save_model, create_model_card + from mlx_lm.utils import save_model, create_model_card, dequantize_model from mlx.utils import tree_unflatten path = Path(path) @@ -3487,6 +3487,7 @@ def save_merged_model(model, tokenizer, path, dequantize=False): model.update_modules(tree_unflatten(fused_linears)) if dequantize: + model = dequantize_model(model) cfg = getattr(model, "_config", None) if isinstance(cfg, dict): model._config = _strip_mlx_quantization_metadata(cfg) From d8cdf89e90c5122a38ab66766ee3b2f11877160f Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 02:07:47 +0800 Subject: [PATCH 03/16] fix(mlx): materialize tied lm head exports --- unsloth_zoo/mlx/utils.py | 211 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 210 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 869af143d..20c316933 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -3454,6 +3454,203 @@ def _save_mlx_config(config, config_path, *, is_vlm=False): save_lm_config(config, config_path) +def _has_tied_word_embeddings(config): + """Return whether any text config declares tied input/output embeddings.""" + if not isinstance(config, dict): + return False + candidates = [ + config, + config.get("text_config"), + config.get("language_config"), + (config.get("thinker_config") or {}).get("text_config"), + ] + return any( + isinstance(item, dict) and item.get("tie_word_embeddings") is True + for item in candidates + ) + + +def _lm_head_key_for_embed_key(embed_key): + """Map an input embedding tensor key to the matching tied LM head key.""" + if embed_key == "embed_tokens.weight": + return "lm_head.weight" + if embed_key == "model.embed_tokens.weight": + return "lm_head.weight" + + suffix = ".model.embed_tokens.weight" + if embed_key.endswith(suffix): + prefix = embed_key[:-len(suffix)] + return f"{prefix}.lm_head.weight" if prefix else "lm_head.weight" + + suffix = ".embed_tokens.weight" + if embed_key.endswith(suffix): + prefix = embed_key[:-len(suffix)] + if not prefix or prefix == "model": + return "lm_head.weight" + return f"{prefix}.lm_head.weight" + + return None + + +def _safetensor_names(path): + """Read tensor names from a safetensors directory or index file.""" + path = Path(path) + index_path = path / "model.safetensors.index.json" + if index_path.exists(): + with open(index_path, "r") as f: + return set(json.load(f).get("weight_map", {})) + + try: + from safetensors import safe_open + except Exception: + return set() + + names = set() + for file in sorted(path.glob("*.safetensors")): + with safe_open(str(file), framework="np") as f: + names.update(f.keys()) + return names + + +def _source_has_lm_head_tensor(source_path, lm_head_key): + """Check whether the source checkpoint explicitly stored an LM head.""" + if source_path is None: + return None + source_path = Path(source_path) + if not source_path.exists(): + return None + + names = _safetensor_names(source_path) + if not names: + return None + if lm_head_key in names: + return True + if lm_head_key.endswith(".lm_head.weight") and "lm_head.weight" in names: + return True + return False + + +def _tensor_nbytes(tensor): + """Return tensor byte size across MLX, NumPy, and torch-like objects.""" + value = getattr(tensor, "nbytes", None) + if callable(value): + value = value() + if value is not None: + return int(value) + itemsize = getattr(tensor, "itemsize", None) + if callable(itemsize): + itemsize = itemsize() + if itemsize is None: + element_size = getattr(tensor, "element_size", None) + itemsize = element_size() if callable(element_size) else 0 + return int(_tensor_size(tensor) * itemsize) + + +def _tensor_size(tensor): + """Return tensor element count across MLX, NumPy, and torch-like objects.""" + value = getattr(tensor, "size", None) + if callable(value): + numel = getattr(tensor, "numel", None) + if callable(numel): + return int(numel()) + shape = getattr(tensor, "shape", ()) + total = 1 + for dim in shape: + total *= int(dim) + return total + if value is not None: + return int(value) + return 0 + + +def _duplicate_tensor_for_safetensors(tensor): + """Clone tensors when available before writing a tied duplicate key.""" + clone = getattr(tensor, "clone", None) + if callable(clone): + return clone() + return tensor + + +def _materialize_tied_lm_head_in_saved_model( + path, + config, + *, + source_path=None, + is_vlm=False, +): + """Duplicate tied input embeddings into the saved LM head when CUDA does.""" + if not _has_tied_word_embeddings(config): + return 0 + + path = Path(path) + index_path = path / "model.safetensors.index.json" + if not index_path.exists(): + return 0 + + with open(index_path, "r") as f: + index_data = json.load(f) + + weight_map = dict(index_data.get("weight_map", {})) + additions = [] + for embed_key in sorted(weight_map): + if not embed_key.endswith("embed_tokens.weight"): + continue + lm_head_key = _lm_head_key_for_embed_key(embed_key) + if lm_head_key is None or lm_head_key in weight_map: + continue + + source_has_lm_head = _source_has_lm_head_tensor(source_path, lm_head_key) + if source_has_lm_head is False: + continue + if source_has_lm_head is None and is_vlm: + continue + if source_has_lm_head is None and lm_head_key != "lm_head.weight": + continue + + additions.append((embed_key, lm_head_key)) + + added = 0 + added_bytes = 0 + added_parameters = 0 + for embed_key, lm_head_key in additions: + shard_name = weight_map[embed_key] + shard_path = path / shard_name + tensors = mx.load(str(shard_path)) + if lm_head_key in tensors: + weight_map[lm_head_key] = shard_name + continue + if embed_key not in tensors: + continue + + tensor = tensors[embed_key] + tensors[lm_head_key] = _duplicate_tensor_for_safetensors(tensor) + mx.eval(*tensors.values()) + tmp_file = shard_path.with_name(f"{shard_path.stem}.tmp{shard_path.suffix}") + mx.save_safetensors(str(tmp_file), tensors, metadata={"format": "mlx"}) + os.replace(tmp_file, shard_path) + + weight_map[lm_head_key] = shard_name + added += 1 + added_bytes += _tensor_nbytes(tensor) + added_parameters += _tensor_size(tensor) + + if added: + metadata = index_data.setdefault("metadata", {}) + if "total_size" in metadata: + metadata["total_size"] = int(metadata["total_size"]) + added_bytes + if "total_parameters" in metadata: + metadata["total_parameters"] = ( + int(metadata["total_parameters"]) + added_parameters + ) + index_data["weight_map"] = { + key: weight_map[key] for key in sorted(weight_map) + } + with open(index_path, "w") as f: + json.dump(index_data, f, indent=4) + + return added + + def save_merged_model(model, tokenizer, path, dequantize=False): """Fuse LoRA weights and save the full merged model. @@ -3500,10 +3697,22 @@ def save_merged_model(model, tokenizer, path, dequantize=False): # Save config.json config = _get_model_config(model) if config: + is_vlm = _is_vlm_model(model) or "vision_config" in config + materialized = _materialize_tied_lm_head_in_saved_model( + path, + config, + source_path=_get_src_path(model), + is_vlm=is_vlm, + ) + if materialized: + print( + "Unsloth: Materialized " + f"{materialized} tied lm_head tensor(s) for CUDA export parity." + ) _save_mlx_config( config, path / "config.json", - is_vlm=_is_vlm_model(model) or "vision_config" in config, + is_vlm=is_vlm, ) # Save tokenizer From 0029c545defa5eb0376b25e9953470ef50d4d51d Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 14:51:52 +0800 Subject: [PATCH 04/16] fix(mlx): forward gguf export options --- unsloth_zoo/mlx/loader.py | 9 ++++++++- unsloth_zoo/mlx/utils.py | 11 ++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index efd1403f5..6127791a6 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -2413,12 +2413,18 @@ def _mlx_save_pretrained_merged(self, save_directory, tokenizer=None, **kwargs): save_pretrained_merged(self, tokenizer, save_directory, **kwargs) +def _mlx_supported_kwargs(kwargs, supported): + """Keep CUDA-compatible kwargs out of MLX-only save/export APIs.""" + return {key: kwargs[key] for key in supported if key in kwargs} + + def _mlx_save_pretrained_gguf(self, save_directory, tokenizer=None, quantization_method="fast_quantized", **kwargs): from .utils import save_pretrained_gguf tokenizer = tokenizer or self._tokenizer + kwargs = _mlx_supported_kwargs(kwargs, ("first_conversion",)) save_pretrained_gguf(self, tokenizer, save_directory, - quantization_method=quantization_method) + quantization_method=quantization_method, **kwargs) def _mlx_push_to_hub_merged(self, repo_id, tokenizer=None, save_directory=None, **kwargs): @@ -2435,6 +2441,7 @@ def _mlx_push_to_hub_gguf(self, repo_id, tokenizer=None, quantization_method="fast_quantized", **kwargs): from .utils import push_to_hub_gguf tokenizer = tokenizer or self._tokenizer + kwargs = _mlx_supported_kwargs(kwargs, ("first_conversion", "token", "private")) push_to_hub_gguf(self, tokenizer, repo_id, repo_id=repo_id, quantization_method=quantization_method, **kwargs) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 20c316933..33992be28 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -4430,6 +4430,7 @@ def push_to_hub_gguf( save_directory, repo_id, quantization_method="fast_quantized", + first_conversion=None, token=None, private=None, ): @@ -4441,6 +4442,8 @@ def push_to_hub_gguf( save_directory: Local path for GGUF output. repo_id: HuggingFace repo ID. quantization_method: GGUF quantization type. + first_conversion: Optional intermediate GGUF dtype passed through to + save_pretrained_gguf. token: HuggingFace token. private: Whether repo should be private. """ @@ -4449,7 +4452,13 @@ def push_to_hub_gguf( save_directory = Path(save_directory) # Export to GGUF - save_pretrained_gguf(model, tokenizer, save_directory, quantization_method) + save_pretrained_gguf( + model, + tokenizer, + save_directory, + quantization_method=quantization_method, + first_conversion=first_conversion, + ) # Upload GGUF files api = HfApi(token=token) From 6fcaa19a1c159a17870e955748081f53a7f7d6b0 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 25 May 2026 21:14:41 +0800 Subject: [PATCH 05/16] fix(mlx): normalize vlm mmproj export tensors --- unsloth_zoo/mlx/utils.py | 260 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 259 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 33992be28..edb98b4ae 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -27,6 +27,7 @@ import mlx.utils import copy import inspect +import importlib import json import numpy as np import os @@ -3651,6 +3652,256 @@ def _materialize_tied_lm_head_in_saved_model( return added +def _has_vision_config(config): + """Return whether a raw or thinker-wrapped VLM config has vision settings.""" + thinker_config = config.get("thinker_config") or {} + return "vision_config" in config or "vision_config" in thinker_config + + +class _MlxVlmSanitizeProxy: + """Minimal instance shim for mlx-vlm class sanitize methods.""" + def __init__(self, config): + self.config = config + self.args = config + + +def _copy_mlx_vlm_sanitize_weights(weights): + """Copy MLX arrays before replaying sanitizer transforms.""" + return { + key: mx.array(value) if isinstance(value, mx.array) else value + for key, value in weights.items() + } + + +def _call_mlx_vlm_sanitize(cls, config, weights): + """Call an mlx-vlm sanitize method with its expected signature.""" + sanitize = getattr(cls, "sanitize", None) + if sanitize is None: + return weights + + weights = _copy_mlx_vlm_sanitize_weights(weights) + params = inspect.signature(sanitize).parameters + if len(params) == 1: + return sanitize(weights) + return sanitize(_MlxVlmSanitizeProxy(config), weights) + + +def _get_nested_config(config, *names): + """Walk nested config attributes, returning None for missing segments.""" + cur = config + for name in names: + cur = getattr(cur, name, None) + if cur is None: + return None + return cur + + +def _build_mlx_vlm_sanitize_steps(config): + """Build class-based mlx-vlm sanitizer steps from a saved VLM config.""" + if not _has_vision_config(config): + return [] + + try: + from mlx_vlm.utils import get_model_and_args, update_module_configs + + config_copy = copy.deepcopy(config) + model_module, model_type = get_model_and_args(config_copy) + config_copy.setdefault("text_config", config_copy.pop("llm_config", {})) + config_copy.setdefault("vision_config", {}) + config_copy.setdefault("audio_config", {}) + + model_config = model_module.ModelConfig.from_dict(config_copy) + try: + model_config = update_module_configs( + model_config, + model_module, + config_copy, + ["text", "vision", "perceiver", "projector", "audio"], + ) + except Exception: + pass + except Exception: + return [] + + steps = [] + if hasattr(model_module, "Model"): + steps.append((model_module.Model, model_config)) + + thinker_config = _get_nested_config(model_config, "thinker_config") + if thinker_config is not None: + thinker_cls = getattr(model_module, "Thinker", None) + if thinker_cls is None: + try: + thinker_mod = importlib.import_module( + f"mlx_vlm.models.{model_type}.thinker" + ) + thinker_cls = getattr(thinker_mod, "Thinker", None) + except Exception: + thinker_cls = None + if thinker_cls is not None: + steps.append((thinker_cls, thinker_config)) + + vision_config = ( + _get_nested_config(model_config, "vision_config") + or _get_nested_config(model_config, "thinker_config", "vision_config") + ) + if vision_config is not None and hasattr(model_module, "VisionModel"): + steps.append((model_module.VisionModel, vision_config)) + + return [ + (cls, step_config) + for cls, step_config in steps + if getattr(cls, "sanitize", None) is not None + ] + + +def _apply_mlx_vlm_sanitizers(steps, weights): + """Replay a sanitizer pipeline and return None if any step rejects it.""" + sanitized = dict(weights) + for cls, config in steps: + try: + sanitized = _call_mlx_vlm_sanitize(cls, config, sanitized) + except Exception: + return None + return sanitized + + +def _vlm_gguf_name_candidates(name): + """Yield HF/llama.cpp tensor-name candidates for an MLX VLM tensor.""" + candidates = [] + + def add(value): + if value != name and value not in candidates: + candidates.append(value) + + if name.startswith("thinker.vision_tower."): + suffix = name[len("thinker.vision_tower."):] + add(f"thinker.visual.{suffix}") + if name.startswith("model.vision_tower."): + suffix = name[len("model.vision_tower."):] + add(f"model.visual.{suffix}") + if name.startswith("vision_tower."): + suffix = name[len("vision_tower."):] + add(f"visual.{suffix}") + add(f"model.visual.{suffix}") + add(f"model.language_model.visual.{suffix}") + add(f"vit.{suffix}") + + return candidates + + +def _vlm_gguf_tensor_candidates(tensor): + """Yield HF-layout tensor candidates for an MLX VLM tensor.""" + candidates = [] + shape = getattr(tensor, "shape", ()) + + if len(shape) == 5: + candidates.append(mx.transpose(tensor, (0, 4, 1, 2, 3))) + elif len(shape) == 4: + candidates.append(mx.transpose(tensor, (0, 3, 1, 2))) + + if len(shape) == 1 and mx.issubdtype(tensor.dtype, mx.floating): + candidates.append(tensor - 1) + + candidates.append(tensor) + return candidates + + +def _mlx_arrays_match(actual, expected): + """Compare MLX-like arrays without assuming a concrete backend type.""" + shape = getattr(actual, "shape", None) + if shape != getattr(expected, "shape", None): + return False + if actual is expected: + return True + if shape is None: + return actual == expected + if len(shape) not in (1, 4, 5): + return True + try: + return bool(mx.all(actual == expected).item()) + except Exception: + return False + + +def _rewrite_mlx_vlm_tensor_for_gguf(name, tensor, sanitize_steps): + """Invert mlx-vlm sanitizers to recover HF tensor names/layouts for GGUF.""" + for candidate_name in _vlm_gguf_name_candidates(name): + for candidate_tensor in _vlm_gguf_tensor_candidates(tensor): + sanitized = _apply_mlx_vlm_sanitizers( + sanitize_steps, + {candidate_name: candidate_tensor}, + ) + if not sanitized or len(sanitized) != 1: + continue + sanitized_name, sanitized_tensor = next(iter(sanitized.items())) + if sanitized_name != name: + continue + if not _mlx_arrays_match(sanitized_tensor, tensor): + continue + return candidate_name, candidate_tensor, True + + return name, tensor, False + + +def _prepare_vlm_gguf_export_directory(path): + """Rewrite MLX-native VLM tensor names in the temporary GGUF export dir.""" + path = Path(path) + config_path = path / "config.json" + if not config_path.exists(): + return 0 + with open(config_path, "r") as f: + config = json.load(f) + sanitize_steps = _build_mlx_vlm_sanitize_steps(config) + if not sanitize_steps: + return 0 + + rewritten = 0 + name_map = {} + for file in sorted(path.glob("*.safetensors")): + tensors = mx.load(str(file)) + updated = {} + file_rewritten = 0 + for name, tensor in tensors.items(): + new_name, tensor, changed = _rewrite_mlx_vlm_tensor_for_gguf( + name, tensor, sanitize_steps + ) + if new_name in updated: + raise RuntimeError( + f"Unsloth: duplicate tensor name after GGUF VLM rewrite: {new_name}" + ) + updated[new_name] = tensor + name_map[name] = new_name + file_rewritten += int(changed) + if file_rewritten: + # mx.load() may return arrays backed by the source safetensors file. + # Saving back to the same path can truncate those backing bytes before + # unchanged tensors are materialized, so write beside it and replace. + mx.eval(*updated.values()) + tmp_file = file.with_name(f"{file.stem}.tmp{file.suffix}") + mx.save_safetensors(str(tmp_file), updated, metadata={"format": "mlx"}) + os.replace(tmp_file, file) + rewritten += file_rewritten + + index_path = path / "model.safetensors.index.json" + if rewritten and index_path.exists(): + with open(index_path, "r") as f: + index_data = json.load(f) + weight_map = {} + for name, shard in index_data.get("weight_map", {}).items(): + new_name = name_map.get(name, name) + if new_name in weight_map: + raise RuntimeError( + f"Unsloth: duplicate index tensor name after GGUF VLM rewrite: {new_name}" + ) + weight_map[new_name] = shard + index_data["weight_map"] = dict(sorted(weight_map.items())) + with open(index_path, "w") as f: + json.dump(index_data, f, indent=4) + + return rewritten + + def save_merged_model(model, tokenizer, path, dequantize=False): """Fuse LoRA weights and save the full merged model. @@ -4183,8 +4434,16 @@ def save_pretrained_gguf( # Step 1: Save merged model to a temp HF-format directory with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) / "merged" + is_vlm_model = _is_vlm_model(model) print("Unsloth: Merging LoRA weights and saving to 16-bit...") save_merged_model(model, tokenizer, tmp_path, dequantize=True) + if is_vlm_model: + rewritten = _prepare_vlm_gguf_export_directory(tmp_path) + if rewritten: + print( + "Unsloth: Rewrote " + f"{rewritten} MLX VLM tensors for llama.cpp GGUF export." + ) # Step 2: Ensure llama.cpp is installed and gguf package is available llama_cpp_folder = "llama.cpp" @@ -4236,7 +4495,6 @@ def save_pretrained_gguf( # Step 5: Convert HF -> GGUF print(f"Unsloth: Converting to GGUF format...") - is_vlm_model = bool(getattr(model, "_is_vlm_model", False)) kwargs = dict( model_name=output_base, input_folder=str(tmp_path), From 97c980fcad41609608f365268588c5bd01915697 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 17:17:12 +0800 Subject: [PATCH 06/16] fix(mlx): rewrite VLM GGUF tensor candidates Prefer HF vision aliases when inverting mlx-vlm sanitizers, while keeping same-name tensor rewrites as a fallback. This preserves Gemma3 patch embedding layout fixes and avoids stopping early on Qwen-family MLX vision_tower names. --- unsloth_zoo/mlx/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index edb98b4ae..df8109fa5 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -3771,7 +3771,7 @@ def _vlm_gguf_name_candidates(name): candidates = [] def add(value): - if value != name and value not in candidates: + if value not in candidates: candidates.append(value) if name.startswith("thinker.vision_tower."): @@ -3787,6 +3787,7 @@ def add(value): add(f"model.language_model.visual.{suffix}") add(f"vit.{suffix}") + add(name) return candidates @@ -3839,6 +3840,12 @@ def _rewrite_mlx_vlm_tensor_for_gguf(name, tensor, sanitize_steps): continue if not _mlx_arrays_match(sanitized_tensor, tensor): continue + changed = ( + candidate_name != name + or not _mlx_arrays_match(candidate_tensor, tensor) + ) + if not changed: + continue return candidate_name, candidate_tensor, True return name, tensor, False From c7002e78d582db9cfa0b66e545c024a898317117 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 18:38:40 +0800 Subject: [PATCH 07/16] fix(mlx): replay model sanitizers for VLM GGUF Use loaded VLM model instances when replaying mlx-vlm sanitizers for GGUF tensor rewrites, while keeping the config-derived class pipeline as a fallback. This covers models whose top-level sanitizer delegates through submodules, such as GLM-OCR. --- unsloth_zoo/mlx/utils.py | 96 +++++++++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 21 deletions(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index df8109fa5..a788f1037 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -3686,6 +3686,36 @@ def _call_mlx_vlm_sanitize(cls, config, weights): return sanitize(_MlxVlmSanitizeProxy(config), weights) +def _add_mlx_vlm_sanitize_step(steps, module): + """Append a real mlx-vlm module sanitizer once, preserving order.""" + if module is None or getattr(module, "sanitize", None) is None: + return + if all(existing is not module for existing, _ in steps): + steps.append((module, None)) + + +def _get_mlx_vlm_model_sanitize_pipelines(model): + """Build sanitizer pipelines from a loaded mlx-vlm model and submodules.""" + if model is None or getattr(model, "sanitize", None) is None: + return [] + + model_step = [(model, None)] + pipelines = [model_step] + + extra_steps = [] + for attr in ("thinker", "vision_tower", "vision_model", "vision_encoder", "visual"): + _add_mlx_vlm_sanitize_step(extra_steps, getattr(model, attr, None)) + + thinker = getattr(model, "thinker", None) + for attr in ("vision_tower", "vision_model", "vision_encoder", "visual"): + _add_mlx_vlm_sanitize_step(extra_steps, getattr(thinker, attr, None)) + + for idx in range(len(extra_steps)): + pipelines.append(model_step + extra_steps[: idx + 1]) + + return pipelines + + def _get_nested_config(config, *names): """Walk nested config attributes, returning None for missing segments.""" cur = config @@ -3755,6 +3785,15 @@ def _build_mlx_vlm_sanitize_steps(config): ] +def _build_mlx_vlm_sanitize_pipelines(config, model=None): + """Combine real-model and config-derived sanitizer replay pipelines.""" + pipelines = _get_mlx_vlm_model_sanitize_pipelines(model) + class_steps = _build_mlx_vlm_sanitize_steps(config) + if class_steps: + pipelines.append(class_steps) + return pipelines + + def _apply_mlx_vlm_sanitizers(steps, weights): """Replay a sanitizer pipeline and return None if any step rejects it.""" sanitized = dict(weights) @@ -3825,33 +3864,48 @@ def _mlx_arrays_match(actual, expected): return False +def _is_mlx_vlm_sanitize_step(value): + """Return whether a value is one sanitizer step tuple.""" + return isinstance(value, tuple) and len(value) == 2 + + +def _normalize_mlx_vlm_sanitize_pipelines(sanitize_steps): + """Normalize legacy step lists and multi-pipeline sanitizer inputs.""" + if not sanitize_steps: + return [] + if all(_is_mlx_vlm_sanitize_step(step) for step in sanitize_steps): + return [sanitize_steps] + return sanitize_steps + + def _rewrite_mlx_vlm_tensor_for_gguf(name, tensor, sanitize_steps): """Invert mlx-vlm sanitizers to recover HF tensor names/layouts for GGUF.""" for candidate_name in _vlm_gguf_name_candidates(name): for candidate_tensor in _vlm_gguf_tensor_candidates(tensor): - sanitized = _apply_mlx_vlm_sanitizers( - sanitize_steps, - {candidate_name: candidate_tensor}, - ) - if not sanitized or len(sanitized) != 1: - continue - sanitized_name, sanitized_tensor = next(iter(sanitized.items())) - if sanitized_name != name: - continue - if not _mlx_arrays_match(sanitized_tensor, tensor): - continue - changed = ( - candidate_name != name - or not _mlx_arrays_match(candidate_tensor, tensor) - ) - if not changed: - continue - return candidate_name, candidate_tensor, True + for pipeline in _normalize_mlx_vlm_sanitize_pipelines(sanitize_steps): + sanitized = _apply_mlx_vlm_sanitizers( + pipeline, + {candidate_name: candidate_tensor}, + ) + if not sanitized or len(sanitized) != 1: + continue + sanitized_name, sanitized_tensor = next(iter(sanitized.items())) + if sanitized_name != name: + continue + if not _mlx_arrays_match(sanitized_tensor, tensor): + continue + changed = ( + candidate_name != name + or not _mlx_arrays_match(candidate_tensor, tensor) + ) + if not changed: + continue + return candidate_name, candidate_tensor, True return name, tensor, False -def _prepare_vlm_gguf_export_directory(path): +def _prepare_vlm_gguf_export_directory(path, model=None): """Rewrite MLX-native VLM tensor names in the temporary GGUF export dir.""" path = Path(path) config_path = path / "config.json" @@ -3859,7 +3913,7 @@ def _prepare_vlm_gguf_export_directory(path): return 0 with open(config_path, "r") as f: config = json.load(f) - sanitize_steps = _build_mlx_vlm_sanitize_steps(config) + sanitize_steps = _build_mlx_vlm_sanitize_pipelines(config, model=model) if not sanitize_steps: return 0 @@ -4445,7 +4499,7 @@ def save_pretrained_gguf( print("Unsloth: Merging LoRA weights and saving to 16-bit...") save_merged_model(model, tokenizer, tmp_path, dequantize=True) if is_vlm_model: - rewritten = _prepare_vlm_gguf_export_directory(tmp_path) + rewritten = _prepare_vlm_gguf_export_directory(tmp_path, model=model) if rewritten: print( "Unsloth: Rewrote " From 962dec8537d438fe9ce3790ac86d231a643e98d9 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 19:20:53 +0800 Subject: [PATCH 08/16] fix(mlx): extract vlm config objects for saves --- unsloth_zoo/mlx/utils.py | 42 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index a788f1037..1dc5b5fce 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -3418,21 +3418,55 @@ def _enrich_mlx_adapter_config(model, adapter_config): return adapter_config +def _config_to_plain_python(value): + """Recursively convert config dataclasses and containers to plain Python.""" + import dataclasses + + if dataclasses.is_dataclass(value) and not isinstance(value, type): + value = dataclasses.asdict(value) + elif isinstance(value, dict): + value = copy.deepcopy(value) + elif isinstance(value, (list, tuple)): + return [_config_to_plain_python(item) for item in value] + else: + return value + + if isinstance(value, dict): + return { + key: _config_to_plain_python(item) + for key, item in value.items() + } + return value + + def _get_model_config(model): """Extract config dict from an MLX model. mlx-lm stores the raw config dict at model._config when loaded. + mlx-vlm exposes config dataclasses at model.config. Falls back to reconstructing from model.args dataclass. """ + import dataclasses + # Prefer the raw config dict stashed by our loader if hasattr(model, "_config") and isinstance(model._config, dict): - return dict(model._config) + return _config_to_plain_python(model._config) + + if hasattr(model, "config"): + config = model.config + if isinstance(config, dict) or ( + dataclasses.is_dataclass(config) and not isinstance(config, type) + ): + return _config_to_plain_python(config) + if hasattr(config, "to_dict"): + config = config.to_dict() + if isinstance(config, dict): + return _config_to_plain_python(config) # Reconstruct from the ModelArgs dataclass if hasattr(model, "args"): - import dataclasses - if dataclasses.is_dataclass(model.args): - return dataclasses.asdict(model.args) + if dataclasses.is_dataclass(model.args) and not isinstance(model.args, type): + return _config_to_plain_python(model.args) return {} From 4546225d62be6908171fde1f803e6d04e9eac72c Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 19:20:56 +0800 Subject: [PATCH 09/16] fix(gguf): run package converters beside conversion --- unsloth_zoo/llama_cpp.py | 7 ++++++- unsloth_zoo/mlx/utils.py | 21 ++++++++++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 98c265d63..2a21e966f 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -1338,7 +1338,12 @@ def _download_convert_hf_to_gguf_cached(name, _local_script_info, _conversion_in # 4. Write Patched File - patched_filename = os.path.join(LLAMA_CPP_DEFAULT_DIR, f"{name}.py") + # Package-layout converters import sibling modules from conversion/. + # Keep the patched entrypoint beside that package so subprocess + # execution resolves `from conversion import ...`. + patched_dir = _llama_cpp_dir if _layout == "package" else LLAMA_CPP_DEFAULT_DIR + os.makedirs(patched_dir, exist_ok=True) + patched_filename = os.path.join(patched_dir, f"{name}.py") logger.info(f"Unsloth: Saving patched script to {patched_filename}") with open(patched_filename, "wb") as file: file.write(patched_content) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 1dc5b5fce..5c4e5234c 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -4490,6 +4490,7 @@ def save_pretrained_gguf( quantize_gguf, install_llama_cpp, check_llama_cpp, + LLAMA_CPP_DEFAULT_DIR, _download_convert_hf_to_gguf, ) @@ -4541,12 +4542,13 @@ def save_pretrained_gguf( ) # Step 2: Ensure llama.cpp is installed and gguf package is available - llama_cpp_folder = "llama.cpp" + llama_cpp_folder = LLAMA_CPP_DEFAULT_DIR try: - check_llama_cpp(llama_cpp_folder) + quantizer_location, converter_location = check_llama_cpp(llama_cpp_folder) except Exception: print("Unsloth: Installing llama.cpp (this only happens once)...") - _install_llama_cpp_macos(llama_cpp_folder) + quantizer_location, converter_location = install_llama_cpp(llama_cpp_folder) + llama_cpp_folder = os.path.dirname(converter_location) # Ensure gguf Python package is installed (may be missing if # llama.cpp was built in a different venv) @@ -4573,7 +4575,16 @@ def save_pretrained_gguf( converter = os.path.join(llama_cpp_folder, "unsloth_convert_hf_to_gguf.py") supported_text_archs = None supported_vision_archs = None - result = _download_convert_hf_to_gguf() # no args — uses defaults + old_scripts_dir = os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") + if old_scripts_dir is None: + os.environ["UNSLOTH_LLAMA_CPP_SCRIPTS_DIR"] = llama_cpp_folder + try: + result = _download_convert_hf_to_gguf() + finally: + if old_scripts_dir is None: + os.environ.pop("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR", None) + else: + os.environ["UNSLOTH_LLAMA_CPP_SCRIPTS_DIR"] = old_scripts_dir if isinstance(result, tuple) and len(result) >= 3: converter, supported_text_archs, supported_vision_archs = result[:3] elif isinstance(result, str): @@ -4607,7 +4618,7 @@ def save_pretrained_gguf( # Step 6: Quantize if the target quant differs from first_conversion if quant_type not in ("bf16", "f16", "f32") and first_conversion != quant_type: - quantizer = os.path.join(llama_cpp_folder, "llama-quantize") + quantizer = quantizer_location base_gguf = f"{output_base}.{first_conversion.upper()}.gguf" final_gguf = f"{output_base}.{quant_type.upper()}.gguf" From d6d5d39b1f2a6d0406368f1575b49f3c19505061 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 19:22:00 +0800 Subject: [PATCH 10/16] fix(mlx): align gguf nextn metadata with tensors --- unsloth_zoo/mlx/utils.py | 56 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 5c4e5234c..5dbc299f1 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -3939,6 +3939,54 @@ def _rewrite_mlx_vlm_tensor_for_gguf(name, tensor, sanitize_steps): return name, tensor, False +def _sync_gguf_nextn_layer_config(config, model): + """Align speculative-layer config metadata with exported MLX layers.""" + if model is None or not isinstance(config, dict): + return False + + layers = _get_transformer_layers(model) + if layers is None: + return False + try: + actual_layers = len(layers) + except Exception: + return False + + text_configs = [ + config.get("text_config"), + config.get("language_config"), + (config.get("thinker_config") or {}).get("text_config"), + ] + changed = False + for text_config in text_configs: + if not isinstance(text_config, dict): + continue + num_hidden_layers = text_config.get("num_hidden_layers") + if not isinstance(num_hidden_layers, int): + continue + + actual_nextn = actual_layers - num_hidden_layers + for key in ( + "num_nextn_predict_layers", + "mtp_num_hidden_layers", + "nextn_predict_layers", + ): + num_nextn = text_config.get(key) + if not isinstance(num_nextn, int) or num_nextn <= 0: + continue + if actual_layers < num_hidden_layers: + continue + if actual_nextn >= num_nextn: + continue + if actual_nextn > 0: + text_config[key] = actual_nextn + else: + text_config.pop(key, None) + changed = True + + return changed + + def _prepare_vlm_gguf_export_directory(path, model=None): """Rewrite MLX-native VLM tensor names in the temporary GGUF export dir.""" path = Path(path) @@ -3947,8 +3995,12 @@ def _prepare_vlm_gguf_export_directory(path, model=None): return 0 with open(config_path, "r") as f: config = json.load(f) + config_changed = _sync_gguf_nextn_layer_config(config, model) sanitize_steps = _build_mlx_vlm_sanitize_pipelines(config, model=model) if not sanitize_steps: + if config_changed: + with open(config_path, "w") as f: + json.dump(config, f, indent=4) return 0 rewritten = 0 @@ -3994,6 +4046,10 @@ def _prepare_vlm_gguf_export_directory(path, model=None): with open(index_path, "w") as f: json.dump(index_data, f, indent=4) + if config_changed: + with open(config_path, "w") as f: + json.dump(config, f, indent=4) + return rewritten From 398cb9ca57be8f2fdc7b61fcdb0bffc46003c335 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 19:22:25 +0800 Subject: [PATCH 11/16] fix(mlx): repair degraded VLM processors --- unsloth_zoo/mlx/loader.py | 156 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index 6127791a6..f37cdb56a 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -360,6 +360,154 @@ def _load_weights_without_projection_quant_state(self, file_or_weights, strict=T nn.Module.load_weights = original_load_weights +def _read_json_file(path): + """Read a JSON object, returning an empty dict for missing/bad sidecars.""" + try: + with open(path, "r", encoding="utf-8") as file: + return json.load(file) + except Exception: + return {} + + +def _resolve_mlx_vlm_processor_class(model_type, processor_class_name): + """Resolve a custom mlx-vlm or Transformers processor class by name.""" + if not processor_class_name: + return None + + module_model_type = (model_type or "").replace("-", "_") + module_candidates = ( + f"mlx_vlm.models.{module_model_type}.processing", + f"mlx_vlm.models.{module_model_type}.processing_{module_model_type}", + ) + for module_name in module_candidates: + try: + module = importlib.import_module(module_name) + except Exception: + continue + processor_class = getattr(module, processor_class_name, None) + if processor_class is not None: + return processor_class + + try: + import transformers + return getattr(transformers, processor_class_name, None) + except Exception: + return None + + +def _build_vlm_image_processor_from_config(model_path, processor_config, preprocessor_config): + """Recreate the image processor from saved processor sidecar configs.""" + image_config = processor_config.get("image_processor") + if not isinstance(image_config, dict): + image_config = preprocessor_config + if not isinstance(image_config, dict): + image_config = {} + + image_processor_type = ( + image_config.get("image_processor_type") + or preprocessor_config.get("image_processor_type") + ) + image_kwargs = dict(image_config) + image_kwargs.pop("image_processor_type", None) + image_kwargs.pop("processor_class", None) + + if image_processor_type: + try: + import transformers + image_processor_class = getattr(transformers, image_processor_type, None) + if image_processor_class is not None: + return image_processor_class(**image_kwargs) + except Exception: + pass + + try: + from transformers import AutoImageProcessor + return AutoImageProcessor.from_pretrained(model_path) + except Exception: + return None + + +def _repair_degraded_vlm_processor( + processor, + model_path, + model_type, + *, + token=None, + trust_remote_code=False, +): + """Rebuild VLM processors when mlx-vlm falls back to tokenizer-only. + + mlx-vlm registers several custom processors through an AutoProcessor patch. + If the custom processor's image processor cannot be constructed through + AutoImageProcessor, the patch falls back to the prior tokenizer-only loader. + Rebuild from the source processor configs so downstream saves preserve real + multimodal processor metadata. + """ + if processor is None or getattr(processor, "image_processor", None) is not None: + return processor + + if not model_path or not os.path.isdir(str(model_path)): + return processor + + processor_config = _read_json_file( + os.path.join(str(model_path), "processor_config.json") + ) + preprocessor_config = _read_json_file( + os.path.join(str(model_path), "preprocessor_config.json") + ) + processor_class_name = ( + processor_config.get("processor_class") + or preprocessor_config.get("processor_class") + ) + processor_class = _resolve_mlx_vlm_processor_class( + model_type, processor_class_name, + ) + if processor_class is None: + return processor + + image_processor = _build_vlm_image_processor_from_config( + model_path, processor_config, preprocessor_config, + ) + if image_processor is None: + return processor + + tokenizer = getattr(processor, "tokenizer", None) or processor + if tokenizer is None or not hasattr(tokenizer, "save_pretrained"): + try: + from transformers import AutoTokenizer + tokenizer_kwargs = {"trust_remote_code": trust_remote_code} + if token: + tokenizer_kwargs["token"] = token + tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs) + except Exception: + return processor + + chat_template = getattr(processor, "chat_template", None) + if chat_template is not None and getattr(tokenizer, "chat_template", None) is None: + tokenizer.chat_template = chat_template + + try: + repaired = processor_class( + image_processor=image_processor, + tokenizer=tokenizer, + chat_template=chat_template, + ) + except TypeError: + try: + repaired = processor_class( + image_processor=image_processor, + tokenizer=tokenizer, + ) + except Exception: + return processor + except Exception: + return processor + + if chat_template is not None and getattr(repaired, "chat_template", None) is None: + repaired.chat_template = chat_template + return repaired + + def _build_vlm_model_types(): """Build the set of model_type strings that mlx_vlm supports. @@ -3304,6 +3452,14 @@ def from_pretrained( hf_token=token, ) + processor = _repair_degraded_vlm_processor( + processor, + local_path or model_name, + model_type, + token=token, + trust_remote_code=trust_remote_code, + ) + if target_dtype is not None: _convert_mlx_dtype(model, target_dtype, model_type=model_type) elif want_runtime_quant: From df5deb90f4b4322e47df7f61596acf986dd7abbb Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 22:07:25 +0800 Subject: [PATCH 12/16] fix(mlx): preserve source save sidecars --- unsloth_zoo/mlx/utils.py | 56 +++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 5dbc299f1..8a5f4f448 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -4053,6 +4053,53 @@ def _prepare_vlm_gguf_export_directory(path, model=None): return rewritten +_CORE_SAVE_FILENAMES = { + "config.json", + "model.safetensors.index.json", + "README.md", + ".gitattributes", +} +_MODEL_WEIGHT_SUFFIXES = ( + ".safetensors", + ".bin", + ".gguf", + ".h5", + ".msgpack", + ".onnx", + ".pt", + ".pth", +) +_MODEL_SIDECAR_SUFFIXES = (".json", ".jinja", ".model", ".txt", ".py") + + +def _copy_source_sidecars(src_path, path): + """Copy non-weight source sidecars that tokenizer/model saves may omit.""" + copied = 0 + src_path = Path(src_path) + path = Path(path) + if not src_path.exists(): + return copied + for source in src_path.iterdir(): + if not source.is_file(): + continue + name = source.name + if name in _CORE_SAVE_FILENAMES: + continue + if name.startswith("model-") or name.startswith("pytorch_model"): + continue + suffix = source.suffix + if suffix in _MODEL_WEIGHT_SUFFIXES: + continue + if suffix not in _MODEL_SIDECAR_SUFFIXES: + continue + target = path / name + if target.exists(): + continue + shutil.copy2(source, target) + copied += 1 + return copied + + def save_merged_model(model, tokenizer, path, dequantize=False): """Fuse LoRA weights and save the full merged model. @@ -4120,15 +4167,10 @@ def save_merged_model(model, tokenizer, path, dequantize=False): # Save tokenizer tokenizer.save_pretrained(str(path)) - # Copy auxiliary files (generation_config.json, *.py) from source + # Copy auxiliary source files that tokenizer/model saves may omit. src_path = _get_src_path(model) if src_path is not None: - src_path = Path(src_path) - if src_path.exists(): - import glob as globmod - for pattern in ["generation_config.json", "*.py"]: - for f in globmod.glob(str(src_path / pattern)): - shutil.copy(f, path) + _copy_source_sidecars(src_path, path) # Model card hf_repo = getattr(model, "_hf_repo", None) From e5305370ba94d7b5c62405a1afbc187857f7c972 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 26 May 2026 23:53:36 +0800 Subject: [PATCH 13/16] test(mlx): cover save export regressions --- tests/test_mlx_save_export_regressions.py | 767 ++++++++++++++++++++++ 1 file changed, 767 insertions(+) create mode 100644 tests/test_mlx_save_export_regressions.py diff --git a/tests/test_mlx_save_export_regressions.py b/tests/test_mlx_save_export_regressions.py new file mode 100644 index 000000000..efed95329 --- /dev/null +++ b/tests/test_mlx_save_export_regressions.py @@ -0,0 +1,767 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +"""Fast regressions for MLX save/export parity fixes. + +These tests cover the contracts behind the real save / GGUF export bugs without +downloading or converting large models. +""" + +from __future__ import annotations + +import dataclasses +import json +import os +import sys +import types +from pathlib import Path + +import pytest + + +@pytest.fixture(autouse=True, scope="module") +def _install_mlx_torch_shim(): + from mlx_simulation import simulate_mlx_on_torch + + simulate_mlx_on_torch() + + +def test_vlm_config_save_uses_vlm_helper_and_preserves_quantization_config( + monkeypatch, + tmp_path, +): + import unsloth_zoo.mlx.utils as mutils + + calls = {} + fake_vlm_utils = types.ModuleType("mlx_vlm.utils") + + def fake_save_config(config, path): + calls["config"] = config + calls["path"] = Path(path) + Path(path).write_text(json.dumps(config), encoding="utf-8") + + fake_vlm_utils.save_config = fake_save_config + monkeypatch.setitem(sys.modules, "mlx_vlm.utils", fake_vlm_utils) + + config = { + "model_type": "gemma3", + "vision_config": {"hidden_size": 8}, + "quantization": {"group_size": 64, "bits": 4}, + } + mutils._save_mlx_config(config, tmp_path / "config.json", is_vlm=True) + + assert calls["path"] == tmp_path / "config.json" + assert calls["config"]["quantization_config"] == config["quantization"] + assert "quantization_config" not in config + + +def test_merged_16bit_save_fully_dequantizes_model(monkeypatch, tmp_path): + import unsloth_zoo.mlx.utils as mutils + + calls = {"fuse": [], "dequantize": 0} + + class LoRALinear: + def fuse(self, dequantize=False): + calls["fuse"].append(dequantize) + return "fused-linear" + + class Model: + _config = { + "model_type": "llama", + "tie_word_embeddings": False, + "quantization": {"bits": 4}, + "nested": {"quantization_config": {"bits": 4}}, + } + + def eval(self): + calls["eval"] = True + + def named_modules(self): + return [("layers.0.self_attn.q_proj", LoRALinear())] + + def update_modules(self, modules): + calls["updated"] = modules + + class Tokenizer: + def save_pretrained(self, path): + Path(path).mkdir(parents=True, exist_ok=True) + calls["tokenizer_path"] = Path(path) + + fake_mlx_lm_utils = types.ModuleType("mlx_lm.utils") + + def fake_dequantize_model(model): + calls["dequantize"] += 1 + return model + + def fake_save_model(path, model, donate_model=False): + Path(path).mkdir(parents=True, exist_ok=True) + calls["donate_model"] = donate_model + + def fake_save_config(config, path): + calls["saved_config"] = config + Path(path).write_text(json.dumps(config), encoding="utf-8") + + fake_mlx_lm_utils.dequantize_model = fake_dequantize_model + fake_mlx_lm_utils.save_model = fake_save_model + fake_mlx_lm_utils.save_config = fake_save_config + fake_mlx_lm_utils.create_model_card = lambda path, hf_repo: None + monkeypatch.setitem(sys.modules, "mlx_lm.utils", fake_mlx_lm_utils) + + fake_mlx_utils = types.ModuleType("mlx.utils") + fake_mlx_utils.tree_unflatten = dict + monkeypatch.setitem(sys.modules, "mlx.utils", fake_mlx_utils) + + mutils.save_merged_model(Model(), Tokenizer(), tmp_path, dequantize=True) + + assert calls["eval"] is True + assert calls["fuse"] == [True] + assert calls["dequantize"] == 1 + assert calls["donate_model"] is False + assert "quantization" not in calls["saved_config"] + assert "quantization_config" not in calls["saved_config"]["nested"] + + +def test_materialize_tied_lm_head_updates_saved_index(monkeypatch, tmp_path): + import torch + import unsloth_zoo.mlx.utils as mutils + + shard = tmp_path / "model-00001-of-00001.safetensors" + shard.write_text("placeholder", encoding="utf-8") + index_path = tmp_path / "model.safetensors.index.json" + index_path.write_text( + json.dumps( + { + "metadata": {"total_size": 24, "total_parameters": 6}, + "weight_map": { + "model.embed_tokens.weight": shard.name, + }, + } + ), + encoding="utf-8", + ) + + saved = {} + embed = torch.ones(2, 3, dtype=torch.float32) + monkeypatch.setattr( + mutils.mx, + "load", + lambda path: {"model.embed_tokens.weight": embed}, + ) + monkeypatch.setattr(mutils.mx, "eval", lambda *values: None) + + def fake_save_safetensors(path, tensors, metadata=None): + saved["path"] = Path(path) + saved["tensors"] = tensors + Path(path).write_text("saved", encoding="utf-8") + + monkeypatch.setattr(mutils.mx, "save_safetensors", fake_save_safetensors) + + added = mutils._materialize_tied_lm_head_in_saved_model( + tmp_path, + {"tie_word_embeddings": True}, + ) + + assert added == 1 + assert "lm_head.weight" in saved["tensors"] + updated = json.loads(index_path.read_text(encoding="utf-8")) + assert updated["weight_map"]["lm_head.weight"] == shard.name + assert updated["metadata"]["total_size"] > 24 + assert updated["metadata"]["total_parameters"] == 12 + + +def test_bound_gguf_save_filters_cuda_only_kwargs(monkeypatch, tmp_path): + import unsloth_zoo.mlx.loader as loader + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + def fake_save_pretrained_gguf( + model, + tokenizer, + save_directory, + quantization_method="fast_quantized", + **kwargs, + ): + calls["tokenizer"] = tokenizer + calls["save_directory"] = Path(save_directory) + calls["quantization_method"] = quantization_method + calls["kwargs"] = kwargs + + monkeypatch.setattr(mutils, "save_pretrained_gguf", fake_save_pretrained_gguf) + tokenizer = object() + model = types.SimpleNamespace(_tokenizer=tokenizer) + + loader._mlx_save_pretrained_gguf( + model, + tmp_path, + quantization_method="not_quantized", + first_conversion="f16", + maximum_memory_usage=0.5, + temporary_location="/tmp/ignored", + ) + + assert calls == { + "tokenizer": tokenizer, + "save_directory": tmp_path, + "quantization_method": "not_quantized", + "kwargs": {"first_conversion": "f16"}, + } + + +def test_bound_gguf_push_filters_kwargs(monkeypatch): + import unsloth_zoo.mlx.loader as loader + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + def fake_push_to_hub_gguf( + model, + tokenizer, + save_directory, + repo_id, + quantization_method="fast_quantized", + **kwargs, + ): + calls["tokenizer"] = tokenizer + calls["save_directory"] = save_directory + calls["repo_id"] = repo_id + calls["quantization_method"] = quantization_method + calls["kwargs"] = kwargs + + monkeypatch.setattr(mutils, "push_to_hub_gguf", fake_push_to_hub_gguf) + tokenizer = object() + model = types.SimpleNamespace(_tokenizer=tokenizer) + + loader._mlx_push_to_hub_gguf( + model, + "org/model", + quantization_method="q8_0", + first_conversion="bf16", + token="hf_token", + private=True, + maximum_memory_usage=0.5, + temporary_location="/tmp/ignored", + ) + + assert calls == { + "tokenizer": tokenizer, + "save_directory": "org/model", + "repo_id": "org/model", + "quantization_method": "q8_0", + "kwargs": { + "first_conversion": "bf16", + "token": "hf_token", + "private": True, + }, + } + + +def test_lora_push_uses_lora_adapter_hub_path(monkeypatch, tmp_path): + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + class Model: + def named_modules(self): + return [("layers.0.q_proj", types.SimpleNamespace(fuse=lambda: None))] + + def trainable_parameters(self): + return {} + + class Tokenizer: + def save_pretrained(self, path): + calls["tokenizer_path"] = Path(path) + + def fake_save_lora_adapters(model, save_directory): + calls["adapter_dir"] = Path(save_directory) + + def fake_push_lora_adapters_to_hub( + save_directory, + **kwargs, + ): + calls["hub_dir"] = Path(save_directory) + calls["hub_kwargs"] = kwargs + + monkeypatch.setattr( + mutils, + "collect_mlx_lora_adapter_tensors", + lambda model: {"layers.0.q_proj.lora_a": object()}, + ) + monkeypatch.setattr(mutils, "iter_mlx_lora_modules", lambda model: []) + monkeypatch.setattr(mutils, "save_lora_adapters", fake_save_lora_adapters) + monkeypatch.setattr( + mutils, + "_push_lora_adapters_to_hub", + fake_push_lora_adapters_to_hub, + ) + monkeypatch.setattr( + mutils, + "push_to_hub_merged", + lambda *args, **kwargs: pytest.fail("push_to_hub_merged should not run"), + ) + + mutils.save_pretrained_merged( + Model(), + Tokenizer(), + tmp_path, + save_method="lora", + push_to_hub=True, + token="hf_token", + private=True, + ) + + assert calls["adapter_dir"] == tmp_path + assert calls["hub_dir"] == tmp_path + assert calls["hub_kwargs"]["repo_id"] is None + assert calls["hub_kwargs"]["token"] == "hf_token" + assert calls["hub_kwargs"]["private"] is True + + +def _patch_mlx_tensor_helpers_for_torch(monkeypatch, mutils): + import torch + + monkeypatch.setattr( + mutils.mx, + "transpose", + lambda tensor, axes=None, **kwargs: tensor.permute(*axes) + if axes is not None + else tensor.permute(*reversed(range(tensor.ndim))), + ) + monkeypatch.setattr(mutils.mx, "all", torch.all) + + +def test_vlm_rewrite_prefers_hf_alias_before_current_name(monkeypatch): + import torch + import unsloth_zoo.mlx.utils as mutils + + _patch_mlx_tensor_helpers_for_torch(monkeypatch, mutils) + + class QwenSanitizer: + @staticmethod + def sanitize(weights): + renamed = {} + for name, tensor in weights.items(): + if name.startswith("visual."): + name = f"vision_tower.{name[len('visual.'):]}" + renamed[name] = tensor + return renamed + + tensor = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5) + new_name, new_tensor, changed = mutils._rewrite_mlx_vlm_tensor_for_gguf( + "vision_tower.patch_embed.proj.weight", + tensor, + [(QwenSanitizer, None)], + ) + + assert changed is True + assert new_name == "visual.patch_embed.proj.weight" + assert mutils._mlx_arrays_match(new_tensor, tensor) + + +def test_vlm_rewrite_handles_same_name_layout_transforms(monkeypatch): + import torch + import unsloth_zoo.mlx.utils as mutils + + _patch_mlx_tensor_helpers_for_torch(monkeypatch, mutils) + + class SameNameConvSanitizer: + @staticmethod + def sanitize(weights): + return { + name: mutils.mx.transpose(tensor, (0, 2, 3, 1)) + for name, tensor in weights.items() + } + + mlx_layout = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5) + new_name, hf_layout, changed = mutils._rewrite_mlx_vlm_tensor_for_gguf( + "vision_tower.patch_embed.proj.weight", + mlx_layout, + [(SameNameConvSanitizer, None)], + ) + + assert changed is True + assert new_name == "vision_tower.patch_embed.proj.weight" + assert tuple(hf_layout.shape) == (2, 5, 3, 4) + assert mutils._mlx_arrays_match( + mutils.mx.transpose(hf_layout, (0, 2, 3, 1)), + mlx_layout, + ) + + +def test_vlm_sanitizer_replay_uses_real_model_instances(): + import unsloth_zoo.mlx.utils as mutils + + class VisionTower: + def sanitize(self, weights): + assert "vision_tower.proj.weight" in weights + return {"visual.proj.weight": weights["vision_tower.proj.weight"]} + + class Model: + def __init__(self): + self.vision_tower = VisionTower() + + def sanitize(self, weights): + return self.vision_tower.sanitize(weights) + + model = Model() + pipelines = mutils._get_mlx_vlm_model_sanitize_pipelines(model) + + assert pipelines[0][0][0] is model + assert mutils._apply_mlx_vlm_sanitizers( + pipelines[0], + {"vision_tower.proj.weight": "tensor"}, + ) == {"visual.proj.weight": "tensor"} + + +def test_repair_degraded_vlm_processor_rebuilds_from_sidecar_configs( + monkeypatch, + tmp_path, +): + import unsloth_zoo.mlx.loader as loader + + class FakeProcessor: + def __init__(self, image_processor, tokenizer, chat_template=None): + self.image_processor = image_processor + self.tokenizer = tokenizer + self.chat_template = chat_template + + fake_processing = types.ModuleType("mlx_vlm.models.glm_ocr.processing") + fake_processing.FakeProcessor = FakeProcessor + monkeypatch.setitem( + sys.modules, + "mlx_vlm.models.glm_ocr.processing", + fake_processing, + ) + + image_processor = object() + monkeypatch.setattr( + loader, + "_build_vlm_image_processor_from_config", + lambda model_path, processor_config, preprocessor_config: image_processor, + ) + + (tmp_path / "processor_config.json").write_text( + json.dumps({"processor_class": "FakeProcessor"}), + encoding="utf-8", + ) + (tmp_path / "preprocessor_config.json").write_text( + json.dumps({"image_processor_type": "FakeImageProcessor"}), + encoding="utf-8", + ) + + tokenizer = types.SimpleNamespace( + chat_template=None, + save_pretrained=lambda path: None, + ) + degraded = types.SimpleNamespace( + tokenizer=tokenizer, + chat_template="{{ messages }}", + ) + + repaired = loader._repair_degraded_vlm_processor( + degraded, + tmp_path, + "glm_ocr", + ) + + assert isinstance(repaired, FakeProcessor) + assert repaired.image_processor is image_processor + assert repaired.tokenizer is tokenizer + assert repaired.chat_template == "{{ messages }}" + assert tokenizer.chat_template == "{{ messages }}" + + +def test_get_model_config_extracts_dataclass_configs(): + import unsloth_zoo.mlx.utils as mutils + + @dataclasses.dataclass + class VisionConfig: + hidden_size: int + + @dataclasses.dataclass + class ModelConfig: + model_type: str + vision_config: VisionConfig + scales: tuple[int, int] + + model = types.SimpleNamespace( + config=ModelConfig( + model_type="glm_ocr", + vision_config=VisionConfig(hidden_size=16), + scales=(1, 2), + ) + ) + + assert mutils._get_model_config(model) == { + "model_type": "glm_ocr", + "vision_config": {"hidden_size": 16}, + "scales": [1, 2], + } + + +def test_get_model_config_prefers_copied_raw_config(): + import unsloth_zoo.mlx.utils as mutils + + raw_config = {"model_type": "qwen3", "nested": {"values": [1]}} + model = types.SimpleNamespace( + _config=raw_config, + config=types.SimpleNamespace(to_dict=lambda: {"model_type": "wrong"}), + ) + + extracted = mutils._get_model_config(model) + extracted["nested"]["values"].append(2) + + assert extracted["model_type"] == "qwen3" + assert raw_config["nested"]["values"] == [1] + + +def test_prepare_vlm_gguf_export_directory_writes_nextn_config_without_tensors( + monkeypatch, + tmp_path, +): + import unsloth_zoo.mlx.utils as mutils + + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "model_type": "glm_ocr", + "vision_config": {}, + "text_config": { + "num_hidden_layers": 16, + "num_nextn_predict_layers": 1, + "mtp_num_hidden_layers": 1, + "nextn_predict_layers": 1, + }, + } + ), + encoding="utf-8", + ) + monkeypatch.setattr(mutils, "_get_transformer_layers", lambda model: [object()] * 16) + monkeypatch.setattr( + mutils, + "_build_mlx_vlm_sanitize_pipelines", + lambda config, model=None: [], + ) + + rewritten = mutils._prepare_vlm_gguf_export_directory(tmp_path, model=object()) + + assert rewritten == 0 + updated = json.loads(config_path.read_text(encoding="utf-8")) + assert "num_nextn_predict_layers" not in updated["text_config"] + assert "mtp_num_hidden_layers" not in updated["text_config"] + assert "nextn_predict_layers" not in updated["text_config"] + + +def test_copy_source_sidecars_preserves_image_processor_metadata(tmp_path): + import unsloth_zoo.mlx.utils as mutils + + src = tmp_path / "src" + dst = tmp_path / "dst" + src.mkdir() + dst.mkdir() + + for name in ( + "preprocessor_config.json", + "processor_config.json", + "video_preprocessor_config.json", + "chat_template.jinja", + "tokenizer.model", + "vocab.txt", + "custom_processing.py", + "config.json", + "README.md", + ".gitattributes", + "model.safetensors", + "model-00001-of-00002.safetensors", + "pytorch_model.bin", + ): + (src / name).write_text(name, encoding="utf-8") + (dst / "preprocessor_config.json").write_text("existing", encoding="utf-8") + + copied = mutils._copy_source_sidecars(src, dst) + + assert copied == 6 + assert (dst / "preprocessor_config.json").read_text(encoding="utf-8") == "existing" + for name in ( + "processor_config.json", + "video_preprocessor_config.json", + "chat_template.jinja", + "tokenizer.model", + "vocab.txt", + "custom_processing.py", + ): + assert (dst / name).read_text(encoding="utf-8") == name + for skipped in ( + "config.json", + "README.md", + ".gitattributes", + "model.safetensors", + "model-00001-of-00002.safetensors", + "pytorch_model.bin", + ): + assert not (dst / skipped).exists() + + +def test_save_pretrained_gguf_anchors_patcher_to_checked_llama_cpp_root( + monkeypatch, + tmp_path, +): + import unsloth_zoo.llama_cpp as llama_cpp + import unsloth_zoo.mlx.utils as mutils + + monkeypatch.setitem(sys.modules, "gguf", types.ModuleType("gguf")) + + llama_root = tmp_path / "llama.cpp" + llama_root.mkdir() + converter = llama_root / "convert_hf_to_gguf.py" + converter.write_text("# converter", encoding="utf-8") + quantizer = llama_root / "llama-quantize" + quantizer.write_text("# quantizer", encoding="utf-8") + + calls = {} + + def fake_save_merged_model(model, tokenizer, path, dequantize=False): + calls["dequantize"] = dequantize + Path(path).mkdir(parents=True, exist_ok=True) + + def fake_download_convert_hf_to_gguf(): + calls["scripts_dir"] = os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") + patched = llama_root / "unsloth_convert_hf_to_gguf.py" + patched.write_text("# patched converter", encoding="utf-8") + return str(patched), {"Qwen3ForCausalLM"}, {"Gemma3ForConditionalGeneration"} + + def fake_convert_to_gguf(**kwargs): + calls["convert_kwargs"] = kwargs + output = Path( + f"{kwargs['model_name']}.{kwargs['quantization_type'].upper()}.gguf" + ) + output.write_bytes(b"GGUF") + + monkeypatch.setattr(mutils, "save_merged_model", fake_save_merged_model) + monkeypatch.setattr(mutils, "_is_vlm_model", lambda model: False) + monkeypatch.setattr(llama_cpp, "LLAMA_CPP_DEFAULT_DIR", str(tmp_path / "unused")) + monkeypatch.setattr( + llama_cpp, + "check_llama_cpp", + lambda llama_cpp_folder: (str(quantizer), str(converter)), + ) + monkeypatch.setattr( + llama_cpp, + "install_llama_cpp", + lambda llama_cpp_folder: pytest.fail("install_llama_cpp should not run"), + ) + monkeypatch.setattr( + llama_cpp, + "_download_convert_hf_to_gguf", + fake_download_convert_hf_to_gguf, + ) + monkeypatch.setattr(llama_cpp, "convert_to_gguf", fake_convert_to_gguf) + monkeypatch.setattr( + llama_cpp, + "quantize_gguf", + lambda **kwargs: pytest.fail("quantize_gguf should not run"), + ) + + old_scripts_dir = os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") + model = types.SimpleNamespace(_hf_repo="org/TestModel") + out = tmp_path / "out" + mutils.save_pretrained_gguf( + model, + tokenizer=object(), + save_directory=out, + quantization_method="not_quantized", + first_conversion="f16", + ) + + assert calls["dequantize"] is True + assert calls["scripts_dir"] == str(llama_root) + assert calls["convert_kwargs"]["converter_location"] == str( + llama_root / "unsloth_convert_hf_to_gguf.py" + ) + assert calls["convert_kwargs"]["supported_text_archs"] == {"Qwen3ForCausalLM"} + assert (out / "TestModel.F16.gguf").read_bytes() == b"GGUF" + assert os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") == old_scripts_dir + + +def test_push_to_hub_gguf_forwards_first_conversion(monkeypatch, tmp_path): + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + class FakeHfApi: + def __init__(self, token=None): + calls["token"] = token + + def create_repo(self, repo_id, exist_ok=True, private=None): + calls["repo"] = { + "repo_id": repo_id, + "exist_ok": exist_ok, + "private": private, + } + + def update_repo_settings(self, **kwargs): + calls["update_repo_settings"] = kwargs + + def upload_file(self, path_or_fileobj, path_in_repo, repo_id): + calls["upload"] = { + "path_or_fileobj": Path(path_or_fileobj), + "path_in_repo": path_in_repo, + "repo_id": repo_id, + } + + fake_hub = types.ModuleType("huggingface_hub") + fake_hub.HfApi = FakeHfApi + monkeypatch.setitem(sys.modules, "huggingface_hub", fake_hub) + + def fake_save_pretrained_gguf( + model, + tokenizer, + save_directory, + quantization_method="fast_quantized", + first_conversion=None, + ): + calls["save"] = { + "quantization_method": quantization_method, + "first_conversion": first_conversion, + } + Path(save_directory).mkdir(parents=True, exist_ok=True) + (Path(save_directory) / "model.F16.gguf").write_bytes(b"GGUF") + + monkeypatch.setattr(mutils, "save_pretrained_gguf", fake_save_pretrained_gguf) + + mutils.push_to_hub_gguf( + model=object(), + tokenizer=object(), + save_directory=tmp_path, + repo_id="org/model", + quantization_method="not_quantized", + first_conversion="f16", + token="hf_token", + private=True, + ) + + assert calls["save"] == { + "quantization_method": "not_quantized", + "first_conversion": "f16", + } + assert calls["token"] == "hf_token" + assert calls["repo"] == { + "repo_id": "org/model", + "exist_ok": True, + "private": True, + } + assert calls["upload"]["path_in_repo"] == "model.F16.gguf" From 0336fc3ec1bcc2ee6f27515630fdd56cd6de0691 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 27 May 2026 15:16:21 +0800 Subject: [PATCH 14/16] fix(mlx): harden save export review cases --- tests/test_mlx_save_export_regressions.py | 140 ++++++++++++++++++++++ unsloth_zoo/mlx/loader.py | 2 +- unsloth_zoo/mlx/utils.py | 57 ++++++--- 3 files changed, 181 insertions(+), 18 deletions(-) diff --git a/tests/test_mlx_save_export_regressions.py b/tests/test_mlx_save_export_regressions.py index efed95329..34f29d86e 100644 --- a/tests/test_mlx_save_export_regressions.py +++ b/tests/test_mlx_save_export_regressions.py @@ -64,6 +64,7 @@ def fake_save_config(config, path): mutils._save_mlx_config(config, tmp_path / "config.json", is_vlm=True) assert calls["path"] == tmp_path / "config.json" + assert calls["config"]["quantization"] == config["quantization"] assert calls["config"]["quantization_config"] == config["quantization"] assert "quantization_config" not in config @@ -401,6 +402,22 @@ def sanitize(weights): ) +def test_mlx_arrays_match_checks_2d_tensor_values(monkeypatch): + import torch + import unsloth_zoo.mlx.utils as mutils + + monkeypatch.setattr(mutils.mx, "all", torch.all) + + assert mutils._mlx_arrays_match( + torch.zeros(2, 3), + torch.zeros(2, 3), + ) + assert not mutils._mlx_arrays_match( + torch.zeros(2, 3), + torch.ones(2, 3), + ) + + def test_vlm_sanitizer_replay_uses_real_model_instances(): import unsloth_zoo.mlx.utils as mutils @@ -484,6 +501,30 @@ def __init__(self, image_processor, tokenizer, chat_template=None): assert tokenizer.chat_template == "{{ messages }}" +def test_read_json_file_returns_empty_for_missing_or_malformed_files(tmp_path): + import unsloth_zoo.mlx.loader as loader + + assert loader._read_json_file(tmp_path / "missing.json") == {} + + malformed = tmp_path / "malformed.json" + malformed.write_text("{not-json", encoding="utf-8") + + assert loader._read_json_file(malformed) == {} + + +def test_read_json_file_does_not_swallow_unexpected_errors(monkeypatch, tmp_path): + import builtins + import unsloth_zoo.mlx.loader as loader + + def fail_open(*args, **kwargs): + raise RuntimeError("unexpected") + + monkeypatch.setattr(builtins, "open", fail_open) + + with pytest.raises(RuntimeError, match="unexpected"): + loader._read_json_file(tmp_path / "config.json") + + def test_get_model_config_extracts_dataclass_configs(): import unsloth_zoo.mlx.utils as mutils @@ -528,6 +569,74 @@ def test_get_model_config_prefers_copied_raw_config(): assert raw_config["nested"]["values"] == [1] +def test_has_tied_word_embeddings_ignores_malformed_thinker_config(): + import unsloth_zoo.mlx.utils as mutils + + assert not mutils._has_tied_word_embeddings({"thinker_config": "bad"}) + assert mutils._has_tied_word_embeddings( + {"thinker_config": {"text_config": {"tie_word_embeddings": True}}} + ) + + +def test_has_vision_config_handles_nested_and_malformed_configs(): + import unsloth_zoo.mlx.utils as mutils + + assert not mutils._has_vision_config(None) + assert not mutils._has_vision_config({"thinker_config": "bad"}) + assert mutils._has_vision_config({"vision_config": {}}) + assert mutils._has_vision_config({"thinker_config": {"vision_config": {}}}) + + +def test_save_merged_model_detects_nested_vlm_config(monkeypatch, tmp_path): + import unsloth_zoo.mlx.utils as mutils + + calls = {} + + class Model: + _config = { + "model_type": "glm_ocr", + "thinker_config": {"vision_config": {"hidden_size": 8}}, + } + + def eval(self): + pass + + def named_modules(self): + return [] + + class Tokenizer: + def save_pretrained(self, path): + Path(path).mkdir(parents=True, exist_ok=True) + + fake_mlx_lm_utils = types.ModuleType("mlx_lm.utils") + fake_mlx_lm_utils.dequantize_model = lambda model: model + fake_mlx_lm_utils.save_model = lambda path, model, donate_model=False: Path( + path + ).mkdir(parents=True, exist_ok=True) + fake_mlx_lm_utils.create_model_card = lambda path, hf_repo: None + fake_mlx_lm_utils.save_config = lambda config, path: pytest.fail( + "text save_config should not run" + ) + monkeypatch.setitem(sys.modules, "mlx_lm.utils", fake_mlx_lm_utils) + + fake_mlx_utils = types.ModuleType("mlx.utils") + fake_mlx_utils.tree_unflatten = dict + monkeypatch.setitem(sys.modules, "mlx.utils", fake_mlx_utils) + + monkeypatch.setattr(mutils, "_is_vlm_model", lambda model: False) + + def fake_save_mlx_config(config, config_path, *, is_vlm=False): + calls["is_vlm"] = is_vlm + calls["config"] = config + + monkeypatch.setattr(mutils, "_save_mlx_config", fake_save_mlx_config) + + mutils.save_merged_model(Model(), Tokenizer(), tmp_path) + + assert calls["is_vlm"] is True + assert calls["config"]["thinker_config"]["vision_config"]["hidden_size"] == 8 + + def test_prepare_vlm_gguf_export_directory_writes_nextn_config_without_tensors( monkeypatch, tmp_path, @@ -566,6 +675,37 @@ def test_prepare_vlm_gguf_export_directory_writes_nextn_config_without_tensors( assert "nextn_predict_layers" not in updated["text_config"] +def test_prepare_vlm_gguf_export_directory_ignores_malformed_thinker_config( + monkeypatch, + tmp_path, +): + import unsloth_zoo.mlx.utils as mutils + + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "model_type": "glm_ocr", + "thinker_config": "bad", + "text_config": {"num_hidden_layers": 16}, + } + ), + encoding="utf-8", + ) + monkeypatch.setattr( + mutils, + "_get_transformer_layers", + lambda model: [object()] * 16, + ) + monkeypatch.setattr( + mutils, + "_build_mlx_vlm_sanitize_pipelines", + lambda config, model=None: [], + ) + + assert mutils._prepare_vlm_gguf_export_directory(tmp_path, model=object()) == 0 + + def test_copy_source_sidecars_preserves_image_processor_metadata(tmp_path): import unsloth_zoo.mlx.utils as mutils diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index f37cdb56a..823c0aa3a 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -365,7 +365,7 @@ def _read_json_file(path): try: with open(path, "r", encoding="utf-8") as file: return json.load(file) - except Exception: + except (FileNotFoundError, json.JSONDecodeError, OSError, UnicodeDecodeError): return {} diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 8a5f4f448..e99329c41 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -34,12 +34,16 @@ import sys import shutil import tempfile +import threading from pathlib import Path from .cce import _get_runtime_cce +_LLAMA_CPP_PATCHER_ENV_LOCK = threading.Lock() + + def _safe_token_denominator(ntoks): return mx.maximum(ntoks.astype(mx.float32), mx.array(1.0, dtype=mx.float32)) @@ -3493,11 +3497,16 @@ def _has_tied_word_embeddings(config): """Return whether any text config declares tied input/output embeddings.""" if not isinstance(config, dict): return False + thinker_config = config.get("thinker_config") candidates = [ config, config.get("text_config"), config.get("language_config"), - (config.get("thinker_config") or {}).get("text_config"), + ( + thinker_config.get("text_config") + if isinstance(thinker_config, dict) + else None + ), ] return any( isinstance(item, dict) and item.get("tie_word_embeddings") is True @@ -3688,8 +3697,16 @@ def _materialize_tied_lm_head_in_saved_model( def _has_vision_config(config): """Return whether a raw or thinker-wrapped VLM config has vision settings.""" - thinker_config = config.get("thinker_config") or {} - return "vision_config" in config or "vision_config" in thinker_config + if not isinstance(config, dict): + return False + thinker_config = config.get("thinker_config") + return ( + "vision_config" in config + or ( + isinstance(thinker_config, dict) + and "vision_config" in thinker_config + ) + ) class _MlxVlmSanitizeProxy: @@ -3890,10 +3907,10 @@ def _mlx_arrays_match(actual, expected): return True if shape is None: return actual == expected - if len(shape) not in (1, 4, 5): - return True try: - return bool(mx.all(actual == expected).item()) + result = mx.all(actual == expected) + item = getattr(result, "item", None) + return bool(item() if callable(item) else result) except Exception: return False @@ -3952,10 +3969,15 @@ def _sync_gguf_nextn_layer_config(config, model): except Exception: return False + thinker_config = config.get("thinker_config") text_configs = [ config.get("text_config"), config.get("language_config"), - (config.get("thinker_config") or {}).get("text_config"), + ( + thinker_config.get("text_config") + if isinstance(thinker_config, dict) + else None + ), ] changed = False for text_config in text_configs: @@ -4146,7 +4168,7 @@ def save_merged_model(model, tokenizer, path, dequantize=False): # Save config.json config = _get_model_config(model) if config: - is_vlm = _is_vlm_model(model) or "vision_config" in config + is_vlm = _is_vlm_model(model) or _has_vision_config(config) materialized = _materialize_tied_lm_head_in_saved_model( path, config, @@ -4673,16 +4695,17 @@ def save_pretrained_gguf( converter = os.path.join(llama_cpp_folder, "unsloth_convert_hf_to_gguf.py") supported_text_archs = None supported_vision_archs = None - old_scripts_dir = os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") - if old_scripts_dir is None: - os.environ["UNSLOTH_LLAMA_CPP_SCRIPTS_DIR"] = llama_cpp_folder - try: - result = _download_convert_hf_to_gguf() - finally: + with _LLAMA_CPP_PATCHER_ENV_LOCK: + old_scripts_dir = os.environ.get("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR") if old_scripts_dir is None: - os.environ.pop("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR", None) - else: - os.environ["UNSLOTH_LLAMA_CPP_SCRIPTS_DIR"] = old_scripts_dir + os.environ["UNSLOTH_LLAMA_CPP_SCRIPTS_DIR"] = llama_cpp_folder + try: + result = _download_convert_hf_to_gguf() + finally: + if old_scripts_dir is None: + os.environ.pop("UNSLOTH_LLAMA_CPP_SCRIPTS_DIR", None) + else: + os.environ["UNSLOTH_LLAMA_CPP_SCRIPTS_DIR"] = old_scripts_dir if isinstance(result, tuple) and len(result) >= 3: converter, supported_text_archs, supported_vision_archs = result[:3] elif isinstance(result, str): From 16ca0d181e7751e8aa1e399f1aeb039520472e32 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 27 May 2026 16:48:51 +0800 Subject: [PATCH 15/16] fix(mlx): address save export review cases --- tests/test_mlx_save_export_regressions.py | 38 +++++++++++++++++++++++ unsloth_zoo/mlx/utils.py | 23 +++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/tests/test_mlx_save_export_regressions.py b/tests/test_mlx_save_export_regressions.py index 34f29d86e..f50ac6103 100644 --- a/tests/test_mlx_save_export_regressions.py +++ b/tests/test_mlx_save_export_regressions.py @@ -402,6 +402,32 @@ def sanitize(weights): ) +def test_vlm_rewrite_skips_untransformable_text_tensors(): + import torch + import unsloth_zoo.mlx.utils as mutils + + calls = 0 + + class CountingSanitizer: + @staticmethod + def sanitize(weights): + nonlocal calls + calls += 1 + return weights + + tensor = torch.zeros(2, 3) + new_name, new_tensor, changed = mutils._rewrite_mlx_vlm_tensor_for_gguf( + "language_model.model.layers.0.mlp.gate_proj.weight", + tensor, + [(CountingSanitizer, None)], + ) + + assert calls == 0 + assert changed is False + assert new_name == "language_model.model.layers.0.mlp.gate_proj.weight" + assert new_tensor is tensor + + def test_mlx_arrays_match_checks_2d_tensor_values(monkeypatch): import torch import unsloth_zoo.mlx.utils as mutils @@ -756,6 +782,18 @@ def test_copy_source_sidecars_preserves_image_processor_metadata(tmp_path): assert not (dst / skipped).exists() +def test_copy_source_sidecars_ignores_non_directory_source(tmp_path): + import unsloth_zoo.mlx.utils as mutils + + src = tmp_path / "model.safetensors" + dst = tmp_path / "dst" + src.write_text("weights", encoding="utf-8") + dst.mkdir() + + assert mutils._copy_source_sidecars(src, dst) == 0 + assert list(dst.iterdir()) == [] + + def test_save_pretrained_gguf_anchors_patcher_to_checked_llama_cpp_root( monkeypatch, tmp_path, diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index e99329c41..64e1fb0a4 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -3898,6 +3898,24 @@ def _vlm_gguf_tensor_candidates(tensor): return candidates +def _has_vlm_gguf_tensor_candidate(tensor): + """Return whether a tensor shape can require HF-layout recovery.""" + shape = getattr(tensor, "shape", ()) + if len(shape) in (4, 5): + return True + if len(shape) == 1: + dtype = getattr(tensor, "dtype", None) + return dtype is not None and mx.issubdtype(dtype, mx.floating) + return False + + +def _has_vlm_gguf_rewrite_candidate(name, tensor): + """Return whether a tensor can differ between mlx-vlm and GGUF layouts.""" + if any(candidate_name != name for candidate_name in _vlm_gguf_name_candidates(name)): + return True + return _has_vlm_gguf_tensor_candidate(tensor) + + def _mlx_arrays_match(actual, expected): """Compare MLX-like arrays without assuming a concrete backend type.""" shape = getattr(actual, "shape", None) @@ -3931,6 +3949,9 @@ def _normalize_mlx_vlm_sanitize_pipelines(sanitize_steps): def _rewrite_mlx_vlm_tensor_for_gguf(name, tensor, sanitize_steps): """Invert mlx-vlm sanitizers to recover HF tensor names/layouts for GGUF.""" + if not _has_vlm_gguf_rewrite_candidate(name, tensor): + return name, tensor, False + for candidate_name in _vlm_gguf_name_candidates(name): for candidate_tensor in _vlm_gguf_tensor_candidates(tensor): for pipeline in _normalize_mlx_vlm_sanitize_pipelines(sanitize_steps): @@ -4099,7 +4120,7 @@ def _copy_source_sidecars(src_path, path): copied = 0 src_path = Path(src_path) path = Path(path) - if not src_path.exists(): + if not src_path.is_dir(): return copied for source in src_path.iterdir(): if not source.is_file(): From f73f432180b0fe1b9c9ea6949e49f5f6e3a6e275 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 27 May 2026 19:27:31 +0800 Subject: [PATCH 16/16] revert(mlx): drop tied lm head materialization This reverts commit d8cdf89e90c5122a38ab66766ee3b2f11877160f. Also removes the related regression coverage and later helper hardening now that bug-4 is out of scope. --- tests/test_mlx_save_export_regressions.py | 57 ------ unsloth_zoo/mlx/utils.py | 214 ---------------------- 2 files changed, 271 deletions(-) diff --git a/tests/test_mlx_save_export_regressions.py b/tests/test_mlx_save_export_regressions.py index f50ac6103..3051962e7 100644 --- a/tests/test_mlx_save_export_regressions.py +++ b/tests/test_mlx_save_export_regressions.py @@ -135,54 +135,6 @@ def fake_save_config(config, path): assert "quantization_config" not in calls["saved_config"]["nested"] -def test_materialize_tied_lm_head_updates_saved_index(monkeypatch, tmp_path): - import torch - import unsloth_zoo.mlx.utils as mutils - - shard = tmp_path / "model-00001-of-00001.safetensors" - shard.write_text("placeholder", encoding="utf-8") - index_path = tmp_path / "model.safetensors.index.json" - index_path.write_text( - json.dumps( - { - "metadata": {"total_size": 24, "total_parameters": 6}, - "weight_map": { - "model.embed_tokens.weight": shard.name, - }, - } - ), - encoding="utf-8", - ) - - saved = {} - embed = torch.ones(2, 3, dtype=torch.float32) - monkeypatch.setattr( - mutils.mx, - "load", - lambda path: {"model.embed_tokens.weight": embed}, - ) - monkeypatch.setattr(mutils.mx, "eval", lambda *values: None) - - def fake_save_safetensors(path, tensors, metadata=None): - saved["path"] = Path(path) - saved["tensors"] = tensors - Path(path).write_text("saved", encoding="utf-8") - - monkeypatch.setattr(mutils.mx, "save_safetensors", fake_save_safetensors) - - added = mutils._materialize_tied_lm_head_in_saved_model( - tmp_path, - {"tie_word_embeddings": True}, - ) - - assert added == 1 - assert "lm_head.weight" in saved["tensors"] - updated = json.loads(index_path.read_text(encoding="utf-8")) - assert updated["weight_map"]["lm_head.weight"] == shard.name - assert updated["metadata"]["total_size"] > 24 - assert updated["metadata"]["total_parameters"] == 12 - - def test_bound_gguf_save_filters_cuda_only_kwargs(monkeypatch, tmp_path): import unsloth_zoo.mlx.loader as loader import unsloth_zoo.mlx.utils as mutils @@ -595,15 +547,6 @@ def test_get_model_config_prefers_copied_raw_config(): assert raw_config["nested"]["values"] == [1] -def test_has_tied_word_embeddings_ignores_malformed_thinker_config(): - import unsloth_zoo.mlx.utils as mutils - - assert not mutils._has_tied_word_embeddings({"thinker_config": "bad"}) - assert mutils._has_tied_word_embeddings( - {"thinker_config": {"text_config": {"tie_word_embeddings": True}}} - ) - - def test_has_vision_config_handles_nested_and_malformed_configs(): import unsloth_zoo.mlx.utils as mutils diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 64e1fb0a4..922e2b918 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -3493,208 +3493,6 @@ def _save_mlx_config(config, config_path, *, is_vlm=False): save_lm_config(config, config_path) -def _has_tied_word_embeddings(config): - """Return whether any text config declares tied input/output embeddings.""" - if not isinstance(config, dict): - return False - thinker_config = config.get("thinker_config") - candidates = [ - config, - config.get("text_config"), - config.get("language_config"), - ( - thinker_config.get("text_config") - if isinstance(thinker_config, dict) - else None - ), - ] - return any( - isinstance(item, dict) and item.get("tie_word_embeddings") is True - for item in candidates - ) - - -def _lm_head_key_for_embed_key(embed_key): - """Map an input embedding tensor key to the matching tied LM head key.""" - if embed_key == "embed_tokens.weight": - return "lm_head.weight" - if embed_key == "model.embed_tokens.weight": - return "lm_head.weight" - - suffix = ".model.embed_tokens.weight" - if embed_key.endswith(suffix): - prefix = embed_key[:-len(suffix)] - return f"{prefix}.lm_head.weight" if prefix else "lm_head.weight" - - suffix = ".embed_tokens.weight" - if embed_key.endswith(suffix): - prefix = embed_key[:-len(suffix)] - if not prefix or prefix == "model": - return "lm_head.weight" - return f"{prefix}.lm_head.weight" - - return None - - -def _safetensor_names(path): - """Read tensor names from a safetensors directory or index file.""" - path = Path(path) - index_path = path / "model.safetensors.index.json" - if index_path.exists(): - with open(index_path, "r") as f: - return set(json.load(f).get("weight_map", {})) - - try: - from safetensors import safe_open - except Exception: - return set() - - names = set() - for file in sorted(path.glob("*.safetensors")): - with safe_open(str(file), framework="np") as f: - names.update(f.keys()) - return names - - -def _source_has_lm_head_tensor(source_path, lm_head_key): - """Check whether the source checkpoint explicitly stored an LM head.""" - if source_path is None: - return None - source_path = Path(source_path) - if not source_path.exists(): - return None - - names = _safetensor_names(source_path) - if not names: - return None - if lm_head_key in names: - return True - if lm_head_key.endswith(".lm_head.weight") and "lm_head.weight" in names: - return True - return False - - -def _tensor_nbytes(tensor): - """Return tensor byte size across MLX, NumPy, and torch-like objects.""" - value = getattr(tensor, "nbytes", None) - if callable(value): - value = value() - if value is not None: - return int(value) - itemsize = getattr(tensor, "itemsize", None) - if callable(itemsize): - itemsize = itemsize() - if itemsize is None: - element_size = getattr(tensor, "element_size", None) - itemsize = element_size() if callable(element_size) else 0 - return int(_tensor_size(tensor) * itemsize) - - -def _tensor_size(tensor): - """Return tensor element count across MLX, NumPy, and torch-like objects.""" - value = getattr(tensor, "size", None) - if callable(value): - numel = getattr(tensor, "numel", None) - if callable(numel): - return int(numel()) - shape = getattr(tensor, "shape", ()) - total = 1 - for dim in shape: - total *= int(dim) - return total - if value is not None: - return int(value) - return 0 - - -def _duplicate_tensor_for_safetensors(tensor): - """Clone tensors when available before writing a tied duplicate key.""" - clone = getattr(tensor, "clone", None) - if callable(clone): - return clone() - return tensor - - -def _materialize_tied_lm_head_in_saved_model( - path, - config, - *, - source_path=None, - is_vlm=False, -): - """Duplicate tied input embeddings into the saved LM head when CUDA does.""" - if not _has_tied_word_embeddings(config): - return 0 - - path = Path(path) - index_path = path / "model.safetensors.index.json" - if not index_path.exists(): - return 0 - - with open(index_path, "r") as f: - index_data = json.load(f) - - weight_map = dict(index_data.get("weight_map", {})) - additions = [] - for embed_key in sorted(weight_map): - if not embed_key.endswith("embed_tokens.weight"): - continue - lm_head_key = _lm_head_key_for_embed_key(embed_key) - if lm_head_key is None or lm_head_key in weight_map: - continue - - source_has_lm_head = _source_has_lm_head_tensor(source_path, lm_head_key) - if source_has_lm_head is False: - continue - if source_has_lm_head is None and is_vlm: - continue - if source_has_lm_head is None and lm_head_key != "lm_head.weight": - continue - - additions.append((embed_key, lm_head_key)) - - added = 0 - added_bytes = 0 - added_parameters = 0 - for embed_key, lm_head_key in additions: - shard_name = weight_map[embed_key] - shard_path = path / shard_name - tensors = mx.load(str(shard_path)) - if lm_head_key in tensors: - weight_map[lm_head_key] = shard_name - continue - if embed_key not in tensors: - continue - - tensor = tensors[embed_key] - tensors[lm_head_key] = _duplicate_tensor_for_safetensors(tensor) - mx.eval(*tensors.values()) - tmp_file = shard_path.with_name(f"{shard_path.stem}.tmp{shard_path.suffix}") - mx.save_safetensors(str(tmp_file), tensors, metadata={"format": "mlx"}) - os.replace(tmp_file, shard_path) - - weight_map[lm_head_key] = shard_name - added += 1 - added_bytes += _tensor_nbytes(tensor) - added_parameters += _tensor_size(tensor) - - if added: - metadata = index_data.setdefault("metadata", {}) - if "total_size" in metadata: - metadata["total_size"] = int(metadata["total_size"]) + added_bytes - if "total_parameters" in metadata: - metadata["total_parameters"] = ( - int(metadata["total_parameters"]) + added_parameters - ) - index_data["weight_map"] = { - key: weight_map[key] for key in sorted(weight_map) - } - with open(index_path, "w") as f: - json.dump(index_data, f, indent=4) - - return added - - def _has_vision_config(config): """Return whether a raw or thinker-wrapped VLM config has vision settings.""" if not isinstance(config, dict): @@ -4142,7 +3940,6 @@ def _copy_source_sidecars(src_path, path): copied += 1 return copied - def save_merged_model(model, tokenizer, path, dequantize=False): """Fuse LoRA weights and save the full merged model. @@ -4190,17 +3987,6 @@ def save_merged_model(model, tokenizer, path, dequantize=False): config = _get_model_config(model) if config: is_vlm = _is_vlm_model(model) or _has_vision_config(config) - materialized = _materialize_tied_lm_head_in_saved_model( - path, - config, - source_path=_get_src_path(model), - is_vlm=is_vlm, - ) - if materialized: - print( - "Unsloth: Materialized " - f"{materialized} tied lm_head tensor(s) for CUDA export parity." - ) _save_mlx_config( config, path / "config.json",