diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index bca459dec..316e5cd08 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -13,6 +13,7 @@ from io import BytesIO import base64 import os +import concurrent.futures from exo import DEBUG import mlx.core as mx @@ -120,7 +121,18 @@ def load_model_shard( raise FileNotFoundError(f"No safetensors found in {model_path}") weights = {} - for wf in weight_files: + for wf in sorted(weight_files): + if DEBUG >= 8: + layer_nums = set() + for k in mx.load(wf): + if k.startswith("model.layers."): + layer_num = int(k.split(".")[2]) + layer_nums.add(layer_num) + if k.startswith("language_model.model.layers."): + layer_num = int(k.split(".")[3]) + layer_nums.add(layer_num) + print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},") + weights.update(mx.load(wf)) model_class, model_args_class = _get_classes(config=config) @@ -150,14 +162,15 @@ def load_model_shard( async def get_repo_size(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None): it = await asyncio.to_thread(list_repo_tree, repo_id, revision=revision, repo_type=repo_type) files = list(filter_repo_objects(it, allow_patterns=allow_patterns, key=lambda f: f.path)) - return sum(file.size for file in files if file.size is not None) + return sum(file.size for file in files if hasattr(file, "size") and file.size is not None) async def monitor_progress(dir, total_size, print_progress=False, on_progress: Callable[[int, int], None] = None): while True: + try: await asyncio.sleep(0.1) current_size = sum(os.path.getsize(os.path.join(root, file)) - for root, _, files in os.walk(dir) - for file in files) + for root, _, files in os.walk(dir) + for file in files) progress = min(current_size / total_size * 100, 100) if print_progress: print(f"\rProgress: {progress:.2f}% ({current_size}/{total_size} bytes)", end="", flush=True) @@ -167,10 +180,15 @@ async def monitor_progress(dir, total_size, print_progress=False, on_progress: C if print_progress: print("\nDownload complete!") break + except Exception as e: + print(f"Error monitoring progress: {e}") async def download_repo(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None): - # Use snapshot_download in a separate thread to not block the event loop - return await asyncio.to_thread(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type) + with concurrent.futures.ThreadPoolExecutor() as pool: + return await asyncio.get_event_loop().run_in_executor( + pool, + partial(snapshot_download, repo_id=repo_id, revision=revision, allow_patterns=allow_patterns, repo_type=repo_type) + ) async def download_async_with_progress(repo_id: str, revision: Optional[str] = None, allow_patterns: Optional[Union[List[str], str]] = None, repo_type: Optional[str] = None, on_progress: Callable[[int, int], None] = None): storage_folder = os.path.join(HF_HUB_CACHE, repo_folder_name(repo_id=repo_id, repo_type="model")) @@ -184,11 +202,113 @@ async def download_async_with_progress(repo_id: str, revision: Optional[str] = N progress_task = asyncio.create_task(monitor_progress(storage_folder, total_size, on_progress=on_progress)) # Wait for both tasks to complete - result = await asyncio.gather(download_task, progress_task) + result = await asyncio.gather(download_task, progress_task, return_exceptions=True) return result[0] # Return the result from download_task +repo_id_safetensors_layers = { + "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit": { + "model.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] + }, + "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit": { + "model-00001-of-00008.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "model-00002-of-00008.safetensors": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + "model-00003-of-00008.safetensors": [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "model-00004-of-00008.safetensors": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42], + "model-00005-of-00008.safetensors": [42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53], + "model-00006-of-00008.safetensors": [53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64], + "model-00007-of-00008.safetensors": [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75], + "model-00008-of-00008.safetensors": [75, 76, 77, 78, 79], + }, + "mlx-community/Meta-Llama-3.1-405B-Instruct-4bit": { + "model-00001-of-00046.safetensors": [0, 1, 2], + "model-00002-of-00046.safetensors": [2, 3, 4, 5], + "model-00003-of-00046.safetensors": [5, 6, 7], + "model-00004-of-00046.safetensors": [8, 9, 10], + "model-00005-of-00046.safetensors": [10, 11, 12, 13], + "model-00006-of-00046.safetensors": [13, 14, 15, 16], + "model-00007-of-00046.safetensors": [16, 17, 18, 19], + "model-00008-of-00046.safetensors": [19, 20, 21], + "model-00009-of-00046.safetensors": [22, 23, 24], + "model-00010-of-00046.safetensors": [24, 25, 26, 27], + "model-00011-of-00046.safetensors": [27, 28, 29, 30], + "model-00012-of-00046.safetensors": [30, 31, 32, 33], + "model-00013-of-00046.safetensors": [33, 34, 35], + "model-00014-of-00046.safetensors": [36, 37, 38], + "model-00015-of-00046.safetensors": [38, 39, 40, 41], + "model-00016-of-00046.safetensors": [41, 42, 43, 44], + "model-00017-of-00046.safetensors": [44, 45, 46, 47], + "model-00018-of-00046.safetensors": [47, 48, 49], + "model-00019-of-00046.safetensors": [50, 51, 52], + "model-00020-of-00046.safetensors": [52, 53, 54, 55], + "model-00021-of-00046.safetensors": [55, 56, 57, 58], + "model-00022-of-00046.safetensors": [58, 59, 60, 61], + "model-00023-of-00046.safetensors": [61, 62, 63], + "model-00024-of-00046.safetensors": [64, 65, 66], + "model-00025-of-00046.safetensors": [66, 67, 68, 69], + "model-00026-of-00046.safetensors": [69, 70, 71, 72], + "model-00027-of-00046.safetensors": [72, 73, 74, 75], + "model-00028-of-00046.safetensors": [75, 76, 77], + "model-00029-of-00046.safetensors": [78, 79, 80], + "model-00030-of-00046.safetensors": [80, 81, 82, 83], + "model-00031-of-00046.safetensors": [83, 84, 85, 86], + "model-00032-of-00046.safetensors": [86, 87, 88, 89], + "model-00033-of-00046.safetensors": [89, 90, 91], + "model-00034-of-00046.safetensors": [92, 93, 94], + "model-00035-of-00046.safetensors": [94, 95, 96, 97], + "model-00036-of-00046.safetensors": [97, 98, 99, 100], + "model-00037-of-00046.safetensors": [100, 101, 102, 103], + "model-00038-of-00046.safetensors": [103, 104, 105], + "model-00039-of-00046.safetensors": [106, 107, 108], + "model-00040-of-00046.safetensors": [108, 109, 110, 111], + "model-00041-of-00046.safetensors": [111, 112, 113, 114], + "model-00042-of-00046.safetensors": [114, 115, 116, 117], + "model-00043-of-00046.safetensors": [117, 118, 119], + "model-00044-of-00046.safetensors": [120, 121, 122], + "model-00045-of-00046.safetensors": [122, 123, 124, 125], + "model-00046-of-00046.safetensors": [125] + }, + "mlx-community/Mistral-Nemo-Instruct-2407-4bit": { + "model-00001-of-00002.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32], + "model-00002-of-00002.safetensors": [32, 33, 34, 35, 36, 37, 38, 39], + }, + "mlx-community/Mistral-Large-Instruct-2407-4bit": { + "model-00001-of-00014.safetensors": [0, 1, 2, 3, 4, 5, 6], + "model-00002-of-00014.safetensors": [6, 7, 8, 9, 10, 11, 12, 13], + "model-00003-of-00014.safetensors": [13, 14, 15, 16, 17, 18, 19, 20], + "model-00004-of-00014.safetensors": [20, 21, 22, 23, 24, 25, 26], + "model-00005-of-00014.safetensors": [27, 28, 29, 30, 31, 32, 33], + "model-00006-of-00014.safetensors": [33, 34, 35, 36, 37, 38, 39, 40], + "model-00007-of-00014.safetensors": [40, 41, 42, 43, 44, 45, 46, 47], + "model-00008-of-00014.safetensors": [47, 48, 49, 50, 51, 52, 53, 54], + "model-00009-of-00014.safetensors": [54, 55, 56, 57, 58, 59, 60], + "model-00010-of-00014.safetensors": [61, 62, 63, 64, 65, 66, 67], + "model-00011-of-00014.safetensors": [67, 68, 69, 70, 71, 72, 73, 74], + "model-00012-of-00014.safetensors": [74, 75, 76, 77, 78, 79, 80, 81], + "model-00013-of-00014.safetensors": [81, 82, 83, 84, 85, 86, 87], + "model-00014-of-00014.safetensors": [87] + }, + "llava-hf/llava-1.5-7b-hf": { + "model-00001-of-00003.safetensors": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "model-00002-of-00003.safetensors": [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22], + "model-00003-of-00003.safetensors": [22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + } +} + +def get_safetensors_allow_patterns(repo_id: str, shard: Optional[Shard] = None): + return ["*.safetensors"] # TODO: enable this + if not shard: + return ["*.safetensors"] + + allow_patterns = [] + for repo_id, safetensors_layers in repo_id_safetensors_layers.items(): + if repo_id == shard.model_id: + for safetensor, layers in safetensors_layers.items(): + if any(shard.start_layer <= layer <= shard.end_layer for layer in layers): + allow_patterns.append(safetensor) + + return allow_patterns if len(allow_patterns) > 0 else ["*.safetensors"] -async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path: +async def get_model_path(path_or_hf_repo: str, shard: Optional[Shard] = None, revision: Optional[str] = None, on_download_progress: Callable[[int, int], None] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, it is downloaded from the Hugging Face Hub. @@ -209,12 +329,11 @@ async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None, o revision=revision, allow_patterns=[ "*.json", - "*.safetensors", "*.py", "tokenizer.model", "*.tiktoken", "*.txt", - ], + ] + get_safetensors_allow_patterns(path_or_hf_repo, shard), on_progress=on_download_progress, ) ) @@ -259,7 +378,7 @@ async def load_shard( FileNotFoundError: If config file or safetensors are not found. ValueError: If model class or args class are not found. """ - model_path = await get_model_path(path_or_hf_repo, on_download_progress=on_download_progress) + model_path = await get_model_path(path_or_hf_repo, shard, on_download_progress=on_download_progress) model = load_model_shard(model_path, shard, lazy, model_config) if adapter_path is not None: diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index 5992ce1f4..cd4c6e2a9 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -48,7 +48,7 @@ async def SendPrompt(self, request, context): image_str = request.image_str request_id = request.request_id result = await self.node.process_prompt(shard, prompt, image_str, request_id) - if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image=} {request_id=} result: {result}") + if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}") tensor_data = result.tobytes() if result is not None else None return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 242f6d113..c2d9e028f 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -130,7 +130,7 @@ async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optio if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}") if shard.start_layer != 0: if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}") - await self.forward_to_next_shard(shard, prompt, request_id, image_str) + await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state) return result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state) @@ -146,7 +146,7 @@ async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optio if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") if not is_finished: - asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state)) + asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state)) return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None