Skip to content

Commit

Permalink
add support for selective model downloading. related: #16
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Aug 1, 2024
1 parent be3a09c commit 4faa6c0
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 14 deletions.
141 changes: 130 additions & 11 deletions exo/inference/mlx/sharded_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from io import BytesIO
import base64
import os
import concurrent.futures

from exo import DEBUG
import mlx.core as mx
Expand Down Expand Up @@ -120,7 +121,18 @@ def load_model_shard(
raise FileNotFoundError(f"No safetensors found in {model_path}")

weights = {}
for wf in weight_files:
for wf in sorted(weight_files):
if DEBUG >= 8:
layer_nums = set()
for k in mx.load(wf):
if k.startswith("model.layers."):
layer_num = int(k.split(".")[2])
layer_nums.add(layer_num)
if k.startswith("language_model.model.layers."):
layer_num = int(k.split(".")[3])
layer_nums.add(layer_num)
print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},")

weights.update(mx.load(wf))

model_class, model_args_class = _get_classes(config=config)
Expand Down Expand Up @@ -150,14 +162,15 @@ def load_model_shard(
async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type)
files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path))
return sum(file.size for file in files if file.size is not None)
return sum(file.size for file in files if hasattr(file, "size") and file.size is not None)

async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None):
while True:
try:
await asyncio.sleep(0.1)
current_size = sum(os.path.getsize(os.path.join(root, file))
for root, _, files in os.walk(dir)
for file in files)
for root, _, files in os.walk(dir)
for file in files)
progress = min(current_size / total_size * 100, 100)
if print_progress:
print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True)
Expand All @@ -167,10 +180,15 @@ async def monitor_progress(dir, total_size, print_progress=False, on_progress: C
if print_progress:
print("\nDownload complete!")
break
except Exception as e:
print(f"Error monitoring progress: {e}")

async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None):
# Use snapshot_download in a separate thread to not block the event loop
return await asyncio.to_thread(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
with concurrent.futures.ThreadPoolExecutor() as pool:
return await asyncio.get_event_loop().run_in_executor(
pool,
partial(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type)
)

async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None):
storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model"))
Expand All @@ -184,11 +202,113 @@ async def download_async_with_progress(repo_id: str, revision: Optional[str] = N
progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress))

# Wait for both tasks to complete
result = await asyncio.gather(download_task, progress_task)
result = await asyncio.gather(download_task, progress_task, return_exceptions=True)
return result[0] # Return the result from download_task

