From ae961c619d08b95845182223c9c10b7d92660d06 Mon Sep 17 00:00:00 2001 From: spicyneuron <183504714+spicyneuron@users.noreply.github.com> Date: Sat, 21 Mar 2026 20:05:34 +0800 Subject: [PATCH 1/2] Add flag to drop unknown weights --- mlx_lm/benchmark.py | 12 ++- mlx_lm/cache_prompt.py | 6 ++ mlx_lm/chat.py | 13 ++- mlx_lm/convert.py | 7 ++ mlx_lm/evaluate.py | 11 ++- mlx_lm/fuse.py | 10 ++- mlx_lm/generate.py | 11 ++- mlx_lm/lora.py | 11 ++- mlx_lm/perplexity.py | 11 ++- mlx_lm/quant/awq.py | 12 ++- mlx_lm/quant/dwq.py | 19 ++++- mlx_lm/quant/dynamic_quant.py | 11 ++- mlx_lm/quant/gptq.py | 12 ++- mlx_lm/server.py | 28 ++++++- mlx_lm/utils.py | 56 +++++++++++-- tests/test_utils.py | 148 ++++++++++++++++++++++++++++++++++ 16 files changed, 357 insertions(+), 21 deletions(-) diff --git a/mlx_lm/benchmark.py b/mlx_lm/benchmark.py index 3b2cb66a4..9682e67f3 100644 --- a/mlx_lm/benchmark.py +++ b/mlx_lm/benchmark.py @@ -73,6 +73,11 @@ def setup_arg_parser(): default=0, help="Delay between each test in seconds (default: 0)", ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) return parser @@ -94,7 +99,11 @@ def rprint(*args, **kwargs): if group.size() > 1: model, tokenizer, config = sharded_load( - model_path, pipeline_group, tensor_group, return_config=True + model_path, + pipeline_group, + tensor_group, + return_config=True, + drop_unknown_weights=args.drop_unknown_weights, ) else: model, tokenizer, config = load( @@ -102,6 +111,7 @@ def rprint(*args, **kwargs): return_config=True, tokenizer_config={"trust_remote_code": True}, model_config={"quantize_activations": args.quantize_activations}, + drop_unknown_weights=args.drop_unknown_weights, ) # Empty to avoid early stopping diff --git a/mlx_lm/cache_prompt.py b/mlx_lm/cache_prompt.py index 096eecc42..3bf54d64a 100644 --- a/mlx_lm/cache_prompt.py +++ b/mlx_lm/cache_prompt.py @@ -77,6 +77,11 @@ def setup_arg_parser(): type=int, default=DEFAULT_QUANTIZED_KV_START, ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) return parser @@ -93,6 +98,7 @@ def main(): args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, + drop_unknown_weights=args.drop_unknown_weights, ) args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt diff --git a/mlx_lm/chat.py b/mlx_lm/chat.py index e168e2827..f611e2449 100644 --- a/mlx_lm/chat.py +++ b/mlx_lm/chat.py @@ -84,6 +84,11 @@ def setup_arg_parser(): action="store_true", help="Use pipelining instead of tensor parallelism", ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) return parser @@ -105,7 +110,12 @@ def rprint(*args, **kwargs): if group.size() > 1: if args.adapter_path: parser.error("Adapters not supported in distributed mode") - model, tokenizer = sharded_load(args.model, pipeline_group, tensor_group) + model, tokenizer = sharded_load( + args.model, + pipeline_group, + tensor_group, + drop_unknown_weights=args.drop_unknown_weights, + ) else: model, tokenizer = load( args.model, @@ -113,6 +123,7 @@ def rprint(*args, **kwargs): tokenizer_config={ "trust_remote_code": True if args.trust_remote_code else None }, + drop_unknown_weights=args.drop_unknown_weights, ) def print_help(): diff --git a/mlx_lm/convert.py b/mlx_lm/convert.py index ab3fc62ac..87d92bbda 100644 --- a/mlx_lm/convert.py +++ b/mlx_lm/convert.py @@ -97,6 +97,7 @@ def convert( Union[Callable[[str, nn.Module, dict], Union[bool, dict]], str] ] = None, trust_remote_code: bool = False, + drop_unknown_weights: bool = False, ): # Check the save path is empty if isinstance(mlx_path, str): @@ -115,6 +116,7 @@ def convert( return_config=True, tokenizer_config={"trust_remote_code": trust_remote_code}, lazy=True, + drop_unknown_weights=drop_unknown_weights, ) if isinstance(quant_predicate, str): @@ -250,6 +252,11 @@ def configure_parser() -> argparse.ArgumentParser: action="store_true", default=False, ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) return parser diff --git a/mlx_lm/evaluate.py b/mlx_lm/evaluate.py index f5170e763..82eec068b 100644 --- a/mlx_lm/evaluate.py +++ b/mlx_lm/evaluate.py @@ -81,11 +81,14 @@ def __init__( use_chat_template: Optional[bool] = None, trust_remote_code: bool = False, sampler: Optional[Callable[[mx.array], mx.array]] = None, + drop_unknown_weights: bool = False, ) -> None: super().__init__() tokenizer_config = {"trust_remote_code": True if trust_remote_code else None} self._model, self.tokenizer = load( - path_or_hf_repo, tokenizer_config=tokenizer_config + path_or_hf_repo, + tokenizer_config=tokenizer_config, + drop_unknown_weights=drop_unknown_weights, ) self._max_tokens = max_tokens self._batch_size = batch_size @@ -455,6 +458,11 @@ def main(): parser.add_argument("--temp", type=float, default=0.0, help="Sampling temperature") parser.add_argument("--top-p", type=float, default=1.0, help="Sampling top-p") parser.add_argument("--top-k", type=int, default=0, help="Sampling top-k") + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) args = parser.parse_args() output_dir = Path(args.output_dir) @@ -483,6 +491,7 @@ def main(): use_chat_template=args.apply_chat_template, trust_remote_code=args.trust_remote_code, sampler=sampler, + drop_unknown_weights=args.drop_unknown_weights, ) MLXLM.apply_chat_template = chat_template_fn(**args.chat_template_args) diff --git a/mlx_lm/fuse.py b/mlx_lm/fuse.py index 87f667752..4f95be110 100644 --- a/mlx_lm/fuse.py +++ b/mlx_lm/fuse.py @@ -54,6 +54,11 @@ def parse_arguments() -> argparse.Namespace: default="ggml-model-f16.gguf", type=str, ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) return parser.parse_args() @@ -62,7 +67,10 @@ def main() -> None: args = parse_arguments() model, tokenizer, config = load( - args.model, adapter_path=args.adapter_path, return_config=True + args.model, + adapter_path=args.adapter_path, + return_config=True, + drop_unknown_weights=args.drop_unknown_weights, ) fused_linears = [ diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index ef8dbf7bf..944fed4c5 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -215,6 +215,11 @@ def setup_arg_parser(): help="Number of tokens to draft when using speculative decoding.", default=3, ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) return parser @@ -1455,6 +1460,7 @@ def main(): adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, model_config={"quantize_activations": args.quantize_activations}, + drop_unknown_weights=args.drop_unknown_weights, ) for eos_token in args.extra_eos_token: tokenizer.add_eos_token(eos_token) @@ -1499,7 +1505,10 @@ def main(): prompt = tokenizer.encode(prompt) if args.draft_model is not None: - draft_model, draft_tokenizer = load(args.draft_model) + draft_model, draft_tokenizer = load( + args.draft_model, + drop_unknown_weights=args.drop_unknown_weights, + ) if draft_tokenizer.vocab_size != tokenizer.vocab_size: raise ValueError("Draft model tokenizer does not match model tokenizer.") else: diff --git a/mlx_lm/lora.py b/mlx_lm/lora.py index b5fcce349..a17749446 100644 --- a/mlx_lm/lora.py +++ b/mlx_lm/lora.py @@ -210,6 +210,11 @@ def build_parser(): help="Project name for logging. Defaults to the name of the root directory.", ) parser.add_argument("--seed", type=int, help="The PRNG seed") + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) return parser @@ -326,7 +331,11 @@ def run(args, training_callback: TrainingCallback = None): ) print("Loading pretrained model") - model, tokenizer = load(args.model, tokenizer_config={"trust_remote_code": True}) + model, tokenizer = load( + args.model, + tokenizer_config={"trust_remote_code": True}, + drop_unknown_weights=args.drop_unknown_weights, + ) print("Loading datasets") train_set, valid_set, test_set = load_dataset(args, tokenizer) diff --git a/mlx_lm/perplexity.py b/mlx_lm/perplexity.py index 64cde5d93..63781aed5 100644 --- a/mlx_lm/perplexity.py +++ b/mlx_lm/perplexity.py @@ -135,6 +135,11 @@ def main(): parser.add_argument( "--seed", type=int, default=123, help="Random seed for data sampling" ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) args = parser.parse_args() @@ -145,7 +150,11 @@ def main(): # Load model print(f"Loading model from {args.model}...") tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} - model, tokenizer = load(args.model, tokenizer_config=tokenizer_config) + model, tokenizer = load( + args.model, + tokenizer_config=tokenizer_config, + drop_unknown_weights=args.drop_unknown_weights, + ) # Count parameters total_params = get_total_parameters(model) diff --git a/mlx_lm/quant/awq.py b/mlx_lm/quant/awq.py index 13b64057d..949f8f2f3 100644 --- a/mlx_lm/quant/awq.py +++ b/mlx_lm/quant/awq.py @@ -544,6 +544,11 @@ def main(): parser.add_argument("--sequence-length", type=int, default=512) parser.add_argument("--n-grid", type=int, default=20) parser.add_argument("--seed", type=int, default=123) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) args = parser.parse_args() group = mx.distributed.init() @@ -554,7 +559,12 @@ def main(): mx.random.seed(args.seed) - model, tokenizer, config = load(args.model, lazy=True, return_config=True) + model, tokenizer, config = load( + args.model, + lazy=True, + return_config=True, + drop_unknown_weights=args.drop_unknown_weights, + ) model_type = config["model_type"] if (awq_config := AWQ_MODEL_CONFIGS.get(model_type, None)) is None: diff --git a/mlx_lm/quant/dwq.py b/mlx_lm/quant/dwq.py index d7b8144d0..dd60b55a9 100644 --- a/mlx_lm/quant/dwq.py +++ b/mlx_lm/quant/dwq.py @@ -300,6 +300,11 @@ def main(): action="store_true", help="Use pipeline parallel instead of data parallel.", ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) args = parser.parse_args() @@ -328,9 +333,18 @@ def main(): # Load the base model if we need it if not has_targets or args.quantized_model is None: if args.pipeline and group.size() > 1: - model, _, config = pipeline_load(args.model, return_config=True) + model, _, config = pipeline_load( + args.model, + return_config=True, + drop_unknown_weights=args.drop_unknown_weights, + ) else: - model, _, config = load(args.model, return_config=True, lazy=True) + model, _, config = load( + args.model, + return_config=True, + lazy=True, + drop_unknown_weights=args.drop_unknown_weights, + ) else: model = None @@ -366,6 +380,7 @@ def target_fn(batch, idx, split): args.quantized_model, lazy=True, return_config=True, + drop_unknown_weights=args.drop_unknown_weights, ) if "quantization" not in config: raise ValueError("Quantized model must already be quantized.") diff --git a/mlx_lm/quant/dynamic_quant.py b/mlx_lm/quant/dynamic_quant.py index c339015aa..1ca8504ad 100644 --- a/mlx_lm/quant/dynamic_quant.py +++ b/mlx_lm/quant/dynamic_quant.py @@ -182,10 +182,19 @@ def main(): choices=["float32", "bfloat16"], help="What type to use to accumulate the gradients for the sensitivities", ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) args = parser.parse_args() group = mx.distributed.init() - model, tokenizer, config = load(args.model, return_config=True) + model, tokenizer, config = load( + args.model, + return_config=True, + drop_unknown_weights=args.drop_unknown_weights, + ) if args.sensitivities is None: mx.random.seed(args.seed) diff --git a/mlx_lm/quant/gptq.py b/mlx_lm/quant/gptq.py index a516b7f2f..54b171765 100644 --- a/mlx_lm/quant/gptq.py +++ b/mlx_lm/quant/gptq.py @@ -197,11 +197,21 @@ def main(): help="Sequence length for the calibration data.", ) parser.add_argument("--seed", type=int, default=123) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) args = parser.parse_args() mx.random.seed(args.seed) - model, tokenizer, config = load(args.model, lazy=True, return_config=True) + model, tokenizer, config = load( + args.model, + lazy=True, + return_config=True, + drop_unknown_weights=args.drop_unknown_weights, + ) calibration_data = load_data(tokenizer, args.num_samples, args.sequence_length) model, config["quantization"] = gptq_quantize( diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 7fc91fa2a..078ae610d 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -513,6 +513,7 @@ def load(self, model_path, adapter_path=None, draft_model_path=None): model_path = self.default_model_map.get(model_path, model_path) if self.model_key == (model_path, adapter_path, draft_model_path): return self.model, self.tokenizer + drop = self.cli_args.drop_unknown_weights # Remove the old model if it exists. self.model = None @@ -537,25 +538,33 @@ def load(self, model_path, adapter_path=None, draft_model_path=None): # TODO: Generalize distributed load if self.is_distributed: model, tokenizer = sharded_load( - self.cli_args.model, self.pipeline_group, self.tensor_group + self.cli_args.model, + self.pipeline_group, + self.tensor_group, + drop_unknown_weights=drop, ) else: model, tokenizer = load( self.cli_args.model, adapter_path=adapter_path, tokenizer_config=tokenizer_config, + drop_unknown_weights=drop, ) else: # TODO: Generalize distributed load if self.is_distributed: model, tokenizer = sharded_load( - model_path, self.pipeline_group, self.tensor_group + model_path, + self.pipeline_group, + self.tensor_group, + drop_unknown_weights=drop, ) else: model, tokenizer = load( model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config, + drop_unknown_weights=drop, ) if self.cli_args.use_default_chat_template: @@ -579,11 +588,17 @@ def validate_draft_tokenizer(draft_tokenizer): draft_model_path == "default_model" and self.cli_args.draft_model is not None ): - self.draft_model, draft_tokenizer = load(self.cli_args.draft_model) + self.draft_model, draft_tokenizer = load( + self.cli_args.draft_model, + drop_unknown_weights=drop, + ) validate_draft_tokenizer(draft_tokenizer) elif draft_model_path is not None and draft_model_path != "default_model": - self.draft_model, draft_tokenizer = load(draft_model_path) + self.draft_model, draft_tokenizer = load( + draft_model_path, + drop_unknown_weights=drop, + ) validate_draft_tokenizer(draft_tokenizer) if self.draft_model is None: @@ -2028,6 +2043,11 @@ def main(): action="store_true", help="Use pipelining instead of tensor parallelism", ) + parser.add_argument( + "--drop-unknown-weights", + action="store_true", + help="Drop weights not present in the instantiated model.", + ) args = parser.parse_args() if mx.metal.is_available(): wired_limit = mx.device_info()["max_recommended_working_set_size"] diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index 70bf8c83f..53137d0bd 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -285,6 +285,7 @@ def load_model( strict: bool = True, model_config: Optional[Dict[str, Any]] = None, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, + drop_unknown_weights: bool = False, ) -> Tuple[nn.Module, dict]: """ Load and initialize the model from a given path. @@ -301,6 +302,9 @@ def load_model( get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): A function that returns the model class and model args class given a config. Defaults to the ``_get_classes`` function. + drop_unknown_weights (bool): If ``True`` drop sanitized weights + that do not exist in the instantiated model before loading. + Default: ``False`` Returns: Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config. @@ -412,6 +416,20 @@ def _maybe_qq(m): model.update_modules(leaves) model.eval() + if drop_unknown_weights: + expected_keys = {key for key, _ in tree_flatten(model.parameters())} + unknown_keys = [key for key in weights if key not in expected_keys] + if unknown_keys: + sample_keys = sorted(unknown_keys)[:3] + sample = ", ".join(sample_keys) + if len(unknown_keys) > 3: + sample = f"{sample}, ..." + print( + "[INFO] Dropping weights not present in the instantiated model: " + f"count={len(unknown_keys)} sample=[{sample}]" + ) + weights = {key: value for key, value in weights.items() if key in expected_keys} + model.load_weights(list(weights.items()), strict=strict) if not lazy: @@ -458,6 +476,7 @@ def load( lazy: bool = False, return_config: bool = False, revision: Optional[str] = None, + drop_unknown_weights: bool = False, ) -> Union[ Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]], @@ -478,6 +497,8 @@ def load( when needed. Default: ``False`` return_config (bool: If ``True`` return the model config as the last item.. revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. + drop_unknown_weights (bool): If ``True`` drop sanitized weights + that are not supported by the instantiated model before loading. Returns: Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]: A tuple containing the loaded model, tokenizer and, if requested, the model config. @@ -488,7 +509,12 @@ def load( """ model_path = _download(path_or_hf_repo, revision=revision) - model, config = load_model(model_path, lazy, model_config=model_config) + model, config = load_model( + model_path, + lazy, + model_config=model_config, + drop_unknown_weights=drop_unknown_weights, + ) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() @@ -507,6 +533,7 @@ def sharded_load( pipeline_group: Optional[mx.distributed.Group] = None, tensor_group: Optional[mx.distributed.Group] = None, return_config: bool = False, + drop_unknown_weights: bool = False, ): # Get model path with everything but weight safetensors model_path = _download( @@ -525,7 +552,12 @@ def sharded_load( # Lazy load model to figure out what type of sharding we can do and which # weights we need to download. - model, config = load_model(model_path, lazy=True, strict=False) + model, config = load_model( + model_path, + lazy=True, + strict=False, + drop_unknown_weights=drop_unknown_weights, + ) has_pipelining = hasattr(model, "model") and hasattr(model.model, "pipeline") has_tensor_parallel = hasattr(model, "shard") @@ -574,7 +606,12 @@ def sharded_load( {"trust_remote_code": True}, eos_token_ids=config.get("eos_token_id", None), ) - model, _ = load_model(model_path, lazy=True, strict=False) + model, _ = load_model( + model_path, + lazy=True, + strict=False, + drop_unknown_weights=drop_unknown_weights, + ) if tensor_group is not None: model.shard(tensor_group) if pipeline_group is not None: @@ -589,8 +626,17 @@ def sharded_load( return model, tokenizer -def pipeline_load(repo, return_config=False): - return sharded_load(repo, mx.distributed.init(), None, return_config) +def pipeline_load( + repo, + return_config: bool = False, + drop_unknown_weights: bool = False, +): + return sharded_load( + repo, + pipeline_group=mx.distributed.init(), + return_config=return_config, + drop_unknown_weights=drop_unknown_weights, + ) def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: diff --git a/tests/test_utils.py b/tests/test_utils.py index 3434b4ac7..d6efda5ca 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,12 @@ # Copyright © 2024 Apple Inc. +import io +import json import os import tempfile import unittest +from contextlib import redirect_stdout +from pathlib import Path import mlx.core as mx import mlx.nn as nn @@ -13,6 +17,35 @@ HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" +class TinyArgs: + def __init__(self, vocab_size, hidden_size): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + + @classmethod + def from_dict(cls, config): + return cls(config["vocab_size"], config["hidden_size"]) + + +class TinyModel(nn.Module): + def __init__(self, args): + super().__init__() + self.embed = nn.Embedding(args.vocab_size, args.hidden_size) + self.proj = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def sanitize(self, weights): + clean = {} + for key, value in weights.items(): + if key.startswith("language_model."): + key = key[len("language_model.") :] + clean[key] = value + return clean + + +def get_tiny_classes(config): + return TinyModel, TinyArgs + + class TestUtils(unittest.TestCase): @classmethod @@ -123,6 +156,121 @@ def custom_get_classes(config): self.assertEqual(model.custom_attribute, "This is a custom model") self.assertTrue(hasattr(model, "qwenWeights")) + def _tiny_weights(self): + model = TinyModel(TinyArgs(16, 32)) + mx.eval(model.parameters()) + return dict(tree_flatten(model.parameters())) + + def _quantized_tiny_weights(self): + weights = self._tiny_weights() + q_weight, scales, biases = mx.quantize( + weights["proj.weight"], + bits=4, + group_size=32, + ) + return { + "embed.weight": weights["embed.weight"], + "proj.weight": q_weight, + "proj.scales": scales, + "proj.biases": biases, + } + + def _write_model_dir(self, weights, extra_config=None): + model_dir = Path(tempfile.mkdtemp(dir=self.test_dir)) + config = { + "model_type": "tiny", + "vocab_size": 16, + "hidden_size": 32, + } + if extra_config is not None: + config.update(extra_config) + with open(model_dir / "config.json", "w") as fid: + json.dump(config, fid) + mx.save_safetensors(str(model_dir / "model.safetensors"), weights) + return model_dir + + def _load_tiny_model(self, model_dir, **kwargs): + return utils.load_model(model_dir, get_model_classes=get_tiny_classes, **kwargs) + + def _capture_tiny_load(self, model_dir, **kwargs): + stdout = io.StringIO() + with redirect_stdout(stdout): + model, config = self._load_tiny_model(model_dir, **kwargs) + return model, config, stdout.getvalue() + + def test_load_model_drops_unknown_weights_when_enabled(self): + weights = self._tiny_weights() + weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) + weights["audio_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) + model_dir = self._write_model_dir(weights) + + with self.assertRaises(ValueError): + self._load_tiny_model(model_dir) + + # No output when there is nothing to drop + clean_dir = self._write_model_dir(self._tiny_weights()) + _, _, clean_output = self._capture_tiny_load( + clean_dir, drop_unknown_weights=True + ) + self.assertEqual(clean_output, "") + + model, _, output = self._capture_tiny_load( + model_dir, + drop_unknown_weights=True, + ) + + loaded = dict(tree_flatten(model.parameters())) + self.assertIn("count=2", output) + self.assertIn("audio_tower.encoder.weight", output) + self.assertIn("vision_tower.encoder.weight", output) + self.assertNotIn("vision_tower.encoder.weight", loaded) + self.assertNotIn("audio_tower.encoder.weight", loaded) + + def test_load_model_drops_unknown_weights_after_sanitize(self): + base = self._tiny_weights() + weights = {f"language_model.{key}": value for key, value in base.items()} + weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) + model_dir = self._write_model_dir(weights) + + model, _, output = self._capture_tiny_load( + model_dir, + drop_unknown_weights=True, + ) + + loaded = dict(tree_flatten(model.parameters())) + self.assertIn("count=1", output) + self.assertIn("vision_tower.encoder.weight", output) + self.assertTrue(mx.allclose(loaded["embed.weight"], base["embed.weight"])) + self.assertTrue(mx.allclose(loaded["proj.weight"], base["proj.weight"])) + + def test_load_model_still_fails_for_missing_supported_weights(self): + weights = self._tiny_weights() + weights.pop("proj.weight") + model_dir = self._write_model_dir(weights) + + with self.assertRaises(ValueError): + self._load_tiny_model(model_dir, drop_unknown_weights=True) + + def test_load_model_keeps_supported_quantized_weights(self): + weights = self._quantized_tiny_weights() + weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) + model_dir = self._write_model_dir( + weights, + extra_config={"quantization": {"bits": 4, "group_size": 32}}, + ) + + model, _ = self._load_tiny_model( + model_dir, + drop_unknown_weights=True, + ) + + loaded = dict(tree_flatten(model.parameters())) + self.assertIn("proj.weight", loaded) + self.assertIn("proj.scales", loaded) + self.assertIn("proj.biases", loaded) + self.assertTrue(mx.allclose(loaded["proj.scales"], weights["proj.scales"])) + self.assertTrue(mx.allclose(loaded["proj.biases"], weights["proj.biases"])) + if __name__ == "__main__": unittest.main() From 272f991745a9fa9ebd400a99433c32c3c67432f7 Mon Sep 17 00:00:00 2001 From: spicyneuron <183504714+spicyneuron@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:05:48 +0800 Subject: [PATCH 2/2] Remove CLI flags, drop as default behavior --- mlx_lm/benchmark.py | 12 +------ mlx_lm/cache_prompt.py | 6 ---- mlx_lm/chat.py | 13 +------ mlx_lm/convert.py | 7 ---- mlx_lm/evaluate.py | 11 +----- mlx_lm/fuse.py | 10 +----- mlx_lm/generate.py | 11 +----- mlx_lm/lora.py | 11 +----- mlx_lm/perplexity.py | 11 +----- mlx_lm/quant/awq.py | 12 +------ mlx_lm/quant/dwq.py | 19 ++-------- mlx_lm/quant/dynamic_quant.py | 11 +----- mlx_lm/quant/gptq.py | 12 +------ mlx_lm/server.py | 28 +++------------ mlx_lm/utils.py | 67 +++++++++-------------------------- tests/test_utils.py | 56 ++++++++--------------------- 16 files changed, 47 insertions(+), 250 deletions(-) diff --git a/mlx_lm/benchmark.py b/mlx_lm/benchmark.py index 9682e67f3..3b2cb66a4 100644 --- a/mlx_lm/benchmark.py +++ b/mlx_lm/benchmark.py @@ -73,11 +73,6 @@ def setup_arg_parser(): default=0, help="Delay between each test in seconds (default: 0)", ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) return parser @@ -99,11 +94,7 @@ def rprint(*args, **kwargs): if group.size() > 1: model, tokenizer, config = sharded_load( - model_path, - pipeline_group, - tensor_group, - return_config=True, - drop_unknown_weights=args.drop_unknown_weights, + model_path, pipeline_group, tensor_group, return_config=True ) else: model, tokenizer, config = load( @@ -111,7 +102,6 @@ def rprint(*args, **kwargs): return_config=True, tokenizer_config={"trust_remote_code": True}, model_config={"quantize_activations": args.quantize_activations}, - drop_unknown_weights=args.drop_unknown_weights, ) # Empty to avoid early stopping diff --git a/mlx_lm/cache_prompt.py b/mlx_lm/cache_prompt.py index 3bf54d64a..096eecc42 100644 --- a/mlx_lm/cache_prompt.py +++ b/mlx_lm/cache_prompt.py @@ -77,11 +77,6 @@ def setup_arg_parser(): type=int, default=DEFAULT_QUANTIZED_KV_START, ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) return parser @@ -98,7 +93,6 @@ def main(): args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, - drop_unknown_weights=args.drop_unknown_weights, ) args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt diff --git a/mlx_lm/chat.py b/mlx_lm/chat.py index f611e2449..e168e2827 100644 --- a/mlx_lm/chat.py +++ b/mlx_lm/chat.py @@ -84,11 +84,6 @@ def setup_arg_parser(): action="store_true", help="Use pipelining instead of tensor parallelism", ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) return parser @@ -110,12 +105,7 @@ def rprint(*args, **kwargs): if group.size() > 1: if args.adapter_path: parser.error("Adapters not supported in distributed mode") - model, tokenizer = sharded_load( - args.model, - pipeline_group, - tensor_group, - drop_unknown_weights=args.drop_unknown_weights, - ) + model, tokenizer = sharded_load(args.model, pipeline_group, tensor_group) else: model, tokenizer = load( args.model, @@ -123,7 +113,6 @@ def rprint(*args, **kwargs): tokenizer_config={ "trust_remote_code": True if args.trust_remote_code else None }, - drop_unknown_weights=args.drop_unknown_weights, ) def print_help(): diff --git a/mlx_lm/convert.py b/mlx_lm/convert.py index 87d92bbda..ab3fc62ac 100644 --- a/mlx_lm/convert.py +++ b/mlx_lm/convert.py @@ -97,7 +97,6 @@ def convert( Union[Callable[[str, nn.Module, dict], Union[bool, dict]], str] ] = None, trust_remote_code: bool = False, - drop_unknown_weights: bool = False, ): # Check the save path is empty if isinstance(mlx_path, str): @@ -116,7 +115,6 @@ def convert( return_config=True, tokenizer_config={"trust_remote_code": trust_remote_code}, lazy=True, - drop_unknown_weights=drop_unknown_weights, ) if isinstance(quant_predicate, str): @@ -252,11 +250,6 @@ def configure_parser() -> argparse.ArgumentParser: action="store_true", default=False, ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) return parser diff --git a/mlx_lm/evaluate.py b/mlx_lm/evaluate.py index 82eec068b..f5170e763 100644 --- a/mlx_lm/evaluate.py +++ b/mlx_lm/evaluate.py @@ -81,14 +81,11 @@ def __init__( use_chat_template: Optional[bool] = None, trust_remote_code: bool = False, sampler: Optional[Callable[[mx.array], mx.array]] = None, - drop_unknown_weights: bool = False, ) -> None: super().__init__() tokenizer_config = {"trust_remote_code": True if trust_remote_code else None} self._model, self.tokenizer = load( - path_or_hf_repo, - tokenizer_config=tokenizer_config, - drop_unknown_weights=drop_unknown_weights, + path_or_hf_repo, tokenizer_config=tokenizer_config ) self._max_tokens = max_tokens self._batch_size = batch_size @@ -458,11 +455,6 @@ def main(): parser.add_argument("--temp", type=float, default=0.0, help="Sampling temperature") parser.add_argument("--top-p", type=float, default=1.0, help="Sampling top-p") parser.add_argument("--top-k", type=int, default=0, help="Sampling top-k") - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) args = parser.parse_args() output_dir = Path(args.output_dir) @@ -491,7 +483,6 @@ def main(): use_chat_template=args.apply_chat_template, trust_remote_code=args.trust_remote_code, sampler=sampler, - drop_unknown_weights=args.drop_unknown_weights, ) MLXLM.apply_chat_template = chat_template_fn(**args.chat_template_args) diff --git a/mlx_lm/fuse.py b/mlx_lm/fuse.py index 4f95be110..87f667752 100644 --- a/mlx_lm/fuse.py +++ b/mlx_lm/fuse.py @@ -54,11 +54,6 @@ def parse_arguments() -> argparse.Namespace: default="ggml-model-f16.gguf", type=str, ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) return parser.parse_args() @@ -67,10 +62,7 @@ def main() -> None: args = parse_arguments() model, tokenizer, config = load( - args.model, - adapter_path=args.adapter_path, - return_config=True, - drop_unknown_weights=args.drop_unknown_weights, + args.model, adapter_path=args.adapter_path, return_config=True ) fused_linears = [ diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 944fed4c5..ef8dbf7bf 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -215,11 +215,6 @@ def setup_arg_parser(): help="Number of tokens to draft when using speculative decoding.", default=3, ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) return parser @@ -1460,7 +1455,6 @@ def main(): adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, model_config={"quantize_activations": args.quantize_activations}, - drop_unknown_weights=args.drop_unknown_weights, ) for eos_token in args.extra_eos_token: tokenizer.add_eos_token(eos_token) @@ -1505,10 +1499,7 @@ def main(): prompt = tokenizer.encode(prompt) if args.draft_model is not None: - draft_model, draft_tokenizer = load( - args.draft_model, - drop_unknown_weights=args.drop_unknown_weights, - ) + draft_model, draft_tokenizer = load(args.draft_model) if draft_tokenizer.vocab_size != tokenizer.vocab_size: raise ValueError("Draft model tokenizer does not match model tokenizer.") else: diff --git a/mlx_lm/lora.py b/mlx_lm/lora.py index a17749446..b5fcce349 100644 --- a/mlx_lm/lora.py +++ b/mlx_lm/lora.py @@ -210,11 +210,6 @@ def build_parser(): help="Project name for logging. Defaults to the name of the root directory.", ) parser.add_argument("--seed", type=int, help="The PRNG seed") - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) return parser @@ -331,11 +326,7 @@ def run(args, training_callback: TrainingCallback = None): ) print("Loading pretrained model") - model, tokenizer = load( - args.model, - tokenizer_config={"trust_remote_code": True}, - drop_unknown_weights=args.drop_unknown_weights, - ) + model, tokenizer = load(args.model, tokenizer_config={"trust_remote_code": True}) print("Loading datasets") train_set, valid_set, test_set = load_dataset(args, tokenizer) diff --git a/mlx_lm/perplexity.py b/mlx_lm/perplexity.py index 63781aed5..64cde5d93 100644 --- a/mlx_lm/perplexity.py +++ b/mlx_lm/perplexity.py @@ -135,11 +135,6 @@ def main(): parser.add_argument( "--seed", type=int, default=123, help="Random seed for data sampling" ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) args = parser.parse_args() @@ -150,11 +145,7 @@ def main(): # Load model print(f"Loading model from {args.model}...") tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} - model, tokenizer = load( - args.model, - tokenizer_config=tokenizer_config, - drop_unknown_weights=args.drop_unknown_weights, - ) + model, tokenizer = load(args.model, tokenizer_config=tokenizer_config) # Count parameters total_params = get_total_parameters(model) diff --git a/mlx_lm/quant/awq.py b/mlx_lm/quant/awq.py index 949f8f2f3..13b64057d 100644 --- a/mlx_lm/quant/awq.py +++ b/mlx_lm/quant/awq.py @@ -544,11 +544,6 @@ def main(): parser.add_argument("--sequence-length", type=int, default=512) parser.add_argument("--n-grid", type=int, default=20) parser.add_argument("--seed", type=int, default=123) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) args = parser.parse_args() group = mx.distributed.init() @@ -559,12 +554,7 @@ def main(): mx.random.seed(args.seed) - model, tokenizer, config = load( - args.model, - lazy=True, - return_config=True, - drop_unknown_weights=args.drop_unknown_weights, - ) + model, tokenizer, config = load(args.model, lazy=True, return_config=True) model_type = config["model_type"] if (awq_config := AWQ_MODEL_CONFIGS.get(model_type, None)) is None: diff --git a/mlx_lm/quant/dwq.py b/mlx_lm/quant/dwq.py index dd60b55a9..d7b8144d0 100644 --- a/mlx_lm/quant/dwq.py +++ b/mlx_lm/quant/dwq.py @@ -300,11 +300,6 @@ def main(): action="store_true", help="Use pipeline parallel instead of data parallel.", ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) args = parser.parse_args() @@ -333,18 +328,9 @@ def main(): # Load the base model if we need it if not has_targets or args.quantized_model is None: if args.pipeline and group.size() > 1: - model, _, config = pipeline_load( - args.model, - return_config=True, - drop_unknown_weights=args.drop_unknown_weights, - ) + model, _, config = pipeline_load(args.model, return_config=True) else: - model, _, config = load( - args.model, - return_config=True, - lazy=True, - drop_unknown_weights=args.drop_unknown_weights, - ) + model, _, config = load(args.model, return_config=True, lazy=True) else: model = None @@ -380,7 +366,6 @@ def target_fn(batch, idx, split): args.quantized_model, lazy=True, return_config=True, - drop_unknown_weights=args.drop_unknown_weights, ) if "quantization" not in config: raise ValueError("Quantized model must already be quantized.") diff --git a/mlx_lm/quant/dynamic_quant.py b/mlx_lm/quant/dynamic_quant.py index 1ca8504ad..c339015aa 100644 --- a/mlx_lm/quant/dynamic_quant.py +++ b/mlx_lm/quant/dynamic_quant.py @@ -182,19 +182,10 @@ def main(): choices=["float32", "bfloat16"], help="What type to use to accumulate the gradients for the sensitivities", ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) args = parser.parse_args() group = mx.distributed.init() - model, tokenizer, config = load( - args.model, - return_config=True, - drop_unknown_weights=args.drop_unknown_weights, - ) + model, tokenizer, config = load(args.model, return_config=True) if args.sensitivities is None: mx.random.seed(args.seed) diff --git a/mlx_lm/quant/gptq.py b/mlx_lm/quant/gptq.py index 54b171765..a516b7f2f 100644 --- a/mlx_lm/quant/gptq.py +++ b/mlx_lm/quant/gptq.py @@ -197,21 +197,11 @@ def main(): help="Sequence length for the calibration data.", ) parser.add_argument("--seed", type=int, default=123) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) args = parser.parse_args() mx.random.seed(args.seed) - model, tokenizer, config = load( - args.model, - lazy=True, - return_config=True, - drop_unknown_weights=args.drop_unknown_weights, - ) + model, tokenizer, config = load(args.model, lazy=True, return_config=True) calibration_data = load_data(tokenizer, args.num_samples, args.sequence_length) model, config["quantization"] = gptq_quantize( diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 078ae610d..7fc91fa2a 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -513,7 +513,6 @@ def load(self, model_path, adapter_path=None, draft_model_path=None): model_path = self.default_model_map.get(model_path, model_path) if self.model_key == (model_path, adapter_path, draft_model_path): return self.model, self.tokenizer - drop = self.cli_args.drop_unknown_weights # Remove the old model if it exists. self.model = None @@ -538,33 +537,25 @@ def load(self, model_path, adapter_path=None, draft_model_path=None): # TODO: Generalize distributed load if self.is_distributed: model, tokenizer = sharded_load( - self.cli_args.model, - self.pipeline_group, - self.tensor_group, - drop_unknown_weights=drop, + self.cli_args.model, self.pipeline_group, self.tensor_group ) else: model, tokenizer = load( self.cli_args.model, adapter_path=adapter_path, tokenizer_config=tokenizer_config, - drop_unknown_weights=drop, ) else: # TODO: Generalize distributed load if self.is_distributed: model, tokenizer = sharded_load( - model_path, - self.pipeline_group, - self.tensor_group, - drop_unknown_weights=drop, + model_path, self.pipeline_group, self.tensor_group ) else: model, tokenizer = load( model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config, - drop_unknown_weights=drop, ) if self.cli_args.use_default_chat_template: @@ -588,17 +579,11 @@ def validate_draft_tokenizer(draft_tokenizer): draft_model_path == "default_model" and self.cli_args.draft_model is not None ): - self.draft_model, draft_tokenizer = load( - self.cli_args.draft_model, - drop_unknown_weights=drop, - ) + self.draft_model, draft_tokenizer = load(self.cli_args.draft_model) validate_draft_tokenizer(draft_tokenizer) elif draft_model_path is not None and draft_model_path != "default_model": - self.draft_model, draft_tokenizer = load( - draft_model_path, - drop_unknown_weights=drop, - ) + self.draft_model, draft_tokenizer = load(draft_model_path) validate_draft_tokenizer(draft_tokenizer) if self.draft_model is None: @@ -2043,11 +2028,6 @@ def main(): action="store_true", help="Use pipelining instead of tensor parallelism", ) - parser.add_argument( - "--drop-unknown-weights", - action="store_true", - help="Drop weights not present in the instantiated model.", - ) args = parser.parse_args() if mx.metal.is_available(): wired_limit = mx.device_info()["max_recommended_working_set_size"] diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index 53137d0bd..5ff19e2bb 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -285,7 +285,6 @@ def load_model( strict: bool = True, model_config: Optional[Dict[str, Any]] = None, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, - drop_unknown_weights: bool = False, ) -> Tuple[nn.Module, dict]: """ Load and initialize the model from a given path. @@ -302,9 +301,6 @@ def load_model( get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): A function that returns the model class and model args class given a config. Defaults to the ``_get_classes`` function. - drop_unknown_weights (bool): If ``True`` drop sanitized weights - that do not exist in the instantiated model before loading. - Default: ``False`` Returns: Tuple[nn.Module, dict[str, Any]]: The loaded and initialized model and config. @@ -416,19 +412,18 @@ def _maybe_qq(m): model.update_modules(leaves) model.eval() - if drop_unknown_weights: - expected_keys = {key for key, _ in tree_flatten(model.parameters())} - unknown_keys = [key for key in weights if key not in expected_keys] - if unknown_keys: - sample_keys = sorted(unknown_keys)[:3] - sample = ", ".join(sample_keys) - if len(unknown_keys) > 3: - sample = f"{sample}, ..." - print( - "[INFO] Dropping weights not present in the instantiated model: " - f"count={len(unknown_keys)} sample=[{sample}]" - ) - weights = {key: value for key, value in weights.items() if key in expected_keys} + expected_keys = {key for key, _ in tree_flatten(model.parameters())} + unknown_keys = [key for key in weights if key not in expected_keys] + if unknown_keys: + sample_keys = sorted(unknown_keys)[:3] + sample = ", ".join(sample_keys) + if len(unknown_keys) > 3: + sample = f"{sample}, ..." + print( + "[INFO] Dropping weights not present in the instantiated model: " + f"count={len(unknown_keys)} sample=[{sample}]" + ) + weights = {key: value for key, value in weights.items() if key in expected_keys} model.load_weights(list(weights.items()), strict=strict) @@ -476,7 +471,6 @@ def load( lazy: bool = False, return_config: bool = False, revision: Optional[str] = None, - drop_unknown_weights: bool = False, ) -> Union[ Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]], @@ -497,8 +491,6 @@ def load( when needed. Default: ``False`` return_config (bool: If ``True`` return the model config as the last item.. revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash. - drop_unknown_weights (bool): If ``True`` drop sanitized weights - that are not supported by the instantiated model before loading. Returns: Union[Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]]]: A tuple containing the loaded model, tokenizer and, if requested, the model config. @@ -509,12 +501,7 @@ def load( """ model_path = _download(path_or_hf_repo, revision=revision) - model, config = load_model( - model_path, - lazy, - model_config=model_config, - drop_unknown_weights=drop_unknown_weights, - ) + model, config = load_model(model_path, lazy, model_config=model_config) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() @@ -533,7 +520,6 @@ def sharded_load( pipeline_group: Optional[mx.distributed.Group] = None, tensor_group: Optional[mx.distributed.Group] = None, return_config: bool = False, - drop_unknown_weights: bool = False, ): # Get model path with everything but weight safetensors model_path = _download( @@ -552,12 +538,7 @@ def sharded_load( # Lazy load model to figure out what type of sharding we can do and which # weights we need to download. - model, config = load_model( - model_path, - lazy=True, - strict=False, - drop_unknown_weights=drop_unknown_weights, - ) + model, config = load_model(model_path, lazy=True, strict=False) has_pipelining = hasattr(model, "model") and hasattr(model.model, "pipeline") has_tensor_parallel = hasattr(model, "shard") @@ -606,12 +587,7 @@ def sharded_load( {"trust_remote_code": True}, eos_token_ids=config.get("eos_token_id", None), ) - model, _ = load_model( - model_path, - lazy=True, - strict=False, - drop_unknown_weights=drop_unknown_weights, - ) + model, _ = load_model(model_path, lazy=True, strict=False) if tensor_group is not None: model.shard(tensor_group) if pipeline_group is not None: @@ -626,17 +602,8 @@ def sharded_load( return model, tokenizer -def pipeline_load( - repo, - return_config: bool = False, - drop_unknown_weights: bool = False, -): - return sharded_load( - repo, - pipeline_group=mx.distributed.init(), - return_config=return_config, - drop_unknown_weights=drop_unknown_weights, - ) +def pipeline_load(repo, return_config=False): + return sharded_load(repo, mx.distributed.init(), None, return_config) def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: diff --git a/tests/test_utils.py b/tests/test_utils.py index d6efda5ca..d55ad346b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,9 @@ # Copyright © 2024 Apple Inc. -import io import json import os import tempfile import unittest -from contextlib import redirect_stdout from pathlib import Path import mlx.core as mx @@ -189,42 +187,24 @@ def _write_model_dir(self, weights, extra_config=None): mx.save_safetensors(str(model_dir / "model.safetensors"), weights) return model_dir - def _load_tiny_model(self, model_dir, **kwargs): - return utils.load_model(model_dir, get_model_classes=get_tiny_classes, **kwargs) - - def _capture_tiny_load(self, model_dir, **kwargs): - stdout = io.StringIO() - with redirect_stdout(stdout): - model, config = self._load_tiny_model(model_dir, **kwargs) - return model, config, stdout.getvalue() + def _load_tiny_model(self, model_dir): + return utils.load_model( + model_dir, + get_model_classes=get_tiny_classes, + ) - def test_load_model_drops_unknown_weights_when_enabled(self): - weights = self._tiny_weights() + def test_load_model_drops_unknown_weights(self): + base = self._tiny_weights() + weights = dict(base) weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) weights["audio_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) model_dir = self._write_model_dir(weights) - with self.assertRaises(ValueError): - self._load_tiny_model(model_dir) - - # No output when there is nothing to drop - clean_dir = self._write_model_dir(self._tiny_weights()) - _, _, clean_output = self._capture_tiny_load( - clean_dir, drop_unknown_weights=True - ) - self.assertEqual(clean_output, "") - - model, _, output = self._capture_tiny_load( - model_dir, - drop_unknown_weights=True, - ) + model, _ = self._load_tiny_model(model_dir) loaded = dict(tree_flatten(model.parameters())) - self.assertIn("count=2", output) - self.assertIn("audio_tower.encoder.weight", output) - self.assertIn("vision_tower.encoder.weight", output) - self.assertNotIn("vision_tower.encoder.weight", loaded) - self.assertNotIn("audio_tower.encoder.weight", loaded) + self.assertTrue(mx.allclose(loaded["embed.weight"], base["embed.weight"])) + self.assertTrue(mx.allclose(loaded["proj.weight"], base["proj.weight"])) def test_load_model_drops_unknown_weights_after_sanitize(self): base = self._tiny_weights() @@ -232,14 +212,9 @@ def test_load_model_drops_unknown_weights_after_sanitize(self): weights["vision_tower.encoder.weight"] = mx.zeros((1,), dtype=mx.float32) model_dir = self._write_model_dir(weights) - model, _, output = self._capture_tiny_load( - model_dir, - drop_unknown_weights=True, - ) + model, _ = self._load_tiny_model(model_dir) loaded = dict(tree_flatten(model.parameters())) - self.assertIn("count=1", output) - self.assertIn("vision_tower.encoder.weight", output) self.assertTrue(mx.allclose(loaded["embed.weight"], base["embed.weight"])) self.assertTrue(mx.allclose(loaded["proj.weight"], base["proj.weight"])) @@ -249,7 +224,7 @@ def test_load_model_still_fails_for_missing_supported_weights(self): model_dir = self._write_model_dir(weights) with self.assertRaises(ValueError): - self._load_tiny_model(model_dir, drop_unknown_weights=True) + self._load_tiny_model(model_dir) def test_load_model_keeps_supported_quantized_weights(self): weights = self._quantized_tiny_weights() @@ -259,10 +234,7 @@ def test_load_model_keeps_supported_quantized_weights(self): extra_config={"quantization": {"bits": 4, "group_size": 32}}, ) - model, _ = self._load_tiny_model( - model_dir, - drop_unknown_weights=True, - ) + model, _ = self._load_tiny_model(model_dir) loaded = dict(tree_flatten(model.parameters())) self.assertIn("proj.weight", loaded)