Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,19 @@ def _maybe_qq(m):
model.update_modules(leaves)

model.eval()
expected_keys = {key for key, _ in tree_flatten(model.parameters())}
unknown_keys = [key for key in weights if key not in expected_keys]
if unknown_keys:
sample_keys = sorted(unknown_keys)[:3]
sample = ", ".join(sample_keys)
if len(unknown_keys) > 3:
sample = f"{sample}, ..."
print(
"[INFO] Dropping weights not present in the instantiated model: "
f"count={len(unknown_keys)} sample=[{sample}]"
)
weights = {key: value for key, value in weights.items() if key in expected_keys}

model.load_weights(list(weights.items()), strict=strict)

if not lazy:
Expand Down
120 changes: 120 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright © 2024 Apple Inc.

import json
import os
import tempfile
import unittest
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
Expand All @@ -13,6 +15,35 @@
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"


class TinyArgs:
def __init__(self, vocab_size, hidden_size):
self.vocab_size = vocab_size
self.hidden_size = hidden_size

@classmethod
def from_dict(cls, config):
return cls(config["vocab_size"], config["hidden_size"])


class TinyModel(nn.Module):
def __init__(self, args):
super().__init__()
self.embed = nn.Embedding(args.vocab_size, args.hidden_size)
self.proj = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def sanitize(self, weights):
clean = {}
for key, value in weights.items():
if key.startswith("language_model."):
key = key[len("language_model.") :]
clean[key] = value
return clean


def get_tiny_classes(config):
return TinyModel, TinyArgs


class TestUtils(unittest.TestCase):

@classmethod
Expand Down Expand Up @@ -123,6 +154,95 @@ def custom_get_classes(config):
self.assertEqual(model.custom_attribute, "This is a custom model")
self.assertTrue(hasattr(model, "qwenWeights"))

def _tiny_weights(self):
model = TinyModel(TinyArgs(16, 32))
mx.eval(model.parameters())
return dict(tree_flatten(model.parameters()))

def _quantized_tiny_weights(self):
weights = self._tiny_weights()
q_weight, scales, biases = mx.quantize(
weights["proj.weight"],
bits=4,
group_size=32,
)
return {
"embed.weight": weights["embed.weight"],
"proj.weight": q_weight,
"proj.scales": scales,
"proj.biases": biases,
}

def _write_model_dir(self, weights, extra_config=None):
model_dir = Path(tempfile.mkdtemp(dir=self.test_dir))
config = {
"model_type": "tiny",
"vocab_size": 16,
"hidden_size": 32,
}
if extra_config is not None:
config.update(extra_config)
with open(model_dir / "config.json", "w") as fid:
json.dump(config, fid)
mx.save_safetensors(str(model_dir / "model.safetensors"), weights)
return model_dir

def _load_tiny_model(self, model_dir):
return utils.load_model(
model_dir,
get_model_classes=get_tiny_classes,
)

def test_load_model_drops_unknown_weights(self):
base = self._tiny_weights()
weights = dict(base)
weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32)
weights["audio_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32)
model_dir = self._write_model_dir(weights)

model, _ = self._load_tiny_model(model_dir)

loaded = dict(tree_flatten(model.parameters()))
self.assertTrue(mx.allclose(loaded["embed.weight"], base["embed.weight"]))
self.assertTrue(mx.allclose(loaded["proj.weight"], base["proj.weight"]))

def test_load_model_drops_unknown_weights_after_sanitize(self):
base = self._tiny_weights()
weights = {f"language_model.{key}": value for key, value in base.items()}
weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32)
model_dir = self._write_model_dir(weights)

model, _ = self._load_tiny_model(model_dir)

loaded = dict(tree_flatten(model.parameters()))
self.assertTrue(mx.allclose(loaded["embed.weight"], base["embed.weight"]))
self.assertTrue(mx.allclose(loaded["proj.weight"], base["proj.weight"]))

def test_load_model_still_fails_for_missing_supported_weights(self):
weights = self._tiny_weights()
weights.pop("proj.weight")
model_dir = self._write_model_dir(weights)

with self.assertRaises(ValueError):
self._load_tiny_model(model_dir)

def test_load_model_keeps_supported_quantized_weights(self):
weights = self._quantized_tiny_weights()
weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32)
model_dir = self._write_model_dir(
weights,
extra_config={"quantization": {"bits": 4, "group_size": 32}},
)

model, _ = self._load_tiny_model(model_dir)

loaded = dict(tree_flatten(model.parameters()))
self.assertIn("proj.weight", loaded)
self.assertIn("proj.scales", loaded)
self.assertIn("proj.biases", loaded)
self.assertTrue(mx.allclose(loaded["proj.scales"], weights["proj.scales"]))
self.assertTrue(mx.allclose(loaded["proj.biases"], weights["proj.biases"]))


if __name__ == "__main__":
unittest.main()