repo_id_safetensors_layers = {
"mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": {
"model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
},
"mlx-community/Meta-Llama-3.1-70B-Instruct-4bit": {
"model-00001-of-00008.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
"model-00002-of-00008.safetensors": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
"model-00003-of-00008.safetensors": [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
"model-00004-of-00008.safetensors": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42],
"model-00005-of-00008.safetensors": [42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53],
"model-00006-of-00008.safetensors": [53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64],
"model-00007-of-00008.safetensors": [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75],
"model-00008-of-00008.safetensors": [75, 76, 77, 78, 79],
},
"mlx-community/Meta-Llama-3.1-405B-Instruct-4bit": {
"model-00001-of-00046.safetensors": [0, 1, 2],
"model-00002-of-00046.safetensors": [2, 3, 4, 5],
"model-00003-of-00046.safetensors": [5, 6, 7],
"model-00004-of-00046.safetensors": [8, 9, 10],
"model-00005-of-00046.safetensors": [10, 11, 12, 13],
"model-00006-of-00046.safetensors": [13, 14, 15, 16],
"model-00007-of-00046.safetensors": [16, 17, 18, 19],
"model-00008-of-00046.safetensors": [19, 20, 21],
"model-00009-of-00046.safetensors": [22, 23, 24],
"model-00010-of-00046.safetensors": [24, 25, 26, 27],
"model-00011-of-00046.safetensors": [27, 28, 29, 30],
"model-00012-of-00046.safetensors": [30, 31, 32, 33],
"model-00013-of-00046.safetensors": [33, 34, 35],
"model-00014-of-00046.safetensors": [36, 37, 38],
"model-00015-of-00046.safetensors": [38, 39, 40, 41],
"model-00016-of-00046.safetensors": [41, 42, 43, 44],
"model-00017-of-00046.safetensors": [44, 45, 46, 47],
"model-00018-of-00046.safetensors": [47, 48, 49],
"model-00019-of-00046.safetensors": [50, 51, 52],
"model-00020-of-00046.safetensors": [52, 53, 54, 55],
"model-00021-of-00046.safetensors": [55, 56, 57, 58],
"model-00022-of-00046.safetensors": [58, 59, 60, 61],
"model-00023-of-00046.safetensors": [61, 62, 63],
"model-00024-of-00046.safetensors": [64, 65, 66],
"model-00025-of-00046.safetensors": [66, 67, 68, 69],
"model-00026-of-00046.safetensors": [69, 70, 71, 72],
"model-00027-of-00046.safetensors": [72, 73, 74, 75],
"model-00028-of-00046.safetensors": [75, 76, 77],
"model-00029-of-00046.safetensors": [78, 79, 80],
"model-00030-of-00046.safetensors": [80, 81, 82, 83],
"model-00031-of-00046.safetensors": [83, 84, 85, 86],
"model-00032-of-00046.safetensors": [86, 87, 88, 89],
"model-00033-of-00046.safetensors": [89, 90, 91],
"model-00034-of-00046.safetensors": [92, 93, 94],
"model-00035-of-00046.safetensors": [94, 95, 96, 97],
"model-00036-of-00046.safetensors": [97, 98, 99, 100],
"model-00037-of-00046.safetensors": [100, 101, 102, 103],
"model-00038-of-00046.safetensors": [103, 104, 105],
"model-00039-of-00046.safetensors": [106, 107, 108],
"model-00040-of-00046.safetensors": [108, 109, 110, 111],
"model-00041-of-00046.safetensors": [111, 112, 113, 114],
"model-00042-of-00046.safetensors": [114, 115, 116, 117],
"model-00043-of-00046.safetensors": [117, 118, 119],
"model-00044-of-00046.safetensors": [120, 121, 122],
"model-00045-of-00046.safetensors": [122, 123, 124, 125],
"model-00046-of-00046.safetensors": [125]
},
"mlx-community/Mistral-Nemo-Instruct-2407-4bit": {
"model-00001-of-00002.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
"model-00002-of-00002.safetensors": [32, 33, 34, 35, 36, 37, 38, 39],
},
"mlx-community/Mistral-Large-Instruct-2407-4bit": {
"model-00001-of-00014.safetensors": [0, 1, 2, 3, 4, 5, 6],
"model-00002-of-00014.safetensors": [6, 7, 8, 9, 10, 11, 12, 13],
"model-00003-of-00014.safetensors": [13, 14, 15, 16, 17, 18, 19, 20],
"model-00004-of-00014.safetensors": [20, 21, 22, 23, 24, 25, 26],
"model-00005-of-00014.safetensors": [27, 28, 29, 30, 31, 32, 33],
"model-00006-of-00014.safetensors": [33, 34, 35, 36, 37, 38, 39, 40],
"model-00007-of-00014.safetensors": [40, 41, 42, 43, 44, 45, 46, 47],
"model-00008-of-00014.safetensors": [47, 48, 49, 50, 51, 52, 53, 54],
"model-00009-of-00014.safetensors": [54, 55, 56, 57, 58, 59, 60],
"model-00010-of-00014.safetensors": [61, 62, 63, 64, 65, 66, 67],
"model-00011-of-00014.safetensors": [67, 68, 69, 70, 71, 72, 73, 74],
"model-00012-of-00014.safetensors": [74, 75, 76, 77, 78, 79, 80, 81],
"model-00013-of-00014.safetensors": [81, 82, 83, 84, 85, 86, 87],
"model-00014-of-00014.safetensors": [87]
},
"llava-hf/llava-1.5-7b-hf": {
"model-00001-of-00003.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"model-00002-of-00003.safetensors": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22],
"model-00003-of-00003.safetensors": [22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
}
}

def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None):
return ["*.safetensors"] # TODO: enable this
if not shard:
return ["*.safetensors"]

allow_patterns = []
for repo_id, safetensors_layers in repo_id_safetensors_layers.items():
if repo_id == shard.model_id:
for safetensor, layers in safetensors_layers.items():
if any(shard.start_layer <= layer <= shard.end_layer for layer in layers):
allow_patterns.append(safetensor)

return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"]

async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Expand All @@ -209,12 +329,11 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, o
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
] + get_safetensors_allow_patterns(path_or_hf_repo, shard),
on_progress=on_download_progress,
)
)
Expand Down Expand Up @@ -259,7 +378,7 @@ async def load_shard(
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
model_path = await get_model_path(path_or_hf_repo, on_download_progress=on_download_progress)
model_path = await get_model_path(path_or_hf_repo, shard, on_download_progress=on_download_progress)

model = load_model_shard(model_path, shard, lazy, model_config)
if adapter_path is not None:
Expand Down
2 changes: 1 addition & 1 deletion exo/networking/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def SendPrompt(self, request, context):
image_str = request.image_str
request_id = request.request_id
result = await self.node.process_prompt(shard, prompt, image_str, request_id)
if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image=} {request_id=} result: {result}")
if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

Expand Down
4 changes: 2 additions & 2 deletions exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optio
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
if shard.start_layer != 0:
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
await self.forward_to_next_shard(shard, prompt, request_id, image_str)
await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state)
return

result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
Expand All @@ -146,7 +146,7 @@ async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optio
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")

if not is_finished:
asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))

return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None

Expand Down

0 comments on commit 4faa6c0

Please sign in to comment.