Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 64 additions & 43 deletions clip/convert.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,68 @@
# Copyright © 2023-2024 Apple Inc.

import argparse
import json
import shutil
from pathlib import Path
from typing import Tuple
from typing import Any, Dict, Union

import mlx.core as mx
import torch
from huggingface_hub import snapshot_download


def make_shards(weights: dict, max_file_size_gb: int = 5) -> list:
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += v.nbytes
shards.append(shard)
return shards


def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
"""Save model weights into specified directory."""
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)

shards = make_shards(weights)
shards_count = len(shards)
shard_file_format = (
"model-{:05d}-of-{:05d}.safetensors"
if shards_count > 1
else "model.safetensors"
)

total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}

for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count)
shard_path = save_path / shard_name

mx.save_safetensors(str(shard_path), shard)

for weight_name in shard.keys():
index_data["weight_map"][weight_name] = shard_name

index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}

with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(
index_data,
f,
indent=4,
)


def get_model_path(path_or_hf_repo: str) -> Path:
model_path = Path(path_or_hf_repo)
if not model_path.exists():
Expand All @@ -32,44 +85,6 @@ def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
return mx.array(a.numpy(), getattr(mx, dtype))


def map_weights(key: str, value: torch.Tensor) -> Tuple[str, mx.array]:
key = key.replace("embeddings.", "")
key = key.replace("encoder.", "")
key = key.replace("position_embedding.weight", "position_embedding")

# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
if "layer_norm1." in key:
key = key.replace("layer_norm1.", "ln1.")
if "layer_norm2." in key:
key = key.replace("layer_norm2.", "ln2.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
# Fix layernorm typo
if "pre_layrnorm" in key:
# Fix typo in weights :)
key = key.replace("pre_layrnorm", "pre_layernorm")
if "patch_embedding.weight" in key:
# Initially, value: [out_channels, in_channels, kH, KW].
# We want [out_channels, kH, KW, in_channels]
value = value.permute(0, 2, 3, 1)
return (key, torch_to_mx(value, dtype=str(value.dtype).replace("torch.", "")))


def should_keep_weight(key: str):
return not ("position_ids" in key)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download and Convert (OpenAI) CLIP weights to MLX"
Expand All @@ -86,7 +101,12 @@ def should_keep_weight(key: str):
default="mlx_model",
help="Path to save the MLX model.",
)

parser.add_argument(
"--dtype",
help="The data type to save the converted model.",
type=str,
default="float32",
)
args = parser.parse_args()

torch_path = get_model_path(args.hf_repo)
Expand All @@ -96,10 +116,11 @@ def should_keep_weight(key: str):
print("[INFO] Loading")
torch_weights = torch.load(torch_path / "pytorch_model.bin")
print("[INFO] Converting")
mlx_weights = dict(map_weights(k, v) for (k, v) in torch_weights.items())
mlx_weights = {k: v for (k, v) in mlx_weights.items() if should_keep_weight(k)}
mlx_weights = {
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
}
print("[INFO] Saving")
mx.savez(str(mlx_path / "weights.npz"), **mlx_weights)
save_weights(mlx_path, mlx_weights)
for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
shutil.copyfile(
str(torch_path / f"{fn}"),
Expand Down
Loading