diff --git a/.circleci/config.yml b/.circleci/config.yml index 2554847b..f09985b8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -250,6 +250,34 @@ jobs: workflows: version: 2 build_and_test: + jobs: + - approve_run: + type: approval + requires: [] + filters: + branches: + ignore: main + - unit_test: + requires: + - approve_run + - discovery_integration_test: + requires: + - approve_run + - chatgpt_api_integration_test_mlx: + requires: + - approve_run + - test_macos_m1: + requires: + - approve_run + - chatgpt_api_integration_test_torch_linux_cpu: + requires: + - approve_run + - chatgpt_api_integration_test_torch_mac: + requires: + - approve_run + + # Workflow for forked PRs without approval + forked_pr_workflow: jobs: - unit_test - discovery_integration_test diff --git a/.gitignore b/.gitignore index 93227e3c..cb98c151 100644 --- a/.gitignore +++ b/.gitignore @@ -169,5 +169,14 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# XCode **/*.xcodeproj/* + +# Aider .aider* + +# PyTorch interface +.offload + +# neovim/vim settings +.vimrc \ No newline at end of file diff --git a/build/lib/exo/__init__.py b/build/lib/exo/__init__.py new file mode 100644 index 00000000..e802d331 --- /dev/null +++ b/build/lib/exo/__init__.py @@ -0,0 +1 @@ +from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION diff --git a/build/lib/exo/api/__init__.py b/build/lib/exo/api/__init__.py new file mode 100644 index 00000000..660e7507 --- /dev/null +++ b/build/lib/exo/api/__init__.py @@ -0,0 +1 @@ +from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI diff --git a/build/lib/exo/api/chatgpt_api.py b/build/lib/exo/api/chatgpt_api.py new file mode 100644 index 00000000..1abda85f --- /dev/null +++ b/build/lib/exo/api/chatgpt_api.py @@ -0,0 +1,358 @@ +import uuid +import time +import asyncio +import json +from pathlib import Path +from transformers import AutoTokenizer +from typing import List, Literal, Union, Dict +from aiohttp import web +import aiohttp_cors +import traceback +from exo import DEBUG, VERSION +from exo.helpers import PrefixDict +from exo.inference.shard import Shard +from exo.inference.tokenizers import resolve_tokenizer +from exo.orchestration import Node +from exo.models import model_base_shards +from typing import Callable + +class Message: + def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]): + self.role = role + self.content = content + + def to_dict(self): + return {"role": self.role, "content": self.content} + + +class ChatCompletionRequest: + def __init__(self, model: str, messages: List[Message], temperature: float): + self.model = model + self.messages = messages + self.temperature = temperature + + def to_dict(self): + return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature} + + +def generate_completion( + chat_request: ChatCompletionRequest, + tokenizer, + prompt: str, + request_id: str, + tokens: List[int], + stream: bool, + finish_reason: Union[Literal["length", "stop"], None], + object_type: Literal["chat.completion", "text_completion"], +) -> dict: + completion = { + "id": f"chatcmpl-{request_id}", + "object": object_type, + "created": int(time.time()), + "model": chat_request.model, + "system_fingerprint": f"exo_{VERSION}", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": tokenizer.decode(tokens)}, + "logprobs": None, + "finish_reason": finish_reason, + }], + } + + if not stream: + completion["usage"] = { + "prompt_tokens": len(tokenizer.encode(prompt)), + "completion_tokens": len(tokens), + "total_tokens": len(tokenizer.encode(prompt)) + len(tokens), + } + + choice = completion["choices"][0] + if object_type.startswith("chat.completion"): + key_name = "delta" if stream else "message" + choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)} + elif object_type == "text_completion": + choice["text"] = tokenizer.decode(tokens) + else: + ValueError(f"Unsupported response type: {object_type}") + + return completion + + +def remap_messages(messages: List[Message]) -> List[Message]: + remapped_messages = [] + last_image = None + for message in messages: + if not isinstance(message.content, list): + remapped_messages.append(message) + continue + + remapped_content = [] + for content in message.content: + if isinstance(content, dict): + if content.get("type") in ["image_url", "image"]: + image_url = content.get("image_url", {}).get("url") or content.get("image") + if image_url: + last_image = {"type": "image", "image": image_url} + remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"}) + else: + remapped_content.append(content) + else: + remapped_content.append(content) + remapped_messages.append(Message(role=message.role, content=remapped_content)) + + if last_image: + # Replace the last image placeholder with the actual image content + for message in reversed(remapped_messages): + for i, content in enumerate(message.content): + if isinstance(content, dict): + if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]": + message.content[i] = last_image + return remapped_messages + + return remapped_messages + + +def build_prompt(tokenizer, _messages: List[Message]): + if len(_messages) == 1: + user_msg = _messages[0] + + # get instruct sys message + sys_msg = Message(role="system", content="You are a helpful assistant.") + + # restructure for sys_msg to go first + _messages = [sys_msg, user_msg] + + messages = remap_messages(_messages) + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + if DEBUG >= 3: + print(f"prompt: {str(prompt)}") + for msg in messages: + print(f"chat role: {msg.role}\ncontent: {msg.content}") + + image_str = None + for message in messages: + if not isinstance(message.content, list): + continue + + for content in message.content: + # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 + # follows the convention in https://platform.openai.com/docs/guides/vision + if isinstance(content, dict) and content.get("type", None) == "image": + image_str = content.get("image", None) + break + + return prompt, image_str + + +def parse_message(data: dict): + if "role" not in data or "content" not in data: + raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'") + return Message(data["role"], data["content"]) + + +def parse_chat_request(data: dict): + return ChatCompletionRequest( + data.get("model", "llama-3.1-8b"), + [parse_message(msg) for msg in data["messages"]], + data.get("temperature", 0.0), + ) + + +class PromptSession: + def __init__(self, request_id: str, timestamp: int, prompt: str): + self.request_id = request_id + self.timestamp = timestamp + self.prompt = prompt + + +class ChatGPTAPI: + def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None): + self.node = node + self.inference_engine_classname = inference_engine_classname + self.response_timeout_secs = response_timeout_secs + self.on_chat_completion_request = on_chat_completion_request + self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload + self.prompts: PrefixDict[str, PromptSession] = PrefixDict() + self.prev_token_lens: Dict[str, int] = {} + self.stream_tasks: Dict[str, asyncio.Task] = {} + cors = aiohttp_cors.setup(self.app) + cors_options = aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + allow_methods="*", + ) + cors.add(self.app.router.add_get("/models", self.handle_get_models), {"*": cors_options}) + cors.add(self.app.router.add_get("/v1/models", self.handle_get_models), {"*": cors_options}) + cors.add(self.app.router.add_post("/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}) + cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}) + cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options}) + cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options}) + + self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat" + self.app.router.add_get("/", self.handle_root) + self.app.router.add_static("/", self.static_dir, name="static") + + # Add middleware to log every request + self.app.middlewares.append(self.log_request) + + async def log_request(self, app, handler): + async def middleware(request): + if DEBUG >= 2: print(f"Received request: {request.method} {request.path}") + return await handler(request) + + return middleware + + async def handle_root(self, request): + return web.FileResponse(self.static_dir/"index.html") + + async def handle_get_models(self, request): + return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True } for model_name, _ in model_base_shards.items()]) + + async def handle_post_chat_token_encode(self, request): + data = await request.json() + shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname) + messages = [parse_message(msg) for msg in data.get("messages", [])] + tokenizer = await resolve_tokenizer(shard.model_id) + return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])}) + + async def handle_post_chat_completions(self, request): + data = await request.json() + if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}") + stream = data.get("stream", False) + chat_request = parse_chat_request(data) + if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead + chat_request.model = "llama-3.1-8b" + if not chat_request.model or chat_request.model not in model_base_shards: + if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b") + chat_request.model = "llama-3.1-8b" + shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None) + if not shard: + supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines] + return web.json_response( + {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"}, + status=400, + ) + + tokenizer = await resolve_tokenizer(shard.model_id) + if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}") + + prompt, image_str = build_prompt(tokenizer, chat_request.messages) + request_id = str(uuid.uuid4()) + if self.on_chat_completion_request: + try: + self.on_chat_completion_request(request_id, chat_request, prompt) + except Exception as e: + if DEBUG >= 2: traceback.print_exc() + # request_id = None + # match = self.prompts.find_longest_prefix(prompt) + # if match and len(prompt) > len(match[1].prompt): + # if DEBUG >= 2: + # print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}") + # request_id = match[1].request_id + # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) + # # remove the matching prefix from the prompt + # prompt = prompt[len(match[1].prompt):] + # else: + # request_id = str(uuid.uuid4()) + # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) + + callback_id = f"chatgpt-api-wait-response-{request_id}" + callback = self.node.on_token.register(callback_id) + + if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}") + try: + await self.node.process_prompt(shard, prompt, image_str, request_id=request_id) + except Exception as e: + if DEBUG >= 2: traceback.print_exc() + return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) + + try: + if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s") + + if stream: + response = web.StreamResponse( + status=200, + reason="OK", + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + }, + ) + await response.prepare(request) + + async def stream_result(request_id: str, tokens: List[int], is_finished: bool): + prev_last_tokens_len = self.prev_token_lens.get(request_id, 0) + self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens)) + new_tokens = tokens[prev_last_tokens_len:] + finish_reason = None + eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, + AutoTokenizer) else getattr(tokenizer, "eos_token_id", None) + if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id: + new_tokens = new_tokens[:-1] + if is_finished: + finish_reason = "stop" + if is_finished and not finish_reason: + finish_reason = "length" + + completion = generate_completion( + chat_request, + tokenizer, + prompt, + request_id, + new_tokens, + stream, + finish_reason, + "chat.completion", + ) + if DEBUG >= 2: print(f"Streaming completion: {completion}") + try: + await response.write(f"data: {json.dumps(completion)}\n\n".encode()) + except Exception as e: + if DEBUG >= 2: print(f"Error streaming completion: {e}") + if DEBUG >= 2: traceback.print_exc() + + def on_result(_request_id: str, tokens: List[int], is_finished: bool): + self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished)) + + return _request_id == request_id and is_finished + + _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs) + if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete + if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.") + try: + await asyncio.wait_for(self.stream_tasks[request_id], timeout=30) + except asyncio.TimeoutError: + print("WARNING: Stream task timed out. This should not happen.") + await response.write_eof() + return response + else: + _, tokens, _ = await callback.wait( + lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, + timeout=self.response_timeout_secs, + ) + + finish_reason = "length" + eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id + if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}") + if tokens[-1] == eos_token_id: + tokens = tokens[:-1] + finish_reason = "stop" + + return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion")) + except asyncio.TimeoutError: + return web.json_response({"detail": "Response generation timed out"}, status=408) + finally: + deregistered_callback = self.node.on_token.deregister(callback_id) + if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}") + + async def run(self, host: str = "0.0.0.0", port: int = 8000): + runner = web.AppRunner(self.app) + await runner.setup() + site = web.TCPSite(runner, host, port) + await site.start() diff --git a/build/lib/exo/download/__init__.py b/build/lib/exo/download/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/download/download_progress.py b/build/lib/exo/download/download_progress.py new file mode 100644 index 00000000..779e5328 --- /dev/null +++ b/build/lib/exo/download/download_progress.py @@ -0,0 +1,61 @@ +from typing import Dict, Callable, Coroutine, Any, Literal +from dataclasses import dataclass +from datetime import timedelta + + +@dataclass +class RepoFileProgressEvent: + repo_id: str + repo_revision: str + file_path: str + downloaded: int + downloaded_this_session: int + total: int + speed: int + eta: timedelta + status: Literal["not_started", "in_progress", "complete"] + + def to_dict(self): + return { + "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session, + "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status + } + + @classmethod + def from_dict(cls, data): + if 'eta' in data: data['eta'] = timedelta(seconds=data['eta']) + return cls(**data) + + +@dataclass +class RepoProgressEvent: + repo_id: str + repo_revision: str + completed_files: int + total_files: int + downloaded_bytes: int + downloaded_bytes_this_session: int + total_bytes: int + overall_speed: int + overall_eta: timedelta + file_progress: Dict[str, RepoFileProgressEvent] + status: Literal["not_started", "in_progress", "complete"] + + def to_dict(self): + return { + "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes, + "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(), + "file_progress": {k: v.to_dict() + for k, v in self.file_progress.items()}, "status": self.status + } + + @classmethod + def from_dict(cls, data): + if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta']) + if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()} + + return cls(**data) + + +RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]] +RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]] diff --git a/build/lib/exo/download/hf/__init__.py b/build/lib/exo/download/hf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/download/hf/hf_helpers.py b/build/lib/exo/download/hf/hf_helpers.py new file mode 100644 index 00000000..8fd96dc5 --- /dev/null +++ b/build/lib/exo/download/hf/hf_helpers.py @@ -0,0 +1,403 @@ +import asyncio +import aiohttp +import json +import os +from urllib.parse import urljoin +from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal +from datetime import datetime, timedelta +from fnmatch import fnmatch +from pathlib import Path +from typing import Generator, Iterable, TypeVar, TypedDict +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type +from exo.helpers import DEBUG +from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback +from exo.inference.shard import Shard +import aiofiles +from aiofiles import os as aios + +T = TypeVar("T") + +async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]: + refs_dir = get_repo_root(repo_id)/"refs" + refs_file = refs_dir/revision + if await aios.path.exists(refs_file): + async with aiofiles.open(refs_file, 'r') as f: + commit_hash = (await f.read()).strip() + snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash + return snapshot_dir + return None + + +def filter_repo_objects( + items: Iterable[T], + *, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + key: Optional[Callable[[T], str]] = None, +) -> Generator[T, None, None]: + if isinstance(allow_patterns, str): + allow_patterns = [allow_patterns] + if isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + if allow_patterns is not None: + allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns] + if ignore_patterns is not None: + ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns] + + if key is None: + + def _identity(item: T) -> str: + if isinstance(item, str): + return item + if isinstance(item, Path): + return str(item) + raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.") + + key = _identity + + for item in items: + path = key(item) + if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns): + continue + if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns): + continue + yield item + + +def _add_wildcard_to_directories(pattern: str) -> str: + if pattern[-1] == "/": + return pattern + "*" + return pattern + + +def get_hf_home() -> Path: + """Get the Hugging Face home directory.""" + return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface")) + + +async def get_hf_token(): + """Retrieve the Hugging Face token from the user's HF_HOME directory.""" + token_path = get_hf_home()/"token" + if await aios.path.exists(token_path): + async with aiofiles.open(token_path, 'r') as f: + return (await f.read()).strip() + return None + + +async def get_auth_headers(): + """Get authentication headers if a token is available.""" + token = await get_hf_token() + if token: + return {"Authorization": f"Bearer {token}"} + return {} + + +def get_repo_root(repo_id: str) -> Path: + """Get the root directory for a given repo ID in the Hugging Face cache.""" + sanitized_repo_id = repo_id.replace("/", "--") + return get_hf_home()/"hub"/f"models--{sanitized_repo_id}" + + +async def fetch_file_list(session, repo_id, revision, path=""): + api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}" + url = f"{api_url}/{path}" if path else api_url + + headers = await get_auth_headers() + async with session.get(url, headers=headers) as response: + if response.status == 200: + data = await response.json() + files = [] + for item in data: + if item["type"] == "file": + files.append({"path": item["path"], "size": item["size"]}) + elif item["type"] == "directory": + subfiles = await fetch_file_list(session, repo_id, revision, item["path"]) + files.extend(subfiles) + return files + else: + raise Exception(f"Failed to fetch file list: {response.status}") + + +@retry( + stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True +) +async def download_file( + session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True +): + base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/" + url = urljoin(base_url, file_path) + local_path = os.path.join(save_directory, file_path) + + await aios.makedirs(os.path.dirname(local_path), exist_ok=True) + + # Check if file already exists and get its size + local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0 + + headers = await get_auth_headers() + if use_range_request: + headers["Range"] = f"bytes={local_file_size}-" + + async with session.get(url, headers=headers) as response: + total_size = int(response.headers.get('Content-Length', 0)) + downloaded_size = local_file_size + downloaded_this_session = 0 + mode = 'ab' if use_range_request else 'wb' + if downloaded_size == total_size: + if DEBUG >= 2: print(f"File already downloaded: {file_path}") + if progress_callback: + await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) + return + + if response.status == 200: + # File doesn't support range requests or we're not using them, start from beginning + mode = 'wb' + downloaded_size = 0 + elif response.status == 206: + # Partial content, resume download + content_range = response.headers.get('Content-Range', '') + try: + total_size = int(content_range.split('/')[-1]) + except ValueError: + if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") + return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) + elif response.status == 416: + # Range not satisfiable, get the actual file size + content_range = response.headers.get('Content-Range', '') + try: + total_size = int(content_range.split('/')[-1]) + if downloaded_size == total_size: + if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}") + if progress_callback: + await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) + return + except ValueError: + if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") + return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) + else: + raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}") + + if downloaded_size == total_size: + print(f"File already downloaded: {file_path}") + if progress_callback: + await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) + return + + DOWNLOAD_CHUNK_SIZE = 32768 + start_time = datetime.now() + async with aiofiles.open(local_path, mode) as f: + async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): + await f.write(chunk) + downloaded_size += len(chunk) + downloaded_this_session += len(chunk) + if progress_callback and total_size: + elapsed_time = (datetime.now() - start_time).total_seconds() + speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0 + remaining_size = total_size - downloaded_size + eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0) + status = "in_progress" if downloaded_size < total_size else "complete" + if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}") + await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status)) + if DEBUG >= 2: print(f"Downloaded: {file_path}") + + +async def download_repo_files( + repo_id: str, + revision: str = "main", + progress_callback: Optional[RepoProgressCallback] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + max_parallel_downloads: int = 4 +) -> Path: + repo_root = get_repo_root(repo_id) + refs_dir = repo_root/"refs" + snapshots_dir = repo_root/"snapshots" + cachedreqs_dir = repo_root/"cachedreqs" + + # Ensure directories exist + await aios.makedirs(refs_dir, exist_ok=True) + await aios.makedirs(snapshots_dir, exist_ok=True) + await aios.makedirs(cachedreqs_dir, exist_ok=True) + + # Check if we have a cached commit hash + refs_file = refs_dir/revision + if await aios.path.exists(refs_file): + async with aiofiles.open(refs_file, 'r') as f: + commit_hash = (await f.read()).strip() + if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}") + else: + async with aiohttp.ClientSession() as session: + # Fetch the commit hash for the given revision + api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}" + headers = await get_auth_headers() + async with session.get(api_url, headers=headers) as response: + if response.status != 200: + raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}") + revision_info = await response.json() + commit_hash = revision_info['sha'] + + # Cache the commit hash + async with aiofiles.open(refs_file, 'w') as f: + await f.write(commit_hash) + + # Set up the snapshot directory + snapshot_dir = snapshots_dir/commit_hash + await aios.makedirs(snapshot_dir, exist_ok=True) + + # Set up the cached file list directory + cached_file_list_dir = cachedreqs_dir/commit_hash + await aios.makedirs(cached_file_list_dir, exist_ok=True) + cached_file_list_path = cached_file_list_dir/"fetch_file_list.json" + + async with aiohttp.ClientSession() as session: + # Check if we have a cached file list + if await aios.path.exists(cached_file_list_path): + async with aiofiles.open(cached_file_list_path, 'r') as f: + file_list = json.loads(await f.read()) + if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}") + else: + file_list = await fetch_file_list(session, repo_id, revision) + # Cache the file list + async with aiofiles.open(cached_file_list_path, 'w') as f: + await f.write(json.dumps(file_list)) + if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}") + + filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"])) + total_files = len(filtered_file_list) + total_bytes = sum(file["size"] for file in filtered_file_list) + file_progress: Dict[str, RepoFileProgressEvent] = { + file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") + for file in filtered_file_list + } + start_time = datetime.now() + + async def download_with_progress(file_info, progress_state): + local_path = snapshot_dir/file_info["path"] + if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]: + if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}") + progress_state['completed_files'] += 1 + progress_state['downloaded_bytes'] += file_info["size"] + file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete") + if progress_callback: + elapsed_time = (datetime.now() - start_time).total_seconds() + overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 + remaining_bytes = total_bytes - progress_state['downloaded_bytes'] + overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) + status = "in_progress" if progress_state['completed_files'] < total_files else "complete" + await progress_callback( + RepoProgressEvent( + repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, + overall_eta, file_progress, status + ) + ) + return + + async def file_progress_callback(event: RepoFileProgressEvent): + progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded + progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session + file_progress[event.file_path] = event + if progress_callback: + elapsed_time = (datetime.now() - start_time).total_seconds() + overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 + remaining_bytes = total_bytes - progress_state['downloaded_bytes'] + overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) + status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete" + await progress_callback( + RepoProgressEvent( + repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, + overall_eta, file_progress, status + ) + ) + + await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback) + progress_state['completed_files'] += 1 + file_progress[ + file_info["path"] + ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete") + if progress_callback: + elapsed_time = (datetime.now() - start_time).total_seconds() + overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 + remaining_bytes = total_bytes - progress_state['downloaded_bytes'] + overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) + status = "in_progress" if progress_state['completed_files'] < total_files else "complete" + await progress_callback( + RepoProgressEvent( + repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, + overall_eta, file_progress, status + ) + ) + + progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0} + + semaphore = asyncio.Semaphore(max_parallel_downloads) + + async def download_with_semaphore(file_info): + async with semaphore: + await download_with_progress(file_info, progress_state) + + tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list] + await asyncio.gather(*tasks) + + return snapshot_dir + + +async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]: + """ + Retrieve the weight map from the model.safetensors.index.json file. + + Args: + repo_id (str): The Hugging Face repository ID. + revision (str): The revision of the repository to use. + + Returns: + Optional[Dict[str, str]]: The weight map if it exists, otherwise None. + """ + + # Download the index file + await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json") + + # Check if the file exists + repo_root = get_repo_root(repo_id) + snapshot_dir = repo_root/"snapshots" + index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None) + + if index_file: + index_file_path = snapshot_dir/index_file + if await aios.path.exists(index_file_path): + async with aiofiles.open(index_file_path, 'r') as f: + index_data = json.loads(await f.read()) + return index_data.get("weight_map") + + return None + + +def extract_layer_num(tensor_name: str) -> Optional[int]: + # This is a simple example and might need to be adjusted based on the actual naming convention + parts = tensor_name.split('.') + for part in parts: + if part.isdigit(): + return int(part) + return None + + +def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: + default_patterns = [ + "*.json", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + ] + shard_specific_patterns = [] + if weight_map: + for tensor_name, filename in weight_map.items(): + layer_num = extract_layer_num(tensor_name) + if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer: + shard_specific_patterns.append(filename) + sorted_file_names = sorted(weight_map.values()) + if shard.is_first_layer(): + shard_specific_patterns.append(sorted_file_names[0]) + elif shard.is_last_layer(): + shard_specific_patterns.append(sorted_file_names[-1]) + else: + shard_specific_patterns = ["*.safetensors"] + return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates diff --git a/build/lib/exo/download/hf/hf_shard_download.py b/build/lib/exo/download/hf/hf_shard_download.py new file mode 100644 index 00000000..eb562c3c --- /dev/null +++ b/build/lib/exo/download/hf/hf_shard_download.py @@ -0,0 +1,77 @@ +import asyncio +import traceback +from pathlib import Path +from typing import Dict, List, Tuple +from exo.inference.shard import Shard +from exo.download.shard_download import ShardDownloader +from exo.download.download_progress import RepoProgressEvent +from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root +from exo.helpers import AsyncCallbackSystem, DEBUG + + +class HFShardDownloader(ShardDownloader): + def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4): + self.quick_check = quick_check + self.max_parallel_downloads = max_parallel_downloads + self.active_downloads: Dict[Shard, asyncio.Task] = {} + self.completed_downloads: Dict[Shard, Path] = {} + self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]() + + async def ensure_shard(self, shard: Shard) -> Path: + if shard in self.completed_downloads: + return self.completed_downloads[shard] + if self.quick_check: + repo_root = get_repo_root(shard.model_id) + snapshots_dir = repo_root/"snapshots" + if snapshots_dir.exists(): + visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')] + if visible_dirs: + most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime) + return most_recent_dir + + # If a download on this shard is already in progress, keep that one + for active_shard in self.active_downloads: + if active_shard == shard: + if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.") + return await self.active_downloads[shard] + + # Cancel any downloads for this model_id on a different shard + existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id] + for active_shard in existing_active_shards: + if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})") + task = self.active_downloads[active_shard] + task.cancel() + try: + await task + except asyncio.CancelledError: + pass # This is expected when cancelling a task + except Exception as e: + if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}") + traceback.print_exc() + self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id} + + # Start new download + download_task = asyncio.create_task(self._download_shard(shard)) + self.active_downloads[shard] = download_task + try: + path = await download_task + self.completed_downloads[shard] = path + return path + finally: + # Ensure the task is removed even if an exception occurs + print(f"Removing download task for {shard}: {shard in self.active_downloads}") + if shard in self.active_downloads: + self.active_downloads.pop(shard) + + async def _download_shard(self, shard: Shard) -> Path: + async def wrapped_progress_callback(event: RepoProgressEvent): + self._on_progress.trigger_all(shard, event) + + weight_map = await get_weight_map(shard.model_id) + allow_patterns = get_allow_patterns(weight_map, shard) + + return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads) + + @property + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + return self._on_progress diff --git a/build/lib/exo/download/shard_download.py b/build/lib/exo/download/shard_download.py new file mode 100644 index 00000000..771fb868 --- /dev/null +++ b/build/lib/exo/download/shard_download.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple +from pathlib import Path +from exo.inference.shard import Shard +from exo.download.download_progress import RepoProgressEvent +from exo.helpers import AsyncCallbackSystem + + +class ShardDownloader(ABC): + @abstractmethod + async def ensure_shard(self, shard: Shard) -> Path: + """ + Ensures that the shard is downloaded. + Does not allow multiple overlapping downloads at once. + If you try to download a Shard which overlaps a Shard that is already being downloaded, + the download will be cancelled and a new download will start. + + Args: + shard (Shard): The shard to download. + """ + pass + + @property + @abstractmethod + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + pass diff --git a/build/lib/exo/helpers.py b/build/lib/exo/helpers.py new file mode 100644 index 00000000..d8a5c6cc --- /dev/null +++ b/build/lib/exo/helpers.py @@ -0,0 +1,234 @@ +import os +import asyncio +from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List +import socket +import random +import platform +import psutil +import uuid +import netifaces +from pathlib import Path + +DEBUG = int(os.getenv("DEBUG", default="0")) +DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0")) +VERSION = "0.0.1" + +exo_text = r""" + _____ _____ + / _ \ \/ / _ \ +| __/> < (_) | + \___/_/\_\___/ + """ + + +def get_system_info(): + if psutil.MACOS: + if platform.machine() == "arm64": + return "Apple Silicon Mac" + if platform.machine() in ["x86_64", "i386"]: + return "Intel Mac" + return "Unknown Mac architecture" + if psutil.LINUX: + return "Linux" + return "Non-Mac, non-Linux system" + +def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int: + used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports") + + def read_used_ports(): + if os.path.exists(used_ports_file): + with open(used_ports_file, "r") as f: + return [int(line.strip()) for line in f if line.strip().isdigit()] + return [] + + def write_used_port(port, used_ports): + with open(used_ports_file, "w") as f: + print(used_ports[-19:]) + for p in used_ports[-19:] + [port]: + f.write(f"{p}\n") + + used_ports = read_used_ports() + available_ports = set(range(min_port, max_port + 1)) - set(used_ports) + + while available_ports: + port = random.choice(list(available_ports)) + if DEBUG >= 2: print(f"Trying to find available port {port=}") + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((host, port)) + write_used_port(port, used_ports) + return port + except socket.error: + available_ports.remove(port) + + raise RuntimeError("No available ports in the specified range") + + +def print_exo(): + print(exo_text) + + +def print_yellow_exo(): + yellow = "\033[93m" # ANSI escape code for yellow + reset = "\033[0m" # ANSI escape code to reset color + print(f"{yellow}{exo_text}{reset}") + + +def terminal_link(uri, label=None): + if label is None: + label = uri + parameters = "" + + # OSC 8 ; params ; URI ST OSC 8 ;; ST + escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\" + + return escape_mask.format(parameters, uri, label) + + +T = TypeVar("T") +K = TypeVar("K") + + +class AsyncCallback(Generic[T]): + def __init__(self) -> None: + self.condition: asyncio.Condition = asyncio.Condition() + self.result: Optional[Tuple[T, ...]] = None + self.observers: list[Callable[..., None]] = [] + + async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]: + async with self.condition: + await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout) + assert self.result is not None # for type checking + return self.result + + def on_next(self, callback: Callable[..., None]) -> None: + self.observers.append(callback) + + def set(self, *args: T) -> None: + self.result = args + for observer in self.observers: + observer(*args) + asyncio.create_task(self.notify()) + + async def notify(self) -> None: + async with self.condition: + self.condition.notify_all() + + +class AsyncCallbackSystem(Generic[K, T]): + def __init__(self) -> None: + self.callbacks: Dict[K, AsyncCallback[T]] = {} + + def register(self, name: K) -> AsyncCallback[T]: + if name not in self.callbacks: + self.callbacks[name] = AsyncCallback[T]() + return self.callbacks[name] + + def deregister(self, name: K) -> None: + if name in self.callbacks: + del self.callbacks[name] + + def trigger(self, name: K, *args: T) -> None: + if name in self.callbacks: + self.callbacks[name].set(*args) + + def trigger_all(self, *args: T) -> None: + for callback in self.callbacks.values(): + callback.set(*args) + + +K = TypeVar('K', bound=str) +V = TypeVar('V') + + +class PrefixDict(Generic[K, V]): + def __init__(self): + self.items: Dict[K, V] = {} + + def add(self, key: K, value: V) -> None: + self.items[key] = value + + def find_prefix(self, argument: str) -> List[Tuple[K, V]]: + return [(key, value) for key, value in self.items.items() if argument.startswith(key)] + + def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]: + matches = self.find_prefix(argument) + if len(matches) == 0: + return None + + return max(matches, key=lambda x: len(x[0])) + + +def is_valid_uuid(val): + try: + uuid.UUID(str(val)) + return True + except ValueError: + return False + + +def get_or_create_node_id(): + NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__)))/".exo_node_id" + try: + if NODE_ID_FILE.is_file(): + with open(NODE_ID_FILE, "r") as f: + stored_id = f.read().strip() + if is_valid_uuid(stored_id): + if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}") + return stored_id + else: + if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.") + + new_id = str(uuid.uuid4()) + with open(NODE_ID_FILE, "w") as f: + f.write(new_id) + + if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}") + return new_id + except IOError as e: + if DEBUG >= 2: print(f"IO error creating node_id: {e}") + return str(uuid.uuid4()) + except Exception as e: + if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}") + return str(uuid.uuid4()) + + +def pretty_print_bytes(size_in_bytes: int) -> str: + if size_in_bytes < 1024: + return f"{size_in_bytes} B" + elif size_in_bytes < 1024**2: + return f"{size_in_bytes / 1024:.2f} KB" + elif size_in_bytes < 1024**3: + return f"{size_in_bytes / (1024 ** 2):.2f} MB" + elif size_in_bytes < 1024**4: + return f"{size_in_bytes / (1024 ** 3):.2f} GB" + else: + return f"{size_in_bytes / (1024 ** 4):.2f} TB" + + +def pretty_print_bytes_per_second(bytes_per_second: int) -> str: + if bytes_per_second < 1024: + return f"{bytes_per_second} B/s" + elif bytes_per_second < 1024**2: + return f"{bytes_per_second / 1024:.2f} KB/s" + elif bytes_per_second < 1024**3: + return f"{bytes_per_second / (1024 ** 2):.2f} MB/s" + elif bytes_per_second < 1024**4: + return f"{bytes_per_second / (1024 ** 3):.2f} GB/s" + else: + return f"{bytes_per_second / (1024 ** 4):.2f} TB/s" + + +def get_all_ip_addresses(): + try: + ip_addresses = [] + for interface in netifaces.interfaces(): + ifaddresses = netifaces.ifaddresses(interface) + if netifaces.AF_INET in ifaddresses: + for link in ifaddresses[netifaces.AF_INET]: + ip = link['addr'] + ip_addresses.append(ip) + return list(set(ip_addresses)) + except: + if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.") + return ["localhost"] diff --git a/build/lib/exo/inference/__init__.py b/build/lib/exo/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/debug_inference_engine.py b/build/lib/exo/inference/debug_inference_engine.py new file mode 100644 index 00000000..27bcb592 --- /dev/null +++ b/build/lib/exo/inference/debug_inference_engine.py @@ -0,0 +1,59 @@ +from exo.inference.inference_engine import InferenceEngine +from exo.inference.shard import Shard +from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine +import asyncio +import numpy as np + + +# An inference engine should work the same for any number of Shards, as long as the Shards are continuous. +async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str): + from exo.inference.tinygrad.inference import Tokenizer + from pathlib import Path + + _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model")) + + prompt = "In a single word only, what is the last name of the president of the United States? " + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt) + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + "A", + shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), + input_data=resp_full, + inference_state=inference_state_full, + ) + + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt) + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), + input_data=resp1, + inference_state=inference_state_1, + ) + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), + input_data=resp2, + inference_state=inference_state_2, + ) + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), + input_data=resp3, + inference_state=inference_state_3, + ) + + print(f"{resp2=}") + print(f"full: {_tokenizer.decode(resp_full)}") + print(f"next full: {_tokenizer.decode(next_resp_full)}") + print(f"resp2: {_tokenizer.decode(resp2)}") + print(f"{resp4=}") + print(f"resp4: {_tokenizer.decode(resp4)}") + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) + + +asyncio.run(test_inference_engine( + TinygradDynamicShardInferenceEngine(), + TinygradDynamicShardInferenceEngine(), + "llama3-8b-sfr", +)) diff --git a/build/lib/exo/inference/inference_engine.py b/build/lib/exo/inference/inference_engine.py new file mode 100644 index 00000000..2b98adbe --- /dev/null +++ b/build/lib/exo/inference/inference_engine.py @@ -0,0 +1,34 @@ +import numpy as np +import os + +from typing import Tuple, Optional +from abc import ABC, abstractmethod +from .shard import Shard + + +class InferenceEngine(ABC): + @abstractmethod + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + pass + + @abstractmethod + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + pass + + +def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'): + if inference_engine_name == "mlx": + from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine + + return MLXDynamicShardInferenceEngine(shard_downloader) + elif inference_engine_name == "tinygrad": + from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine + import tinygrad.helpers + tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) + + return TinygradDynamicShardInferenceEngine(shard_downloader) + elif inference_engine_name == "pytorch": + from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine + return PyTorchDynamicShardInferenceEngine(shard_downloader) + else: + raise ValueError(f"Inference engine {inference_engine_name} not supported") diff --git a/build/lib/exo/inference/mlx/__init__.py b/build/lib/exo/inference/mlx/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/mlx/models/__init__.py b/build/lib/exo/inference/mlx/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/mlx/models/base.py b/build/lib/exo/inference/mlx/models/base.py new file mode 100644 index 00000000..a1f1878c --- /dev/null +++ b/build/lib/exo/inference/mlx/models/base.py @@ -0,0 +1,9 @@ +from typing import Optional +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import KVCache + + +class IdentityBlock(nn.Module): + def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array: + return x diff --git a/build/lib/exo/inference/mlx/models/deepseek_v2.py b/build/lib/exo/inference/mlx/models/deepseek_v2.py new file mode 100644 index 00000000..9ea271ed --- /dev/null +++ b/build/lib/exo/inference/mlx/models/deepseek_v2.py @@ -0,0 +1,127 @@ +from dataclasses import dataclass, field +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.base import KVCache +from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer +from .base import IdentityBlock +from exo.inference.shard import Shard + + +@dataclass +class ModelArgs(ModelArgs): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + if isinstance(self.shard, Shard): + return + if not isinstance(self.shard, dict): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + self.shard = Shard(**self.shard) + + +class DeepseekV2Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.args = config + self.num_hidden_layers = config.num_hidden_layers + self.vocab_size = config.vocab_size + if self.args.shard.is_first_layer(): + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + self.layers = [] + for i in range(self.num_hidden_layers): + if self.args.shard.start_layer <= i <= self.args.shard.end_layer: + self.layers.append(DeepseekV2DecoderLayer(config, i)) + else: + self.layers.append(IdentityBlock()) + + if self.args.shard.is_last_layer(): + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + x: mx.array, + cache: Optional[KVCache] = None, + ) -> mx.array: + if self.args.shard.is_first_layer(): + h = self.embed_tokens(x) + else: + h = x + + mask = None + T = h.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None]*len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + if self.args.shard.is_last_layer(): + h = self.norm(h) + return h + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.args = config + self.model_type = config.model_type + self.model = DeepseekV2Model(config) + if self.args.shard.is_last_layer(): + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[KVCache] = None, + ): + out = self.model(inputs, cache) + if self.args.shard.is_last_layer(): + return self.lm_head(out) + return out + + def sanitize(self, weights): + shard_state_dict = {} + + for key, value in weights.items(): + if key.startswith('model.layers.'): + layer_num = int(key.split('.')[2]) + if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: + shard_state_dict[key] = value + elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): + shard_state_dict[key] = value + elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')): + shard_state_dict[key] = value + + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict: + to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)] + shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) + + return shard_state_dict + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return ( + self.args.qk_nope_head_dim + self.args.qk_rope_head_dim, + self.args.v_head_dim, + ) + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/build/lib/exo/inference/mlx/models/llama.py b/build/lib/exo/inference/mlx/models/llama.py new file mode 100644 index 00000000..719d6a88 --- /dev/null +++ b/build/lib/exo/inference/mlx/models/llama.py @@ -0,0 +1,125 @@ +from dataclasses import dataclass, field + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.llama import TransformerBlock, ModelArgs + +from ...shard import Shard +from .base import IdentityBlock + + +@dataclass +class ModelArgs(ModelArgs): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + super().__post_init__() # Ensure parent initializations are respected + + if isinstance(self.shard, Shard): + return + if not isinstance(self.shard, dict): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + self.shard = Shard(**self.shard) + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + if self.args.shard.is_first_layer(): + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [] + for i in range(self.num_hidden_layers): + if self.args.shard.start_layer <= i <= self.args.shard.end_layer: + self.layers.append(TransformerBlock(args=args)) + else: + self.layers.append(IdentityBlock()) + if self.args.shard.is_last_layer(): + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + if self.args.shard.is_first_layer(): + h = self.embed_tokens(inputs) + else: + h = inputs + + mask = None + if h.shape[1] > 1: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None]*len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + if self.args.shard.is_last_layer(): + h = self.norm(h) + return h + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LlamaModel(args) + if self.args.shard.is_last_layer(): + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.shard.is_last_layer(): + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + shard_state_dict = {} + + for key, value in weights.items(): + if "self_attn.rotary_emb.inv_freq" in key: + continue + if key.startswith('model.layers.'): + layer_num = int(key.split('.')[2]) + if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: + shard_state_dict[key] = value + elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): + shard_state_dict[key] = value + elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'): + shard_state_dict[key] = value + elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'): + shard_state_dict[key] = value + elif self.args.shard.is_last_layer() and (key.startswith('model.norm')): + shard_state_dict[key] = value + + return shard_state_dict + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads) + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/build/lib/exo/inference/mlx/models/llava.py b/build/lib/exo/inference/mlx/models/llava.py new file mode 100644 index 00000000..b734b09b --- /dev/null +++ b/build/lib/exo/inference/mlx/models/llava.py @@ -0,0 +1,585 @@ +# Copyright © 2024 Apple Inc. + +import math +import inspect +from dataclasses import dataclass, field +from typing import Optional, Dict, Union + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import BaseModelArgs, KVCache +from exo.inference.shard import Shard +from .base import IdentityBlock +import numpy as np + + +@dataclass +class VisionConfig: + model_type: str + num_hidden_layers: int = 24 + hidden_size: int = 1024 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + layer_norm_eps: float = 1e-5 + + @classmethod + def from_dict(cls, params): + return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) + + +class VisionAttention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError("The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0") + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + scale = math.sqrt(1/queries.shape[-1]) + scores = (queries*scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + +class VisionMLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.activation_fn = nn.GELU(approx="fast") + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x + + +class VisionEncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = VisionAttention(config.hidden_size, config.num_attention_heads, bias=True) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = VisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y + + +class VisionEncoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + + +class VisionEmbeddings(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = mx.zeros((config.hidden_size,)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: + batch_size = x.shape[0] + patch_embeddings = self.patch_embedding(x) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) + embed_dim = patch_embeddings.shape[-1] + cls_embeddings = mx.broadcast_to(self.class_embedding, (batch_size, 1, embed_dim)) + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + embeddings += self.position_embedding.weight + return embeddings + + +class ClipVisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = VisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + x = self.embeddings(x) + x = self.pre_layrnorm(x) + + encoder_states = (x,) if output_hidden_states else None + + for l in self.encoder.layers: + x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) + + pooler_output = self.post_layernorm(x[:, 0, :]) + return pooler_output, x, encoder_states + + +class VisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + + self.model_type = config.model_type + if self.model_type != "clip_vision_model": + raise ValueError(f"Unsupported model type: {self.model_type}") + + self.vision_model = ClipVisionModel(config) + + def __call__(self, x: mx.array, output_hidden_states: Optional[bool] = None) -> mx.array: + return self.vision_model(x, output_hidden_states) + + def sanitize(self, weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_embedding.weight" in k: + # PyTorch conv2d weight tensors have shape: + # [out_channels, in_channels, kH, KW] + # MLX conv2d expects the weight be of shape: + # [out_channels, kH, KW, in_channels] + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights + + +@dataclass +class TextConfig: + model_type: str + hidden_size: int = 4096 + num_hidden_layers: int = 32 + intermediate_size: int = 11008 + num_attention_heads: int = 32 + head_dim: int = None + rms_norm_eps: float = 1e-6 + vocab_size: int = 32000 + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + @classmethod + def from_dict(cls, params): + return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.head_dim is None: + self.head_dim = self.hidden_size // self.num_attention_heads + + if self.model_type is None: + self.model_type = "llama" + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + +class TextAttention(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + + dim = config.hidden_size + self.n_heads = n_heads = config.num_attention_heads + self.n_kv_heads = n_kv_heads = config.num_key_value_heads + + self.repeats = n_heads // n_kv_heads + + head_dim = config.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False) + self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False) + + rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1) + self.rope = nn.RoPE( + head_dim, + traditional=config.rope_traditional, + base=config.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class TextMLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.self_attn = TextAttention(config) + self.mlp = TextMLP(config.hidden_size, config.intermediate_size) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class Llama(nn.Module): + def __init__(self, config: TextConfig, shard: Shard): + super().__init__() + self.config = config + self.shard = shard + self.vocab_size = config.vocab_size + self.model_type = config.model_type + self.num_hidden_layers = config.num_hidden_layers + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + assert self.vocab_size > 0 + if self.shard.is_first_layer(): + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [] + for i in range(self.num_hidden_layers): + if self.shard.start_layer <= i <= self.shard.end_layer: + self.layers.append(TransformerBlock(config=config)) + else: + self.layers.append(IdentityBlock()) + if self.shard.is_last_layer(): + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + ): + # for passing merged input embeddings + if inputs_embeds is None: + if self.shard.is_first_layer(): + h = self.embed_tokens(inputs) + else: + h = inputs + else: + h = inputs_embeds + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None]*len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + if self.shard.is_last_layer(): + h = self.norm(h) + return h + + +class LanguageModel(nn.Module): + def __init__(self, config: TextConfig, shard: Shard): + super().__init__() + self.model_type = config.model_type + if self.model_type != "llama": + raise ValueError(f"Model type {self.model_type} not supported. Currently only 'llama' is supported") + self.shard = shard + self.model = Llama(config, shard) + if self.shard.is_last_layer(): + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + ): + out = self.model(inputs, cache, inputs_embeds) + if self.shard.is_last_layer(): + out = self.lm_head(out) + return out + + def sanitize(self, weights): + shard_state_dict = {} + for key, value in weights.items(): + if "self_attn.rotary_emb.inv_freq" in key: + continue + + if key.startswith('language_model.model.layers.'): + layer_num = int(key.split('.')[3]) + if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer: + continue + if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'): + continue + elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')): + continue + + shard_state_dict[key] = value + + return shard_state_dict + + +@dataclass +class LlaVAConfig(BaseModelArgs): + text_config: TextConfig + vision_config: VisionConfig = None + model_type: str = "llava" + ignore_index: int = -100 + image_token_index: int = 32000 + vision_feature_select_strategy: str = "default" + vision_feature_layer: int = -2 + vocab_size: int = 32000 + + @classmethod + def from_dict(cls, params): + updated_params = {} + class_params = inspect.signature(cls).parameters + for k, v in params.items(): + if k in class_params: + if k in ["text_config", "vision_config"]: + v = class_params[k].annotation.from_dict(v) + updated_params.update({k: v}) + + return cls(**updated_params) + + +@dataclass +class ModelArgs(LlaVAConfig): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + if isinstance(self.shard, dict): + self.shard = Shard(**self.shard) + + if not isinstance(self.shard, Shard): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + if not self.shard.is_first_layer(): + self.vision_config = None + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlaVAConfig): + super().__init__() + self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.gelu = nn.GELU() + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + + def __call__(self, x: mx.array) -> mx.array: + x = self.linear_1(x) + x = self.gelu(x) + x = self.linear_2(x) + return x + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.model_type = config.model_type + if config.vision_config: + self.vision_tower = VisionModel(config.vision_config) + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vision_feature_layer = config.vision_feature_layer + self.vision_feature_select_strategy = config.vision_feature_select_strategy + self.language_model = LanguageModel(config.text_config, config.shard) + + def get_input_embeddings( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + ): + if pixel_values is None: + return self.language_model(input_ids) + + # Get the input embeddings from the language model + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + # Get the ouptut hidden states from the vision model + *_, hidden_states = self.vision_tower(pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True) + + # Select the hidden states from the desired layer + selected_image_feature = hidden_states[self.vision_feature_layer] + + if self.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError("Unexpected feature selection strategy: " + f"{self.vision_feature_select_strategy}") + + # Pass image features through the multi-modal projector + image_features = self.multi_modal_projector(selected_image_feature) + + # Insert special image tokens in the input_ids + final_inputs_embeds = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids) + return final_inputs_embeds + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids): + image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape + + # Positions of tokens in input_ids, assuming batch size is 1 + image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + + if len(image_positions) != num_images: + raise ValueError(f"The number of image tokens ({len(image_positions)}) does not " + f" match the number of image inputs ({num_images}).") + + text_segments = [] + start_idx = 0 + + for position in image_positions: + text_segments.append(inputs_embeds[:, start_idx:position]) + start_idx = position + 1 + + image_embeddings = mx.split(image_features, image_features.shape[0]) + final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] + final_embeddings += [inputs_embeds[:, start_idx:]] + + # Create a final embedding of shape + # (1, num_image_patches*num_images + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=1) + + def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None): + input_embddings = None + if pixel_values is not None: + input_embddings = self.get_input_embeddings(input_ids, pixel_values) + logits = self.language_model(input_ids, cache=cache, inputs_embeds=input_embddings) + return logits + + def sanitize(self, weights): + if self.config.vision_config: + weights = self.vision_tower.sanitize(weights) + else: + weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))} + weights = self.language_model.sanitize(weights) + return weights + + @property + def layers(self): + return self.language_model.model.layers + + @property + def head_dim(self): + return (self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads) + + @property + def n_kv_heads(self): + return self.language_model.model.num_key_value_heads diff --git a/build/lib/exo/inference/mlx/sharded_inference_engine.py b/build/lib/exo/inference/mlx/sharded_inference_engine.py new file mode 100644 index 00000000..40cabfeb --- /dev/null +++ b/build/lib/exo/inference/mlx/sharded_inference_engine.py @@ -0,0 +1,40 @@ +import numpy as np +import mlx.core as mx +from ..inference_engine import InferenceEngine +from .sharded_model import StatefulShardedModel +from .sharded_utils import load_shard, get_image_from_str +from ..shard import Shard +from typing import Optional +from exo.download.shard_download import ShardDownloader + + +class MLXDynamicShardInferenceEngine(InferenceEngine): + def __init__(self, shard_downloader: ShardDownloader): + self.shard = None + self.shard_downloader = shard_downloader + + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): + await self.ensure_shard(shard) + if image_str: + image = await get_image_from_str(image_str) + inputs = self.tokenizer(prompt, image, return_tensors="np") + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, input_ids, pixel_values)) + else: + output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt)))) + return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id + + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): + await self.ensure_shard(shard) + output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data))) + return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id + + async def ensure_shard(self, shard: Shard): + if self.shard == shard: + return + + model_path = await self.shard_downloader.ensure_shard(shard) + model_shard, self.tokenizer = await load_shard(model_path, shard) + self.stateful_sharded_model = StatefulShardedModel(shard, model_shard) + self.shard = shard diff --git a/build/lib/exo/inference/mlx/sharded_model.py b/build/lib/exo/inference/mlx/sharded_model.py new file mode 100644 index 00000000..c4570fbf --- /dev/null +++ b/build/lib/exo/inference/mlx/sharded_model.py @@ -0,0 +1,86 @@ +from typing import Dict, Generator, Optional, Tuple +from collections import OrderedDict + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import KVCache, RotatingKVCache +from mlx_lm.sample_utils import top_p_sampling + +from ..shard import Shard + + +class StatefulShardedModel: + def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2): + self.shard = shard + self.model = model + self.max_kv_size = max_kv_size + self.max_caches = max_caches + self.caches = OrderedDict() + + def step( + self, + request_id: str, + x, + pixel_values=None, + temp: float = 0.0, + top_p: float = 1.0, + logit_bias: Optional[Dict[int, float]] = None, + ) -> Generator[Tuple[mx.array, mx.array], None, None]: + def sample(logits: mx.array) -> Tuple[mx.array, float]: + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + logits[:, indices] += values + + if temp == 0: + token = mx.argmax(logits, axis=-1) + else: + if top_p > 0 and top_p < 1.0: + token = top_p_sampling(logits, top_p, temp) + else: + token = mx.random.categorical(logits*(1/temp)) + + return token + + y = x + + if request_id not in self.caches: + self.init_cache(request_id) + else: + self.caches.move_to_end(request_id) + + cache = self.caches[request_id] + + if pixel_values is None: + output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache) + else: + output = self.model(y, pixel_values=pixel_values, cache=cache) + + if self.shard.is_last_layer(): + logits = output[:, -1, :] + y = sample(logits) + return y + else: + return output + + def __call__( + self, + request_id: str, + x, + temp: float = 0.0, + top_p: float = 1.0, + logit_bias: Optional[Dict[int, float]] = None, + ) -> Generator[Tuple[mx.array, mx.array], None, None]: + return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias) + + def init_cache(self, request_id: str): + kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads) + if self.max_kv_size is not None: + cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads] + else: + cache = [KVCache(self.model.head_dim, n) for n in kv_heads] + + if len(self.caches) >= self.max_caches: + self.caches.popitem(last=False) + + self.caches[request_id] = cache diff --git a/build/lib/exo/inference/mlx/sharded_utils.py b/build/lib/exo/inference/mlx/sharded_utils.py new file mode 100644 index 00000000..7fa38eaa --- /dev/null +++ b/build/lib/exo/inference/mlx/sharded_utils.py @@ -0,0 +1,207 @@ +# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py + +import glob +import importlib +import json +import logging +import asyncio +import aiohttp +from functools import partial +from pathlib import Path +from typing import Optional, Tuple, Union, List, Callable +from PIL import Image +from io import BytesIO +import base64 + +import mlx.core as mx +import mlx.nn as nn +from transformers import AutoProcessor + +from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper +from mlx_lm.tuner.utils import apply_lora_layers + +from exo import DEBUG +from ..shard import Shard + + +class ModelNotFoundError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +MODEL_REMAPPING = { + "mistral": "llama", # mistral is compatible with llama + "phi-msft": "phixtral", +} + + +def _get_classes(config: dict): + """ + Retrieve the model and model args classes based on the configuration. + + Args: + config (dict): The model configuration. + + Returns: + A tuple containing the Model class and the ModelArgs class. + """ + model_type = config["model_type"] + model_type = MODEL_REMAPPING.get(model_type, model_type) + try: + arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}") + except ImportError: + msg = f"Model type {model_type} not supported." + logging.error(msg) + raise ValueError(msg) + + return arch.Model, arch.ModelArgs + + +def load_config(model_path: Path) -> dict: + try: + with open(model_path/"config.json", "r") as f: + config = json.load(f) + except FileNotFoundError: + logging.error(f"Config file not found in {model_path}") + raise + return config + + +def load_model_shard( + model_path: Path, + shard: Shard, + lazy: bool = False, + model_config: dict = {}, +) -> nn.Module: + """ + Load and initialize the model from a given path. + + Args: + model_path (Path): The path to load the model from. + lazy (bool): If False eval the model parameters to make sure they are + loaded in memory before returning, otherwise they will be loaded + when needed. Default: ``False`` + model_config(dict, optional): Configuration parameters for the model. + Defaults to an empty dictionary. + + Returns: + nn.Module: The loaded and initialized model. + + Raises: + FileNotFoundError: If the weight files (.safetensors) are not found. + ValueError: If the model class or args class are not found or cannot be instantiated. + """ + config = load_config(model_path) + config.update(model_config) + + # TODO hack + config["shard"] = { + "model_id": model_path.name, + "start_layer": shard.start_layer, + "end_layer": shard.end_layer, + "n_layers": shard.n_layers, + } + + weight_files = glob.glob(str(model_path/"model*.safetensors")) + + if not weight_files: + # Try weight for back-compat + weight_files = glob.glob(str(model_path/"weight*.safetensors")) + + if not weight_files: + logging.error(f"No safetensors found in {model_path}") + raise FileNotFoundError(f"No safetensors found in {model_path}") + + weights = {} + for wf in sorted(weight_files): + if DEBUG >= 8: + layer_nums = set() + for k in mx.load(wf): + if k.startswith("model.layers."): + layer_num = int(k.split(".")[2]) + layer_nums.add(layer_num) + if k.startswith("language_model.model.layers."): + layer_num = int(k.split(".")[3]) + layer_nums.add(layer_num) + print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},") + + weights.update(mx.load(wf)) + + model_class, model_args_class = _get_classes(config=config) + + model_args = model_args_class.from_dict(config) + model = model_class(model_args) + + if hasattr(model, "sanitize"): + weights = model.sanitize(weights) + + if (quantization := config.get("quantization", None)) is not None: + # Handle legacy models which may not have everything quantized + def class_predicate(p, m): + if not hasattr(m, "to_quantized"): + return False + return f"{p}.scales" in weights + + nn.quantize( + model, + **quantization, + class_predicate=class_predicate, + ) + + model.load_weights(list(weights.items()), strict=True) + + if not lazy: + mx.eval(model.parameters()) + + model.eval() + return model + + +async def load_shard( + model_path: str, + shard: Shard, + tokenizer_config={}, + model_config={}, + adapter_path: Optional[str] = None, + lazy: bool = False, +) -> Tuple[nn.Module, TokenizerWrapper]: + model = load_model_shard(model_path, shard, lazy, model_config) + if adapter_path is not None: + model = apply_lora_layers(model, adapter_path) + model.eval() + + # TODO: figure out a generic solution + if model.model_type == "llava": + processor = AutoProcessor.from_pretrained(model_path) + processor.eos_token_id = processor.tokenizer.eos_token_id + processor.encode = processor.tokenizer.encode + return model, processor + else: + tokenizer = load_tokenizer(model_path, tokenizer_config) + return model, tokenizer + + +async def get_image_from_str(_image_str: str): + image_str = _image_str.strip() + + if image_str.startswith("http"): + async with aiohttp.ClientSession() as session: + async with session.get(image_str, timeout=10) as response: + content = await response.read() + return Image.open(BytesIO(content)).convert("RGB") + elif image_str.startswith("data:image/"): + # Extract the image format and base64 data + format_prefix, base64_data = image_str.split(";base64,") + image_format = format_prefix.split("/")[1].lower() + if DEBUG >= 2: print(f"{image_str=} {image_format=}") + imgdata = base64.b64decode(base64_data) + img = Image.open(BytesIO(imgdata)) + + # Convert to RGB if not already + if img.mode != "RGB": + img = img.convert("RGB") + + return img + else: + raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.") diff --git a/build/lib/exo/inference/mlx/test_sharded_llama.py b/build/lib/exo/inference/mlx/test_sharded_llama.py new file mode 100644 index 00000000..1c48b936 --- /dev/null +++ b/build/lib/exo/inference/mlx/test_sharded_llama.py @@ -0,0 +1,40 @@ +import mlx.core as mx +from exo.inference.mlx.sharded_model import StatefulShardedModel +from exo.inference.mlx.sharded_utils import load_shard +from exo.inference.shard import Shard + +# 79, 80 for Llama-3-70B +shard_full = Shard("llama", 0, 31, 32) +shard1 = Shard("llama", 0, 12, 32) +shard2 = Shard("llama", 13, 31, 32) + +full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard_full) +model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1) +model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2) + +full = StatefulShardedModel(shard_full, full_model_shard) +m1 = StatefulShardedModel(shard1, model_shard1) +m2 = StatefulShardedModel(shard2, model_shard2) + +prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:" +prompt_tokens = mx.array(full_tokenizer.encode(prompt)) +max_tokens = 50 + +resp = prompt_tokens +full_generated_tokens = [] +for _ in range(max_tokens): + resp = full.step(resp) + full_generated_tokens.append(resp.item()) + +print("full response: ", full_tokenizer.decode(full_generated_tokens)) + +sharded_generated_tokens = [] +sharded_resp = prompt_tokens +for _ in range(max_tokens): + resp1 = m1.step(sharded_resp) + sharded_resp = m2.step(resp1) + sharded_generated_tokens.append(sharded_resp.item()) + +print("sharded response: ", tokenizer1.decode(sharded_generated_tokens)) + +assert tokenizer1.decode(full_generated_tokens) == tokenizer1.decode(sharded_generated_tokens) diff --git a/build/lib/exo/inference/mlx/test_sharded_llava.py b/build/lib/exo/inference/mlx/test_sharded_llava.py new file mode 100644 index 00000000..958a5acc --- /dev/null +++ b/build/lib/exo/inference/mlx/test_sharded_llava.py @@ -0,0 +1,64 @@ +import codecs +import asyncio +import requests +from PIL import Image +from io import BytesIO + +import mlx.core as mx +from mlx_lm.models.base import KVCache + +from exo.inference.mlx.sharded_model import StatefulShardedModel +from exo.inference.mlx.sharded_utils import load_shard +from exo.inference.shard import Shard + +shard_full = Shard("llava", 0, 31, 32) +shard1 = Shard("llava", 0, 12, 32) +shard2 = Shard("llava", 13, 31, 32) + +model_path = "llava-hf/llava-1.5-7b-hf" + +full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full)) +model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1)) +model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2)) + +full = StatefulShardedModel(shard_full, full_model_shard) +m1 = StatefulShardedModel(shard1, model_shard1) +m2 = StatefulShardedModel(shard2, model_shard2) + +PROMPT = "USER: \nWhat are these?\nASSISTANT:" +IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg" +response = requests.get(IMAGE_FILE) +img = Image.open(BytesIO(response.content)) +prompt = codecs.decode(PROMPT, "unicode_escape") +inputs = full_processor(prompt, img, return_tensors="np") +pixel_values = mx.array(inputs["pixel_values"]) +input_ids = mx.array(inputs["input_ids"]) + +print(prompt) +y = full.step("full", input_ids, pixel_values, temp=0) +full_generated_tokens = [y.item()] + +for _ in range(13): + y = full.step("full", y, temp=0) + full_generated_tokens.append(y.item()) + +full_response = full_processor.tokenizer.decode(full_generated_tokens) +print("full response:", full_response) + +inputs = processor1(prompt, img, return_tensors="np") +pixel_values = mx.array(inputs["pixel_values"]) +input_ids = mx.array(inputs["input_ids"]) + +y = m1.step("shard", input_ids, pixel_values, temp=0) +y = m2.step("shard", y, temp=0) +full_generated_tokens = [y.item()] + +for _ in range(13): + y = m1.step("shard", y, temp=0) + y = m2.step("shard", y, temp=0) + full_generated_tokens.append(y.item()) + +sharded_response = processor2.tokenizer.decode(full_generated_tokens) +print("sharded response:", sharded_response) + +assert full_response == sharded_response diff --git a/build/lib/exo/inference/mlx/test_sharded_model.py b/build/lib/exo/inference/mlx/test_sharded_model.py new file mode 100644 index 00000000..c9743d07 --- /dev/null +++ b/build/lib/exo/inference/mlx/test_sharded_model.py @@ -0,0 +1,52 @@ +from exo.inference.shard import Shard +import mlx.core as mx +import mlx.nn as nn +from typing import Optional +import numpy as np + + +class DummyModel(nn.Module): + def __init__(self, shard: Optional[Shard] = None): + self.shard = shard + self.layers = [ + nn.Linear(8, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 8), + ] + + self.n_kv_heads = 4 + self.head_dim = 4 + + def __call__(self, x, cache=None): + if self.shard: + for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]: + x = layer(x) + if self.shard.is_last_layer(): + x = x.reshape((1, 2, 4)) + else: + for layer in self.layers: + x = layer(x) + x = x.reshape((1, 2, 4)) + + return x + + +model = DummyModel() +model.save_weights("./test_weights.npz") +n_layers = 5 +shard1 = Shard("test", 0, n_layers // 2, n_layers) +sharded_model1 = DummyModel(shard1) +shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers) +sharded_model2 = DummyModel(shard2) + +model.load_weights("./test_weights.npz") +sharded_model1.load_weights("./test_weights.npz") +sharded_model2.load_weights("./test_weights.npz") + +fullresp = model(mx.array([1, 2, 3, 4, 5, 6, 7, 8])) +resp1 = sharded_model1(mx.array([1, 2, 3, 4, 5, 6, 7, 8])) +resp2 = sharded_model2(resp1) + +assert np.all(np.array(fullresp) == np.array(resp2)) diff --git a/build/lib/exo/inference/pytorch/__init__.py b/build/lib/exo/inference/pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/pytorch/helpers.py b/build/lib/exo/inference/pytorch/helpers.py new file mode 100644 index 00000000..addea2db --- /dev/null +++ b/build/lib/exo/inference/pytorch/helpers.py @@ -0,0 +1,24 @@ +# Helper functions for pytorch inference +# Some code coming from tinygrad but written towards pytorch + +import asyncio +import aiohttp +from tqdm import tqdm +from pathlib import Path +from typing import List + +async def fetch_file_async(session, url: str, output_path: Path): + async with session.get(url) as response: + response.raise_for_status() + with open(output_path, 'wb') as f: + async for chunk in response.content.iter_chunked(8192): + f.write(chunk) + +async def download_files(urls: List[str], output_paths: List[Path]): + async with aiohttp.ClientSession() as session: + tasks = [] + for url, output_path in zip(urls, output_paths): + tasks.append(fetch_file_async(session, url, output_path)) + + for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Downloading files"): + await f diff --git a/build/lib/exo/inference/pytorch/inference.py b/build/lib/exo/inference/pytorch/inference.py new file mode 100644 index 00000000..ba834eb6 --- /dev/null +++ b/build/lib/exo/inference/pytorch/inference.py @@ -0,0 +1,211 @@ +# experimental, based off of tinygrad/inference.py +import numpy as np +import torch +import numpy as np +import json +from typing import Optional, Tuple +from exo.inference.shard import Shard +from exo.inference.inference_engine import InferenceEngine +from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel +from exo.api.chatgpt_api import resolve_tokenizer +from exo.helpers import DEBUG +from transformers import DynamicCache +from accelerate import disk_offload + +class PyTorchDynamicShardInferenceEngine(InferenceEngine): + """ + PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. + """ + + def __init__(self, shard): + """ + Initialize the inference engine. + + Args: + debug (bool): If True, enables debug logging. Defaults to False. + """ + self.shard = shard + self.model = None + self.tokenizer = None + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + async def infer_prompt( + self, + request_id: str, + shard: Optional[Shard] = None, + prompt: str = "", + image_str: Optional[str] = None, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + + await self.ensure_shard(shard) + + # need to make this so inference_state is not a string + # cant use it with dynamic cache + + tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + tokens = self.model.embed_tokens(tokens) + current_kvs = None + + if DEBUG >= 4: + print("infer_prompt called") + print(f"tokens: {tokens}\n") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + + # convert inference_state or cache from json to DynamicCache + past_kv = DynamicCache() + if inference_state != None: + cache_dict = json.loads(inference_state) + past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] + past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] + + output_data, current_kvs = self.model.forward( + tokens, + past_kv + ) + + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + + if DEBUG >= 4: + print(f"output_data: {output_data}\n") + print(f"output_data.size {output_data.size}\n") + + print(f"finished: {is_finished}") + print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") + print(f"output_data[-1] {output_data[-1]}") + + if output_data.size == 1: + print(f"size 1 output_data.item() {output_data.item()}") + print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] + } + + return ( + output_data, + json.dumps(cache_dict), + is_finished + ) + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + + await self.ensure_shard(shard) + + current_kvs = None + + if input_data.size == 1: + in_tensor = torch.from_numpy( + input_data, + ).unsqueeze(0).long().to(self.device) + else: + in_tensor = torch.from_numpy( + input_data + ).long().to(self.device) + + in_tensor = self.model.embed_tokens(in_tensor) + + if DEBUG >= 4: + print("infer_tensor called") + print(f"input_data: {input_data}") + print(f"input_data.size: {input_data.size}") + print(f"input_tensor: {in_tensor}\n") + print(f"shard: {self.shard}") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + + # convert inference_state or cache from json to DynamicCache + past_kv = DynamicCache() + if inference_state != None: + try: + cache_dict = json.loads(inference_state) + past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] + past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] + + if DEBUG >= 4: + print("Loaded past_kv from JSON") + print(f"past_kv: {past_kv}") + print(f"past_kv.key_cache len: {len(past_kv.key_cache)}") + print(f"past_kv.value_cache len: {len(past_kv.value_cache)}") + except json.JSONDecodeError: + print(f"ERROR DECODING INFERENCE STATE") + + output_data, current_kvs = self.model.forward( + in_tensor, + past_kv + ) + + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + + if DEBUG >= 4: + print(f"in_tensor: {in_tensor}\n") + print(f"output_data: {output_data}\n") + print(f"output_data.size {output_data.size}\n") + print(f"finished: {is_finished}") + print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") + print(f"output_data[-1] {output_data[-1]}") + + if output_data.size == 1: + print(f"size 1 output_data.item() {output_data.item()}") + print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + + + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] + } + + return ( + output_data, + json.dumps(cache_dict), + is_finished + ) + + async def ensure_shard(self, shard: Optional[Shard]): + """ + Ensure the model shard is loaded and ready for inference. + + Args: + shard (Optional[Shard]): Shard information for the model. + """ + # if self.shard == shard: + # return + + if DEBUG >= 4: + print(f"Loading new shard: {shard}") + + if self.model: + if DEBUG >= 2: + print(f"\nCLEARING MODEL {shard.model_id}\n") + print(f"before allocated: {torch.cuda.memory_allocated()}") + print(f"before reserved: {torch.cuda.memory_reserved()}") + + # delete model and free up memory to reload + # self.model.cuda() + # disk_offload(model=self.model, offload_dir="./.offload") + import gc + + del self.model + gc.collect() + torch.cuda.empty_cache() + + if DEBUG >= 2: + print(f"after allocated: {torch.cuda.memory_allocated()}") + print(f"after reserved: {torch.cuda.memory_reserved()}") + + self.shard = shard + self.tokenizer = await resolve_tokenizer(shard.model_id) + self.model = ShardedHuggingFaceModel(shard, self.tokenizer) + + if DEBUG >= 4: + print(f"Shard loaded successfully: {shard}") \ No newline at end of file diff --git a/build/lib/exo/inference/pytorch/model/__init__.py b/build/lib/exo/inference/pytorch/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/pytorch/model/hf.py b/build/lib/exo/inference/pytorch/model/hf.py new file mode 100644 index 00000000..aa2873c5 --- /dev/null +++ b/build/lib/exo/inference/pytorch/model/hf.py @@ -0,0 +1,155 @@ +import torch +import numpy as np +from transformers import AutoModelForCausalLM, DynamicCache, Cache +from exo.inference.shard import Shard +from exo.helpers import DEBUG +from typing import Tuple, Optional, Union, List +from exo.inference.pytorch.model.utils import sample_logits + +TOP_P = 0.75 #0.95 +TOP_K = 20 +TEMP = 0.8 + +class ShardedHuggingFaceModel(torch.nn.Module): + def __init__(self, shard: Shard, tokenizer: any): + super(ShardedHuggingFaceModel, self).__init__() + + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + self.shard = shard + self.tokenizer = tokenizer + + # Load the model + try: + self.llm_model = AutoModelForCausalLM.from_pretrained( + shard.model_id, + torch_dtype=torch.float32, + device_map="auto", + # offload_buffers=True + ) + + # disk_offload(model=self.llm_model, offload_dir="./.offload") + + self.base_model = self.llm_model.model + except Exception as err: + print(f"Error loading model: {err}") + raise + + if DEBUG >= 2: + print(f"\nShardedHuggingFaceModel init with shard {shard}") + print(f"self.llm_model: {self.llm_model}") + print(f"self.base_model: {self.base_model}") + + if DEBUG >= 2: + print(f"full_model.model layer: {len(self.base_model.layers)}") + + # Embeddings and final layer norm + # used for doing what forward LlamaModel does in transformers + self.norm = self.base_model.norm + self.lm_head = self.llm_model.lm_head + self.embed_tokens = self.base_model.embed_tokens + + def forward( + self, + input_ids: torch.tensor, + past_kvs: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + ) -> Tuple[np.ndarray, any]: + """ + Forward through layers using the base model + + Args: + input_ids: tensor input + past_kvs: past key value stores for cache + use_cache: use cache + + Returns: + hidden_states: numpy of states between layers + or logits: numpy of normalization and linearization of last hidden state + past_kvs: DynamicCache of past key values if use_cache is true + + Ref: + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/qwen2/modeling_qwen2.py#L804 + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L887 + """ + if DEBUG >= 4: + print("forward called") + print(f"input_ids: {input_ids}\n") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + + past_kvs = DynamicCache.from_legacy_cache(past_kvs) + past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 + + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + input_ids.shape[1], + device=input_ids.device + ).to(self.device) + + position_ids = cache_position.unsqueeze(0).to(self.device) + + try: + position_embeddings = self.base_model.rotary_emb( + input_ids, + position_ids + ) + except Exception as err: + print(f"rotary_emb not found in base_model") + position_embeddings = None + + # progress through layers + for i in range(self.shard.start_layer, self.shard.end_layer + 1): + decoder_layer = self.base_model.layers[i] + + if DEBUG >= 4: + print("Going through layer") + print(f"{decoder_layer}") + print("input_ids") + print(f"{input_ids}") + + layer_outputs = decoder_layer( + input_ids, + position_ids=position_ids if not position_embeddings else None, + position_embeddings=position_embeddings, + past_key_value=past_kvs, + use_cache=True, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + next_kvs = layer_outputs[1] + + if DEBUG >= 3: + print(f"layer_outputs {layer_outputs}") + + if self.shard.is_last_layer(): + hs_norm = self.norm(hidden_states) + hs_lm_head = self.llm_model.lm_head(hs_norm).float() + + # Use the sampling function with default settings + with torch.no_grad(): + output_token = sample_logits( + hs_lm_head[:, -1, :], + TEMP, + TOP_P, + TOP_K + ).numpy(force=True).flatten() + + if DEBUG >= 2: + print(f"hs_norm: {hs_norm}") + print(f"hs_lm_head: {hs_lm_head}") + print(f"output_token: {output_token}") + + return (output_token, next_kvs) + + with torch.no_grad(): + out_hidden_states = hidden_states.numpy(force=True) + + return ( + out_hidden_states, + next_kvs + ) \ No newline at end of file diff --git a/build/lib/exo/inference/pytorch/model/utils.py b/build/lib/exo/inference/pytorch/model/utils.py new file mode 100644 index 00000000..df84b397 --- /dev/null +++ b/build/lib/exo/inference/pytorch/model/utils.py @@ -0,0 +1,83 @@ +import torch +from torch.nn import functional as F + +def top_p_sampling(scaled_logits: torch.Tensor, top_p: float) -> torch.Tensor: + """ + Apply top-p (nucleus) sampling to logits. + + Args: + scaled_logits (torch.Tensor): The scaled logits from the model's output. + top_p (float): The cumulative probability threshold for top-p filtering. + temp (float): Temperature parameter for softmax distribution reshaping. + + Returns: + torch.Tensor: Token selected based on the top-p criterion. + + Ref: + https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py#L67C1-L97C17 + """ + scaled_logits = torch.where(torch.isnan(scaled_logits), torch.zeros_like(scaled_logits), scaled_logits) + scaled_logits = torch.where(torch.isinf(scaled_logits), torch.full_like(scaled_logits, 1e6), scaled_logits) + + probs = torch.softmax(scaled_logits, dim=-1) + + sorted_probs, sorted_indices = torch.sort( + probs, + descending=True, + dim=-1 + ) + + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + mask = cumulative_probs > top_p + + top_probs = torch.where(mask, torch.zeros_like(sorted_probs), sorted_probs) + sum_probs = top_probs.sum(dim=-1, keepdim=True) + top_probs = torch.where(sum_probs > 0, top_probs / sum_probs, torch.ones_like(top_probs) / top_probs.size(-1)) + + if torch.isnan(top_probs).any() or torch.isinf(top_probs).any(): + print("Warning: Top probabilities contain NaN or Inf values after normalization") + top_probs = torch.where(torch.isnan(top_probs) | torch.isinf(top_probs), + 1.0 / top_probs.size(-1), + top_probs) + + sorted_token = torch.multinomial(top_probs, num_samples=1) + + token = sorted_indices.gather(-1, sorted_token) + + return token.squeeze(-1) + +def sample_logits(logits, temp, top_p, top_k): + """ + Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. + + Args: + logits (torch.Tensor): The logits distribution to sample from. + temp (float): temp for scaling logits. + top_p (float): The cumulative probability threshold for nucleus sampling. + + Returns: + torch.Tensor: The selected token index. + """ + + # Ensure logits are float + logits = logits.float() + + # If temp is very low, just use argmax + if temp == 0: + return logits.argmax(dim=-1) + + scaled_logits = logits/temp + + # top k + if top_k > 0: + top_values, top_indices = torch.topk(scaled_logits, top_k, dim=-1) + scaled_logits = torch.zeros_like(logits).scatter_(-1, top_indices, top_values) + + # Top-p sampling + if 0 < top_p < 1.0: + return top_p_sampling(scaled_logits, top_p) + else: + # random distribution selection + probs = torch.softmax(scaled_logits, dim=-1) + rand_sample = torch.distributions.Categorical(probs) + return rand_sample.sample().squeeze() \ No newline at end of file diff --git a/build/lib/exo/inference/pytorch/test_inference_engine.py b/build/lib/exo/inference/pytorch/test_inference_engine.py new file mode 100644 index 00000000..bacf53bc --- /dev/null +++ b/build/lib/exo/inference/pytorch/test_inference_engine.py @@ -0,0 +1,141 @@ + +import asyncio +from exo.inference.shard import Shard +from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.inference_engine import InferenceEngine +from exo.inference.shard import Shard +from exo.helpers import DEBUG +import os +import numpy as np + +async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int): + # prompt = "Why is the sky blue?" + prompt = "In a single word only, what is the last name of the current president of the USA?" + + # shard = Shard( + # model_id=model_id, + # start_layer=0, + # end_layer=n_layers-1, + # n_layers=n_layers + # ) + + # resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + # "A", + # shard=shard, + # prompt=prompt + # ) + + # print(f"resp_full: {resp_full}") + + # next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + # "A", + # shard=shard, + # input_data=resp_full, + # inference_state=inference_state_full, + # ) + + # print(f"next_resp_full: {next_resp_full}") + + pp = int(n_layers/2) + + resp_shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=pp, + n_layers=n_layers + ) + + resp_shard2 = Shard( + model_id=model_id, + start_layer=pp + 1, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( + "B", + shard=resp_shard, + prompt=prompt + ) + + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp1, + inference_state=inference_state_1, + ) + + # resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + # "B", + # shard=resp_shard, + # input_data=resp2, + # inference_state=inference_state_2, + # ) + + # resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + # "B", + # shard=resp_shard2, + # input_data=resp3, + # inference_state=inference_state_3, + # ) + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) + +if __name__ == '__main__': + # try: + # print(f"\n\n -------- TEST QWEN2 -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Qwen/Qwen2-0.5B-Instruct", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "andrijdavid/Llama3-1B-Base", + # 3 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "meta-llama/Meta-Llama-3.1-8B", + # 32 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Chickaboo/ChickaQ-Large", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") + + try: + print(f"\n\n --------- TEST ambrosfitz/TinyLlama-1.1B-Chat-yawp -------\n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "ambrosfitz/TinyLlama-1.1B-Chat-yawp", + 22 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! ambrosfitz/TinyLlama-1.1B-Chat-yawp TEST FAILED \n{err}\n") + diff --git a/build/lib/exo/inference/shard.py b/build/lib/exo/inference/shard.py new file mode 100644 index 00000000..21b662f6 --- /dev/null +++ b/build/lib/exo/inference/shard.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class Shard: + model_id: str + start_layer: int + end_layer: int + n_layers: int + + def __hash__(self): + return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers)) + + def is_first_layer(self) -> bool: + return self.start_layer == 0 + + def is_last_layer(self) -> bool: + return self.end_layer == self.n_layers - 1 + + def get_layer_count(self) -> int: + return self.end_layer - self.start_layer + 1 + + def to_dict(self) -> dict: + return { + "model_id": self.model_id, + "start_layer": self.start_layer, + "end_layer": self.end_layer, + "n_layers": self.n_layers, + } + + def from_dict(data: dict) -> 'Shard': + return Shard(**data) + + def overlaps(self, other: 'Shard') -> bool: + return shards_overlap(self, other) + + +def shards_overlap(shard1: Shard, shard2: Shard) -> bool: + return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)) diff --git a/build/lib/exo/inference/test_inference_engine.py b/build/lib/exo/inference/test_inference_engine.py new file mode 100644 index 00000000..e57c608d --- /dev/null +++ b/build/lib/exo/inference/test_inference_engine.py @@ -0,0 +1,64 @@ +from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.inference_engine import InferenceEngine +from exo.inference.shard import Shard +from exo.helpers import DEBUG +import os +import asyncio +import numpy as np + + +# An inference engine should work the same for any number of Shards, as long as the Shards are continuous. +async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str): + prompt = "In a single word only, what is the last name of the current president of the USA?" + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt) + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + "A", + shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), + input_data=resp_full, + inference_state=inference_state_full, + ) + + pp = 15 + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt) + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32), + input_data=resp1, + inference_state=inference_state_1, + ) + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), + input_data=resp2, + inference_state=inference_state_2, + ) + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32), + input_data=resp3, + inference_state=inference_state_3, + ) + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) + + +asyncio.run(test_inference_engine( + MLXDynamicShardInferenceEngine(HFShardDownloader()), + MLXDynamicShardInferenceEngine(HFShardDownloader()), + "mlx-community/Meta-Llama-3-8B-Instruct-4bit", +)) + +if os.getenv("RUN_TINYGRAD", default="0") == "1": + import tinygrad + import os + from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine + tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) + asyncio.run( + test_inference_engine( + TinygradDynamicShardInferenceEngine(HFShardDownloader()), + TinygradDynamicShardInferenceEngine(HFShardDownloader()), + "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", + ) + ) diff --git a/build/lib/exo/inference/tokenizers.py b/build/lib/exo/inference/tokenizers.py new file mode 100644 index 00000000..9accd943 --- /dev/null +++ b/build/lib/exo/inference/tokenizers.py @@ -0,0 +1,45 @@ +import traceback +from aiofiles import os as aios +from transformers import AutoTokenizer, AutoProcessor +from exo.download.hf.hf_helpers import get_local_snapshot_dir +from exo.helpers import DEBUG + +async def resolve_tokenizer(model_id: str): + local_path = await get_local_snapshot_dir(model_id) + if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}") + try: + if await aios.path.exists(local_path): + if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}") + return await _resolve_tokenizer(local_path) + except: + if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...") + if DEBUG >= 5: traceback.print_exc() + return await _resolve_tokenizer(model_id) + +async def _resolve_tokenizer(model_id_or_local_path: str): + try: + if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}") + if "Mistral-Large" in str(model_id_or_local_path): + use_fast = True + else: + use_fast = False + processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=use_fast) + if not hasattr(processor, 'eos_token_id'): + processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id + if not hasattr(processor, 'encode'): + processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode + if not hasattr(processor, 'decode'): + processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode + return processor + except Exception as e: + if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}") + if DEBUG >= 4: print(traceback.format_exc()) + + try: + if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}") + return AutoTokenizer.from_pretrained(model_id_or_local_path) + except Exception as e: + if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}") + if DEBUG >= 4: print(traceback.format_exc()) + + raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}") diff --git a/build/lib/exo/models.py b/build/lib/exo/models.py new file mode 100644 index 00000000..137b881c --- /dev/null +++ b/build/lib/exo/models.py @@ -0,0 +1,44 @@ +from exo.inference.shard import Shard + +model_base_shards = { + ### llama + "llama-3.1-8b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), + "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=32), + }, + "llama-3.1-70b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80), + }, + "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),}, + "llama-3-8b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), + "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32), + }, + "llama-3-70b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80), + }, + "llama-3-2B-Base": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=6), + }, + "llama-3-1B-Base": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), + }, + "TinyLlama-1.1B-Chat-yaw": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="ambrosfitz/TinyLlama-1.1B-Chat-yawp", start_layer=0, end_layer=0, n_layers=22), + }, + ### mistral + "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, + "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),}, + ### deepseek v2 + "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),}, + ### llava + "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),}, + ### qwen + "Qwen2-0.5B-Instruct": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), + }, + +} diff --git a/build/lib/exo/networking/__init__.py b/build/lib/exo/networking/__init__.py new file mode 100644 index 00000000..44a10a30 --- /dev/null +++ b/build/lib/exo/networking/__init__.py @@ -0,0 +1,5 @@ +from .discovery import Discovery +from .peer_handle import PeerHandle +from .server import Server + +__all__ = ["Discovery", "PeerHandle", "Server"] diff --git a/build/lib/exo/networking/discovery.py b/build/lib/exo/networking/discovery.py new file mode 100644 index 00000000..cdcbfabc --- /dev/null +++ b/build/lib/exo/networking/discovery.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from typing import List +from .peer_handle import PeerHandle + + +class Discovery(ABC): + @abstractmethod + async def start(self) -> None: + pass + + @abstractmethod + async def stop(self) -> None: + pass + + @abstractmethod + async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: + pass diff --git a/build/lib/exo/networking/grpc/__init__.py b/build/lib/exo/networking/grpc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/networking/grpc/grpc_discovery.py b/build/lib/exo/networking/grpc/grpc_discovery.py new file mode 100644 index 00000000..eb08a838 --- /dev/null +++ b/build/lib/exo/networking/grpc/grpc_discovery.py @@ -0,0 +1,188 @@ +import asyncio +import json +import socket +import time +from typing import List, Dict, Callable, Tuple, Coroutine +from ..discovery import Discovery +from ..peer_handle import PeerHandle +from .grpc_peer_handle import GRPCPeerHandle +from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES +from exo import DEBUG_DISCOVERY + + +class ListenProtocol(asyncio.DatagramProtocol): + def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]): + super().__init__() + self.on_message = on_message + self.loop = asyncio.get_event_loop() + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + asyncio.create_task(self.on_message(data, addr)) + + +class GRPCDiscovery(Discovery): + def __init__( + self, + node_id: str, + node_port: int, + listen_port: int, + broadcast_port: int = None, + broadcast_interval: int = 1, + device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES, + discovery_timeout: int = 30, + ): + self.node_id = node_id + self.node_port = node_port + self.device_capabilities = device_capabilities + self.listen_port = listen_port + self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port + self.broadcast_interval = broadcast_interval + self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {} + self.broadcast_task = None + self.listen_task = None + self.cleanup_task = None + self.discovery_timeout = discovery_timeout + + async def start(self): + self.device_capabilities = device_capabilities() + self.broadcast_task = asyncio.create_task(self.task_broadcast_presence()) + self.listen_task = asyncio.create_task(self.task_listen_for_peers()) + self.cleanup_task = asyncio.create_task(self.task_cleanup_peers()) + + async def stop(self): + if self.broadcast_task: + self.broadcast_task.cancel() + if self.listen_task: + self.listen_task.cancel() + if self.cleanup_task: + self.cleanup_task.cancel() + if self.broadcast_task or self.listen_task or self.cleanup_task: + await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True) + + async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: + if DEBUG_DISCOVERY >= 2: + print("Starting peer discovery process...") + + if wait_for_peers > 0: + while len(self.known_peers) == 0: + if DEBUG_DISCOVERY >= 2: + print("No peers discovered yet, retrying in 1 second...") + await asyncio.sleep(1) # Keep trying to find peers + if DEBUG_DISCOVERY >= 2: + print(f"Discovered first peer: {next(iter(self.known_peers.values()))}") + + grace_period = 5 # seconds + while True: + initial_peer_count = len(self.known_peers) + if DEBUG_DISCOVERY >= 2: + print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...") + if len(self.known_peers) == initial_peer_count: + if wait_for_peers > 0: + await asyncio.sleep(grace_period) + if DEBUG_DISCOVERY >= 2: + print(f"Waiting additional {wait_for_peers} seconds for more peers.") + wait_for_peers = 0 + else: + if DEBUG_DISCOVERY >= 2: + print("No new peers discovered in the last grace period. Ending discovery process.") + break # No new peers found in the grace period, we are done + + return [peer_handle for peer_handle, _, _ in self.known_peers.values()] + + async def task_broadcast_presence(self): + transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET) + sock = transport.get_extra_info("socket") + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + + message = json.dumps({ + "type": "discovery", + "node_id": self.node_id, + "grpc_port": self.node_port, + "device_capabilities": self.device_capabilities.to_dict(), + }).encode("utf-8") + + while True: + try: + if DEBUG_DISCOVERY >= 3: + print(f"Broadcast presence: {message}") + transport.sendto(message, ("", self.broadcast_port)) + await asyncio.sleep(self.broadcast_interval) + except Exception as e: + print(f"Error in broadcast presence: {e}") + import traceback + + print(traceback.format_exc()) + + async def on_listen_message(self, data, addr): + if not data: + return + + decoded_data = data.decode("utf-8", errors="ignore") + + # Check if the decoded data starts with a valid JSON character + if not (decoded_data.strip() and decoded_data.strip()[0] in "{["): + if DEBUG_DISCOVERY >= 2: + print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}") + return + + try: + decoder = json.JSONDecoder(strict=False) + message = decoder.decode(decoded_data) + except json.JSONDecodeError as e: + if DEBUG_DISCOVERY >= 2: + print(f"Error decoding JSON data from {addr}: {e}") + return + + if DEBUG_DISCOVERY >= 2: + print(f"received from peer {addr}: {message}") + + if message["type"] == "discovery" and message["node_id"] != self.node_id: + peer_id = message["node_id"] + peer_host = addr[0] + peer_port = message["grpc_port"] + device_capabilities = DeviceCapabilities(**message["device_capabilities"]) + if peer_id not in self.known_peers: + self.known_peers[peer_id] = ( + GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), + time.time(), + time.time(), + ) + if DEBUG_DISCOVERY >= 2: + print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}") + self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time()) + + async def task_listen_for_peers(self): + await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port)) + if DEBUG_DISCOVERY >= 2: + print("Started listen task") + + async def task_cleanup_peers(self): + while True: + try: + current_time = time.time() + peers_to_remove = [ + peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() + if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout + ] + if DEBUG_DISCOVERY >= 2: + print( + "Peer statuses:", + {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" + for peer_handle, connected_at, last_seen in self.known_peers.values()}, + ) + if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: + print(f"Cleaning up peers: {peers_to_remove}") + for peer_id in peers_to_remove: + if peer_id in self.known_peers: + del self.known_peers[peer_id] + if DEBUG_DISCOVERY >= 2: + print(f"Removed peer {peer_id} due to inactivity.") + await asyncio.sleep(self.broadcast_interval) + except Exception as e: + print(f"Error in cleanup peers: {e}") + import traceback + + print(traceback.format_exc()) diff --git a/build/lib/exo/networking/grpc/grpc_peer_handle.py b/build/lib/exo/networking/grpc/grpc_peer_handle.py new file mode 100644 index 00000000..0629dc77 --- /dev/null +++ b/build/lib/exo/networking/grpc/grpc_peer_handle.py @@ -0,0 +1,109 @@ +import grpc +import numpy as np +from typing import Optional, Tuple, List + +# These would be generated from the .proto file +from . import node_service_pb2 +from . import node_service_pb2_grpc + +from ..peer_handle import PeerHandle +from exo.inference.shard import Shard +from exo.topology.topology import Topology +from exo.topology.device_capabilities import DeviceCapabilities + + +class GRPCPeerHandle(PeerHandle): + def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities): + self._id = _id + self.address = address + self._device_capabilities = device_capabilities + self.channel = None + self.stub = None + + def id(self) -> str: + return self._id + + def device_capabilities(self) -> DeviceCapabilities: + return self._device_capabilities + + async def connect(self): + self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)]) + self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel) + + async def is_connected(self) -> bool: + return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY + + async def disconnect(self): + if self.channel: + await self.channel.close() + self.channel = None + self.stub = None + + async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + request = node_service_pb2.PromptRequest( + prompt=prompt, + image_str=image_str, + shard=node_service_pb2.Shard( + model_id=shard.model_id, + start_layer=shard.start_layer, + end_layer=shard.end_layer, + n_layers=shard.n_layers, + ), + request_id=request_id, + inference_state=inference_state, + ) + response = await self.stub.SendPrompt(request) + + if not response.tensor_data or not response.shape or not response.dtype: + return None + + return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape) + + async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + request = node_service_pb2.TensorRequest( + shard=node_service_pb2.Shard( + model_id=shard.model_id, + start_layer=shard.start_layer, + end_layer=shard.end_layer, + n_layers=shard.n_layers, + ), + tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)), + request_id=request_id, + inference_state=inference_state, + ) + response = await self.stub.SendTensor(request) + + if not response.tensor_data or not response.shape or not response.dtype: + return None + + return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape) + + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + request = node_service_pb2.GetInferenceResultRequest(request_id=request_id) + response = await self.stub.GetInferenceResult(request) + if response.tensor is None: + return None, response.is_finished + return ( + np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape), + response.is_finished, + ) + + async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: + request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth) + response = await self.stub.CollectTopology(request) + topology = Topology() + for node_id, capabilities in response.nodes.items(): + device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops) + topology.update_node(node_id, device_capabilities) + for node_id, peers in response.peer_graph.items(): + for peer_id in peers.peer_ids: + topology.add_edge(node_id, peer_id) + return topology + + async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished) + await self.stub.SendResult(request) + + async def send_opaque_status(self, request_id: str, status: str) -> None: + request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status) + await self.stub.SendOpaqueStatus(request) diff --git a/build/lib/exo/networking/grpc/grpc_server.py b/build/lib/exo/networking/grpc/grpc_server.py new file mode 100644 index 00000000..1481ef51 --- /dev/null +++ b/build/lib/exo/networking/grpc/grpc_server.py @@ -0,0 +1,118 @@ +import grpc +from concurrent import futures +import numpy as np +from asyncio import CancelledError + +from . import node_service_pb2 +from . import node_service_pb2_grpc +from exo import DEBUG +from exo.inference.shard import Shard +from exo.orchestration import Node + + +class GRPCServer(node_service_pb2_grpc.NodeServiceServicer): + def __init__(self, node: Node, host: str, port: int): + self.node = node + self.host = host + self.port = port + self.server = None + + async def start(self) -> None: + self.server = grpc.aio.server( + futures.ThreadPoolExecutor(max_workers=10), + options=[ + ("grpc.max_metadata_size", 32*1024*1024), + ("grpc.max_send_message_length", 128*1024*1024), + ("grpc.max_receive_message_length", 128*1024*1024), + ], + ) + node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server) + listen_addr = f"{self.host}:{self.port}" + self.server.add_insecure_port(listen_addr) + await self.server.start() + if DEBUG >= 1: print(f"Server started, listening on {listen_addr}") + + async def stop(self) -> None: + if self.server: + try: + await self.server.stop(grace=5) + await self.server.wait_for_termination() + except CancelledError: + pass + if DEBUG >= 1: print("Server stopped and all connections are closed") + + async def SendPrompt(self, request, context): + shard = Shard( + model_id=request.shard.model_id, + start_layer=request.shard.start_layer, + end_layer=request.shard.end_layer, + n_layers=request.shard.n_layers, + ) + prompt = request.prompt + image_str = request.image_str + request_id = request.request_id + result = await self.node.process_prompt(shard, prompt, image_str, request_id) + if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}") + tensor_data = result.tobytes() if result is not None else None + return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() + + async def SendTensor(self, request, context): + shard = Shard( + model_id=request.shard.model_id, + start_layer=request.shard.start_layer, + end_layer=request.shard.end_layer, + n_layers=request.shard.n_layers, + ) + tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape) + request_id = request.request_id + inference_state = request.inference_state + + result = await self.node.process_tensor(shard, tensor, request_id, inference_state) + if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}") + tensor_data = result.tobytes() if result is not None else None + return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() + + async def GetInferenceResult(self, request, context): + request_id = request.request_id + result = await self.node.get_inference_result(request_id) + if DEBUG >= 5: print(f"GetInferenceResult {request_id=}: {result}") + tensor_data = result[0].tobytes() if result[0] is not None else None + return ( + node_service_pb2.InferenceResult( + tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)), + is_finished=result[1], + ) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1]) + ) + + async def CollectTopology(self, request, context): + max_depth = request.max_depth + visited = set(request.visited) + topology = await self.node.collect_topology(visited, max_depth) + nodes = { + node_id: + node_service_pb2.DeviceCapabilities( + model=cap.model, + chip=cap.chip, + memory=cap.memory, + flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8), + ) + for node_id, cap in topology.nodes.items() + } + peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()} + if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}") + return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph) + + async def SendResult(self, request, context): + request_id = request.request_id + result = request.result + is_finished = request.is_finished + if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}") + self.node.on_token.trigger_all(request_id, result, is_finished) + return node_service_pb2.Empty() + + async def SendOpaqueStatus(self, request, context): + request_id = request.request_id + status = request.status + if DEBUG >= 5: print(f"Received SendOpaqueStatus request: {request_id=} {status=}") + self.node.on_opaque_status.trigger_all(request_id, status) + return node_service_pb2.Empty() diff --git a/build/lib/exo/networking/grpc/node_service_pb2.py b/build/lib/exo/networking/grpc/node_service_pb2.py new file mode 100644 index 00000000..cae2d080 --- /dev/null +++ b/build/lib/exo/networking/grpc/node_service_pb2.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: node_service.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xc3\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x16\n\timage_str\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0c\n\n_image_strB\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x07\n\x05\x45mpty2\xde\x03\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TOPOLOGY_NODESENTRY']._loaded_options = None + _globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001' + _globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001' + _globals['_SHARD']._serialized_start = 36 + _globals['_SHARD']._serialized_end = 119 + _globals['_PROMPTREQUEST']._serialized_start = 122 + _globals['_PROMPTREQUEST']._serialized_end = 317 + _globals['_TENSORREQUEST']._serialized_start = 320 + _globals['_TENSORREQUEST']._serialized_end = 499 + _globals['_GETINFERENCERESULTREQUEST']._serialized_start = 501 + _globals['_GETINFERENCERESULTREQUEST']._serialized_end = 548 + _globals['_INFERENCERESULT']._serialized_start = 550 + _globals['_INFERENCERESULT']._serialized_end = 642 + _globals['_TENSOR']._serialized_start = 644 + _globals['_TENSOR']._serialized_end = 703 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start = 705 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end = 765 + _globals['_TOPOLOGY']._serialized_start = 768 + _globals['_TOPOLOGY']._serialized_end = 1038 + _globals['_TOPOLOGY_NODESENTRY']._serialized_start = 889 + _globals['_TOPOLOGY_NODESENTRY']._serialized_end = 967 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start = 969 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end = 1038 + _globals['_PEERS']._serialized_start = 1040 + _globals['_PEERS']._serialized_end = 1065 + _globals['_DEVICEFLOPS']._serialized_start = 1067 + _globals['_DEVICEFLOPS']._serialized_end = 1122 + _globals['_DEVICECAPABILITIES']._serialized_start = 1124 + _globals['_DEVICECAPABILITIES']._serialized_end = 1231 + _globals['_SENDRESULTREQUEST']._serialized_start = 1233 + _globals['_SENDRESULTREQUEST']._serialized_end = 1309 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start = 1311 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end = 1372 + _globals['_EMPTY']._serialized_start = 1374 + _globals['_EMPTY']._serialized_end = 1381 + _globals['_NODESERVICE']._serialized_start = 1384 + _globals['_NODESERVICE']._serialized_end = 1862 +# @@protoc_insertion_point(module_scope) diff --git a/build/lib/exo/networking/grpc/node_service_pb2_grpc.py b/build/lib/exo/networking/grpc/node_service_pb2_grpc.py new file mode 100644 index 00000000..ea1d3c98 --- /dev/null +++ b/build/lib/exo/networking/grpc/node_service_pb2_grpc.py @@ -0,0 +1,272 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from . import node_service_pb2 as node__service__pb2 + +GRPC_GENERATED_VERSION = '1.64.1' +GRPC_VERSION = grpc.__version__ +EXPECTED_ERROR_RELEASE = '1.65.0' +SCHEDULED_RELEASE_DATE = 'June 25, 2024' +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + warnings.warn( + f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning + ) + + +class NodeServiceStub(object): + """Missing associated documentation comment in .proto file.""" + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendPrompt = channel.unary_unary( + '/node_service.NodeService/SendPrompt', + request_serializer=node__service__pb2.PromptRequest.SerializeToString, + response_deserializer=node__service__pb2.Tensor.FromString, + _registered_method=True + ) + self.SendTensor = channel.unary_unary( + '/node_service.NodeService/SendTensor', + request_serializer=node__service__pb2.TensorRequest.SerializeToString, + response_deserializer=node__service__pb2.Tensor.FromString, + _registered_method=True + ) + self.GetInferenceResult = channel.unary_unary( + '/node_service.NodeService/GetInferenceResult', + request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString, + response_deserializer=node__service__pb2.InferenceResult.FromString, + _registered_method=True + ) + self.CollectTopology = channel.unary_unary( + '/node_service.NodeService/CollectTopology', + request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString, + response_deserializer=node__service__pb2.Topology.FromString, + _registered_method=True + ) + self.SendResult = channel.unary_unary( + '/node_service.NodeService/SendResult', + request_serializer=node__service__pb2.SendResultRequest.SerializeToString, + response_deserializer=node__service__pb2.Empty.FromString, + _registered_method=True + ) + self.SendOpaqueStatus = channel.unary_unary( + '/node_service.NodeService/SendOpaqueStatus', + request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString, + response_deserializer=node__service__pb2.Empty.FromString, + _registered_method=True + ) + + +class NodeServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + def SendPrompt(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendTensor(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetInferenceResult(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CollectTopology(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendResult(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendOpaqueStatus(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_NodeServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendPrompt': + grpc.unary_unary_rpc_method_handler( + servicer.SendPrompt, + request_deserializer=node__service__pb2.PromptRequest.FromString, + response_serializer=node__service__pb2.Tensor.SerializeToString, + ), + 'SendTensor': + grpc.unary_unary_rpc_method_handler( + servicer.SendTensor, + request_deserializer=node__service__pb2.TensorRequest.FromString, + response_serializer=node__service__pb2.Tensor.SerializeToString, + ), + 'GetInferenceResult': + grpc.unary_unary_rpc_method_handler( + servicer.GetInferenceResult, + request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString, + response_serializer=node__service__pb2.InferenceResult.SerializeToString, + ), + 'CollectTopology': + grpc.unary_unary_rpc_method_handler( + servicer.CollectTopology, + request_deserializer=node__service__pb2.CollectTopologyRequest.FromString, + response_serializer=node__service__pb2.Topology.SerializeToString, + ), + 'SendResult': + grpc.unary_unary_rpc_method_handler( + servicer.SendResult, + request_deserializer=node__service__pb2.SendResultRequest.FromString, + response_serializer=node__service__pb2.Empty.SerializeToString, + ), + 'SendOpaqueStatus': + grpc.unary_unary_rpc_method_handler( + servicer.SendOpaqueStatus, + request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString, + response_serializer=node__service__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers) + + +# This class is part of an EXPERIMENTAL API. +class NodeService(object): + """Missing associated documentation comment in .proto file.""" + @staticmethod + def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/SendPrompt', + node__service__pb2.PromptRequest.SerializeToString, + node__service__pb2.Tensor.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/SendTensor', + node__service__pb2.TensorRequest.SerializeToString, + node__service__pb2.Tensor.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/GetInferenceResult', + node__service__pb2.GetInferenceResultRequest.SerializeToString, + node__service__pb2.InferenceResult.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/CollectTopology', + node__service__pb2.CollectTopologyRequest.SerializeToString, + node__service__pb2.Topology.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/SendResult', + node__service__pb2.SendResultRequest.SerializeToString, + node__service__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/SendOpaqueStatus', + node__service__pb2.SendOpaqueStatusRequest.SerializeToString, + node__service__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) diff --git a/build/lib/exo/networking/grpc/test_grpc_discovery.py b/build/lib/exo/networking/grpc/test_grpc_discovery.py new file mode 100644 index 00000000..13372bbb --- /dev/null +++ b/build/lib/exo/networking/grpc/test_grpc_discovery.py @@ -0,0 +1,22 @@ +import asyncio +import unittest +from .grpc_discovery import GRPCDiscovery + + +class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679) + self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678) + await self.node1.start() + await self.node2.start() + + async def asyncTearDown(self): + await self.node1.stop() + await self.node2.stop() + + async def test_discovery(self): + await asyncio.sleep(4) + + # Check discovered peers + print("Node1 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()])) + print("Node2 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()])) diff --git a/build/lib/exo/networking/peer_handle.py b/build/lib/exo/networking/peer_handle.py new file mode 100644 index 00000000..cf232d00 --- /dev/null +++ b/build/lib/exo/networking/peer_handle.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple, List +import numpy as np +from exo.inference.shard import Shard +from exo.topology.device_capabilities import DeviceCapabilities +from exo.topology.topology import Topology + + +class PeerHandle(ABC): + @abstractmethod + def id(self) -> str: + pass + + @abstractmethod + def device_capabilities(self) -> DeviceCapabilities: + pass + + @abstractmethod + async def connect(self) -> None: + pass + + @abstractmethod + async def is_connected(self) -> bool: + pass + + @abstractmethod + async def disconnect(self) -> None: + pass + + @abstractmethod + async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + pass + + @abstractmethod + async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + pass + + @abstractmethod + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + pass + + @abstractmethod + async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: + pass + + @abstractmethod + async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + pass diff --git a/build/lib/exo/networking/server.py b/build/lib/exo/networking/server.py new file mode 100644 index 00000000..8e7f9812 --- /dev/null +++ b/build/lib/exo/networking/server.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + + +class Server(ABC): + @abstractmethod + async def start(self) -> None: + pass + + @abstractmethod + async def stop(self) -> None: + pass diff --git a/build/lib/exo/orchestration/__init__.py b/build/lib/exo/orchestration/__init__.py new file mode 100644 index 00000000..478af537 --- /dev/null +++ b/build/lib/exo/orchestration/__init__.py @@ -0,0 +1,4 @@ +from .node import Node +from .standard_node import StandardNode + +__all__ = ["Node", "StandardNode"] diff --git a/build/lib/exo/orchestration/node.py b/build/lib/exo/orchestration/node.py new file mode 100644 index 00000000..60b72974 --- /dev/null +++ b/build/lib/exo/orchestration/node.py @@ -0,0 +1,47 @@ +from typing import Optional, Tuple, List +import numpy as np +from abc import ABC, abstractmethod +from exo.helpers import AsyncCallbackSystem +from exo.inference.shard import Shard +from exo.topology.topology import Topology + + +class Node(ABC): + @abstractmethod + async def start(self, wait_for_peers: int = 0) -> None: + pass + + @abstractmethod + async def stop(self) -> None: + pass + + @abstractmethod + async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + pass + + @abstractmethod + async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + pass + + @abstractmethod + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + pass + + @abstractmethod + async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology: + pass + + @property + @abstractmethod + def current_topology(self) -> Topology: + pass + + @property + @abstractmethod + def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]: + pass + + @property + @abstractmethod + def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]: + pass diff --git a/build/lib/exo/orchestration/standard_node.py b/build/lib/exo/orchestration/standard_node.py new file mode 100644 index 00000000..b968b659 --- /dev/null +++ b/build/lib/exo/orchestration/standard_node.py @@ -0,0 +1,385 @@ +import numpy as np +import json +import asyncio +import uuid +import time +import traceback +from typing import List, Dict, Optional, Tuple, Union +from exo.networking import Discovery, PeerHandle, Server +from exo.inference.inference_engine import InferenceEngine, Shard +from .node import Node +from exo.topology.topology import Topology +from exo.topology.device_capabilities import device_capabilities +from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards +from exo import DEBUG +from exo.helpers import AsyncCallbackSystem +from exo.viz.topology_viz import TopologyViz +from exo.download.hf.hf_helpers import RepoProgressEvent + + +class StandardNode(Node): + def __init__( + self, + _id: str, + server: Server, + inference_engine: InferenceEngine, + discovery: Discovery, + partitioning_strategy: PartitioningStrategy = None, + max_generate_tokens: int = 1024, + chatgpt_api_endpoints: List[str] = [], + web_chat_urls: List[str] = [], + disable_tui: Optional[bool] = False, + topology_viz: Optional[TopologyViz] = None, + ): + self.id = _id + self.inference_engine = inference_engine + self.server = server + self.discovery = discovery + self.partitioning_strategy = partitioning_strategy + self.peers: List[PeerHandle] = {} + self.topology: Topology = Topology() + self.device_capabilities = device_capabilities() + self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {} + self.max_generate_tokens = max_generate_tokens + self.topology_viz = topology_viz + self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]() + self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]() + self._on_opaque_status.register("node_status").on_next(self.on_node_status) + self.node_download_progress: Dict[str, RepoProgressEvent] = {} + + async def start(self, wait_for_peers: int = 0) -> None: + await self.server.start() + await self.discovery.start() + await self.update_peers(wait_for_peers) + await self.collect_topology() + if DEBUG >= 2: print(f"Collected topology: {self.topology}") + asyncio.create_task(self.periodic_topology_collection(5)) + + async def stop(self) -> None: + await self.discovery.stop() + await self.server.stop() + + def on_node_status(self, request_id, opaque_status): + try: + status_data = json.loads(opaque_status) + if status_data.get("type", "") == "node_status": + if status_data.get("status", "").startswith("start_"): + self.current_topology.active_node_id = status_data.get("node_id") + elif status_data.get("status", "").startswith("end_"): + if status_data.get("node_id") == self.current_topology.active_node_id: + self.current_topology.active_node_id = None + download_progress = None + if status_data.get("type", "") == "download_progress": + if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}") + download_progress = RepoProgressEvent.from_dict(status_data.get('progress')) + self.node_download_progress[status_data.get('node_id')] = download_progress + if self.topology_viz: + self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress) + except Exception as e: + if DEBUG >= 1: print(f"Error updating visualization: {e}") + if DEBUG >= 1: traceback.print_exc() + + async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + shard = self.get_current_shard(base_shard) + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "start_process_prompt", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "prompt": prompt, + "image_str": image_str, + "inference_state": inference_state, + "request_id": request_id, + }), + ) + ) + start_time = time.perf_counter_ns() + resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state) + end_time = time.perf_counter_ns() + elapsed_time_ns = end_time - start_time + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "end_process_prompt", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "prompt": prompt, + "image_str": image_str, + "inference_state": inference_state, + "request_id": request_id, + "elapsed_time_ns": elapsed_time_ns, + "result_size": resp.size if resp is not None else 0, + }), + ) + ) + return resp + + async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + if request_id is None: + request_id = str(uuid.uuid4()) + if request_id not in self.buffered_token_output: + self.buffered_token_output[request_id] = ([], False) + shard = self.get_current_shard(base_shard) + + if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}") + if shard.start_layer != 0: + if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}") + await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state) + return + + result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state) + is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens + if is_finished: + self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) + asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity + + if result.size == 1: + self.buffered_token_output[request_id][0].append(result.item()) + self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished) + + if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") + + if not is_finished: + asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state)) + + return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None + + async def process_tensor( + self, + base_shard: Shard, + tensor: np.ndarray, + request_id: Optional[str] = None, + inference_state: Optional[str] = None, + ) -> Optional[np.ndarray]: + shard = self.get_current_shard(base_shard) + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "start_process_tensor", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "tensor_size": tensor.size, + "tensor_shape": tensor.shape, + "request_id": request_id, + "inference_state": inference_state, + }), + ) + ) + start_time = time.perf_counter_ns() + resp = await self._process_tensor(shard, tensor, request_id, inference_state) + end_time = time.perf_counter_ns() + elapsed_time_ns = end_time - start_time + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "end_process_tensor", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "request_id": request_id, + "elapsed_time_ns": elapsed_time_ns, + "result_size": resp.size if resp is not None else 0, + }), + ) + ) + return resp + + async def _process_tensor( + self, + base_shard: Shard, + tensor: np.ndarray, + request_id: Optional[str] = None, + inference_state: Optional[str] = None, + ) -> Optional[np.ndarray]: + if request_id is None: + request_id = str(uuid.uuid4()) + if request_id not in self.buffered_token_output: + self.buffered_token_output[request_id] = ([], False) + shard = self.get_current_shard(base_shard) + + try: + if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}") + result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state) + is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens + if is_finished: + self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) + asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity + + if result.size == 1: # we got a new token out + self.buffered_token_output[request_id][0].append(result.item()) + self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished) + if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") + + if not is_finished: + asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state)) + + return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None + except Exception as e: + print(f"Error processing tensor for shard {shard}: {e}") + traceback.print_exc() + return None + + async def forward_to_next_shard( + self, + base_shard: Shard, + tensor_or_prompt: Union[np.ndarray, str], + request_id: str, + image_str: Optional[str] = None, + inference_state: Optional[str] = None, + ) -> None: + if not self.partitioning_strategy: + if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.") + return + shard = self.get_current_shard(base_shard) + + partitions = self.partitioning_strategy.partition(self.topology) + shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id) + current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None) + if DEBUG >= 1: print(f"Current partition index: {current_partition_index}") + if current_partition_index is not None: + next_partition_index = (current_partition_index+1) % len(partitions) + next_partition: Partition = partitions[next_partition_index] + next_shard = shards[next_partition_index] + if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}") + + if next_partition.node_id == self.id: + if isinstance(tensor_or_prompt, np.ndarray): + await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state) + else: + await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state) + return + + target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None) + if not target_peer: + raise ValueError(f"Peer for {next_partition} not found") + + if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}") + + if isinstance(tensor_or_prompt, np.ndarray): + await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state) + else: + await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state) + + def get_current_shard(self, base_shard: Shard) -> Shard: + partitions = self.partitioning_strategy.partition(self.topology) + shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id) + current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None) + if current_partition_index is None: + raise ValueError(f"No current partition found for node: {self.id}") + return shards[current_partition_index] + + async def update_peers(self, wait_for_peers: int = 0) -> None: + self.peers = await self.discovery.discover_peers(wait_for_peers) + for peer in self.peers: + is_connected = await peer.is_connected() + if DEBUG >= 2 and is_connected: + print(f"Already connected to {peer.id()}: {is_connected}") + if not is_connected: + if DEBUG >= 2: print(f"Connecting to {peer.id()}...") + await peer.connect() + if DEBUG >= 1: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})") + + async def periodic_topology_collection(self, interval: int): + while True: + await asyncio.sleep(interval) + try: + await self.update_peers() + await self.collect_topology() + except Exception as e: + print(f"Error collecting topology: {e}") + traceback.print_exc() + + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + if request_id not in self.buffered_token_output: + return None, False + return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1] + + async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology: + next_topology = Topology() + next_topology.update_node(self.id, self.device_capabilities) + + if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}") + + prev_visited = visited.copy() + # TODO: should we add our own peer id here? + visited.update(p.id() for p in self.peers) + + for peer in self.peers: + next_topology.update_node(peer.id(), peer.device_capabilities()) + next_topology.add_edge(self.id, peer.id()) + + if peer.id() in prev_visited: + continue + + if max_depth <= 0: + if DEBUG >= 2: print("Max depth reached. Skipping...") + continue + + try: + other_topology = await peer.collect_topology(visited, max_depth=max_depth - 1) + if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}") + self.topology.merge(other_topology) + except Exception as e: + print(f"Error collecting topology from {peer.id()}: {e}") + + next_topology.active_node_id = self.topology.active_node_id # this is not so clean. + self.topology = next_topology + if self.topology_viz: + self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id) + return next_topology + + @property + def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]: + return self._on_token + + @property + def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]: + return self._on_opaque_status + + def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None: + if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}") + self.on_token.trigger_all(request_id, tokens, is_finished) + + async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + async def send_result_to_peer(peer): + try: + await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0) + except asyncio.TimeoutError: + print(f"Timeout broadcasting result to {peer.id()}") + except Exception as e: + print(f"Error broadcasting result to {peer.id()}: {e}") + traceback.print_exc() + + await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True) + + async def broadcast_opaque_status(self, request_id: str, status: str) -> None: + if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}") + + async def send_status_to_peer(peer): + try: + await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0) + except asyncio.TimeoutError: + print(f"Timeout sending opaque status to {peer.id()}") + except Exception as e: + print(f"Error sending opaque status to {peer.id()}: {e}") + traceback.print_exc() + + await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True) + # in the case of opaque status, we also want to receive our own opaque statuses + self.on_opaque_status.trigger_all(request_id, status) + + @property + def current_topology(self) -> Topology: + return self.topology diff --git a/build/lib/exo/orchestration/test_node.py b/build/lib/exo/orchestration/test_node.py new file mode 100644 index 00000000..230ef0cf --- /dev/null +++ b/build/lib/exo/orchestration/test_node.py @@ -0,0 +1,57 @@ +import unittest +from unittest.mock import Mock, AsyncMock +import numpy as np + +from .standard_node import StandardNode +from exo.networking.peer_handle import PeerHandle + + +class TestNode(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.mock_inference_engine = AsyncMock() + self.mock_server = AsyncMock() + self.mock_server.start = AsyncMock() + self.mock_server.stop = AsyncMock() + self.mock_discovery = AsyncMock() + self.mock_discovery.start = AsyncMock() + self.mock_discovery.stop = AsyncMock() + mock_peer1 = Mock(spec=PeerHandle) + mock_peer1.id.return_value = "peer1" + mock_peer2 = Mock(spec=PeerHandle) + mock_peer2.id.return_value = "peer2" + self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2]) + + self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery) + + async def asyncSetUp(self): + await self.node.start() + + async def asyncTearDown(self): + await self.node.stop() + + async def test_node_initialization(self): + self.assertEqual(self.node.node_id, "test_node") + self.assertEqual(self.node.host, "localhost") + self.assertEqual(self.node.port, 50051) + + async def test_node_start(self): + self.mock_server.start.assert_called_once_with("localhost", 50051) + + async def test_node_stop(self): + await self.node.stop() + self.mock_server.stop.assert_called_once() + + async def test_discover_and_connect_to_peers(self): + await self.node.discover_and_connect_to_peers() + self.assertEqual(len(self.node.peers), 2) + self.assertIn("peer1", map(lambda p: p.id(), self.node.peers)) + self.assertIn("peer2", map(lambda p: p.id(), self.node.peers)) + + async def test_process_tensor_calls_inference_engine(self): + mock_peer = Mock() + self.node.peers = [mock_peer] + + input_tensor = np.array([69, 1, 2]) + await self.node.process_tensor(input_tensor, None) + + self.node.inference_engine.process_shard.assert_called_once_with(input_tensor) diff --git a/build/lib/exo/stats/__init__.py b/build/lib/exo/stats/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/stats/metrics.py b/build/lib/exo/stats/metrics.py new file mode 100644 index 00000000..f29533ff --- /dev/null +++ b/build/lib/exo/stats/metrics.py @@ -0,0 +1,29 @@ +from exo.orchestration import Node +from prometheus_client import start_http_server, Counter, Histogram +import json + +# Create metrics to track time spent and requests made. +PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"]) +PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"]) +PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"]) + + +def start_metrics_server(node: Node, port: int): + start_http_server(port) + + def _on_opaque_status(request_id, opaque_status: str): + status_data = json.loads(opaque_status) + _type = status_data.get("type", "") + node_id = status_data.get("node_id", "") + if _type != "node_status": + return + status = status_data.get("status", "") + + if status == "end_process_prompt": + PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc() + elif status == "end_process_tensor": + elapsed_time_ns = status_data.get("elapsed_time_ns", 0) + PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc() + PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9) # Convert ns to seconds + + node.on_opaque_status.register("stats").on_next(_on_opaque_status) diff --git a/build/lib/exo/test_callbacks.py b/build/lib/exo/test_callbacks.py new file mode 100644 index 00000000..c10083d6 --- /dev/null +++ b/build/lib/exo/test_callbacks.py @@ -0,0 +1,50 @@ +import asyncio +from typing import Any, Callable +from exo.helpers import AsyncCallbackSystem, AsyncCallback + + +# Usage example +async def main() -> None: + callback_system = AsyncCallbackSystem[str, Any]() + + # Register callbacks + callback1 = callback_system.register("callback1") + callback2 = callback_system.register("callback2") + + def on_next_callback(name: str) -> Callable[..., None]: + def callback(*args: Any) -> None: + print(f"{name} received values: {args}") + + return callback + + callback1.on_next(on_next_callback("Callback1")) + callback2.on_next(on_next_callback("Callback2")) + + async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None: + try: + result = await callback.wait(condition, timeout=2) + print(f"{name} wait completed with result: {result}") + except asyncio.TimeoutError: + print(f"{name} wait timed out") + + # Trigger all callbacks at once + callback_system.trigger_all("Hello", 42, True) + + # Wait for all callbacks with different conditions + await asyncio.gather( + wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0), + wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True), + ) + + # Trigger individual callback + callback_system.trigger("callback2", "World", -10, False) + + # Demonstrate timeout + new_callback = callback_system.register("new_callback") + new_callback.on_next(on_next_callback("NewCallback")) + await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100) + + callback_system.trigger("callback2", "World", 200, False) + + +asyncio.run(main()) diff --git a/build/lib/exo/topology/__init__.py b/build/lib/exo/topology/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/topology/device_capabilities.py b/build/lib/exo/topology/device_capabilities.py new file mode 100644 index 00000000..51db53ef --- /dev/null +++ b/build/lib/exo/topology/device_capabilities.py @@ -0,0 +1,207 @@ +from exo import DEBUG +from dataclasses import dataclass, asdict +import subprocess +import psutil + +TFLOPS = 1.00 + + +@dataclass +class DeviceFlops: + # units of TFLOPS + fp32: float + fp16: float + int8: float + + def __str__(self): + return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS" + + def to_dict(self): + return asdict(self) + + +@dataclass +class DeviceCapabilities: + model: str + chip: str + memory: int + flops: DeviceFlops + + def __str__(self): + return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}" + + def __post_init__(self): + if isinstance(self.flops, dict): + self.flops = DeviceFlops(**self.flops) + + def to_dict(self): + return {"model": self.model, "chip": self.chip, "memory": self.memory, "flops": self.flops.to_dict()} + + +UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)) + +CHIP_FLOPS = { + # Source: https://www.cpu-monkey.com + # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative + ### M chips + "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS), + "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS), + "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS), + "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS), + "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS), + "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS), + "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS), + "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS), + "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS), + "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + ### A chips + "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS), + "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS), + "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS), + "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS), + "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS), + ### NVIDIA GPUs + # RTX 40 series + "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS), + "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS), + "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS), + "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS), + "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS), + "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS), + "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS), + "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS), + # RTX 30 series + "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS), + "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS), + "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS), + "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS), + "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS), + "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS), + "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS), + "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS), + "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS), + "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS), + # RTX 20 series + "NVIDIA GEFORCE RTX 2060": DeviceFlops(fp32=6.45*TFLOPS, fp16=12.9*TFLOPS, int8=25.8*TFLOPS), + "NVIDIA GEFORCE RTX 2060 SUPER": DeviceFlops(fp32=7.2*TFLOPS, fp16=14.4*TFLOPS, int8=28.8*TFLOPS), + "NVIDIA GEFORCE RTX 2070": DeviceFlops(fp32=7.46*TFLOPS, fp16=14.93*TFLOPS, int8=29.86*TFLOPS), + "NVIDIA GEFORCE RTX 2070 SUPER": DeviceFlops(fp32=9.06*TFLOPS, fp16=18.12*TFLOPS, int8=36.24*TFLOPS), + "NVIDIA GEFORCE RTX 2080": DeviceFlops(fp32=10.07*TFLOPS, fp16=20.14*TFLOPS, int8=40.28*TFLOPS), + "NVIDIA GEFORCE RTX 2080 SUPER": DeviceFlops(fp32=11.15*TFLOPS, fp16=22.30*TFLOPS, int8=44.60*TFLOPS), + "NVIDIA TITAN RTX": DeviceFlops(fp32=16.31*TFLOPS, fp16=32.62*TFLOPS, int8=65.24*TFLOPS), + # QUATRO RTX Ampere series + "NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS), + "NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS), + "NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS), + "NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS), + "NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS), + # Common Server GPUs + "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS), + "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA T1000 8GB": DeviceFlops(fp32=2.5 * TFLOPS, fp16=5.0 * TFLOPS, int8=10.0 * TFLOPS), + "Quadro M2000": DeviceFlops(fp32=0.5 * TFLOPS, fp16=1.0 * TFLOPS, int8=2.0 * TFLOPS), + "Quadro P400": DeviceFlops(fp32=0.641 * TFLOPS, fp16=1.282 * TFLOPS, int8=2.564 * TFLOPS), + # ... add more devices if needed ... + ### AMD GPUs + # RX 6000 series + "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS), + "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS), + "AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS), + "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS), + "AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS), + "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS), + "AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS), + "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS), + "AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS), + # RX 7000 series + "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS), + "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS), + "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS), + "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS), + "AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS), + "AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS), + ### Qualcomm embedded chips: TODO +} +CHIP_FLOPS.update({f"LAPTOP GPU {key}": value for key, value in CHIP_FLOPS.items()}) +CHIP_FLOPS.update({f"Laptop GPU {key}": value for key, value in CHIP_FLOPS.items()}) +CHIP_FLOPS.update({f"{key} LAPTOP GPU": value for key, value in CHIP_FLOPS.items()}) +CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items()}) + + +def device_capabilities() -> DeviceCapabilities: + if psutil.MACOS: + return mac_device_capabilities() + elif psutil.LINUX: + return linux_device_capabilities() + else: + return DeviceCapabilities( + model="Unknown Device", + chip="Unknown Chip", + memory=psutil.virtual_memory().total // 2**20, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ) + + +def mac_device_capabilities() -> DeviceCapabilities: + # Fetch the model of the Mac using system_profiler + model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8") + model_line = next((line for line in model.split("\n") if "Model Name" in line), None) + model_id = model_line.split(": ")[1] if model_line else "Unknown Model" + chip_line = next((line for line in model.split("\n") if "Chip" in line), None) + chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip" + memory_line = next((line for line in model.split("\n") if "Memory" in line), None) + memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory" + memory_units = memory_str.split() + memory_value = int(memory_units[0]) + if memory_units[1] == "GB": + memory = memory_value*1024 + else: + memory = memory_value + + # Assuming static values for other attributes for demonstration + return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0))) + + +def linux_device_capabilities() -> DeviceCapabilities: + import psutil + from tinygrad import Device + + if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}") + if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU": + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + gpu_name = pynvml.nvmlDeviceGetName(handle).upper() + gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + + if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}") + + return DeviceCapabilities( + model=f"Linux Box ({gpu_name})", + chip=gpu_name, + memory=gpu_memory_info.total // 2**20, + flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)), + ) + elif Device.DEFAULT == "AMD": + # TODO AMD support + return DeviceCapabilities( + model="Linux Box (AMD)", + chip="Unknown AMD", + memory=psutil.virtual_memory().total // 2**20, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ) + else: + return DeviceCapabilities( + model=f"Linux Box (Device: {Device.DEFAULT})", + chip=f"Unknown Chip (Device: {Device.DEFAULT})", + memory=psutil.virtual_memory().total // 2**20, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ) diff --git a/build/lib/exo/topology/partitioning_strategy.py b/build/lib/exo/topology/partitioning_strategy.py new file mode 100644 index 00000000..29c3dc6a --- /dev/null +++ b/build/lib/exo/topology/partitioning_strategy.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import List +from dataclasses import dataclass +from .topology import Topology +from exo.inference.shard import Shard + + +# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1 +@dataclass +class Partition: + node_id: str + start: float + end: float + + +class PartitioningStrategy(ABC): + @abstractmethod + def partition(self, topology: Topology) -> List[Partition]: + pass + + +def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]: + shards = [] + for i, partition in enumerate(partitions): + start_layer = int(partition.start*num_layers) + end_layer = int(partition.end*num_layers) - 1 + + # Ensure the last partition covers up to num_layers - 1 + if i == len(partitions) - 1: + end_layer = num_layers - 1 + + # Ensure no empty shards + if start_layer <= end_layer: + shards.append(Shard(model_id, start_layer, end_layer, num_layers)) + + # Ensure full coverage + if shards and shards[-1].end_layer < num_layers - 1: + shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers) + + return shards diff --git a/build/lib/exo/topology/ring_memory_weighted_partitioning_strategy.py b/build/lib/exo/topology/ring_memory_weighted_partitioning_strategy.py new file mode 100644 index 00000000..6550aeb1 --- /dev/null +++ b/build/lib/exo/topology/ring_memory_weighted_partitioning_strategy.py @@ -0,0 +1,18 @@ +from typing import List +from .partitioning_strategy import PartitioningStrategy +from .topology import Topology +from .partitioning_strategy import Partition + + +class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy): + def partition(self, topology: Topology) -> List[Partition]: + nodes = list(topology.all_nodes()) + nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True) + total_memory = sum(node[1].memory for node in nodes) + partitions = [] + start = 0 + for node in nodes: + end = round(start + (node[1].memory/total_memory), 5) + partitions.append(Partition(node[0], start, end)) + start = end + return partitions diff --git a/build/lib/exo/topology/test_device_capabilities.py b/build/lib/exo/topology/test_device_capabilities.py new file mode 100644 index 00000000..5f8b4c3a --- /dev/null +++ b/build/lib/exo/topology/test_device_capabilities.py @@ -0,0 +1,91 @@ +import unittest +from unittest.mock import patch +from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS + + +class TestMacDeviceCapabilities(unittest.TestCase): + @patch("subprocess.check_output") + def test_mac_device_capabilities_pro(self, mock_check_output): + # Mock the subprocess output + mock_check_output.return_value = b""" +Hardware: + +Hardware Overview: + +Model Name: MacBook Pro +Model Identifier: Mac15,9 +Model Number: Z1CM000EFB/A +Chip: Apple M3 Max +Total Number of Cores: 16 (12 performance and 4 efficiency) +Memory: 128 GB +System Firmware Version: 10000.000.0 +OS Loader Version: 10000.000.0 +Serial Number (system): XXXXXXXXXX +Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX +Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX +Activation Lock Status: Enabled +""" + + # Call the function + result = mac_device_capabilities() + + # Check the results + self.assertIsInstance(result, DeviceCapabilities) + self.assertEqual(result.model, "MacBook Pro") + self.assertEqual(result.chip, "Apple M3 Max") + self.assertEqual(result.memory, 131072) # 16 GB in MB + self.assertEqual( + str(result), + "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS", + ) + + @patch("subprocess.check_output") + def test_mac_device_capabilities_air(self, mock_check_output): + # Mock the subprocess output + mock_check_output.return_value = b""" +Hardware: + +Hardware Overview: + +Model Name: MacBook Air +Model Identifier: Mac14,2 +Model Number: MLY33B/A +Chip: Apple M2 +Total Number of Cores: 8 (4 performance and 4 efficiency) +Memory: 8 GB +System Firmware Version: 10000.00.0 +OS Loader Version: 10000.00.0 +Serial Number (system): XXXXXXXXXX +Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX +Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX +Activation Lock Status: Disabled +""" + + # Call the function + result = mac_device_capabilities() + + # Check the results + self.assertIsInstance(result, DeviceCapabilities) + self.assertEqual(result.model, "MacBook Air") + self.assertEqual(result.chip, "Apple M2") + self.assertEqual(result.memory, 8192) # 8 GB in MB + + @unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB") + def test_mac_device_capabilities_real(self): + # Call the function without mocking + result = mac_device_capabilities() + + # Check the results + self.assertIsInstance(result, DeviceCapabilities) + self.assertEqual(result.model, "MacBook Pro") + self.assertEqual(result.chip, "Apple M3 Max") + self.assertEqual(result.memory, 131072) # 128 GB in MB + self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS)) + self.assertEqual( + str(result), + "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/lib/exo/topology/test_map_partitions.py b/build/lib/exo/topology/test_map_partitions.py new file mode 100644 index 00000000..5254915e --- /dev/null +++ b/build/lib/exo/topology/test_map_partitions.py @@ -0,0 +1,81 @@ +import unittest +from typing import List +from exo.topology.partitioning_strategy import Partition, map_partitions_to_shards +from exo.inference.shard import Shard + + +class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): + def test_map_partitions_to_shards(self): + partitions = [ + Partition("node1", 0.0, 0.42857), + Partition("node2", 0.42857, 0.71428), + Partition("node3", 0.71428, 0.99999), + ] + shards = map_partitions_to_shards(partitions, 32, "model") + self.assertEqual( + shards, + [ + Shard("model", 0, 12, 32), + Shard("model", 13, 21, 32), + Shard("model", 22, 31, 32), + ], + ) + + partitions = [ + Partition("node1", 0.0, 0.1), + Partition("node2", 0.1, 0.2), + Partition("node3", 0.2, 1.0), + ] + shards = map_partitions_to_shards(partitions, 32, "model") + self.assertEqual( + shards, + [ + Shard("model", 0, 2, 32), + Shard("model", 3, 5, 32), + Shard("model", 6, 31, 32), + ], + ) + + partitions = [ + Partition("node1", 0.0, 1.0), + ] + shards = map_partitions_to_shards(partitions, 32, "model") + self.assertEqual( + shards, + [ + Shard("model", 0, 31, 32), + ], + ) + + partitions = [] + shards = map_partitions_to_shards(partitions, 32, "model") + self.assertEqual(shards, []) + + def test_broken_map_partitions_to_shards(self): + # this was an old broken implementation that sometimes had rounding errors! + def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str): + shards = [] + for i, partition in enumerate(partitions): + start_layer = int(partition.start*num_layers) + end_layer = int(partition.end*num_layers) - 1 + shards.append(Shard(model_id, start_layer, end_layer, num_layers)) + return shards + + partitions = [ + Partition("node1", 0.0, 0.42857), + Partition("node2", 0.42857, 0.71428), + Partition("node3", 0.71428, 0.99999), + ] + shards = _broken_map_partitions_to_shards(partitions, 32, "model") + self.assertEqual( + shards, + [ + Shard("model", 0, 12, 32), + Shard("model", 13, 21, 32), + Shard("model", 22, 30, 32), + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/lib/exo/topology/test_ring_memory_weighted_partitioning_strategy.py b/build/lib/exo/topology/test_ring_memory_weighted_partitioning_strategy.py new file mode 100644 index 00000000..fd466f36 --- /dev/null +++ b/build/lib/exo/topology/test_ring_memory_weighted_partitioning_strategy.py @@ -0,0 +1,90 @@ +import unittest +from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy +from exo.topology.topology import Topology +from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops +from exo.topology.partitioning_strategy import Partition + + +class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): + def test_partition(self): + # triangle + # node1 -> node2 -> node3 -> node1 + topology = Topology() + topology.update_node( + "node1", + DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)), + ) + topology.update_node( + "node2", + DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)), + ) + topology.update_node( + "node3", + DeviceCapabilities(model="test3", chip="test3", memory=6000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)), + ) + topology.add_edge("node1", "node2") + topology.add_edge("node2", "node3") + topology.add_edge("node3", "node1") + topology.add_edge("node1", "node3") + + strategy = RingMemoryWeightedPartitioningStrategy() + partitions = strategy.partition(topology) + + self.assertEqual(len(partitions), 3) + self.assertEqual( + partitions, + [ + Partition("node3", 0.0, 0.6), + Partition("node1", 0.6, 0.9), + Partition("node2", 0.9, 1.0), + ], + ) + + def test_partition_rounding(self): + # triangle + # node1 -> node2 -> node3 -> node1 + topology = Topology() + topology.update_node( + "node1", + DeviceCapabilities( + model="MacBook Pro", + chip="test1", + memory=128*1024*1024*1024, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ), + ) + topology.update_node( + "node2", + DeviceCapabilities( + model="Mac Studio", + chip="test2", + memory=192*1024*1024*1024, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ), + ) + topology.update_node( + "node3", + DeviceCapabilities( + model="MacBook Pro", + chip="test3", + memory=128*1024*1024*1024, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ), + ) + + strategy = RingMemoryWeightedPartitioningStrategy() + partitions = strategy.partition(topology) + + self.assertEqual(len(partitions), 3) + self.assertEqual( + partitions, + [ + Partition("node3", 0.0, 0.42857), + Partition("node1", 0.6, 0.9), + Partition("node2", 0.9, 1.0), + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/lib/exo/topology/topology.py b/build/lib/exo/topology/topology.py new file mode 100644 index 00000000..46b512e5 --- /dev/null +++ b/build/lib/exo/topology/topology.py @@ -0,0 +1,49 @@ +from .device_capabilities import DeviceCapabilities +from typing import Dict, Set, Optional + + +class Topology: + def __init__(self): + self.nodes: Dict[str, DeviceCapabilities] = {} # Maps node IDs to DeviceCapabilities + self.peer_graph: Dict[str, Set[str]] = {} # Adjacency list representing the graph + self.active_node_id: Optional[str] = None + + def update_node(self, node_id: str, device_capabilities: DeviceCapabilities): + self.nodes[node_id] = device_capabilities + + def get_node(self, node_id: str) -> DeviceCapabilities: + return self.nodes.get(node_id) + + def all_nodes(self): + return self.nodes.items() + + def add_edge(self, node1_id: str, node2_id: str): + if node1_id not in self.peer_graph: + self.peer_graph[node1_id] = set() + if node2_id not in self.peer_graph: + self.peer_graph[node2_id] = set() + self.peer_graph[node1_id].add(node2_id) + self.peer_graph[node2_id].add(node1_id) + + def get_neighbors(self, node_id: str) -> Set[str]: + return self.peer_graph.get(node_id, set()) + + def all_edges(self): + edges = [] + for node, neighbors in self.peer_graph.items(): + for neighbor in neighbors: + if (neighbor, node) not in edges: # Avoid duplicate edges + edges.append((node, neighbor)) + return edges + + def merge(self, other: "Topology"): + for node_id, capabilities in other.nodes.items(): + self.update_node(node_id, capabilities) + for node_id, neighbors in other.peer_graph.items(): + for neighbor in neighbors: + self.add_edge(node_id, neighbor) + + def __str__(self): + nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items()) + edges_str = ", ".join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items()) + return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})" diff --git a/build/lib/exo/viz/__init__.py b/build/lib/exo/viz/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/viz/test_topology_viz.py b/build/lib/exo/viz/test_topology_viz.py new file mode 100644 index 00000000..e57de1ae --- /dev/null +++ b/build/lib/exo/viz/test_topology_viz.py @@ -0,0 +1,129 @@ +import asyncio +import unittest +from datetime import timedelta +from exo.viz.topology_viz import TopologyViz +from exo.topology.topology import Topology +from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops +from exo.topology.partitioning_strategy import Partition +from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent + + +def create_hf_repo_progress_event( + completed_files: int = 5, + total_files: int = 10, + downloaded_bytes: int = 500000000, + downloaded_bytes_this_session: int = 250000000, + total_bytes: int = 1000000000, + overall_speed: int = 5000000, + overall_eta: timedelta = timedelta(seconds=100), + file_progress: dict = None, + status: str = "in_progress" +) -> RepoProgressEvent: + if file_progress is None: + file_progress = { + "file1.bin": + RepoFileProgressEvent( + repo_id="repo_id", + repo_revision="repo_revision", + file_path="file1.bin", + downloaded=100000000, + downloaded_this_session=50000000, + total=200000000, + speed=1000000, + eta=timedelta(seconds=100), + status="in_progress" + ), "file2.bin": + RepoFileProgressEvent( + repo_id="repo_id", + repo_revision="repo_revision", + file_path="file2.bin", + downloaded=200000000, + downloaded_this_session=100000000, + total=200000000, + speed=2000000, + eta=timedelta(seconds=0), + status="complete" + ) + } + + return RepoProgressEvent( + repo_id="repo_id", + repo_revision="repo_revision", + completed_files=completed_files, + total_files=total_files, + downloaded_bytes=downloaded_bytes, + downloaded_bytes_this_session=downloaded_bytes_this_session, + total_bytes=total_bytes, + overall_speed=overall_speed, + overall_eta=overall_eta, + file_progress=file_progress, + status=status + ) + + +class TestNodeViz(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.topology = Topology() + self.topology.update_node( + "node1", + DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)), + ) + self.topology.update_node( + "node2", + DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)), + ) + self.topology.update_node( + "node3", + DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)), + ) + self.topology.update_node( + "node4", + DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)), + ) + + self.top_viz = TopologyViz() + await asyncio.sleep(2) # Simulate running for a short time + + async def test_layout_generation(self): + # self.top_viz._generate_layout() + self.top_viz.refresh() + import time + + time.sleep(2) + self.top_viz.update_visualization( + self.topology, + [ + Partition("node1", 0, 0.2), + Partition("node4", 0.2, 0.4), + Partition("node2", 0.4, 0.8), + Partition("node3", 0.8, 0.9), + ], + "node1", + { + "node1": create_hf_repo_progress_event(), + "node2": create_hf_repo_progress_event(), + "node3": create_hf_repo_progress_event(), + "node4": create_hf_repo_progress_event(), + }, + ) + time.sleep(2) + self.topology.active_node_id = "node3" + self.top_viz.update_visualization( + self.topology, + [ + Partition("node1", 0, 0.3), + Partition("node5", 0.3, 0.5), + Partition("node2", 0.5, 0.7), + Partition("node4", 0.7, 0.9), + ], + "node5", + { + "node1": create_hf_repo_progress_event(), + "node5": create_hf_repo_progress_event(), + }, + ) + time.sleep(2) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/lib/exo/viz/topology_viz.py b/build/lib/exo/viz/topology_viz.py new file mode 100644 index 00000000..3664f378 --- /dev/null +++ b/build/lib/exo/viz/topology_viz.py @@ -0,0 +1,307 @@ +import math +from collections import OrderedDict +from typing import List, Optional, Tuple, Dict +from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second +from exo.topology.topology import Topology +from exo.topology.partitioning_strategy import Partition +from exo.download.hf.hf_helpers import RepoProgressEvent +from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES +from rich.console import Console, Group +from rich.text import Text +from rich.live import Live +from rich.style import Style +from rich.table import Table +from rich.layout import Layout +from rich.syntax import Syntax +from rich.panel import Panel +from rich.markdown import Markdown + + +class TopologyViz: + def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []): + self.chatgpt_api_endpoints = chatgpt_api_endpoints + self.web_chat_urls = web_chat_urls + self.topology = Topology() + self.partitions: List[Partition] = [] + self.node_id = None + self.node_download_progress: Dict[str, RepoProgressEvent] = {} + self.requests: OrderedDict[str, Tuple[str, str]] = {} + + self.console = Console() + self.layout = Layout() + self.layout.split(Layout(name="main"), Layout(name="prompt_output", size=15), Layout(name="download", size=25)) + self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow") + self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green") + self.download_panel = Panel("", title="Download Progress", border_style="cyan") + self.layout["main"].update(self.main_panel) + self.layout["prompt_output"].update(self.prompt_output_panel) + self.layout["download"].update(self.download_panel) + + # Initially hide the prompt_output panel + self.layout["prompt_output"].visible = False + self.live_panel = Live(self.layout, auto_refresh=False, console=self.console) + self.live_panel.start() + + def update_visualization(self, topology: Topology, partitions: List[Partition], node_id: Optional[str] = None, node_download_progress: Dict[str, RepoProgressEvent] = {}): + self.topology = topology + self.partitions = partitions + self.node_id = node_id + if node_download_progress: + self.node_download_progress = node_download_progress + self.refresh() + + def update_prompt(self, request_id: str, prompt: Optional[str] = None): + if request_id in self.requests: + self.requests[request_id] = [prompt, self.requests[request_id][1]] + else: + self.requests[request_id] = [prompt, ""] + self.refresh() + + def update_prompt_output(self, request_id: str, output: Optional[str] = None): + if request_id in self.requests: + self.requests[request_id] = [self.requests[request_id][0], output] + else: + self.requests[request_id] = ["", output] + self.refresh() + + def refresh(self): + self.main_panel.renderable = self._generate_main_layout() + # Update the panel title with the number of nodes and partitions + node_count = len(self.topology.nodes) + self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})" + + # Update and show/hide prompt and output panel + if any(r[0] or r[1] for r in self.requests.values()): + self.prompt_output_panel = self._generate_prompt_output_layout() + self.layout["prompt_output"].update(self.prompt_output_panel) + self.layout["prompt_output"].visible = True + else: + self.layout["prompt_output"].visible = False + + # Only show download_panel if there are in-progress downloads + if any(progress.status == "in_progress" for progress in self.node_download_progress.values()): + self.download_panel.renderable = self._generate_download_layout() + self.layout["download"].visible = True + else: + self.layout["download"].visible = False + + self.live_panel.update(self.layout, refresh=True) + + def _generate_prompt_output_layout(self) -> Panel: + content = [] + requests = list(self.requests.values())[-3:] # Get the 3 most recent requests + max_width = self.console.width - 6 # Full width minus padding and icon + max_lines = 13 # Maximum number of lines for the entire panel content + + for (prompt, output) in reversed(requests): + prompt_icon, output_icon = "💬️", "🤖" + + # Process prompt + prompt_lines = prompt.split('\n') + if len(prompt_lines) > max_lines // 2: + prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...'] + prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue") + prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white") + + # Process output + output_lines = output.split('\n') + remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing + if len(output_lines) > remaining_lines: + output_lines = output_lines[:remaining_lines - 1] + ['...'] + output_text = Text(f"\n{output_icon} ", style="bold bright_magenta") + output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white") + + content.append(prompt_text) + content.append(output_text) + content.append(Text()) # Empty line between entries + + return Panel( + Group(*content), + title="", + border_style="cyan", + height=15, # Increased height to accommodate multiple lines + expand=True # Allow the panel to expand to full width + ) + + def _generate_main_layout(self) -> str: + # Calculate visualization parameters + num_partitions = len(self.partitions) + radius_x = 30 + radius_y = 12 + center_x, center_y = 50, 24 # Increased center_y to add more space + + # Generate visualization + visualization = [[" " for _ in range(100)] for _ in range(48)] # Increased height to 48 + + # Add exo_text at the top in bright yellow + exo_lines = exo_text.split("\n") + yellow_style = Style(color="bright_yellow") + max_line_length = max(len(line) for line in exo_lines) + for i, line in enumerate(exo_lines): + centered_line = line.center(max_line_length) + start_x = (100-max_line_length) // 2 + 15 + colored_line = Text(centered_line, style=yellow_style) + for j, char in enumerate(str(colored_line)): + if 0 <= start_x + j < 100 and i < len(visualization): + visualization[i][start_x + j] = char + + # Display chatgpt_api_endpoints and web_chat_urls + info_lines = [] + if len(self.web_chat_urls) > 0: + info_lines.append(f"Web Chat URL (tinychat): {' '.join(self.web_chat_urls[:1])}") + if len(self.chatgpt_api_endpoints) > 0: + info_lines.append(f"ChatGPT API endpoint: {' '.join(self.chatgpt_api_endpoints[:1])}") + + info_start_y = len(exo_lines) + 1 + for i, line in enumerate(info_lines): + start_x = (100 - len(line)) // 2 + 15 + for j, char in enumerate(line): + if 0 <= start_x + j < 100 and info_start_y + i < 48: + visualization[info_start_y + i][start_x + j] = char + + # Calculate total FLOPS and position on the bar + total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions) + bar_pos = (math.tanh(total_flops/20 - 2) + 1)/2 + + # Add GPU poor/rich bar + bar_width = 30 + bar_start_x = (100-bar_width) // 2 + bar_y = info_start_y + len(info_lines) + 1 + + # Create a gradient bar using emojis + gradient_bar = Text() + emojis = ["🟥", "🟧", "🟨", "🟩"] + for i in range(bar_width): + emoji_index = min(int(i/(bar_width/len(emojis))), len(emojis) - 1) + gradient_bar.append(emojis[emoji_index]) + + # Add the gradient bar to the visualization + visualization[bar_y][bar_start_x - 1] = "[" + visualization[bar_y][bar_start_x + bar_width] = "]" + for i, segment in enumerate(str(gradient_bar)): + visualization[bar_y][bar_start_x + i] = segment + + # Add labels + visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor" + visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = "GPU rich" + + # Add position indicator and FLOPS value + pos_x = bar_start_x + int(bar_pos*bar_width) + flops_str = f"{total_flops:.2f} TFLOPS" + visualization[bar_y - 1][pos_x] = "▼" + visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str + visualization[bar_y + 2][pos_x] = "▲" + + # Add an extra empty line for spacing + bar_y += 4 + + for i, partition in enumerate(self.partitions): + device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES) + + angle = 2*math.pi*i/num_partitions + x = int(center_x + radius_x*math.cos(angle)) + y = int(center_y + radius_y*math.sin(angle)) + + # Place node with different color for active node and this node + if partition.node_id == self.topology.active_node_id: + visualization[y][x] = "🔴" + elif partition.node_id == self.node_id: + visualization[y][x] = "🟢" + else: + visualization[y][x] = "🔵" + + # Place node info (model, memory, TFLOPS, partition) on three lines + node_info = [ + f"{device_capabilities.model} {device_capabilities.memory // 1024}GB", + f"{device_capabilities.flops.fp16}TFLOPS", + f"[{partition.start:.2f}-{partition.end:.2f}]", + ] + + # Calculate info position based on angle + info_distance_x = radius_x + 6 + info_distance_y = radius_y + 3 + info_x = int(center_x + info_distance_x*math.cos(angle)) + info_y = int(center_y + info_distance_y*math.sin(angle)) + + # Adjust text position to avoid overwriting the node icon and prevent cutoff + if info_x < x: + info_x = max(0, x - len(max(node_info, key=len)) - 1) + elif info_x > x: + info_x = min(99 - len(max(node_info, key=len)), info_x) + + # Adjust for top and bottom nodes + if 5*math.pi/4 < angle < 7*math.pi/4: + info_x += 4 + elif math.pi/4 < angle < 3*math.pi/4: + info_x += 3 + info_y -= 2 + + for j, line in enumerate(node_info): + for k, char in enumerate(line): + if 0 <= info_y + j < 48 and 0 <= info_x + k < 100: + if info_y + j != y or info_x + k != x: + visualization[info_y + j][info_x + k] = char + + # Draw line to next node + next_i = (i+1) % num_partitions + next_angle = 2*math.pi*next_i/num_partitions + next_x = int(center_x + radius_x*math.cos(next_angle)) + next_y = int(center_y + radius_y*math.sin(next_angle)) + + # Simple line drawing + steps = max(abs(next_x - x), abs(next_y - y)) + for step in range(1, steps): + line_x = int(x + (next_x-x)*step/steps) + line_y = int(y + (next_y-y)*step/steps) + if 0 <= line_y < 48 and 0 <= line_x < 100: + visualization[line_y][line_x] = "-" + + # Convert to string + return "\n".join("".join(str(char) for char in row) for row in visualization) + + def _generate_download_layout(self) -> Table: + summary = Table(show_header=False, box=None, padding=(0, 1), expand=True) + summary.add_column("Info", style="cyan", no_wrap=True, ratio=50) + summary.add_column("Progress", style="cyan", no_wrap=True, ratio=40) + summary.add_column("Percentage", style="cyan", no_wrap=True, ratio=10) + + # Current node download progress + if self.node_id in self.node_download_progress: + download_progress = self.node_download_progress[self.node_id] + title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):" + summary.add_row(Text(title, style="bold")) + progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})" + summary.add_row(progress_info) + + eta_info = f"{download_progress.overall_eta}" + summary.add_row(eta_info) + + summary.add_row("") # Empty row for spacing + + for file_path, file_progress in download_progress.file_progress.items(): + if file_progress.status != "complete": + progress = int(file_progress.downloaded/file_progress.total*30) + bar = f"[{'=' * progress}{' ' * (30 - progress)}]" + percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%" + summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage) + + summary.add_row("") # Empty row for spacing + + # Other nodes download progress summary + summary.add_row(Text("Other Nodes Download Progress:", style="bold")) + for node_id, progress in self.node_download_progress.items(): + if node_id != self.node_id: + device = self.topology.nodes.get(node_id) + partition = next((p for p in self.partitions if p.node_id == node_id), None) + partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else "" + percentage = progress.downloaded_bytes/progress.total_bytes*100 if progress.total_bytes > 0 else 0 + speed = pretty_print_bytes_per_second(progress.overall_speed) + device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}" + progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})" + progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]" + percentage_str = f"{percentage:.1f}%" + eta_str = f"{progress.overall_eta}" + summary.add_row(device_info, progress_info, percentage_str) + summary.add_row("", progress_bar, eta_str) + + return summary diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 7def1a41..64127406 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -62,6 +62,9 @@ def generate_completion( "finish_reason": finish_reason, }], } + + if DEBUG >= 3: + print(f"completion: {completion}") if not stream: completion["usage"] = { diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index e2eb2434..ab4deef7 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -194,6 +194,8 @@ async def download_file( if progress_callback: await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) return + if DEBUG >= 2: print(f"Range not satisfiable {file_path=} {total_size=} {downloaded_size=}") + return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) except ValueError: if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) diff --git a/exo/helpers.py b/exo/helpers.py index 1bba9cd9..7477b360 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -34,7 +34,6 @@ def get_system_info(): return "Linux" return "Non-Mac, non-Linux system" - def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int: used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports") diff --git a/exo/inference/torch/.gitignore b/exo/inference/torch/.gitignore new file mode 100644 index 00000000..6d76c24d --- /dev/null +++ b/exo/inference/torch/.gitignore @@ -0,0 +1,2 @@ +data/ +model/archive/ diff --git a/exo/inference/torch/README.md b/exo/inference/torch/README.md new file mode 100644 index 00000000..9d4e757d --- /dev/null +++ b/exo/inference/torch/README.md @@ -0,0 +1,92 @@ +# PyTorch inference engine + +## Devs +- [Vincent Castro](https://github.com/risingsunomi) + +## Notes/Issues +### 10/10/2024 +- To select a pytorch device via environment variables, set the variable TORCH_DEVICE + - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM + - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373) + - Looking into adding mobile device support properly +- If device is not CPU the data type defaults to float32 else float16. + +### 10/13/2024 +Still working on split model development (see test_split_model.py). Right now, it seems to do it but still transformers is loading more in the RAM and GPU as it loads up a larger models (causing an OOM). Will research and add to next update. Right now, tests are added and are in development. + +### 10/21/2024 +Working on removing transformers due to inference and VRAM usage [issues](https://github.com/exo-explore/exo/pull/139#issuecomment-2424953962). Creating a pure pytorch implementation of llama3 as using transformers wont work for exo. Using some code from meta but also implementing the use of torchtune. + +### 10/27/2024 +Still working on llama3 model but wanted to note that a better KVCache needs to be investigated. + +#### 11/17/2024 +Llama sharded model now working and next step is inference engine. Still testing on small llama 3.2 1B but will try larger models. + +## Tech + +Tested on + +```bash +# Laptop/PC +Distributor ID: Pop +Description: Pop!_OS 22.04 LTS +Release: 22.04 +Codename: jammy +CUDA Version: 12.4 +Nvidia Driver Version: 550.107.02 + +GPU 1: Nvidia GeForce RTX 3060 6GB Laptop +``` +```bash +# Server +Distributor ID: Pop +Description: Pop!_OS 22.04 LTS +Release: 22.04 +Codename: jammy +CUDA Version: 12.4 +Nvidia Driver Version: 550.90.07 + +GPU 1: NVIDIA T1000 8GB +GPU 2: NVIDIA Quadro M2000 4GB +GPU 3: NVIDIA Quadro M2000 4GB +GPU 4: NVIDIA Quadro P400 2GB +GPU 5: NVIDIA Quadro P400 2GB +``` + +## Current Model + +WIP pytorch llama model + +``` +# Llama-3.2-1B-Instruct # + +ShardedLlamaModel( + (model): ShardTransformerDecoder( + (tok_embeddings): Embedding(128256, 2048) + (layers): ModuleList( + (0-15): 16 x TransformerSelfAttentionLayer( + (attn): MultiHeadAttention( + (q_proj): Linear(in_features=2048, out_features=2048, bias=False) + (k_proj): Linear(in_features=2048, out_features=512, bias=False) + (v_proj): Linear(in_features=2048, out_features=512, bias=False) + (output_proj): Linear(in_features=2048, out_features=2048, bias=False) + (pos_embeddings): Llama3ScaledRoPE() + ) + (mlp): MultiLayerPreceptron( + (gate_proj): Linear(in_features=2048, out_features=8192, bias=False) + (up_proj): Linear(in_features=2048, out_features=8192, bias=False) + (down_proj): Linear(in_features=8192, out_features=2048, bias=False) + (act_fn): SiLU() + ) + (sa_norm): RMSNorm() + (mlp_norm): RMSNorm() + (sa_scale): Identity() + (mlp_scale): Identity() + ) + ) + (norm): RMSNorm() + ) +) + +``` diff --git a/exo/inference/torch/__init__.py b/exo/inference/torch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/exo/inference/torch/hf_inference.py b/exo/inference/torch/hf_inference.py new file mode 100644 index 00000000..4912a0a2 --- /dev/null +++ b/exo/inference/torch/hf_inference.py @@ -0,0 +1,377 @@ +""" +HFDynamicShardInferenceEngine +Sharded inference engine using PyTorch based HuggingFace transformers +""" +import asyncio +import os +import json +import functools +from concurrent.futures import ThreadPoolExecutor + +import numpy as np + +import torch + +from typing import Optional, Tuple, Union, List +from exo.inference.shard import Shard +from exo.inference.inference_engine import InferenceEngine +from exo.inference.torch.models.hf import ShardedHuggingFaceModel +from exo.inference.tokenizers import resolve_tokenizer +from exo.helpers import DEBUG +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.hf.hf_helpers import get_weight_map + +from transformers import Cache + +# model value options +TOP_K = 20 +TEMP = 0.6 +TOP_P = 0.9 + +class HFDynamicShardInferenceEngine(InferenceEngine): + def __init__(self, shard_downloader: HFShardDownloader): + """ + Initialize the inference engine. + + Args: + shard_downloader: Model and weights sharding download + """ + self.shard = None + self.shard_downloader = shard_downloader + + # the whole history with new logits need to + # be passed to the model to reach the end token + # even with caching + self.past_input_ids = None + + # setup cuda device + if os.environ.get("TORCH_DEVICE"): + self.device = torch.device(os.environ["TORCH_DEVICE"]) + elif torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + + torch.set_default_device(self.device) + + # setup cude dtype + self.dtype = torch.get_default_dtype() + + # setup device_map + if os.environ.get("TORCH_DEVICE_MAP"): + self.device_map = os.environ["TORCH_DEVICE_MAP"] + else: + self.device_map = str(self.device) + + def infer_caching( + self, + inference_state: Optional[str] = None + ) -> Tuple[Optional[torch.Tensor], Optional[dict]]: + """ + inference caching from inference_state json + """ + # setup cache and cached input_ids + past_iids = None + cached_iids = None + if inference_state is not None: + try: + infer_state = json.loads(inference_state) + except ValueError: + infer_state = None + + if infer_state is not None: + cached_iids = infer_state["cached_iids"] + if cached_iids is not None: + past_iids = None + if len(cached_iids) > 0: + past_iids = torch.tensor(cached_iids["input_ids"]).to(self.device) + cached_iids = {"input_ids": past_iids.tolist()} + + if DEBUG >= 4: + print(f"cached_iids len: {len(cached_iids)}") + print(f"cached_iids: {cached_iids}") + + return (past_iids, cached_iids) + + async def async_forward( + self, + input_ids: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None + ) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]: + """ + Asynchronously performs the forward pass using a stateful sharded model. + + Args: + input_ids (torch.Tensor, optional): Input token IDs for the model. If not provided, `hidden_states` must be used. + hidden_states (torch.Tensor, optional): Precomputed hidden states to be used instead of `input_ids`. + attention_mask (torch.Tensor, optional): Mask to prevent attention on padding token indices. + + Returns: + A tuple containing: + + - shard_hidden_states (torch.Tensor, optional): Hidden states resulting from the forward pass. + - shard_past_kvs (list(torch.FloatTensor), optional): List of past key-value tensors (cache) used in the model. + - shard_logits (torch.Tensor, optional): The logits computed during the forward pass. + """ + loop = asyncio.get_running_loop() + + with ThreadPoolExecutor() as pool: + result = await loop.run_in_executor(pool, functools.partial( + self.stateful_sharded_model.forward, + input_ids=input_ids, + hidden_states=hidden_states, + attention_mask=attention_mask + )) + + if DEBUG >=4: + print("async_forward") + print(f"result: {result}") + + return result[0], result[1], result[2] + + async def async_logit_sample( + self, + logits: torch.Tensor + ) -> torch.Tensor: + """ + Asynchronously samples logits using the model's logit sampling method. + + Args: + logits (torch.Tensor): The logits produced by the model for sampling. + + Returns: + next_logit (torch.Tensor): The next logit samples from given logis + """ + loop = asyncio.get_running_loop() + + with ThreadPoolExecutor() as pool: + result = await loop.run_in_executor(pool, functools.partial( + self.stateful_sharded_model.logits_sample, + logits=logits + )) + + return result + + async def infer_prompt( + self, + request_id: str, + shard: Shard, + prompt: str, + image_str: Optional[str] = None, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + """ + Asynchronously processes a prompt using the specified shard and returns the inference result. + + Args: + request_id (str): The unique identifier for the request. + shard (Shard): The model shard used for inference. + prompt (str): The text prompt to be processed by the model. + image_str (str, optional): A base64 encoded image string to be optionally used in the inference. Defaults to None. + inference_state (str, optional): The cached inference state for resuming or continuing inference. Defaults to None. + + Returns: + A tuple containing: + + - input_ids (np.ndarray): The processed token IDs as a NumPy array if logits were generated. Otherwise, it returns hidden states. + - cache_json (str): A JSON string containing the cached input IDs for further inference steps. + - is_finished (bool): A boolean indicating whether the model has reached the end-of-sequence (EOS) token. + """ + if DEBUG >= 4: + print("infer_prompt called") + print(f"prompt: {prompt}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + + await self.ensure_shard(shard) + + inputs = self.tokenizer([prompt], return_tensors="pt") + input_ids = inputs.input_ids.to(self.device) + input_attention_mask = inputs.attention_mask.to(self.device) + + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) + + if past_iids is not None: + self.past_input_ids = past_iids + else: + self.past_input_ids = input_ids + + if DEBUG >= 4: + print(f"past_input_ids: {self.past_input_ids}\n") + + shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( + input_ids=self.past_input_ids, + attention_mask=input_attention_mask + ) + + if DEBUG >= 4: + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + next_token = None + if shard_logits is not None: + next_token = await self.async_logit_sample(shard_logits) + self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) + input_ids = next_token + + if DEBUG >= 4: + print(f"\nnext_token: {next_token}") + + if self.past_input_ids is not None: + cached_iids = {"input_ids": self.past_input_ids.tolist()} + + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), + is_finished + ) + + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + """ + Asynchronously processes input tensor data using the specified shard and returns the inference result. + + Args: + request_id (str): The unique identifier for the request. + shard (Shard): The model shard used for inference. + input_data (np.ndarray): The input data in NumPy array format to be processed by the model. + inference_state (str, optional): The cached inference state for resuming or continuing inference. Defaults to None. + + Returns: + A tuple containing: + + - input_ids (np.ndarray): The processed token IDs as a NumPy array if logits were generated. Otherwise, it returns hidden states. + - cache_json (str): A JSON string containing the cached input IDs for further inference steps. + - is_finished (bool): A boolean indicating whether the model has reached the end-of-sequence (EOS) token. + """ + if DEBUG >= 4: + print("infer_tensor called") + print(f"input_data: {input_data}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + + await self.ensure_shard(shard) + + input_ids = torch.tensor(input_data).to(self.device) + + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) + + # detect if hidden_states or not + hidden_states = None + self.past_input_ids = None + if input_ids.size()[-1] > 1: + hidden_states = input_ids + self.past_input_ids = past_iids + else: + if past_iids is not None: + self.past_input_ids = past_iids + else: + self.past_input_ids = input_ids + + if DEBUG >= 4: + print(f"\npast_input_ids: {self.past_input_ids}") + print(f"\nhidden_state: {hidden_states}") + print(f"\ninference_state: {inference_state}") + + shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( + input_ids=self.past_input_ids, + hidden_states=hidden_states + ) + + next_token = None + if shard_logits is not None: + next_token = await self.async_logit_sample(shard_logits) + input_ids = next_token + + #cache + next_cached_logits = None + if next_token is not None: + if self.past_input_ids is not None: + next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) + elif past_iids is not None: + next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) + + cached_iids = { + "input_ids": next_cached_logits.tolist() if next_cached_logits is not None else [] + } + + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id + + if is_finished: + # clear cache + cached_iids = {"input_ids": []} + + if DEBUG >= 4: + print(f"\ninput_ids: {input_ids}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), + is_finished + ) + + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + + async def ensure_shard(self, shard: Shard): + """ + Ensure the model shard is loaded and ready for inference. + + Args: + shard (Optional[Shard]): Shard information for the model. + """ + if self.shard == shard: + return + + if DEBUG >= 4: + print(f"Loading new shard: {shard}") + + model_path = await self.shard_downloader.ensure_shard(shard) + + # get model weight map + model_wm = await get_weight_map(repo_id=shard.model_id) + + self.stateful_sharded_model = ShardedHuggingFaceModel( + shard=shard, + local_model_path=model_path, + weight_map=model_wm, + device=self.device, + dtype=self.dtype, + device_map=self.device_map, + top_k=TOP_K, + temp=TEMP, + top_p=TOP_P + ) + self.shard = shard + + self.tokenizer = await resolve_tokenizer(shard.model_id) + + if DEBUG >= 4: + print(f"Shard loaded successfully: {shard}") diff --git a/exo/inference/torch/models/__init__.py b/exo/inference/torch/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/exo/inference/torch/models/hf.py b/exo/inference/torch/models/hf.py new file mode 100644 index 00000000..5d5b03e4 --- /dev/null +++ b/exo/inference/torch/models/hf.py @@ -0,0 +1,338 @@ +from typing import Tuple, Optional, Union, List +from pathlib import Path + +import torch +import torch.nn as nn + +from exo.inference.shard import Shard +from exo.helpers import DEBUG +from exo.inference.torch.models.hf_safe_tensor_shard import HFSafeTensorShard + +from transformers import ( + AutoModelForCausalLM, + DynamicCache, + Cache, + LogitsProcessorList, + TopKLogitsWarper, + TopPLogitsWarper, + TemperatureLogitsWarper +) + +# llama +from transformers.models.llama.modeling_llama import LlamaModel + +class ShardedHuggingFaceModel: + def __init__( + self, + shard: Shard, + local_model_path: Path, + weight_map: Optional[dict], + device: torch.device, + dtype: torch.dtype, + device_map: str, + top_k: int = 25, + temp: float = 0.7, + top_p: float = 0.9, + offload_buffers: bool = True + ): + """ + Initializes the ShardedHuggingFaceModel with a specified shard, model path, and device. + + Args: + shard (Shard): The model shard containing the start and end layers. + local_model_path (str): The local path to the model. + device (str): The device on which to run the model, e.g., "cuda" or "cpu". + dtype (torch.dtype): The data type (precision) to be used for model computations. + top_k (int, optional): The number of top tokens to consider for sampling. Defaults to 25. + temp (float, optional): The temperature for softmax sampling. Defaults to 0.7. + top_p (float, optional): The cumulative probability threshold for nucleus sampling. Defaults to 0.9. + """ + + # class vars + self.shard = shard + self.local_model_path = local_model_path + self.weight_map = weight_map + self.device = device + self.dtype = dtype + self.device_map = device_map + self.offload_buffers = offload_buffers + self.model_safetensors_path = self.local_model_path/"model.safetensors.index.json" + self.safetensor_sharder = HFSafeTensorShard( + self.local_model_path, + self.shard + ) + # setup logit processors + self.logits_processor = LogitsProcessorList([ + TopKLogitsWarper(top_k), + TemperatureLogitsWarper(temp), + TopPLogitsWarper(top_p) + ]) + + # setup sharded llm + try: + self.llm_model = self.load_sharded_model() + self.model = self.llm_model.model.to(self.device) + + # restore originals for next run, if one + self.safetensor_sharder.restore_backups() + except Exception as err: + print(f"error loading and sharding model: {err}") + raise + + # forward variables + self.hidden_states = None + self.input_ids = None + self.inputs_embeds = None + self.attention_mask = None + self.position_embeddings = None + self.past_key_values = None + self.cache_position = None + self.position_ids = None + self.causal_mask = None + + def load_sharded_model(self) -> AutoModelForCausalLM: + """ + Loads sharded version of model where only needed + weights are loaded for necessary layers + + Returns: + llm_model (AutoModelForCausalLM) - sharded llm model with only needed layers loaded + """ + if DEBUG >= 4: + print("load_sharded_model called") + + # modify safetensor + self.safetensor_sharder.modify_safetensor() + self.safetensor_sharder.create_safetensor_index() + self.safetensor_sharder.shard_safetensor_index(self.weight_map) + + # load model + try: + shard_num_hidden_layers = (self.shard.end_layer - self.shard.start_layer) + 1 + if DEBUG >= 4: + print(f"config with {shard_num_hidden_layers} layers") + + llm_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=self.local_model_path, + device_map=self.device_map, + torch_dtype=self.dtype, + offload_buffers=self.offload_buffers, + local_files_only=True, + num_hidden_layers=shard_num_hidden_layers, + use_safetensors=True, + low_cpu_mem_usage=True + ) + + # restore backup for next run + self.safetensor_sharder.restore_backups() + + if self.device_map == "auto": + return llm_model + else: + return llm_model.to(self.device) + + except Exception as err: + print(f"err: {err}") + raise + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_legacy_cache: bool = False + ) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]: + """ + Performs a forward pass through the model shard, computing hidden states, past key values, and logits. + + Args: + input_ids (torch.Tensor, optional): The input token IDs for the model. Either input_ids or hidden_states must be provided. + hidden_states (torch.Tensor, optional): The hidden states of the model at the current layer. + attention_mask (torch.Tensor, optional): The attention mask to prevent attending to padding tokens. + past_key_values (Union[Cache, List[torch.FloatTensor]], optional): Cached past key values for fast autoregressive generation. + use_legacy_cache (bool, optional): Whether to use the legacy cache format for past key values. Defaults to False. + + Returns: + Tuple: + - hidden_states (torch.Tensor, optional): The hidden states after the forward pass. + - past_key_values (Union[Cache, List[torch.FloatTensor]], optional): The updated past key values. + - logits (torch.Tensor, optional): The logits produced by the model if the last layer is processed. + """ + model_inputs = None + self.hidden_states = hidden_states + self.input_ids = input_ids + + # if there is hidden states and no position_ids, will need to be calculated + # this is not needed for Qwen model but Llama requires it + + # embed input_ids + self.inputs_embeds = self.model.embed_tokens(self.input_ids) + + # cache + if past_key_values and not isinstance(past_key_values, Cache): + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + self.position_ids = cache_position.unsqueeze(0) + + if DEBUG >= 4: + print("hf forward called") + print(f"hidden_states: {self.hidden_states}") + print(f"input_ids: {self.input_ids}") + print(f"input_embeds: {self.inputs_embeds}") + print(f"position_ids: {self.position_ids}") + print(f"past_key_values: {past_key_values}") + + if self.hidden_states is None: + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, + self.inputs_embeds, + cache_position, + past_key_values, + False # dont out attentions + ) + + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( + self.inputs_embeds, + self.position_ids + ) + + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=self.position_ids, + cache_position=cache_position + ) + + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] + + if DEBUG >= 4: + print(f"model_inputs: {model_inputs}") + + # run through decoder layers + layer_amt = range(self.shard.end_layer - self.shard.start_layer) + + if DEBUG >= 4: + print(f"hidden_states: {self.hidden_states}") + print(f"model layer amt: {len(self.model.layers)}") + print(f"layer_amt: {layer_amt}") + + for i in layer_amt: + decoder_layer = self.model.layers[i] + if DEBUG >= 5: + print(f"layer #{i}") + print("decoder_layer before") + print(f"decoder_layer: {decoder_layer}") + print(f"hidden_states: {self.hidden_states}") + print(f"position_ids: {self.position_ids}") + print(f"position_embeddings: {self.position_embeddings}") + + # TODO: fix caching as decoder layer is not returning + # present_key_value from attention layer on models + # might have some other generation functions needed to do it + # see https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2917 + # for qwen2 exhttps://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py#L291 + layer_outputs = decoder_layer( + self.hidden_states, + attention_mask=self.causal_mask, + position_ids=self.position_ids, + past_key_values=self.past_key_values, + use_cache=True, + cache_position=self.cache_position + ) + + self.hidden_states = layer_outputs[0] + self.next_decoder_cache = layer_outputs[1] + + if DEBUG >= 5: + print("decoder_layer after") + print(f"layer_outputs: {layer_outputs}\n") + print(f"self.next_decoder_cache: {self.next_decoder_cache}") + print(f"hidden_states: {self.hidden_states}") + print(f"next_decoder_cache: {self.next_decoder_cache}") + + # handle last layer to get logits + # shard is last layer says true at the start and not detecting last layer correctly + if self.shard.is_last_layer(): + self.hidden_states = self.model.norm(self.hidden_states) + if use_legacy_cache: + self.past_key_values = self.next_decoder_cache.to_legacy_cache() + else: + self.past_key_values = self.next_decoder_cache + + # lm_head + logits = self.llm_model.lm_head(self.hidden_states).to(self.device) + + if DEBUG >= 4: + print(f"logits: {logits}") + + return ( + None, + None, + logits + ) + + if DEBUG >= 4: + print("hf out [no logit]") + print(f"hidden_states: {self.hidden_states}") + print(f"past_key_values: {self.past_key_values}") + print(f"position_ids: {self.position_ids}") + print(f"input_ids: {self.input_ids}") + + return ( + self.hidden_states, + self.past_key_values, + None + ) + + def logits_sample( + self, + logits: torch.Tensor, + use_max: Optional[bool] = False + ) -> torch.Tensor: + """ + Samples the next token from the model's output logits, either by using argmax or probabilistic sampling. + + Args: + logits (torch.Tensor): The logits output from the model's final layer. + use_max (bool, optional): If True, uses torch.argmax to select the next token from logits. Defaults to False. + + Returns: + torch.Tensor: The next predicted token. + """ + + # get a single cloned logit + logits = logits[:, -1, :].clone().float() + + next_token_scores = self.logits_processor(self.input_ids, logits) + + if not use_max: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_token_scores, dim=-1) + + if DEBUG >= 4: + print(f"input_ids: {self.input_ids}") + print(f"next_token: {next_token}") + + return next_token[:, None].squeeze(-1) diff --git a/exo/inference/torch/models/hf_safe_tensor_shard.py b/exo/inference/torch/models/hf_safe_tensor_shard.py new file mode 100644 index 00000000..c3afdea5 --- /dev/null +++ b/exo/inference/torch/models/hf_safe_tensor_shard.py @@ -0,0 +1,243 @@ +""" +HuggingFace Safetensor Shard +Sharding of safetensors to only use weights of models needed +""" +import os +import shutil +import json + +from typing import Optional +from pathlib import Path + +from safetensors import safe_open +from safetensors.torch import save_file + +import torch + +from exo.inference.shard import Shard +from exo.helpers import DEBUG +from exo.inference.torch.utils import extract_layers + +class HFSafeTensorShard: + def __init__(self, model_path: Path, shard: Shard): + self.model_path = model_path + self.shard = shard + self.safetensors_path = self.get_safetensors() + self.safetensor_index_path = f"{self.model_path}/model.safetensors.index.json" + self.metadata = { + "metadata": { + "total_size": 0 + }, + "weight_map": {} + } + + def get_safetensors(self) -> list: + """ + Gets a list of all files that have the extension .safetensors + + Return: + list: A list of all the safetensors file paths + """ + safetensors_path = [] + try: + for file_name in os.listdir(self.model_path): + if file_name.endswith(".safetensors"): + safetensor_path = os.path.join(self.model_path, file_name) + if safetensor_path not in safetensors_path: + safetensors_path.append(safetensor_path) + except Exception as err: + print(f"Error in get_safetensor_path: {err}") + raise + + return safetensors_path + + def backup_safetensor(self): + try: + for safetensor_path in self.safetensors_path: + backup_path = safetensor_path+".backup" + if not os.path.exists(backup_path): + shutil.copy(safetensor_path, backup_path) + + if DEBUG >= 4: + print(f"Backup created at {backup_path}") + except Exception as err: + print(f"Error in backup_safetensor: {err}") + raise + + def modify_safetensor(self): + """ + Extract needed weights for layers from safetensor files + and create a new safetensor with same names + """ + try: + self.backup_safetensor() + safetensor_is_used = False + for safetensor_path in self.safetensors_path: + initial_size = os.path.getsize(safetensor_path) + with safe_open(safetensor_path, framework="pt") as f: + metadata = f.metadata() + new_tensors = {} + + # Iterate over tensors, including only those within the specified layer range + for key in f.keys(): + layer_number = self.extract_layer_number(key) + if self.shard.start_layer <= layer_number <= self.shard.end_layer: + if DEBUG >= 4: + print(f"modify_safetensor [{layer_number}] extracting {key}") + new_tensors[key] = f.get_tensor(key) + safetensor_is_used = True + + # Save the modified safetensor + if safetensor_is_used: + save_file(new_tensors, safetensor_path, metadata) + modified_size = os.path.getsize(safetensor_path) + + if DEBUG >= 4: + print(f"Safetensor modified and saved to {safetensor_path}") + print(f"Initial size: {initial_size / (1024**3):.2f} GB") + print(f"Modified size: {modified_size / (1024**3):.2f} GB") + else: + # remove unused safetensors + os.remove(safetensor_path) + + if DEBUG >= 4: + print(f"Removed safetensor: {safetensor_path}") + except Exception as err: + print(f"Error modifying safetensor: {err}") + raise + + def extract_layer_number(self, key): + """ + Extract the layer number from a tensor key. + This function assumes keys follow the format 'model.layers..'. + """ + try: + parts = key.split(".") + layer_idx = 0 + if parts[0] == "model" and parts[1] == "layers": + layer_idx = int(parts[2]) + return layer_idx + #layer_idx = next(i for i, part in enumerate(parts) if part.startswith("h")) + #return int(parts[layer_idx + 1]) + except (IndexError, ValueError) as err: + print(f"Error extracting layer number from key '{key}': {err}") + return -1 + + def create_safetensor_index(self): + """ + Creates a model.safetensors.index.json file from a list of safetensor files. + """ + if os.path.exists(self.safetensor_index_path): + backup_index_path = f"{self.model_path}/model.safetensors.index.json.backup" + if not os.path.exists(backup_index_path): + shutil.copy(self.safetensor_index_path, backup_index_path) + + if DEBUG >= 4: + print(f"backed up index json {self.safetensor_index_path}") + + if self.safetensors_path: + # initialize the metadata and weight_map + for safetensor_file in self.safetensors_path: + # use the safetensor file name as the shard_name + shard_name = os.path.basename(safetensor_file) + + # open the safetensor file to read the metadata + with safe_open(safetensor_file, framework="pt", device="cpu") as f: + # get tensor names + tensor_names = f.keys() + + # collect metadata for each tensor + for name in tensor_names: + tensor_data = f.get_tensor(name) + shape = tensor_data.shape + dtype = tensor_data.dtype + + # calculate the tensor size in bytes based on dtype + total_elements = 1 + for dim in shape: + total_elements *= dim + + if dtype == torch.float32: + element_size = 4 + elif dtype == torch.float16 or dtype == torch.bfloat16: + element_size = 2 + # extend this to support more data types if needed + else: + raise ValueError(f"unsupported dtype: {dtype}") + + tensor_size = total_elements * element_size + self.metadata["metadata"]["total_size"] += tensor_size + + # add to weight_map, mapping the tensor to the shard (file) name + self.metadata["weight_map"][name] = shard_name + + # write the metadata and weight map to the index file + with open(self.safetensor_index_path, "w") as f: + json.dump(self.metadata, f, indent=4) + + if DEBUG >= 4: + print(f"created new {self.safetensor_index_path}") + else: + print("No safetensor files provided.") + + def shard_safetensor_index(self, weight_map: Optional[dict] = None): + """ + Modify the weight_map of the safetensors index json to only + get weights for the working layers + + Args: + weight_map(dict, Optional): holds which weight maps to which layer + """ + if weight_map is None: + weight_map = self.metadata["weight_map"] + + layer_weight_map = extract_layers( + weight_map, + self.shard + ) + + # rewrite model.safetensors.index.json for only needed layers + try: + mst_json = {} + with open(self.safetensor_index_path, "r") as mst_file: + mst_json = json.load(mst_file) + mst_json["weight_map"] = layer_weight_map + + if DEBUG >= 4: + print(f"new safetensor index\n{json.dumps(mst_json, indent=4)}\n") + + os.remove(self.safetensor_index_path) + + with open(self.safetensor_index_path, "w") as mst_file: + json.dump(mst_json, mst_file, indent=4) + except Exception as err: + print(f"err: {err}") + raise + + def restore_backups(self): + """ + Restore the original safetensor and index json, if any, from the backup file. + """ + try: + for safetensor_path in self.safetensors_path: + backup_path = safetensor_path+".backup" + if os.path.exists(backup_path): + os.remove(safetensor_path) + shutil.copy(backup_path, safetensor_path) + os.remove(backup_path) + + if DEBUG >= 4: + print(f"Safetensor restored from backup at {backup_path}") + + backup_index_path = self.safetensor_index_path+".backup" + if os.path.exists(backup_index_path): + os.remove(self.safetensor_index_path) + shutil.copy(backup_index_path, self.safetensor_index_path) + os.remove(backup_index_path) + + if DEBUG >= 4: + print(f"Safetensor index JSON restored from backup at {backup_index_path}") + except Exception as err: + print(f"Error in restore_backup: {err}") + raise + diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py new file mode 100644 index 00000000..feef0baa --- /dev/null +++ b/exo/inference/torch/models/llama3.py @@ -0,0 +1,395 @@ +""" +llama3 model + +Written with pytorch using torchtune and other methods +""" + +from typing import Optional, Any, Tuple, List, Union, Callable + +import torch +import torch.nn as nn +import torchtune.modules as ttm +import torchtune.generation as ttg +from torchtune.models.llama3_1 import Llama3ScaledRoPE +from torchtune.modules.attention_utils import _MaskType + +from exo.inference.shard import Shard +from exo.inference.torch.models.llm_utils import MultiLayerPreceptron, RMSNorm, get_torch_dtype + + +class ShardTransformerDecoder(ttm.TransformerDecoder): + """ + ShardTransformerDecorder + Custom version of torchtune TransformerDecoder to allow for + sharding of models and passing of hidden layers between shards + """ + + def __init__( + self, + *, + shard: Shard, + tok_embeddings: nn.Embedding, + layers: Union[nn.Module, List[nn.Module], nn.ModuleList], + max_seq_len: int, + num_heads: int, + head_dim: int, + norm: nn.Module, + output: Union[nn.Linear, Callable], + num_layers: Optional[int] = None, + output_hidden_states: Optional[List[int]] = None, + ): + super().__init__( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=norm, + output=output, + num_layers=num_layers, + output_hidden_states=output_hidden_states, + ) + + self.shard = shard + + def setup_caches( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, + ): + """ + modified version for shard + + assume just decoder layers + """ + if decoder_max_seq_len is not None: + self.decoder_max_cache_seq_len = decoder_max_seq_len + else: + self.decoder_max_cache_seq_len = self.max_seq_len + + for layer in self.layers: + if layer is not None: + layer.setup_caches( + batch_size, + dtype, + encoder_max_seq_len=self.encoder_max_cache_seq_len, + decoder_max_seq_len=self.decoder_max_cache_seq_len, + ) + + def caches_are_enabled(self) -> bool: + """ + modified version for shard + """ + if self.layers[0] is not None: + return self.layers[0].caches_are_enabled() + else: + for layer in self.layers: + if layer is not None: + return layer.caches_are_enabled() + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[_MaskType] = None, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + # Determine the type of input and shape + if tokens.ndim == 3: + h = tokens # Use directly as hidden states + else: + h = self.tok_embeddings(tokens) # Apply token tok_embeddings + + seq_len = h.shape[1] + + self._validate_inputs( + seq_len, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + + # Initialize a list to capture hidden states if requested + # for captured hidden states + hidden = [] + + for i in range(self.shard.start_layer, self.shard.end_layer + 1): + layer = self.layers[i] + + print(f"\nhidden layer in H[{i}]\n{h}\nmask\n{mask}\ninput_pos\n{input_pos}\n{self.output_hidden_states}\n") + + # Process through each transformer layer + with torch.no_grad(): + h = layer( + h, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + + if i in self.output_hidden_states: + hidden.append(h) + + print(f"\nhidden layer out H[{i}]->H[{i + 1}]\n{h}\n") + + # Apply normalization + h = self.norm(h) + + # Handle chunked output if needed + if self.num_output_chunks > 0: + output = self.chunked_output(h) + else: + output = self.output(h).float() + + # Return list if hidden states are requested + output = [hidden[-1], output] if hidden else output + print(f"\n\noutput {output}\n\n") + return output + + +def LlamaModel(config: dict, shard: Shard): + """ + LlamaModel using torchtune + """ + # rope scaling config + scale_factor = 32 + if config["rope_scaling"] is not None: + scale_factor = config["rope_scaling"].get("factor", 32) + + rope = Llama3ScaledRoPE( + dim=config["head_dim"], + max_seq_len=config["max_seq_len"], + base=config["rope_base"], + scale_factor=scale_factor, + ) + + # hack to align sharded weights with layers + # fill unused layer positions with None + layers = [None for _ in range(shard.n_layers)] + for i in range(shard.start_layer, shard.end_layer + 1): + self_attn = ttm.MultiHeadAttention( + embed_dim=config["embed_dim"], + num_heads=config["num_heads"], + num_kv_heads=config["num_kv_heads"], + head_dim=config["head_dim"], + q_proj=nn.Linear( + config["embed_dim"], + config["num_heads"] * config["head_dim"], + bias=config["attn_bias"], + ), + k_proj=nn.Linear( + config["embed_dim"], + config["num_kv_heads"] * config["head_dim"], + bias=config["attn_bias"], + ), + v_proj=nn.Linear( + config["embed_dim"], + config["num_kv_heads"] * config["head_dim"], + bias=config["attn_bias"], + ), + output_proj=nn.Linear( + config["embed_dim"], + config["embed_dim"], + bias=config["attn_bias"], + ), + max_seq_len=config["max_seq_len"], + attn_dropout=config["attn_dropout"], + pos_embeddings=rope, + ) + + mlp = MultiLayerPreceptron(config["embed_dim"], config["intermediate_dim"], config["hidden_act"]) + + layer = ttm.TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), + mlp_norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), + ) + + layers[i] = layer + + #for i in range(len(layers)): + # print(f"layers[{i}]: {layers[i]}") + layers = nn.ModuleList(layers) + tok_embeddings = nn.Embedding(config["vocab_size"], config["embed_dim"]) + output_proj = ttm.TiedLinear(tok_embeddings) + # output_proj = nn.Linear( + # config["embed_dim"], + # config["vocab_size"], + # bias=config["attn_bias"], + # ) + + return ShardTransformerDecoder( + tok_embeddings=tok_embeddings, + shard=shard, + layers=layers, + max_seq_len=config["max_seq_len"], + num_heads=config["num_heads"], + head_dim=config["head_dim"], + norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), + output=output_proj, + num_layers=config["num_layers"], + ) + + # return ttm.TransformerDecoder( + # tok_embeddings=tok_embeddings, + # layers=layers, + # max_seq_len=config["max_seq_len"], + # num_heads=config["num_heads"], + # head_dim=config["head_dim"], + # norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), + # output=output_proj, + # num_layers=config["num_layers"], + # ) + + +class ShardedLlamaModel(nn.Module): + def __init__( + self, + config: dict, + shard: Shard, + tokenizer: Any, + device: Optional[torch.device] = None, + max_new_tokens: int = 2048, + use_cache: Optional[bool] = False + ): + super(ShardedLlamaModel, self).__init__() + + self.tokenizer = tokenizer + self.shard = shard + self.config = config + self.dtype = get_torch_dtype(self.config["torch_dtype"]) if "torch_dtype" in self.config else torch.float + self.device = device if device is not None else torch.device("cpu") + self.max_new_tokens = max_new_tokens + self.max_seq_len = self.config["max_seq_len"] + + if use_cache: + self.use_cache = use_cache + else: + self.config.get("use_cache", False) + + self.model = LlamaModel(config, self.shard).to(dtype=self.dtype, device=self.device) + + print(f"model loaded: {self.model}\n") + + def generate( + self, + tokens: torch.Tensor, + hidden_state: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], bool]: + """ + Generate logits and/or hidden_states from llama model + + Args + tokens (torch.Tensor) - tokens from prompt tokenization and generation + hidden_state (torch.Tensor, optional) - hidden state from last activated hidden layer, if any + """ + if tokens.ndim == 1: + tokens = tokens.view(1, -1) + + bsz, tokens_length = tokens.size() + + # setup cache + print(self.model) + if not self.model.caches_are_enabled() and self.use_cache: + with self.device: + self.model.setup_caches( + bsz, + self.dtype, + decoder_max_seq_len=tokens.numel() + self.max_new_tokens + ) + + if not self.shard.is_last_layer(): + self.model.output_hidden_states = [self.shard.end_layer] + + total_response_length = tokens_length + self.max_seq_len + resp_max_seq_len = total_response_length if not self.model.caches_are_enabled() else self.model.decoder_max_cache_seq_len + + # clone tokens + generated_tokens = tokens.clone() + + # masking for proper attention + padding_masks = generated_tokens != self.tokenizer.pad_id + if not padding_masks.all(): + padding_masks = torch.nn.functional.pad(padding_masks, (0, self.max_seq_len), value=True) + + masks = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=resp_max_seq_len) + + input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) + else: + masks = torch.tril( + torch.ones( + total_response_length, + resp_max_seq_len if resp_max_seq_len is not None else total_response_length, + dtype=torch.bool, + device=tokens.device, + ) + ).unsqueeze(0) + + input_pos = torch.arange(0, total_response_length, device=generated_tokens.device).unsqueeze(0) + + if self.model.caches_are_enabled(): + curr_masks = masks[:, :tokens_length] + else: + curr_masks = masks[:, :tokens_length, :tokens_length] + + input_pos = input_pos[:, :tokens_length].squeeze() + + if hidden_state is not None: + model_output = self.model( + tokens=hidden_state, + mask=curr_masks, + input_pos=input_pos, + ) + else: + model_output = self.model( + tokens=tokens, + mask=curr_masks, + input_pos=input_pos, + ) + + print(f"\nmodel_output: {model_output}") + + # stop token + stop_tokens = None + + stop_token_reached = torch.zeros( + bsz, + dtype=torch.bool, + device=tokens.device + ) + stop_tokens = ( + torch.tensor( + stop_tokens, + device=tokens.device, + dtype=tokens.dtype + ) + if stop_tokens + else None + ) + + finished = False + + if isinstance(model_output, list): + model_logits = model_output[1] + model_output.pop() # remove logits + model_hs = model_output[0] # get last hidden state + else: + model_logits = model_output + model_hs = None + + if stop_tokens is not None: + stop_token_reached = ttg._generation.update_stop_tokens_tracker( + tokens, stop_tokens, stop_token_reached + ) + + finished = True if stop_token_reached.all() else False + + return model_hs, model_logits, finished diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py new file mode 100644 index 00000000..9edd779a --- /dev/null +++ b/exo/inference/torch/models/llm_utils.py @@ -0,0 +1,239 @@ +""" +Utility methods used by LLMs +""" + +import re +import json +from pathlib import Path +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchtune.modules as ttm +from torchtune.models.convert_weights import hf_to_tune +import math + +from safetensors.torch import load_file as load_safetensors + +from transformers import LogitsProcessorList, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper +from transformers.cache_utils import Cache, DynamicCache + +from exo.helpers import DEBUG +from exo.inference.shard import Shard + + +def get_torch_dtype(dtype_str: str) -> torch.dtype: + """ + Get dtype from setting in model's config.json + """ + if dtype_str == "bfloat16": + return torch.bfloat16 + elif dtype_str == "float16": + return torch.float16 + else: + return torch.float16 + + +def load_model_config(model_config_path: Path) -> dict: + """ + Loads the config.json of the model + + Args: + model_path (Path): local path to model config json + + Returns: + dict: The config as a dictionary + """ + model_config = {} + with open(model_config_path, "r") as f: + base_config = json.load(f) + + model_config = { + "rope_scaling": base_config.get("rope_scaling"), + "embed_dim": base_config["hidden_size"], + "num_heads": base_config["num_attention_heads"], + "head_dim": base_config["hidden_size"] // base_config["num_attention_heads"], # Assuming embed_dim = hidden_size + "num_kv_heads": base_config["num_key_value_heads"], + "max_seq_len": base_config["max_position_embeddings"], + "intermediate_dim": base_config["intermediate_size"], + "attn_dropout": base_config.get("attention_dropout", 0.0), + "norm_eps": base_config["rms_norm_eps"], + "rope_base": base_config["rope_theta"], + "vocab_size": base_config["vocab_size"], + "num_layers": base_config["num_hidden_layers"], + "attn_bias": base_config.get("attention_bias", False), + "hidden_act": base_config.get("hidden_act", "silu") + } + + return model_config + + +def check_weights(model, state_dict): + """ + Verifies that the weights from the state dictionary are properly loaded into the model. + """ + model_state_dict = model.state_dict() + for name, param in model_state_dict.items(): + if name in state_dict: + # print(f"\nchecking {name}\n") + loaded_param = state_dict[name] + if param.shape != loaded_param.shape: + print(f"Shape mismatch for {name}: expected {param.shape}, got {loaded_param.shape}") + else: + print(f"{name}: loaded correctly") + + for name in state_dict: + if name not in model_state_dict: + print(f"Unexpected weight {name} found in state_dict") + + +def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): + """ + Loads weights from huggingface and changes it to match torchtune naming structure + """ + # Load weights from safetensors files in the cache directory + safetensors_files = list(cache_dir.glob("*.safetensors")) + if not safetensors_files: + raise FileNotFoundError("No safetensors files found in the cache directory.") + + # Load weights from each found safetensors file + paried_lmhead = True + shard_layer_range = list(range(shard.start_layer, shard.end_layer)) + + full_state_dict = None + for safetensor_file in safetensors_files: + state_dict = load_safetensors(safetensor_file) + + if full_state_dict is not None: + full_state_dict = full_state_dict | state_dict + else: + full_state_dict = state_dict + + # remap to work with our model + remapped_state_dict = {} + paried_embed_weight = None + for key, value in full_state_dict.items(): + # load layer by shard + for layer_num in range(shard.start_layer, shard.end_layer + 1): + # change input layer norm to sa_norm for torchtune + re_iln = re.findall(rf"model.layers\.{layer_num}\.(input_layernorm)\.weight", key) + if len(re_iln) != 0: + new_key = f"model.layers.{layer_num}.sa_norm.weight" + # print(f"{key} == {new_key}") + remapped_state_dict[new_key] = value + + # change post attention layernorm to mlp_norm for torchtune + re_pal = re.findall(rf"model.layers\.{layer_num}\.(post_attention_layernorm)\.weight", key) + if len(re_pal) != 0: + new_key = f"model.layers.{layer_num}.mlp_norm.weight" + # print(f"{key} == {new_key}") + remapped_state_dict[new_key] = value + + # change self_attn to attn + # along with changing o_proj to output_proj + re_attn = re.findall(rf"model\.layers\.{layer_num}.(\w+)\.(\w+)\.(\w+)", key) + if len(re_attn) != 0 and re_attn[0][0] == "self_attn": + if re_attn[0][1] == "o_proj": + new_key = f"model.layers.{layer_num}.attn.output_proj.weight" + # print(f"{key} == {new_key}") + remapped_state_dict[new_key] = value + else: + new_key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}" + # print(f"{key} == {new_key}") + remapped_state_dict[new_key] = value + + # set mlp weights + re_mlp = re.findall(rf"model\.layers\.{layer_num}.mlp.(\w+)\.(\w+)", key) + if len(re_mlp) != 0: + new_key = f"model.layers.{layer_num}.mlp.{re_mlp[0][0]}.{re_mlp[0][1]}" + # print(f"load mlp {key}") + remapped_state_dict[new_key] = value + + # saving embed for paired weights + if key == "model.embed_tokens.weight": + # paried_embed_weight = value + # change name for torchtune + # print("model.embed_tokens.weight == model.tok_embeddings.weight") + remapped_state_dict["model.tok_embeddings.weight"] = value + + # elif key == "lm_head.weight": + # paried_lmhead = False + + # get everything else except layers, embed_tokens and lm_head + if len(re.findall(r"model\.layers\..*", key)) == 0 and key != "model.embed_tokens.weight" and key != "lm_head.weight": + # print(f"loading other weight: {key}") + remapped_state_dict[key] = value + + # if paried_lmhead: + # print(f"model.output.weight: {paried_embed_weight}") + # remapped_state_dict["model.output.weight"] = paried_embed_weight + + # print("\nRemapped state dict\n") + # for rsdk in remapped_state_dict.keys(): + # print(f"-- {rsdk}") + del state_dict + del full_state_dict + model.load_state_dict(remapped_state_dict, strict=False) + + # if DEBUG >= 7: + # print("\n--- checking weights ----\n") + # print(f"\nremapped_state_dict: {remapped_state_dict.keys()}\n") + check_weights(model, remapped_state_dict) + + +class MultiLayerPreceptron(nn.Module): + def __init__(self, input_dim, hidden_dim, activation="silu", use_bias=False): + """ + General MLP (Multi-Layer Perceptron) module. + + Args: + input_dim (int): Dimensionality of the input. + hidden_dims (int): Hidden layer/intermediate dimensions. + output_dim (int): Dimensionality of the output. + activation (str): Activation function ('relu', 'gelu', 'tanh', 'sigmoid', etc.). + use_bias (bool): Use bias with linearization + """ + super(MultiLayerPreceptron, self).__init__() + + # Activation function mapping + activations = { + "relu": nn.ReLU(), + "gelu": nn.GELU(), + "tanh": nn.Tanh(), + "sigmoid": nn.Sigmoid(), + "leaky_relu": nn.LeakyReLU(0.2), + "silu": nn.SiLU() + } + + # Ensure valid activation + if activation not in activations: + raise ValueError( + f"Invalid activation: {activation}. Choose from {list(activations.keys())}") + + # Construct MLP layers + self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) + self.up_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) + self.down_proj = nn.Linear(hidden_dim, input_dim, bias=use_bias) + self.act_fn = activations[activation] + + def forward(self, x) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + RMSNorm + designed for llama model but used for other models + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return self.weight * hidden_states.to(input_dtype) \ No newline at end of file diff --git a/exo/inference/torch/pt_inference.py b/exo/inference/torch/pt_inference.py new file mode 100644 index 00000000..b5a1c8fd --- /dev/null +++ b/exo/inference/torch/pt_inference.py @@ -0,0 +1,130 @@ +""" +TorchDynamicShardInferenceEngine +Sharded inference engine using PyTorch based torchtune models +""" +import os +from typing import Optional, Tuple, Union, List +import functools +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import asyncio +import torch + +from torchtune.models import llama3 + +from exo.inference.inference_engine import InferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.shard import Shard +from exo.helpers import DEBUG +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_model_weights_torchtune, +) + +# supported models +from exo.inference.torch.models.llama3 import ShardedLlamaModel + +TEMP = 0.6 +TOP_K = 25 + +class TorchDynamicShardInferenceEngine(InferenceEngine): + def __init__(self, shard_downloader: HFShardDownloader, model_id: str="llama"): + self.shard = None + self.shard_downloader = shard_downloader + self.model_id = model_id + self.supported_models = ["llama"] + + # device settings + if os.environ.get("TORCH_DEVICE"): + self.device = torch.device(os.environ["TORCH_DEVICE"]) + elif torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + + async def infer_prompt( + self, + request_id: str, + shard: Shard, + prompt: str, + image_str: Optional[str] = None, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_prompt called") + print(f"prompt: {prompt}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + # ensure shard + await self.ensure_shard(shard) + + # tokenize + tokens = torch.tensor( + self.tokenizer.encode(prompt, add_bos=True, add_eos=True), + dtype=torch.int + ) + hidden_states = None + + # generate + loop = asyncio.get_running_loop() + with ThreadPoolExecutor() as pool: + hidden_states, logits, finished = await loop.run_in_executor( + pool, + functools.partial( + self.sharded_model.generate, + tokens=tokens + ) + ) + + if hidden_states is not None: + return hidden_states.numpy(force=True), "", finished + else: + return logits.numpy(force=True), "", finished + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + # ensure shard + await self.ensure_shard(shard) + + return np.empty((1,1)), "", False + + async def ensure_shard(self, shard: Shard): + if self.shard == shard: + return + + # download model safetensors and shard + model_path = await self.shard_downloader.ensure_shard(shard) + model_config = load_model_config(model_path / "config.json") + + self.tokenizer = llama3.llama3_tokenizer( + path=f"{model_path}/original/tokenizer.model" + ) + + if self.model_id not in self.supported_models: + raise ValueError( + f"Model {self.model_id} not supported, only supported models are\n{self.supported_models}" + ) + + self.sharded_model = ShardedLlamaModel( + model_config, + shard, + self.tokenizer, + self.device, + None, + use_cache=True + ) + + # load sharded weights + load_model_weights_torchtune( + model_path, + shard, + self.sharded_model + ) diff --git a/exo/inference/torch/tests/__init__.py b/exo/inference/torch/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/exo/inference/torch/tests/test_hf_inference_engine.py b/exo/inference/torch/tests/test_hf_inference_engine.py new file mode 100644 index 00000000..c7230c89 --- /dev/null +++ b/exo/inference/torch/tests/test_hf_inference_engine.py @@ -0,0 +1,141 @@ +""" +Test inference engine and model sharding +""" +import time +import asyncio + +from exo.inference.shard import Shard +from exo.inference.torch.hf_inference import HFDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.inference_engine import InferenceEngine + +import numpy as np + +async def test_inference_engine( + inference_engine_1: InferenceEngine, + inference_engine_2: InferenceEngine, + model_id: str, + n_layers: int): + + prompt = "In a single word only, what is the last name of the current president of the USA?" + + shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=0, + n_layers=n_layers + ) + + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + "A", + shard=shard, + prompt=prompt + ) + + print("\n------------resp_full---------------\n") + print(resp_full) + print("\n------------resp_full---------------\n") + + time.sleep(5) + + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + "A", + shard=shard, + input_data=resp_full, + inference_state=inference_state_full, + ) + + print("\n------------next_resp_full---------------\n") + print(next_resp_full) + print("\n------------next_resp_full---------------\n") + + time.sleep(5) + + half_layer = int(n_layers/2) + + resp_shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=half_layer, + n_layers=n_layers + ) + + resp_shard2 = Shard( + model_id=model_id, + start_layer=half_layer+1, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( + "B", + shard=resp_shard, + prompt=prompt + ) + + print("\n------------resp1---------------\n") + print(resp1) + print("\n------------resp1---------------\n") + + time.sleep(5) + + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp1, + inference_state=inference_state_1, + ) + + print("\n------------resp2---------------\n") + print(resp2) + print("\n------------resp2---------------\n") + + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + "B", + shard=resp_shard, + input_data=resp2, + inference_state=inference_state_2, + ) + + print("\n------------resp3---------------\n") + print(resp3) + print("\n------------resp3---------------\n") + + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp3, + inference_state=inference_state_3, + ) + + print("\n------------resp4---------------\n") + print(resp4) + print("\n------------resp4---------------\n") + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) + +if __name__ == '__main__': + try: + print("\n\n -------- TEST Qwen/Qwen2-0.5B-Instruct -------- \n\n") + asyncio.run(test_inference_engine( + HFDynamicShardInferenceEngine(HFShardDownloader()), + HFDynamicShardInferenceEngine(HFShardDownloader()), + "Qwen/Qwen2-0.5B-Instruct", + 36 + )) + except Exception as err: + print(f"\n!!!! QWEN2 TEST FAILED \n{err}\n") + + #try: + # print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") + # asyncio.run(test_inference_engine( + # TorchDynamicShardInferenceEngine(HFShardDownloader()), + # TorchDynamicShardInferenceEngine(HFShardDownloader()), + # "unsloth/Meta-Llama-3.1-8B-Instruct", + # 32 + # )) + #except Exception as err: + # print(f"\n!!!! unsloth/Meta-Llama-3.1-8B-Instruct TEST FAILED \n{err}\n") + + diff --git a/exo/inference/torch/tests/test_llama3_full.py b/exo/inference/torch/tests/test_llama3_full.py new file mode 100644 index 00000000..7ffb4dce --- /dev/null +++ b/exo/inference/torch/tests/test_llama3_full.py @@ -0,0 +1,129 @@ +""" +Test of pytorch based llama3 models +full layer run +""" + +from pathlib import Path +import torch +from huggingface_hub import snapshot_download + +import torchtune.generation as ttg +from torchtune.models import llama3 +from torchtune.data import Message + + +from exo.inference.torch.models.llama3 import ShardedLlamaModel +from exo.inference.shard import Shard + +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_model_weights_torchtune, +) + + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +TEMP = 0.6 +TOP_K = 25 +MAX_NEW_TOKENS = 10 + +def main(model, prompt: str, device: torch.device=torch.device("cpu")): + # Tokenize input text + messages = [] + messages.extend([ + Message(role="system", content="You are a helpful and creative AI assistant."), + Message(role="user", content=prompt), + # Empty assistant message to kick-start generation + Message(role="assistant", content=""), + ]) + + tokenizer_out = llama_tokenizer({"messages": messages}, inference=True) + print(f"tokenizer_out: {tokenizer_out}") + tokens = torch.tensor(tokenizer_out["tokens"], dtype=torch.int, device=device) + + _, logits = model.generate(tokens=tokens) + + tokens = ttg.sample(logits=logits[:, -1].clone(), temperature=TEMP, top_k=TOP_K) + + print(f"tokens: {tokens}") + + generated_tokens = tokens.clone().tolist() + print(f"generated_tokens: {generated_tokens}") + print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(generated_tokens[0])}\n\n\n") + + +def normal_full(model, user_prompt: str, device: torch.device=torch.device("cpu")): + # Tokenize input text + messages = [] + messages.extend([ + Message(role="system", content="You are a helpful and creative AI assistant."), + Message(role="user", content=user_prompt), + # Empty assistant message to kick-start generation + Message(role="assistant", content=""), + ]) + + tokenizer_out = llama_tokenizer({"messages": messages}, inference=True) + prompt = torch.tensor(tokenizer_out["tokens"], dtype=torch.int, device=device) + print(f"tokens prompt: {prompt}") + print(f"pad_id: {llama_tokenizer.pad_id}") + + + generated_tokens, _ = ttg.generate( + model=model.model, + prompt=prompt, + max_generated_tokens=MAX_NEW_TOKENS, + pad_id=llama_tokenizer.pad_id, + temperature=TEMP, + top_k=TOP_K, + stop_tokens=llama_tokenizer.stop_tokens, + ) + + generated_tokens = generated_tokens[:, -MAX_NEW_TOKENS:].tolist() + + print(f"generated_tokens: {generated_tokens}") + + print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(generated_tokens[0])}\n\n\n") + + +if __name__ == "__main__": + # prompt = "hello" + prompt = "What is the capital of france?" + + # Get the path to the model files from the Hugging Face cache + cache_dir = Path(snapshot_download(MODEL_NAME)) + print(f"Cache directory: {cache_dir}") + + # Load model configuration + config = load_model_config(cache_dir / "config.json") + + print(f"current config\n{config}") + + # Setup shard + n_layers = int(config["num_layers"]) + shard_1 = Shard( + model_id=MODEL_NAME, + start_layer=0, + end_layer=n_layers-1, + n_layers=n_layers, + ) + + # Initialize tokenizer + llama_tokenizer_path = f"{cache_dir}/original/tokenizer.model" + llama_tokenizer = llama3.llama3_tokenizer(path=llama_tokenizer_path) + + # Initialize LlamaModel with config and tokenizer + # device = torch.device("cuda") + device = None + shard_model_1 = ShardedLlamaModel( + config, + shard_1, + llama_tokenizer, + device, + MAX_NEW_TOKENS, + use_cache=True + ) + print(f"\nshard_model_1: {shard_model_1}") + + load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) + + # main(shard_model_1, prompt, device) + normal_full(shard_model_1, prompt, device) diff --git a/exo/inference/torch/tests/test_llama3_split.py b/exo/inference/torch/tests/test_llama3_split.py new file mode 100644 index 00000000..68272765 --- /dev/null +++ b/exo/inference/torch/tests/test_llama3_split.py @@ -0,0 +1,134 @@ +""" +Test of pytorch based llama3 model +""" + +from pathlib import Path +import torch +from huggingface_hub import snapshot_download + +import torchtune.generation as ttg +from torchtune.models import llama3 +from torchtune.data import Message + + +from exo.inference.torch.models.llama3 import ShardedLlamaModel +from exo.inference.shard import Shard + +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_model_weights_torchtune, +) + + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +TEMP = 0.6 +TOP_K = 25 +MAX_NEW_TOKENS=10 + + +def test_generation_1(shard_model, prompt): + """ + Test the generation capabilities of the LlamaModel with sample text. + """ + # Tokenize input text + messages = [] + messages.extend([ + Message(role="system", content="You are a helpful and creative AI assistant."), + Message(role="user", content=prompt), + # Empty assistant message to kick-start generation + Message(role="assistant", content=""), + ]) + + print(f"last?: {shard_model.shard.is_last_layer()}") + tokenizer_out = llama_tokenizer({"messages": messages}, inference=True) + print(f"tokenizer_out: {tokenizer_out}") + tokens = torch.tensor(tokenizer_out["tokens"], dtype=torch.int) + + hidden_states, _ = shard_model.generate(tokens) + + if hidden_states is not None: + print(f"hidden_states[{len(hidden_states)}]: {hidden_states}") + + return hidden_states, tokens + + +def test_generation_2(shard_model, in_tokens, hidden_state): + print("Generate with the rest of layers") + hidden_states, logits = shard_model.generate(tokens=in_tokens, hidden_state=hidden_state) + + if hidden_states is not None: + print(f"hidden_states {hidden_states.shape}: {hidden_states}") + + if logits is not None: + print(f"logits: {logits.shape}\n{logits}") + + # rand_sample = torch.empty(( + # logits.size(0), + # shard_model.model.tok_embeddings.num_embeddings + # ), + # device=logits.device + # ).exponential_(1, generator=None) + + tokens = ttg.sample( + logits=logits[:, -1].clone(), + temperature=TEMP, + top_k=TOP_K, + # q=rand_sample + ) + + print(f"tokens: {tokens}") + + generated_tokens = tokens.clone() + generated_tokens = generated_tokens.tolist() + + print(f"generated_tokens: {generated_tokens}") + + print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(generated_tokens[0])}\n\n\n") + + +if __name__ == "__main__": + print("\nTesting generation:") + + prompt = "Hello, just say 'Hello' back nothing else" + + # Get the path to the model files from the Hugging Face cache + cache_dir = Path(snapshot_download(MODEL_NAME)) + + # Load model configuration + config = load_model_config(cache_dir / "config.json") + + # Setup shard + n_layers = int(config["num_layers"]) + s1_end = int(n_layers / 2) + shard_1 = Shard(model_id=MODEL_NAME, start_layer=0, end_layer=s1_end, n_layers=n_layers) + + shard_2 = Shard(model_id=MODEL_NAME, start_layer=s1_end + 1, end_layer=n_layers - 1, n_layers=n_layers) + + # Initialize tokenizer + llama_tokenizer_path = f"{cache_dir}/original/tokenizer.model" + llama_tokenizer = llama3.llama3_tokenizer(path=llama_tokenizer_path) + + # Initialize LlamaModel with config and tokenizer + shard_model_1 = ShardedLlamaModel( + config, + shard_1, + llama_tokenizer, + None, + MAX_NEW_TOKENS, + use_cache=True + ) + print(f"\nshard_model_1: {shard_model_1}") + load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) + shard_1_hs, shard_1_tokens = test_generation_1(shard_model_1, prompt) + + shard_model_2 = ShardedLlamaModel( + config, + shard_2, + llama_tokenizer, + None, + MAX_NEW_TOKENS, + use_cache=True + ) + print(f"\nshard_model_2: {shard_model_2}") + load_model_weights_torchtune(cache_dir, shard_2, shard_model_2) + test_generation_2(shard_model_2, shard_1_tokens, shard_1_hs) diff --git a/exo/inference/torch/tests/test_pt_inference_engine.py b/exo/inference/torch/tests/test_pt_inference_engine.py new file mode 100644 index 00000000..e430989a --- /dev/null +++ b/exo/inference/torch/tests/test_pt_inference_engine.py @@ -0,0 +1,53 @@ +""" +Test inference engine and model sharding +""" +import time +import asyncio + +from exo.inference.shard import Shard +from exo.inference.torch.pt_inference import TorchDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.inference_engine import InferenceEngine + +import numpy as np + +async def test_inference_engine( + inference_engine_1: InferenceEngine, + inference_engine_2: InferenceEngine, + model_id: str, + n_layers: int): + + prompt = "In a single word only, what is the last name of the current president of the USA?" + + shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=0, + n_layers=n_layers + ) + + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + "A", + shard=shard, + prompt=prompt + ) + + print("\n------------resp_full---------------\n") + print(resp_full) + print("\n------------resp_full---------------\n") + + time.sleep(5) + +if __name__ == '__main__': + try: + print("\n\n -------- TEST meta-llama/Llama-3.2-1B-Instruct -------- \n\n") + asyncio.run(test_inference_engine( + TorchDynamicShardInferenceEngine(HFShardDownloader()), + TorchDynamicShardInferenceEngine(HFShardDownloader()), + "meta-llama/Llama-3.2-1B-Instruct", + 16 + )) + except Exception as err: + print(f"\n!!!! LLAMA TEST FAILED \n{err}\n") + + diff --git a/exo/inference/torch/tests/test_safetensor_json.py b/exo/inference/torch/tests/test_safetensor_json.py new file mode 100644 index 00000000..3ec02c71 --- /dev/null +++ b/exo/inference/torch/tests/test_safetensor_json.py @@ -0,0 +1,120 @@ +""" +Create a model.safetensors.index.json from safetensors +""" +import json +import os + +import asyncio + +from safetensors import safe_open + +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.shard import Shard + +import torch + +def create_safetensor_index(safetensor_files: list, index_file: str): + """ + Creates a model.safetensors.index.json file from a list of safetensor files. + + Args: + safetensor_files (list): List of paths to the safetensor files. + index_file (str): Path where the index JSON file should be saved. + + Raises: + ValueError: If an unsupported data type is encountered. + """ + if safetensor_files: + # Initialize the metadata and weight_map + metadata = { + "metadata": { + "total_size": 0 + }, + "weight_map": {} + } + + for safetensor_file in safetensor_files: + # Use the safetensor file name as the shard_name + shard_name = os.path.basename(safetensor_file) + + # Open the safetensor file to read the metadata + with safe_open(safetensor_file, framework="pt") as f: + # Get tensor names + tensor_names = f.keys() + + # Collect metadata for each tensor + for name in tensor_names: + tensor_data = f.get_tensor(name) + print(f"tensor_data: {tensor_data}") + shape = tensor_data.shape + dtype = tensor_data.dtype + print(f"shape: {shape}") + print(f"dtype: {str(dtype) == "torch.bfloat16"}") + + # Calculate the tensor size in bytes based on dtype + total_elements = 1 + for dim in shape: + total_elements *= dim + + if dtype == torch.float32: + element_size = 4 + elif dtype == torch.float16 or dtype == torch.bfloat16: + element_size = 2 + # Extend this to support more data types if needed + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + tensor_size = total_elements * element_size + metadata["metadata"]["total_size"] += tensor_size + + # Add to weight_map, mapping the tensor to the shard (file) name + metadata["weight_map"][name] = shard_name + + # Write the metadata and weight map to the index file + with open(index_file, "w") as f: + json.dump(metadata, f, indent=4) + + print(f"Index file created: {index_file}") + else: + print("No safetensor files provided.") + + +async def main(): + """ + Main asynchronous function to download the model shard and create an index file for safetensors. + + This function downloads a model shard from Hugging Face, identifies safetensor files, and + generates a corresponding index file using the `create_safetensor_index` function. + """ + start_layer = 3 + end_layer = 5 + + # Create a Shard object + shard = Shard( + model_id="meta-llama/Llama-3.2-1B-Instruct", + start_layer=start_layer, + end_layer=end_layer-1, + n_layers=32 + ) + + print(f"Loading shard: {shard}") + shard_downloader = HFShardDownloader() + + # Ensure shard is downloaded + model_path = await shard_downloader.ensure_shard(shard) + + # Collect all safetensor files from the model path + safetensor_files = [ + os.path.join(model_path, file_name) + for file_name in os.listdir(model_path) if file_name.endswith(".safetensors") + ] + + # Create the index file + if safetensor_files: + create_safetensor_index(safetensor_files, os.path.join(model_path, "model.safetensors.index.json")) + else: + print("No safetensor files found in the model path.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/exo/inference/torch/tests/test_safetensor_shard.py b/exo/inference/torch/tests/test_safetensor_shard.py new file mode 100644 index 00000000..dd84ff18 --- /dev/null +++ b/exo/inference/torch/tests/test_safetensor_shard.py @@ -0,0 +1,69 @@ +""" +Sharding safetensor +""" + +import asyncio + +from exo.inference.shard import Shard +from exo.inference.torch.models.hf_safe_tensor_shard import HFSafeTensorShard +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.hf.hf_helpers import get_weight_map + +from transformers import AutoModelForCausalLM, AutoTokenizer + +async def main(): + start_layer = 0 + end_layer = 1 + + # Create a Shard object + shard = Shard( + model_id="unsloth/Meta-Llama-3.1-8B-Instruct", + start_layer=start_layer, + end_layer=end_layer-1, + n_layers=32 + ) + + print(f"Loading shard: {shard}") + shard_downloader = HFShardDownloader() + + # Ensure shard is downloaded + model_path = await shard_downloader.ensure_shard(shard) + + # weight map, if any + model_wm = await get_weight_map( + repo_id=shard.model_id + ) + + tensor_shard = HFSafeTensorShard(model_path, shard) + tensor_shard.modify_safetensor() + tensor_shard.create_safetensor_index() + + # load model and test + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=shard.model_id, + local_files_only=True, + num_hidden_layers=shard.end_layer - shard.start_layer, + #device_map="auto", + torch_dtype="float16" + ).to("cuda") + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "In one simple word, what is the color of a red apple?"} + ] + + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = tokenizer([text], return_tensors="pt") + + print(f"model_inputs:\n{model_inputs}") + + tensor_shard.restore_backups() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/exo/inference/torch/tests/test_simple_model.py b/exo/inference/torch/tests/test_simple_model.py new file mode 100644 index 00000000..5ffd30ef --- /dev/null +++ b/exo/inference/torch/tests/test_simple_model.py @@ -0,0 +1,50 @@ +""" +Simple model test using basic pytorch/huggingface LLM model loading, inference and generation +with logit sampling +""" +from transformers import AutoModelForCausalLM, AutoTokenizer + +def run_simple(prompt: str): + model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen2-0.5B-Instruct", + torch_dtype="auto", + device_map="auto" + ) + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = tokenizer([text], return_tensors="pt") + + print(f"model_inputs:\n{model_inputs}") + + print(f"generation_config:\n{model.generation_config}") + + generated_ids = model.generate( + model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + max_new_tokens=512, + do_sample=True + ) + + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + print(f"Prompt: {prompt}\n") + print(f"Response: {response}\n") + +if __name__ == "__main__": + run_simple( + "In a single word only, what is the last name of the current president of the USA?" + ) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py new file mode 100644 index 00000000..197a7c07 --- /dev/null +++ b/exo/inference/torch/tests/test_split_model.py @@ -0,0 +1,214 @@ +""" +Testing of loading model by layer +""" +import asyncio +import re +import json +import os +from pathlib import Path +from typing import Optional + +import torch + +from exo.download.hf.hf_helpers import get_weight_map +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.shard import Shard +from exo.inference.torch.utils import print_cuda_vram_stats + +from transformers import AutoModelForCausalLM, AutoTokenizer + +def load_model( + shard: Shard, + model_path: Path, + weight_map: Optional[dict], + device: Optional[torch.device] = torch.device("cpu") +) -> Optional[AutoModelForCausalLM]: + """ + load model by layer and safetensors + return causal llm automodel with only requested layers, if weight maps + if no weight map, return and load the whole model + """ + print("load_model called") + model_st_snapshot = model_path/"model.safetensors.index.json" + + if os.environ.get("TORCH_DEVICE"): + device = torch.device(os.environ["TORCH_DEVICE"]) + elif torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") + + torch.set_default_device(device) + + # setup cude dtype + dtype = torch.get_default_dtype() + + # setup device_map + if os.environ.get("TORCH_DEVICE_MAP"): + device_map = os.environ["TORCH_DEVICE_MAP"] + else: + device_map = str(device) + + if weight_map: + layer_weight_map = {} + non_layer_weights = [] + + for wname, wtensor in weight_map.items(): + # get layer number + layer_rgx = r'^model\.layers\.(\d+)\.*' + layer_found = re.findall(layer_rgx, wname) + print(f"wname: {wname}") + if layer_found: + print(f"layer_found: {layer_found}") + # slice up layer map to start and end layers + # from shard + layer_idx = int(layer_found[0]) + if shard.start_layer <= layer_idx <= shard.end_layer: + layer_weight_map[wname] = wtensor + else: + non_layer_weights.append((wname, wtensor)) + + non_layer_weights = sorted(non_layer_weights, key=lambda x: x[1]) + + print(f"sorted non_layer_weights: {non_layer_weights}") + + if shard.is_first_layer(): + # this assumes at max only one first weight non-layer for model + first_weight = non_layer_weights[0] + layer_weight_map[first_weight[0]] = first_weight[1] + elif shard.is_last_layer(): + last_weights = non_layer_weights[1:] + for last_weight in last_weights: + layer_weight_map[last_weight[0]] = last_weight[1] + + # rewrite model.safetensors.index.json + try: + mst_json = {} + with open(model_st_snapshot, "r") as mst_file: + mst_json = json.load(mst_file) + mst_json["weight_map"] = layer_weight_map + + print(f"mst_json: {json.dumps(mst_json, indent=4)}") + + os.remove(model_st_snapshot) + + with open(model_st_snapshot, "w") as mst_file: + json.dump(mst_json, mst_file, indent=4) + except Exception as err: + print(f"err: {err}") + raise + + else: + print("weight_map not found, loading whole model") + + # setup the weight range for init_weights + shard_num_hidden_layers = shard.end_layer - shard.start_layer + print(f"Setting up LLM config with {shard_num_hidden_layers} hidden layers") + + # load model with layer edits + # or whole model if no weight_map + print(f"Loading sharded AutoModelForCausalLM from {model_path}") + shard_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_path, + device_map=device_map, + torch_dtype=dtype, + offload_buffers=True, + local_files_only=True, + num_hidden_layers=shard_num_hidden_layers + ).to(device) + + print("Loading tokenizer") + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=model_path, + local_files_only=True, + ) + + if torch.cuda.is_available() and device == "cuda": + print_cuda_vram_stats() + + prompt = "In a single word only, what color is a red apple?" + + model_inputs = tokenizer( + [prompt], + return_tensors="pt" + ) + + generated_ids = shard_model.generate( + model_inputs.input_ids.to(device), + attention_mask=model_inputs.attention_mask.to(device), + max_new_tokens=512, + do_sample=True + ) + + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip( + model_inputs.input_ids, + generated_ids + ) + ] + + response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + print(f"Prompt: {prompt}\n") + print(f"Response: {response}\n") + + # have to clear out edited model safetensors mst_json + os.remove(model_st_snapshot) + + return shard_model + +async def test_split_model( + model_id: str, + start_layer: int, + end_layer: int, + n_layers: int +): + """ + Test to load split models + """ + + shard = Shard( + model_id=model_id, + start_layer=start_layer, + end_layer=end_layer-1, + n_layers=n_layers + ) + + print(f"loading shard: {shard}") + shard_downloader = HFShardDownloader() + model_path = await shard_downloader.ensure_shard(shard) + weight_map = await get_weight_map(model_id) + + load_model( + shard, + model_path, + weight_map + ) + +if __name__ == "__main__": + n_layers = int(os.environ["N_LAYERS"]) if os.environ.get("N_LAYERS") else 32 + start_layer = int(os.environ["START_LAYER"]) if os.environ.get("START_LAYER") else 0 + end_layer = int(os.environ["END_LAYER"]) if os.environ.get("END_LAYER") else int(n_layers/2) + #Qwen/Qwen2.5-3B + #try: + # print("\n-------- Test Qwen/Qwen2.5-3B-Instruct ----------\n") + # asyncio.run(test_split_model( + # "Qwen/Qwen2.5-3B-Instruct", + # 0, + # 6, + # 36 + # )) + #except Exception as err: + # print(f"\n\n !!!!!!!!!!! Qwen/Qwen2.5-3B-Instruct TEST FAILED \n{err}\n") + + # unsloth/Meta-Llama-3.1-8B-Instruct + try: + print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") + asyncio.run(test_split_model( + "unsloth/Meta-Llama-3.1-8B-Instruct", + start_layer, + end_layer, + n_layers + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.1-8B-Instruct TEST FAILED \n{err}\n") diff --git a/exo/inference/torch/tests/utils.py b/exo/inference/torch/tests/utils.py new file mode 100644 index 00000000..e4062da9 --- /dev/null +++ b/exo/inference/torch/tests/utils.py @@ -0,0 +1,185 @@ +import torch +from torch.nn import functional as F + +def top_k_sampling(logits, thres): + num_logits = logits.shape[-1] + val, ind = torch.topk(logits, thres, dim=-1, largest=True, sorted=True) + mask = torch.zeros_like(logits) + mask.scatter_(-1, ind, 1) + logits = logits * mask + + return logits + +def top_p_sampling(logits, thres): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + print(f"top_p_sampling sorted_logits\n{sorted_logits}\nsorted_indices {sorted_indices}") + softmax_logits = F.softmax(sorted_logits, dim=-1) + print(f"top_p_sampling\nsoftmax_logits {softmax_logits}") + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + print(f"top_p_sampling\n{cumulative_probs}") + + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > thres + + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) + print(f"top_p_sampling\nindicies_to_remove: {indices_to_remove}") + logits[indices_to_remove] = float('-inf') + return logits + +def sample_logits(logits, temp, top_p, top_k): + """ + Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. + + Args: + logits (torch.Tensor): The logits distribution to sample from. + temp (float): temp for scaling logits. + top_p (float): The cumulative probability threshold for nucleus sampling. + + Returns: + torch.Tensor: The selected token index. + """ + # If temp is very low, just use argmax + if temp == 0: + return logits.argmax(dim=-1) + + print(f"logits {logits}") + + scaled_logits = logits/temp + + print(f"scaled_logits: {scaled_logits}") + + if 0 < top_p < 1.0: + top_p_logits = top_p_sampling(scaled_logits, top_p) + print(f"top_p logits {top_p_logits}") + if top_k > 0: + top_k_logits = top_k_sampling(top_p_logits, top_k) + return top_k_logits.argmax(dim=-1) + elif top_k > 0: + top_k_logits = top_k_sampling(logits, top_k) + print(f"top_k logits {top_k_logits}") + return top_k_logits.argmax(dim=-1) + + return scaled_logits.argmax(dim=-1) + + +# from tinygrad llama model sample +def sample(logits: torch.Tensor, temp: float, k: int, p: float, af: float, ap: float): + assert logits.ndim == 1, "only works on 1D tensors" + assert 0 <= p <= 1, "p must be between 0 and 1" + assert 0 <= k <= logits.numel(), "k must be between 0 and numel" + + # If temperature is very low, just use argmax + if temp < 1e-6: + return logits.argmax().reshape(1) + + # Alpha sampling + if af or ap: + if not hasattr(sample, "alpha_counter"): + sample.alpha_counter = torch.zeros_like(logits, dtype=torch.int32).contiguous() + logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0).float() * ap) + + # Replace NaNs with -inf + logits = torch.where(logits != logits, torch.tensor(-float("inf"), device=logits.device), logits) + + # Apply softmax after temperature scaling + t = F.softmax(logits / temp, dim=-1) + + counter = torch.arange(t.numel(), device=logits.device).contiguous() + counter2 = torch.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous() + + # Top-k sampling + if k: + output = torch.zeros(k, device=logits.device).contiguous() + output_indices = torch.zeros(k, device=logits.device, dtype=torch.int32).contiguous() + + for i in range(k): + t_max = t.max() + t_argmax = (t.numel() - ((t == t_max) * counter2).max() - 1).to(torch.int) + output[i] = t_max + output_indices[i] = t_argmax + t = torch.where(counter == t_argmax, torch.tensor(0.0, device=logits.device), t) + + # Approximate top-p sampling + output_cumsum = output.flip(dims=(0,)).cumsum(dim=0).flip(dims=(0,)) + t.sum() + mask = output_cumsum >= (1 - p) + output = output * mask.float() + output_indices = output_indices * mask.int() + + # Sample from the distribution + output_idx = output.multinomial(num_samples=1) + output_token = output_indices[output_idx] + else: + output_token = t.multinomial(num_samples=1) + + # Increase alpha counter + if af or ap: + sample.alpha_counter = torch.where(counter == output_token, sample.alpha_counter + 1, sample.alpha_counter) + + return output_token + + +def sample_3d(logits: torch.Tensor, temp: float, k: int, p: float, af: float, ap: float): + assert logits.ndim == 3, "only works on 3D tensors" + assert 0 <= p <= 1, "p must be between 0 and 1" + assert 0 <= k <= logits.shape[-1], "k must be between 0 and the last dimension size" + + batch_size, seq_len, vocab_size = logits.shape + + # If temperature is very low, just use argmax + if temp < 1e-6: + return logits.argmax(dim=-1) + + # Alpha sampling + if af or ap: + if not hasattr(sample, "alpha_counter"): + sample.alpha_counter = torch.zeros_like(logits, dtype=torch.int32).contiguous() + logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0).float() * ap) + + # Replace NaNs with -inf + logits = torch.where(logits != logits, torch.tensor(-float("inf"), device=logits.device), logits) + + # Apply softmax after temperature scaling + t = F.softmax(logits / temp, dim=-1) + + counter = torch.arange(vocab_size, device=logits.device).unsqueeze(0).unsqueeze(0).expand_as(t).contiguous() + counter2 = torch.arange(vocab_size - 1, -1, -1, device=logits.device).unsqueeze(0).unsqueeze(0).expand_as(t).contiguous() + + # Top-k sampling + if k: + output = torch.zeros((batch_size, seq_len, k), device=logits.device).contiguous() + output_indices = torch.zeros((batch_size, seq_len, k), device=logits.device, dtype=torch.int32).contiguous() + + for i in range(k): + t_max, _ = t.max(dim=-1, keepdim=True) + t_argmax = (vocab_size - ((t == t_max) * counter2).max(dim=-1, keepdim=True)[0] - 1).to(torch.int) + output[:, :, i] = t_max.squeeze(-1) + output_indices[:, :, i] = t_argmax.squeeze(-1) + t = torch.where(counter == t_argmax, torch.tensor(0.0, device=logits.device), t) + + # Approximate top-p sampling + output_cumsum = output.flip(dims=(-1,)).cumsum(dim=-1).flip(dims=(-1,)) + t.sum(dim=-1, keepdim=True) + mask = output_cumsum >= (1 - p) + output = output * mask.float() + output_indices = output_indices * mask.int() + + # Sample from the distribution + output_flat = output.view(batch_size * seq_len, -1) + output_idx = output_flat.multinomial(num_samples=1).squeeze(-1) + output_indices_flat = output_indices.view(batch_size * seq_len, -1) + output_token = output_indices_flat.gather(dim=-1, index=output_idx.unsqueeze(-1)).view(batch_size, seq_len) + else: + output_flat = t.view(batch_size * seq_len, -1) + output_token = output_flat.multinomial(num_samples=1).view(batch_size, seq_len) + + # Increase alpha counter + if af or ap: + sample.alpha_counter = torch.where(counter == output_token.unsqueeze(-1), sample.alpha_counter + 1, sample.alpha_counter) + + return output_token + diff --git a/exo/inference/torch/utils.py b/exo/inference/torch/utils.py new file mode 100644 index 00000000..b9c4f148 --- /dev/null +++ b/exo/inference/torch/utils.py @@ -0,0 +1,60 @@ +""" +Utility functions to be used by inference engine +and model +""" +import re + +from exo.inference.shard import Shard + +import torch + +def extract_layers( + weight_map: dict, + shard: Shard +) -> dict: + """ + Extract layers from weight map in range + + Args: + + Returns: + """ + + layer_rgx = r'^model\.layers\.(\d+)\.*' + layer_weight_map = {} + non_layer_weights = [] + + for wname, wtensor in weight_map.items(): + layer_found = re.findall(layer_rgx, wname) + if layer_found: + layer_idx = int(layer_found[0]) + if shard.start_layer <= layer_idx <= shard.end_layer: + layer_weight_map[wname] = wtensor + else: + non_layer_weights.append((wname, wtensor)) + + non_layer_weights = sorted(non_layer_weights, key=lambda x: x[1]) + + if shard.is_first_layer(): + # this assumes at max only one first weight non-layer for model + first_weight = non_layer_weights[0] + layer_weight_map[first_weight[0]] = first_weight[1] + elif shard.is_last_layer(): + last_weights = non_layer_weights[1:] + for last_weight in last_weights: + layer_weight_map[last_weight[0]] = last_weight[1] + + return layer_weight_map + +def print_cuda_vram_stats(): + """ + Prints CUDA VRAM stats being used by pytorch + """ + allocated_memory = torch.cuda.memory_allocated() + max_memory = torch.cuda.max_memory_allocated() + cached_memory = torch.cuda.memory_reserved() + + print("CUDA stats") + print(f'Allocated memory: {allocated_memory / 1024**2} MB') + print(f'Max allocated memory: {max_memory / 1024**2} MB') + print(f'Cached memory: {cached_memory / 1024**2} MB') diff --git a/exo/main.py b/exo/main.py index 928dd4d1..aa3c00d9 100644 --- a/exo/main.py +++ b/exo/main.py @@ -111,6 +111,7 @@ raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.") discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities)) topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None + node = StandardNode( args.node_id, None, diff --git a/exo/models.py b/exo/models.py index 1fb567a6..b1b39e66 100644 --- a/exo/models.py +++ b/exo/models.py @@ -9,6 +9,7 @@ "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit", "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct", }, + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), }, "llama-3.2-3b": { "layers": 28, @@ -16,6 +17,7 @@ "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit", "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct", }, + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-3B-Instruct", start_layer=0, end_layer=0, n_layers=28), }, "llama-3.1-8b": { "layers": 32, @@ -23,6 +25,7 @@ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", }, + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Meta-Llama-3.1-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), }, "llama-3.1-70b": { "layers": 80, @@ -30,6 +33,7 @@ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct", }, + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80), }, "llama-3.1-70b-bf16": { "layers": 80, @@ -44,6 +48,7 @@ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit", "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", }, + "TorchDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), }, "llama-3-70b": { "layers": 80, diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index a788b87a..8799b250 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -12,7 +12,6 @@ from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops from exo.helpers import DEBUG - class GRPCPeerHandle(PeerHandle): def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities): self._id = _id @@ -78,6 +77,7 @@ async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] ), request_id=request_id, ) + response = await self.stub.SendPrompt(request) if not response.tensor_data or not response.shape or not response.dtype: @@ -96,6 +96,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Option tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)), request_id=request_id, ) + response = await self.stub.SendTensor(request) if not response.tensor_data or not response.shape or not response.dtype: diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index 44fbb99b..b9d105a3 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -18,7 +18,7 @@ - + diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index 4424f15d..65d6fe9f 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -117,6 +117,10 @@ def to_dict(self): "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA T1000 8GB": DeviceFlops(fp32=2.5 * TFLOPS, fp16=5.0 * TFLOPS, int8=10.0 * TFLOPS), + "Quadro M2000": DeviceFlops(fp32=0.5 * TFLOPS, fp16=1.0 * TFLOPS, int8=2.0 * TFLOPS), + "Quadro P400": DeviceFlops(fp32=0.641 * TFLOPS, fp16=1.282 * TFLOPS, int8=2.564 * TFLOPS), + "NVIDIA A10": DeviceFlops(fp32=31.2 * TFLOPS, fp16=62.5 * TFLOPS, int8=2.5 * TFLOPS), # ... add more devices if needed ... ### AMD GPUs # RX 6000 series diff --git a/install.ps1 b/install.ps1 new file mode 100644 index 00000000..c766cdd5 --- /dev/null +++ b/install.ps1 @@ -0,0 +1,8 @@ +# Create a virtual environment +python3 -m venv .venv + +# Activate the virtual environment +& .\.venv\Scripts\Activate.ps1 + +# Install the package in the virtual environment +pip install . diff --git a/setup.py b/setup.py index e2a855c9..b7ac7f1f 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,8 @@ "transformers==4.46.3", "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79", + "torch==2.4.0", + "accelerate==0.34.2" ] extras_require = {