diff --git a/.circleci/config.yml b/.circleci/config.yml index d029943a3..b74b72c1b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -17,11 +17,11 @@ commands: source env/bin/activate # Start first instance - HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout-secs 900 > output1.log 2>&1 & + HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log & PID1=$! # Start second instance - HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout-secs 900 > output2.log 2>&1 & + HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --inference-engine <> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 2>&1 | tee output2.log & PID2=$! # Wait for discovery @@ -144,7 +144,7 @@ jobs: PID2=$! sleep 10 kill $PID1 $PID2 - if grep -q "Connected to peer" output1.log && grep -q "Connected to peer" output2.log; then + if grep -q "Successfully connected peers: \['node2@.*:.*'\]" output1.log && ! grep -q "Failed to connect peers:" output1.log && grep -q "Successfully connected peers: \['node1@.*:.*'\]" output2.log && ! grep -q "Failed to connect peers:" output2.log; then echo "Test passed: Both instances discovered each other" exit 0 else diff --git a/README.md b/README.md index a0becd1ff..089fef7ee 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,12 @@ exo logo -exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs_). +exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).

-[Discord](https://discord.gg/EUnjGpsmWw) | [Telegram](https://t.me/+Kh-KqHTzFYg3MGNk) | [X](https://x.com/exolabs_) +[Discord](https://discord.gg/EUnjGpsmWw) | [Telegram](https://t.me/+Kh-KqHTzFYg3MGNk) | [X](https://x.com/exolabs)

@@ -25,14 +25,12 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l Forget expensive NVIDIA GPUs, unify your existing devices into one powerful GPU: iPhone, iPad, Android, Mac, Linux, pretty much any device!
-

Update: Exo Supports Llama 3.1

-

Run 8B, 70B and 405B parameter Llama 3.1 models on your own devices

-

See the code

+

Update: exo is hiring. See here for more details.

## Get Involved -exo is **experimental** software. Expect bugs early on. Create issues so they can be fixed. The [exo labs](https://x.com/exolabs_) team will strive to resolve issues quickly. +exo is **experimental** software. Expect bugs early on. Create issues so they can be fixed. The [exo labs](https://x.com/exolabs) team will strive to resolve issues quickly. We also welcome contributions from the community. We have a list of bounties in [this sheet](https://docs.google.com/spreadsheets/d/1cTCpTIp48UnnIvHeLEUNg1iMy_Q6lRybgECSFCoVJpE/edit?usp=sharing). @@ -52,7 +50,7 @@ exo will [automatically discover](https://github.com/exo-explore/exo/blob/945f90 ### ChatGPT-compatible API -exo provides a [ChatGPT-compatible API](exo/api/chatgpt_api.py) for running models. It's a [one-line change](examples/chatgpt_api.py) in your application to run models on your own hardware using exo. +exo provides a [ChatGPT-compatible API](exo/api/chatgpt_api.py) for running models. It's a [one-line change](examples/chatgpt_api.sh) in your application to run models on your own hardware using exo. ### Device Equality @@ -108,8 +106,6 @@ python3 main.py That's it! No configuration required - exo will automatically discover the other device(s). -The native way to access models running on exo is using the exo library with peer handles. See how in [this example for Llama 3](examples/llama3_distributed.py). - exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000 For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curls: @@ -150,6 +146,26 @@ curl http://localhost:8000/v1/chat/completions \ }' ``` +### Example Usage on Multiple Heterogenous Devices (MacOS + Linux) + +#### Device 1 (MacOS): + +```sh +python3 main.py --inference-engine tinygrad +``` + +Here we explicitly tell exo to use the **tinygrad** inference engine. + +#### Device 2 (Linux): +```sh +python3 main.py +``` + +Linux devices will automatically default to using the **tinygrad** inference engine. + +You can read about tinygrad-specific env vars [here](https://docs.tinygrad.org/env_vars/). For example, you can configure tinygrad to use the cpu by specifying `CLANG=1`. + + ## Debugging Enable debug logs with the DEBUG environment variable (0-9). @@ -158,6 +174,12 @@ Enable debug logs with the DEBUG environment variable (0-9). DEBUG=9 python3 main.py ``` +For the **tinygrad** inference engine specifically, there is a separate DEBUG flag `TINYGRAD_DEBUG` that can be used to enable debug logs (1-6). + +```sh +TINYGRAD_DEBUG=2 python3 main.py +``` + ## Known Issues - 🚧 As the library is evolving so quickly, the iOS implementation has fallen behind Python. We have decided for now not to put out the buggy iOS version and receive a bunch of GitHub issues for outdated code. We are working on solving this properly and will make an announcement when it's ready. If you would like access to the iOS implementation now, please email alex@exolabs.net with your GitHub username explaining your use-case and you will be granted access on GitHub. diff --git a/examples/chatgpt_api.sh b/examples/chatgpt_api.sh new file mode 100755 index 000000000..2fc0e2bea --- /dev/null +++ b/examples/chatgpt_api.sh @@ -0,0 +1,39 @@ +# exo provides an API that aims to be a drop-in replacements for the ChatGPT-API. +# This example shows how you can use the API first without streaming and second with streaming. +# This works the same in a single-node set up and in a multi-node setup. +# You need to start exo before running this by running `python3 main.py`. + +API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}" +MODEL="llama-3.1-8b" +PROMPT="What is the meaning of exo?" +TEMPERATURE=0.7 + +echo "" +echo "" +echo "--- Output without streaming:" +echo "" +curl "${API_ENDPOINT}/v1/chat/completions" --silent \ + -H "Content-Type: application/json" \ + -d '{ + "model": "'"${MODEL}"'", + "messages": [{"role": "user", "content": "'"${PROMPT}"'"}], + "temperature": '"${TEMPERATURE}"' + }' + +echo "" +echo "" +echo "--- Output with streaming:" +echo "" +curl "${API_ENDPOINT}/v1/chat/completions" --silent \ + -H "Content-Type: application/json" \ + -d '{ + "model": "'"${MODEL}"'", + "messages": [{"role": "user", "content": "'"${PROMPT}"'"}], + "temperature": '"${TEMPERATURE}"', + "stream": true + }' | while read -r line; do + if [[ $line == data:* ]]; then + content=$(echo "$line" | sed 's/^data: //') + echo "$content" | jq -r '.choices[].delta.content' --unbuffered | tr -d '\n' + fi + done \ No newline at end of file diff --git a/examples/llama3_distributed.py b/examples/llama3_distributed.py deleted file mode 100644 index f5d5ad4c7..000000000 --- a/examples/llama3_distributed.py +++ /dev/null @@ -1,81 +0,0 @@ -# In this example, a user is running a home cluster with 3 shards. -# They are prompting the cluster to generate a response to a question. -# The cluster is given the question, and the user is given the response. - -from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer -from exo.inference.shard import Shard -from exo.networking.peer_handle import PeerHandle -from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle -from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops -from typing import List -import asyncio -import argparse -import uuid - -models = { - "mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), - "mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80) -} - -path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit" -model_path = get_model_path(path_or_hf_repo) -tokenizer_config = {} -tokenizer = load_tokenizer(model_path, tokenizer_config) - -# we intentionally leave out peer1 to demonstrate equality of nodes in exo. -# there is no "master" node in exo, all nodes are equal and can take on any role. -# peer1 = GRPCPeerHandle( -# "node1", -# "localhost:8080", -# DeviceCapabilities(model="placeholder", chip="placeholder", memory=0) -# ) -peer2 = GRPCPeerHandle("node2", "localhost:8081", DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))) -shard = models[path_or_hf_repo] -request_id = str(uuid.uuid4()) - - -async def run_prompt(prompt: str): - if tokenizer.chat_template is None: - tokenizer.chat_template = tokenizer.default_chat_template - if (hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None): - messages = [{"role": "user", "content": prompt}] - prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - - await peer2.connect() - - try: - await peer2.send_prompt(shard, prompt, request_id) - except Exception as e: - print(e) - - import time - # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user) - previous_length = 0 - n_tokens = 0 - start_time = time.perf_counter() - while True: - try: - result, is_finished = await peer2.get_inference_result(request_id) - except Exception as e: - continue - await asyncio.sleep(0.1) - - # Print the updated string in place - updated_string = tokenizer.decode(result) - n_tokens = len(result) - print(updated_string[previous_length:], end='', flush=True) - previous_length = len(updated_string) - - if is_finished: - print("\nDone") - break - end_time = time.perf_counter() - print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run prompt") - parser.add_argument("--prompt", type=str, help="The prompt to run") - args = parser.parse_args() - - asyncio.run(run_prompt(args.prompt)) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 1abda85fe..f8f8b6e46 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -58,6 +58,9 @@ def generate_completion( "finish_reason": finish_reason, }], } + + if DEBUG >= 3: + print(f"completion: {completion}") if not stream: completion["usage"] = { @@ -67,9 +70,16 @@ def generate_completion( } choice = completion["choices"][0] + print(f"\nchoice {choice}") if object_type.startswith("chat.completion"): key_name = "delta" if stream else "message" - choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)} + + token_decode = tokenizer.batch_decode( + tokens, + skip_special_tokens=True, + clean_up_tokenization_spaces=False + ) + choice[key_name] = {"role": "assistant", "content": token_decode} elif object_type == "text_completion": choice["text"] = tokenizer.decode(tokens) else: @@ -113,16 +123,9 @@ def remap_messages(messages: List[Message]) -> List[Message]: def build_prompt(tokenizer, _messages: List[Message]): - if len(_messages) == 1: - user_msg = _messages[0] - - # get instruct sys message - sys_msg = Message(role="system", content="You are a helpful assistant.") - - # restructure for sys_msg to go first - _messages = [sys_msg, user_msg] - messages = remap_messages(_messages) + if DEBUG >= 3: + print(f"messages: {messages}") prompt = tokenizer.apply_chat_template( messages, tokenize=False, @@ -140,7 +143,7 @@ def build_prompt(tokenizer, _messages: List[Message]): continue for content in message.content: - # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 + # note: wae only support one image at time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 # follows the convention in https://platform.openai.com/docs/guides/vision if isinstance(content, dict) and content.get("type", None) == "image": image_str = content.get("image", None) @@ -171,10 +174,10 @@ def __init__(self, request_id: str, timestamp: int, prompt: str): class ChatGPTAPI: - def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None): + def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None): self.node = node self.inference_engine_classname = inference_engine_classname - self.response_timeout_secs = response_timeout_secs + self.response_timeout = response_timeout self.on_chat_completion_request = on_chat_completion_request self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload self.prompts: PrefixDict[str, PromptSession] = PrefixDict() @@ -273,7 +276,7 @@ async def handle_post_chat_completions(self, request): return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) try: - if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s") + if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s") if stream: response = web.StreamResponse( @@ -322,7 +325,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool): return _request_id == request_id and is_finished - _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs) + _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout) if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.") try: @@ -334,7 +337,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool): else: _, tokens, _ = await callback.wait( lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, - timeout=self.response_timeout_secs, + timeout=self.response_timeout, ) finish_reason = "length" diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index 8fd96dc5f..3197605da 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -200,6 +200,36 @@ async def download_file( if DEBUG >= 2: print(f"Downloaded: {file_path}") +async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str: + repo_root = get_repo_root(repo_id) + refs_dir = repo_root/"refs" + refs_file = refs_dir/revision + + # Check if we have a cached commit hash + if await aios.path.exists(refs_file): + async with aiofiles.open(refs_file, 'r') as f: + commit_hash = (await f.read()).strip() + if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}") + return commit_hash + + # Fetch the commit hash for the given revision + async with aiohttp.ClientSession() as session: + api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}" + headers = await get_auth_headers() + async with session.get(api_url, headers=headers) as response: + if response.status != 200: + raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}") + revision_info = await response.json() + commit_hash = revision_info['sha'] + + # Cache the commit hash + await aios.makedirs(refs_dir, exist_ok=True) + async with aiofiles.open(refs_file, 'w') as f: + await f.write(commit_hash) + + return commit_hash + + async def download_repo_files( repo_id: str, revision: str = "main", @@ -209,35 +239,15 @@ async def download_repo_files( max_parallel_downloads: int = 4 ) -> Path: repo_root = get_repo_root(repo_id) - refs_dir = repo_root/"refs" snapshots_dir = repo_root/"snapshots" cachedreqs_dir = repo_root/"cachedreqs" # Ensure directories exist - await aios.makedirs(refs_dir, exist_ok=True) await aios.makedirs(snapshots_dir, exist_ok=True) await aios.makedirs(cachedreqs_dir, exist_ok=True) - # Check if we have a cached commit hash - refs_file = refs_dir/revision - if await aios.path.exists(refs_file): - async with aiofiles.open(refs_file, 'r') as f: - commit_hash = (await f.read()).strip() - if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}") - else: - async with aiohttp.ClientSession() as session: - # Fetch the commit hash for the given revision - api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}" - headers = await get_auth_headers() - async with session.get(api_url, headers=headers) as response: - if response.status != 200: - raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}") - revision_info = await response.json() - commit_hash = revision_info['sha'] - - # Cache the commit hash - async with aiofiles.open(refs_file, 'w') as f: - await f.write(commit_hash) + # Resolve revision to commit hash + commit_hash = await resolve_revision_to_commit_hash(repo_id, revision) # Set up the snapshot directory snapshot_dir = snapshots_dir/commit_hash @@ -357,7 +367,8 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[ # Check if the file exists repo_root = get_repo_root(repo_id) - snapshot_dir = repo_root/"snapshots" + commit_hash = await resolve_revision_to_commit_hash(repo_id, revision) + snapshot_dir = repo_root/"snapshots"/commit_hash index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None) if index_file: @@ -380,24 +391,19 @@ def extract_layer_num(tensor_name: str) -> Optional[int]: def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: - default_patterns = [ - "*.json", - "*.py", - "tokenizer.model", - "*.tiktoken", - "*.txt", - ] - shard_specific_patterns = [] + default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"]) + shard_specific_patterns = set() if weight_map: for tensor_name, filename in weight_map.items(): layer_num = extract_layer_num(tensor_name) if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer: - shard_specific_patterns.append(filename) + shard_specific_patterns.add(filename) sorted_file_names = sorted(weight_map.values()) if shard.is_first_layer(): - shard_specific_patterns.append(sorted_file_names[0]) + shard_specific_patterns.add(sorted_file_names[0]) elif shard.is_last_layer(): - shard_specific_patterns.append(sorted_file_names[-1]) + shard_specific_patterns.add(sorted_file_names[-1]) else: shard_specific_patterns = ["*.safetensors"] - return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates + if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}") + return list(default_patterns | shard_specific_patterns) diff --git a/exo/helpers.py b/exo/helpers.py index d8a5c6cc2..9ee4b8c18 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -8,6 +8,7 @@ import uuid import netifaces from pathlib import Path +import tempfile DEBUG = int(os.getenv("DEBUG", default="0")) DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0")) @@ -33,7 +34,7 @@ def get_system_info(): return "Non-Mac, non-Linux system" def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int: - used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports") + used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports") def read_used_ports(): if os.path.exists(used_ports_file): diff --git a/exo/inference/mlx/models/qwen2.py b/exo/inference/mlx/models/qwen2.py new file mode 100644 index 000000000..7aed2d045 --- /dev/null +++ b/exo/inference/mlx/models/qwen2.py @@ -0,0 +1,127 @@ +from dataclasses import dataclass, field + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs + +from ...shard import Shard +from .base import IdentityBlock + + +@dataclass +class ModelArgs(ModelArgs): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + super().__post_init__() # Ensure parent initializations are respected + + if isinstance(self.shard, Shard): + return + if not isinstance(self.shard, dict): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + self.shard = Shard(**self.shard) + +class Qwen2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + if self.args.shard.is_first_layer(): + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [] + for i in range(self.num_hidden_layers): + if self.args.shard.start_layer <= i <= self.args.shard.end_layer: + self.layers.append(TransformerBlock(args=args)) + else: + self.layers.append(IdentityBlock()) + if self.args.shard.is_last_layer(): + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + if self.args.shard.is_first_layer(): + h = self.embed_tokens(inputs) + else: + h = inputs + + mask = None + if h.shape[1] > 1: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + if self.args.shard.is_last_layer(): + h = self.norm(h) + return h + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Qwen2Model(args) + if self.args.shard.is_last_layer(): + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.shard.is_last_layer(): + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + shard_state_dict = {} + + for key, value in weights.items(): + if "self_attn.rotary_emb.inv_freq" in key: + continue + if key.startswith('model.layers.'): + layer_num = int(key.split('.')[2]) + if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: + shard_state_dict[key] = value + elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): + shard_state_dict[key] = value + elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'): + shard_state_dict[key] = value + elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'): + shard_state_dict[key] = value + elif self.args.shard.is_last_layer() and (key.startswith('model.norm')): + shard_state_dict[key] = value + + if self.args.tie_word_embeddings: + shard_state_dict.pop("lm_head.weight", None) + + return shard_state_dict + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/exo/inference/mlx/sharded_inference_engine.py b/exo/inference/mlx/sharded_inference_engine.py index 40cabfeb6..7b920ccc4 100644 --- a/exo/inference/mlx/sharded_inference_engine.py +++ b/exo/inference/mlx/sharded_inference_engine.py @@ -6,28 +6,32 @@ from ..shard import Shard from typing import Optional from exo.download.shard_download import ShardDownloader - +import asyncio +from concurrent.futures import ThreadPoolExecutor class MLXDynamicShardInferenceEngine(InferenceEngine): def __init__(self, shard_downloader: ShardDownloader): self.shard = None self.shard_downloader = shard_downloader + self.executor = ThreadPoolExecutor(max_workers=1) async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): await self.ensure_shard(shard) + loop = asyncio.get_running_loop() if image_str: image = await get_image_from_str(image_str) - inputs = self.tokenizer(prompt, image, return_tensors="np") + inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np") pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) - output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, input_ids, pixel_values)) + output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values)) else: - output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt)))) + input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt)) + output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids)) return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): await self.ensure_shard(shard) - output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data))) + output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data))) return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id async def ensure_shard(self, shard: Shard): @@ -35,6 +39,10 @@ async def ensure_shard(self, shard: Shard): return model_path = await self.shard_downloader.ensure_shard(shard) - model_shard, self.tokenizer = await load_shard(model_path, shard) - self.stateful_sharded_model = StatefulShardedModel(shard, model_shard) - self.shard = shard + + if self.shard != shard: + loop = asyncio.get_running_loop() + def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard)) + model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper) + self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard) + self.shard = shard diff --git a/exo/inference/mlx/sharded_model.py b/exo/inference/mlx/sharded_model.py index c4570fbf6..46f06b3db 100644 --- a/exo/inference/mlx/sharded_model.py +++ b/exo/inference/mlx/sharded_model.py @@ -8,7 +8,7 @@ from ..shard import Shard - +# TODO: support a speculative model so we can parallelise compute across devices class StatefulShardedModel: def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2): self.shard = shard diff --git a/exo/inference/pytorch/.gitignore b/exo/inference/pytorch/.gitignore new file mode 100644 index 000000000..6d76c24de --- /dev/null +++ b/exo/inference/pytorch/.gitignore @@ -0,0 +1,2 @@ +data/ +model/archive/ diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 063a9e4a3..f3036e788 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,33 +1,90 @@ # experimental, based off of tinygrad/inference.py import numpy as np import torch -import numpy as np import json +import gc from typing import Optional, Tuple from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG -from transformers import DynamicCache +from transformers import DynamicCache, Cache from accelerate import disk_offload +from exo.download.shard_download import ShardDownloader + +# model value options +TOP_K = 20 +TEMP = 0.6 +TOP_P = 0.9 +MAX_LENGTH = 125 +MAX_TIME = 60.0 class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ - PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. + PyTorch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. """ - def __init__(self, shard): + def __init__(self, shard_downloader: ShardDownloader): """ Initialize the inference engine. Args: debug (bool): If True, enables debug logging. Defaults to False. """ - self.shard = shard - self.model = None + self.shard = None + self.shard_downloader = shard_downloader + self.stateful_sharded_model = None self.tokenizer = None - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # the whole history with new logits need to + # be passed to the model to reach the end token + # even with caching + self.past_input_ids = None + + # setup cuda device + if torch.cuda.is_available(): + self.device = torch.device("cuda") + self.torch_dtype = torch.float32 + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + self.torch_dtype = torch.float32 + else: + self.device = torch.device("cpu") + self.torch_dtype = torch.float16 + + # setup unfinished sequence + self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device) + + def infer_caching( + self, + inference_state: Optional[str] = None + ) -> Tuple[Optional[torch.tensor], Optional[dict]]: + """ + inference caching from inference_state json + """ + # setup cache and cached input_ids + past_iids = None + cached_iids = None + if inference_state is not None: + try: + infer_state = json.loads(inference_state) + except ValueError: + infer_state = None + + if infer_state is not None: + cached_iids = infer_state["cached_iids"] + if cached_iids is not None: + past_iids = None + if len(cached_iids) > 0: + past_iids = torch.tensor(cached_iids["input_ids"]).to(self.device) + cached_iids = {"input_ids": past_iids.tolist()} + + if DEBUG >= 4: + print(f"cached_iids: {cached_iids}") + + return (past_iids, cached_iids) + async def infer_prompt( self, @@ -37,60 +94,78 @@ async def infer_prompt( image_str: Optional[str] = None, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_prompt called") + print(f"prompt: {prompt}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") await self.ensure_shard(shard) + + # setup prompt input + messages = [{"role": "user", "content": prompt}] + txt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + inputs = self.tokenizer([txt], return_tensors="pt") + input_ids = inputs.input_ids.to(self.device) + input_attention_mask = inputs.attention_mask.to(self.device) + batch_size, seq_length = input_ids.shape[:2] + + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) - # need to make this so inference_state is not a string - # cant use it with dynamic cache - - tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) - tokens = self.model.embed_tokens(tokens) - current_kvs = None + if past_iids is not None: + self.past_input_ids = past_iids, + else: + self.past_input_ids = input_ids if DEBUG >= 4: - print("infer_prompt called") - print(f"tokens: {tokens}\n") - print(f"layer_count: {self.shard.get_layer_count()}") - print(f"is_first_layer: {self.shard.is_first_layer()}") - print(f"is_last_layer: {self.shard.is_last_layer()}") + print(f"past_input_ids: {self.past_input_ids}\n") - # convert inference_state or cache from json to DynamicCache - past_kv = DynamicCache() - if inference_state != None: - cache_dict = json.loads(inference_state) - past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] - past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] - - output_data, current_kvs = self.model.forward( - tokens, - past_kv + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + input_ids=self.past_input_ids, + attention_mask=input_attention_mask ) - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + if DEBUG >= 4: + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + next_token = None + if shard_logits is not None: + next_token = self.stateful_sharded_model.logits_sample(shard_logits) + self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) + input_ids = next_token + + if self.past_input_ids is not None: + cached_iids = {"input_ids": self.past_input_ids.tolist()} + + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id if DEBUG >= 4: - print(f"output_data: {output_data}\n") - print(f"output_data.size {output_data.size}\n") - - print(f"finished: {is_finished}") - print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") - print(f"output_data[-1] {output_data[-1]}") - - if output_data.size == 1: - print(f"size 1 output_data.item() {output_data.item()}") - print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") - - cache_dict = { - 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] - } - - return ( - output_data, - json.dumps(cache_dict), + print(f"\ninput_ids: {input_ids}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), is_finished ) + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + async def infer_tensor( self, request_id: str, @@ -98,79 +173,83 @@ async def infer_tensor( input_data: np.ndarray, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_tensor called") + print(f"input_data: {input_data}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") await self.ensure_shard(shard) - current_kvs = None + input_ids = torch.tensor(input_data).to(self.device) - if input_data.size == 1: - in_tensor = torch.from_numpy( - input_data, - ).unsqueeze(0).long().to(self.device) + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) + + # detect if hidden_states or not + hidden_states = None + self.past_input_ids = None + if input_ids.size()[-1] > 1: + hidden_states = input_ids else: - in_tensor = torch.from_numpy( - input_data - ).long().to(self.device) + if past_iids is not None: + self.past_input_ids = past_iids + else: + self.past_input_ids = input_ids + + if DEBUG >= 4: + print(f"input_ids: {input_ids}") + print(f"inference_state: {inference_state}") - in_tensor = self.model.embed_tokens(in_tensor) + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + input_ids=self.past_input_ids, + hidden_states=hidden_states + ) - if DEBUG >= 4: - print("infer_tensor called") - print(f"input_data: {input_data}") - print(f"input_data.size: {input_data.size}") - print(f"input_tensor: {in_tensor}\n") - print(f"shard: {self.shard}") - print(f"layer_count: {self.shard.get_layer_count()}") - print(f"is_first_layer: {self.shard.is_first_layer()}") - print(f"is_last_layer: {self.shard.is_last_layer()}") + hidden_dict = None + if shard_hidden_states is not None: + hidden_dict = {"hidden_states": shard_hidden_states.tolist()} + + next_token = None + if shard_logits is not None: + next_token = self.stateful_sharded_model.logits_sample(shard_logits) + input_ids = next_token - # convert inference_state or cache from json to DynamicCache - past_kv = DynamicCache() - if inference_state != None: - try: - cache_dict = json.loads(inference_state) - past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] - past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] + #cache + if next_token is not None: + if self.past_input_ids is not None: + next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) + elif past_iids is not None: + next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) + + cached_iids = {"input_ids": next_cached_logits.tolist()} - if DEBUG >= 4: - print("Loaded past_kv from JSON") - print(f"past_kv: {past_kv}") - print(f"past_kv.key_cache len: {len(past_kv.key_cache)}") - print(f"past_kv.value_cache len: {len(past_kv.value_cache)}") - except json.JSONDecodeError: - print(f"ERROR DECODING INFERENCE STATE") - - output_data, current_kvs = self.model.forward( - in_tensor, - past_kv - ) + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + if is_finished: + # clear cache + cached_iids = {"input_ids": []} if DEBUG >= 4: - print(f"in_tensor: {in_tensor}\n") - print(f"output_data: {output_data}\n") - print(f"output_data.size {output_data.size}\n") - print(f"finished: {is_finished}") - print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") - print(f"output_data[-1] {output_data[-1]}") + print(f"\ninput_ids: {input_ids}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") - if output_data.size == 1: - print(f"size 1 output_data.item() {output_data.item()}") - print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") - - - cache_dict = { - 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] - } - - return ( - output_data, - json.dumps(cache_dict), + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), is_finished ) + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + + async def ensure_shard(self, shard: Optional[Shard]): """ Ensure the model shard is loaded and ready for inference. @@ -184,9 +263,25 @@ async def ensure_shard(self, shard: Optional[Shard]): if DEBUG >= 4: print(f"Loading new shard: {shard}") - self.shard = shard + # -- TO DO -- + # Build in shard downloader but requires pulling + # apart how TrainedModel loads weight in its __init__ + # function in the transformer library + # model_path = await self.shard_downloader.ensure_shard(shard) + self.tokenizer = await resolve_tokenizer(shard.model_id) - self.model = ShardedHuggingFaceModel(shard, self.tokenizer) + self.stateful_sharded_model = ShardedHuggingFaceModel( + shard=shard, + device=self.device, + dtype=self.torch_dtype, + top_k=TOP_K, + temp=TEMP, + top_p=TOP_P, + max_length=MAX_LENGTH, + max_time=MAX_TIME + ) + + self.shard = shard if DEBUG >= 4: - print(f"Shard loaded successfully: {shard}") \ No newline at end of file + print(f"Shard loaded successfully: {shard}") diff --git a/exo/inference/pytorch/model/archive/hf_manual.py b/exo/inference/pytorch/model/archive/hf_manual.py new file mode 100644 index 000000000..e5af2eaf8 --- /dev/null +++ b/exo/inference/pytorch/model/archive/hf_manual.py @@ -0,0 +1,203 @@ +# Attempted version to recreate manually using LlamaModel and others +# BROKEN +import torch +import numpy as np +from transformers import AutoModelForCausalLM, DynamicCache, Cache, AutoModel +from exo.inference.shard import Shard +from exo.helpers import DEBUG +from typing import Tuple, Optional, Union, List +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from exo.inference.pytorch.model.archive.utils import sample_logits + +TOP_P = 0.7 #0.95 +TOP_K = 50 +TEMP = 0.01 + + +class ShardedHuggingFaceModel(torch.nn.Module): + def __init__(self, shard: Shard): + super(ShardedHuggingFaceModel, self).__init__() + + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + self.shard = shard + + # Load the model + try: + self.base_model = AutoModel.from_pretrained( + shard.model_id, + torch_dtype=torch.float32, + device_map="auto", + # offload_buffers=True + ) + + # disk_offload(model=self.base_model, offload_dir="./.offload") + except Exception as err: + print(f"Error loading model: {err}") + raise + + if DEBUG >= 2: + print(f"\nShardedHuggingFaceModel init with shard {shard}") + print(f"self.base_model: {self.base_model}") + + # Embeddings and final layer norm + # used for doing what forward LlamaModel does in transformers + self.norm = self.base_model.norm + self.lm_head = torch.nn.Linear( + self.base_model.config.hidden_size, + self.base_model.config.vocab_size, + bias=False + ).to(self.device) + self.embed_tokens = self.base_model.embed_tokens + + def forward( + self, + input_ids: torch.tensor, + attention_mask: torch.tensor = None, + past_kvs: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + ) -> Tuple[np.ndarray, any]: + """ + Forward through layers using the base model + + Args: + input_ids: tensor input + attention_mask: attention mask from tokenizer + past_kvs: past key value stores for cache + + Returns: + hidden_states: numpy of states between layers + or logits: numpy of normalization and linearization of last hidden state + past_kvs: DynamicCache of past key values if use_cache is true + + Ref: + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/qwen2/modeling_qwen2.py#L804 + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L887 + """ + if DEBUG >= 4: + print("forward called") + print(f"input_ids: {input_ids}\n") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + + if self.shard.is_first_layer(): + if DEBUG >= 2: + print("first layer, embed") + print(f"input_ids: {input_ids}") + input_ids = self.embed_tokens(input_ids) + + if DEBUG >= 2: + print(f"embeded input_ids: {input_ids}") + + if attention_mask == None: + # get attention mask + past_kv_length = len(past_kvs) + batch_size, seq_length = input_ids.shape[:2] + attention_mask = _prepare_4d_causal_attention_mask( + None, (batch_size, seq_length), input_ids, past_kv_length + ) + + past_kvs = DynamicCache.from_legacy_cache(past_kvs) + past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + input_ids.shape[1], + device=self.device + ) + + position_ids = cache_position.unsqueeze(0).to(self.device) + + try: + position_embeddings = self.base_model.rotary_emb( + input_ids, + position_ids + ) + except Exception as err: + print(f"rotary_emb not found in base_model") + position_embeddings = None + + causal_mask = self.base_model._update_causal_mask( + attention_mask, + input_ids, + cache_position, + past_kvs, + self.base_model.config.output_attentions + ) + + # progress through layers + for i in range(self.shard.start_layer, self.shard.end_layer + 1): + decoder_layer = self.base_model.layers[i] + + if DEBUG >= 4: + print("Going through layer") + print(f"{decoder_layer}") + print("input_ids") + print(f"{input_ids}") + print("causal_mask") + print(f"{causal_mask}") + + try: + layer_outputs = decoder_layer( + input_ids, + attention_mask=causal_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_value=past_kvs, + use_cache=True, + cache_position=cache_position, + output_logits=True + ) + except Exception as err: + print(f"Going through layer failed: {err}") + print(err.__traceback__.tb_lineno) + raise + + hidden_states = layer_outputs[0] + next_kvs = layer_outputs[1] + + if DEBUG >= 3: + print(f"layer_outputs {layer_outputs}") + print(layer_outputs[1:]) + + if self.shard.is_last_layer(): + hs_norm = self.norm(hidden_states).to(self.device) + # hs_lm_head = self.base_model.lm_head(hs_norm).float() + + # Use the sampling function with default settings + with torch.no_grad(): + logits = self.lm_head( + hs_norm[:, -1:, :] + ).to(self.device).float() + + if DEBUG >= 2: + print(f"hs_norm: {hs_norm}") + # print(f"hs_lm_head: {hs_lm_head}") + print(f"logits: {logits}") + print(f"logits.shape: {logits.shape}") + + # output_token = sample_logits( + # logits, + # TEMP, + # TOP_P, + # TOP_K + # ).unsqueeze(0).unsqueeze(0).long() + + output_token = torch.distributions.Categorical( + logits=logits + ).sample(sample_shape=(1,)) + + if DEBUG >= 2: + print(f"output_token: {output_token}") + + return (output_token.numpy(force=True), next_kvs) + + with torch.no_grad(): + out_hidden_states = hidden_states.float().numpy(force=True) + + return ( + out_hidden_states, + next_kvs + ) \ No newline at end of file diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index ed9e6ae17..1481c40cc 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,155 +1,293 @@ import torch +import torch.nn as nn import numpy as np -from transformers import AutoModelForCausalLM, DynamicCache, Cache +from typing import Tuple, Optional, Union, List + from exo.inference.shard import Shard from exo.helpers import DEBUG -from typing import Tuple, Optional, Union, List -from exo.inference.pytorch.model.utils import sample_logits +from exo.inference.inference_engine import InferenceEngine +from exo.download.shard_download import ShardDownloader + +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + DynamicCache, + Cache, + LogitsProcessorList, + #MinLengthLogitsProcessor, + LogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TemperatureLogitsWarper, + StoppingCriteriaList, + MaxLengthCriteria, + MaxTimeCriteria +) -TOP_P = 0.9 #0.95 -TOP_K = 25 -TEMP = 0.85 +from transformers.generation.configuration_utils import ( + GenerationConfig, + GenerationMode +) -class ShardedHuggingFaceModel(torch.nn.Module): - def __init__(self, shard: Shard, tokenizer: any): - super(ShardedHuggingFaceModel, self).__init__() +# llama +from transformers.models.llama.modeling_llama import LlamaModel - if torch.cuda.is_available(): - self.device = torch.device("cuda") - else: - self.device = torch.device("cpu") +# qwen2 +from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + +class ShardedHuggingFaceModel: + def __init__( + self, + shard: Shard, + device, + dtype, + top_k: int = 25, + temp: float = 0.7, + top_p: float = 0.9, + max_length: int = 50, + max_time: float = 10.0 + ): + # class vars self.shard = shard - self.tokenizer = tokenizer + self.hidden_states = None + self.input_ids = None + self.inputs_embeds = None + self.attention_mask = None + self.position_embeddings = None + self.past_key_values = None + self.cache_position = None + self.position_ids = None + self.causal_mask = None + + # setup logit processors + self.logits_processor = LogitsProcessorList([ + TopKLogitsWarper(top_k), + TemperatureLogitsWarper(temp), + TopPLogitsWarper(top_p) + ]) - # Load the model + # setup stopping critera for generation + self.stopping_critera = StoppingCriteriaList( + [ + #MaxLengthCriteria(max_length=max_length), + MaxTimeCriteria(max_time=max_time), + ] + ) + + self.device = device + self.torch_dtype = dtype + + # setup pytorch and transformer llm try: self.llm_model = AutoModelForCausalLM.from_pretrained( shard.model_id, - torch_dtype=torch.float32, + torch_dtype=self.torch_dtype, device_map="auto", - # offload_buffers=True + offload_buffers=True ) - # disk_offload(model=self.llm_model, offload_dir="./.offload") - - self.base_model = self.llm_model.model + self.model = self.llm_model.model except Exception as err: - print(f"Error loading model: {err}") + print(f"error loading and splitting model: {err}") raise - if DEBUG >= 2: - print(f"\nShardedHuggingFaceModel init with shard {shard}") - print(f"self.llm_model: {self.llm_model}") - print(f"self.base_model: {self.base_model}") - if DEBUG >= 2: - print(f"full_model.model layer: {len(self.base_model.layers)}") - - # Embeddings and final layer norm - # used for doing what forward LlamaModel does in transformers - self.norm = self.base_model.norm - self.lm_head = self.llm_model.lm_head - self.embed_tokens = self.base_model.embed_tokens - def forward( self, - input_ids: torch.tensor, - past_kvs: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - ) -> Tuple[np.ndarray, any]: + shard: Optional[Shard] = None, + input_ids: Optional[torch.tensor] = None, + hidden_states: Optional[torch.tensor] = None, + attention_mask: Optional[torch.tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_legacy_cache: bool = False + ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: + """ - Forward through layers using the base model + Generate hidden states or logits via passing through set amount of layers of a model + To be passed only input_ids OR hidden_state and not both. This is for connecting the model + layer to generate a complete output Args: - input_ids: tensor input - past_kvs: past key value stores for cache - use_cache: use cache - + model: base llm model tramsformers class + llm_model: llm chat model class + input_ids: tensor optional + attention_mask: tensor optional + past_key_values: Cache or list[tensor] optional + use_legacy_cache: bool optional + infer_tensor: bool optional, lets forward know to handle tensors + Returns: - hidden_states: numpy of states between layers - or logits: numpy of normalization and linearization of last hidden state - past_kvs: DynamicCache of past key values if use_cache is true + Tuple of + - hidden_states: tensor optional + - past_key_values: Cache or list[tensor] optional + - logits: tensor Optional - Ref: - https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/qwen2/modeling_qwen2.py#L804 - https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L887 """ - if DEBUG >= 4: - print("forward called") - print(f"input_ids: {input_ids}\n") - print(f"layer_count: {self.shard.get_layer_count()}") - print(f"is_first_layer: {self.shard.is_first_layer()}") - print(f"is_last_layer: {self.shard.is_last_layer()}") + + model_inputs = None + self.hidden_states = None + + if hidden_states is not None: + self.hidden_states = hidden_states + else: + self.input_ids = input_ids - past_kvs = DynamicCache.from_legacy_cache(past_kvs) - past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 + # embed input_ids + self.inputs_embeds = self.model.embed_tokens(self.input_ids) + + # cache + if past_key_values and not isinstance(past_key_values, Cache): + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + input_ids.shape[1], - device=input_ids.device - ).to(self.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + position_ids = cache_position.unsqueeze(0) - position_ids = cache_position.unsqueeze(0).to(self.device) + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, + self.inputs_embeds, + cache_position, + past_key_values, + False # dont out attentions + ) - try: - position_embeddings = self.base_model.rotary_emb( - input_ids, - position_ids + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( + self.inputs_embeds, + position_ids + ) + + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=position_ids, + cache_position=cache_position ) - except Exception as err: - print(f"rotary_emb not found in base_model") - position_embeddings = None - # progress through layers - for i in range(self.shard.start_layer, self.shard.end_layer + 1): - decoder_layer = self.base_model.layers[i] + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] if DEBUG >= 4: - print("Going through layer") - print(f"{decoder_layer}") - print("input_ids") - print(f"{input_ids}") + print(f"model_inputs: {model_inputs}") + + # run through decoder layers + layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) + + if DEBUG >= 4: + print(f"hidden_states: {self.hidden_states}") + print(f"layer_amt: {layer_amt}") + + for i in layer_amt: + decoder_layer = self.model.layers[i] + if DEBUG >= 5: + print("decoder_layer before") + print(f"decoder_layer: {decoder_layer}") + print(f"hidden_states: {self.hidden_states}") + # TODO: fix caching as decoder layer is not returning + # present_key_value from attention layer on models + # might have some other generation functions needed to do it + # see https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2917 + # for qwen2 exhttps://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py#L291 layer_outputs = decoder_layer( - input_ids, - position_ids=position_ids if not position_embeddings else None, - position_embeddings=position_embeddings, - past_key_value=past_kvs, + self.hidden_states, + attention_mask=self.causal_mask, + position_ids=self.position_ids, + past_key_values=self.past_key_values, use_cache=True, - cache_position=cache_position, + cache_position=self.cache_position ) - hidden_states = layer_outputs[0] - next_kvs = layer_outputs[1] + self.hidden_states = layer_outputs[0] + self.next_decoder_cache = layer_outputs[1] - if DEBUG >= 3: - print(f"layer_outputs {layer_outputs}") - + if DEBUG >= 5: + print("decoder_layer after") + print(f"layer_outputs: {layer_outputs}\n") + print(f"self.next_decoder_cache: {self.next_decoder_cache}") + print(f"hidden_states: {self.hidden_states}") + print(f"next_decoder_cache: {self.next_decoder_cache}") + + + # handle last layer to get logits + # shard is last layer says true at the start and not detecting last layer correctly if self.shard.is_last_layer(): - hs_norm = self.norm(hidden_states) - hs_lm_head = self.llm_model.lm_head(hs_norm).float() - - # Use the sampling function with default settings - with torch.no_grad(): - output_token = sample_logits( - hs_lm_head[:, -1, :], - TEMP, - TOP_P, - TOP_K - ).numpy(force=True).flatten() - - if DEBUG >= 2: - print(f"hs_norm: {hs_norm}") - print(f"hs_lm_head: {hs_lm_head}") - print(f"output_token: {output_token}") - - return (output_token, next_kvs) - - with torch.no_grad(): - out_hidden_states = hidden_states.numpy(force=True) + self.hidden_states = self.model.norm(self.hidden_states) + if use_legacy_cache: + self.past_key_values = self.next_decoder_cache.to_legacy_cache() + else: + self.past_key_values = self.next_decoder_cache + + # lm_head + logits = self.llm_model.lm_head(self.hidden_states).to(self.device) + + if DEBUG >= 4: + print(f"logits: {logits}") + + return ( + None, + None, + logits + ) + + if DEBUG >= 4: + print(f"hidden_states: {self.hidden_states}") + print(f"past_key_values: {self.past_key_values}") return ( - out_hidden_states, - next_kvs - ) \ No newline at end of file + self.hidden_states, + self.past_key_values, + None + ) + + def logits_sample( + self, + logits: torch.tensor, + use_max: Optional[bool] = False + ) -> torch.tensor: + """ + Get a sample of the logits from end of model run for next token + + Args: + logits: tensor + use_max: bool, if function should sample with argmax + + Returns: + next_token: tensor + """ + + # get a single cloned logit + logits = logits[:, -1, :].clone().float() + + next_token_scores = self.logits_processor(self.input_ids, logits) + + if not use_max: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_token_scores, dim=-1) + + if DEBUG >= 4: + print(f"input_ids: {self.input_ids}") + print(f"next_token: {next_token}") + + return next_token[:, None].squeeze(-1) + + diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py deleted file mode 100644 index df84b3977..000000000 --- a/exo/inference/pytorch/model/utils.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -from torch.nn import functional as F - -def top_p_sampling(scaled_logits: torch.Tensor, top_p: float) -> torch.Tensor: - """ - Apply top-p (nucleus) sampling to logits. - - Args: - scaled_logits (torch.Tensor): The scaled logits from the model's output. - top_p (float): The cumulative probability threshold for top-p filtering. - temp (float): Temperature parameter for softmax distribution reshaping. - - Returns: - torch.Tensor: Token selected based on the top-p criterion. - - Ref: - https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py#L67C1-L97C17 - """ - scaled_logits = torch.where(torch.isnan(scaled_logits), torch.zeros_like(scaled_logits), scaled_logits) - scaled_logits = torch.where(torch.isinf(scaled_logits), torch.full_like(scaled_logits, 1e6), scaled_logits) - - probs = torch.softmax(scaled_logits, dim=-1) - - sorted_probs, sorted_indices = torch.sort( - probs, - descending=True, - dim=-1 - ) - - cumulative_probs = torch.cumsum(sorted_probs, dim=-1) - mask = cumulative_probs > top_p - - top_probs = torch.where(mask, torch.zeros_like(sorted_probs), sorted_probs) - sum_probs = top_probs.sum(dim=-1, keepdim=True) - top_probs = torch.where(sum_probs > 0, top_probs / sum_probs, torch.ones_like(top_probs) / top_probs.size(-1)) - - if torch.isnan(top_probs).any() or torch.isinf(top_probs).any(): - print("Warning: Top probabilities contain NaN or Inf values after normalization") - top_probs = torch.where(torch.isnan(top_probs) | torch.isinf(top_probs), - 1.0 / top_probs.size(-1), - top_probs) - - sorted_token = torch.multinomial(top_probs, num_samples=1) - - token = sorted_indices.gather(-1, sorted_token) - - return token.squeeze(-1) - -def sample_logits(logits, temp, top_p, top_k): - """ - Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. - - Args: - logits (torch.Tensor): The logits distribution to sample from. - temp (float): temp for scaling logits. - top_p (float): The cumulative probability threshold for nucleus sampling. - - Returns: - torch.Tensor: The selected token index. - """ - - # Ensure logits are float - logits = logits.float() - - # If temp is very low, just use argmax - if temp == 0: - return logits.argmax(dim=-1) - - scaled_logits = logits/temp - - # top k - if top_k > 0: - top_values, top_indices = torch.topk(scaled_logits, top_k, dim=-1) - scaled_logits = torch.zeros_like(logits).scatter_(-1, top_indices, top_values) - - # Top-p sampling - if 0 < top_p < 1.0: - return top_p_sampling(scaled_logits, top_p) - else: - # random distribution selection - probs = torch.softmax(scaled_logits, dim=-1) - rand_sample = torch.distributions.Categorical(probs) - return rand_sample.sample().squeeze() \ No newline at end of file diff --git a/exo/inference/pytorch/tests/__init__.py b/exo/inference/pytorch/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/tests/test_inference_engine.py similarity index 61% rename from exo/inference/pytorch/test_inference_engine.py rename to exo/inference/pytorch/tests/test_inference_engine.py index 4bad37c26..7e64c137a 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/tests/test_inference_engine.py @@ -8,8 +8,14 @@ from exo.helpers import DEBUG import os import numpy as np +import time + +async def test_inference_engine( + inference_engine_1: InferenceEngine, + inference_engine_2: InferenceEngine, + model_id: str, + n_layers: int): -async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int): # prompt = "Why is the sky blue?" prompt = "In a single word only, what is the last name of the current president of the USA?" @@ -26,7 +32,11 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e prompt=prompt ) - print(f"resp_full: {resp_full}") + print("\n------------resp_full---------------\n") + print(resp_full) + print("\n------------resp_full---------------\n") + + time.sleep(5) next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( "A", @@ -35,10 +45,14 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_full, ) - print(f"next_resp_full: {next_resp_full}") + print("\n------------next_resp_full---------------\n") + print(next_resp_full) + print("\n------------next_resp_full---------------\n") + + time.sleep(5) pp = int(n_layers/2) - + resp_shard = Shard( model_id=model_id, start_layer=0, @@ -59,6 +73,13 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e prompt=prompt ) + print("\n------------resp1---------------\n") + print(resp1) + print("\n------------resp1---------------\n") + + time.sleep(5) + + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( "B", shard=resp_shard2, @@ -66,6 +87,10 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_1, ) + print("\n------------resp2---------------\n") + print(resp2) + print("\n------------resp2---------------\n") + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( "B", shard=resp_shard, @@ -73,6 +98,10 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_2, ) + print("\n------------resp3---------------\n") + print(resp3) + print("\n------------resp3---------------\n") + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( "B", shard=resp_shard2, @@ -80,42 +109,46 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_3, ) + print("\n------------resp4---------------\n") + print(resp4) + print("\n------------resp4---------------\n") + assert np.array_equal(resp_full, resp2) assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': + try: + print(f"\n\n -------- TEST QWEN2 -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "Qwen/Qwen2-0.5B-Instruct", + 24 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + # try: - # print(f"\n\n -------- TEST QWEN2 -------- \n\n") + # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") # asyncio.run(test_inference_engine( # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Qwen/Qwen2-0.5B-Instruct", - # 24 + # "andrijdavid/Llama3-1B-Base", + # 3 # )) # except Exception as err: - # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") # try: - # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") + # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") # asyncio.run(test_inference_engine( # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "andrijdavid/Llama3-1B-Base", - # 3 + # "meta-llama/Meta-Llama-3.1-8B", + # 32 # )) # except Exception as err: - # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") - - try: - print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "meta-llama/Meta-Llama-3.1-8B", - 32 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") # try: # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") @@ -128,14 +161,14 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e # except Exception as err: # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") - try: - print(f"\n\n --------- TEST ambrosfitz/TinyLlama-1.1B-Chat-yawp -------\n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "ambrosfitz/TinyLlama-1.1B-Chat-yawp", - 22 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! ambrosfitz/TinyLlama-1.1B-Chat-yawp TEST FAILED \n{err}\n") + #try: + # print(f"\n\n --------- TEST TinyLlama/TinyLlama_v1.1 -------\n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "TinyLlama/TinyLlama_v1.1", + # 22 + # )) + #except Exception as err: + # print(f"\n\n !!!!!!!!!!! TinyLlama/TinyLlama_v1.1 TEST FAILED \n{err}\n") diff --git a/exo/inference/pytorch/tests/test_simple_model.py b/exo/inference/pytorch/tests/test_simple_model.py new file mode 100644 index 000000000..1b08a1801 --- /dev/null +++ b/exo/inference/pytorch/tests/test_simple_model.py @@ -0,0 +1,44 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +device = "cuda" # the device to load the model onto + +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen2-0.5B-Instruct", + torch_dtype="auto", + device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + +prompt = "In a single word only, what is the last name of the current president of the USA?" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} +] +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True +) +model_inputs = tokenizer([text], return_tensors="pt").to(device) + +print(f"model_inputs:\n{model_inputs}") + +print(f"generation_config:\n{model.generation_config}") + +generated_ids = model.generate( + model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + max_new_tokens=512, + do_sample=True, + #top_k=20, + #num_beams=5, + #early_stopping=True +) +generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) +] + +response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + +print(f"Prompt: {prompt}\n") +print(f"Response: {response}\n") diff --git a/exo/inference/pytorch/tests/test_split_model.py b/exo/inference/pytorch/tests/test_split_model.py new file mode 100644 index 000000000..827bdec2e --- /dev/null +++ b/exo/inference/pytorch/tests/test_split_model.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +import asyncio +import gc +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + DynamicCache, + Cache, + LogitsProcessorList, + #MinLengthLogitsProcessor, + LogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TemperatureLogitsWarper, + StoppingCriteriaList, + MaxLengthCriteria, + MaxTimeCriteria +) + +from transformers.generation.configuration_utils import ( + GenerationConfig, + GenerationMode +) + +# llama +from transformers.models.llama.modeling_llama import LlamaModel + +# qwen2 +from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + +from exo.api.chatgpt_api import resolve_tokenizer +from typing import Tuple, Optional, Union, List +import re + +TEMP = 0.6 +TOP_K = 60 + +class OnionHuggingFaceLM(): + def __init__(self, layers, is_last=False): + self.layers = layers + self.is_last = is_last + self.past_key_values = None + self.cache_position = None + self.position_ids = None + self.input_embed = None + self.causal_mask = None + self.position_embeddings = None + self.attention_mask = None + self.input_ids = None + self.hidden_states = None + self.next_decoder_cache = None + + def forward( + self, + model, + llm_model, + input_ids: Optional[torch.tensor] = None, + hidden_states: Optional[torch.tensor] = None, + attention_mask: Optional[torch.tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + **kwargs + ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: + + """ + Generate hidden states or logits via passing through set amount of layers of a model + To be passed only input_ids OR hidden_state and not both. This is for connecting the model + layer to generate a complete output + + Args: + model: base llm model tramsformers class + llm_model: llm chat model class + input_ids: tensor Optional + hidden_states: tensor Optional + + Returns: + Tuple of + - hidden_states: tensor Optional + - past_key_values + - logits: tensor Optional + + """ + output_attentions = False # outputting attention not needed + use_legacy_cache = False # some models still use legacy kv store + + if input_ids is not None and hidden_states is not None: + raise ValueError + + if hidden_states is not None: + self.hidden_states = hidden_states + + if input_ids is not None: + self.input_ids = input_ids + + # embed input_ids + self.inputs_embeds = model.embed_tokens(self.input_ids) + + # cache + if past_key_values and not isinstance(past_key_values, Cache): + print("Using legacy cache") + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + position_ids = cache_position.unsqueeze(0) + + # causal mask + self.attention_mask = attention_mask + self.causal_mask = model._update_causal_mask( + None, + self.inputs_embeds, + cache_position, + past_key_values, + output_attentions + ) + + #print(f"causal_mask.dim(): {self.causal_mask.dim()}") + + print(f"\ncausal_mask:{self.causal_mask}\n\n") + + # embed positions, some models require and some dont + if isinstance(model, LlamaModel): + self.position_embeddings = model.rotary_emb( + self.inputs_embeds, + position_ids + ) + + model_inputs = llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=position_ids, + cache_position=cache_position + ) + + print(f"model_inputs\n{model_inputs}") + + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] + + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + self.hidden_states, + attention_mask=self.causal_mask, + position_ids=self.position_ids, + past_key_values=self.past_key_values, + use_cache=True, + cache_position=self.cache_position + + ) + + self.hidden_states = layer_outputs[0] + self.next_decoder_cache = layer_outputs[1] + + if self.is_last: + self.hidden_states = model.norm(self.hidden_states) + + if use_legacy_cache: + self.past_key_values = self.next_decoder_cache.to_legacy_cache() + else: + self.past_key_values = self.next_decoder_cache + + # lm_head + logits = llm_model.lm_head(self.hidden_states).to("cuda") + + return ( + None, + None, + logits + ) + + return ( + self.hidden_states, + self.past_key_values, + None + ) + +async def model_half_split_test(prompt: str, model_id: str, layers: int): + """ + Test for splitting in half + """ + + half_layers = int(layers / 2) + + # inference + tokenizer = AutoTokenizer.from_pretrained(model_id) + max_length = 512 #tokenizer.model_max_length + + # get llm model + llm_model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype="auto", + device_map="auto", + use_cache=True + ) + + # get base model + model = llm_model.model + + # add pad token if none, depending on model + if tokenizer.pad_token == None: + if re.match(r"Llama|llama", model_id): + tokenizer.add_special_tokens({"pad_token":""}) + model.resize_token_embeddings(len(tokenizer)) + + + # generate input_ids + messages = [{"role": "user", "content": prompt}] + txt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + inputs = tokenizer([txt], return_tensors="pt") + input_ids = inputs.input_ids.to("cuda") + input_attention_mask = inputs.attention_mask.to("cuda") + batch_size, seq_length = input_ids.shape[:2] + + is_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + logit_runs = 1 + + raw_logits = None + + while not is_finished: + print(f"\n\nLOGIT RUN {logit_runs}\n\n") + + print(f"input_ids:\n{input_ids}\n") + print(input_ids.shape) + + print("\n first half of layers") + shard_layers = nn.ModuleList(model.layers[:half_layers])#.to("cuda") + #shard_layers = nn.ModuleList(model.layers) + sharded_model = OnionHuggingFaceLM(layers=shard_layers) + #sharded_model.is_last = True + + # generate first half + # add if first layer of model check + shard_hidden_states, shard_past_kvs, shard_logits = sharded_model.forward( + model=model, + llm_model=llm_model, + attention_mask=input_attention_mask, + input_ids=input_ids, + hidden_states=None + ) + + # second half + print(f"\n second half of layers") + sharded_model.layers = nn.ModuleList(model.layers[half_layers:]) + sharded_model.is_last = True + + shard_hidden_states, shard_past_kvs, shard_logits = sharded_model.forward( + model=model, + llm_model=llm_model, + hidden_states=shard_hidden_states, + past_key_values=shard_past_kvs + ) + + # this part of the generation and _sample functions for transformers GenerationMixin + # ref: https://github.com/huggingface/transformers/blob/0a55d9f7376f72ad3ff296d4249840021b03bcc4/src/transformers/generation/utils.py#L1301 + + # clone logit sample + logits = shard_logits[:, -1, :].clone().float() + + raw_logits = logits + + # distribute + logits_processor = LogitsProcessorList([ + TopKLogitsWarper(35), + TemperatureLogitsWarper(0.6), + TopPLogitsWarper(0.8) + ]) + + stopping_critera = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=50), + MaxTimeCriteria(max_time=10.0), + ] + ) + + next_token_scores = logits_processor(input_ids, logits) + + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + #next_tokens = torch.argmax(next_token_scores, dim=-1) + + # get inputs ready incase not finished + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + unfinished_sequences = unfinished_sequences & ~stopping_critera(input_ids, None) + is_finished = unfinished_sequences.max() == 0 + + print(f"is_finished?:\n{is_finished}\n") + + logit_runs += 1 + + del logits + del shard_logits + + print(f"model.generation_config\n{llm_model.generation_config}") + + generated_text = tokenizer.batch_decode( + input_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False + )[0] + + print(f"generated_text:\n{generated_text}\n") + + # free model from memory + del model + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + #prompt = "In a single word only, what is the last name of the current president of the USA?" + prompt = "What color is the sky? Explain why" + #prompt = "In a single word only, what is the color of an apple?" + + #print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") + #model_id = "TinyLlama/TinyLlama_v1.1" + #model_layers = 22 + + #asyncio.run( + # model_half_split_test( + # prompt=prompt, + # model_id=model_id, + # layers=model_layers + # ) + #) + + #print("\n-------- Test meta-llama/Meta-Llama-3.1-8B ----------\n") + #model_id = "meta-llama/Meta-Llama-3.1-8B" + #model_layers = 32 + + #asyncio.run( + # model_half_split_test( + # prompt=prompt, + # model_id=model_id, + # layers=model_layers + # ) + #) + + print("\n-------- Test Qwen/Qwen2-0.5B-Instruct ----------\n") + model_id = "Qwen/Qwen2-0.5B-Instruct" + model_layers = 24 + + asyncio.run( + model_half_split_test( + prompt=prompt, + model_id=model_id, + layers=model_layers + ) + ) + diff --git a/exo/inference/pytorch/tests/utils.py b/exo/inference/pytorch/tests/utils.py new file mode 100644 index 000000000..e4062da96 --- /dev/null +++ b/exo/inference/pytorch/tests/utils.py @@ -0,0 +1,185 @@ +import torch +from torch.nn import functional as F + +def top_k_sampling(logits, thres): + num_logits = logits.shape[-1] + val, ind = torch.topk(logits, thres, dim=-1, largest=True, sorted=True) + mask = torch.zeros_like(logits) + mask.scatter_(-1, ind, 1) + logits = logits * mask + + return logits + +def top_p_sampling(logits, thres): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + print(f"top_p_sampling sorted_logits\n{sorted_logits}\nsorted_indices {sorted_indices}") + softmax_logits = F.softmax(sorted_logits, dim=-1) + print(f"top_p_sampling\nsoftmax_logits {softmax_logits}") + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + print(f"top_p_sampling\n{cumulative_probs}") + + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > thres + + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) + print(f"top_p_sampling\nindicies_to_remove: {indices_to_remove}") + logits[indices_to_remove] = float('-inf') + return logits + +def sample_logits(logits, temp, top_p, top_k): + """ + Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. + + Args: + logits (torch.Tensor): The logits distribution to sample from. + temp (float): temp for scaling logits. + top_p (float): The cumulative probability threshold for nucleus sampling. + + Returns: + torch.Tensor: The selected token index. + """ + # If temp is very low, just use argmax + if temp == 0: + return logits.argmax(dim=-1) + + print(f"logits {logits}") + + scaled_logits = logits/temp + + print(f"scaled_logits: {scaled_logits}") + + if 0 < top_p < 1.0: + top_p_logits = top_p_sampling(scaled_logits, top_p) + print(f"top_p logits {top_p_logits}") + if top_k > 0: + top_k_logits = top_k_sampling(top_p_logits, top_k) + return top_k_logits.argmax(dim=-1) + elif top_k > 0: + top_k_logits = top_k_sampling(logits, top_k) + print(f"top_k logits {top_k_logits}") + return top_k_logits.argmax(dim=-1) + + return scaled_logits.argmax(dim=-1) + + +# from tinygrad llama model sample +def sample(logits: torch.Tensor, temp: float, k: int, p: float, af: float, ap: float): + assert logits.ndim == 1, "only works on 1D tensors" + assert 0 <= p <= 1, "p must be between 0 and 1" + assert 0 <= k <= logits.numel(), "k must be between 0 and numel" + + # If temperature is very low, just use argmax + if temp < 1e-6: + return logits.argmax().reshape(1) + + # Alpha sampling + if af or ap: + if not hasattr(sample, "alpha_counter"): + sample.alpha_counter = torch.zeros_like(logits, dtype=torch.int32).contiguous() + logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0).float() * ap) + + # Replace NaNs with -inf + logits = torch.where(logits != logits, torch.tensor(-float("inf"), device=logits.device), logits) + + # Apply softmax after temperature scaling + t = F.softmax(logits / temp, dim=-1) + + counter = torch.arange(t.numel(), device=logits.device).contiguous() + counter2 = torch.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous() + + # Top-k sampling + if k: + output = torch.zeros(k, device=logits.device).contiguous() + output_indices = torch.zeros(k, device=logits.device, dtype=torch.int32).contiguous() + + for i in range(k): + t_max = t.max() + t_argmax = (t.numel() - ((t == t_max) * counter2).max() - 1).to(torch.int) + output[i] = t_max + output_indices[i] = t_argmax + t = torch.where(counter == t_argmax, torch.tensor(0.0, device=logits.device), t) + + # Approximate top-p sampling + output_cumsum = output.flip(dims=(0,)).cumsum(dim=0).flip(dims=(0,)) + t.sum() + mask = output_cumsum >= (1 - p) + output = output * mask.float() + output_indices = output_indices * mask.int() + + # Sample from the distribution + output_idx = output.multinomial(num_samples=1) + output_token = output_indices[output_idx] + else: + output_token = t.multinomial(num_samples=1) + + # Increase alpha counter + if af or ap: + sample.alpha_counter = torch.where(counter == output_token, sample.alpha_counter + 1, sample.alpha_counter) + + return output_token + + +def sample_3d(logits: torch.Tensor, temp: float, k: int, p: float, af: float, ap: float): + assert logits.ndim == 3, "only works on 3D tensors" + assert 0 <= p <= 1, "p must be between 0 and 1" + assert 0 <= k <= logits.shape[-1], "k must be between 0 and the last dimension size" + + batch_size, seq_len, vocab_size = logits.shape + + # If temperature is very low, just use argmax + if temp < 1e-6: + return logits.argmax(dim=-1) + + # Alpha sampling + if af or ap: + if not hasattr(sample, "alpha_counter"): + sample.alpha_counter = torch.zeros_like(logits, dtype=torch.int32).contiguous() + logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0).float() * ap) + + # Replace NaNs with -inf + logits = torch.where(logits != logits, torch.tensor(-float("inf"), device=logits.device), logits) + + # Apply softmax after temperature scaling + t = F.softmax(logits / temp, dim=-1) + + counter = torch.arange(vocab_size, device=logits.device).unsqueeze(0).unsqueeze(0).expand_as(t).contiguous() + counter2 = torch.arange(vocab_size - 1, -1, -1, device=logits.device).unsqueeze(0).unsqueeze(0).expand_as(t).contiguous() + + # Top-k sampling + if k: + output = torch.zeros((batch_size, seq_len, k), device=logits.device).contiguous() + output_indices = torch.zeros((batch_size, seq_len, k), device=logits.device, dtype=torch.int32).contiguous() + + for i in range(k): + t_max, _ = t.max(dim=-1, keepdim=True) + t_argmax = (vocab_size - ((t == t_max) * counter2).max(dim=-1, keepdim=True)[0] - 1).to(torch.int) + output[:, :, i] = t_max.squeeze(-1) + output_indices[:, :, i] = t_argmax.squeeze(-1) + t = torch.where(counter == t_argmax, torch.tensor(0.0, device=logits.device), t) + + # Approximate top-p sampling + output_cumsum = output.flip(dims=(-1,)).cumsum(dim=-1).flip(dims=(-1,)) + t.sum(dim=-1, keepdim=True) + mask = output_cumsum >= (1 - p) + output = output * mask.float() + output_indices = output_indices * mask.int() + + # Sample from the distribution + output_flat = output.view(batch_size * seq_len, -1) + output_idx = output_flat.multinomial(num_samples=1).squeeze(-1) + output_indices_flat = output_indices.view(batch_size * seq_len, -1) + output_token = output_indices_flat.gather(dim=-1, index=output_idx.unsqueeze(-1)).view(batch_size, seq_len) + else: + output_flat = t.view(batch_size * seq_len, -1) + output_token = output_flat.multinomial(num_samples=1).view(batch_size, seq_len) + + # Increase alpha counter + if af or ap: + sample.alpha_counter = torch.where(counter == output_token.unsqueeze(-1), sample.alpha_counter + 1, sample.alpha_counter) + + return output_token + diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index 8f7503500..55c257397 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -12,6 +12,10 @@ import numpy as np from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load from exo.download.shard_download import ShardDownloader +from concurrent.futures import ThreadPoolExecutor +import asyncio +import threading +from functools import partial Tensor.no_grad = True # default settings @@ -52,14 +56,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine): def __init__(self, shard_downloader: ShardDownloader): self.shard = None self.shard_downloader = shard_downloader + self.executor = ThreadPoolExecutor(max_workers=1) async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): await self.ensure_shard(shard) start_pos = json.loads(inference_state or "{}").get("start_pos", 0) n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0) - toks = self.tokenizer.encode(prompt) - h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize() + toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt) + h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()) if h.shape == (1,): start_pos += len(toks) @@ -75,7 +80,7 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr start_pos = json.loads(inference_state or "{}").get("start_pos", 0) n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0) - h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize() + h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()) if h.shape == (1,): start_pos += n_captured_toks @@ -90,6 +95,10 @@ async def ensure_shard(self, shard: Shard): return model_path = await self.shard_downloader.ensure_shard(shard) - self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B") - self.tokenizer = await resolve_tokenizer(str((model_path if model_path.is_dir() else model_path.parent))) - self.shard = shard + + if self.shard != shard: + self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B") + + tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent)) + self.tokenizer = await resolve_tokenizer(tokenizer_path) + self.shard = shard diff --git a/exo/inference/tokenizers.py b/exo/inference/tokenizers.py index 9accd9436..bf4b5a6c2 100644 --- a/exo/inference/tokenizers.py +++ b/exo/inference/tokenizers.py @@ -1,5 +1,8 @@ import traceback from aiofiles import os as aios +from os import PathLike +from pathlib import Path +from typing import Union from transformers import AutoTokenizer, AutoProcessor from exo.download.hf.hf_helpers import get_local_snapshot_dir from exo.helpers import DEBUG @@ -8,7 +11,7 @@ async def resolve_tokenizer(model_id: str): local_path = await get_local_snapshot_dir(model_id) if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}") try: - if await aios.path.exists(local_path): + if local_path and await aios.path.exists(local_path): if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}") return await _resolve_tokenizer(local_path) except: @@ -16,14 +19,10 @@ async def resolve_tokenizer(model_id: str): if DEBUG >= 5: traceback.print_exc() return await _resolve_tokenizer(model_id) -async def _resolve_tokenizer(model_id_or_local_path: str): +async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]): try: if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}") - if "Mistral-Large" in str(model_id_or_local_path): - use_fast = True - else: - use_fast = False - processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=use_fast) + processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False) if not hasattr(processor, 'eos_token_id'): processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id if not hasattr(processor, 'encode'): diff --git a/exo/models.py b/exo/models.py index 137b881ce..f83bc1423 100644 --- a/exo/models.py +++ b/exo/models.py @@ -9,7 +9,11 @@ }, "llama-3.1-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), - "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80), + "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80), + }, + "llama-3.1-70b-bf16": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16", start_layer=0, end_layer=0, n_layers=80), + "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80), }, "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),}, "llama-3-8b": { @@ -41,4 +45,20 @@ "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), }, + ### qwen + "qwen-2.5-7b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), + }, + "qwen-2.5-math-7b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), + }, + "qwen-2.5-14b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48), + }, + "qwen-2.5-72b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + }, + "qwen-2.5-math-72b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + }, } diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index 0629dc777..1c7e17b07 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -11,6 +11,7 @@ from exo.topology.topology import Topology from exo.topology.device_capabilities import DeviceCapabilities +from exo.helpers import DEBUG class GRPCPeerHandle(PeerHandle): def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities): @@ -23,12 +24,16 @@ def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabiliti def id(self) -> str: return self._id + def addr(self) -> str: + return self.address + def device_capabilities(self) -> DeviceCapabilities: return self._device_capabilities async def connect(self): self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)]) self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel) + await self.channel.channel_ready() async def is_connected(self) -> bool: return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY @@ -52,6 +57,8 @@ async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] request_id=request_id, inference_state=inference_state, ) + + print(f"request: {request}") response = await self.stub.SendPrompt(request) if not response.tensor_data or not response.shape or not response.dtype: diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index 1481ef512..c7e591411 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -113,6 +113,6 @@ async def SendResult(self, request, context): async def SendOpaqueStatus(self, request, context): request_id = request.request_id status = request.status - if DEBUG >= 5: print(f"Received SendOpaqueStatus request: {request_id=} {status=}") + if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}") self.node.on_opaque_status.trigger_all(request_id, status) return node_service_pb2.Empty() diff --git a/exo/networking/grpc/test_grpc_discovery.py b/exo/networking/grpc/test_grpc_discovery.py deleted file mode 100644 index 13372bbb4..000000000 --- a/exo/networking/grpc/test_grpc_discovery.py +++ /dev/null @@ -1,22 +0,0 @@ -import asyncio -import unittest -from .grpc_discovery import GRPCDiscovery - - -class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679) - self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678) - await self.node1.start() - await self.node2.start() - - async def asyncTearDown(self): - await self.node1.stop() - await self.node2.stop() - - async def test_discovery(self): - await asyncio.sleep(4) - - # Check discovered peers - print("Node1 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()])) - print("Node2 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()])) diff --git a/exo/networking/peer_handle.py b/exo/networking/peer_handle.py index cf232d006..390966bd5 100644 --- a/exo/networking/peer_handle.py +++ b/exo/networking/peer_handle.py @@ -5,12 +5,15 @@ from exo.topology.device_capabilities import DeviceCapabilities from exo.topology.topology import Topology - class PeerHandle(ABC): @abstractmethod def id(self) -> str: pass + @abstractmethod + def addr(self) -> str: + pass + @abstractmethod def device_capabilities(self) -> DeviceCapabilities: pass @@ -36,13 +39,13 @@ async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional pass @abstractmethod - async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None: pass @abstractmethod - async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: pass @abstractmethod - async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: pass diff --git a/exo/networking/test_udp_discovery.py b/exo/networking/test_udp_discovery.py new file mode 100644 index 000000000..fb5094a05 --- /dev/null +++ b/exo/networking/test_udp_discovery.py @@ -0,0 +1,76 @@ +import asyncio +import unittest +from unittest import mock +from exo.networking.udp_discovery import UDPDiscovery +from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle +from exo.networking.grpc.grpc_server import GRPCServer +from exo.orchestration.node import Node + +class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.peer1 = mock.AsyncMock() + self.peer2 = mock.AsyncMock() + self.peer1.connect = mock.AsyncMock() + self.peer2.connect = mock.AsyncMock() + self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1) + self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2) + await self.discovery1.start() + await self.discovery2.start() + + async def asyncTearDown(self): + await self.discovery1.stop() + await self.discovery2.stop() + + async def test_discovery(self): + peers1 = await self.discovery1.discover_peers(wait_for_peers=1) + assert len(peers1) == 1 + peers2 = await self.discovery2.discover_peers(wait_for_peers=1) + assert len(peers2) == 1 + + # connect has to be explicitly called after discovery + self.peer1.connect.assert_not_called() + self.peer2.connect.assert_not_called() + + +class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.node1 = mock.AsyncMock(spec=Node) + self.node2 = mock.AsyncMock(spec=Node) + self.server1 = GRPCServer(self.node1, "localhost", 50053) + self.server2 = GRPCServer(self.node2, "localhost", 50054) + await self.server1.start() + await self.server2.start() + self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities)) + self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities)) + await self.discovery1.start() + await self.discovery2.start() + + async def asyncTearDown(self): + await self.discovery1.stop() + await self.discovery2.stop() + await self.server1.stop() + await self.server2.stop() + + async def test_grpc_discovery(self): + peers1 = await self.discovery1.discover_peers(wait_for_peers=1) + assert len(peers1) == 1 + peers2 = await self.discovery2.discover_peers(wait_for_peers=1) + assert len(peers2) == 1 + assert not await peers1[0].is_connected() + assert not await peers2[0].is_connected() + + # Connect + await peers1[0].connect() + await peers2[0].connect() + assert await peers1[0].is_connected() + assert await peers2[0].is_connected() + + # Kill server1 + await self.server1.stop() + + assert await peers1[0].is_connected() + assert not await peers2[0].is_connected() + + +if __name__ == "__main__": + asyncio.run(unittest.main()) diff --git a/exo/networking/grpc/grpc_discovery.py b/exo/networking/udp_discovery.py similarity index 52% rename from exo/networking/grpc/grpc_discovery.py rename to exo/networking/udp_discovery.py index eb08a8385..e92b6c1b5 100644 --- a/exo/networking/grpc/grpc_discovery.py +++ b/exo/networking/udp_discovery.py @@ -2,13 +2,12 @@ import json import socket import time +import traceback from typing import List, Dict, Callable, Tuple, Coroutine -from ..discovery import Discovery -from ..peer_handle import PeerHandle -from .grpc_peer_handle import GRPCPeerHandle +from .discovery import Discovery +from .peer_handle import PeerHandle from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES -from exo import DEBUG_DISCOVERY - +from exo.helpers import DEBUG, DEBUG_DISCOVERY class ListenProtocol(asyncio.DatagramProtocol): def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]): @@ -23,28 +22,30 @@ def datagram_received(self, data, addr): asyncio.create_task(self.on_message(data, addr)) -class GRPCDiscovery(Discovery): +class UDPDiscovery(Discovery): def __init__( self, node_id: str, node_port: int, listen_port: int, - broadcast_port: int = None, + broadcast_port: int, + create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle], broadcast_interval: int = 1, - device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES, discovery_timeout: int = 30, + device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES, ): self.node_id = node_id self.node_port = node_port - self.device_capabilities = device_capabilities self.listen_port = listen_port - self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port + self.broadcast_port = broadcast_port + self.create_peer_handle = create_peer_handle self.broadcast_interval = broadcast_interval - self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {} + self.discovery_timeout = discovery_timeout + self.device_capabilities = device_capabilities + self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {} self.broadcast_task = None self.listen_task = None self.cleanup_task = None - self.discovery_timeout = discovery_timeout async def start(self): self.device_capabilities = device_capabilities() @@ -53,68 +54,45 @@ async def start(self): self.cleanup_task = asyncio.create_task(self.task_cleanup_peers()) async def stop(self): - if self.broadcast_task: - self.broadcast_task.cancel() - if self.listen_task: - self.listen_task.cancel() - if self.cleanup_task: - self.cleanup_task.cancel() + if self.broadcast_task: self.broadcast_task.cancel() + if self.listen_task: self.listen_task.cancel() + if self.cleanup_task: self.cleanup_task.cancel() if self.broadcast_task or self.listen_task or self.cleanup_task: await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True) async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: - if DEBUG_DISCOVERY >= 2: - print("Starting peer discovery process...") - if wait_for_peers > 0: - while len(self.known_peers) == 0: - if DEBUG_DISCOVERY >= 2: - print("No peers discovered yet, retrying in 1 second...") - await asyncio.sleep(1) # Keep trying to find peers - if DEBUG_DISCOVERY >= 2: - print(f"Discovered first peer: {next(iter(self.known_peers.values()))}") - - grace_period = 5 # seconds - while True: - initial_peer_count = len(self.known_peers) - if DEBUG_DISCOVERY >= 2: - print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...") - if len(self.known_peers) == initial_peer_count: - if wait_for_peers > 0: - await asyncio.sleep(grace_period) - if DEBUG_DISCOVERY >= 2: - print(f"Waiting additional {wait_for_peers} seconds for more peers.") - wait_for_peers = 0 - else: - if DEBUG_DISCOVERY >= 2: - print("No new peers discovered in the last grace period. Ending discovery process.") - break # No new peers found in the grace period, we are done - + while len(self.known_peers) < wait_for_peers: + if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...") + await asyncio.sleep(0.1) return [peer_handle for peer_handle, _, _ in self.known_peers.values()] async def task_broadcast_presence(self): - transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET) - sock = transport.get_extra_info("socket") - sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) - - message = json.dumps({ - "type": "discovery", - "node_id": self.node_id, - "grpc_port": self.node_port, - "device_capabilities": self.device_capabilities.to_dict(), - }).encode("utf-8") - while True: try: - if DEBUG_DISCOVERY >= 3: - print(f"Broadcast presence: {message}") + message = json.dumps({ + "type": "discovery", + "node_id": self.node_id, + "grpc_port": self.node_port, + "device_capabilities": self.device_capabilities.to_dict(), + }).encode("utf-8") + if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}") + + transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET) + sock = transport.get_extra_info("socket") + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) transport.sendto(message, ("", self.broadcast_port)) - await asyncio.sleep(self.broadcast_interval) except Exception as e: print(f"Error in broadcast presence: {e}") - import traceback - print(traceback.format_exc()) + finally: + if transport: + try: + transport.close() + except: + if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}") + if DEBUG_DISCOVERY >= 2: traceback.print_exc() + await asyncio.sleep(self.broadcast_interval) async def on_listen_message(self, data, addr): if not data: @@ -124,40 +102,35 @@ async def on_listen_message(self, data, addr): # Check if the decoded data starts with a valid JSON character if not (decoded_data.strip() and decoded_data.strip()[0] in "{["): - if DEBUG_DISCOVERY >= 2: - print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}") + if DEBUG_DISCOVERY >= 2: print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}") return try: decoder = json.JSONDecoder(strict=False) message = decoder.decode(decoded_data) except json.JSONDecodeError as e: - if DEBUG_DISCOVERY >= 2: - print(f"Error decoding JSON data from {addr}: {e}") + if DEBUG_DISCOVERY >= 2: print(f"Error decoding JSON data from {addr}: {e}") return - if DEBUG_DISCOVERY >= 2: - print(f"received from peer {addr}: {message}") + if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}") if message["type"] == "discovery" and message["node_id"] != self.node_id: peer_id = message["node_id"] peer_host = addr[0] peer_port = message["grpc_port"] device_capabilities = DeviceCapabilities(**message["device_capabilities"]) - if peer_id not in self.known_peers: + if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}": + if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}") self.known_peers[peer_id] = ( - GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), + self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time(), time.time(), ) - if DEBUG_DISCOVERY >= 2: - print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}") self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time()) async def task_listen_for_peers(self): await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port)) - if DEBUG_DISCOVERY >= 2: - print("Started listen task") + if DEBUG_DISCOVERY >= 2: print("Started listen task") async def task_cleanup_peers(self): while True: @@ -167,22 +140,12 @@ async def task_cleanup_peers(self): peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout ] - if DEBUG_DISCOVERY >= 2: - print( - "Peer statuses:", - {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" - for peer_handle, connected_at, last_seen in self.known_peers.values()}, - ) - if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: - print(f"Cleaning up peers: {peers_to_remove}") + if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()}) for peer_id in peers_to_remove: - if peer_id in self.known_peers: - del self.known_peers[peer_id] - if DEBUG_DISCOVERY >= 2: - print(f"Removed peer {peer_id} due to inactivity.") - await asyncio.sleep(self.broadcast_interval) + if peer_id in self.known_peers: del self.known_peers[peer_id] + if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.") except Exception as e: print(f"Error in cleanup peers: {e}") - import traceback - print(traceback.format_exc()) + finally: + await asyncio.sleep(self.broadcast_interval) diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 2a54dc9b3..15e1e16d6 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -50,7 +50,7 @@ async def start(self, wait_for_peers: int = 0) -> None: await self.update_peers(wait_for_peers) await self.collect_topology() if DEBUG >= 2: print(f"Collected topology: {self.topology}") - asyncio.create_task(self.periodic_topology_collection(5)) + asyncio.create_task(self.periodic_topology_collection(1.0)) async def stop(self) -> None: await self.discovery.stop() @@ -67,7 +67,7 @@ def on_node_status(self, request_id, opaque_status): self.current_topology.active_node_id = None download_progress = None if status_data.get("type", "") == "download_progress": - if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}") + if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}") download_progress = RepoProgressEvent.from_dict(status_data.get('progress')) self.node_download_progress[status_data.get('node_id')] = download_progress if self.topology_viz: @@ -277,23 +277,75 @@ def get_current_shard(self, base_shard: Shard) -> Shard: raise ValueError(f"No current partition found for node: {self.id}") return shards[current_partition_index] - async def update_peers(self, wait_for_peers: int = 0) -> None: - self.peers = await self.discovery.discover_peers(wait_for_peers) - for peer in self.peers: - is_connected = await peer.is_connected() - if DEBUG >= 2 and is_connected: - print(f"Already connected to {peer.id()}: {is_connected}") - if not is_connected: - if DEBUG >= 2: print(f"Connecting to {peer.id()}...") - await peer.connect() - if DEBUG >= 1: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})") + async def update_peers(self, wait_for_peers: int = 0) -> bool: + next_peers = await self.discovery.discover_peers(wait_for_peers) + current_peer_ids = {peer.id() for peer in self.peers} + next_peer_ids = {peer.id() for peer in next_peers} + peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids] + peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids] + peers_updated = [ + peer for peer in next_peers + if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id()) + ] + peers_unchanged = [ + peer for peer in next_peers + if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id()) + ] + peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()] + peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()] + + def _pretty(peers: List[PeerHandle]) -> List[str]: + return [f"{peer.id()}@{peer.addr()}" for peer in peers] + if DEBUG >= 2: print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}") + + async def disconnect_with_timeout(peer, timeout=5): + try: + await asyncio.wait_for(peer.disconnect(), timeout) + return True + except Exception as e: + print(f"Error disconnecting peer {peer.id()}@{peer.addr()}: {e}") + traceback.print_exc() + return False + + async def connect_with_timeout(peer, timeout=5): + try: + await asyncio.wait_for(peer.connect(), timeout) + return True + except Exception as e: + print(f"Error connecting peer {peer.id()}@{peer.addr()}: {e}") + traceback.print_exc() + return False + + disconnect_results = await asyncio.gather( + *(disconnect_with_timeout(peer) for peer in peers_to_disconnect), + return_exceptions=True + ) + connect_results = await asyncio.gather( + *(connect_with_timeout(peer) for peer in peers_to_connect), + return_exceptions=True + ) + + successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True] + failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False] + successful_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is True] + failed_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is False] + if DEBUG >= 1: + if successful_disconnects: print(f"Successfully disconnected peers: {_pretty(successful_disconnects)}") + if failed_disconnects: print(f"Failed to disconnect peers: {_pretty(failed_disconnects)}") + if successful_connects: print(f"Successfully connected peers: {_pretty(successful_connects)}") + if failed_connects: print(f"Failed to connect peers: {_pretty(failed_connects)}") + + self.peers = next_peers + return len(peers_to_connect) > 0 or len(peers_to_disconnect) > 0 async def periodic_topology_collection(self, interval: int): while True: await asyncio.sleep(interval) try: - await self.update_peers() - await self.collect_topology() + did_peers_change = await self.update_peers() + if DEBUG >= 2: print(f"{did_peers_change=}") + if did_peers_change: + await self.collect_topology() except Exception as e: print(f"Error collecting topology: {e}") traceback.print_exc() @@ -310,7 +362,7 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}") prev_visited = visited.copy() - # TODO: should we add our own peer id here? + visited.add(self.id) visited.update(p.id() for p in self.peers) for peer in self.peers: @@ -325,7 +377,7 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) continue try: - other_topology = await peer.collect_topology(visited, max_depth=max_depth - 1) + other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0) if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}") self.topology.merge(other_topology) except Exception as e: @@ -362,7 +414,7 @@ async def send_result_to_peer(peer): await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True) async def broadcast_opaque_status(self, request_id: str, status: str) -> None: - if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}") + if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}") async def send_status_to_peer(peer): try: diff --git a/main.py b/main.py index a39294ff8..c09952afb 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,8 @@ import uuid from exo.orchestration.standard_node import StandardNode from exo.networking.grpc.grpc_server import GRPCServer -from exo.networking.grpc.grpc_discovery import GRPCDiscovery +from exo.networking.udp_discovery import UDPDiscovery +from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy from exo.api import ChatGPTAPI from exo.download.shard_download import ShardDownloader, RepoProgressEvent @@ -33,7 +34,7 @@ parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds") parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting") parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port") -parser.add_argument("--chatgpt-api-response-timeout-secs", type=int, default=90, help="ChatGPT API response timeout in seconds") +parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds") parser.add_argument("--max-generate-tokens", type=int, default=1024, help="Max tokens to generate in each request") parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use") parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI") @@ -66,7 +67,7 @@ for chatgpt_api_endpoint in chatgpt_api_endpoints: print(f" - {terminal_link(chatgpt_api_endpoint)}") -discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout) +discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout) topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None node = StandardNode( @@ -83,7 +84,7 @@ api = ChatGPTAPI( node, inference_engine.__class__.__name__, - response_timeout_secs=args.chatgpt_api_response_timeout_secs, + response_timeout=args.chatgpt_api_response_timeout, on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None ) node.on_token.register("update_topology_viz").on_next( diff --git a/test/reconnect.sh b/test/reconnect.sh new file mode 100755 index 000000000..537f9ac31 --- /dev/null +++ b/test/reconnect.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +echo "Starting node 1" +DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 > output1.log 2>&1 & +PID1=$! +echo "Started node 1 PID: $PID1" +echo "Starting node 2" +DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 > output2.log 2>&1 & +PID2=$! +echo "Started node 2 PID: $PID2" +sleep 5 +kill $PID2 +sleep 5 +echo "Starting node 2 again..." +DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 > output3.log 2>&1 & +PID2=$! +sleep 5 +echo "Killing nodes and ending test..." +kill $PID1 +kill $PID2 +echo "Test complete." \ No newline at end of file diff --git a/test/test_hf.py b/test/test_hf.py new file mode 100644 index 000000000..0477d132b --- /dev/null +++ b/test/test_hf.py @@ -0,0 +1,26 @@ +import os +import sys + +# Add the project root to the Python path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) + +import asyncio +from exo.download.hf.hf_helpers import get_weight_map + +async def test_get_weight_map(): + repo_ids = [ + "mlx-community/quantized-gemma-2b", + "mlx-community/Meta-Llama-3.1-8B-4bit", + "mlx-community/Meta-Llama-3.1-70B-4bit", + "mlx-community/Meta-Llama-3.1-405B-4bit", + ] + for repo_id in repo_ids: + weight_map = await get_weight_map(repo_id) + assert weight_map is not None, "Weight map should not be None" + assert isinstance(weight_map, dict), "Weight map should be a dictionary" + assert len(weight_map) > 0, "Weight map should not be empty" + print(f"OK: {repo_id}") + +if __name__ == "__main__": + asyncio.run(test_get_weight_map()) diff --git a/test/test_tokenizers.py b/test/test_tokenizers.py index d60255eda..931561003 100644 --- a/test/test_tokenizers.py +++ b/test/test_tokenizers.py @@ -1,3 +1,5 @@ +import os +import re from transformers import AutoTokenizer, AutoProcessor from exo.models import model_base_shards @@ -22,10 +24,10 @@ def test_tokenizer(name, tokenizer, verbose=False): strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id])) assert text == strip_tokens(decoded) == strip_tokens(reconstructed) -ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "llava-hf/llava-1.5-7b-hf"] -models = [shard.model_id for shards in model_base_shards.values() for shard in shards.values() if shard.model_id not in ignore] +ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*"] +ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")") +models = [shard.model_id for shards in model_base_shards.values() for shard in shards.values() if not ignore_pattern.match(shard.model_id)] -import os verbose = os.environ.get("VERBOSE", "0").lower() == "1" for m in models: # TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit diff --git a/tinychat/examples/tinychat/index.html b/tinychat/examples/tinychat/index.html index 350cea178..d20d78d74 100644 --- a/tinychat/examples/tinychat/index.html +++ b/tinychat/examples/tinychat/index.html @@ -27,7 +27,22 @@
+ + + + + + + + + + + + + + + +