diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index 70bf8c83f..5ff19e2bb 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -412,6 +412,19 @@ def _maybe_qq(m): model.update_modules(leaves) model.eval() + expected_keys = {key for key, _ in tree_flatten(model.parameters())} + unknown_keys = [key for key in weights if key not in expected_keys] + if unknown_keys: + sample_keys = sorted(unknown_keys)[:3] + sample = ", ".join(sample_keys) + if len(unknown_keys) > 3: + sample = f"{sample}, ..." + print( + "[INFO] Dropping weights not present in the instantiated model: " + f"count={len(unknown_keys)} sample=[{sample}]" + ) + weights = {key: value for key, value in weights.items() if key in expected_keys} + model.load_weights(list(weights.items()), strict=strict) if not lazy: diff --git a/tests/test_utils.py b/tests/test_utils.py index 3434b4ac7..d55ad346b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,10 @@ # Copyright © 2024 Apple Inc. +import json import os import tempfile import unittest +from pathlib import Path import mlx.core as mx import mlx.nn as nn @@ -13,6 +15,35 @@ HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" +class TinyArgs: + def __init__(self, vocab_size, hidden_size): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + + @classmethod + def from_dict(cls, config): + return cls(config["vocab_size"], config["hidden_size"]) + + +class TinyModel(nn.Module): + def __init__(self, args): + super().__init__() + self.embed = nn.Embedding(args.vocab_size, args.hidden_size) + self.proj = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def sanitize(self, weights): + clean = {} + for key, value in weights.items(): + if key.startswith("language_model."): + key = key[len("language_model.") :] + clean[key] = value + return clean + + +def get_tiny_classes(config): + return TinyModel, TinyArgs + + class TestUtils(unittest.TestCase): @classmethod @@ -123,6 +154,95 @@ def custom_get_classes(config): self.assertEqual(model.custom_attribute, "This is a custom model") self.assertTrue(hasattr(model, "qwenWeights")) + def _tiny_weights(self): + model = TinyModel(TinyArgs(16, 32)) + mx.eval(model.parameters()) + return dict(tree_flatten(model.parameters())) + + def _quantized_tiny_weights(self): + weights = self._tiny_weights() + q_weight, scales, biases = mx.quantize( + weights["proj.weight"], + bits=4, + group_size=32, + ) + return { + "embed.weight": weights["embed.weight"], + "proj.weight": q_weight, + "proj.scales": scales, + "proj.biases": biases, + } + + def _write_model_dir(self, weights, extra_config=None): + model_dir = Path(tempfile.mkdtemp(dir=self.test_dir)) + config = { + "model_type": "tiny", + "vocab_size": 16, + "hidden_size": 32, + } + if extra_config is not None: + config.update(extra_config) + with open(model_dir / "config.json", "w") as fid: + json.dump(config, fid) + mx.save_safetensors(str(model_dir / "model.safetensors"), weights) + return model_dir + + def _load_tiny_model(self, model_dir): + return utils.load_model( + model_dir, + get_model_classes=get_tiny_classes, + ) + + def test_load_model_drops_unknown_weights(self): + base = self._tiny_weights() + weights = dict(base) + weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) + weights["audio_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) + model_dir = self._write_model_dir(weights) + + model, _ = self._load_tiny_model(model_dir) + + loaded = dict(tree_flatten(model.parameters())) + self.assertTrue(mx.allclose(loaded["embed.weight"], base["embed.weight"])) + self.assertTrue(mx.allclose(loaded["proj.weight"], base["proj.weight"])) + + def test_load_model_drops_unknown_weights_after_sanitize(self): + base = self._tiny_weights() + weights = {f"language_model.{key}": value for key, value in base.items()} + weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) + model_dir = self._write_model_dir(weights) + + model, _ = self._load_tiny_model(model_dir) + + loaded = dict(tree_flatten(model.parameters())) + self.assertTrue(mx.allclose(loaded["embed.weight"], base["embed.weight"])) + self.assertTrue(mx.allclose(loaded["proj.weight"], base["proj.weight"])) + + def test_load_model_still_fails_for_missing_supported_weights(self): + weights = self._tiny_weights() + weights.pop("proj.weight") + model_dir = self._write_model_dir(weights) + + with self.assertRaises(ValueError): + self._load_tiny_model(model_dir) + + def test_load_model_keeps_supported_quantized_weights(self): + weights = self._quantized_tiny_weights() + weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) + model_dir = self._write_model_dir( + weights, + extra_config={"quantization": {"bits": 4, "group_size": 32}}, + ) + + model, _ = self._load_tiny_model(model_dir) + + loaded = dict(tree_flatten(model.parameters())) + self.assertIn("proj.weight", loaded) + self.assertIn("proj.scales", loaded) + self.assertIn("proj.biases", loaded) + self.assertTrue(mx.allclose(loaded["proj.scales"], weights["proj.scales"])) + self.assertTrue(mx.allclose(loaded["proj.biases"], weights["proj.biases"])) + if __name__ == "__main__": unittest.main()