From 5ec28bd2259a26a676c6688399e11da2daa039a2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 4 Aug 2024 21:15:48 -0800 Subject: [PATCH 001/491] adding in pytorch inference with some gen and custom code, added in install for powershell in windows env, base pytorch version off of tinygrid --- exo/inference/pytorch/index.ex.json | 3 + exo/inference/pytorch/inference.py | 221 ++++++++++++++++++++++++++ exo/inference/pytorch/models/llama.py | 0 install.ps1 | 8 + setup.py | 1 + 5 files changed, 233 insertions(+) create mode 100644 exo/inference/pytorch/index.ex.json create mode 100644 exo/inference/pytorch/inference.py create mode 100644 exo/inference/pytorch/models/llama.py create mode 100644 install.ps1 diff --git a/exo/inference/pytorch/index.ex.json b/exo/inference/pytorch/index.ex.json new file mode 100644 index 00000000..19b16d91 --- /dev/null +++ b/exo/inference/pytorch/index.ex.json @@ -0,0 +1,3 @@ +{ + "model": "huggingface_model_name" +} \ No newline at end of file diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py new file mode 100644 index 00000000..d8ae7289 --- /dev/null +++ b/exo/inference/pytorch/inference.py @@ -0,0 +1,221 @@ +# experimental, based off of tinygrad/inference.py + +import asyncio +from functools import partial +from pathlib import Path +from typing import List, Optional, Union, Callable, Dict +import json +import torch +from torch import nn +from transformers import AutoTokenizer, AutoModelForCausalLM +import numpy as np +import os + +MODEL_PARAMS = { + "8B": { + "args": { + "dim": 4096, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 32, + "norm_eps": 1e-5, + "rope_theta": 500000, + "vocab_size": 128256, + "hidden_dim": 14336, + }, + "files": 1, + }, + "70B": { + "args": { + "dim": 8192, + "n_heads": 64, + "n_kv_heads": 8, + "n_layers": 80, + "norm_eps": 1e-5, + "rope_theta": 500000, + "vocab_size": 128256, + "hidden_dim": 28672, + }, + "files": 8, + }, +} + + +# **** helper functions **** +def load(fn: str) -> Union[str, Dict[str, torch.Tensor]]: + model = "" + if fn.endswith(".index.json"): + with open(fn) as fp: + model = json.load(fp)["model"] + + if model == "": + model = torch.load(fn, map_location="cpu") + + return model + +def build_transformer(model_path: Union[str, Path], model_size="8B", quantize=None, device=None): + # Load the model configuration and parameters + model = load(model_path) + if isinstance(model, str): + with torch.device(device): + model = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32, + device_map="auto" if "cuda" in str(device) else None + ) + + # Quantize the model if specified + if quantize: + model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8) + + # Shard the model if using multiple devices + if isinstance(device, tuple): + for name, param in model.named_parameters(): + if "scale" in name: + param.data = param.data.chunk(len(device), dim=0) + elif ".attention." in name: + param.data = param.data.chunk(len(device), dim=-1) + elif ".feed_forward.w1." in name or ".feed_forward.w3." in name: + param.data = param.data.chunk(len(device), dim=0) + elif ".feed_forward." in name: + param.data = param.data.chunk(len(device), dim=-1) + elif "tok_embeddings.weight" in name or "output.weight" in name: + param.data = param.data.chunk(len(device), dim=0) + + return model + +# Sample function using the built transformer +def sample(logits: torch.Tensor, temp: float, k: int, p: float): + if temp < 1e-6: + return torch.argmax(logits) + + logits[torch.isnan(logits)] = -float("inf") + probs = torch.nn.functional.softmax(logits / temp, dim=-1) + + if k: + top_probs, top_indices = torch.topk(probs, k) + top_probs = top_probs[top_probs.cumsum(dim=-1) <= p] + top_indices = top_indices[:len(top_probs)] + sampled_idx = torch.multinomial(top_probs, 1) + return top_indices[sampled_idx] + + return torch.multinomial(probs, 1) + + +# default settings +TEMPERATURE = 0 # 0.85 +TOP_K = 25 +TOP_P = 0.9 +ALPHA_F = 0.1 +ALPHA_P = 0.0 + + +def prefill(model, toks, start_pos=0): + # prefill the model + for tok in tqdm(toks): + GlobalCounters.reset() + inputs = torch.tensor([[tok]], device=model.device) + model.generate(inputs, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, max_new_tokens=1) + start_pos += 1 + return start_pos + + +class PytorchDynamicShardInferenceEngine(InferenceEngine): + def __init__(self): + self.shard = None + + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): + await self.ensure_shard(shard) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + + toks = self.tokenizer.encode(prompt) + start_pos = prefill(self.model, toks[:-1], start_pos=start_pos) + last_tok = toks[-1] + + input_ids = torch.tensor([[last_tok]], device=self.model.device) + output_data = self.model.generate(input_ids, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, max_new_tokens=1).cpu().numpy() + if output_data.size == 1: + start_pos += 1 + + return ( + output_data, + json.dumps({"start_pos": start_pos}), + output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], + ) + + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): + await self.ensure_shard(shard) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + + input_ids = torch.tensor(input_data, device=self.model.device) + output_data = self.model.generate(input_ids, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, max_new_tokens=1).cpu().numpy() + if output_data.size == 1: + start_pos += 1 + + return ( + output_data, + json.dumps({"start_pos": start_pos}), + output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], + ) + + async def ensure_shard(self, shard: Shard): + if self.shard == shard: + return + + model_path = Path(shard.model_id) + models_dir = Path(_cache_dir) / "pytorch" / "downloads" + model_path = models_dir / shard.model_id + size = "8B" + if Path(model_path / "tokenizer_config.json").exists(): + model = model_path + else: + if DEBUG >= 2: + print(f"Downloading pytorch model {shard.model_id}...") + if shard.model_id.lower().find("llama3-8b-sfr") != -1: + num_files = 4 + for i in range(num_files): + await fetch_async( + f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.bin", + f"model-{(i+1):05d}-of-{num_files:05d}.bin", + subdir=shard.model_id, + ) + await fetch_async( + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json", + "config.json", + subdir=shard.model_id, + ) + model = await fetch_async( + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.index.json", + "model.index.json", + subdir=shard.model_id, + ) + await fetch_async( + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json", + "special_tokens_map.json", + subdir=shard.model_id, + ) + await fetch_async( + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json", + "tokenizer.json", + subdir=shard.model_id, + ) + await fetch_async( + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json", + "tokenizer_config.json", + subdir=shard.model_id, + ) + size = "8B" + elif shard.model_id.lower().find("llama3-70b-sfr") != -1: + raise NotImplementedError("llama3-70b-sfr is not implemented for pytorch") + else: + raise ValueError(f"pytorch doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}") + + model = build_transformer(model_path, shard=shard, model_size=size) + tokenizer = AutoTokenizer.from_pretrained(model_path if model_path.is_dir() else model_path.parent) + + self.shard = shard + self.model = model + self.tokenizer = tokenizer + + def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): + pass diff --git a/exo/inference/pytorch/models/llama.py b/exo/inference/pytorch/models/llama.py new file mode 100644 index 00000000..e69de29b diff --git a/install.ps1 b/install.ps1 new file mode 100644 index 00000000..c766cdd5 --- /dev/null +++ b/install.ps1 @@ -0,0 +1,8 @@ +# Create a virtual environment +python3 -m venv .venv + +# Activate the virtual environment +& .\.venv\Scripts\Activate.ps1 + +# Install the package in the virtual environment +pip install . diff --git a/setup.py b/setup.py index 77641a2b..892548b1 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "transformers==4.43.3", "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@639af3f823cf242a1945dc24183e52a9df0af2b7", + "torch==2.4.0" ] # Add macOS-specific packages if on Darwin (macOS) From 7fcc89decaca669da8e1af39a0ded94140aaf6d7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 4 Aug 2024 21:57:29 -0800 Subject: [PATCH 002/491] removing and rewriting pytorch inference, removing model file as using huggingface hub, adding in pytorch_model.bin.index.json loading to load method --- exo/inference/pytorch/index.ex.json | 3 - exo/inference/pytorch/inference.py | 166 ++++---------------------- exo/inference/pytorch/models/llama.py | 0 3 files changed, 23 insertions(+), 146 deletions(-) delete mode 100644 exo/inference/pytorch/index.ex.json delete mode 100644 exo/inference/pytorch/models/llama.py diff --git a/exo/inference/pytorch/index.ex.json b/exo/inference/pytorch/index.ex.json deleted file mode 100644 index 19b16d91..00000000 --- a/exo/inference/pytorch/index.ex.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "model": "huggingface_model_name" -} \ No newline at end of file diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index d8ae7289..6d6a306d 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -42,27 +42,35 @@ # **** helper functions **** +def concat_weights(models, device=None): + def convert(name) -> torch.Tensor: + disk_tensors: List[torch.Tensor] = [model[name] for model in models] + if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1: + return disk_tensors[0].to(device=device) + axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0 + lazy_tensors = [data.to(device=device) for data in disk_tensors] + return torch.cat(lazy_tensors, dim=axis) + + return {name: convert(name) for name in {name for model in models for name in model}} + def load(fn: str) -> Union[str, Dict[str, torch.Tensor]]: model = "" - if fn.endswith(".index.json"): + if fn.endswith("pytorch_model.bin.index.json"): with open(fn) as fp: - model = json.load(fp)["model"] - - if model == "": + weight_map = json.load(fp)["weight_map"] + parts = {n: torch.load(str(Path(fn).parent / Path(n).name), map_location="cpu") for n in set(weight_map.values())} + return {k: parts[n][k] for k, n in weight_map.items()} + else: model = torch.load(fn, map_location="cpu") - return model -def build_transformer(model_path: Union[str, Path], model_size="8B", quantize=None, device=None): - # Load the model configuration and parameters - model = load(model_path) - if isinstance(model, str): - with torch.device(device): - model = AutoModelForCausalLM.from_pretrained( - model, - torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32, - device_map="auto" if "cuda" in str(device) else None - ) +def build_transformer(model_name: str, model_size="8B", quantize=None, device=None): + with torch.device(device): + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32, + device_map="auto" if "cuda" in str(device) else None + ) # Quantize the model if specified if quantize: @@ -84,24 +92,6 @@ def build_transformer(model_path: Union[str, Path], model_size="8B", quantize=No return model -# Sample function using the built transformer -def sample(logits: torch.Tensor, temp: float, k: int, p: float): - if temp < 1e-6: - return torch.argmax(logits) - - logits[torch.isnan(logits)] = -float("inf") - probs = torch.nn.functional.softmax(logits / temp, dim=-1) - - if k: - top_probs, top_indices = torch.topk(probs, k) - top_probs = top_probs[top_probs.cumsum(dim=-1) <= p] - top_indices = top_indices[:len(top_probs)] - sampled_idx = torch.multinomial(top_probs, 1) - return top_indices[sampled_idx] - - return torch.multinomial(probs, 1) - - # default settings TEMPERATURE = 0 # 0.85 TOP_K = 25 @@ -109,113 +99,3 @@ def sample(logits: torch.Tensor, temp: float, k: int, p: float): ALPHA_F = 0.1 ALPHA_P = 0.0 - -def prefill(model, toks, start_pos=0): - # prefill the model - for tok in tqdm(toks): - GlobalCounters.reset() - inputs = torch.tensor([[tok]], device=model.device) - model.generate(inputs, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, max_new_tokens=1) - start_pos += 1 - return start_pos - - -class PytorchDynamicShardInferenceEngine(InferenceEngine): - def __init__(self): - self.shard = None - - async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): - await self.ensure_shard(shard) - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - - toks = self.tokenizer.encode(prompt) - start_pos = prefill(self.model, toks[:-1], start_pos=start_pos) - last_tok = toks[-1] - - input_ids = torch.tensor([[last_tok]], device=self.model.device) - output_data = self.model.generate(input_ids, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, max_new_tokens=1).cpu().numpy() - if output_data.size == 1: - start_pos += 1 - - return ( - output_data, - json.dumps({"start_pos": start_pos}), - output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], - ) - - async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): - await self.ensure_shard(shard) - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - - input_ids = torch.tensor(input_data, device=self.model.device) - output_data = self.model.generate(input_ids, do_sample=True, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, max_new_tokens=1).cpu().numpy() - if output_data.size == 1: - start_pos += 1 - - return ( - output_data, - json.dumps({"start_pos": start_pos}), - output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], - ) - - async def ensure_shard(self, shard: Shard): - if self.shard == shard: - return - - model_path = Path(shard.model_id) - models_dir = Path(_cache_dir) / "pytorch" / "downloads" - model_path = models_dir / shard.model_id - size = "8B" - if Path(model_path / "tokenizer_config.json").exists(): - model = model_path - else: - if DEBUG >= 2: - print(f"Downloading pytorch model {shard.model_id}...") - if shard.model_id.lower().find("llama3-8b-sfr") != -1: - num_files = 4 - for i in range(num_files): - await fetch_async( - f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.bin", - f"model-{(i+1):05d}-of-{num_files:05d}.bin", - subdir=shard.model_id, - ) - await fetch_async( - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json", - "config.json", - subdir=shard.model_id, - ) - model = await fetch_async( - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.index.json", - "model.index.json", - subdir=shard.model_id, - ) - await fetch_async( - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json", - "special_tokens_map.json", - subdir=shard.model_id, - ) - await fetch_async( - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json", - "tokenizer.json", - subdir=shard.model_id, - ) - await fetch_async( - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json", - "tokenizer_config.json", - subdir=shard.model_id, - ) - size = "8B" - elif shard.model_id.lower().find("llama3-70b-sfr") != -1: - raise NotImplementedError("llama3-70b-sfr is not implemented for pytorch") - else: - raise ValueError(f"pytorch doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}") - - model = build_transformer(model_path, shard=shard, model_size=size) - tokenizer = AutoTokenizer.from_pretrained(model_path if model_path.is_dir() else model_path.parent) - - self.shard = shard - self.model = model - self.tokenizer = tokenizer - - def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): - pass diff --git a/exo/inference/pytorch/models/llama.py b/exo/inference/pytorch/models/llama.py deleted file mode 100644 index e69de29b..00000000 From 54b330648e549306a8255d7e5716455c04b1ec66 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 4 Aug 2024 22:46:51 -0800 Subject: [PATCH 003/491] separating out helpers, building unittest for build_transformer helper, gen pydoc --- exo/inference/pytorch/helpers.py | 231 ++++++++++++++++++ exo/inference/pytorch/inference.py | 109 ++------- .../pytorch/test_build_transformer.py | 46 ++++ 3 files changed, 299 insertions(+), 87 deletions(-) create mode 100644 exo/inference/pytorch/helpers.py create mode 100644 exo/inference/pytorch/test_build_transformer.py diff --git a/exo/inference/pytorch/helpers.py b/exo/inference/pytorch/helpers.py new file mode 100644 index 00000000..8376ea26 --- /dev/null +++ b/exo/inference/pytorch/helpers.py @@ -0,0 +1,231 @@ +# Helper functions for pytorch inference +# Some code coming from tinygrad but written towards pytorch + +# import os +# import numpy as np +# import asyncio +import json +import torch +# from functools import partial +from pathlib import Path +from typing import List, Union, Dict, Any +from transformers import AutoModelForCausalLM +from exo.inference.shard import Shard +# from exo.inference.inference_engine import InferenceEngine + +MODEL_PARAMS = { + "8B": { + "args": { + "dim": 4096, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 32, + "norm_eps": 1e-5, + "rope_theta": 500000, + "vocab_size": 128256, + "hidden_dim": 14336, + }, + "files": 1, + }, + "70B": { + "args": { + "dim": 8192, + "n_heads": 64, + "n_kv_heads": 8, + "n_layers": 80, + "norm_eps": 1e-5, + "rope_theta": 500000, + "vocab_size": 128256, + "hidden_dim": 28672, + }, + "files": 8, + }, +} + +def concat_weights(models, device=None): + """ + Concatenates weights from multiple model parts along the appropriate axis. + + Args: + models (List[Dict[str, torch.Tensor]]): List of dictionaries containing model weights. + device (Optional[torch.device]): The device to move the weights to (e.g., 'cpu' or 'cuda'). + + Returns: + Dict[str, torch.Tensor]: A dictionary where the keys are the weight names and the values + are the concatenated tensors moved to the specified device. + """ + def convert(name) -> torch.Tensor: + disk_tensors: List[torch.Tensor] = [model[name] for model in models] + if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1: + return disk_tensors[0].to(device=device) + + ewn = name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") + axis = 1 if ewn else 0 + + lazy_tensors = [data.to(device=device) for data in disk_tensors] + return torch.cat(lazy_tensors, dim=axis) + + return {name: convert(name) for name in {name for model in models for name in model}} + +def load_weights(fn: str) -> Union[str, Dict[str, torch.Tensor]]: + """ + Loads model weights from a specified file. Supports both individual model files and + index files that map to multiple weight files. + + Args: + fn (str): The file path to load weights from. + + Returns: + Union[str, Dict[str, torch.Tensor]]: A string representing the model or a + dictionary of model weights. + """ + model = "" + if fn.endswith("pytorch_model.bin.index.json"): + with open(fn) as fp: + weight_map = json.load(fp)["weight_map"] + + for n in set(weight_map.values()): + full_path = str(Path(fn).parent / Path(n).name) + parts = {n: torch.load(full_path, map_location="cpu")} + + return {k: parts[n][k] for k, n in weight_map.items()} + else: + model = torch.load(fn, map_location="cpu") + return model + +def convert_from_huggingface( + weights: Dict[str, torch.Tensor], + model: torch.nn.Module, + n_heads: int, + n_kv_heads: int, + shard: Shard) -> Dict[str, torch.Tensor]: + """ + Converts Hugging Face model weights to the format expected by the target model. + + Args: + weights (Dict[str, torch.Tensor]): Dictionary of Hugging Face model weights. + model (nn.Module): The target model. + n_heads (int): Number of attention heads. + n_kv_heads (int): Number of key-value heads. + shard (Shard): Shard object containing information about the model shard. + + Returns: + Dict[str, torch.Tensor]: Dictionary of converted weights. + """ + def permute(v: torch.Tensor, n_heads: int) -> torch.Tensor: + return v.view( + n_heads, + 2, + v.shape[0] // (2 * n_heads), + v.shape[1] + ).transpose(1, 2).reshape(*v.shape) + + keymap = { + "model.embed_tokens.weight": "tok_embeddings.weight", + **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))}, + **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w_{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))}, + **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))}, + **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w_{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))}, + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + sd = {} + for k, v in weights.items(): + if ".rotary_emb." in k: + continue + + if "model.layers" in k: + layer_num = int(k.split(".")[2]) + if shard.start_layer <= layer_num <= shard.end_layer: + k = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:]) + else: + continue + + if "q_proj" in k: + v = permute(v, n_heads) + elif "k_proj" in k: + v = permute(v, n_kv_heads) + + if k in keymap: + sd[keymap[k]] = v + + return sd + +def fix_bf16(weights: Dict[Any, torch.Tensor]) -> Dict[Any, torch.Tensor]: + """ + Converts weights to bfloat16 if supported by the device, otherwise to float16. + + Args: + weights (Dict[Any, torch.Tensor]): Dictionary of model weights. + + Returns: + Dict[Any, torch.Tensor]: Dictionary of converted weights. + """ + supports_bf16 = torch.cuda.is_bf16_supported() + + if supports_bf16: + return {k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v for k, v in weights.items()} + else: + return {k: v.to(torch.float16) if v.dtype == torch.bfloat16 else v for k, v in weights.items()} + + +def build_transformer(model_name: str, shard: Shard, model_size="8B", quantize=None, device=None): + # Load model from Hugging Face hub + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32, + device_map="auto" if "cuda" in str(device) else None + ) + + # Load weights + model_path = Path(model_name) + if model_path.is_dir(): + if (model_path / "pytorch_model.bin.index.json").exists(): + weights = load_weights(str(model_path / "pytorch_model.bin.index.json")) + else: + pth_weights = [] + for i in range(MODEL_PARAMS[model_size]["files"]): + pth_path = str(model_path / f"consolidated.{i:02d}.pth") + pth_weights.append(load_weights(pth_path)) + + weights = concat_weights( + pth_weights, + device[0] if isinstance(device, tuple) else device, + ) + else: + weights = load_weights(str(model_path)) + + if "model.embed_tokens.weight" in weights: + weights = convert_from_huggingface( + weights, + model, + MODEL_PARAMS[model_size]["args"]["n_heads"], + MODEL_PARAMS[model_size]["args"]["n_kv_heads"], + shard=shard, + ) + weights = fix_bf16(weights) + + # Quantize the model if specified + if quantize: + model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}) + + # Shard the model if using multiple devices + if isinstance(device, tuple): + for name, param in model.named_parameters(): + if "scale" in name: + param.data = param.data.chunk(len(device), dim=0) + elif ".attention." in name: + param.data = param.data.chunk(len(device), dim=-1) + elif ".feed_forward.w1." in name or ".feed_forward.w3." in name: + param.data = param.data.chunk(len(device), dim=0) + elif ".feed_forward." in name: + param.data = param.data.chunk(len(device), dim=-1) + elif "tok_embeddings.weight" in name or "output.weight" in name: + param.data = param.data.chunk(len(device), dim=0) + + # Replace weights in model + model.load_state_dict(weights, strict=False) + + return model + diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 6d6a306d..d411481d 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,96 +1,16 @@ # experimental, based off of tinygrad/inference.py - +import os +import numpy as np import asyncio +import json +import torch from functools import partial from pathlib import Path from typing import List, Optional, Union, Callable, Dict -import json -import torch -from torch import nn from transformers import AutoTokenizer, AutoModelForCausalLM -import numpy as np -import os - -MODEL_PARAMS = { - "8B": { - "args": { - "dim": 4096, - "n_heads": 32, - "n_kv_heads": 8, - "n_layers": 32, - "norm_eps": 1e-5, - "rope_theta": 500000, - "vocab_size": 128256, - "hidden_dim": 14336, - }, - "files": 1, - }, - "70B": { - "args": { - "dim": 8192, - "n_heads": 64, - "n_kv_heads": 8, - "n_layers": 80, - "norm_eps": 1e-5, - "rope_theta": 500000, - "vocab_size": 128256, - "hidden_dim": 28672, - }, - "files": 8, - }, -} - - -# **** helper functions **** -def concat_weights(models, device=None): - def convert(name) -> torch.Tensor: - disk_tensors: List[torch.Tensor] = [model[name] for model in models] - if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1: - return disk_tensors[0].to(device=device) - axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0 - lazy_tensors = [data.to(device=device) for data in disk_tensors] - return torch.cat(lazy_tensors, dim=axis) - - return {name: convert(name) for name in {name for model in models for name in model}} - -def load(fn: str) -> Union[str, Dict[str, torch.Tensor]]: - model = "" - if fn.endswith("pytorch_model.bin.index.json"): - with open(fn) as fp: - weight_map = json.load(fp)["weight_map"] - parts = {n: torch.load(str(Path(fn).parent / Path(n).name), map_location="cpu") for n in set(weight_map.values())} - return {k: parts[n][k] for k, n in weight_map.items()} - else: - model = torch.load(fn, map_location="cpu") - return model - -def build_transformer(model_name: str, model_size="8B", quantize=None, device=None): - with torch.device(device): - model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32, - device_map="auto" if "cuda" in str(device) else None - ) - - # Quantize the model if specified - if quantize: - model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8) - - # Shard the model if using multiple devices - if isinstance(device, tuple): - for name, param in model.named_parameters(): - if "scale" in name: - param.data = param.data.chunk(len(device), dim=0) - elif ".attention." in name: - param.data = param.data.chunk(len(device), dim=-1) - elif ".feed_forward.w1." in name or ".feed_forward.w3." in name: - param.data = param.data.chunk(len(device), dim=0) - elif ".feed_forward." in name: - param.data = param.data.chunk(len(device), dim=-1) - elif "tok_embeddings.weight" in name or "output.weight" in name: - param.data = param.data.chunk(len(device), dim=0) - - return model +from exo.inference.shard import Shard +from exo.inference.inference_engine import InferenceEngine +# from exo.inference.pytorch.helpers import # default settings TEMPERATURE = 0 # 0.85 @@ -99,3 +19,18 @@ def build_transformer(model_name: str, model_size="8B", quantize=None, device=No ALPHA_F = 0.1 ALPHA_P = 0.0 + +# don't think prefill is needed +# think that is used for stats but will look into + +class PyTorchDynamicShardInferenceEngine(InferenceEngine): + def __init__(self): + self.shard = None + + # async def infer_prompt + + # async def infer_tensor + + # async def ensure_shard + + # def set_on_download_progess [is this needed?] diff --git a/exo/inference/pytorch/test_build_transformer.py b/exo/inference/pytorch/test_build_transformer.py new file mode 100644 index 00000000..1d553f47 --- /dev/null +++ b/exo/inference/pytorch/test_build_transformer.py @@ -0,0 +1,46 @@ +import unittest +from unittest.mock import patch, MagicMock +from pathlib import Path +import torch +from transformers import AutoModelForCausalLM +from exo.inference.shard import Shard +from exo.inference.pytorch.helpers import build_transformer + +class TestBuildTransformer(unittest.TestCase): + + @patch('torch.load') + @patch('transformers.AutoModelForCausalLM.from_pretrained') + @patch('builtins.open', new_callable=unittest.mock.mock_open, read_data='{"weight_map": {"0": "pytorch_model.bin"}}') + def test_build_transformer(self, mock_open, mock_from_pretrained, mock_torch_load): + # Mocking model and weights + mock_model = MagicMock(spec=AutoModelForCausalLM) + mock_from_pretrained.return_value = mock_model + + mock_weights = { + "model.embed_tokens.weight": torch.randn(1024, 768), + "model.layers.0.self_attn.q_proj.weight": torch.randn(768, 768), + # Add other necessary mock weights here + } + mock_torch_load.return_value = mock_weights + + # Define the shard + shard = Shard(model_id="mock_model", start_layer=0, end_layer=0, n_layers=1) + + # Call the build_transformer function + model = build_transformer("mock_model", shard, model_size="8B", quantize=True, device="cpu") + + # Assertions to verify the function behavior + mock_from_pretrained.assert_called_once_with( + "mock_model", + torch_dtype=torch.float32, + device_map=None + ) + + mock_open.assert_called_once_with("mock_model/pytorch_model.bin.index.json") + mock_torch_load.assert_called() + + mock_model.load_state_dict.assert_called() + self.assertEqual(model, mock_model) + +if __name__ == '__main__': + unittest.main() From 207422ff26f3a65e7285e98a9a39d185860dac61 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 4 Aug 2024 23:00:10 -0800 Subject: [PATCH 004/491] cleaning up bad gen code test for convert, simplifying test, taking out loading weights as not needed --- exo/inference/pytorch/helpers.py | 49 ++++++------------- .../pytorch/test_build_transformer.py | 3 +- 2 files changed, 18 insertions(+), 34 deletions(-) diff --git a/exo/inference/pytorch/helpers.py b/exo/inference/pytorch/helpers.py index 8376ea26..cb93adc0 100644 --- a/exo/inference/pytorch/helpers.py +++ b/exo/inference/pytorch/helpers.py @@ -171,41 +171,27 @@ def fix_bf16(weights: Dict[Any, torch.Tensor]) -> Dict[Any, torch.Tensor]: def build_transformer(model_name: str, shard: Shard, model_size="8B", quantize=None, device=None): + """ + Builds a transformer model by loading it from the Hugging Face model hub and applying + weight conversion, quantization, and sharding as specified. + + Args: + model_name (str): The name of the model to load from the Hugging Face model hub. + shard (Shard): A Shard object containing information about the model shard. + model_size (str, optional): The size of the model to load (default is "8B"). + quantize (bool, optional): Whether to apply dynamic quantization to the model (default is None). + device (torch.device, optional): The device to load the model onto (default is None). + + Returns: + nn.Module: The constructed and configured transformer model. + """ # Load model from Hugging Face hub model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32, + model_name, + torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32, device_map="auto" if "cuda" in str(device) else None ) - # Load weights - model_path = Path(model_name) - if model_path.is_dir(): - if (model_path / "pytorch_model.bin.index.json").exists(): - weights = load_weights(str(model_path / "pytorch_model.bin.index.json")) - else: - pth_weights = [] - for i in range(MODEL_PARAMS[model_size]["files"]): - pth_path = str(model_path / f"consolidated.{i:02d}.pth") - pth_weights.append(load_weights(pth_path)) - - weights = concat_weights( - pth_weights, - device[0] if isinstance(device, tuple) else device, - ) - else: - weights = load_weights(str(model_path)) - - if "model.embed_tokens.weight" in weights: - weights = convert_from_huggingface( - weights, - model, - MODEL_PARAMS[model_size]["args"]["n_heads"], - MODEL_PARAMS[model_size]["args"]["n_kv_heads"], - shard=shard, - ) - weights = fix_bf16(weights) - # Quantize the model if specified if quantize: model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}) @@ -224,8 +210,5 @@ def build_transformer(model_name: str, shard: Shard, model_size="8B", quantize=N elif "tok_embeddings.weight" in name or "output.weight" in name: param.data = param.data.chunk(len(device), dim=0) - # Replace weights in model - model.load_state_dict(weights, strict=False) - return model diff --git a/exo/inference/pytorch/test_build_transformer.py b/exo/inference/pytorch/test_build_transformer.py index 1d553f47..981d73b3 100644 --- a/exo/inference/pytorch/test_build_transformer.py +++ b/exo/inference/pytorch/test_build_transformer.py @@ -14,6 +14,7 @@ class TestBuildTransformer(unittest.TestCase): def test_build_transformer(self, mock_open, mock_from_pretrained, mock_torch_load): # Mocking model and weights mock_model = MagicMock(spec=AutoModelForCausalLM) + mock_model.layers = [MagicMock()] * 2 # Mocking layers attribute mock_from_pretrained.return_value = mock_model mock_weights = { @@ -24,7 +25,7 @@ def test_build_transformer(self, mock_open, mock_from_pretrained, mock_torch_loa mock_torch_load.return_value = mock_weights # Define the shard - shard = Shard(model_id="mock_model", start_layer=0, end_layer=0, n_layers=1) + shard = Shard(model_id="mock_model", start_layer=0, end_layer=1, n_layers=2) # Call the build_transformer function model = build_transformer("mock_model", shard, model_size="8B", quantize=True, device="cpu") From fb7cf22bb3f5422b0f18259c4e10909a3a3d16de Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 4 Aug 2024 23:17:42 -0800 Subject: [PATCH 005/491] created PyTorchDynamicShardInferenceEngine and adding tooling --- exo/inference/pytorch/helpers.py | 27 +++- exo/inference/pytorch/inference.py | 117 +++++++++++++++++- .../pytorch/test_build_transformer.py | 36 +----- 3 files changed, 142 insertions(+), 38 deletions(-) diff --git a/exo/inference/pytorch/helpers.py b/exo/inference/pytorch/helpers.py index cb93adc0..656f307c 100644 --- a/exo/inference/pytorch/helpers.py +++ b/exo/inference/pytorch/helpers.py @@ -170,7 +170,7 @@ def fix_bf16(weights: Dict[Any, torch.Tensor]) -> Dict[Any, torch.Tensor]: return {k: v.to(torch.float16) if v.dtype == torch.bfloat16 else v for k, v in weights.items()} -def build_transformer(model_name: str, shard: Shard, model_size="8B", quantize=None, device=None): +def build_transformer(model_name: str, quantize=None, device=None): """ Builds a transformer model by loading it from the Hugging Face model hub and applying weight conversion, quantization, and sharding as specified. @@ -212,3 +212,28 @@ def build_transformer(model_name: str, shard: Shard, model_size="8B", quantize=N return model +def shard_model(model: Any, model_name: str, num_shards: int) -> List[Shard]: + # Get the total number of layers + if hasattr(model, 'config'): + n_layers = model.config.num_hidden_layers + else: + raise ValueError("Unable to determine the number of layers in the model") + + # Calculate layers per shard + layers_per_shard = n_layers // num_shards + remainder = n_layers % num_shards + + shards = [] + start_layer = 0 + for i in range(num_shards): + end_layer = start_layer + layers_per_shard - 1 + if i < remainder: + end_layer += 1 + + shard = Shard(model_name, start_layer, end_layer, n_layers) + shards.append(shard) + + start_layer = end_layer + 1 + + return shards + diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index d411481d..c31950f9 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -6,11 +6,17 @@ import torch from functools import partial from pathlib import Path -from typing import List, Optional, Union, Callable, Dict +from typing import List, Optional, Union, Callable, Dict, Tuple from transformers import AutoTokenizer, AutoModelForCausalLM from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine -# from exo.inference.pytorch.helpers import +from exo.inference.pytorch.helpers import ( + fix_bf16, + build_transformer, + load_weights, + convert_from_huggingface, + MODEL_PARAMS +) # default settings TEMPERATURE = 0 # 0.85 @@ -26,11 +32,110 @@ class PyTorchDynamicShardInferenceEngine(InferenceEngine): def __init__(self): self.shard = None + self.model = None + self.tokenizer = None + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # async def infer_prompt + async def infer_prompt( + self, + request_id: str, + shard: Shard, + prompt: str, + image_str: Optional[str] = None, + inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + await self.ensure_shard(shard) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - # async def infer_tensor + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + attention_mask = torch.ones_like(input_ids) + + with torch.no_grad(): + outputs = self.model.generate( + input_ids, + attention_mask=attention_mask, + max_new_tokens=1, + do_sample=True, + temperature=0.7, + top_k=50, + top_p=0.95, + pad_token_id=self.tokenizer.eos_token_id, + start_pos=start_pos + ) - # async def ensure_shard + output_token = outputs[0, -1].item() + output_data = np.array([output_token]) + start_pos += 1 - # def set_on_download_progess [is this needed?] + is_eos = output_token == self.tokenizer.eos_token_id + + return ( + output_data, + json.dumps({"start_pos": start_pos}), + is_eos + ) + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + await self.ensure_shard(shard) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + + input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.device) + + with torch.no_grad(): + outputs = self.model.generate( + input_tensor, + max_new_tokens=1, + do_sample=True, + temperature=0.7, + top_k=50, + top_p=0.95, + pad_token_id=self.tokenizer.eos_token_id, + start_pos=start_pos + ) + + output_token = outputs[0, -1].item() + output_data = np.array([output_token]) + start_pos += 1 + + is_eos = output_token == self.tokenizer.eos_token_id + + return ( + output_data, + json.dumps({"start_pos": start_pos}), + is_eos + ) + + async def ensure_shard(self, shard: Shard): + if self.shard == shard: + return + + cache_dir = Path.home() / ".cache" / "huggingface" + model_path = cache_dir / "models--" / shard.model_id.replace('/', '--') + + if not model_path.exists(): + print(f"Downloading PyTorch model {shard.model_id}...") + weights = load_weights(str(model_path / "pytorch_model.bin")) + else: + weights = load_weights(str(model_path / "pytorch_model.bin")) + + model_size = "8B" # Assume 8B model, adjust as needed + n_heads = MODEL_PARAMS[model_size]["args"]["n_heads"] + n_kv_heads = MODEL_PARAMS[model_size]["args"]["n_kv_heads"] + + self.model = build_transformer(shard.model_id, device=self.device) + converted_weights = convert_from_huggingface(weights, self.model, n_heads, n_kv_heads, shard) + converted_weights = fix_bf16(converted_weights) + + self.model.load_state_dict(converted_weights, strict=False) + self.model.to(self.device) + + self.tokenizer = AutoTokenizer.from_pretrained(str(model_path)) + self.shard = shard + + def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): + # This method can be implemented if progress tracking is needed + pass \ No newline at end of file diff --git a/exo/inference/pytorch/test_build_transformer.py b/exo/inference/pytorch/test_build_transformer.py index 981d73b3..fd094658 100644 --- a/exo/inference/pytorch/test_build_transformer.py +++ b/exo/inference/pytorch/test_build_transformer.py @@ -2,46 +2,20 @@ from unittest.mock import patch, MagicMock from pathlib import Path import torch -from transformers import AutoModelForCausalLM from exo.inference.shard import Shard from exo.inference.pytorch.helpers import build_transformer class TestBuildTransformer(unittest.TestCase): - @patch('torch.load') - @patch('transformers.AutoModelForCausalLM.from_pretrained') - @patch('builtins.open', new_callable=unittest.mock.mock_open, read_data='{"weight_map": {"0": "pytorch_model.bin"}}') def test_build_transformer(self, mock_open, mock_from_pretrained, mock_torch_load): - # Mocking model and weights - mock_model = MagicMock(spec=AutoModelForCausalLM) - mock_model.layers = [MagicMock()] * 2 # Mocking layers attribute - mock_from_pretrained.return_value = mock_model - - mock_weights = { - "model.embed_tokens.weight": torch.randn(1024, 768), - "model.layers.0.self_attn.q_proj.weight": torch.randn(768, 768), - # Add other necessary mock weights here - } - mock_torch_load.return_value = mock_weights - - # Define the shard - shard = Shard(model_id="mock_model", start_layer=0, end_layer=1, n_layers=2) - # Call the build_transformer function - model = build_transformer("mock_model", shard, model_size="8B", quantize=True, device="cpu") - - # Assertions to verify the function behavior - mock_from_pretrained.assert_called_once_with( - "mock_model", - torch_dtype=torch.float32, - device_map=None + model = build_transformer( + "gpt2", + quantize=True, + device="cuda" ) - mock_open.assert_called_once_with("mock_model/pytorch_model.bin.index.json") - mock_torch_load.assert_called() - - mock_model.load_state_dict.assert_called() - self.assertEqual(model, mock_model) + self.assertIsNotNone(model) if __name__ == '__main__': unittest.main() From 5a262cfba7bb7ae0fd15e48c782bc5e506e3de4c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 4 Aug 2024 23:58:08 -0800 Subject: [PATCH 006/491] removing custom sharding from tinygrad inspo and using pytorch FSDP for sharding --- exo/inference/pytorch/helpers.py | 194 +---------------------------- exo/inference/pytorch/inference.py | 85 +++++-------- 2 files changed, 33 insertions(+), 246 deletions(-) diff --git a/exo/inference/pytorch/helpers.py b/exo/inference/pytorch/helpers.py index 656f307c..32e0572e 100644 --- a/exo/inference/pytorch/helpers.py +++ b/exo/inference/pytorch/helpers.py @@ -1,174 +1,8 @@ # Helper functions for pytorch inference # Some code coming from tinygrad but written towards pytorch -# import os -# import numpy as np -# import asyncio -import json import torch -# from functools import partial -from pathlib import Path -from typing import List, Union, Dict, Any from transformers import AutoModelForCausalLM -from exo.inference.shard import Shard -# from exo.inference.inference_engine import InferenceEngine - -MODEL_PARAMS = { - "8B": { - "args": { - "dim": 4096, - "n_heads": 32, - "n_kv_heads": 8, - "n_layers": 32, - "norm_eps": 1e-5, - "rope_theta": 500000, - "vocab_size": 128256, - "hidden_dim": 14336, - }, - "files": 1, - }, - "70B": { - "args": { - "dim": 8192, - "n_heads": 64, - "n_kv_heads": 8, - "n_layers": 80, - "norm_eps": 1e-5, - "rope_theta": 500000, - "vocab_size": 128256, - "hidden_dim": 28672, - }, - "files": 8, - }, -} - -def concat_weights(models, device=None): - """ - Concatenates weights from multiple model parts along the appropriate axis. - - Args: - models (List[Dict[str, torch.Tensor]]): List of dictionaries containing model weights. - device (Optional[torch.device]): The device to move the weights to (e.g., 'cpu' or 'cuda'). - - Returns: - Dict[str, torch.Tensor]: A dictionary where the keys are the weight names and the values - are the concatenated tensors moved to the specified device. - """ - def convert(name) -> torch.Tensor: - disk_tensors: List[torch.Tensor] = [model[name] for model in models] - if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1: - return disk_tensors[0].to(device=device) - - ewn = name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") - axis = 1 if ewn else 0 - - lazy_tensors = [data.to(device=device) for data in disk_tensors] - return torch.cat(lazy_tensors, dim=axis) - - return {name: convert(name) for name in {name for model in models for name in model}} - -def load_weights(fn: str) -> Union[str, Dict[str, torch.Tensor]]: - """ - Loads model weights from a specified file. Supports both individual model files and - index files that map to multiple weight files. - - Args: - fn (str): The file path to load weights from. - - Returns: - Union[str, Dict[str, torch.Tensor]]: A string representing the model or a - dictionary of model weights. - """ - model = "" - if fn.endswith("pytorch_model.bin.index.json"): - with open(fn) as fp: - weight_map = json.load(fp)["weight_map"] - - for n in set(weight_map.values()): - full_path = str(Path(fn).parent / Path(n).name) - parts = {n: torch.load(full_path, map_location="cpu")} - - return {k: parts[n][k] for k, n in weight_map.items()} - else: - model = torch.load(fn, map_location="cpu") - return model - -def convert_from_huggingface( - weights: Dict[str, torch.Tensor], - model: torch.nn.Module, - n_heads: int, - n_kv_heads: int, - shard: Shard) -> Dict[str, torch.Tensor]: - """ - Converts Hugging Face model weights to the format expected by the target model. - - Args: - weights (Dict[str, torch.Tensor]): Dictionary of Hugging Face model weights. - model (nn.Module): The target model. - n_heads (int): Number of attention heads. - n_kv_heads (int): Number of key-value heads. - shard (Shard): Shard object containing information about the model shard. - - Returns: - Dict[str, torch.Tensor]: Dictionary of converted weights. - """ - def permute(v: torch.Tensor, n_heads: int) -> torch.Tensor: - return v.view( - n_heads, - 2, - v.shape[0] // (2 * n_heads), - v.shape[1] - ).transpose(1, 2).reshape(*v.shape) - - keymap = { - "model.embed_tokens.weight": "tok_embeddings.weight", - **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))}, - **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w_{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))}, - **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))}, - **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w_{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))}, - "model.norm.weight": "norm.weight", - "lm_head.weight": "output.weight", - } - - sd = {} - for k, v in weights.items(): - if ".rotary_emb." in k: - continue - - if "model.layers" in k: - layer_num = int(k.split(".")[2]) - if shard.start_layer <= layer_num <= shard.end_layer: - k = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:]) - else: - continue - - if "q_proj" in k: - v = permute(v, n_heads) - elif "k_proj" in k: - v = permute(v, n_kv_heads) - - if k in keymap: - sd[keymap[k]] = v - - return sd - -def fix_bf16(weights: Dict[Any, torch.Tensor]) -> Dict[Any, torch.Tensor]: - """ - Converts weights to bfloat16 if supported by the device, otherwise to float16. - - Args: - weights (Dict[Any, torch.Tensor]): Dictionary of model weights. - - Returns: - Dict[Any, torch.Tensor]: Dictionary of converted weights. - """ - supports_bf16 = torch.cuda.is_bf16_supported() - - if supports_bf16: - return {k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v for k, v in weights.items()} - else: - return {k: v.to(torch.float16) if v.dtype == torch.bfloat16 else v for k, v in weights.items()} - def build_transformer(model_name: str, quantize=None, device=None): """ @@ -210,30 +44,4 @@ def build_transformer(model_name: str, quantize=None, device=None): elif "tok_embeddings.weight" in name or "output.weight" in name: param.data = param.data.chunk(len(device), dim=0) - return model - -def shard_model(model: Any, model_name: str, num_shards: int) -> List[Shard]: - # Get the total number of layers - if hasattr(model, 'config'): - n_layers = model.config.num_hidden_layers - else: - raise ValueError("Unable to determine the number of layers in the model") - - # Calculate layers per shard - layers_per_shard = n_layers // num_shards - remainder = n_layers % num_shards - - shards = [] - start_layer = 0 - for i in range(num_shards): - end_layer = start_layer + layers_per_shard - 1 - if i < remainder: - end_layer += 1 - - shard = Shard(model_name, start_layer, end_layer, n_layers) - shards.append(shard) - - start_layer = end_layer + 1 - - return shards - + return model \ No newline at end of file diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index c31950f9..e2b873f3 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,49 +1,39 @@ # experimental, based off of tinygrad/inference.py -import os +# utilizing pytorch FSDP for sharding + import numpy as np -import asyncio import json import torch -from functools import partial -from pathlib import Path -from typing import List, Optional, Union, Callable, Dict, Tuple -from transformers import AutoTokenizer, AutoModelForCausalLM +from typing import Optional, Callable, Tuple +from transformers import AutoTokenizer from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine -from exo.inference.pytorch.helpers import ( - fix_bf16, - build_transformer, - load_weights, - convert_from_huggingface, - MODEL_PARAMS -) - -# default settings -TEMPERATURE = 0 # 0.85 -TOP_K = 25 -TOP_P = 0.9 +from exo.inference.pytorch.helpers import build_transformer +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import auto_wrap_policy +from torch.distributed import init_process_group, destroy_process_group + +# Default settings +TEMPERATURE = 0.7 +TOP_K = 50 +TOP_P = 0.95 ALPHA_F = 0.1 ALPHA_P = 0.0 - -# don't think prefill is needed -# think that is used for stats but will look into - class PyTorchDynamicShardInferenceEngine(InferenceEngine): def __init__(self): self.shard = None self.model = None self.tokenizer = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Initialize process group + init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo') async def infer_prompt( self, - request_id: str, - shard: Shard, prompt: str, - image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: - await self.ensure_shard(shard) start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) @@ -55,9 +45,9 @@ async def infer_prompt( attention_mask=attention_mask, max_new_tokens=1, do_sample=True, - temperature=0.7, - top_k=50, - top_p=0.95, + temperature=TEMPERATURE, + top_k=TOP_K, + top_p=TOP_P, pad_token_id=self.tokenizer.eos_token_id, start_pos=start_pos ) @@ -90,9 +80,9 @@ async def infer_tensor( input_tensor, max_new_tokens=1, do_sample=True, - temperature=0.7, - top_k=50, - top_p=0.95, + temperature=TEMPERATURE, + top_k=TOP_K, + top_p=TOP_P, pad_token_id=self.tokenizer.eos_token_id, start_pos=start_pos ) @@ -113,29 +103,18 @@ async def ensure_shard(self, shard: Shard): if self.shard == shard: return - cache_dir = Path.home() / ".cache" / "huggingface" - model_path = cache_dir / "models--" / shard.model_id.replace('/', '--') - - if not model_path.exists(): - print(f"Downloading PyTorch model {shard.model_id}...") - weights = load_weights(str(model_path / "pytorch_model.bin")) - else: - weights = load_weights(str(model_path / "pytorch_model.bin")) - - model_size = "8B" # Assume 8B model, adjust as needed - n_heads = MODEL_PARAMS[model_size]["args"]["n_heads"] - n_kv_heads = MODEL_PARAMS[model_size]["args"]["n_kv_heads"] - - self.model = build_transformer(shard.model_id, device=self.device) - converted_weights = convert_from_huggingface(weights, self.model, n_heads, n_kv_heads, shard) - converted_weights = fix_bf16(converted_weights) - - self.model.load_state_dict(converted_weights, strict=False) - self.model.to(self.device) + # Load model and tokenizer from Hugging Face hub + self.model = build_transformer(shard.model_id, shard, device=self.device) + + # Wrap the model with FSDP + self.model = FSDP(self.model, auto_wrap_policy=auto_wrap_policy) - self.tokenizer = AutoTokenizer.from_pretrained(str(model_path)) + self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) self.shard = shard def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): # This method can be implemented if progress tracking is needed - pass \ No newline at end of file + pass + + def __del__(self): + destroy_process_group() From ec1f656202c3622297f0fe3383c176554de2b554 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 5 Aug 2024 00:03:12 -0800 Subject: [PATCH 007/491] fixing infer_prompt, working on testing, cleaning up PyTorchDynamicShardInferenceEngine --- exo/inference/pytorch/inference.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index e2b873f3..ffe9f9d8 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -32,8 +32,14 @@ def __init__(self): async def infer_prompt( self, + request_id: str, + shard: Shard, prompt: str, + image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + + await self.ensure_shard(shard) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) @@ -70,7 +76,9 @@ async def infer_tensor( shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + await self.ensure_shard(shard) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.device) From a612b1fe5520e55adaa98bb4c2316b3c198d2560 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 5 Aug 2024 00:04:44 -0800 Subject: [PATCH 008/491] fixing test --- exo/inference/pytorch/test_build_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/test_build_transformer.py b/exo/inference/pytorch/test_build_transformer.py index fd094658..cdbfa6fc 100644 --- a/exo/inference/pytorch/test_build_transformer.py +++ b/exo/inference/pytorch/test_build_transformer.py @@ -7,7 +7,7 @@ class TestBuildTransformer(unittest.TestCase): - def test_build_transformer(self, mock_open, mock_from_pretrained, mock_torch_load): + def test_build_transformer(self): # Call the build_transformer function model = build_transformer( "gpt2", From d904caf9ea6805c14bb1b502c605c9ec439b8f49 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 5 Aug 2024 00:06:25 -0800 Subject: [PATCH 009/491] removing pytorch from setup.py, will need to think of another way to install right pytorch version --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 892548b1..3ea7c8ac 100644 --- a/setup.py +++ b/setup.py @@ -26,8 +26,7 @@ "tqdm==4.66.4", "transformers==4.43.3", "uuid==1.30", - "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@639af3f823cf242a1945dc24183e52a9df0af2b7", - "torch==2.4.0" + "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@639af3f823cf242a1945dc24183e52a9df0af2b7" ] # Add macOS-specific packages if on Darwin (macOS) From 1fabef2067022b78067b82497ce2d37632862571 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 5 Aug 2024 01:06:45 -0800 Subject: [PATCH 010/491] working more on engine, removing pytorch fsd as exo is bascially what it is trying to do, implementing exo sharding next --- exo/inference/pytorch/helpers.py | 14 --------- exo/inference/pytorch/inference.py | 46 +++++++--------------------- exo/inference/pytorch/test_engine.py | 17 ++++++++++ 3 files changed, 28 insertions(+), 49 deletions(-) create mode 100644 exo/inference/pytorch/test_engine.py diff --git a/exo/inference/pytorch/helpers.py b/exo/inference/pytorch/helpers.py index 32e0572e..27700509 100644 --- a/exo/inference/pytorch/helpers.py +++ b/exo/inference/pytorch/helpers.py @@ -30,18 +30,4 @@ def build_transformer(model_name: str, quantize=None, device=None): if quantize: model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}) - # Shard the model if using multiple devices - if isinstance(device, tuple): - for name, param in model.named_parameters(): - if "scale" in name: - param.data = param.data.chunk(len(device), dim=0) - elif ".attention." in name: - param.data = param.data.chunk(len(device), dim=-1) - elif ".feed_forward.w1." in name or ".feed_forward.w3." in name: - param.data = param.data.chunk(len(device), dim=0) - elif ".feed_forward." in name: - param.data = param.data.chunk(len(device), dim=-1) - elif "tok_embeddings.weight" in name or "output.weight" in name: - param.data = param.data.chunk(len(device), dim=0) - return model \ No newline at end of file diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ffe9f9d8..69b63a13 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,17 +1,17 @@ # experimental, based off of tinygrad/inference.py # utilizing pytorch FSDP for sharding +# look into shard being optional for the inferece import numpy as np import json import torch +import functools +import os from typing import Optional, Callable, Tuple from transformers import AutoTokenizer from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.helpers import build_transformer -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.wrap import auto_wrap_policy -from torch.distributed import init_process_group, destroy_process_group # Default settings TEMPERATURE = 0.7 @@ -21,25 +21,20 @@ ALPHA_P = 0.0 class PyTorchDynamicShardInferenceEngine(InferenceEngine): - def __init__(self): - self.shard = None - self.model = None - self.tokenizer = None - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Initialize process group - init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo') + def __init__(self, model_name: str = "gpt2", device: str = "cuda", tokenizer: str="gpt2"): + self.device = device + self.model_name = model_name + self.shard = Shard(model_id=model_name, start_layer=0, end_layer=1, n_layers=2) + self.model = build_transformer(self.shard.model_id, self.shard, device=self.device) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) async def infer_prompt( self, request_id: str, shard: Shard, prompt: str, - image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: - await self.ensure_shard(shard) - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) @@ -54,8 +49,7 @@ async def infer_prompt( temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, - pad_token_id=self.tokenizer.eos_token_id, - start_pos=start_pos + pad_token_id=self.tokenizer.eos_token_id ) output_token = outputs[0, -1].item() @@ -77,8 +71,6 @@ async def infer_tensor( input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: - await self.ensure_shard(shard) - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.device) @@ -107,22 +99,6 @@ async def infer_tensor( is_eos ) - async def ensure_shard(self, shard: Shard): - if self.shard == shard: - return - - # Load model and tokenizer from Hugging Face hub - self.model = build_transformer(shard.model_id, shard, device=self.device) - - # Wrap the model with FSDP - self.model = FSDP(self.model, auto_wrap_policy=auto_wrap_policy) - - self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) - self.shard = shard - def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): # This method can be implemented if progress tracking is needed - pass - - def __del__(self): - destroy_process_group() + pass \ No newline at end of file diff --git a/exo/inference/pytorch/test_engine.py b/exo/inference/pytorch/test_engine.py new file mode 100644 index 00000000..838958d4 --- /dev/null +++ b/exo/inference/pytorch/test_engine.py @@ -0,0 +1,17 @@ +import unittest +from .inference import PyTorchDynamicShardInferenceEngine +from exo.inference.shard import Shard +import asyncio + +class TestPyTorchDynamicShardInferenceEngine(unittest.TestCase): + def test_one(self): + shard = Shard(model_id="mock_model", start_layer=0, end_layer=1, n_layers=2) + engine = PyTorchDynamicShardInferenceEngine() + prompt_resp = asyncio.run( + engine.infer_prompt( + "", + shard, + "Why is the sky blue?") + ) + + self.assertIsNotNone(prompt_resp) From 1aff2e9425ccd5787e3800bc660f0482d5a2a479 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 19:28:04 -0800 Subject: [PATCH 011/491] building out engine more with sharding implemented and kv caching, added a llama focused sharding model and started work on a more general hugging face sharding model --- exo/inference/pytorch/helpers.py | 47 ++--- exo/inference/pytorch/inference.py | 194 +++++++++++------- exo/inference/pytorch/model/hf.py | 52 +++++ exo/inference/pytorch/model/llama.py | 41 ++++ exo/inference/pytorch/test_engine.py | 17 -- .../pytorch/test_inference_engine.py | 55 +++++ 6 files changed, 291 insertions(+), 115 deletions(-) create mode 100644 exo/inference/pytorch/model/hf.py create mode 100644 exo/inference/pytorch/model/llama.py delete mode 100644 exo/inference/pytorch/test_engine.py create mode 100644 exo/inference/pytorch/test_inference_engine.py diff --git a/exo/inference/pytorch/helpers.py b/exo/inference/pytorch/helpers.py index 27700509..addea2db 100644 --- a/exo/inference/pytorch/helpers.py +++ b/exo/inference/pytorch/helpers.py @@ -1,33 +1,24 @@ # Helper functions for pytorch inference # Some code coming from tinygrad but written towards pytorch -import torch -from transformers import AutoModelForCausalLM +import asyncio +import aiohttp +from tqdm import tqdm +from pathlib import Path +from typing import List -def build_transformer(model_name: str, quantize=None, device=None): - """ - Builds a transformer model by loading it from the Hugging Face model hub and applying - weight conversion, quantization, and sharding as specified. +async def fetch_file_async(session, url: str, output_path: Path): + async with session.get(url) as response: + response.raise_for_status() + with open(output_path, 'wb') as f: + async for chunk in response.content.iter_chunked(8192): + f.write(chunk) - Args: - model_name (str): The name of the model to load from the Hugging Face model hub. - shard (Shard): A Shard object containing information about the model shard. - model_size (str, optional): The size of the model to load (default is "8B"). - quantize (bool, optional): Whether to apply dynamic quantization to the model (default is None). - device (torch.device, optional): The device to load the model onto (default is None). - - Returns: - nn.Module: The constructed and configured transformer model. - """ - # Load model from Hugging Face hub - model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32, - device_map="auto" if "cuda" in str(device) else None - ) - - # Quantize the model if specified - if quantize: - model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}) - - return model \ No newline at end of file +async def download_files(urls: List[str], output_paths: List[Path]): + async with aiohttp.ClientSession() as session: + tasks = [] + for url, output_path in zip(urls, output_paths): + tasks.append(fetch_file_async(session, url, output_path)) + + for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Downloading files"): + await f diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 69b63a13..306758e7 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -2,103 +2,157 @@ # utilizing pytorch FSDP for sharding # look into shard being optional for the inferece -import numpy as np +import os +import shutil import json import torch -import functools -import os +import numpy as np +from pathlib import Path from typing import Optional, Callable, Tuple -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModelForCausalLM from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine -from exo.inference.pytorch.helpers import build_transformer +from exo.inference.pytorch.helpers import download_files +from exo.inference.pytorch.model.llama import ShardedLLAMAModel # Default settings TEMPERATURE = 0.7 TOP_K = 50 -TOP_P = 0.95 -ALPHA_F = 0.1 -ALPHA_P = 0.0 class PyTorchDynamicShardInferenceEngine(InferenceEngine): - def __init__(self, model_name: str = "gpt2", device: str = "cuda", tokenizer: str="gpt2"): - self.device = device - self.model_name = model_name - self.shard = Shard(model_id=model_name, start_layer=0, end_layer=1, n_layers=2) - self.model = build_transformer(self.shard.model_id, self.shard, device=self.device) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + def __init__(self, debug: bool = False): + self.shard = None + self.debug = debug async def infer_prompt( self, request_id: str, - shard: Shard, + shard: Optional[Shard], prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: - - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + await self.ensure_shard(shard) - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) - attention_mask = torch.ones_like(input_ids) + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) - with torch.no_grad(): - outputs = self.model.generate( - input_ids, - attention_mask=attention_mask, - max_new_tokens=1, - do_sample=True, - temperature=TEMPERATURE, - top_k=TOP_K, - top_p=TOP_P, - pad_token_id=self.tokenizer.eos_token_id - ) - - output_token = outputs[0, -1].item() - output_data = np.array([output_token]) - start_pos += 1 - - is_eos = output_token == self.tokenizer.eos_token_id - - return ( - output_data, - json.dumps({"start_pos": start_pos}), - is_eos - ) + # Continue the sequence if inference state exists + past_key_values = None + if inference_state: + past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) + + output, past_key_values = self.model(input_ids, past_key_values=past_key_values) + + if self.shard.is_last_layer(): + logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) + next_token = torch.argmax(logits[:, -1, :], dim=-1) + output_data = np.array([next_token.item()]) + is_eos = next_token.item() == self.tokenizer.eos_token_id + else: + output_data = output.cpu().numpy() + is_eos = False + + new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) + + if self.debug: + print(f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + + return output_data, new_inference_state, is_eos async def infer_tensor( self, request_id: str, - shard: Shard, + shard: Optional[Shard], input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: - - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + await self.ensure_shard(shard) - input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.device) + input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) - with torch.no_grad(): - outputs = self.model.generate( - input_tensor, - max_new_tokens=1, - do_sample=True, - temperature=TEMPERATURE, - top_k=TOP_K, - top_p=TOP_P, - pad_token_id=self.tokenizer.eos_token_id, - start_pos=start_pos - ) - - output_token = outputs[0, -1].item() - output_data = np.array([output_token]) - start_pos += 1 - - is_eos = output_token == self.tokenizer.eos_token_id - - return ( - output_data, - json.dumps({"start_pos": start_pos}), - is_eos - ) + # Continue the sequence if inference state exists + past_key_values = None + if inference_state: + past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) + + output, past_key_values = self.model(input_tensor, past_key_values=past_key_values) + + if self.shard.is_last_layer(): + logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) + next_token = torch.argmax(logits[:, -1, :], dim=-1) + output_data = np.array([next_token.item()]) + is_eos = next_token.item() == self.tokenizer.eos_token_id + else: + output_data = output.cpu().numpy() + is_eos = False + + new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) + + if self.debug: + print(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + + return output_data, new_inference_state, is_eos + + def _apply_generation_settings(self, logits, temperature, top_k): + logits = logits / temperature + if top_k > 0: + top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) + logits = logits.scatter(1, top_k_indices, top_k_values) + return logits + + def _load_kv_cache(self, past_key_values_list): + if past_key_values_list is None: + return None + return [torch.tensor(kv, device=self.model.device) for kv in past_key_values_list] + + def _save_kv_cache(self, past_key_values): + return [kv.cpu().tolist() for kv in past_key_values] + + async def ensure_shard(self, shard: Optional[Shard]): + if self.shard == shard: + return + + model_path = Path(f".cache/{shard.model_id}") + if not model_path.exists(): + os.makedirs(model_path, exist_ok=True) + else: + shutil.rmtree(model_path) + os.makedirs(model_path) + + if shard.model_id.lower().find("llama3-8b-sfr") != -1: + num_files = 4 + urls = [ + f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.safetensors" + for i in range(num_files) + ] + urls.extend([ + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json", + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.safetensors.index.json", + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json", + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json", + "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json" + ]) + + output_paths = [ + model_path / f"model-{(i+1):05d}-of-{num_files:05d}.safetensors" + for i in range(num_files) + ] + output_paths.extend([ + model_path / "config.json", + model_path / "model.safetensors.index.json", + model_path / "special_tokens_map.json", + model_path / "tokenizer.json", + model_path / "tokenizer_config.json" + ]) + + await download_files(urls, output_paths) + else: + raise ValueError(f"Unsupported model: {shard.model_id}") + + # Load model and tokenizer from the downloaded files + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) + self.model = ShardedLLAMAModel(model, shard) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + self.shard = shard def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): # This method can be implemented if progress tracking is needed - pass \ No newline at end of file + pass diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py new file mode 100644 index 00000000..ba838be8 --- /dev/null +++ b/exo/inference/pytorch/model/hf.py @@ -0,0 +1,52 @@ +# Work in progress on a generic hugging face model sharder +# right now doesn't work with all models + +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM +from exo.inference.shard import Shard +import logging + +class ShardedHuggingFaceModel(nn.Module): + def __init__(self, model_name: str, shard: Shard): + super(ShardedHuggingFaceModel, self).__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.shard = shard + + # Load the model + self.model = AutoModelForCausalLM.from_pretrained(model_name) + + # Only keep layers corresponding to this shard + self.layers = nn.ModuleList([ + self.model.transformer.h[i] for i in range(shard.start_layer, shard.end_layer + 1) + ]) + + logging.info(f"layers: {self.layers}") + + self.model.transformer.wte.to(self.device) + self.model.transformer.wpe.to(self.device) + + def forward(self, input_ids, past_key_values=None): + hidden_states = self._get_initial_hidden_states(input_ids) + hidden_states, new_past_key_values = self._process_layers(hidden_states, past_key_values) + + if self.shard.is_last_layer(): + hidden_states = self.model.transformer.ln_f(hidden_states.to(self.device)) + logits = self.model.lm_head(hidden_states) + return logits, new_past_key_values + else: + return hidden_states, new_past_key_values + + def _get_initial_hidden_states(self, input_ids): + input_embeds = self.model.transformer.wte(input_ids.to(self.device)) + position_embeds = self.model.transformer.wpe(torch.arange(input_ids.shape[1], device=self.device)) + return input_embeds + position_embeds + + def _process_layers(self, hidden_states, past_key_values): + new_past_key_values = [] + for i, layer in enumerate(self.layers): + layer_past = past_key_values[i] if past_key_values else None + hidden_states, new_layer_past = layer(hidden_states, past_key_values=layer_past) + new_past_key_values.append(new_layer_past) + return hidden_states, new_past_key_values + diff --git a/exo/inference/pytorch/model/llama.py b/exo/inference/pytorch/model/llama.py new file mode 100644 index 00000000..2871e357 --- /dev/null +++ b/exo/inference/pytorch/model/llama.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from exo.inference.shard import Shard + +class ShardedLLAMAModel(nn.Module): + def __init__(self, model, shard: Shard): + super(ShardedLLAMAModel, self).__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.shard = shard + + # Only keep layers corresponding to this shard + self.layers = nn.ModuleList([model.transformer.h[i] for i in range(shard.start_layer, shard.end_layer + 1)]) + + # Move embeddings to the appropriate device + self.model = model + self.model.transformer.wte.to(self.device) + self.model.transformer.wpe.to(self.device) + + def forward(self, input_ids, past_key_values=None): + hidden_states = self._get_initial_hidden_states(input_ids) + hidden_states, new_past_key_values = self._process_layers(hidden_states, past_key_values) + + if self.shard.is_last_layer(): + hidden_states = self.model.transformer.ln_f(hidden_states.to(self.device)) + logits = self.model.lm_head(hidden_states) + return logits, new_past_key_values + else: + return hidden_states, new_past_key_values + + def _get_initial_hidden_states(self, input_ids): + input_embeds = self.model.transformer.wte(input_ids.to(self.device)) + position_embeds = self.model.transformer.wpe(torch.arange(input_ids.shape[1], device=self.device)) + return input_embeds + position_embeds + + def _process_layers(self, hidden_states, past_key_values): + new_past_key_values = [] + for i, layer in enumerate(self.layers): + layer_past = past_key_values[i] if past_key_values else None + hidden_states, new_layer_past = layer(hidden_states, past_key_values=layer_past) + new_past_key_values.append(new_layer_past) + return hidden_states, new_past_key_values diff --git a/exo/inference/pytorch/test_engine.py b/exo/inference/pytorch/test_engine.py deleted file mode 100644 index 838958d4..00000000 --- a/exo/inference/pytorch/test_engine.py +++ /dev/null @@ -1,17 +0,0 @@ -import unittest -from .inference import PyTorchDynamicShardInferenceEngine -from exo.inference.shard import Shard -import asyncio - -class TestPyTorchDynamicShardInferenceEngine(unittest.TestCase): - def test_one(self): - shard = Shard(model_id="mock_model", start_layer=0, end_layer=1, n_layers=2) - engine = PyTorchDynamicShardInferenceEngine() - prompt_resp = asyncio.run( - engine.infer_prompt( - "", - shard, - "Why is the sky blue?") - ) - - self.assertIsNotNone(prompt_resp) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py new file mode 100644 index 00000000..3be15276 --- /dev/null +++ b/exo/inference/pytorch/test_inference_engine.py @@ -0,0 +1,55 @@ +import unittest +import asyncio +from exo.inference.shard import Shard +from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine + +class TestPyTorchDynamicShardInferenceEngine(unittest.TestCase): + + @classmethod + def setUpClass(cls): + + # Create a shard + cls.shard = Shard( + model_id="llama3-8b-sfr", + start_layer=0, + end_layer=0, + n_layers=12 + ) + + # Initialize the inference engine + cls.engine = PyTorchDynamicShardInferenceEngine(debug=True) + + def test_infer_prompt(self): + # Prepare the prompt + prompt = "Why is the sky blue?" + + # Run inference + loop = asyncio.get_event_loop() + output_data, new_inference_state, is_eos = loop.run_until_complete( + self.engine.infer_prompt( + request_id="test_request", shard=self.shard, prompt=prompt + ) + ) + + # Assertions + self.assertIsNotNone(output_data) + self.assertIsNotNone(new_inference_state) + self.assertFalse(is_eos) + + # def test_infer_tensor(self): + # # Prepare the input tensor + # input_ids = self.tokenizer.encode("Hello, world!", return_tensors="pt").numpy() + + # # Run inference + # loop = asyncio.get_event_loop() + # output_data, new_inference_state, is_eos = loop.run_until_complete(self.engine.infer_tensor( + # request_id="test_request", shard=self.shard, input_data=input_ids + # )) + + # # Assertions + # self.assertIsNotNone(output_data) + # self.assertIsNotNone(new_inference_state) + # self.assertFalse(is_eos) + +if __name__ == '__main__': + unittest.main() From 4b11fe71ef25a7d2ffc1a07bedabcd5706357f7d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 22:32:43 -0800 Subject: [PATCH 012/491] rebuilding based on LlamaForCausalLM and redefining its forward --- exo/inference/pytorch/inference.py | 95 +++++++++++++++++++++++----- exo/inference/pytorch/model/llama.py | 71 ++++++++++++++------- 2 files changed, 128 insertions(+), 38 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 306758e7..0fe932bc 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,7 +1,3 @@ -# experimental, based off of tinygrad/inference.py -# utilizing pytorch FSDP for sharding -# look into shard being optional for the inferece - import os import shutil import json @@ -9,18 +5,27 @@ import numpy as np from pathlib import Path from typing import Optional, Callable, Tuple -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, LlamaForCausalLM from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.helpers import download_files -from exo.inference.pytorch.model.llama import ShardedLLAMAModel # Default settings TEMPERATURE = 0.7 TOP_K = 50 class PyTorchDynamicShardInferenceEngine(InferenceEngine): + """ + PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. + """ + def __init__(self, debug: bool = False): + """ + Initialize the inference engine. + + Args: + debug (bool): If True, enables debug logging. Defaults to False. + """ self.shard = None self.debug = debug @@ -30,6 +35,18 @@ async def infer_prompt( shard: Optional[Shard], prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + """ + Perform inference based on a text prompt. + + Args: + request_id (str): Unique identifier for the request. + shard (Optional[Shard]): Shard information for the model. + prompt (str): The input text prompt for inference. + inference_state (Optional[str]): The previous inference state. + + Returns: + Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. + """ await self.ensure_shard(shard) input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) @@ -63,6 +80,18 @@ async def infer_tensor( shard: Optional[Shard], input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + """ + Perform inference based on an input tensor. + + Args: + request_id (str): Unique identifier for the request. + shard (Optional[Shard]): Shard information for the model. + input_data (np.ndarray): The input tensor for inference. + inference_state (Optional[str]): The previous inference state. + + Returns: + Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. + """ await self.ensure_shard(shard) input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) @@ -91,6 +120,17 @@ async def infer_tensor( return output_data, new_inference_state, is_eos def _apply_generation_settings(self, logits, temperature, top_k): + """ + Apply temperature and top_k settings to logits. + + Args: + logits (torch.Tensor): The logits to be adjusted. + temperature (float): The temperature setting for generation. + top_k (int): The top_k setting for generation. + + Returns: + torch.Tensor: The adjusted logits. + """ logits = logits / temperature if top_k > 0: top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) @@ -98,23 +138,47 @@ def _apply_generation_settings(self, logits, temperature, top_k): return logits def _load_kv_cache(self, past_key_values_list): + """ + Load key-value cache from the inference state. + + Args: + past_key_values_list (list): List of past key-value tensors. + + Returns: + list: List of loaded past key-value tensors. + """ if past_key_values_list is None: return None return [torch.tensor(kv, device=self.model.device) for kv in past_key_values_list] def _save_kv_cache(self, past_key_values): + """ + Save key-value cache to the inference state. + + Args: + past_key_values (list): List of past key-value tensors. + + Returns: + list: List of key-value tensors in a format suitable for saving. + """ return [kv.cpu().tolist() for kv in past_key_values] async def ensure_shard(self, shard: Optional[Shard]): + """ + Ensure the model shard is loaded and ready for inference. + + Args: + shard (Optional[Shard]): Shard information for the model. + """ if self.shard == shard: return - model_path = Path(f".cache/{shard.model_id}") + model_path = Path(self.model_name) + models_dir = Path(__file__).parent / "temp_model_dir" + model_path = models_dir / shard.model_id + if not model_path.exists(): os.makedirs(model_path, exist_ok=True) - else: - shutil.rmtree(model_path) - os.makedirs(model_path) if shard.model_id.lower().find("llama3-8b-sfr") != -1: num_files = 4 @@ -147,12 +211,11 @@ async def ensure_shard(self, shard: Optional[Shard]): raise ValueError(f"Unsupported model: {shard.model_id}") # Load model and tokenizer from the downloaded files - model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) - self.model = ShardedLLAMAModel(model, shard) + # This is written for llama model but need to add in option for others + self.model = LlamaForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) + self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.shard = shard - - def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): - # This method can be implemented if progress tracking is needed - pass diff --git a/exo/inference/pytorch/model/llama.py b/exo/inference/pytorch/model/llama.py index 2871e357..01a69445 100644 --- a/exo/inference/pytorch/model/llama.py +++ b/exo/inference/pytorch/model/llama.py @@ -1,41 +1,68 @@ import torch import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaForCausalLM from exo.inference.shard import Shard class ShardedLLAMAModel(nn.Module): - def __init__(self, model, shard: Shard): + """ + Sharded LLAMA Model for performing inference with a subset of model layers. + """ + + def __init__(self, model_path: str, shard: Shard): + """ + Initialize the ShardedLLAMAModel. + + Args: + model_path (str): Path to the pretrained model. + shard (Shard): Shard information indicating which layers to include. + """ super(ShardedLLAMAModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard - # Only keep layers corresponding to this shard - self.layers = nn.ModuleList([model.transformer.h[i] for i in range(shard.start_layer, shard.end_layer + 1)]) + # Load the full model and move to device + self.full_model = LlamaForCausalLM.from_pretrained(model_path) + self.full_model.to(self.device) - # Move embeddings to the appropriate device - self.model = model - self.model.transformer.wte.to(self.device) - self.model.transformer.wpe.to(self.device) + # Extract only the layers for this shard + self.layers = nn.ModuleList([ + self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) + ]) + + # Embeddings and final layer norm + self.embed_tokens = self.full_model.model.embed_tokens + self.embed_positions = self.full_model.model.embed_positions + self.norm = self.full_model.model.norm + self.lm_head = self.full_model.lm_head def forward(self, input_ids, past_key_values=None): - hidden_states = self._get_initial_hidden_states(input_ids) - hidden_states, new_past_key_values = self._process_layers(hidden_states, past_key_values) + """ + Perform a forward pass through the model. - if self.shard.is_last_layer(): - hidden_states = self.model.transformer.ln_f(hidden_states.to(self.device)) - logits = self.model.lm_head(hidden_states) - return logits, new_past_key_values - else: - return hidden_states, new_past_key_values + Args: + input_ids (torch.Tensor): Input token IDs. + past_key_values (list, optional): List of past key-value states for attention layers. - def _get_initial_hidden_states(self, input_ids): - input_embeds = self.model.transformer.wte(input_ids.to(self.device)) - position_embeds = self.model.transformer.wpe(torch.arange(input_ids.shape[1], device=self.device)) - return input_embeds + position_embeds + Returns: + tuple: Output logits or hidden states and the new past key-values. + """ + if past_key_values is None: + past_key_values = [None] * len(self.layers) - def _process_layers(self, hidden_states, past_key_values): + # Token and position embeddings + hidden_states = self.embed_tokens(input_ids) + self.embed_positions(input_ids) + + # Apply each layer in this shard new_past_key_values = [] for i, layer in enumerate(self.layers): - layer_past = past_key_values[i] if past_key_values else None + layer_past = past_key_values[i] hidden_states, new_layer_past = layer(hidden_states, past_key_values=layer_past) new_past_key_values.append(new_layer_past) - return hidden_states, new_past_key_values + + if self.shard.is_last_layer(): + # Apply final layer norm and compute logits + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + return logits, new_past_key_values + else: + return hidden_states, new_past_key_values From a0476568cac167dedc60c3e4d9aea3b146ff178a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 22:34:11 -0800 Subject: [PATCH 013/491] adding back set_on_download_progress --- exo/inference/pytorch/inference.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 0fe932bc..d73e3a75 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -219,3 +219,14 @@ async def ensure_shard(self, shard: Optional[Shard]): self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.shard = shard + + def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): + """ + Set a callback function to track download progress. + + Args: + on_download_progress (Callable[[int, int], None]): Callback function to track progress. + """ + # must have this function or inference engine breaks + # This method can be implemented if progress tracking is needed + pass From 36f675e99aa8eda00cf794d9bad54ef45ad4b90c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 22:35:38 -0800 Subject: [PATCH 014/491] fixing model path --- exo/inference/pytorch/inference.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index d73e3a75..8ca34f48 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -173,10 +173,7 @@ async def ensure_shard(self, shard: Optional[Shard]): if self.shard == shard: return - model_path = Path(self.model_name) - models_dir = Path(__file__).parent / "temp_model_dir" - model_path = models_dir / shard.model_id - + model_path = Path(f".cache/{shard.model_id}") if not model_path.exists(): os.makedirs(model_path, exist_ok=True) From 103a6bc3a934cc85cc4d34e8c9550ec59eee1667 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 22:41:01 -0800 Subject: [PATCH 015/491] trying to fix output changing from numpy to string at end of layer --- exo/inference/pytorch/inference.py | 5 +++-- exo/inference/pytorch/model/llama.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 8ca34f48..38ab500c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -9,6 +9,7 @@ from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.helpers import download_files +import logging # Default settings TEMPERATURE = 0.7 @@ -19,7 +20,7 @@ class PyTorchDynamicShardInferenceEngine(InferenceEngine): PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. """ - def __init__(self, debug: bool = False): + def __init__(self, debug: bool = True): """ Initialize the inference engine. @@ -70,7 +71,7 @@ async def infer_prompt( new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) if self.debug: - print(f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + logging.info(f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") return output_data, new_inference_state, is_eos diff --git a/exo/inference/pytorch/model/llama.py b/exo/inference/pytorch/model/llama.py index 01a69445..b6490844 100644 --- a/exo/inference/pytorch/model/llama.py +++ b/exo/inference/pytorch/model/llama.py @@ -2,6 +2,7 @@ import torch.nn as nn from transformers.models.llama.modeling_llama import LlamaForCausalLM from exo.inference.shard import Shard +from transformers import Cache class ShardedLLAMAModel(nn.Module): """ @@ -41,22 +42,25 @@ def forward(self, input_ids, past_key_values=None): Args: input_ids (torch.Tensor): Input token IDs. - past_key_values (list, optional): List of past key-value states for attention layers. + past_key_values (Cache, optional): Cache object for past key-value states. Returns: tuple: Output logits or hidden states and the new past key-values. """ if past_key_values is None: - past_key_values = [None] * len(self.layers) + past_key_values = Cache() # Token and position embeddings hidden_states = self.embed_tokens(input_ids) + self.embed_positions(input_ids) # Apply each layer in this shard - new_past_key_values = [] + new_past_key_values = Cache() for i, layer in enumerate(self.layers): - layer_past = past_key_values[i] - hidden_states, new_layer_past = layer(hidden_states, past_key_values=layer_past) + layer_past = past_key_values[i] if past_key_values else None + hidden_states, new_layer_past = layer( + hidden_states, + past_key_values=layer_past + ) new_past_key_values.append(new_layer_past) if self.shard.is_last_layer(): From b7fd4ed7d9f7148818a24db9bb0cb5aaa88c442b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 22:45:44 -0800 Subject: [PATCH 016/491] using print for debugging for now, fixed using cache for past kvs --- exo/inference/pytorch/inference.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 38ab500c..0840893a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -5,7 +5,7 @@ import numpy as np from pathlib import Path from typing import Optional, Callable, Tuple -from transformers import AutoTokenizer, LlamaForCausalLM +from transformers import AutoTokenizer, LlamaForCausalLM, Cache from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.helpers import download_files @@ -71,7 +71,7 @@ async def infer_prompt( new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) if self.debug: - logging.info(f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + print(f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") return output_data, new_inference_state, is_eos @@ -146,11 +146,14 @@ def _load_kv_cache(self, past_key_values_list): past_key_values_list (list): List of past key-value tensors. Returns: - list: List of loaded past key-value tensors. + Cache: Loaded past key-value cache. """ if past_key_values_list is None: - return None - return [torch.tensor(kv, device=self.model.device) for kv in past_key_values_list] + return Cache() + cache = Cache() + for kv in past_key_values_list: + cache.append(torch.tensor(kv, device=self.model.device)) + return cache def _save_kv_cache(self, past_key_values): """ From 034bd5af743b315ab246844706e40ad951130d0d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 22:49:25 -0800 Subject: [PATCH 017/491] trying to get some logging out --- exo/inference/pytorch/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 0840893a..675b6563 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -10,6 +10,8 @@ from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.helpers import download_files import logging +logging.basicConfig() +logging.getLogger("pytorch.inference").setLevel( logging.INFO ) # Default settings TEMPERATURE = 0.7 @@ -71,7 +73,8 @@ async def infer_prompt( new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) if self.debug: - print(f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + logging.info( + f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") return output_data, new_inference_state, is_eos From f5b4056adc19e8f6e80c9ca2eb497015c6acdbd1 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 22:50:24 -0800 Subject: [PATCH 018/491] lowering log level --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 675b6563..df6585b2 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -11,7 +11,7 @@ from exo.inference.pytorch.helpers import download_files import logging logging.basicConfig() -logging.getLogger("pytorch.inference").setLevel( logging.INFO ) +logging.getLogger("pytorch.inference").setLevel(logging.DEBUG) # Default settings TEMPERATURE = 0.7 From 58443557e0ad747fac5889675cce465704d0e9c7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 22:55:08 -0800 Subject: [PATCH 019/491] trying to get testing to show output still --- exo/inference/pytorch/test_inference_engine.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 3be15276..d2165baf 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -2,12 +2,14 @@ import asyncio from exo.inference.shard import Shard from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine +import logging +logging.basicConfig() +logging.getLogger("pytorch.inference.test_engine").setLevel(logging.DEBUG) class TestPyTorchDynamicShardInferenceEngine(unittest.TestCase): @classmethod def setUpClass(cls): - # Create a shard cls.shard = Shard( model_id="llama3-8b-sfr", @@ -23,6 +25,8 @@ def test_infer_prompt(self): # Prepare the prompt prompt = "Why is the sky blue?" + logging.info(f"Testing infer_prompt with prompt {prompt}") + # Run inference loop = asyncio.get_event_loop() output_data, new_inference_state, is_eos = loop.run_until_complete( From d015088e0cf5847ab2d25bac2887f5ee368c8824 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 22:57:39 -0800 Subject: [PATCH 020/491] logging --- exo/inference/pytorch/test_inference_engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index d2165baf..1748c77f 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -3,8 +3,7 @@ from exo.inference.shard import Shard from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine import logging -logging.basicConfig() -logging.getLogger("pytorch.inference.test_engine").setLevel(logging.DEBUG) + class TestPyTorchDynamicShardInferenceEngine(unittest.TestCase): @@ -56,4 +55,6 @@ def test_infer_prompt(self): # self.assertFalse(is_eos) if __name__ == '__main__': + logging.basicConfig() + logging.getLogger("pytorch.inference.test_engine").setLevel(logging.DEBUG) unittest.main() From f0e51bc91ea873d9d3aff905601ff402fbbb6c97 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 23:09:07 -0800 Subject: [PATCH 021/491] logging --- exo/inference/pytorch/test_inference_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 1748c77f..aa4aab8e 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -21,10 +21,12 @@ def setUpClass(cls): cls.engine = PyTorchDynamicShardInferenceEngine(debug=True) def test_infer_prompt(self): + log = logging.getLogger("pytorch.inference.test_engine") + # Prepare the prompt prompt = "Why is the sky blue?" - logging.info(f"Testing infer_prompt with prompt {prompt}") + log.info(f"Testing infer_prompt with prompt {prompt}") # Run inference loop = asyncio.get_event_loop() From 4d4362101f98372821eb863ab2b42bffe252bc84 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 23:12:11 -0800 Subject: [PATCH 022/491] logging --- exo/inference/pytorch/inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index df6585b2..9601773f 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -31,6 +31,7 @@ def __init__(self, debug: bool = True): """ self.shard = None self.debug = debug + self.log = logging.getLogger("pytorch.inference") async def infer_prompt( self, @@ -73,7 +74,7 @@ async def infer_prompt( new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) if self.debug: - logging.info( + self.log.info( f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") return output_data, new_inference_state, is_eos @@ -119,7 +120,7 @@ async def infer_tensor( new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) if self.debug: - print(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + self.log.info(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") return output_data, new_inference_state, is_eos From 002d02bd5eb657a22c6592ddabab6174744e2130 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 23:31:22 -0800 Subject: [PATCH 023/491] updating to go through each layer instead of whole model --- exo/inference/pytorch/inference.py | 3 +-- exo/inference/pytorch/model/llama.py | 34 ++++++++-------------------- 2 files changed, 10 insertions(+), 27 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 9601773f..6328339c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,5 +1,4 @@ import os -import shutil import json import torch import numpy as np @@ -10,6 +9,7 @@ from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.helpers import download_files import logging + logging.basicConfig() logging.getLogger("pytorch.inference").setLevel(logging.DEBUG) @@ -232,6 +232,5 @@ def set_on_download_progress(self, on_download_progress: Callable[[int, int], No Args: on_download_progress (Callable[[int, int], None]): Callback function to track progress. """ - # must have this function or inference engine breaks # This method can be implemented if progress tracking is needed pass diff --git a/exo/inference/pytorch/model/llama.py b/exo/inference/pytorch/model/llama.py index b6490844..f6427e02 100644 --- a/exo/inference/pytorch/model/llama.py +++ b/exo/inference/pytorch/model/llama.py @@ -2,26 +2,14 @@ import torch.nn as nn from transformers.models.llama.modeling_llama import LlamaForCausalLM from exo.inference.shard import Shard -from transformers import Cache class ShardedLLAMAModel(nn.Module): - """ - Sharded LLAMA Model for performing inference with a subset of model layers. - """ - def __init__(self, model_path: str, shard: Shard): - """ - Initialize the ShardedLLAMAModel. - - Args: - model_path (str): Path to the pretrained model. - shard (Shard): Shard information indicating which layers to include. - """ super(ShardedLLAMAModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard - # Load the full model and move to device + # Load the full model self.full_model = LlamaForCausalLM.from_pretrained(model_path) self.full_model.to(self.device) @@ -36,35 +24,31 @@ def __init__(self, model_path: str, shard: Shard): self.norm = self.full_model.model.norm self.lm_head = self.full_model.lm_head - def forward(self, input_ids, past_key_values=None): + def forward_layers(self, input_ids, past_key_values=None): """ - Perform a forward pass through the model. + Forward pass through the specified layers. Args: input_ids (torch.Tensor): Input token IDs. - past_key_values (Cache, optional): Cache object for past key-value states. + past_key_values (list, optional): Past key values for caching. Returns: - tuple: Output logits or hidden states and the new past key-values. + tuple: Hidden states and new past key values. """ if past_key_values is None: - past_key_values = Cache() + past_key_values = [None] * len(self.layers) # Token and position embeddings hidden_states = self.embed_tokens(input_ids) + self.embed_positions(input_ids) # Apply each layer in this shard - new_past_key_values = Cache() + new_past_key_values = [] for i, layer in enumerate(self.layers): - layer_past = past_key_values[i] if past_key_values else None - hidden_states, new_layer_past = layer( - hidden_states, - past_key_values=layer_past - ) + layer_past = past_key_values[i] + hidden_states, new_layer_past = layer(hidden_states, past_key_values=layer_past, use_cache=True) new_past_key_values.append(new_layer_past) if self.shard.is_last_layer(): - # Apply final layer norm and compute logits hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) return logits, new_past_key_values From 1f5c45ebc525647a77e440b6e8c02fd58bb65495 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 23:41:16 -0800 Subject: [PATCH 024/491] fixing forward pass through specific layers --- exo/inference/pytorch/inference.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 6328339c..eeb3cc4f 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,13 +1,16 @@ +# experimental, based off of tinygrad/inference.py + import os import json import torch import numpy as np from pathlib import Path from typing import Optional, Callable, Tuple -from transformers import AutoTokenizer, LlamaForCausalLM, Cache +from transformers import AutoTokenizer, Cache from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.helpers import download_files +from exo.inference.pytorch.model.llama import ShardedLLAMAModel import logging logging.basicConfig() @@ -60,7 +63,7 @@ async def infer_prompt( if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model(input_ids, past_key_values=past_key_values) + output, past_key_values = self.model.forward_layers(input_ids, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) @@ -106,7 +109,7 @@ async def infer_tensor( if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model(input_tensor, past_key_values=past_key_values) + output, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) @@ -215,12 +218,10 @@ async def ensure_shard(self, shard: Optional[Shard]): else: raise ValueError(f"Unsupported model: {shard.model_id}") - # Load model and tokenizer from the downloaded files - # This is written for llama model but need to add in option for others - self.model = LlamaForCausalLM.from_pretrained( - model_path, - torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) + # Load the sharded model + self.model = ShardedLLAMAModel(model_path, shard) + # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.shard = shard From 378871620888734653bd5f14e3bb99d7b6897ab8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 23:46:24 -0800 Subject: [PATCH 025/491] adding pytorch data parallel for multi gpu support --- exo/inference/pytorch/inference.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index eeb3cc4f..4ab0c360 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -219,7 +219,12 @@ async def ensure_shard(self, shard: Optional[Shard]): raise ValueError(f"Unsupported model: {shard.model_id}") # Load the sharded model - self.model = ShardedLLAMAModel(model_path, shard) + sharded_model = ShardedLLAMAModel(model_path, shard) + + # Use DataParallel for multi-GPU support + self.model = torch.nn.DataParallel(sharded_model) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path) From c2be3640f7dcaf512e8327c4ddde9c5abb472ae5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 6 Aug 2024 23:51:42 -0800 Subject: [PATCH 026/491] add in device ids --- exo/inference/pytorch/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 4ab0c360..f0b565bc 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -222,7 +222,8 @@ async def ensure_shard(self, shard: Optional[Shard]): sharded_model = ShardedLLAMAModel(model_path, shard) # Use DataParallel for multi-GPU support - self.model = torch.nn.DataParallel(sharded_model) + device_ids = [i for i in range(torch.cuda.device_count())] + self.model = torch.nn.DataParallel(sharded_model, device_ids=device_ids) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) From ad265d8e9c0afc97196de2c621d1db3628fd3ac4 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 00:02:14 -0800 Subject: [PATCH 027/491] adding distributed data parallel for pytorch to fix single gpu utilization issue --- exo/inference/pytorch/inference.py | 41 ++++++++--- .../pytorch/test_inference_engine.py | 71 +++++++++++-------- 2 files changed, 72 insertions(+), 40 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f0b565bc..b5228560 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -11,6 +11,9 @@ from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.helpers import download_files from exo.inference.pytorch.model.llama import ShardedLLAMAModel +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP import logging logging.basicConfig() @@ -25,16 +28,34 @@ class PyTorchDynamicShardInferenceEngine(InferenceEngine): PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. """ - def __init__(self, debug: bool = True): + def __init__(self, debug: bool = True, rank: int = 0, world_size: int = 1): """ Initialize the inference engine. Args: debug (bool): If True, enables debug logging. Defaults to False. + rank (int): Rank of the current process in distributed training. + world_size (int): Total number of processes in distributed training. """ self.shard = None self.debug = debug + self.rank = rank + self.world_size = world_size + self.device = torch.device(f"cuda:{rank}") self.log = logging.getLogger("pytorch.inference") + self.setup_distributed() + + def setup_distributed(self): + """ + Initialize the process group for distributed training. + """ + dist.init_process_group(backend='nccl', init_method='env://', world_size=self.world_size, rank=self.rank) + + def cleanup_distributed(self): + """ + Clean up the process group for distributed training. + """ + dist.destroy_process_group() async def infer_prompt( self, @@ -56,14 +77,14 @@ async def infer_prompt( """ await self.ensure_shard(shard) - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) # Continue the sequence if inference state exists past_key_values = None if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model.forward_layers(input_ids, past_key_values=past_key_values) + output, past_key_values = self.model(input_ids, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) @@ -102,14 +123,14 @@ async def infer_tensor( """ await self.ensure_shard(shard) - input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) + input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.device) # Continue the sequence if inference state exists past_key_values = None if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) + output, past_key_values = self.model(input_tensor, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) @@ -159,7 +180,7 @@ def _load_kv_cache(self, past_key_values_list): return Cache() cache = Cache() for kv in past_key_values_list: - cache.append(torch.tensor(kv, device=self.model.device)) + cache.append(torch.tensor(kv, device=self.device)) return cache def _save_kv_cache(self, past_key_values): @@ -221,11 +242,9 @@ async def ensure_shard(self, shard: Optional[Shard]): # Load the sharded model sharded_model = ShardedLLAMAModel(model_path, shard) - # Use DataParallel for multi-GPU support - device_ids = [i for i in range(torch.cuda.device_count())] - self.model = torch.nn.DataParallel(sharded_model, device_ids=device_ids) + # Use DistributedDataParallel for multi-GPU support + self.model = DDP(sharded_model.to(self.device), device_ids=[self.rank], output_device=self.rank) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model.to(self.device) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -240,4 +259,4 @@ def set_on_download_progress(self, on_download_progress: Callable[[int, int], No on_download_progress (Callable[[int, int], None]): Callback function to track progress. """ # This method can be implemented if progress tracking is needed - pass + pass \ No newline at end of file diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index aa4aab8e..0309c5ff 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -1,14 +1,15 @@ import unittest +import torch import asyncio +import torch.multiprocessing as mp from exo.inference.shard import Shard from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine -import logging - class TestPyTorchDynamicShardInferenceEngine(unittest.TestCase): - @classmethod def setUpClass(cls): + cls.world_size = torch.cuda.device_count() + # Create a shard cls.shard = Shard( model_id="llama3-8b-sfr", @@ -17,46 +18,58 @@ def setUpClass(cls): n_layers=12 ) - # Initialize the inference engine - cls.engine = PyTorchDynamicShardInferenceEngine(debug=True) + def run_engine(rank, world_size, shard, queue): + """ + Run the inference engine in a distributed setting. + """ + # Initialize the engine + engine = PyTorchDynamicShardInferenceEngine(debug=True, rank=rank, world_size=world_size) - def test_infer_prompt(self): - log = logging.getLogger("pytorch.inference.test_engine") + # Run ensure_shard to set up the model + asyncio.run(engine.ensure_shard(shard)) # Prepare the prompt prompt = "Why is the sky blue?" - log.info(f"Testing infer_prompt with prompt {prompt}") - # Run inference - loop = asyncio.get_event_loop() - output_data, new_inference_state, is_eos = loop.run_until_complete( - self.engine.infer_prompt( - request_id="test_request", shard=self.shard, prompt=prompt + output_data, new_inference_state, is_eos = asyncio.run( + engine.infer_prompt( + request_id="test_request", shard=shard, prompt=prompt ) ) + # Put results in the queue to be checked in the test + queue.put((output_data, new_inference_state, is_eos)) + + def test_infer_prompt(self): + """ + Test the inference on a text prompt in a distributed setting. + """ + mp.set_start_method('spawn') + queue = mp.Queue() + + processes = [] + for rank in range(self.world_size): + p = mp.Process(target=self.run_engine, args=(rank, self.world_size, self.shard, queue)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + output_data, new_inference_state, is_eos = queue.get() + # Assertions self.assertIsNotNone(output_data) self.assertIsNotNone(new_inference_state) self.assertFalse(is_eos) - # def test_infer_tensor(self): - # # Prepare the input tensor - # input_ids = self.tokenizer.encode("Hello, world!", return_tensors="pt").numpy() - - # # Run inference - # loop = asyncio.get_event_loop() - # output_data, new_inference_state, is_eos = loop.run_until_complete(self.engine.infer_tensor( - # request_id="test_request", shard=self.shard, input_data=input_ids - # )) - - # # Assertions - # self.assertIsNotNone(output_data) - # self.assertIsNotNone(new_inference_state) - # self.assertFalse(is_eos) + @classmethod + def tearDownClass(cls): + """ + Clean up after the test. + """ + mp.set_start_method('fork', force=True) # Reset the multiprocessing start method to default if __name__ == '__main__': - logging.basicConfig() - logging.getLogger("pytorch.inference.test_engine").setLevel(logging.DEBUG) unittest.main() From 3e21a5d91b6265779bea86166351a67d88763635 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 00:11:15 -0800 Subject: [PATCH 028/491] ddp fixes --- exo/inference/pytorch/inference.py | 7 +- .../pytorch/test_inference_engine.py | 72 ++++++++++++------- 2 files changed, 52 insertions(+), 27 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index b5228560..cc437092 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -49,7 +49,12 @@ def setup_distributed(self): """ Initialize the process group for distributed training. """ - dist.init_process_group(backend='nccl', init_method='env://', world_size=self.world_size, rank=self.rank) + dist.init_process_group( + backend='nccl', + init_method='env://', + world_size=self.world_size, + rank=self.rank + ) def cleanup_distributed(self): """ diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 0309c5ff..39fdd54b 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -1,15 +1,58 @@ import unittest import torch -import asyncio import torch.multiprocessing as mp +import torch.distributed as dist +import os +import asyncio from exo.inference.shard import Shard from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine +def setup(rank, world_size): + """ + Set up the distributed environment. + """ + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group("nccl", rank=rank, world_size=world_size) + +def cleanup(): + """ + Clean up the distributed environment. + """ + dist.destroy_process_group() + +def run_engine(rank, world_size, shard, queue): + """ + Run the inference engine in a distributed setting. + """ + setup(rank, world_size) + + # Initialize the engine + engine = PyTorchDynamicShardInferenceEngine(debug=True) + + # Run ensure_shard to set up the model + asyncio.run(engine.ensure_shard(shard)) + + # Prepare the prompt + prompt = "Why is the sky blue?" + + # Run inference + output_data, new_inference_state, is_eos = asyncio.run( + engine.infer_prompt( + request_id="test_request", shard=shard, prompt=prompt + ) + ) + + # Put results in the queue to be checked in the test + queue.put((output_data, new_inference_state, is_eos)) + + cleanup() + class TestPyTorchDynamicShardInferenceEngine(unittest.TestCase): @classmethod def setUpClass(cls): cls.world_size = torch.cuda.device_count() - + # Create a shard cls.shard = Shard( model_id="llama3-8b-sfr", @@ -18,29 +61,6 @@ def setUpClass(cls): n_layers=12 ) - def run_engine(rank, world_size, shard, queue): - """ - Run the inference engine in a distributed setting. - """ - # Initialize the engine - engine = PyTorchDynamicShardInferenceEngine(debug=True, rank=rank, world_size=world_size) - - # Run ensure_shard to set up the model - asyncio.run(engine.ensure_shard(shard)) - - # Prepare the prompt - prompt = "Why is the sky blue?" - - # Run inference - output_data, new_inference_state, is_eos = asyncio.run( - engine.infer_prompt( - request_id="test_request", shard=shard, prompt=prompt - ) - ) - - # Put results in the queue to be checked in the test - queue.put((output_data, new_inference_state, is_eos)) - def test_infer_prompt(self): """ Test the inference on a text prompt in a distributed setting. @@ -50,7 +70,7 @@ def test_infer_prompt(self): processes = [] for rank in range(self.world_size): - p = mp.Process(target=self.run_engine, args=(rank, self.world_size, self.shard, queue)) + p = mp.Process(target=run_engine, args=(rank, self.world_size, self.shard, queue)) p.start() processes.append(p) From 93d838fa34c329997adbea9f0a77d5db75cb1a91 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 00:14:11 -0800 Subject: [PATCH 029/491] update testing --- exo/inference/pytorch/test_inference_engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 39fdd54b..7ac659a9 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -2,8 +2,9 @@ import torch import torch.multiprocessing as mp import torch.distributed as dist -import os import asyncio +import os +from transformers import AutoTokenizer from exo.inference.shard import Shard from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine @@ -65,7 +66,7 @@ def test_infer_prompt(self): """ Test the inference on a text prompt in a distributed setting. """ - mp.set_start_method('spawn') + mp.set_start_method('spawn', force=True) queue = mp.Queue() processes = [] From 6c0f6af61720d35f47b32c902c1e5f768a334a10 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 00:16:30 -0800 Subject: [PATCH 030/491] fixing init issue --- exo/inference/pytorch/inference.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index cc437092..7c5ff126 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -47,20 +47,21 @@ def __init__(self, debug: bool = True, rank: int = 0, world_size: int = 1): def setup_distributed(self): """ - Initialize the process group for distributed training. + Set up the distributed environment. """ - dist.init_process_group( - backend='nccl', - init_method='env://', - world_size=self.world_size, - rank=self.rank - ) + if not dist.is_initialized(): + dist.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + rank=self.rank, + world_size=self.world_size + ) def cleanup_distributed(self): """ - Clean up the process group for distributed training. + Clean up the distributed environment. """ - dist.destroy_process_group() + if dist.is_initialized(): + dist.destroy_process_group() async def infer_prompt( self, From db1786b3d0b28f9749122064c532d81f2e87f524 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 00:38:40 -0800 Subject: [PATCH 031/491] going back to DataParallel, using device map auto for from_pretrained --- exo/inference/pytorch/inference.py | 45 +++++----- .../pytorch/test_inference_engine.py | 86 ++++--------------- 2 files changed, 39 insertions(+), 92 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 7c5ff126..b7cee23e 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,19 +1,18 @@ # experimental, based off of tinygrad/inference.py import os +import shutil import json import torch -import numpy as np +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn from pathlib import Path from typing import Optional, Callable, Tuple -from transformers import AutoTokenizer, Cache +from transformers import AutoTokenizer, LlamaForCausalLM, Cache from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.helpers import download_files -from exo.inference.pytorch.model.llama import ShardedLLAMAModel -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel as DDP import logging logging.basicConfig() @@ -28,22 +27,20 @@ class PyTorchDynamicShardInferenceEngine(InferenceEngine): PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. """ - def __init__(self, debug: bool = True, rank: int = 0, world_size: int = 1): + def __init__(self, debug: bool = True): """ Initialize the inference engine. Args: debug (bool): If True, enables debug logging. Defaults to False. - rank (int): Rank of the current process in distributed training. - world_size (int): Total number of processes in distributed training. """ self.shard = None self.debug = debug - self.rank = rank - self.world_size = world_size - self.device = torch.device(f"cuda:{rank}") self.log = logging.getLogger("pytorch.inference") - self.setup_distributed() + self.device_ids = list(range(torch.cuda.device_count())) + self.rank = int(os.getenv("RANK", "0")) + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def setup_distributed(self): """ @@ -245,18 +242,23 @@ async def ensure_shard(self, shard: Optional[Shard]): else: raise ValueError(f"Unsupported model: {shard.model_id}") - # Load the sharded model - sharded_model = ShardedLLAMAModel(model_path, shard) - - # Use DistributedDataParallel for multi-GPU support - self.model = DDP(sharded_model.to(self.device), device_ids=[self.rank], output_device=self.rank) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Load model and tokenizer from the downloaded files + # This is written for llama model but need to add in option for others + self.model = LlamaForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, + device_map="auto" + ) - # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path) + if torch.cuda.device_count() > 1: + self.model = nn.DataParallel(self.model, device_ids=self.device_ids) + + self.model.to(self.device) self.shard = shard + def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): """ Set a callback function to track download progress. @@ -264,5 +266,6 @@ def set_on_download_progress(self, on_download_progress: Callable[[int, int], No Args: on_download_progress (Callable[[int, int], None]): Callback function to track progress. """ + # must have this function or inference engine breaks # This method can be implemented if progress tracking is needed - pass \ No newline at end of file + pass diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 7ac659a9..d50e6d7d 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -1,59 +1,13 @@ import unittest -import torch -import torch.multiprocessing as mp -import torch.distributed as dist import asyncio -import os -from transformers import AutoTokenizer from exo.inference.shard import Shard from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine -def setup(rank, world_size): - """ - Set up the distributed environment. - """ - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - dist.init_process_group("nccl", rank=rank, world_size=world_size) - -def cleanup(): - """ - Clean up the distributed environment. - """ - dist.destroy_process_group() - -def run_engine(rank, world_size, shard, queue): - """ - Run the inference engine in a distributed setting. - """ - setup(rank, world_size) - - # Initialize the engine - engine = PyTorchDynamicShardInferenceEngine(debug=True) - - # Run ensure_shard to set up the model - asyncio.run(engine.ensure_shard(shard)) - - # Prepare the prompt - prompt = "Why is the sky blue?" - - # Run inference - output_data, new_inference_state, is_eos = asyncio.run( - engine.infer_prompt( - request_id="test_request", shard=shard, prompt=prompt - ) - ) - - # Put results in the queue to be checked in the test - queue.put((output_data, new_inference_state, is_eos)) - - cleanup() - class TestPyTorchDynamicShardInferenceEngine(unittest.TestCase): + @classmethod def setUpClass(cls): - cls.world_size = torch.cuda.device_count() - + # Create a shard cls.shard = Shard( model_id="llama3-8b-sfr", @@ -62,35 +16,25 @@ def setUpClass(cls): n_layers=12 ) - def test_infer_prompt(self): - """ - Test the inference on a text prompt in a distributed setting. - """ - mp.set_start_method('spawn', force=True) - queue = mp.Queue() - - processes = [] - for rank in range(self.world_size): - p = mp.Process(target=run_engine, args=(rank, self.world_size, self.shard, queue)) - p.start() - processes.append(p) - - for p in processes: - p.join() + # Initialize the inference engine + cls.engine = PyTorchDynamicShardInferenceEngine(debug=True) - output_data, new_inference_state, is_eos = queue.get() + def test_infer_prompt(self): + # Prepare the prompt + prompt = "Why is the sky blue?" + + # Run inference + loop = asyncio.get_event_loop() + output_data, new_inference_state, is_eos = loop.run_until_complete( + self.engine.infer_prompt( + request_id="test_request", shard=self.shard, prompt=prompt + ) + ) # Assertions self.assertIsNotNone(output_data) self.assertIsNotNone(new_inference_state) self.assertFalse(is_eos) - @classmethod - def tearDownClass(cls): - """ - Clean up after the test. - """ - mp.set_start_method('fork', force=True) # Reset the multiprocessing start method to default - if __name__ == '__main__': unittest.main() From 9eacb5a556abd31ddf00677021c1b8b311c4e8d3 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 00:42:02 -0800 Subject: [PATCH 032/491] numpy fix --- exo/inference/pytorch/inference.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index b7cee23e..fc672dc0 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,12 +1,10 @@ # experimental, based off of tinygrad/inference.py import os -import shutil import json import torch -import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn +import numpy as np from pathlib import Path from typing import Optional, Callable, Tuple from transformers import AutoTokenizer, LlamaForCausalLM, Cache @@ -42,24 +40,6 @@ def __init__(self, debug: bool = True): self.world_size = int(os.getenv("WORLD_SIZE", "1")) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def setup_distributed(self): - """ - Set up the distributed environment. - """ - if not dist.is_initialized(): - dist.init_process_group( - backend="nccl" if torch.cuda.is_available() else "gloo", - rank=self.rank, - world_size=self.world_size - ) - - def cleanup_distributed(self): - """ - Clean up the distributed environment. - """ - if dist.is_initialized(): - dist.destroy_process_group() - async def infer_prompt( self, request_id: str, From 52eb966883b47459b70509d766d9c823b45eb2a7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:06:28 -0800 Subject: [PATCH 033/491] using llama3 and hf repo with token --- exo/inference/pytorch/inference.py | 60 +++++------------------------ exo/inference/pytorch/model/hf.py | 61 ++++++++++++++++++++---------- 2 files changed, 50 insertions(+), 71 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index fc672dc0..8dadc8df 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -7,10 +7,10 @@ import numpy as np from pathlib import Path from typing import Optional, Callable, Tuple -from transformers import AutoTokenizer, LlamaForCausalLM, Cache +from transformers import AutoTokenizer, Cache from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine -from exo.inference.pytorch.helpers import download_files +from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel import logging logging.basicConfig() @@ -25,7 +25,7 @@ class PyTorchDynamicShardInferenceEngine(InferenceEngine): PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. """ - def __init__(self, debug: bool = True): + def __init__(self, model_name: str, debug: bool = True): """ Initialize the inference engine. @@ -33,9 +33,10 @@ def __init__(self, debug: bool = True): debug (bool): If True, enables debug logging. Defaults to False. """ self.shard = None + self.model = None + self.model_name = model_name if model_name else "meta-llama/Meta-Llama-3-8B" self.debug = debug self.log = logging.getLogger("pytorch.inference") - self.device_ids = list(range(torch.cuda.device_count())) self.rank = int(os.getenv("RANK", "0")) self.world_size = int(os.getenv("WORLD_SIZE", "1")) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -188,57 +189,14 @@ async def ensure_shard(self, shard: Optional[Shard]): if self.shard == shard: return - model_path = Path(f".cache/{shard.model_id}") - if not model_path.exists(): - os.makedirs(model_path, exist_ok=True) - - if shard.model_id.lower().find("llama3-8b-sfr") != -1: - num_files = 4 - urls = [ - f"https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/model-{(i+1):05d}-of-{num_files:05d}.safetensors" - for i in range(num_files) - ] - urls.extend([ - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/config.json", - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/raw/main/model.safetensors.index.json", - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/special_tokens_map.json", - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer.json", - "https://huggingface.co/mlx-community/Meta-Llama-3-8B-Instruct/resolve/main/tokenizer_config.json" - ]) - - output_paths = [ - model_path / f"model-{(i+1):05d}-of-{num_files:05d}.safetensors" - for i in range(num_files) - ] - output_paths.extend([ - model_path / "config.json", - model_path / "model.safetensors.index.json", - model_path / "special_tokens_map.json", - model_path / "tokenizer.json", - model_path / "tokenizer_config.json" - ]) - - await download_files(urls, output_paths) - else: - raise ValueError(f"Unsupported model: {shard.model_id}") - # Load model and tokenizer from the downloaded files # This is written for llama model but need to add in option for others - self.model = LlamaForCausalLM.from_pretrained( - model_path, - torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, - device_map="auto" - ) - - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - - if torch.cuda.device_count() > 1: - self.model = nn.DataParallel(self.model, device_ids=self.device_ids) - - self.model.to(self.device) + if not self.model: + self.model = ShardedHuggingFaceModel(self.model_name, shard) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.shard = shard - def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): """ Set a callback function to track download progress. diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index ba838be8..861f259f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -12,9 +12,21 @@ def __init__(self, model_name: str, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard + self.device_ids = list(range(torch.cuda.device_count())) # Load the model - self.model = AutoModelForCausalLM.from_pretrained(model_name) + if torch.cuda.device_count() > 1: + self.model = nn.DataParallel(AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, + device_map="auto" + )) + else: + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, + device_map="auto" + ) # Only keep layers corresponding to this shard self.layers = nn.ModuleList([ @@ -23,30 +35,39 @@ def __init__(self, model_name: str, shard: Shard): logging.info(f"layers: {self.layers}") - self.model.transformer.wte.to(self.device) - self.model.transformer.wpe.to(self.device) + # Embeddings and final layer norm + self.embed_tokens = self.full_model.model.embed_tokens + self.embed_positions = self.full_model.model.embed_positions + self.norm = self.full_model.model.norm + self.lm_head = self.full_model.lm_head - def forward(self, input_ids, past_key_values=None): - hidden_states = self._get_initial_hidden_states(input_ids) - hidden_states, new_past_key_values = self._process_layers(hidden_states, past_key_values) + def forward_layers(self, input_ids, past_key_values=None): + """ + Forward pass through the specified layers. - if self.shard.is_last_layer(): - hidden_states = self.model.transformer.ln_f(hidden_states.to(self.device)) - logits = self.model.lm_head(hidden_states) - return logits, new_past_key_values - else: - return hidden_states, new_past_key_values + Args: + input_ids (torch.Tensor): Input token IDs. + past_key_values (list, optional): Past key values for caching. - def _get_initial_hidden_states(self, input_ids): - input_embeds = self.model.transformer.wte(input_ids.to(self.device)) - position_embeds = self.model.transformer.wpe(torch.arange(input_ids.shape[1], device=self.device)) - return input_embeds + position_embeds + Returns: + tuple: Hidden states and new past key values. + """ + if past_key_values is None: + past_key_values = [None] * len(self.layers) - def _process_layers(self, hidden_states, past_key_values): + # Token and position embeddings + hidden_states = self.embed_tokens(input_ids) + self.embed_positions(input_ids) + + # Apply each layer in this shard new_past_key_values = [] for i, layer in enumerate(self.layers): - layer_past = past_key_values[i] if past_key_values else None - hidden_states, new_layer_past = layer(hidden_states, past_key_values=layer_past) + layer_past = past_key_values[i] + hidden_states, new_layer_past = layer(hidden_states, past_key_values=layer_past, use_cache=True) new_past_key_values.append(new_layer_past) - return hidden_states, new_past_key_values + if self.shard.is_last_layer(): + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + return logits, new_past_key_values + else: + return hidden_states, new_past_key_values \ No newline at end of file From e9b931fb0cf214845676fbc0f14044cd11dd0ecb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:08:25 -0800 Subject: [PATCH 034/491] fixing test --- exo/inference/pytorch/test_inference_engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index d50e6d7d..96889919 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -10,14 +10,17 @@ def setUpClass(cls): # Create a shard cls.shard = Shard( - model_id="llama3-8b-sfr", + model_id="meta-llama/Meta-Llama-3-8B", start_layer=0, end_layer=0, n_layers=12 ) # Initialize the inference engine - cls.engine = PyTorchDynamicShardInferenceEngine(debug=True) + cls.engine = PyTorchDynamicShardInferenceEngine( + cls.shard.model_id, + debug=True + ) def test_infer_prompt(self): # Prepare the prompt From 0f3787091abc2e8637ba1c97497a23f2a3178951 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:14:34 -0800 Subject: [PATCH 035/491] removing nn distributed for less complexity and just using device_map with accelerate --- exo/inference/pytorch/inference.py | 3 +-- exo/inference/pytorch/model/hf.py | 17 +++++------------ 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 8dadc8df..f85daf22 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import numpy as np -from pathlib import Path from typing import Optional, Callable, Tuple from transformers import AutoTokenizer, Cache from exo.inference.shard import Shard @@ -194,7 +193,7 @@ async def ensure_shard(self, shard: Optional[Shard]): if not self.model: self.model = ShardedHuggingFaceModel(self.model_name, shard) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - + self.shard = shard def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 861f259f..57896e33 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -15,18 +15,11 @@ def __init__(self, model_name: str, shard: Shard): self.device_ids = list(range(torch.cuda.device_count())) # Load the model - if torch.cuda.device_count() > 1: - self.model = nn.DataParallel(AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, - device_map="auto" - )) - else: - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, - device_map="auto" - ) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, + device_map="auto" + ) # Only keep layers corresponding to this shard self.layers = nn.ModuleList([ From 11085bebe7a1a8f3ad17f525737a7855e9b1e40d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:16:36 -0800 Subject: [PATCH 036/491] fixing layers call --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 57896e33..f2b9cf58 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -23,7 +23,7 @@ def __init__(self, model_name: str, shard: Shard): # Only keep layers corresponding to this shard self.layers = nn.ModuleList([ - self.model.transformer.h[i] for i in range(shard.start_layer, shard.end_layer + 1) + self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) ]) logging.info(f"layers: {self.layers}") From fca4cd060ec399c65105cbf69345c6009dec2450 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:17:51 -0800 Subject: [PATCH 037/491] fixing model call --- exo/inference/pytorch/model/hf.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index f2b9cf58..4fdc3662 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -23,16 +23,16 @@ def __init__(self, model_name: str, shard: Shard): # Only keep layers corresponding to this shard self.layers = nn.ModuleList([ - self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) + self.model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) ]) logging.info(f"layers: {self.layers}") # Embeddings and final layer norm - self.embed_tokens = self.full_model.model.embed_tokens - self.embed_positions = self.full_model.model.embed_positions - self.norm = self.full_model.model.norm - self.lm_head = self.full_model.lm_head + self.embed_tokens = self.model.model.embed_tokens + self.embed_positions = self.model.model.embed_positions + self.norm = self.model.model.norm + self.lm_head = self.model.lm_head def forward_layers(self, input_ids, past_key_values=None): """ From 3dbf14754b5626662f2d0804b33ce56313ccbb53 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:24:18 -0800 Subject: [PATCH 038/491] fixing model --- exo/inference/pytorch/model/hf.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 4fdc3662..35885b0d 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -22,17 +22,14 @@ def __init__(self, model_name: str, shard: Shard): ) # Only keep layers corresponding to this shard - self.layers = nn.ModuleList([ - self.model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) - ]) - + self.layers = self.model.layers logging.info(f"layers: {self.layers}") # Embeddings and final layer norm - self.embed_tokens = self.model.model.embed_tokens - self.embed_positions = self.model.model.embed_positions - self.norm = self.model.model.norm - self.lm_head = self.model.lm_head + self.embed_tokens = self.model.embed_tokens + self.embed_positions = self.model.embed_positions + self.norm = self.model.norm + self.lm_head = self.lm_head def forward_layers(self, input_ids, past_key_values=None): """ From 904b3c9c30243819331808852308860f2519dd88 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:27:29 -0800 Subject: [PATCH 039/491] fixing get layers --- exo/inference/pytorch/model/hf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 35885b0d..a0be18b9 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -21,8 +21,10 @@ def __init__(self, model_name: str, shard: Shard): device_map="auto" ) - # Only keep layers corresponding to this shard - self.layers = self.model.layers + # Extract only the layers for this shard + self.layers = nn.ModuleList([ + self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) + ]) logging.info(f"layers: {self.layers}") # Embeddings and final layer norm From 61dcf262283b721e4ee15b8e4186344b0bbdf52f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:28:22 -0800 Subject: [PATCH 040/491] fixing get layers --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index a0be18b9..03192c81 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -23,7 +23,7 @@ def __init__(self, model_name: str, shard: Shard): # Extract only the layers for this shard self.layers = nn.ModuleList([ - self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) + self.model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) ]) logging.info(f"layers: {self.layers}") From c7ef7d76cb1d518e56d46d1ece271060def3aa1d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:29:32 -0800 Subject: [PATCH 041/491] fixing get layers --- exo/inference/pytorch/model/hf.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 03192c81..448a8f8e 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -15,7 +15,7 @@ def __init__(self, model_name: str, shard: Shard): self.device_ids = list(range(torch.cuda.device_count())) # Load the model - self.model = AutoModelForCausalLM.from_pretrained( + self.full_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" @@ -23,15 +23,14 @@ def __init__(self, model_name: str, shard: Shard): # Extract only the layers for this shard self.layers = nn.ModuleList([ - self.model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) + self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) ]) - logging.info(f"layers: {self.layers}") - + # Embeddings and final layer norm - self.embed_tokens = self.model.embed_tokens - self.embed_positions = self.model.embed_positions - self.norm = self.model.norm - self.lm_head = self.lm_head + self.embed_tokens = self.full_model.model.embed_tokens + self.embed_positions = self.full_model.model.embed_positions + self.norm = self.full_model.model.norm + self.lm_head = self.full_model.lm_head def forward_layers(self, input_ids, past_key_values=None): """ From f760cba5a62d28c7f1e7dff3f92672ee5d66e764 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:45:44 -0800 Subject: [PATCH 042/491] updated sharded hf model --- exo/inference/pytorch/model/hf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 448a8f8e..3eeb053a 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -28,7 +28,6 @@ def __init__(self, model_name: str, shard: Shard): # Embeddings and final layer norm self.embed_tokens = self.full_model.model.embed_tokens - self.embed_positions = self.full_model.model.embed_positions self.norm = self.full_model.model.norm self.lm_head = self.full_model.lm_head @@ -46,8 +45,8 @@ def forward_layers(self, input_ids, past_key_values=None): if past_key_values is None: past_key_values = [None] * len(self.layers) - # Token and position embeddings - hidden_states = self.embed_tokens(input_ids) + self.embed_positions(input_ids) + # Token embeddings + hidden_states = self.embed_tokens(input_ids) # Apply each layer in this shard new_past_key_values = [] From 073f094c7102036e43ede640febf206d6dd44802 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 01:56:11 -0800 Subject: [PATCH 043/491] fix model call --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f85daf22..7b72b8eb 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -67,7 +67,7 @@ async def infer_prompt( if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model(input_ids, past_key_values=past_key_values) + output, past_key_values = self.model.full_model(input_ids, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) From f54f8b8df6414685e70feb1074b9ae48cd93aed7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:00:56 -0800 Subject: [PATCH 044/491] trying tokenizer on cpu --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 7b72b8eb..9da7a4ed 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -60,7 +60,7 @@ async def infer_prompt( """ await self.ensure_shard(shard) - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to("cpu") # Continue the sequence if inference state exists past_key_values = None From c6f0cbbeb14ed10fa9cc50659dc22f7f0bd1ad2a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:03:56 -0800 Subject: [PATCH 045/491] testing other models --- exo/inference/pytorch/test_inference_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 96889919..e2b0eaf6 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -10,7 +10,7 @@ def setUpClass(cls): # Create a shard cls.shard = Shard( - model_id="meta-llama/Meta-Llama-3-8B", + model_id="LLMQ/LLaMA-3-8B-GPTQ-4bit-b128", start_layer=0, end_layer=0, n_layers=12 From 999759b2a315e29f5eb3c869f68388ecd610d9bc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:06:46 -0800 Subject: [PATCH 046/491] testing other models --- exo/inference/pytorch/test_inference_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index e2b0eaf6..2cdf2b15 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -10,7 +10,7 @@ def setUpClass(cls): # Create a shard cls.shard = Shard( - model_id="LLMQ/LLaMA-3-8B-GPTQ-4bit-b128", + model_id="hoang1123/llama3.1-8b-sum-trans-gguf-q4_k_m", start_layer=0, end_layer=0, n_layers=12 From cae9efb9c4f2d556a1d62c2470f095fa04f7ab7b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:15:56 -0800 Subject: [PATCH 047/491] testing other models --- exo/inference/pytorch/test_inference_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 2cdf2b15..7fce78d8 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -10,7 +10,7 @@ def setUpClass(cls): # Create a shard cls.shard = Shard( - model_id="hoang1123/llama3.1-8b-sum-trans-gguf-q4_k_m", + model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=12 From d4387a357ed73408c83a09f61822f703ce495a03 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:40:31 -0800 Subject: [PATCH 048/491] taking out unittest --- .../pytorch/test_inference_engine.py | 61 ++++++++----------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 7fce78d8..1279ee1e 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -1,43 +1,34 @@ -import unittest + import asyncio from exo.inference.shard import Shard from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine -class TestPyTorchDynamicShardInferenceEngine(unittest.TestCase): - - @classmethod - def setUpClass(cls): - - # Create a shard - cls.shard = Shard( - model_id="meta-llama/Meta-Llama-3.1-8B", - start_layer=0, - end_layer=0, - n_layers=12 - ) - - # Initialize the inference engine - cls.engine = PyTorchDynamicShardInferenceEngine( - cls.shard.model_id, - debug=True - ) - - def test_infer_prompt(self): - # Prepare the prompt - prompt = "Why is the sky blue?" - - # Run inference - loop = asyncio.get_event_loop() - output_data, new_inference_state, is_eos = loop.run_until_complete( - self.engine.infer_prompt( - request_id="test_request", shard=self.shard, prompt=prompt - ) +def main(): + shard = Shard( + model_id="meta-llama/Meta-Llama-3.1-8B", + start_layer=0, + end_layer=0, + n_layers=12 + ) + + engine = PyTorchDynamicShardInferenceEngine( + shard.model_id, + debug=True + ) + + + # Prepare the prompt + prompt = "Why is the sky blue?" + + # Run inference + loop = asyncio.get_event_loop() + output_data, new_inference_state, is_eos = loop.run_until_complete( + engine.infer_prompt( + request_id="test_request", shard=shard, prompt=prompt ) + ) - # Assertions - self.assertIsNotNone(output_data) - self.assertIsNotNone(new_inference_state) - self.assertFalse(is_eos) + assert output_data is not None if __name__ == '__main__': - unittest.main() + main() From f5f5b6f0429ef7603341573a02c9cb8cfaf8bb96 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:41:52 -0800 Subject: [PATCH 049/491] getting to base model for output --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 9da7a4ed..db88f792 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -67,7 +67,7 @@ async def infer_prompt( if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model.full_model(input_ids, past_key_values=past_key_values) + output, past_key_values = self.model.full_model.model(input_ids, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) From 2a8fc820d80d2161c8160e450fdef0eee01863a2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:43:49 -0800 Subject: [PATCH 050/491] working on output issue --- exo/inference/pytorch/inference.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index db88f792..5843a785 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -69,6 +69,10 @@ async def infer_prompt( output, past_key_values = self.model.full_model.model(input_ids, past_key_values=past_key_values) + if self.debug: + self.log.info( + f"\nInfer Prompt Debug - Request ID: {request_id}\nOutput: {output_data}\nEOS: {self.shard.is_last_layer()}") + if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) next_token = torch.argmax(logits[:, -1, :], dim=-1) @@ -80,9 +84,7 @@ async def infer_prompt( new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) - if self.debug: - self.log.info( - f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + return output_data, new_inference_state, is_eos From 9fd5f90274eeb22a7b10e4a03f85906db1e9bf93 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:45:07 -0800 Subject: [PATCH 051/491] working on output issue --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 5843a785..78ebd5c1 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -71,7 +71,7 @@ async def infer_prompt( if self.debug: self.log.info( - f"\nInfer Prompt Debug - Request ID: {request_id}\nOutput: {output_data}\nEOS: {self.shard.is_last_layer()}") + f"\nInfer Prompt Debug - Request ID: {request_id}\nOutput: {output}\nEOS: {self.shard.is_last_layer()}") if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) From de3483db65f45adfbeb93bd0c3bf1806fdfaf0e8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:47:36 -0800 Subject: [PATCH 052/491] working on output issue --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 78ebd5c1..01ab9de1 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -67,7 +67,7 @@ async def infer_prompt( if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model.full_model.model(input_ids, past_key_values=past_key_values) + output, past_key_values = self.model.forward_layers(input_ids, past_key_values=past_key_values) if self.debug: self.log.info( From 7d1598dc2556898c66db277a6e9f5d6a2a681627 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 02:49:19 -0800 Subject: [PATCH 053/491] working on output issue --- exo/inference/pytorch/inference.py | 5 ++++- exo/inference/pytorch/model/hf.py | 11 ++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 01ab9de1..d4e4a340 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -67,7 +67,10 @@ async def infer_prompt( if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model.forward_layers(input_ids, past_key_values=past_key_values) + output, past_key_values = self.model.forward_layers( + input_ids, + past_key_values=past_key_values + ) if self.debug: self.log.info( diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 3eeb053a..0f9c85fe 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, Cache from exo.inference.shard import Shard import logging @@ -43,7 +43,7 @@ def forward_layers(self, input_ids, past_key_values=None): tuple: Hidden states and new past key values. """ if past_key_values is None: - past_key_values = [None] * len(self.layers) + past_key_values = Cache() # Token embeddings hidden_states = self.embed_tokens(input_ids) @@ -52,7 +52,12 @@ def forward_layers(self, input_ids, past_key_values=None): new_past_key_values = [] for i, layer in enumerate(self.layers): layer_past = past_key_values[i] - hidden_states, new_layer_past = layer(hidden_states, past_key_values=layer_past, use_cache=True) + hidden_states, new_layer_past = layer( + hidden_states, + past_key_values=layer_past, + use_cache=True + ) + new_past_key_values.append(new_layer_past) if self.shard.is_last_layer(): From 6068949f289070f6a6234f14003a81405a219d22 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 03:15:01 -0800 Subject: [PATCH 054/491] updated kv caching to dynamic --- exo/inference/pytorch/inference.py | 48 ++++++++++++++++-------------- exo/inference/pytorch/model/hf.py | 18 ++++------- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index d4e4a340..3cff7738 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -6,7 +6,7 @@ import torch.nn as nn import numpy as np from typing import Optional, Callable, Tuple -from transformers import AutoTokenizer, Cache +from transformers import AutoTokenizer, DynamicCache from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel @@ -60,21 +60,14 @@ async def infer_prompt( """ await self.ensure_shard(shard) - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to("cpu") + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) # Continue the sequence if inference state exists past_key_values = None if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model.forward_layers( - input_ids, - past_key_values=past_key_values - ) - - if self.debug: - self.log.info( - f"\nInfer Prompt Debug - Request ID: {request_id}\nOutput: {output}\nEOS: {self.shard.is_last_layer()}") + output, past_key_values = self.model.forward_layers(input_ids, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) @@ -85,9 +78,10 @@ async def infer_prompt( output_data = output.cpu().numpy() is_eos = False - new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) + new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) - + if self.debug: + self.log.info(f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") return output_data, new_inference_state, is_eos @@ -111,14 +105,14 @@ async def infer_tensor( """ await self.ensure_shard(shard) - input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.device) + input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) # Continue the sequence if inference state exists past_key_values = None if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model(input_tensor, past_key_values=past_key_values) + output, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) @@ -129,7 +123,7 @@ async def infer_tensor( output_data = output.cpu().numpy() is_eos = False - new_inference_state = json.dumps({"past_key_values": self._save_kv_cache(past_key_values)}) + new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) if self.debug: self.log.info(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") @@ -162,13 +156,17 @@ def _load_kv_cache(self, past_key_values_list): past_key_values_list (list): List of past key-value tensors. Returns: - Cache: Loaded past key-value cache. + DynamicCache: Loaded past key-value cache. """ if past_key_values_list is None: - return Cache() - cache = Cache() - for kv in past_key_values_list: - cache.append(torch.tensor(kv, device=self.device)) + return DynamicCache() + + cache = DynamicCache() + for layer_idx, (key_states, value_states) in enumerate(past_key_values_list): + key_states_tensor = torch.tensor(key_states, device=self.device) + value_states_tensor = torch.tensor(value_states, device=self.device) + cache.update(key_states_tensor, value_states_tensor, layer_idx) + return cache def _save_kv_cache(self, past_key_values): @@ -176,12 +174,18 @@ def _save_kv_cache(self, past_key_values): Save key-value cache to the inference state. Args: - past_key_values (list): List of past key-value tensors. + past_key_values (DynamicCache): Cache object containing past key-value tensors. Returns: list: List of key-value tensors in a format suitable for saving. """ - return [kv.cpu().tolist() for kv in past_key_values] + past_key_values_list = [] + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + past_key_values_list.append((key_states.cpu().tolist(), value_states.cpu().tolist())) + + return past_key_values_list + async def ensure_shard(self, shard: Optional[Shard]): """ diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 0f9c85fe..9e454360 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,18 +1,13 @@ -# Work in progress on a generic hugging face model sharder -# right now doesn't work with all models - import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, Cache +from transformers import AutoModelForCausalLM, DynamicCache from exo.inference.shard import Shard -import logging class ShardedHuggingFaceModel(nn.Module): def __init__(self, model_name: str, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard - self.device_ids = list(range(torch.cuda.device_count())) # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( @@ -37,19 +32,19 @@ def forward_layers(self, input_ids, past_key_values=None): Args: input_ids (torch.Tensor): Input token IDs. - past_key_values (list, optional): Past key values for caching. + past_key_values (DynamicCache, optional): Past key values for caching. Returns: tuple: Hidden states and new past key values. """ if past_key_values is None: - past_key_values = Cache() + past_key_values = DynamicCache() # Token embeddings hidden_states = self.embed_tokens(input_ids) # Apply each layer in this shard - new_past_key_values = [] + new_past_key_values = DynamicCache() for i, layer in enumerate(self.layers): layer_past = past_key_values[i] hidden_states, new_layer_past = layer( @@ -57,12 +52,11 @@ def forward_layers(self, input_ids, past_key_values=None): past_key_values=layer_past, use_cache=True ) - - new_past_key_values.append(new_layer_past) + new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) if self.shard.is_last_layer(): hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) return logits, new_past_key_values else: - return hidden_states, new_past_key_values \ No newline at end of file + return hidden_states, new_past_key_values From a9d8d646896b9e079f1a402b032e0da4ea7278c4 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 03:19:11 -0800 Subject: [PATCH 055/491] updated kv caching to dynamic --- exo/inference/pytorch/inference.py | 18 +++++++++--------- exo/inference/pytorch/model/hf.py | 7 ++++++- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 3cff7738..a665a0c6 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -158,14 +158,12 @@ def _load_kv_cache(self, past_key_values_list): Returns: DynamicCache: Loaded past key-value cache. """ - if past_key_values_list is None: - return DynamicCache() - cache = DynamicCache() - for layer_idx, (key_states, value_states) in enumerate(past_key_values_list): - key_states_tensor = torch.tensor(key_states, device=self.device) - value_states_tensor = torch.tensor(value_states, device=self.device) - cache.update(key_states_tensor, value_states_tensor, layer_idx) + if past_key_values_list is not None: + for layer_idx, (key_states, value_states) in enumerate(past_key_values_list): + key_states_tensor = torch.tensor(key_states, device=self.device) + value_states_tensor = torch.tensor(value_states, device=self.device) + cache.update(key_states_tensor, value_states_tensor, layer_idx) return cache @@ -182,11 +180,13 @@ def _save_kv_cache(self, past_key_values): past_key_values_list = [] for layer_idx in range(len(past_key_values)): key_states, value_states = past_key_values[layer_idx] - past_key_values_list.append((key_states.cpu().tolist(), value_states.cpu().tolist())) + past_key_values_list.append(( + key_states.cpu().tolist(), + value_states.cpu().tolist() + )) return past_key_values_list - async def ensure_shard(self, shard: Optional[Shard]): """ Ensure the model shard is loaded and ready for inference. diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 9e454360..87824aac 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -46,12 +46,17 @@ def forward_layers(self, input_ids, past_key_values=None): # Apply each layer in this shard new_past_key_values = DynamicCache() for i, layer in enumerate(self.layers): - layer_past = past_key_values[i] + if i < len(past_key_values): + layer_past = past_key_values[i] + else: + layer_past = None + hidden_states, new_layer_past = layer( hidden_states, past_key_values=layer_past, use_cache=True ) + new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) if self.shard.is_last_layer(): From 6dca880b9166de3b4971dd9cc0549580ea2b61bb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 03:27:38 -0800 Subject: [PATCH 056/491] fixing position id error --- exo/inference/pytorch/model/hf.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 87824aac..63df3cdb 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -40,8 +40,11 @@ def forward_layers(self, input_ids, past_key_values=None): if past_key_values is None: past_key_values = DynamicCache() - # Token embeddings - hidden_states = self.embed_tokens(input_ids) + # Token and position embeddings + position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + self.full_model.model.embed_positions(position_ids) # Apply each layer in this shard new_past_key_values = DynamicCache() @@ -50,13 +53,14 @@ def forward_layers(self, input_ids, past_key_values=None): layer_past = past_key_values[i] else: layer_past = None - + hidden_states, new_layer_past = layer( - hidden_states, - past_key_values=layer_past, - use_cache=True + hidden_states, + past_key_values=layer_past, + use_cache=True, + position_ids=position_ids ) - + new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) if self.shard.is_last_layer(): @@ -64,4 +68,4 @@ def forward_layers(self, input_ids, past_key_values=None): logits = self.lm_head(hidden_states) return logits, new_past_key_values else: - return hidden_states, new_past_key_values + return hidden_states, new_past_key_values \ No newline at end of file From c7df760206dd051093b1e8738a3bcfc9acbde193 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 03:30:06 -0800 Subject: [PATCH 057/491] fixing position id error --- exo/inference/pytorch/model/hf.py | 32 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 63df3cdb..609f963d 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -2,12 +2,14 @@ import torch.nn as nn from transformers import AutoModelForCausalLM, DynamicCache from exo.inference.shard import Shard +import logging class ShardedHuggingFaceModel(nn.Module): def __init__(self, model_name: str, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard + self.device_ids = list(range(torch.cuda.device_count())) # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( @@ -40,32 +42,30 @@ def forward_layers(self, input_ids, past_key_values=None): if past_key_values is None: past_key_values = DynamicCache() - # Token and position embeddings - position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + # Token embeddings inputs_embeds = self.embed_tokens(input_ids) + + # Generate position ids if not given + position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + # Apply positional embeddings hidden_states = inputs_embeds + self.full_model.model.embed_positions(position_ids) # Apply each layer in this shard - new_past_key_values = DynamicCache() + new_past_key_values = [] for i, layer in enumerate(self.layers): - if i < len(past_key_values): - layer_past = past_key_values[i] - else: - layer_past = None - + layer_past = past_key_values[i] if i < len(past_key_values) else None hidden_states, new_layer_past = layer( - hidden_states, - past_key_values=layer_past, - use_cache=True, - position_ids=position_ids + hidden_states, + past_key_values=layer_past, + use_cache=True ) - - new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) + new_past_key_values.append(new_layer_past) if self.shard.is_last_layer(): hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) return logits, new_past_key_values else: - return hidden_states, new_past_key_values \ No newline at end of file + return hidden_states, new_past_key_values From 29470a1af28aea2b082eb88881b0d402a6807372 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 03:37:02 -0800 Subject: [PATCH 058/491] fixing position id error --- exo/inference/pytorch/model/hf.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 609f963d..fd21f028 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -2,7 +2,6 @@ import torch.nn as nn from transformers import AutoModelForCausalLM, DynamicCache from exo.inference.shard import Shard -import logging class ShardedHuggingFaceModel(nn.Module): def __init__(self, model_name: str, shard: Shard): @@ -45,21 +44,20 @@ def forward_layers(self, input_ids, past_key_values=None): # Token embeddings inputs_embeds = self.embed_tokens(input_ids) - # Generate position ids if not given + # Generate position ids position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - # Apply positional embeddings - hidden_states = inputs_embeds + self.full_model.model.embed_positions(position_ids) - # Apply each layer in this shard + hidden_states = inputs_embeds new_past_key_values = [] for i, layer in enumerate(self.layers): layer_past = past_key_values[i] if i < len(past_key_values) else None hidden_states, new_layer_past = layer( - hidden_states, - past_key_values=layer_past, - use_cache=True + hidden_states, + past_key_values=layer_past, + use_cache=True, + position_ids=position_ids ) new_past_key_values.append(new_layer_past) From 2b9330ed3c4293fa04d449d2c9ba931babc1a228 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 03:38:08 -0800 Subject: [PATCH 059/491] fixing tensor call --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a665a0c6..36f1ab4c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -120,7 +120,7 @@ async def infer_tensor( output_data = np.array([next_token.item()]) is_eos = next_token.item() == self.tokenizer.eos_token_id else: - output_data = output.cpu().numpy() + output_data = output.detach().numpy() is_eos = False new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) From a61be5c23a84ad64299414d79db05ff0588c9116 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 03:39:06 -0800 Subject: [PATCH 060/491] fixing tensor call --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 36f1ab4c..3644113a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -75,7 +75,7 @@ async def infer_prompt( output_data = np.array([next_token.item()]) is_eos = next_token.item() == self.tokenizer.eos_token_id else: - output_data = output.cpu().numpy() + output_data = output.detach().numpy() is_eos = False new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) From 8f93296b4b22d3accb816ae197a5b8a2482766c7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 03:41:19 -0800 Subject: [PATCH 061/491] fixing tensor call --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 3644113a..c2b10e21 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -75,7 +75,7 @@ async def infer_prompt( output_data = np.array([next_token.item()]) is_eos = next_token.item() == self.tokenizer.eos_token_id else: - output_data = output.detach().numpy() + output_data = output.cpu().detach().numpy() is_eos = False new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) @@ -120,7 +120,7 @@ async def infer_tensor( output_data = np.array([next_token.item()]) is_eos = next_token.item() == self.tokenizer.eos_token_id else: - output_data = output.detach().numpy() + output_data = output.cpu().detach().numpy() is_eos = False new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) From b0287b6673ddbe3d57e633ae5b680086b8abc233 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 04:06:38 -0800 Subject: [PATCH 062/491] fixing cache issue --- exo/inference/pytorch/inference.py | 12 +++++------- exo/inference/pytorch/model/hf.py | 4 ++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index c2b10e21..6f1ec9a7 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -148,7 +148,7 @@ def _apply_generation_settings(self, logits, temperature, top_k): logits = logits.scatter(1, top_k_indices, top_k_values) return logits - def _load_kv_cache(self, past_key_values_list): + def _load_kv_cache(self, past_key_values_list) -> DynamicCache: """ Load key-value cache from the inference state. @@ -158,13 +158,11 @@ def _load_kv_cache(self, past_key_values_list): Returns: DynamicCache: Loaded past key-value cache. """ + if past_key_values_list is None: + return DynamicCache() cache = DynamicCache() - if past_key_values_list is not None: - for layer_idx, (key_states, value_states) in enumerate(past_key_values_list): - key_states_tensor = torch.tensor(key_states, device=self.device) - value_states_tensor = torch.tensor(value_states, device=self.device) - cache.update(key_states_tensor, value_states_tensor, layer_idx) - + for layer_idx, (key, value) in enumerate(past_key_values_list): + cache.update(torch.tensor(key, device=self.device), torch.tensor(value, device=self.device), layer_idx) return cache def _save_kv_cache(self, past_key_values): diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index fd21f028..8f1d79f4 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -50,7 +50,7 @@ def forward_layers(self, input_ids, past_key_values=None): # Apply each layer in this shard hidden_states = inputs_embeds - new_past_key_values = [] + new_past_key_values = DynamicCache() for i, layer in enumerate(self.layers): layer_past = past_key_values[i] if i < len(past_key_values) else None hidden_states, new_layer_past = layer( @@ -59,7 +59,7 @@ def forward_layers(self, input_ids, past_key_values=None): use_cache=True, position_ids=position_ids ) - new_past_key_values.append(new_layer_past) + new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) if self.shard.is_last_layer(): hidden_states = self.norm(hidden_states) From d19ac65d99256541db5031ecc0d13e744516279c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 04:09:43 -0800 Subject: [PATCH 063/491] fixing cache issue --- exo/inference/pytorch/inference.py | 24 +++++++++++++----------- exo/inference/pytorch/model/hf.py | 3 ++- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 6f1ec9a7..b3a12e3e 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -85,19 +85,19 @@ async def infer_prompt( return output_data, new_inference_state, is_eos - async def infer_tensor( - self, - request_id: str, - shard: Optional[Shard], - input_data: np.ndarray, - inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + async def infer_prompt( + self, + request_id: str, + shard: Optional[Shard], + prompt: str, + inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: """ - Perform inference based on an input tensor. + Perform inference based on a text prompt. Args: request_id (str): Unique identifier for the request. shard (Optional[Shard]): Shard information for the model. - input_data (np.ndarray): The input tensor for inference. + prompt (str): The input text prompt for inference. inference_state (Optional[str]): The previous inference state. Returns: @@ -105,14 +105,14 @@ async def infer_tensor( """ await self.ensure_shard(shard) - input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) # Continue the sequence if inference state exists past_key_values = None if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) + output, past_key_values = self.model.forward_layers(input_ids, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) @@ -126,7 +126,8 @@ async def infer_tensor( new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) if self.debug: - self.log.info(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + self.log.info( + f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") return output_data, new_inference_state, is_eos @@ -160,6 +161,7 @@ def _load_kv_cache(self, past_key_values_list) -> DynamicCache: """ if past_key_values_list is None: return DynamicCache() + cache = DynamicCache() for layer_idx, (key, value) in enumerate(past_key_values_list): cache.update(torch.tensor(key, device=self.device), torch.tensor(value, device=self.device), layer_idx) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 8f1d79f4..efb98309 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -59,7 +59,8 @@ def forward_layers(self, input_ids, past_key_values=None): use_cache=True, position_ids=position_ids ) - new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) + if new_layer_past is not None: + new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) if self.shard.is_last_layer(): hidden_states = self.norm(hidden_states) From f19498efb69c141214d6b4bb7a3993b805213227 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 04:11:28 -0800 Subject: [PATCH 064/491] fixing cache issue --- exo/inference/pytorch/inference.py | 36 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index b3a12e3e..cb359f6a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -41,11 +41,11 @@ def __init__(self, model_name: str, debug: bool = True): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") async def infer_prompt( - self, - request_id: str, - shard: Optional[Shard], - prompt: str, - inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + self, + request_id: str, + shard: Optional[Shard], + prompt: str, + inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: """ Perform inference based on a text prompt. @@ -81,23 +81,24 @@ async def infer_prompt( new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) if self.debug: - self.log.info(f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + self.log.info( + f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") return output_data, new_inference_state, is_eos - async def infer_prompt( - self, - request_id: str, - shard: Optional[Shard], - prompt: str, - inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + async def infer_tensor( + self, + request_id: str, + shard: Optional[Shard], + input_data: np.ndarray, + inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: """ - Perform inference based on a text prompt. + Perform inference based on an input tensor. Args: request_id (str): Unique identifier for the request. shard (Optional[Shard]): Shard information for the model. - prompt (str): The input text prompt for inference. + input_data (np.ndarray): The input tensor for inference. inference_state (Optional[str]): The previous inference state. Returns: @@ -105,14 +106,14 @@ async def infer_prompt( """ await self.ensure_shard(shard) - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) + input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) # Continue the sequence if inference state exists past_key_values = None if inference_state: past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - output, past_key_values = self.model.forward_layers(input_ids, past_key_values=past_key_values) + output, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) if self.shard.is_last_layer(): logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) @@ -126,8 +127,7 @@ async def infer_prompt( new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) if self.debug: - self.log.info( - f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + self.log.info(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") return output_data, new_inference_state, is_eos From 142649cb519a3bbd250142a187e9e760b4bd6547 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 04:12:58 -0800 Subject: [PATCH 065/491] cleaning logging --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index cb359f6a..62fdc377 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -82,7 +82,7 @@ async def infer_prompt( if self.debug: self.log.info( - f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + f"\nInfer Prompt Debug - Request ID: {request_id}\nOutput: {output_data}\nEOS: {is_eos}") return output_data, new_inference_state, is_eos From 752ebb4bf06157446e46d00dbdd2eec24d090a57 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 13:57:39 -0800 Subject: [PATCH 066/491] updating device capabilities for other NVIDIA cards in current env --- exo/topology/device_capabilities.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index 02e26d72..f64066ef 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -75,6 +75,7 @@ def to_dict(self): # RTX 30 series "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11 * TFLOPS, fp16=18.22 * TFLOPS, int8=36.44 * TFLOPS), "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0 * TFLOPS, fp16=26.0 * TFLOPS, int8=52.0 * TFLOPS), + "NVIDIA GEFORCE RTX 3060 LAPTOP GPU": DeviceFlops(fp32=12.7 * TFLOPS, fp16=25.4 * TFLOPS, int8=50.8 * TFLOPS), "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS), "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3 * TFLOPS, fp16=40.6 * TFLOPS, int8=81.2 * TFLOPS), "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8 * TFLOPS, fp16=43.6 * TFLOPS, int8=87.2 * TFLOPS), @@ -91,6 +92,7 @@ def to_dict(self): "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), + "NVIDIA T1000 8GB": DeviceFlops(fp32=2.5 * TFLOPS, fp16=5.0 * TFLOPS, int8=10.0 * TFLOPS), # ... add more devices if needed ... ### AMD GPUs # RX 6000 series From 3916feceb91ad333946d03fd7c8489515dfd08fc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 14:00:59 -0800 Subject: [PATCH 067/491] adding more low level devices --- exo/topology/device_capabilities.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index f64066ef..519a6f16 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -93,6 +93,8 @@ def to_dict(self): "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS), "NVIDIA T1000 8GB": DeviceFlops(fp32=2.5 * TFLOPS, fp16=5.0 * TFLOPS, int8=10.0 * TFLOPS), + "Quadro M2000": DeviceFlops(fp32=0.5 * TFLOPS, fp16=1.0 * TFLOPS, int8=2.0 * TFLOPS), + "Quadro P400": DeviceFlops(fp32=0.641 * TFLOPS, fp16=1.282 * TFLOPS, int8=2.564 * TFLOPS), # ... add more devices if needed ... ### AMD GPUs # RX 6000 series From cd9515f3bdae971439739e1e620b171bec59a509 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 14:16:41 -0800 Subject: [PATCH 068/491] updating pytorch inference, adding pytorch to inference selection --- exo/inference/pytorch/inference.py | 9 +++------ exo/inference/pytorch/model/hf.py | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 62fdc377..97ea56b3 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -24,7 +24,7 @@ class PyTorchDynamicShardInferenceEngine(InferenceEngine): PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. """ - def __init__(self, model_name: str, debug: bool = True): + def __init__(self, debug: bool = True): """ Initialize the inference engine. @@ -33,11 +33,8 @@ def __init__(self, model_name: str, debug: bool = True): """ self.shard = None self.model = None - self.model_name = model_name if model_name else "meta-llama/Meta-Llama-3-8B" self.debug = debug self.log = logging.getLogger("pytorch.inference") - self.rank = int(os.getenv("RANK", "0")) - self.world_size = int(os.getenv("WORLD_SIZE", "1")) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") async def infer_prompt( @@ -200,8 +197,8 @@ async def ensure_shard(self, shard: Optional[Shard]): # Load model and tokenizer from the downloaded files # This is written for llama model but need to add in option for others if not self.model: - self.model = ShardedHuggingFaceModel(self.model_name, shard) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model = ShardedHuggingFaceModel(shard) + self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) self.shard = shard diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index efb98309..b2bc1b78 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -4,7 +4,7 @@ from exo.inference.shard import Shard class ShardedHuggingFaceModel(nn.Module): - def __init__(self, model_name: str, shard: Shard): + def __init__(self, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard @@ -12,7 +12,7 @@ def __init__(self, model_name: str, shard: Shard): # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( - model_name, + shard.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) From 3534cbcc2d42f9a98bdfe5d0521b2126272f0c20 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 14:17:08 -0800 Subject: [PATCH 069/491] adding pytorch engine to helpers.py --- exo/helpers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/exo/helpers.py b/exo/helpers.py index 2b4027a4..47b4dc95 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -42,6 +42,9 @@ def get_inference_engine(inference_engine_name): tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) return TinygradDynamicShardInferenceEngine() + elif inference_engine_name == "pytorch": + from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine + return PyTorchDynamicShardInferenceEngine() else: raise ValueError(f"Inference engine {inference_engine_name} not supported") From d0b7e99f2bd273f762b3eb58900cebe8e57af647 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 15:03:04 -0800 Subject: [PATCH 070/491] updating inference_state bug --- exo/api/chatgpt_api.py | 1 + exo/helpers.py | 3 +- exo/inference/pytorch/inference.py | 105 ++++++++++------------------- 3 files changed, 37 insertions(+), 72 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 87390b7f..761aabc9 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -16,6 +16,7 @@ ### llama "llama-3.1-8b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=32), }, "llama-3.1-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), diff --git a/exo/helpers.py b/exo/helpers.py index 47b4dc95..b811a0f9 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -43,8 +43,9 @@ def get_inference_engine(inference_engine_name): return TinygradDynamicShardInferenceEngine() elif inference_engine_name == "pytorch": + # will change from debug being true after testing from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine - return PyTorchDynamicShardInferenceEngine() + return PyTorchDynamicShardInferenceEngine(debug=os.getenv("PYTORCH_DEBUG", default=True)) else: raise ValueError(f"Inference engine {inference_engine_name} not supported") diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 97ea56b3..ed9af7a7 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -10,10 +10,6 @@ from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel -import logging - -logging.basicConfig() -logging.getLogger("pytorch.inference").setLevel(logging.DEBUG) # Default settings TEMPERATURE = 0.7 @@ -34,15 +30,15 @@ def __init__(self, debug: bool = True): self.shard = None self.model = None self.debug = debug - self.log = logging.getLogger("pytorch.inference") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") async def infer_prompt( - self, - request_id: str, - shard: Optional[Shard], - prompt: str, - inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + self, + request_id: str, + shard: Optional[Shard] = None, + prompt: str = "", + image_str: Optional[str] = None, + inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: """ Perform inference based on a text prompt. @@ -50,6 +46,7 @@ async def infer_prompt( request_id (str): Unique identifier for the request. shard (Optional[Shard]): Shard information for the model. prompt (str): The input text prompt for inference. + image_str (Optional[str]): Optional image string for multi-modal models. inference_state (Optional[str]): The previous inference state. Returns: @@ -57,37 +54,27 @@ async def infer_prompt( """ await self.ensure_shard(shard) - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) - - # Continue the sequence if inference state exists - past_key_values = None - if inference_state: - past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - - output, past_key_values = self.model.forward_layers(input_ids, past_key_values=past_key_values) + toks = self.tokenizer.encode(prompt) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - if self.shard.is_last_layer(): - logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) - next_token = torch.argmax(logits[:, -1, :], dim=-1) - output_data = np.array([next_token.item()]) - is_eos = next_token.item() == self.tokenizer.eos_token_id - else: - output_data = output.cpu().detach().numpy() - is_eos = False + start_pos = self.model.prefill(self.model, toks[:-1], start_pos=start_pos) + last_tok = toks[-1] - new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) + output_data = np.array([self.model.forward_layers(torch.tensor([[last_tok]], device=self.model.device), start_pos=start_pos, temperature=TEMPERATURE, top_k=TOP_K).tolist()]) + if output_data.size == 1: + start_pos += 1 - if self.debug: - self.log.info( - f"\nInfer Prompt Debug - Request ID: {request_id}\nOutput: {output_data}\nEOS: {is_eos}") - - return output_data, new_inference_state, is_eos + return ( + output_data, + json.dumps({"start_pos": start_pos}), + output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], + ) async def infer_tensor( self, request_id: str, - shard: Optional[Shard], - input_data: np.ndarray, + shard: Optional[Shard] = None, + input_data: np.ndarray = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: """ Perform inference based on an input tensor. @@ -103,30 +90,17 @@ async def infer_tensor( """ await self.ensure_shard(shard) - input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) - - # Continue the sequence if inference state exists - past_key_values = None - if inference_state: - past_key_values = self._load_kv_cache(json.loads(inference_state).get("past_key_values")) - - output, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) - - if self.shard.is_last_layer(): - logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) - next_token = torch.argmax(logits[:, -1, :], dim=-1) - output_data = np.array([next_token.item()]) - is_eos = next_token.item() == self.tokenizer.eos_token_id - else: - output_data = output.cpu().detach().numpy() - is_eos = False + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - new_inference_state = json.dumps({"past_key_values": past_key_values.to_legacy_cache()}) + output_data = np.array([self.model.forward_layers(torch.tensor([input_data], device=self.model.device), start_pos=start_pos, temperature=TEMPERATURE, top_k=TOP_K).tolist()]) + if output_data.size == 1: + start_pos += 1 - if self.debug: - self.log.info(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") - - return output_data, new_inference_state, is_eos + return ( + output_data, + json.dumps({"start_pos": start_pos}), + output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], + ) def _apply_generation_settings(self, logits, temperature, top_k): """ @@ -146,7 +120,7 @@ def _apply_generation_settings(self, logits, temperature, top_k): logits = logits.scatter(1, top_k_indices, top_k_values) return logits - def _load_kv_cache(self, past_key_values_list) -> DynamicCache: + def _load_kv_cache(self, past_key_values_list): """ Load key-value cache from the inference state. @@ -158,10 +132,7 @@ def _load_kv_cache(self, past_key_values_list) -> DynamicCache: """ if past_key_values_list is None: return DynamicCache() - - cache = DynamicCache() - for layer_idx, (key, value) in enumerate(past_key_values_list): - cache.update(torch.tensor(key, device=self.device), torch.tensor(value, device=self.device), layer_idx) + cache = DynamicCache.from_legacy_cache(past_key_values_list) return cache def _save_kv_cache(self, past_key_values): @@ -169,20 +140,12 @@ def _save_kv_cache(self, past_key_values): Save key-value cache to the inference state. Args: - past_key_values (DynamicCache): Cache object containing past key-value tensors. + past_key_values (DynamicCache): Past key-value cache. Returns: list: List of key-value tensors in a format suitable for saving. """ - past_key_values_list = [] - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - past_key_values_list.append(( - key_states.cpu().tolist(), - value_states.cpu().tolist() - )) - - return past_key_values_list + return past_key_values.to_legacy_cache() async def ensure_shard(self, shard: Optional[Shard]): """ From 80ad4d70e6977d2c369b7a8bf70abc185984f50f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 15:07:18 -0800 Subject: [PATCH 071/491] adding prefill --- exo/inference/pytorch/inference.py | 16 ++++++++------ exo/inference/pytorch/model/hf.py | 34 +++++++++++++++++++++++++----- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ed9af7a7..5d32b8ff 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -57,16 +57,18 @@ async def infer_prompt( toks = self.tokenizer.encode(prompt) start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - start_pos = self.model.prefill(self.model, toks[:-1], start_pos=start_pos) - last_tok = toks[-1] + hidden_states, past_key_values = self.model.prefill(self.model, torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) + last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) + + output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) + output_data = output_data.detach().cpu().numpy() - output_data = np.array([self.model.forward_layers(torch.tensor([[last_tok]], device=self.model.device), start_pos=start_pos, temperature=TEMPERATURE, top_k=TOP_K).tolist()]) if output_data.size == 1: start_pos += 1 return ( output_data, - json.dumps({"start_pos": start_pos}), + json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], ) @@ -92,13 +94,15 @@ async def infer_tensor( start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - output_data = np.array([self.model.forward_layers(torch.tensor([input_data], device=self.model.device), start_pos=start_pos, temperature=TEMPERATURE, top_k=TOP_K).tolist()]) + output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), start_pos=start_pos) + output_data = output_data.detach().cpu().numpy() + if output_data.size == 1: start_pos += 1 return ( output_data, - json.dumps({"start_pos": start_pos}), + json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index b2bc1b78..edae36f7 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -4,15 +4,14 @@ from exo.inference.shard import Shard class ShardedHuggingFaceModel(nn.Module): - def __init__(self, shard: Shard): + def __init__(self, model_name: str, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard - self.device_ids = list(range(torch.cuda.device_count())) # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( - shard.model_id, + model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) @@ -27,6 +26,32 @@ def __init__(self, shard: Shard): self.norm = self.full_model.model.norm self.lm_head = self.full_model.lm_head + def prefill(self, model, tokens, start_pos=0): + """ + Process the initial input tokens and set up the initial hidden states and key-value caches. + """ + # Token embeddings + inputs_embeds = self.embed_tokens(tokens) + + # Generate position ids + position_ids = torch.arange(start_pos, start_pos + tokens.shape[-1], dtype=torch.long, device=tokens.device) + position_ids = position_ids.unsqueeze(0).expand_as(tokens) + + # Apply each layer in this shard + hidden_states = inputs_embeds + past_key_values = [] + for i, layer in enumerate(self.layers): + layer_past = None + hidden_states, new_layer_past = layer( + hidden_states, + past_key_values=layer_past, + use_cache=True, + position_ids=position_ids + ) + past_key_values.append(new_layer_past) + + return hidden_states, past_key_values + def forward_layers(self, input_ids, past_key_values=None): """ Forward pass through the specified layers. @@ -59,8 +84,7 @@ def forward_layers(self, input_ids, past_key_values=None): use_cache=True, position_ids=position_ids ) - if new_layer_past is not None: - new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) + new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) if self.shard.is_last_layer(): hidden_states = self.norm(hidden_states) From 2975929d35c963b0e06c815d6b18b4625fa65eef Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 15:10:44 -0800 Subject: [PATCH 072/491] fixing hugging face sharded class --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index edae36f7..3983e00f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -4,14 +4,14 @@ from exo.inference.shard import Shard class ShardedHuggingFaceModel(nn.Module): - def __init__(self, model_name: str, shard: Shard): + def __init__(self, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( - model_name, + shard.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) From 2e72367d564f4091dd96242015c2e376bcef7a14 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 15:14:52 -0800 Subject: [PATCH 073/491] fixing tensor shape issue --- exo/inference/pytorch/inference.py | 2 +- exo/inference/pytorch/model/hf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 5d32b8ff..09de6d22 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -94,7 +94,7 @@ async def infer_tensor( start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), start_pos=start_pos) + output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), past_key_values=past_key_values) output_data = output_data.detach().cpu().numpy() if output_data.size == 1: diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 3983e00f..1b7e751c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -71,7 +71,7 @@ def forward_layers(self, input_ids, past_key_values=None): # Generate position ids position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + position_ids = position_ids.unsqueeze(0).expand(input_ids.shape[0], -1) # Apply each layer in this shard hidden_states = inputs_embeds From 9ed779ab4e050d99ec3ae8ec67b59ad9e006f176 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 15:19:03 -0800 Subject: [PATCH 074/491] fixing tensor shape issue --- exo/api/chatgpt_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 761aabc9..65c6d40e 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -16,7 +16,7 @@ ### llama "llama-3.1-8b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), - "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=32), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=12), }, "llama-3.1-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), From 86994c352b6a0dc2390fd0d5fc17e93a58721e69 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 17:11:46 -0800 Subject: [PATCH 075/491] fixing tensor shape issue --- exo/inference/pytorch/inference.py | 91 ++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 23 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 09de6d22..647f5ac5 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -52,25 +52,48 @@ async def infer_prompt( Returns: Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. """ + # await self.ensure_shard(shard) + + # toks = self.tokenizer.encode(prompt) + # start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + + # hidden_states, past_key_values = self.model.prefill(self.model, torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) + # last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) + + # output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) + # output_data = output_data.detach().cpu().numpy() + + # if output_data.size == 1: + # start_pos += 1 + + # return ( + # output_data, + # json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), + # output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], + # ) await self.ensure_shard(shard) - toks = self.tokenizer.encode(prompt) - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) + + output = self.model.forward_layers(input_ids) - hidden_states, past_key_values = self.model.prefill(self.model, torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) - last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) + if self.shard.is_last_layer(): + logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) + next_token = torch.argmax(logits[:, -1, :], dim=-1) + output_data = np.array([next_token.item()]) + is_eos = next_token.item() == self.tokenizer.eos_token_id + else: + output_data = output.detach().cpu().numpy() + is_eos = False - output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) - output_data = output_data.detach().cpu().numpy() + new_inference_state = json.dumps({"past_key_values": []}) - if output_data.size == 1: - start_pos += 1 + if self.debug: + self.log.info( + f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") - return ( - output_data, - json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), - output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], - ) + return output_data, new_inference_state, is_eos + async def infer_tensor( self, @@ -90,21 +113,43 @@ async def infer_tensor( Returns: Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. """ + # await self.ensure_shard(shard) + + # start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + + # output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), past_key_values=past_key_values) + # output_data = output_data.detach().cpu().numpy() + + # if output_data.size == 1: + # start_pos += 1 + + # return ( + # output_data, + # json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), + # output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], + # ) + await self.ensure_shard(shard) - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) + + output = self.model.forward_layers(input_tensor) + + if self.shard.is_last_layer(): + logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) + next_token = torch.argmax(logits[:, -1, :], dim=-1) + output_data = np.array([next_token.item()]) + is_eos = next_token.item() == self.tokenizer.eos_token_id + else: + output_data = output.detach().cpu().numpy() + is_eos = False - output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), past_key_values=past_key_values) - output_data = output_data.detach().cpu().numpy() + new_inference_state = json.dumps({"past_key_values": []}) - if output_data.size == 1: - start_pos += 1 + if self.debug: + self.log.info(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") - return ( - output_data, - json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), - output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], - ) + return output_data, new_inference_state, is_eos def _apply_generation_settings(self, logits, temperature, top_k): """ From 8b266d81e049ce620271efaa604215499d34dd56 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 17:32:50 -0800 Subject: [PATCH 076/491] fixing tensor shape issue --- exo/inference/pytorch/inference.py | 367 +++++++++++++++++++---------- exo/inference/pytorch/model/hf.py | 132 +++++++++-- 2 files changed, 367 insertions(+), 132 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 647f5ac5..eb26a74f 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,12 +1,233 @@ -# experimental, based off of tinygrad/inference.py +# # experimental, based off of tinygrad/inference.py + +# import os +# import json +# import torch +# import torch.nn as nn +# import numpy as np +# from typing import Optional, Callable, Tuple +# from transformers import AutoTokenizer, DynamicCache +# from exo.inference.shard import Shard +# from exo.inference.inference_engine import InferenceEngine +# from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel + +# # Default settings +# TEMPERATURE = 0.7 +# TOP_K = 50 + +# class PyTorchDynamicShardInferenceEngine(InferenceEngine): +# """ +# PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. +# """ + +# def __init__(self, debug: bool = True): +# """ +# Initialize the inference engine. + +# Args: +# debug (bool): If True, enables debug logging. Defaults to False. +# """ +# self.shard = None +# self.model = None +# self.debug = debug +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# async def infer_prompt( +# self, +# request_id: str, +# shard: Optional[Shard] = None, +# prompt: str = "", +# image_str: Optional[str] = None, +# inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: +# """ +# Perform inference based on a text prompt. + +# Args: +# request_id (str): Unique identifier for the request. +# shard (Optional[Shard]): Shard information for the model. +# prompt (str): The input text prompt for inference. +# image_str (Optional[str]): Optional image string for multi-modal models. +# inference_state (Optional[str]): The previous inference state. + +# Returns: +# Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. +# """ +# # await self.ensure_shard(shard) + +# # toks = self.tokenizer.encode(prompt) +# # start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + +# # hidden_states, past_key_values = self.model.prefill(self.model, torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) +# # last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) + +# # output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) +# # output_data = output_data.detach().cpu().numpy() + +# # if output_data.size == 1: +# # start_pos += 1 + +# # return ( +# # output_data, +# # json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), +# # output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], +# # ) +# await self.ensure_shard(shard) + +# start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 +# toks = self.tokenizer.encode(prompt) + +# hidden_states, past_key_values = self.model.prefill(self.model, torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) +# last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) + +# output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) +# output_data = output_data.detach().cpu().numpy() + +# if output_data.size == 1: +# start_pos += 1 + +# return ( +# output_data, +# json.dumps({"start_pos": start_pos, "past_key_values": past_key_values}), +# output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], +# ) + + +# async def infer_tensor( +# self, +# request_id: str, +# shard: Optional[Shard] = None, +# input_data: np.ndarray = None, +# inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: +# """ +# Perform inference based on an input tensor. + +# Args: +# request_id (str): Unique identifier for the request. +# shard (Optional[Shard]): Shard information for the model. +# input_data (np.ndarray): The input tensor for inference. +# inference_state (Optional[str]): The previous inference state. + +# Returns: +# Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. +# """ +# # await self.ensure_shard(shard) + +# # start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + +# # output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), past_key_values=past_key_values) +# # output_data = output_data.detach().cpu().numpy() + +# # if output_data.size == 1: +# # start_pos += 1 + +# # return ( +# # output_data, +# # json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), +# # output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], +# # ) + +# await self.ensure_shard(shard) + +# input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) + +# output = self.model.forward_layers(input_tensor) + +# if self.shard.is_last_layer(): +# logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) +# next_token = torch.argmax(logits[:, -1, :], dim=-1) +# output_data = np.array([next_token.item()]) +# is_eos = next_token.item() == self.tokenizer.eos_token_id +# else: +# output_data = output.detach().cpu().numpy() +# is_eos = False + +# new_inference_state = json.dumps({"past_key_values": []}) + +# if self.debug: +# self.log.info(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + +# return output_data, new_inference_state, is_eos + +# def _apply_generation_settings(self, logits, temperature, top_k): +# """ +# Apply temperature and top_k settings to logits. + +# Args: +# logits (torch.Tensor): The logits to be adjusted. +# temperature (float): The temperature setting for generation. +# top_k (int): The top_k setting for generation. + +# Returns: +# torch.Tensor: The adjusted logits. +# """ +# logits = logits / temperature +# if top_k > 0: +# top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) +# logits = logits.scatter(1, top_k_indices, top_k_values) +# return logits + +# def _load_kv_cache(self, past_key_values_list): +# """ +# Load key-value cache from the inference state. + +# Args: +# past_key_values_list (list): List of past key-value tensors. + +# Returns: +# DynamicCache: Loaded past key-value cache. +# """ +# if past_key_values_list is None: +# return DynamicCache() +# cache = DynamicCache.from_legacy_cache(past_key_values_list) +# return cache + +# def _save_kv_cache(self, past_key_values): +# """ +# Save key-value cache to the inference state. + +# Args: +# past_key_values (DynamicCache): Past key-value cache. + +# Returns: +# list: List of key-value tensors in a format suitable for saving. +# """ +# return past_key_values.to_legacy_cache() + +# async def ensure_shard(self, shard: Optional[Shard]): +# """ +# Ensure the model shard is loaded and ready for inference. + +# Args: +# shard (Optional[Shard]): Shard information for the model. +# """ +# if self.shard == shard: +# return + +# # Load model and tokenizer from the downloaded files +# # This is written for llama model but need to add in option for others +# if not self.model: +# self.model = ShardedHuggingFaceModel(shard) +# self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) + +# self.shard = shard + +# def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): +# """ +# Set a callback function to track download progress. + +# Args: +# on_download_progress (Callable[[int, int], None]): Callback function to track progress. +# """ +# # must have this function or inference engine breaks +# # This method can be implemented if progress tracking is needed +# pass + -import os import json import torch -import torch.nn as nn import numpy as np from typing import Optional, Callable, Tuple -from transformers import AutoTokenizer, DynamicCache +from transformers import AutoTokenizer from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel @@ -52,48 +273,25 @@ async def infer_prompt( Returns: Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. """ - # await self.ensure_shard(shard) - - # toks = self.tokenizer.encode(prompt) - # start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - - # hidden_states, past_key_values = self.model.prefill(self.model, torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) - # last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) - - # output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) - # output_data = output_data.detach().cpu().numpy() - - # if output_data.size == 1: - # start_pos += 1 - - # return ( - # output_data, - # json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), - # output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], - # ) await self.ensure_shard(shard) - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) - - output = self.model.forward_layers(input_ids) - - if self.shard.is_last_layer(): - logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) - next_token = torch.argmax(logits[:, -1, :], dim=-1) - output_data = np.array([next_token.item()]) - is_eos = next_token.item() == self.tokenizer.eos_token_id - else: - output_data = output.detach().cpu().numpy() - is_eos = False + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + toks = self.tokenizer.encode(prompt) + + start_pos = self.model.prefill(torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) + last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) - new_inference_state = json.dumps({"past_key_values": []}) + output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=[]) + output_data = output_data.detach().cpu().numpy() - if self.debug: - self.log.info( - f"Infer Prompt Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") + if output_data.size == 1: + start_pos += 1 - return output_data, new_inference_state, is_eos - + return ( + output_data, + json.dumps({"start_pos": start_pos}), + output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], + ) async def infer_tensor( self, @@ -113,88 +311,21 @@ async def infer_tensor( Returns: Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. """ - # await self.ensure_shard(shard) - - # start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - - # output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), past_key_values=past_key_values) - # output_data = output_data.detach().cpu().numpy() - - # if output_data.size == 1: - # start_pos += 1 - - # return ( - # output_data, - # json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), - # output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], - # ) - await self.ensure_shard(shard) - input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - output = self.model.forward_layers(input_tensor) + output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), past_key_values=[]) + output_data = output_data.detach().cpu().numpy() - if self.shard.is_last_layer(): - logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) - next_token = torch.argmax(logits[:, -1, :], dim=-1) - output_data = np.array([next_token.item()]) - is_eos = next_token.item() == self.tokenizer.eos_token_id - else: - output_data = output.detach().cpu().numpy() - is_eos = False + if output_data.size == 1: + start_pos += 1 - new_inference_state = json.dumps({"past_key_values": []}) - - if self.debug: - self.log.info(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") - - return output_data, new_inference_state, is_eos - - def _apply_generation_settings(self, logits, temperature, top_k): - """ - Apply temperature and top_k settings to logits. - - Args: - logits (torch.Tensor): The logits to be adjusted. - temperature (float): The temperature setting for generation. - top_k (int): The top_k setting for generation. - - Returns: - torch.Tensor: The adjusted logits. - """ - logits = logits / temperature - if top_k > 0: - top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) - logits = logits.scatter(1, top_k_indices, top_k_values) - return logits - - def _load_kv_cache(self, past_key_values_list): - """ - Load key-value cache from the inference state. - - Args: - past_key_values_list (list): List of past key-value tensors. - - Returns: - DynamicCache: Loaded past key-value cache. - """ - if past_key_values_list is None: - return DynamicCache() - cache = DynamicCache.from_legacy_cache(past_key_values_list) - return cache - - def _save_kv_cache(self, past_key_values): - """ - Save key-value cache to the inference state. - - Args: - past_key_values (DynamicCache): Past key-value cache. - - Returns: - list: List of key-value tensors in a format suitable for saving. - """ - return past_key_values.to_legacy_cache() + return ( + output_data, + json.dumps({"start_pos": start_pos}), + output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], + ) async def ensure_shard(self, shard: Optional[Shard]): """ diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 1b7e751c..e85ebc52 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,6 +1,101 @@ +# import torch +# import torch.nn as nn +# from transformers import AutoModelForCausalLM, DynamicCache +# from exo.inference.shard import Shard + +# class ShardedHuggingFaceModel(nn.Module): +# def __init__(self, shard: Shard): +# super(ShardedHuggingFaceModel, self).__init__() +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# self.shard = shard + +# # Load the model +# self.full_model = AutoModelForCausalLM.from_pretrained( +# shard.model_id, +# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, +# device_map="auto" +# ) + +# # Extract only the layers for this shard +# self.layers = nn.ModuleList([ +# self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) +# ]) + +# # Embeddings and final layer norm +# self.embed_tokens = self.full_model.model.embed_tokens +# self.norm = self.full_model.model.norm +# self.lm_head = self.full_model.lm_head + +# def prefill(self, model, tokens, start_pos=0): +# """ +# Process the initial input tokens and set up the initial hidden states and key-value caches. +# """ +# # Token embeddings +# inputs_embeds = self.embed_tokens(tokens) + +# # Generate position ids +# position_ids = torch.arange(start_pos, start_pos + tokens.shape[-1], dtype=torch.long, device=tokens.device) +# position_ids = position_ids.unsqueeze(0).expand_as(tokens) + +# # Apply each layer in this shard +# hidden_states = inputs_embeds +# past_key_values = [] +# for i, layer in enumerate(self.layers): +# layer_past = None +# hidden_states, new_layer_past = layer( +# hidden_states, +# past_key_values=layer_past, +# use_cache=True, +# position_ids=position_ids +# ) +# past_key_values.append(new_layer_past) + +# return hidden_states, past_key_values + +# def forward_layers(self, input_ids, past_key_values=None): +# """ +# Forward pass through the specified layers. + +# Args: +# input_ids (torch.Tensor): Input token IDs. +# past_key_values (DynamicCache, optional): Past key values for caching. + +# Returns: +# tuple: Hidden states and new past key values. +# """ +# if past_key_values is None: +# past_key_values = DynamicCache() + +# # Token embeddings +# inputs_embeds = self.embed_tokens(input_ids) + +# # Generate position ids +# position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_ids.device) +# position_ids = position_ids.unsqueeze(0).expand(input_ids.shape[0], -1) + +# # Apply each layer in this shard +# hidden_states = inputs_embeds +# new_past_key_values = DynamicCache() +# for i, layer in enumerate(self.layers): +# layer_past = past_key_values[i] if i < len(past_key_values) else None +# hidden_states, new_layer_past = layer( +# hidden_states, +# past_key_values=layer_past, +# use_cache=True, +# position_ids=position_ids +# ) +# new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) + +# if self.shard.is_last_layer(): +# hidden_states = self.norm(hidden_states) +# logits = self.lm_head(hidden_states) +# return logits, new_past_key_values +# else: +# return hidden_states, new_past_key_values + import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, DynamicCache +from transformers import AutoModelForCausalLM from exo.inference.shard import Shard class ShardedHuggingFaceModel(nn.Module): @@ -8,6 +103,7 @@ def __init__(self, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard + self.device_ids = list(range(torch.cuda.device_count())) # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( @@ -26,9 +122,16 @@ def __init__(self, shard: Shard): self.norm = self.full_model.model.norm self.lm_head = self.full_model.lm_head - def prefill(self, model, tokens, start_pos=0): + def prefill(self, tokens, start_pos=0): """ Process the initial input tokens and set up the initial hidden states and key-value caches. + + Args: + tokens (torch.Tensor): Input tokens. + start_pos (int, optional): Starting position for position ids. Defaults to 0. + + Returns: + int: The updated start position. """ # Token embeddings inputs_embeds = self.embed_tokens(tokens) @@ -39,18 +142,18 @@ def prefill(self, model, tokens, start_pos=0): # Apply each layer in this shard hidden_states = inputs_embeds - past_key_values = [] - for i, layer in enumerate(self.layers): - layer_past = None - hidden_states, new_layer_past = layer( + for layer in self.layers: + hidden_states, _ = layer( hidden_states, - past_key_values=layer_past, + past_key_values=None, use_cache=True, position_ids=position_ids ) - past_key_values.append(new_layer_past) - return hidden_states, past_key_values + # Update start position + start_pos += tokens.shape[-1] + + return start_pos def forward_layers(self, input_ids, past_key_values=None): """ @@ -58,24 +161,24 @@ def forward_layers(self, input_ids, past_key_values=None): Args: input_ids (torch.Tensor): Input token IDs. - past_key_values (DynamicCache, optional): Past key values for caching. + past_key_values (list, optional): Past key values for caching. Returns: tuple: Hidden states and new past key values. """ if past_key_values is None: - past_key_values = DynamicCache() + past_key_values = [] # Token embeddings inputs_embeds = self.embed_tokens(input_ids) # Generate position ids position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand(input_ids.shape[0], -1) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # Apply each layer in this shard hidden_states = inputs_embeds - new_past_key_values = DynamicCache() + new_past_key_values = [] for i, layer in enumerate(self.layers): layer_past = past_key_values[i] if i < len(past_key_values) else None hidden_states, new_layer_past = layer( @@ -84,7 +187,7 @@ def forward_layers(self, input_ids, past_key_values=None): use_cache=True, position_ids=position_ids ) - new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) + new_past_key_values.append(new_layer_past) if self.shard.is_last_layer(): hidden_states = self.norm(hidden_states) @@ -92,3 +195,4 @@ def forward_layers(self, input_ids, past_key_values=None): return logits, new_past_key_values else: return hidden_states, new_past_key_values + From 2be2702f462a7e3fd61597f914cb0d642a19f8e3 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 17:44:33 -0800 Subject: [PATCH 077/491] fixing tensor shape issue --- exo/inference/pytorch/model/hf.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index e85ebc52..ad5d68c3 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -97,6 +97,7 @@ import torch.nn as nn from transformers import AutoModelForCausalLM from exo.inference.shard import Shard +from exo.helpers import DEBUG class ShardedHuggingFaceModel(nn.Module): def __init__(self, shard: Shard): @@ -133,11 +134,17 @@ def prefill(self, tokens, start_pos=0): Returns: int: The updated start position. """ + if DEBUG >=2: + print("\nShardedHuggingFaceModel.prefill called") + # Token embeddings inputs_embeds = self.embed_tokens(tokens) # Generate position ids position_ids = torch.arange(start_pos, start_pos + tokens.shape[-1], dtype=torch.long, device=tokens.device) + + if DEBUG >= 2: + print(f"tokens: {tokens}") position_ids = position_ids.unsqueeze(0).expand_as(tokens) # Apply each layer in this shard From f1943d16c7a2378d79d9b242d63449b405aad192 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 17:48:22 -0800 Subject: [PATCH 078/491] fixing tensor shape issue --- exo/api/chatgpt_api.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 65c6d40e..9ed4a47a 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -87,21 +87,22 @@ def resolve_tinygrad_tokenizer(model_id: str): async def resolve_tokenizer(model_id: str): - try: - if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}") - processor = AutoProcessor.from_pretrained(model_id, use_fast=False) - if not hasattr(processor, 'eos_token_id'): - processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id - if not hasattr(processor, 'encode'): - processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode - if not hasattr(processor, 'decode'): - processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode - return processor - except Exception as e: - if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}") - import traceback + if not model_id == "meta-llama/Meta-Llama-3.1-8B": + try: + if DEBUG >= 2: print(f"Trying AutoProcessor for {model_id}") + processor = AutoProcessor.from_pretrained(model_id, use_fast=False) + if not hasattr(processor, 'eos_token_id'): + processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id + if not hasattr(processor, 'encode'): + processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode + if not hasattr(processor, 'decode'): + processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode + return processor + except Exception as e: + if DEBUG >= 2: print(f"Failed to load processor for {model_id}. Error: {e}") + import traceback - if DEBUG >= 2: print(traceback.format_exc()) + if DEBUG >= 2: print(traceback.format_exc()) try: if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}") From a8d117a4508c0e13276dba3495a19a341d4ea4f9 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 18:05:21 -0800 Subject: [PATCH 079/491] fixing tensor shape issue --- exo/inference/pytorch/model/hf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index ad5d68c3..157a9da7 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -142,10 +142,7 @@ def prefill(self, tokens, start_pos=0): # Generate position ids position_ids = torch.arange(start_pos, start_pos + tokens.shape[-1], dtype=torch.long, device=tokens.device) - - if DEBUG >= 2: - print(f"tokens: {tokens}") - position_ids = position_ids.unsqueeze(0).expand_as(tokens) + position_ids = position_ids.unsqueeze(0).expand(tokens.shape[0], -1) # Match the shape of tokens # Apply each layer in this shard hidden_states = inputs_embeds @@ -160,6 +157,9 @@ def prefill(self, tokens, start_pos=0): # Update start position start_pos += tokens.shape[-1] + if DEBUG >= 2: + print(f"\nstart_post: {start_pos}\nposition_ids: {position_ids}") + return start_pos def forward_layers(self, input_ids, past_key_values=None): From 2acebf3a96103938278464391462c9bb7cc57f47 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 18:23:41 -0800 Subject: [PATCH 080/491] fixing model not getting right shard, utilizing past key values --- exo/inference/pytorch/inference.py | 277 +++++------------------------ 1 file changed, 41 insertions(+), 236 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index eb26a74f..b794b225 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,227 +1,4 @@ -# # experimental, based off of tinygrad/inference.py - -# import os -# import json -# import torch -# import torch.nn as nn -# import numpy as np -# from typing import Optional, Callable, Tuple -# from transformers import AutoTokenizer, DynamicCache -# from exo.inference.shard import Shard -# from exo.inference.inference_engine import InferenceEngine -# from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel - -# # Default settings -# TEMPERATURE = 0.7 -# TOP_K = 50 - -# class PyTorchDynamicShardInferenceEngine(InferenceEngine): -# """ -# PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. -# """ - -# def __init__(self, debug: bool = True): -# """ -# Initialize the inference engine. - -# Args: -# debug (bool): If True, enables debug logging. Defaults to False. -# """ -# self.shard = None -# self.model = None -# self.debug = debug -# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# async def infer_prompt( -# self, -# request_id: str, -# shard: Optional[Shard] = None, -# prompt: str = "", -# image_str: Optional[str] = None, -# inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: -# """ -# Perform inference based on a text prompt. - -# Args: -# request_id (str): Unique identifier for the request. -# shard (Optional[Shard]): Shard information for the model. -# prompt (str): The input text prompt for inference. -# image_str (Optional[str]): Optional image string for multi-modal models. -# inference_state (Optional[str]): The previous inference state. - -# Returns: -# Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. -# """ -# # await self.ensure_shard(shard) - -# # toks = self.tokenizer.encode(prompt) -# # start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - -# # hidden_states, past_key_values = self.model.prefill(self.model, torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) -# # last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) - -# # output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) -# # output_data = output_data.detach().cpu().numpy() - -# # if output_data.size == 1: -# # start_pos += 1 - -# # return ( -# # output_data, -# # json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), -# # output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], -# # ) -# await self.ensure_shard(shard) - -# start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 -# toks = self.tokenizer.encode(prompt) - -# hidden_states, past_key_values = self.model.prefill(self.model, torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) -# last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) - -# output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) -# output_data = output_data.detach().cpu().numpy() - -# if output_data.size == 1: -# start_pos += 1 - -# return ( -# output_data, -# json.dumps({"start_pos": start_pos, "past_key_values": past_key_values}), -# output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], -# ) - - -# async def infer_tensor( -# self, -# request_id: str, -# shard: Optional[Shard] = None, -# input_data: np.ndarray = None, -# inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: -# """ -# Perform inference based on an input tensor. - -# Args: -# request_id (str): Unique identifier for the request. -# shard (Optional[Shard]): Shard information for the model. -# input_data (np.ndarray): The input tensor for inference. -# inference_state (Optional[str]): The previous inference state. - -# Returns: -# Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. -# """ -# # await self.ensure_shard(shard) - -# # start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - -# # output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), past_key_values=past_key_values) -# # output_data = output_data.detach().cpu().numpy() - -# # if output_data.size == 1: -# # start_pos += 1 - -# # return ( -# # output_data, -# # json.dumps({"start_pos": start_pos, "past_key_values": past_key_values.to_legacy_cache()}), -# # output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], -# # ) - -# await self.ensure_shard(shard) - -# input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) - -# output = self.model.forward_layers(input_tensor) - -# if self.shard.is_last_layer(): -# logits = self._apply_generation_settings(output, TEMPERATURE, TOP_K) -# next_token = torch.argmax(logits[:, -1, :], dim=-1) -# output_data = np.array([next_token.item()]) -# is_eos = next_token.item() == self.tokenizer.eos_token_id -# else: -# output_data = output.detach().cpu().numpy() -# is_eos = False - -# new_inference_state = json.dumps({"past_key_values": []}) - -# if self.debug: -# self.log.info(f"Infer Tensor Debug - Request ID: {request_id}, Output: {output_data}, EOS: {is_eos}") - -# return output_data, new_inference_state, is_eos - -# def _apply_generation_settings(self, logits, temperature, top_k): -# """ -# Apply temperature and top_k settings to logits. - -# Args: -# logits (torch.Tensor): The logits to be adjusted. -# temperature (float): The temperature setting for generation. -# top_k (int): The top_k setting for generation. - -# Returns: -# torch.Tensor: The adjusted logits. -# """ -# logits = logits / temperature -# if top_k > 0: -# top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) -# logits = logits.scatter(1, top_k_indices, top_k_values) -# return logits - -# def _load_kv_cache(self, past_key_values_list): -# """ -# Load key-value cache from the inference state. - -# Args: -# past_key_values_list (list): List of past key-value tensors. - -# Returns: -# DynamicCache: Loaded past key-value cache. -# """ -# if past_key_values_list is None: -# return DynamicCache() -# cache = DynamicCache.from_legacy_cache(past_key_values_list) -# return cache - -# def _save_kv_cache(self, past_key_values): -# """ -# Save key-value cache to the inference state. - -# Args: -# past_key_values (DynamicCache): Past key-value cache. - -# Returns: -# list: List of key-value tensors in a format suitable for saving. -# """ -# return past_key_values.to_legacy_cache() - -# async def ensure_shard(self, shard: Optional[Shard]): -# """ -# Ensure the model shard is loaded and ready for inference. - -# Args: -# shard (Optional[Shard]): Shard information for the model. -# """ -# if self.shard == shard: -# return - -# # Load model and tokenizer from the downloaded files -# # This is written for llama model but need to add in option for others -# if not self.model: -# self.model = ShardedHuggingFaceModel(shard) -# self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) - -# self.shard = shard - -# def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): -# """ -# Set a callback function to track download progress. - -# Args: -# on_download_progress (Callable[[int, int], None]): Callback function to track progress. -# """ -# # must have this function or inference engine breaks -# # This method can be implemented if progress tracking is needed -# pass - +# experimental, based off of tinygrad/inference.py import json import torch @@ -275,13 +52,15 @@ async def infer_prompt( """ await self.ensure_shard(shard) - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 toks = self.tokenizer.encode(prompt) - + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + past_key_values_list = json.loads(inference_state).get("past_key_values", None) if inference_state else None + past_key_values = self._load_kv_cache(past_key_values_list) + start_pos = self.model.prefill(torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) - output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=[]) + output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) output_data = output_data.detach().cpu().numpy() if output_data.size == 1: @@ -289,7 +68,7 @@ async def infer_prompt( return ( output_data, - json.dumps({"start_pos": start_pos}), + json.dumps({"start_pos": start_pos, "past_key_values": self._save_kv_cache(past_key_values)}), output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], ) @@ -313,9 +92,13 @@ async def infer_tensor( """ await self.ensure_shard(shard) + input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 + past_key_values_list = json.loads(inference_state).get("past_key_values", None) if inference_state else None + past_key_values = self._load_kv_cache(past_key_values_list) - output_data, past_key_values = self.model.forward_layers(torch.tensor([input_data], device=self.model.device), past_key_values=[]) + output_data, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) output_data = output_data.detach().cpu().numpy() if output_data.size == 1: @@ -323,10 +106,36 @@ async def infer_tensor( return ( output_data, - json.dumps({"start_pos": start_pos}), + json.dumps({"start_pos": start_pos, "past_key_values": self._save_kv_cache(past_key_values)}), output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], ) + def _load_kv_cache(self, past_key_values_list): + """ + Load key-value cache from the inference state. + + Args: + past_key_values_list (list): List of past key-value tensors. + + Returns: + list: List of loaded past key-value tensors. + """ + if past_key_values_list is None: + return [] + return [torch.tensor(kv, device=self.device) for kv in past_key_values_list] + + def _save_kv_cache(self, past_key_values): + """ + Save key-value cache to the inference state. + + Args: + past_key_values (list): List of past key-value tensors. + + Returns: + list: List of key-value tensors in a format suitable for saving. + """ + return [kv.cpu().tolist() for kv in past_key_values] + async def ensure_shard(self, shard: Optional[Shard]): """ Ensure the model shard is loaded and ready for inference. @@ -337,12 +146,8 @@ async def ensure_shard(self, shard: Optional[Shard]): if self.shard == shard: return - # Load model and tokenizer from the downloaded files - # This is written for llama model but need to add in option for others - if not self.model: - self.model = ShardedHuggingFaceModel(shard) - self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) - + self.model = ShardedHuggingFaceModel(shard) + self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) self.shard = shard def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): From 4321c5746ef8e23201bf3d5a5b6555178f28a051 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 18:41:56 -0800 Subject: [PATCH 081/491] fixing return value issue with sendprompt --- exo/inference/pytorch/inference.py | 69 +++++++++----- exo/inference/pytorch/model/hf.py | 142 ++--------------------------- 2 files changed, 52 insertions(+), 159 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index b794b225..f5d7f7b7 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -18,7 +18,7 @@ class PyTorchDynamicShardInferenceEngine(InferenceEngine): PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. """ - def __init__(self, debug: bool = True): + def __init__(self, debug: bool = False): """ Initialize the inference engine. @@ -27,6 +27,7 @@ def __init__(self, debug: bool = True): """ self.shard = None self.model = None + self.tokenizer = None self.debug = debug self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -52,25 +53,31 @@ async def infer_prompt( """ await self.ensure_shard(shard) + if self.debug: + print(f"[{request_id}] Processing prompt: {prompt[:50]}...") + toks = self.tokenizer.encode(prompt) - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - past_key_values_list = json.loads(inference_state).get("past_key_values", None) if inference_state else None - past_key_values = self._load_kv_cache(past_key_values_list) + state = json.loads(inference_state) if inference_state else {} + start_pos = state.get("start_pos", 0) + past_key_values = self._load_kv_cache(state.get("past_key_values")) - start_pos = self.model.prefill(torch.tensor(toks[:-1], device=self.model.device), start_pos=start_pos) - last_tok = torch.tensor([toks[-1]], device=self.model.device).unsqueeze(0) + start_pos = self.model.prefill(torch.tensor(toks[:-1], device=self.device), start_pos=start_pos) + last_tok = torch.tensor([toks[-1]], device=self.device).unsqueeze(0) output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) output_data = output_data.detach().cpu().numpy() - if output_data.size == 1: - start_pos += 1 + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + new_state = { + "start_pos": start_pos + 1, + "past_key_values": self._save_kv_cache(past_key_values) + } + new_inference_state = json.dumps(new_state) + + if self.debug: + print(f"[{request_id}] Output size: {output_data.size}, Is finished: {is_finished}") - return ( - output_data, - json.dumps({"start_pos": start_pos, "past_key_values": self._save_kv_cache(past_key_values)}), - output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], - ) + return output_data, new_inference_state, is_finished async def infer_tensor( self, @@ -92,23 +99,29 @@ async def infer_tensor( """ await self.ensure_shard(shard) - input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.model.device) + if self.debug: + print(f"[{request_id}] Processing tensor input, shape: {input_data.shape}") - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 - past_key_values_list = json.loads(inference_state).get("past_key_values", None) if inference_state else None - past_key_values = self._load_kv_cache(past_key_values_list) + input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.device) + + state = json.loads(inference_state) if inference_state else {} + start_pos = state.get("start_pos", 0) + past_key_values = self._load_kv_cache(state.get("past_key_values")) output_data, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) output_data = output_data.detach().cpu().numpy() - if output_data.size == 1: - start_pos += 1 + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + new_state = { + "start_pos": start_pos + 1, + "past_key_values": self._save_kv_cache(past_key_values) + } + new_inference_state = json.dumps(new_state) + + if self.debug: + print(f"[{request_id}] Output size: {output_data.size}, Is finished: {is_finished}") - return ( - output_data, - json.dumps({"start_pos": start_pos, "past_key_values": self._save_kv_cache(past_key_values)}), - output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id], - ) + return output_data, new_inference_state, is_finished def _load_kv_cache(self, past_key_values_list): """ @@ -134,6 +147,8 @@ def _save_kv_cache(self, past_key_values): Returns: list: List of key-value tensors in a format suitable for saving. """ + if past_key_values is None: + return [] return [kv.cpu().tolist() for kv in past_key_values] async def ensure_shard(self, shard: Optional[Shard]): @@ -146,10 +161,16 @@ async def ensure_shard(self, shard: Optional[Shard]): if self.shard == shard: return + if self.debug: + print(f"Loading new shard: {shard}") + self.model = ShardedHuggingFaceModel(shard) self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) self.shard = shard + if self.debug: + print(f"Shard loaded successfully: {shard}") + def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): """ Set a callback function to track download progress. diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 157a9da7..151ff656 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,110 +1,13 @@ -# import torch -# import torch.nn as nn -# from transformers import AutoModelForCausalLM, DynamicCache -# from exo.inference.shard import Shard - -# class ShardedHuggingFaceModel(nn.Module): -# def __init__(self, shard: Shard): -# super(ShardedHuggingFaceModel, self).__init__() -# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# self.shard = shard - -# # Load the model -# self.full_model = AutoModelForCausalLM.from_pretrained( -# shard.model_id, -# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, -# device_map="auto" -# ) - -# # Extract only the layers for this shard -# self.layers = nn.ModuleList([ -# self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) -# ]) - -# # Embeddings and final layer norm -# self.embed_tokens = self.full_model.model.embed_tokens -# self.norm = self.full_model.model.norm -# self.lm_head = self.full_model.lm_head - -# def prefill(self, model, tokens, start_pos=0): -# """ -# Process the initial input tokens and set up the initial hidden states and key-value caches. -# """ -# # Token embeddings -# inputs_embeds = self.embed_tokens(tokens) - -# # Generate position ids -# position_ids = torch.arange(start_pos, start_pos + tokens.shape[-1], dtype=torch.long, device=tokens.device) -# position_ids = position_ids.unsqueeze(0).expand_as(tokens) - -# # Apply each layer in this shard -# hidden_states = inputs_embeds -# past_key_values = [] -# for i, layer in enumerate(self.layers): -# layer_past = None -# hidden_states, new_layer_past = layer( -# hidden_states, -# past_key_values=layer_past, -# use_cache=True, -# position_ids=position_ids -# ) -# past_key_values.append(new_layer_past) - -# return hidden_states, past_key_values - -# def forward_layers(self, input_ids, past_key_values=None): -# """ -# Forward pass through the specified layers. - -# Args: -# input_ids (torch.Tensor): Input token IDs. -# past_key_values (DynamicCache, optional): Past key values for caching. - -# Returns: -# tuple: Hidden states and new past key values. -# """ -# if past_key_values is None: -# past_key_values = DynamicCache() - -# # Token embeddings -# inputs_embeds = self.embed_tokens(input_ids) - -# # Generate position ids -# position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_ids.device) -# position_ids = position_ids.unsqueeze(0).expand(input_ids.shape[0], -1) - -# # Apply each layer in this shard -# hidden_states = inputs_embeds -# new_past_key_values = DynamicCache() -# for i, layer in enumerate(self.layers): -# layer_past = past_key_values[i] if i < len(past_key_values) else None -# hidden_states, new_layer_past = layer( -# hidden_states, -# past_key_values=layer_past, -# use_cache=True, -# position_ids=position_ids -# ) -# new_past_key_values.update(new_layer_past[0], new_layer_past[1], i) - -# if self.shard.is_last_layer(): -# hidden_states = self.norm(hidden_states) -# logits = self.lm_head(hidden_states) -# return logits, new_past_key_values -# else: -# return hidden_states, new_past_key_values - import torch -import torch.nn as nn from transformers import AutoModelForCausalLM from exo.inference.shard import Shard from exo.helpers import DEBUG -class ShardedHuggingFaceModel(nn.Module): +class ShardedHuggingFaceModel(torch.nn.Module): def __init__(self, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard - self.device_ids = list(range(torch.cuda.device_count())) # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( @@ -114,7 +17,7 @@ def __init__(self, shard: Shard): ) # Extract only the layers for this shard - self.layers = nn.ModuleList([ + self.layers = torch.nn.ModuleList([ self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) ]) @@ -124,25 +27,12 @@ def __init__(self, shard: Shard): self.lm_head = self.full_model.lm_head def prefill(self, tokens, start_pos=0): - """ - Process the initial input tokens and set up the initial hidden states and key-value caches. - - Args: - tokens (torch.Tensor): Input tokens. - start_pos (int, optional): Starting position for position ids. Defaults to 0. - - Returns: - int: The updated start position. - """ - if DEBUG >=2: - print("\nShardedHuggingFaceModel.prefill called") - # Token embeddings inputs_embeds = self.embed_tokens(tokens) # Generate position ids position_ids = torch.arange(start_pos, start_pos + tokens.shape[-1], dtype=torch.long, device=tokens.device) - position_ids = position_ids.unsqueeze(0).expand(tokens.shape[0], -1) # Match the shape of tokens + position_ids = position_ids.unsqueeze(0).expand_as(tokens) # Apply each layer in this shard hidden_states = inputs_embeds @@ -154,27 +44,11 @@ def prefill(self, tokens, start_pos=0): position_ids=position_ids ) - # Update start position - start_pos += tokens.shape[-1] - - if DEBUG >= 2: - print(f"\nstart_post: {start_pos}\nposition_ids: {position_ids}") - - return start_pos + return start_pos + tokens.shape[-1] def forward_layers(self, input_ids, past_key_values=None): - """ - Forward pass through the specified layers. - - Args: - input_ids (torch.Tensor): Input token IDs. - past_key_values (list, optional): Past key values for caching. - - Returns: - tuple: Hidden states and new past key values. - """ if past_key_values is None: - past_key_values = [] + past_key_values = [None] * len(self.layers) # Token embeddings inputs_embeds = self.embed_tokens(input_ids) @@ -187,10 +61,9 @@ def forward_layers(self, input_ids, past_key_values=None): hidden_states = inputs_embeds new_past_key_values = [] for i, layer in enumerate(self.layers): - layer_past = past_key_values[i] if i < len(past_key_values) else None hidden_states, new_layer_past = layer( hidden_states, - past_key_values=layer_past, + past_key_values=past_key_values[i], use_cache=True, position_ids=position_ids ) @@ -201,5 +74,4 @@ def forward_layers(self, input_ids, past_key_values=None): logits = self.lm_head(hidden_states) return logits, new_past_key_values else: - return hidden_states, new_past_key_values - + return hidden_states, new_past_key_values \ No newline at end of file From 7ae4856e6422763c6942c7e9addc3210aa6742ac Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 18:57:13 -0800 Subject: [PATCH 082/491] fixing return value issue with sendprompt --- exo/inference/pytorch/inference.py | 14 +++++++------- exo/inference/pytorch/model/hf.py | 9 +++++++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f5d7f7b7..a3ed7654 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -8,6 +8,7 @@ from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel +from exo.helpers import DEBUG # Default settings TEMPERATURE = 0.7 @@ -28,7 +29,6 @@ def __init__(self, debug: bool = False): self.shard = None self.model = None self.tokenizer = None - self.debug = debug self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") async def infer_prompt( @@ -53,7 +53,7 @@ async def infer_prompt( """ await self.ensure_shard(shard) - if self.debug: + if DEBUG >= 2: print(f"[{request_id}] Processing prompt: {prompt[:50]}...") toks = self.tokenizer.encode(prompt) @@ -74,7 +74,7 @@ async def infer_prompt( } new_inference_state = json.dumps(new_state) - if self.debug: + if DEBUG >= 2: print(f"[{request_id}] Output size: {output_data.size}, Is finished: {is_finished}") return output_data, new_inference_state, is_finished @@ -99,7 +99,7 @@ async def infer_tensor( """ await self.ensure_shard(shard) - if self.debug: + if DEBUG >= 2: print(f"[{request_id}] Processing tensor input, shape: {input_data.shape}") input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.device) @@ -118,7 +118,7 @@ async def infer_tensor( } new_inference_state = json.dumps(new_state) - if self.debug: + if DEBUG >= 2: print(f"[{request_id}] Output size: {output_data.size}, Is finished: {is_finished}") return output_data, new_inference_state, is_finished @@ -161,14 +161,14 @@ async def ensure_shard(self, shard: Optional[Shard]): if self.shard == shard: return - if self.debug: + if DEBUG >= 2: print(f"Loading new shard: {shard}") self.model = ShardedHuggingFaceModel(shard) self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) self.shard = shard - if self.debug: + if DEBUG >= 2: print(f"Shard loaded successfully: {shard}") def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 151ff656..c9067da9 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -2,10 +2,15 @@ from transformers import AutoModelForCausalLM from exo.inference.shard import Shard from exo.helpers import DEBUG +from typing import Tuple class ShardedHuggingFaceModel(torch.nn.Module): def __init__(self, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() + + if DEBUG >= 2: + print(f"ShardedHuggingFaceModel init with shard {shard}") + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard @@ -26,7 +31,7 @@ def __init__(self, shard: Shard): self.norm = self.full_model.model.norm self.lm_head = self.full_model.lm_head - def prefill(self, tokens, start_pos=0): + def prefill(self, tokens, start_pos=0) -> int: # Token embeddings inputs_embeds = self.embed_tokens(tokens) @@ -46,7 +51,7 @@ def prefill(self, tokens, start_pos=0): return start_pos + tokens.shape[-1] - def forward_layers(self, input_ids, past_key_values=None): + def forward_layers(self, input_ids, past_key_values=None) -> Tuple[any, list]: if past_key_values is None: past_key_values = [None] * len(self.layers) From 77a4403eccb43b1f8b1fead8f200b4b8ee62c517 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 19:43:39 -0800 Subject: [PATCH 083/491] fixing prefill stuff and infer_prompt --- exo/inference/pytorch/inference.py | 25 +++++++---- exo/inference/pytorch/model/hf.py | 68 ++++++++++++++++++------------ 2 files changed, 58 insertions(+), 35 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a3ed7654..f8468ac8 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -51,25 +51,33 @@ async def infer_prompt( Returns: Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. """ - await self.ensure_shard(shard) - + async def infer_prompt( + self, + request_id: str, + shard: Optional[Shard] = None, + prompt: str = "", + image_str: Optional[str] = None, + inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 2: print(f"[{request_id}] Processing prompt: {prompt[:50]}...") + await self.ensure_shard(shard) + toks = self.tokenizer.encode(prompt) state = json.loads(inference_state) if inference_state else {} start_pos = state.get("start_pos", 0) past_key_values = self._load_kv_cache(state.get("past_key_values")) - start_pos = self.model.prefill(torch.tensor(toks[:-1], device=self.device), start_pos=start_pos) - last_tok = torch.tensor([toks[-1]], device=self.device).unsqueeze(0) - - output_data, past_key_values = self.model.forward_layers(last_tok, past_key_values=past_key_values) + start_pos = self.model.prefill( + torch.tensor(toks[:-1], device=self.device), start_pos=start_pos) + + output_data, past_key_values = self.model.forward_layers(toks[:, -1:], past_key_values=past_key_values) output_data = output_data.detach().cpu().numpy() - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + is_finished = output_data.shape[1] == 1 and output_data[0, 0, -1] == self.tokenizer.eos_token_id new_state = { - "start_pos": start_pos + 1, + "start_pos": start_pos, "past_key_values": self._save_kv_cache(past_key_values) } new_inference_state = json.dumps(new_state) @@ -79,6 +87,7 @@ async def infer_prompt( return output_data, new_inference_state, is_finished + async def infer_tensor( self, request_id: str, diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index c9067da9..803af751 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -31,52 +31,66 @@ def __init__(self, shard: Shard): self.norm = self.full_model.model.norm self.lm_head = self.full_model.lm_head - def prefill(self, tokens, start_pos=0) -> int: - # Token embeddings - inputs_embeds = self.embed_tokens(tokens) + def prefill(self, tokens: torch.tensor, start_pos: int=0) -> int: + """ + Process the initial input tokens and set up the initial hidden states. + """ + # Assuming tokens is a 1D tensor of token IDs + for token in tokens: + # Convert token to a tensor and get embeddings + token_tensor = torch.tensor([[token]], device=self.device) + inputs_embeds = self.embed_tokens(token_tensor) - # Generate position ids - position_ids = torch.arange(start_pos, start_pos + tokens.shape[-1], dtype=torch.long, device=tokens.device) - position_ids = position_ids.unsqueeze(0).expand_as(tokens) + # Prefill with tokens + for layer in self.layers: + _ = layer( + inputs_embeds, + use_cache=True, + output_attentions=False, + ) + # Update embeddings with layer output + inputs_embeds = layer_outputs[0] - # Apply each layer in this shard - hidden_states = inputs_embeds - for layer in self.layers: - hidden_states, _ = layer( - hidden_states, - past_key_values=None, - use_cache=True, - position_ids=position_ids - ) - - return start_pos + tokens.shape[-1] + # Increment start position + start_pos += 1 + + return start_pos - def forward_layers(self, input_ids, past_key_values=None) -> Tuple[any, list]: + def forward_layers(self, input_ids, past_key_values=None): + """ + Forward pass through the specified layers. + """ if past_key_values is None: past_key_values = [None] * len(self.layers) # Token embeddings - inputs_embeds = self.embed_tokens(input_ids) + hidden_states = self.embed_tokens(input_ids) # Generate position ids - position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + seq_length = input_ids.shape[1] + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand(input_ids.shape) # Apply each layer in this shard - hidden_states = inputs_embeds new_past_key_values = [] for i, layer in enumerate(self.layers): - hidden_states, new_layer_past = layer( + layer_outputs = layer( hidden_states, - past_key_values=past_key_values[i], + attention_mask=None, + position_ids=position_ids, + past_key_value=past_key_values[i], use_cache=True, - position_ids=position_ids + output_attentions=False, ) - new_past_key_values.append(new_layer_past) + hidden_states = layer_outputs[0] + new_past_key_values.append(layer_outputs[1]) if self.shard.is_last_layer(): hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) return logits, new_past_key_values else: - return hidden_states, new_past_key_values \ No newline at end of file + return hidden_states, new_past_key_values + + def is_last_layer(self): + return self.shard.is_last_layer() \ No newline at end of file From dd23891273523411306999036c71a23169e197b1 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 20:42:17 -0800 Subject: [PATCH 084/491] fixing none for shape call --- exo/inference/pytorch/model/hf.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 803af751..4917f23e 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -39,18 +39,15 @@ def prefill(self, tokens: torch.tensor, start_pos: int=0) -> int: for token in tokens: # Convert token to a tensor and get embeddings token_tensor = torch.tensor([[token]], device=self.device) - inputs_embeds = self.embed_tokens(token_tensor) # Prefill with tokens for layer in self.layers: _ = layer( - inputs_embeds, + token_tensor, use_cache=True, output_attentions=False, ) - # Update embeddings with layer output - inputs_embeds = layer_outputs[0] - + # Increment start position start_pos += 1 From 82347462f5f537e7c06eb48c3cbf547b986e967b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 20:44:46 -0800 Subject: [PATCH 085/491] updating test --- exo/inference/pytorch/test_inference_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 1279ee1e..fbc314f0 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -12,8 +12,7 @@ def main(): ) engine = PyTorchDynamicShardInferenceEngine( - shard.model_id, - debug=True + shard ) From e603902fdf20ac3001778d04f1b45b323d5dcd96 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 20:50:15 -0800 Subject: [PATCH 086/491] updating hf --- exo/inference/pytorch/model/hf.py | 36 ++++++++++++------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 4917f23e..d87235fb 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -35,22 +35,30 @@ def prefill(self, tokens: torch.tensor, start_pos: int=0) -> int: """ Process the initial input tokens and set up the initial hidden states. """ - # Assuming tokens is a 1D tensor of token IDs + # Assuming tokens is a 1D tensor of token IDs for token in tokens: # Convert token to a tensor and get embeddings token_tensor = torch.tensor([[token]], device=self.device) + inputs_embeds = self.embed_tokens(token_tensor) + + if DEBUG >= 2: + print(f"Initial input embeddings shape: {inputs_embeds.shape}") # Prefill with tokens for layer in self.layers: - _ = layer( - token_tensor, + layer_outputs = layer( + inputs_embeds, use_cache=True, output_attentions=False, ) - + inputs_embeds = layer_outputs[0] + + if DEBUG >= 2: + print(f"Layer output shape: {inputs_embeds.shape}") + # Increment start position start_pos += 1 - + return start_pos def forward_layers(self, input_ids, past_key_values=None): @@ -62,19 +70,11 @@ def forward_layers(self, input_ids, past_key_values=None): # Token embeddings hidden_states = self.embed_tokens(input_ids) - - # Generate position ids - seq_length = input_ids.shape[1] - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand(input_ids.shape) - # Apply each layer in this shard new_past_key_values = [] for i, layer in enumerate(self.layers): layer_outputs = layer( hidden_states, - attention_mask=None, - position_ids=position_ids, past_key_value=past_key_values[i], use_cache=True, output_attentions=False, @@ -82,12 +82,4 @@ def forward_layers(self, input_ids, past_key_values=None): hidden_states = layer_outputs[0] new_past_key_values.append(layer_outputs[1]) - if self.shard.is_last_layer(): - hidden_states = self.norm(hidden_states) - logits = self.lm_head(hidden_states) - return logits, new_past_key_values - else: - return hidden_states, new_past_key_values - - def is_last_layer(self): - return self.shard.is_last_layer() \ No newline at end of file + return hidden_states, new_past_key_values \ No newline at end of file From 90eb294b4a611b3f0dff39f6a1c5d3d5aa12b808 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 20:53:52 -0800 Subject: [PATCH 087/491] updating hf --- exo/inference/pytorch/model/hf.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index d87235fb..f7e946a2 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -9,7 +9,7 @@ def __init__(self, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() if DEBUG >= 2: - print(f"ShardedHuggingFaceModel init with shard {shard}") + print(f"\nShardedHuggingFaceModel init with shard {shard}") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard @@ -40,21 +40,22 @@ def prefill(self, tokens: torch.tensor, start_pos: int=0) -> int: # Convert token to a tensor and get embeddings token_tensor = torch.tensor([[token]], device=self.device) inputs_embeds = self.embed_tokens(token_tensor) - if DEBUG >= 2: - print(f"Initial input embeddings shape: {inputs_embeds.shape}") + print(f"\nInitial input embeddings shape: {inputs_embeds.shape}") # Debugging # Prefill with tokens + position_ids = torch.arange(start_pos, start_pos + 1, dtype=torch.long, device=self.device).unsqueeze(0) for layer in self.layers: layer_outputs = layer( inputs_embeds, + position_ids=position_ids, use_cache=True, output_attentions=False, ) inputs_embeds = layer_outputs[0] if DEBUG >= 2: - print(f"Layer output shape: {inputs_embeds.shape}") + print(f"\nLayer output shape: {inputs_embeds.shape}") # Debugging # Increment start position start_pos += 1 @@ -82,4 +83,17 @@ def forward_layers(self, input_ids, past_key_values=None): hidden_states = layer_outputs[0] new_past_key_values.append(layer_outputs[1]) - return hidden_states, new_past_key_values \ No newline at end of file + return hidden_states, new_past_key_values + + def forward(self, input_ids, past_key_values=None): + """ + Forward pass through the model. + """ + hidden_states = self.prefill(input_ids) + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + + if DEBUG >= 2: + print(f"\nLogits shape: {logits.shape}") # Debugging + return logits + \ No newline at end of file From 49cf8e3e5ad1d9c455047db39db991940b791efc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 20:58:28 -0800 Subject: [PATCH 088/491] updating hf --- exo/inference/pytorch/inference.py | 2 +- exo/inference/pytorch/model/hf.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f8468ac8..9d4feb75 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -72,7 +72,7 @@ async def infer_prompt( start_pos = self.model.prefill( torch.tensor(toks[:-1], device=self.device), start_pos=start_pos) - output_data, past_key_values = self.model.forward_layers(toks[:, -1:], past_key_values=past_key_values) + output_data, past_key_values = self.model(toks[:, -1:], past_key_values=past_key_values) output_data = output_data.detach().cpu().numpy() is_finished = output_data.shape[1] == 1 and output_data[0, 0, -1] == self.tokenizer.eos_token_id diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index f7e946a2..71df1f5d 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -89,11 +89,10 @@ def forward(self, input_ids, past_key_values=None): """ Forward pass through the model. """ - hidden_states = self.prefill(input_ids) + hidden_states, new_past_key_values = self.forward_layers(input_ids, past_key_values) hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) if DEBUG >= 2: print(f"\nLogits shape: {logits.shape}") # Debugging - return logits - \ No newline at end of file + return logits, new_past_key_values From 9e7948f1a0819f1279e31236dfcb59340b26b358 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 21:07:51 -0800 Subject: [PATCH 089/491] working on none issue --- exo/inference/pytorch/inference.py | 99 +++++++++++------------------- 1 file changed, 37 insertions(+), 62 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 9d4feb75..59c6b3cf 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -32,12 +32,13 @@ def __init__(self, debug: bool = False): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") async def infer_prompt( - self, - request_id: str, - shard: Optional[Shard] = None, - prompt: str = "", - image_str: Optional[str] = None, - inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + self, + request_id: str, + shard: Optional[Shard] = None, + prompt: str = "", + image_str: Optional[str] = None, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: """ Perform inference based on a text prompt. @@ -51,86 +52,60 @@ async def infer_prompt( Returns: Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. """ - async def infer_prompt( - self, - request_id: str, - shard: Optional[Shard] = None, - prompt: str = "", - image_str: Optional[str] = None, - inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: - - if DEBUG >= 2: - print(f"[{request_id}] Processing prompt: {prompt[:50]}...") + # Ensure the shard is loaded await self.ensure_shard(shard) - toks = self.tokenizer.encode(prompt) - state = json.loads(inference_state) if inference_state else {} - start_pos = state.get("start_pos", 0) - past_key_values = self._load_kv_cache(state.get("past_key_values")) + # Tokenize the prompt + toks = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + + # Load the past key values from the inference state if available + past_key_values = self._load_kv_cache(inference_state) + + # Prefill the model with tokens + start_pos = self.model.prefill(toks.squeeze()) - start_pos = self.model.prefill( - torch.tensor(toks[:-1], device=self.device), start_pos=start_pos) - - output_data, past_key_values = self.model(toks[:, -1:], past_key_values=past_key_values) - output_data = output_data.detach().cpu().numpy() + # Run the forward pass through the model layers + output_data, past_key_values = self.model.forward_layers(toks[:, -1:], past_key_values=past_key_values) - is_finished = output_data.shape[1] == 1 and output_data[0, 0, -1] == self.tokenizer.eos_token_id - new_state = { - "start_pos": start_pos, - "past_key_values": self._save_kv_cache(past_key_values) - } - new_inference_state = json.dumps(new_state) + # Save the past key values to the inference state + new_inference_state = self._save_kv_cache(past_key_values) + + is_finished = False # Assuming a mechanism to determine if the sequence is finished if DEBUG >= 2: - print(f"[{request_id}] Output size: {output_data.size}, Is finished: {is_finished}") + print(f"Output data: {output_data}, new inference state: {new_inference_state}, finished: {is_finished}") return output_data, new_inference_state, is_finished - async def infer_tensor( - self, - request_id: str, - shard: Optional[Shard] = None, - input_data: np.ndarray = None, - inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + self, + input_tensor: torch.Tensor, + shard: Optional[Shard] = None, + past_key_values: Optional[list] = None + ) -> Tuple[torch.Tensor, list]: """ - Perform inference based on an input tensor. + Perform inference based on a tensor input. Args: - request_id (str): Unique identifier for the request. + input_tensor (torch.Tensor): The input tensor for inference. shard (Optional[Shard]): Shard information for the model. - input_data (np.ndarray): The input tensor for inference. - inference_state (Optional[str]): The previous inference state. + past_key_values (Optional[list]): The previous inference state. Returns: - Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. + Tuple[torch.Tensor, list]: The output tensor and new inference state. """ - await self.ensure_shard(shard) - - if DEBUG >= 2: - print(f"[{request_id}] Processing tensor input, shape: {input_data.shape}") - - input_tensor = torch.tensor(input_data).unsqueeze(0).to(self.device) - state = json.loads(inference_state) if inference_state else {} - start_pos = state.get("start_pos", 0) - past_key_values = self._load_kv_cache(state.get("past_key_values")) + # Ensure the shard is loaded + await self.ensure_shard(shard) + # Run the forward pass through the model layers output_data, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) - output_data = output_data.detach().cpu().numpy() - - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] - new_state = { - "start_pos": start_pos + 1, - "past_key_values": self._save_kv_cache(past_key_values) - } - new_inference_state = json.dumps(new_state) if DEBUG >= 2: - print(f"[{request_id}] Output size: {output_data.size}, Is finished: {is_finished}") + print(f"Output data shape: {output_data.shape}") - return output_data, new_inference_state, is_finished + return output_data, past_key_values def _load_kv_cache(self, past_key_values_list): """ From b0320e4fa414b9a6218fbc824260bacbfd233e5d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 21:10:16 -0800 Subject: [PATCH 090/491] working on forward_layers --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 71df1f5d..2a5a953a 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -76,7 +76,7 @@ def forward_layers(self, input_ids, past_key_values=None): for i, layer in enumerate(self.layers): layer_outputs = layer( hidden_states, - past_key_value=past_key_values[i], + past_key_value=past_key_values[i] if len(past_key_values) > 0 else None, use_cache=True, output_attentions=False, ) From 6464b40f1e6d01c9df5033cc8d6e45982196e81c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 21:29:07 -0800 Subject: [PATCH 091/491] working on forward_layers bug --- exo/inference/pytorch/inference.py | 30 ++++++++++++++++++++++-------- exo/inference/pytorch/model/hf.py | 16 +++++++++++++--- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 59c6b3cf..25456d38 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -71,19 +71,23 @@ async def infer_prompt( # Save the past key values to the inference state new_inference_state = self._save_kv_cache(past_key_values) - is_finished = False # Assuming a mechanism to determine if the sequence is finished + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 2: print(f"Output data: {output_data}, new inference state: {new_inference_state}, finished: {is_finished}") - return output_data, new_inference_state, is_finished + return ( + output_data, + json.dumps({"start_pos": start_pos}), + is_finished + ) async def infer_tensor( self, - input_tensor: torch.Tensor, - shard: Optional[Shard] = None, - past_key_values: Optional[list] = None - ) -> Tuple[torch.Tensor, list]: + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: """ Perform inference based on a tensor input. @@ -98,14 +102,24 @@ async def infer_tensor( # Ensure the shard is loaded await self.ensure_shard(shard) + start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 # Run the forward pass through the model layers - output_data, past_key_values = self.model.forward_layers(input_tensor, past_key_values=past_key_values) + output_data, past_key_values = self.model.forward_layers( + input_data, + past_key_values=past_key_values + ) + + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 2: print(f"Output data shape: {output_data.shape}") - return output_data, past_key_values + return ( + output_data, + json.dumps({"start_pos": start_pos}), + is_finished + ) def _load_kv_cache(self, past_key_values_list): """ diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 2a5a953a..10ddc756 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -41,7 +41,7 @@ def prefill(self, tokens: torch.tensor, start_pos: int=0) -> int: token_tensor = torch.tensor([[token]], device=self.device) inputs_embeds = self.embed_tokens(token_tensor) if DEBUG >= 2: - print(f"\nInitial input embeddings shape: {inputs_embeds.shape}") # Debugging + print(f"\nprefill shape: {inputs_embeds.shape}") # Debugging # Prefill with tokens position_ids = torch.arange(start_pos, start_pos + 1, dtype=torch.long, device=self.device).unsqueeze(0) @@ -62,7 +62,7 @@ def prefill(self, tokens: torch.tensor, start_pos: int=0) -> int: return start_pos - def forward_layers(self, input_ids, past_key_values=None): + def forward_layers(self, start_pos, input_ids, past_key_values=None): """ Forward pass through the specified layers. """ @@ -72,19 +72,29 @@ def forward_layers(self, input_ids, past_key_values=None): # Token embeddings hidden_states = self.embed_tokens(input_ids) + # Initialize position_ids + position_ids = torch.arange(start_pos, start_pos + input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0) + new_past_key_values = [] for i, layer in enumerate(self.layers): + # Get past key value if available + past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None + + # Forward pass through the layer layer_outputs = layer( hidden_states, - past_key_value=past_key_values[i] if len(past_key_values) > 0 else None, + position_ids=position_ids, + past_key_value=past_key_value, use_cache=True, output_attentions=False, ) + hidden_states = layer_outputs[0] new_past_key_values.append(layer_outputs[1]) return hidden_states, new_past_key_values + def forward(self, input_ids, past_key_values=None): """ Forward pass through the model. From b7fd0c197bea29efdac2a707f64c09ee91a7f255 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:00:26 -0800 Subject: [PATCH 092/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/inference.py | 39 +++++++++--------------------- exo/inference/pytorch/model/hf.py | 36 ++++++++++++++++++--------- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 25456d38..2594c3ca 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -39,19 +39,6 @@ async def infer_prompt( image_str: Optional[str] = None, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: - """ - Perform inference based on a text prompt. - - Args: - request_id (str): Unique identifier for the request. - shard (Optional[Shard]): Shard information for the model. - prompt (str): The input text prompt for inference. - image_str (Optional[str]): Optional image string for multi-modal models. - inference_state (Optional[str]): The previous inference state. - - Returns: - Tuple[np.ndarray, str, bool]: The output data, new inference state, and end-of-sequence flag. - """ # Ensure the shard is loaded await self.ensure_shard(shard) @@ -63,10 +50,18 @@ async def infer_prompt( past_key_values = self._load_kv_cache(inference_state) # Prefill the model with tokens - start_pos = self.model.prefill(toks.squeeze()) + start_pos = self.model.prefill(toks[:-1]) + last_token = toks[-1] # Run the forward pass through the model layers - output_data, past_key_values = self.model.forward_layers(toks[:, -1:], past_key_values=past_key_values) + output_data, past_key_values = self.model.forward_layers( + start_pos, + torch.tensor( + last_token, + device=self.device + ), + past_key_values=past_key_values + ) # Save the past key values to the inference state new_inference_state = self._save_kv_cache(past_key_values) @@ -88,17 +83,6 @@ async def infer_tensor( shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: - """ - Perform inference based on a tensor input. - - Args: - input_tensor (torch.Tensor): The input tensor for inference. - shard (Optional[Shard]): Shard information for the model. - past_key_values (Optional[list]): The previous inference state. - - Returns: - Tuple[torch.Tensor, list]: The output tensor and new inference state. - """ # Ensure the shard is loaded await self.ensure_shard(shard) @@ -106,7 +90,8 @@ async def infer_tensor( # Run the forward pass through the model layers output_data, past_key_values = self.model.forward_layers( - input_data, + start_pos, + torch.tensor(input_data), past_key_values=past_key_values ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 10ddc756..b0938a03 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -31,7 +31,7 @@ def __init__(self, shard: Shard): self.norm = self.full_model.model.norm self.lm_head = self.full_model.lm_head - def prefill(self, tokens: torch.tensor, start_pos: int=0) -> int: + def prefill(self, tokens: list[int], start_pos: int=0) -> int: """ Process the initial input tokens and set up the initial hidden states. """ @@ -39,7 +39,10 @@ def prefill(self, tokens: torch.tensor, start_pos: int=0) -> int: for token in tokens: # Convert token to a tensor and get embeddings token_tensor = torch.tensor([[token]], device=self.device) - inputs_embeds = self.embed_tokens(token_tensor) + + if self.shard.is_first_layer(): + token_tensor = self.embed_tokens(token_tensor) + if DEBUG >= 2: print(f"\nprefill shape: {inputs_embeds.shape}") # Debugging @@ -62,37 +65,48 @@ def prefill(self, tokens: torch.tensor, start_pos: int=0) -> int: return start_pos - def forward_layers(self, start_pos, input_ids, past_key_values=None): + def forward_layers( + self, + start_pos: int, + in_tensor: torch.tensor, + past_key_values=None + ) -> Tuple[any, list]: + """ Forward pass through the specified layers. """ if past_key_values is None: past_key_values = [None] * len(self.layers) - # Token embeddings - hidden_states = self.embed_tokens(input_ids) - # Initialize position_ids - position_ids = torch.arange(start_pos, start_pos + input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0) + position_ids = torch.arange( + start_pos, + start_pos + in_tensor.size(1), + dtype=torch.long, + device=in_tensor.device + ).unsqueeze(0) new_past_key_values = [] for i, layer in enumerate(self.layers): # Get past key value if available - past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None + if past_key_values and len(past_key_values) > 0: + past_key_value = past_key_values[i] + else: + past_key_value = None # Forward pass through the layer layer_outputs = layer( - hidden_states, + layer_out, position_ids=position_ids, past_key_value=past_key_value, use_cache=True, output_attentions=False, ) - hidden_states = layer_outputs[0] + layer_out = layer_outputs[0] new_past_key_values.append(layer_outputs[1]) - return hidden_states, new_past_key_values + return layer_out, new_past_key_values def forward(self, input_ids, past_key_values=None): From aba82da8d489d2be8bf3324ddc496defca241c5d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:09:01 -0800 Subject: [PATCH 093/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/model/hf.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index b0938a03..819852e2 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -44,21 +44,10 @@ def prefill(self, tokens: list[int], start_pos: int=0) -> int: token_tensor = self.embed_tokens(token_tensor) if DEBUG >= 2: - print(f"\nprefill shape: {inputs_embeds.shape}") # Debugging + print(f"\ntoken_tensor shape: {token_tensor.shape}") # Prefill with tokens - position_ids = torch.arange(start_pos, start_pos + 1, dtype=torch.long, device=self.device).unsqueeze(0) - for layer in self.layers: - layer_outputs = layer( - inputs_embeds, - position_ids=position_ids, - use_cache=True, - output_attentions=False, - ) - inputs_embeds = layer_outputs[0] - - if DEBUG >= 2: - print(f"\nLayer output shape: {inputs_embeds.shape}") # Debugging + self.forward_layers(start_pos, token_tensor, None) # Increment start position start_pos += 1 From 46ae2c07e1696bc4266f5454a5685bb5d31ea65e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:16:42 -0800 Subject: [PATCH 094/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/model/hf.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 819852e2..bb48b054 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -69,13 +69,15 @@ def forward_layers( # Initialize position_ids position_ids = torch.arange( - start_pos, - start_pos + in_tensor.size(1), - dtype=torch.long, - device=in_tensor.device - ).unsqueeze(0) + start_pos, + start_pos + len(past_key_values), + dtype=torch.long, + device=in_tensor + ) + position_ids = position_ids.unsqueeze(0) new_past_key_values = [] + out_tensor = None for i, layer in enumerate(self.layers): # Get past key value if available if past_key_values and len(past_key_values) > 0: @@ -85,17 +87,17 @@ def forward_layers( # Forward pass through the layer layer_outputs = layer( - layer_out, + in_tensor if not out_tensor else out_tensor, position_ids=position_ids, past_key_value=past_key_value, use_cache=True, output_attentions=False, ) - layer_out = layer_outputs[0] + out_tensor = layer_outputs[0] new_past_key_values.append(layer_outputs[1]) - return layer_out, new_past_key_values + return out_tensor, new_past_key_values def forward(self, input_ids, past_key_values=None): From 5c5b6a40533828b5a06238f4b2121cd3012f7c4c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:23:04 -0800 Subject: [PATCH 095/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index bb48b054..e3c6ef95 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -72,7 +72,7 @@ def forward_layers( start_pos, start_pos + len(past_key_values), dtype=torch.long, - device=in_tensor + device=self.device ) position_ids = position_ids.unsqueeze(0) From 7ec1aefa8e99c6114123f519345541a644119194 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:32:42 -0800 Subject: [PATCH 096/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/inference.py | 10 +++++----- exo/inference/pytorch/model/hf.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 2594c3ca..2d10397f 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -51,15 +51,15 @@ async def infer_prompt( # Prefill the model with tokens start_pos = self.model.prefill(toks[:-1]) - last_token = toks[-1] + last_token = torch.tensor( + toks[-1], + device=self.device + ) # Run the forward pass through the model layers output_data, past_key_values = self.model.forward_layers( start_pos, - torch.tensor( - last_token, - device=self.device - ), + last_token, past_key_values=past_key_values ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index e3c6ef95..055dd2c8 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -22,6 +22,7 @@ def __init__(self, shard: Shard): ) # Extract only the layers for this shard + print(f"\nself.model: {self.model}\n") self.layers = torch.nn.ModuleList([ self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) ]) From cdc7efba7eb8d15e5b995f532049ea62f451d70a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:33:52 -0800 Subject: [PATCH 097/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 055dd2c8..6c6b8e62 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -22,7 +22,7 @@ def __init__(self, shard: Shard): ) # Extract only the layers for this shard - print(f"\nself.model: {self.model}\n") + print(f"\nself.model: {self.full_model.model}\n") self.layers = torch.nn.ModuleList([ self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) ]) From 6934d354769c8bfa5c2a66f3abc1af686471a179 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:41:10 -0800 Subject: [PATCH 098/491] cleaning up code, fixing passing of tensors --- exo/api/chatgpt_api.py | 2 +- exo/inference/pytorch/model/hf.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 9ed4a47a..fb041b3c 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -16,7 +16,7 @@ ### llama "llama-3.1-8b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), - "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=12), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=32), }, "llama-3.1-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 6c6b8e62..8d2c2aad 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -40,9 +40,7 @@ def prefill(self, tokens: list[int], start_pos: int=0) -> int: for token in tokens: # Convert token to a tensor and get embeddings token_tensor = torch.tensor([[token]], device=self.device) - - if self.shard.is_first_layer(): - token_tensor = self.embed_tokens(token_tensor) + token_tensor = self.embed_tokens(token_tensor) if DEBUG >= 2: print(f"\ntoken_tensor shape: {token_tensor.shape}") @@ -65,6 +63,10 @@ def forward_layers( """ Forward pass through the specified layers. """ + # embed in_tensor + in_tensor = self.embed_tokens(in_tensor) + + # check past key values if past_key_values is None: past_key_values = [None] * len(self.layers) From 5f87cb666f84fc448d92a9420b19ac3827591956 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:43:06 -0800 Subject: [PATCH 099/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/model/hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 8d2c2aad..6c609cbc 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -33,6 +33,7 @@ def __init__(self, shard: Shard): self.lm_head = self.full_model.lm_head def prefill(self, tokens: list[int], start_pos: int=0) -> int: + print(f"\nprefill called") """ Process the initial input tokens and set up the initial hidden states. """ @@ -64,7 +65,7 @@ def forward_layers( Forward pass through the specified layers. """ # embed in_tensor - in_tensor = self.embed_tokens(in_tensor) + # in_tensor = self.embed_tokens(in_tensor) # check past key values if past_key_values is None: From 8eae07d9d56747bd8c59a05dba854ea07d05a967 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:46:00 -0800 Subject: [PATCH 100/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/model/hf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 6c609cbc..2ac67653 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -64,8 +64,11 @@ def forward_layers( """ Forward pass through the specified layers. """ - # embed in_tensor - # in_tensor = self.embed_tokens(in_tensor) + # embed tensor if first layer + if self.shard.is_first_layer(): + if DEBUG >= 2: + print(f"Embedding first layer in_tensor {in_tensor.shape()}") + in_tensor = self.embed_tokens(in_tensor) # check past key values if past_key_values is None: From f8f8e54de099f82d57244217b7abe665725ab9ab Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:46:58 -0800 Subject: [PATCH 101/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 2ac67653..2c902e54 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -67,7 +67,7 @@ def forward_layers( # embed tensor if first layer if self.shard.is_first_layer(): if DEBUG >= 2: - print(f"Embedding first layer in_tensor {in_tensor.shape()}") + print(f"Embedding first layer in_tensor {in_tensor.shape}") in_tensor = self.embed_tokens(in_tensor) # check past key values From 189760a764efdf82e37e256efbba64c65e5cec8e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:49:44 -0800 Subject: [PATCH 102/491] cleaning up code, fixing passing of tensors --- exo/inference/pytorch/model/hf.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 2c902e54..e3f00611 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -93,6 +93,9 @@ def forward_layers( past_key_value = None # Forward pass through the layer + if DEBUG >= 2: + print(f"pass tensor to layer[{i}] {layer}") + layer_outputs = layer( in_tensor if not out_tensor else out_tensor, position_ids=position_ids, @@ -105,16 +108,3 @@ def forward_layers( new_past_key_values.append(layer_outputs[1]) return out_tensor, new_past_key_values - - - def forward(self, input_ids, past_key_values=None): - """ - Forward pass through the model. - """ - hidden_states, new_past_key_values = self.forward_layers(input_ids, past_key_values) - hidden_states = self.norm(hidden_states) - logits = self.lm_head(hidden_states) - - if DEBUG >= 2: - print(f"\nLogits shape: {logits.shape}") # Debugging - return logits, new_past_key_values From 704da617e7f87e19c8b0dcb763cc2d7d5c2246e0 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 22:53:50 -0800 Subject: [PATCH 103/491] messing with layers --- exo/inference/pytorch/model/hf.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index e3f00611..3508aed6 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -23,9 +23,14 @@ def __init__(self, shard: Shard): # Extract only the layers for this shard print(f"\nself.model: {self.full_model.model}\n") - self.layers = torch.nn.ModuleList([ - self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) - ]) + self.layers = [] + for i in range(shard.start_layer, shard.end_layer + 1): + if DEBUG >= 2: + print(f"layer[{i}]: {self.full_model.model.layers[i]}") + + self.layers.append(self.full_model.model.layers[i]) + + # self.layers = torch.nn.ModuleList(layer_list) # Embeddings and final layer norm self.embed_tokens = self.full_model.model.embed_tokens @@ -95,7 +100,7 @@ def forward_layers( # Forward pass through the layer if DEBUG >= 2: print(f"pass tensor to layer[{i}] {layer}") - + layer_outputs = layer( in_tensor if not out_tensor else out_tensor, position_ids=position_ids, From 9bad63f8ec92b5bd8b601d55c8b7105c8795cc58 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 23:02:01 -0800 Subject: [PATCH 104/491] messing with layers --- exo/inference/pytorch/model/hf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 3508aed6..006447a5 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -23,6 +23,7 @@ def __init__(self, shard: Shard): # Extract only the layers for this shard print(f"\nself.model: {self.full_model.model}\n") + print(f"\nlayer amount: {len(self.full_model.model.layers)}") self.layers = [] for i in range(shard.start_layer, shard.end_layer + 1): if DEBUG >= 2: From 2ecb3b749953d87ca6b9557c2f40d760b0c2273d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 23:23:23 -0800 Subject: [PATCH 105/491] messing with layers --- exo/inference/pytorch/inference.py | 24 +++------- exo/inference/pytorch/model/hf.py | 77 ++++++++++++++---------------- 2 files changed, 42 insertions(+), 59 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 2d10397f..d20755a1 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -44,38 +44,26 @@ async def infer_prompt( await self.ensure_shard(shard) # Tokenize the prompt - toks = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + toks = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) # Load the past key values from the inference state if available past_key_values = self._load_kv_cache(inference_state) - # Prefill the model with tokens - start_pos = self.model.prefill(toks[:-1]) - last_token = torch.tensor( - toks[-1], - device=self.device - ) - # Run the forward pass through the model layers output_data, past_key_values = self.model.forward_layers( - start_pos, - last_token, + input_ids=toks, past_key_values=past_key_values ) # Save the past key values to the inference state - new_inference_state = self._save_kv_cache(past_key_values) + self._save_kv_cache(past_key_values) - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + is_finished = False # Assuming a mechanism to determine if the sequence is finished if DEBUG >= 2: - print(f"Output data: {output_data}, new inference state: {new_inference_state}, finished: {is_finished}") + print(f"Output data: {output_data}, new inference state: {past_key_values}, finished: {is_finished}") - return ( - output_data, - json.dumps({"start_pos": start_pos}), - is_finished - ) + return output_data, "", is_finished async def infer_tensor( self, diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 006447a5..bca8ce6f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -38,79 +38,74 @@ def __init__(self, shard: Shard): self.norm = self.full_model.model.norm self.lm_head = self.full_model.lm_head - def prefill(self, tokens: list[int], start_pos: int=0) -> int: - print(f"\nprefill called") - """ - Process the initial input tokens and set up the initial hidden states. - """ - # Assuming tokens is a 1D tensor of token IDs - for token in tokens: - # Convert token to a tensor and get embeddings - token_tensor = torch.tensor([[token]], device=self.device) - token_tensor = self.embed_tokens(token_tensor) + # def prefill(self, tokens: list[int], start_pos: int=0) -> int: + # print(f"\nprefill called") + # """ + # Process the initial input tokens and set up the initial hidden states. + # """ + # # Assuming tokens is a 1D tensor of token IDs + # for token in tokens: + # # Convert token to a tensor and get embeddings + # token_tensor = torch.tensor([[token]], device=self.device) + # token_tensor = self.embed_tokens(token_tensor) - if DEBUG >= 2: - print(f"\ntoken_tensor shape: {token_tensor.shape}") + # if DEBUG >= 2: + # print(f"\ntoken_tensor shape: {token_tensor.shape}") - # Prefill with tokens - self.forward_layers(start_pos, token_tensor, None) + # # Prefill with tokens + # self.forward_layers(start_pos, token_tensor, None) - # Increment start position - start_pos += 1 + # # Increment start position + # start_pos += 1 - return start_pos + # return start_pos def forward_layers( - self, - start_pos: int, - in_tensor: torch.tensor, - past_key_values=None - ) -> Tuple[any, list]: - + self, + input_ids: torch.tensor, + past_key_values=None + ) -> Tuple[any, list]: """ Forward pass through the specified layers. """ - # embed tensor if first layer + # Embed tensor if first layer if self.shard.is_first_layer(): if DEBUG >= 2: - print(f"Embedding first layer in_tensor {in_tensor.shape}") - in_tensor = self.embed_tokens(in_tensor) + print(f"Embedding first layer input_ids {input_ids.shape}") + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_ids - # check past key values + # Check past key values if past_key_values is None: past_key_values = [None] * len(self.layers) # Initialize position_ids position_ids = torch.arange( - start_pos, - start_pos + len(past_key_values), - dtype=torch.long, + hidden_states.size(1), + dtype=torch.long, device=self.device - ) - position_ids = position_ids.unsqueeze(0) + ).unsqueeze(0) new_past_key_values = [] - out_tensor = None for i, layer in enumerate(self.layers): # Get past key value if available - if past_key_values and len(past_key_values) > 0: - past_key_value = past_key_values[i] - else: - past_key_value = None + past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None # Forward pass through the layer if DEBUG >= 2: - print(f"pass tensor to layer[{i}] {layer}") + print(f"Pass tensor to layer[{i}] {layer}") layer_outputs = layer( - in_tensor if not out_tensor else out_tensor, + hidden_states, position_ids=position_ids, past_key_value=past_key_value, use_cache=True, output_attentions=False, ) - out_tensor = layer_outputs[0] + hidden_states = layer_outputs[0] new_past_key_values.append(layer_outputs[1]) - return out_tensor, new_past_key_values + return hidden_states, new_past_key_values + From d22439e988e758e17a8dc8af234f991d816bde0c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 23:27:44 -0800 Subject: [PATCH 106/491] fixing kv save --- exo/inference/pytorch/inference.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index d20755a1..862da6ee 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -120,7 +120,13 @@ def _save_kv_cache(self, past_key_values): """ if past_key_values is None: return [] - return [kv.cpu().tolist() for kv in past_key_values] + + new_cache = [] + for kv in past_key_values: + if kv: + new_cache.append(kv.cpu().tolist()) + + return new_cache async def ensure_shard(self, shard: Optional[Shard]): """ From 7c54803e9fd9f4938909a7c45ed96b9db78e5c73 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 23:45:13 -0800 Subject: [PATCH 107/491] fixing last layer issue --- exo/inference/pytorch/model/hf.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index bca8ce6f..d13c5783 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -107,5 +107,10 @@ def forward_layers( hidden_states = layer_outputs[0] new_past_key_values.append(layer_outputs[1]) - return hidden_states, new_past_key_values + if self.shard.is_last_layer(): + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states).float()[:, -1, :] + return logits + else: + return hidden_states, new_past_key_values From a2478897b3182efdbc0ace84d053526b848d97f1 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 7 Aug 2024 23:46:40 -0800 Subject: [PATCH 108/491] fixing last layer issue --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index d13c5783..b4a4c90a 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -110,7 +110,7 @@ def forward_layers( if self.shard.is_last_layer(): hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states).float()[:, -1, :] - return logits + return logits, new_past_key_values else: return hidden_states, new_past_key_values From d1ea73a95e8e40de055da6b78a6077b879b9b171 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 00:01:18 -0800 Subject: [PATCH 109/491] fixing last layer issue --- exo/inference/pytorch/model/hf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index b4a4c90a..acff9756 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -94,7 +94,7 @@ def forward_layers( # Forward pass through the layer if DEBUG >= 2: - print(f"Pass tensor to layer[{i}] {layer}") + print(f"\nPass tensor to layer[{i}] {layer}") layer_outputs = layer( hidden_states, @@ -103,6 +103,9 @@ def forward_layers( use_cache=True, output_attentions=False, ) + + if DEBUG >= 2: + print(f"\nlayer_outputs: {layer_outputs}") hidden_states = layer_outputs[0] new_past_key_values.append(layer_outputs[1]) From 3a36dfb8d490345bd665d1a7e6257530de57c38b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 00:07:15 -0800 Subject: [PATCH 110/491] fixing last layer issue --- exo/inference/pytorch/model/hf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index acff9756..73d8aab8 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -99,9 +99,8 @@ def forward_layers( layer_outputs = layer( hidden_states, position_ids=position_ids, - past_key_value=past_key_value, - use_cache=True, - output_attentions=False, + # past_key_value=past_key_value, + use_cache=True ) if DEBUG >= 2: From 9596569a55ec793f388f9e72406d78374d67473b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 00:13:39 -0800 Subject: [PATCH 111/491] fixing last layer issue --- exo/inference/pytorch/model/hf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 73d8aab8..63d7b5f9 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -110,8 +110,7 @@ def forward_layers( new_past_key_values.append(layer_outputs[1]) if self.shard.is_last_layer(): - hidden_states = self.norm(hidden_states) - logits = self.lm_head(hidden_states).float()[:, -1, :] + _, logits, _, _, = self.full_model(hidden_states, position_ids=position_ids) return logits, new_past_key_values else: return hidden_states, new_past_key_values From 60aa203d9b8a49c7ccf376c46f21245525d483c2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:27:47 -0800 Subject: [PATCH 112/491] removed new past key values from llama hf model, returning None from model LlamaDecoderLayer, updating infer output to numpy array --- exo/inference/pytorch/inference.py | 28 ++++++++++++++++------------ exo/inference/pytorch/model/hf.py | 22 ++++++++++++---------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 862da6ee..15c7c82f 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,6 +1,6 @@ # experimental, based off of tinygrad/inference.py -import json +import numpy as np import torch import numpy as np from typing import Optional, Callable, Tuple @@ -50,20 +50,25 @@ async def infer_prompt( past_key_values = self._load_kv_cache(inference_state) # Run the forward pass through the model layers - output_data, past_key_values = self.model.forward_layers( + # output_data, past_key_values + output_data = self.model.forward_layers( input_ids=toks, - past_key_values=past_key_values + # past_key_values=past_key_values ) # Save the past key values to the inference state self._save_kv_cache(past_key_values) - is_finished = False # Assuming a mechanism to determine if the sequence is finished + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 2: print(f"Output data: {output_data}, new inference state: {past_key_values}, finished: {is_finished}") - return output_data, "", is_finished + return ( + np.array(output_data), + "", + is_finished + ) async def infer_tensor( self, @@ -74,13 +79,12 @@ async def infer_tensor( # Ensure the shard is loaded await self.ensure_shard(shard) - start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0 # Run the forward pass through the model layers - output_data, past_key_values = self.model.forward_layers( - start_pos, - torch.tensor(input_data), - past_key_values=past_key_values + # output_data, past_key_values + output_data = self.model.forward_layers( + input_ids=torch.tensor(input_data), + # past_key_values=past_key_values ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] @@ -89,8 +93,8 @@ async def infer_tensor( print(f"Output data shape: {output_data.shape}") return ( - output_data, - json.dumps({"start_pos": start_pos}), + np.array(output_data), + "", is_finished ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 63d7b5f9..4309734c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -63,10 +63,12 @@ def __init__(self, shard: Shard): def forward_layers( self, input_ids: torch.tensor, - past_key_values=None - ) -> Tuple[any, list]: + #past_key_values: list + ) -> torch.tensor: #-> Tuple[torch.tensor, list]: """ Forward pass through the specified layers. + + Note: past_key_values not working for model, might be a library bug """ # Embed tensor if first layer if self.shard.is_first_layer(): @@ -77,8 +79,8 @@ def forward_layers( hidden_states = input_ids # Check past key values - if past_key_values is None: - past_key_values = [None] * len(self.layers) + # if past_key_values is None: + # past_key_values = [None] * len(self.layers) # Initialize position_ids position_ids = torch.arange( @@ -87,10 +89,10 @@ def forward_layers( device=self.device ).unsqueeze(0) - new_past_key_values = [] + #new_past_key_values = [] for i, layer in enumerate(self.layers): # Get past key value if available - past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None + # past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None # Forward pass through the layer if DEBUG >= 2: @@ -100,18 +102,18 @@ def forward_layers( hidden_states, position_ids=position_ids, # past_key_value=past_key_value, - use_cache=True + # use_cache=True ) if DEBUG >= 2: print(f"\nlayer_outputs: {layer_outputs}") hidden_states = layer_outputs[0] - new_past_key_values.append(layer_outputs[1]) + # new_past_key_values.append(layer_outputs[1]) if self.shard.is_last_layer(): _, logits, _, _, = self.full_model(hidden_states, position_ids=position_ids) - return logits, new_past_key_values + return logits #, new_past_key_values else: - return hidden_states, new_past_key_values + return hidden_states#, new_past_key_values From 217ad64792ef2a01199a807208c6ac8626ab1c40 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:32:50 -0800 Subject: [PATCH 113/491] fixing logits --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 4309734c..7e7c8671 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -112,8 +112,8 @@ def forward_layers( # new_past_key_values.append(layer_outputs[1]) if self.shard.is_last_layer(): - _, logits, _, _, = self.full_model(hidden_states, position_ids=position_ids) - return logits #, new_past_key_values + logits = self.full_model.model.norm(hidden_states) + return logits.flatten() #, new_past_key_values else: return hidden_states#, new_past_key_values From 7974f23057230f23f2b1dd45d83b1148753c01b3 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:40:29 -0800 Subject: [PATCH 114/491] offloading tensor to numpy with cpu --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 15c7c82f..c2e2faf9 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -65,7 +65,7 @@ async def infer_prompt( print(f"Output data: {output_data}, new inference state: {past_key_values}, finished: {is_finished}") return ( - np.array(output_data), + output_data.cpu().numpy(), "", is_finished ) @@ -93,7 +93,7 @@ async def infer_tensor( print(f"Output data shape: {output_data.shape}") return ( - np.array(output_data), + output_data.cpu().numpy(), "", is_finished ) From 7ea9c5bc4e3420b059b3c757deef9937e1a1b8c2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:44:23 -0800 Subject: [PATCH 115/491] offloading tensor to numpy with cpu --- exo/inference/pytorch/inference.py | 56 +++++++++++++++--------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index c2e2faf9..40edb276 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -65,7 +65,7 @@ async def infer_prompt( print(f"Output data: {output_data}, new inference state: {past_key_values}, finished: {is_finished}") return ( - output_data.cpu().numpy(), + output_data.detach().numpy(), "", is_finished ) @@ -93,44 +93,44 @@ async def infer_tensor( print(f"Output data shape: {output_data.shape}") return ( - output_data.cpu().numpy(), + output_data.detach().numpy(), "", is_finished ) - def _load_kv_cache(self, past_key_values_list): - """ - Load key-value cache from the inference state. + # def _load_kv_cache(self, past_key_values_list): + # """ + # Load key-value cache from the inference state. - Args: - past_key_values_list (list): List of past key-value tensors. + # Args: + # past_key_values_list (list): List of past key-value tensors. - Returns: - list: List of loaded past key-value tensors. - """ - if past_key_values_list is None: - return [] - return [torch.tensor(kv, device=self.device) for kv in past_key_values_list] + # Returns: + # list: List of loaded past key-value tensors. + # """ + # if past_key_values_list is None: + # return [] + # return [torch.tensor(kv, device=self.device) for kv in past_key_values_list] - def _save_kv_cache(self, past_key_values): - """ - Save key-value cache to the inference state. + # def _save_kv_cache(self, past_key_values): + # """ + # Save key-value cache to the inference state. - Args: - past_key_values (list): List of past key-value tensors. + # Args: + # past_key_values (list): List of past key-value tensors. - Returns: - list: List of key-value tensors in a format suitable for saving. - """ - if past_key_values is None: - return [] + # Returns: + # list: List of key-value tensors in a format suitable for saving. + # """ + # if past_key_values is None: + # return [] - new_cache = [] - for kv in past_key_values: - if kv: - new_cache.append(kv.cpu().tolist()) + # new_cache = [] + # for kv in past_key_values: + # if kv: + # new_cache.append(kv.cpu().tolist()) - return new_cache + # return new_cache async def ensure_shard(self, shard: Optional[Shard]): """ From 222612a9714e068924f0d32d87331cee853e7197 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:45:46 -0800 Subject: [PATCH 116/491] removing inference state and key value storage --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 40edb276..0ca9f768 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -47,7 +47,7 @@ async def infer_prompt( toks = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) # Load the past key values from the inference state if available - past_key_values = self._load_kv_cache(inference_state) + # past_key_values = self._load_kv_cache(inference_state) # Run the forward pass through the model layers # output_data, past_key_values @@ -57,7 +57,7 @@ async def infer_prompt( ) # Save the past key values to the inference state - self._save_kv_cache(past_key_values) + # self._save_kv_cache(past_key_values) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] From 8478ca38ea25d26c6330177a4ccdcacb62ee22cb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:47:52 -0800 Subject: [PATCH 117/491] fixing debug error --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 0ca9f768..a3e4e4a9 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -62,7 +62,7 @@ async def infer_prompt( is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 2: - print(f"Output data: {output_data}, new inference state: {past_key_values}, finished: {is_finished}") + print(f"Output data: {output_data} finished: {is_finished}") return ( output_data.detach().numpy(), From c12b64b2a9eca971c9ef57b22de12280157488e8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:49:30 -0800 Subject: [PATCH 118/491] offloading tensor bug --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a3e4e4a9..501f3f3a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -65,7 +65,7 @@ async def infer_prompt( print(f"Output data: {output_data} finished: {is_finished}") return ( - output_data.detach().numpy(), + np.array(output_data.cpu()), "", is_finished ) @@ -93,7 +93,7 @@ async def infer_tensor( print(f"Output data shape: {output_data.shape}") return ( - output_data.detach().numpy(), + np.array(output_data.cpu()), "", is_finished ) From be8c7d41d7171b3a51fbed88a766e5fc8fe78aae Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:52:49 -0800 Subject: [PATCH 119/491] trying no_grad fix --- exo/inference/pytorch/inference.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 501f3f3a..cf66242f 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -64,9 +64,12 @@ async def infer_prompt( if DEBUG >= 2: print(f"Output data: {output_data} finished: {is_finished}") + with torch.no_grad(): + output_npa = np.array(output_data.cpu()) + return ( - np.array(output_data.cpu()), - "", + output_npa, + "", is_finished ) @@ -92,8 +95,12 @@ async def infer_tensor( if DEBUG >= 2: print(f"Output data shape: {output_data.shape}") + + with torch.no_grad(): + output_npa = np.array(output_data.cpu()) + return ( - np.array(output_data.cpu()), + output_npa, "", is_finished ) From 237ab341f72674fc1372617a5ce4d0bb3912cec5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:55:36 -0800 Subject: [PATCH 120/491] fixing embed error --- exo/inference/pytorch/model/hf.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 7e7c8671..240a8480 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -62,7 +62,7 @@ def __init__(self, shard: Shard): def forward_layers( self, - input_ids: torch.tensor, + hidden_states: torch.tensor, #past_key_values: list ) -> torch.tensor: #-> Tuple[torch.tensor, list]: """ @@ -71,12 +71,12 @@ def forward_layers( Note: past_key_values not working for model, might be a library bug """ # Embed tensor if first layer - if self.shard.is_first_layer(): - if DEBUG >= 2: - print(f"Embedding first layer input_ids {input_ids.shape}") - hidden_states = self.embed_tokens(input_ids) - else: - hidden_states = input_ids + # if self.shard.is_first_layer(): + # if DEBUG >= 2: + # print(f"Embedding first layer input_ids {input_ids.shape}") + # hidden_states = self.embed_tokens(input_ids) + # else: + # hidden_states = input_ids # Check past key values # if past_key_values is None: From 015bd4c8590bc71e9a34bf1c31434b3cc98a9836 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 20:59:07 -0800 Subject: [PATCH 121/491] removing embed --- exo/inference/pytorch/model/hf.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 240a8480..adb7d480 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -73,10 +73,15 @@ def forward_layers( # Embed tensor if first layer # if self.shard.is_first_layer(): # if DEBUG >= 2: - # print(f"Embedding first layer input_ids {input_ids.shape}") - # hidden_states = self.embed_tokens(input_ids) + # print(f"Embedding first layer input_ids {hidden_states.shape}") + + # # flatten to 1d and turn to long + # if hidden_states.dim() > 1: + # hidden_states = hidden_states.view(-1) + # hidden_states = hidden_states.long() + # hidden_states = self.embed_tokens(hidden_states) # else: - # hidden_states = input_ids + # hidden_states = hidden_states # Check past key values # if past_key_values is None: From ace849893385f60e595d569b262c506054fb774c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:06:15 -0800 Subject: [PATCH 122/491] cleaning up tokenizer --- exo/inference/pytorch/inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index cf66242f..fbe3d975 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -44,7 +44,7 @@ async def infer_prompt( await self.ensure_shard(shard) # Tokenize the prompt - toks = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + toks = self.tokenizer.encode(prompt, return_tensors="pt").input_ids.to(self.device) # Load the past key values from the inference state if available # past_key_values = self._load_kv_cache(inference_state) @@ -52,7 +52,7 @@ async def infer_prompt( # Run the forward pass through the model layers # output_data, past_key_values output_data = self.model.forward_layers( - input_ids=toks, + torch.tensor(toks), # past_key_values=past_key_values ) @@ -86,7 +86,7 @@ async def infer_tensor( # Run the forward pass through the model layers # output_data, past_key_values output_data = self.model.forward_layers( - input_ids=torch.tensor(input_data), + torch.tensor(input_data), # past_key_values=past_key_values ) From 248c2168cd5f953ac1340c7d7537e33cd62e99af Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:07:40 -0800 Subject: [PATCH 123/491] cleaning up tokenizer --- exo/inference/pytorch/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index fbe3d975..05b40f7f 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -44,7 +44,8 @@ async def infer_prompt( await self.ensure_shard(shard) # Tokenize the prompt - toks = self.tokenizer.encode(prompt, return_tensors="pt").input_ids.to(self.device) + toks = self.tokenizer.encode(prompt, return_tensors="pt") + #.input_ids.to(self.device) # Load the past key values from the inference state if available # past_key_values = self._load_kv_cache(inference_state) From c6ac9a3eca9c1082306f18b127b08c6da9bb81e9 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:11:53 -0800 Subject: [PATCH 124/491] cleaning up tokenizer --- exo/inference/pytorch/inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 05b40f7f..a1937f83 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -44,8 +44,7 @@ async def infer_prompt( await self.ensure_shard(shard) # Tokenize the prompt - toks = self.tokenizer.encode(prompt, return_tensors="pt") - #.input_ids.to(self.device) + toks = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) # Load the past key values from the inference state if available # past_key_values = self._load_kv_cache(inference_state) From 24cd9f1f6ff46ef598d9022162d4860a5087f47d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:19:46 -0800 Subject: [PATCH 125/491] cleaning up tokenizer --- exo/inference/pytorch/inference.py | 4 ++++ exo/inference/pytorch/model/hf.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a1937f83..8083520e 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -45,12 +45,16 @@ async def infer_prompt( # Tokenize the prompt toks = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + toks_tensor = torch.tensor(toks) # Load the past key values from the inference state if available # past_key_values = self._load_kv_cache(inference_state) # Run the forward pass through the model layers # output_data, past_key_values + if DEBUG >= 2: + print(f"toks: {toks}\ntoks_tensor: {toks_tensor}") + output_data = self.model.forward_layers( torch.tensor(toks), # past_key_values=past_key_values diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index adb7d480..3b94cb29 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -26,8 +26,8 @@ def __init__(self, shard: Shard): print(f"\nlayer amount: {len(self.full_model.model.layers)}") self.layers = [] for i in range(shard.start_layer, shard.end_layer + 1): - if DEBUG >= 2: - print(f"layer[{i}]: {self.full_model.model.layers[i]}") + # if DEBUG >= 2: + # print(f"loading layer[{i}]: {self.full_model.model.layers[i]}") self.layers.append(self.full_model.model.layers[i]) From a2b9d79d291e7ebbc85399babfa40c9562c33dc3 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:23:43 -0800 Subject: [PATCH 126/491] cleaning up tokenizer --- exo/inference/pytorch/inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 8083520e..0219850c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -45,7 +45,7 @@ async def infer_prompt( # Tokenize the prompt toks = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) - toks_tensor = torch.tensor(toks) + toks_embed = self.model.embed_tokens(torch.tensor(toks)) # Load the past key values from the inference state if available # past_key_values = self._load_kv_cache(inference_state) @@ -53,10 +53,10 @@ async def infer_prompt( # Run the forward pass through the model layers # output_data, past_key_values if DEBUG >= 2: - print(f"toks: {toks}\ntoks_tensor: {toks_tensor}") + print(f"toks: {toks}\ntoks_embed: {toks_embed}") output_data = self.model.forward_layers( - torch.tensor(toks), + toks_embed, # past_key_values=past_key_values ) From 935b1559f1f43c55e79acb91f123d2f19855d8bc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:28:10 -0800 Subject: [PATCH 127/491] removing flattening and logit, just return hidden states --- exo/inference/pytorch/model/hf.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 3b94cb29..7632d092 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -116,9 +116,10 @@ def forward_layers( hidden_states = layer_outputs[0] # new_past_key_values.append(layer_outputs[1]) - if self.shard.is_last_layer(): - logits = self.full_model.model.norm(hidden_states) - return logits.flatten() #, new_past_key_values - else: - return hidden_states#, new_past_key_values + return hidden_states + # if self.shard.is_last_layer(): + # logits = self.full_model.model.norm(hidden_states) + # return logits.flatten() #, new_past_key_values + # else: + # return hidden_states#, new_past_key_values From 922fb40ffd72087313921f8c6ecc6c45f033920a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:34:53 -0800 Subject: [PATCH 128/491] fix for last layer and stating when finished --- exo/inference/pytorch/inference.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 0219850c..164130fd 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -55,10 +55,14 @@ async def infer_prompt( if DEBUG >= 2: print(f"toks: {toks}\ntoks_embed: {toks_embed}") - output_data = self.model.forward_layers( - toks_embed, - # past_key_values=past_key_values - ) + if shard.is_last_layer(): + output_data = self.full_model.model.norm(toks_embed) + output_data = output_data.flatten() + else: + output_data = self.model.forward_layers( + toks_embed, + # past_key_values=past_key_values + ) # Save the past key values to the inference state # self._save_kv_cache(past_key_values) @@ -89,10 +93,15 @@ async def infer_tensor( # Run the forward pass through the model layers # output_data, past_key_values - output_data = self.model.forward_layers( - torch.tensor(input_data), - # past_key_values=past_key_values - ) + in_tensor = torch.tensor(input_data) + if shard.is_last_layer(): + output_data = self.full_model.model.norm(in_tensor) + output_data = output_data.flatten() + else: + output_data = self.model.forward_layers( + in_tensor, + # past_key_values=past_key_values + ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] From 5564f00d7500f1a11250508b1022219a23986111 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:43:20 -0800 Subject: [PATCH 129/491] making layer adjustments --- exo/inference/pytorch/inference.py | 10 +++++----- exo/inference/pytorch/model/hf.py | 20 ++++++++------------ 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 164130fd..46d037b7 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -44,8 +44,8 @@ async def infer_prompt( await self.ensure_shard(shard) # Tokenize the prompt - toks = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) - toks_embed = self.model.embed_tokens(torch.tensor(toks)) + tokens = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + tokens_tensor = torch.tensor(tokens) # Load the past key values from the inference state if available # past_key_values = self._load_kv_cache(inference_state) @@ -53,14 +53,14 @@ async def infer_prompt( # Run the forward pass through the model layers # output_data, past_key_values if DEBUG >= 2: - print(f"toks: {toks}\ntoks_embed: {toks_embed}") + print(f"tokens: {tokens}\ntokens_tensor: {tokens_tensor}") if shard.is_last_layer(): - output_data = self.full_model.model.norm(toks_embed) + output_data = self.full_model.model.norm(tokens_tensor) output_data = output_data.flatten() else: output_data = self.model.forward_layers( - toks_embed, + tokens_tensor, # past_key_values=past_key_values ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 7632d092..dd12d049 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -62,7 +62,7 @@ def __init__(self, shard: Shard): def forward_layers( self, - hidden_states: torch.tensor, + input_data: torch.tensor, #past_key_values: list ) -> torch.tensor: #-> Tuple[torch.tensor, list]: """ @@ -71,17 +71,12 @@ def forward_layers( Note: past_key_values not working for model, might be a library bug """ # Embed tensor if first layer - # if self.shard.is_first_layer(): - # if DEBUG >= 2: - # print(f"Embedding first layer input_ids {hidden_states.shape}") + if self.shard.is_first_layer(): + if DEBUG >= 2: + print(f"Embedding for first layer {input_data.shape}") - # # flatten to 1d and turn to long - # if hidden_states.dim() > 1: - # hidden_states = hidden_states.view(-1) - # hidden_states = hidden_states.long() - # hidden_states = self.embed_tokens(hidden_states) - # else: - # hidden_states = hidden_states + # flatten to 1d and turn to long + input_data = self.embed_tokens(input_data) # Check past key values # if past_key_values is None: @@ -89,12 +84,13 @@ def forward_layers( # Initialize position_ids position_ids = torch.arange( - hidden_states.size(1), + input_data.size(1), dtype=torch.long, device=self.device ).unsqueeze(0) #new_past_key_values = [] + hidden_states = input_data for i, layer in enumerate(self.layers): # Get past key value if available # past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None From d436641c2d3666eb370883f2a2732afbffb1af59 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:45:31 -0800 Subject: [PATCH 130/491] typo --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index dd12d049..66f0a9c4 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -34,9 +34,9 @@ def __init__(self, shard: Shard): # self.layers = torch.nn.ModuleList(layer_list) # Embeddings and final layer norm + # used for doing what forward LlamaModel does in transformers self.embed_tokens = self.full_model.model.embed_tokens self.norm = self.full_model.model.norm - self.lm_head = self.full_model.lm_head # def prefill(self, tokens: list[int], start_pos: int=0) -> int: # print(f"\nprefill called") From 75e8e69c61f84b2d70f84f0c0460bddee35acb0e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:47:37 -0800 Subject: [PATCH 131/491] typo --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 46d037b7..6aeedf80 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -56,7 +56,7 @@ async def infer_prompt( print(f"tokens: {tokens}\ntokens_tensor: {tokens_tensor}") if shard.is_last_layer(): - output_data = self.full_model.model.norm(tokens_tensor) + output_data = self.model.norm(tokens_tensor) output_data = output_data.flatten() else: output_data = self.model.forward_layers( @@ -95,7 +95,7 @@ async def infer_tensor( # output_data, past_key_values in_tensor = torch.tensor(input_data) if shard.is_last_layer(): - output_data = self.full_model.model.norm(in_tensor) + output_data = self.model.norm(in_tensor) output_data = output_data.flatten() else: output_data = self.model.forward_layers( From 9e04930d3899f653b3d58fc2ff9f82804dfaeed0 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 21:59:33 -0800 Subject: [PATCH 132/491] debugging and embedding everything --- exo/inference/pytorch/inference.py | 7 +++---- exo/inference/pytorch/model/hf.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 6aeedf80..0955ecbe 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -45,7 +45,6 @@ async def infer_prompt( # Tokenize the prompt tokens = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) - tokens_tensor = torch.tensor(tokens) # Load the past key values from the inference state if available # past_key_values = self._load_kv_cache(inference_state) @@ -53,14 +52,14 @@ async def infer_prompt( # Run the forward pass through the model layers # output_data, past_key_values if DEBUG >= 2: - print(f"tokens: {tokens}\ntokens_tensor: {tokens_tensor}") + print(f"tokens: {tokens}\n") if shard.is_last_layer(): - output_data = self.model.norm(tokens_tensor) + output_data = self.model.norm(tokens) output_data = output_data.flatten() else: output_data = self.model.forward_layers( - tokens_tensor, + tokens, # past_key_values=past_key_values ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 66f0a9c4..bd3d2be7 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -69,14 +69,14 @@ def forward_layers( Forward pass through the specified layers. Note: past_key_values not working for model, might be a library bug - """ - # Embed tensor if first layer - if self.shard.is_first_layer(): - if DEBUG >= 2: - print(f"Embedding for first layer {input_data.shape}") - - # flatten to 1d and turn to long - input_data = self.embed_tokens(input_data) + """ + if DEBUG >= 2: + print(f"forward_layer call\ninput_data: {input_data}") + + # flatten to 1d and turn to long + input_data = self.embed_tokens(input_data) + if DEBUG >= 2: + print(f"embedded input_data {input_data}") # Check past key values # if past_key_values is None: From d7bddc16bf77d4026a05d5e1f2602cbda80197fd Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 22:00:23 -0800 Subject: [PATCH 133/491] embed before norm --- exo/inference/pytorch/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 0955ecbe..73fc7933 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -55,7 +55,8 @@ async def infer_prompt( print(f"tokens: {tokens}\n") if shard.is_last_layer(): - output_data = self.model.norm(tokens) + tokens_embed = self.model.embed_tokens(tokens) + output_data = self.model.norm(tokens_embed) output_data = output_data.flatten() else: output_data = self.model.forward_layers( From f8df8d3702b540cd8af00bef32348d9f71776388 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 22:02:01 -0800 Subject: [PATCH 134/491] fixing last layer output --- exo/inference/pytorch/inference.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 73fc7933..37c9b849 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -54,15 +54,16 @@ async def infer_prompt( if DEBUG >= 2: print(f"tokens: {tokens}\n") + + + output_data = self.model.forward_layers( + tokens, + # past_key_values=past_key_values + ) + if shard.is_last_layer(): - tokens_embed = self.model.embed_tokens(tokens) - output_data = self.model.norm(tokens_embed) + output_data = self.model.norm(output_data) output_data = output_data.flatten() - else: - output_data = self.model.forward_layers( - tokens, - # past_key_values=past_key_values - ) # Save the past key values to the inference state # self._save_kv_cache(past_key_values) From 9a511569b6bb1a83360c1c5968f12a29ea77ca53 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 22:06:37 -0800 Subject: [PATCH 135/491] checking layer outputs and inputs --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index bd3d2be7..db61b720 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -97,7 +97,7 @@ def forward_layers( # Forward pass through the layer if DEBUG >= 2: - print(f"\nPass tensor to layer[{i}] {layer}") + print(f"\nhidden_states {hidden_states}") layer_outputs = layer( hidden_states, @@ -107,7 +107,7 @@ def forward_layers( ) if DEBUG >= 2: - print(f"\nlayer_outputs: {layer_outputs}") + print(f"\nlayer_outputs[0]: {layer_outputs[0]}") hidden_states = layer_outputs[0] # new_past_key_values.append(layer_outputs[1]) From 4514427257ffb291ccfdb438eadb0df049c6171a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 22:08:16 -0800 Subject: [PATCH 136/491] reshaping for norm layer call at end --- exo/inference/pytorch/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 37c9b849..51e07460 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -62,6 +62,7 @@ async def infer_prompt( ) if shard.is_last_layer(): + output_data = output_data.view(1, -1, 4096) output_data = self.model.norm(output_data) output_data = output_data.flatten() From 989f29e8ca9e4ee1bceaf37253821b8f5cf2f49e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 22:11:19 -0800 Subject: [PATCH 137/491] reshaping for norm layer call at end --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index db61b720..45ee29c5 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -97,7 +97,7 @@ def forward_layers( # Forward pass through the layer if DEBUG >= 2: - print(f"\nhidden_states {hidden_states}") + print(f"\n[layer {i}] hidden_states {hidden_states}") layer_outputs = layer( hidden_states, @@ -107,7 +107,7 @@ def forward_layers( ) if DEBUG >= 2: - print(f"\nlayer_outputs[0]: {layer_outputs[0]}") + print(f"\n[layer {i}] layer_outputs: {layer_outputs[0]}") hidden_states = layer_outputs[0] # new_past_key_values.append(layer_outputs[1]) From 2fe90e87401b90263c55dcc278729716da1b99b9 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 22:15:50 -0800 Subject: [PATCH 138/491] remove flatten --- exo/inference/pytorch/inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 51e07460..51132946 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -62,9 +62,9 @@ async def infer_prompt( ) if shard.is_last_layer(): - output_data = output_data.view(1, -1, 4096) + # output_data = output_data.view(1, -1, 4096) output_data = self.model.norm(output_data) - output_data = output_data.flatten() + # output_data = output_data.flatten() # Save the past key values to the inference state # self._save_kv_cache(past_key_values) @@ -98,7 +98,7 @@ async def infer_tensor( in_tensor = torch.tensor(input_data) if shard.is_last_layer(): output_data = self.model.norm(in_tensor) - output_data = output_data.flatten() + # output_data = output_data.flatten() else: output_data = self.model.forward_layers( in_tensor, From bc235f749133680e4b993da1b1d2d2a7ce617862 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 22:21:32 -0800 Subject: [PATCH 139/491] encode tensor to set eos_token_id --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 51132946..ebf7dafd 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -44,7 +44,7 @@ async def infer_prompt( await self.ensure_shard(shard) # Tokenize the prompt - tokens = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + tokens = torch.tensor(self.tokenizer.encode(prompt, return_tensors="pt").input_ids.to(self.device)) # Load the past key values from the inference state if available # past_key_values = self._load_kv_cache(inference_state) From c60d6186762e6461e781fc9c7f7d22cd5d82d9e8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 22:25:02 -0800 Subject: [PATCH 140/491] enable eos_token --- exo/inference/pytorch/inference.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ebf7dafd..ac63ad09 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -168,7 +168,11 @@ async def ensure_shard(self, shard: Optional[Shard]): print(f"Loading new shard: {shard}") self.model = ShardedHuggingFaceModel(shard) - self.tokenizer = AutoTokenizer.from_pretrained(shard.model_id) + self.tokenizer = AutoTokenizer.from_pretrained( + shard.model_id, + add_eos_token=True, + use_fast=True + ) self.shard = shard if DEBUG >= 2: From e96d146da15f20176f09a76a5e67341f05899c7a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 8 Aug 2024 22:26:54 -0800 Subject: [PATCH 141/491] token fix --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ac63ad09..e1fc791a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -44,7 +44,7 @@ async def infer_prompt( await self.ensure_shard(shard) # Tokenize the prompt - tokens = torch.tensor(self.tokenizer.encode(prompt, return_tensors="pt").input_ids.to(self.device)) + tokens = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) # Load the past key values from the inference state if available # past_key_values = self._load_kv_cache(inference_state) From 8feae93b261a3d498681eb87d0c96cd5835f187b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 14:45:29 -0800 Subject: [PATCH 142/491] adding more logging to fix infinite infer_tensor issue --- exo/inference/pytorch/inference.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index e1fc791a..517a22fd 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -73,6 +73,7 @@ async def infer_prompt( if DEBUG >= 2: print(f"Output data: {output_data} finished: {is_finished}") + print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") with torch.no_grad(): output_npa = np.array(output_data.cpu()) @@ -89,13 +90,18 @@ async def infer_tensor( shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + + in_tensor = torch.tensor(input_data) + if DEBUG >= 2: + print(f"input_data: {input_data}\n") + print(f"in_tensor: {in_tensor}\n") # Ensure the shard is loaded await self.ensure_shard(shard) # Run the forward pass through the model layers # output_data, past_key_values - in_tensor = torch.tensor(input_data) + if shard.is_last_layer(): output_data = self.model.norm(in_tensor) # output_data = output_data.flatten() @@ -108,7 +114,8 @@ async def infer_tensor( is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 2: - print(f"Output data shape: {output_data.shape}") + print(f"Output data: {output_data} finished: {is_finished}") + print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") with torch.no_grad(): From 8922e0a460b4a04019f2d1259db83b34c1276b23 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 14:52:12 -0800 Subject: [PATCH 143/491] testing if just needing to see if layer is normalized it will be finished --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 517a22fd..d5a5d8ae 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -69,7 +69,7 @@ async def infer_prompt( # Save the past key values to the inference state # self._save_kv_cache(past_key_values) - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + is_finished = output_data.size == 1 if DEBUG >= 2: print(f"Output data: {output_data} finished: {is_finished}") @@ -111,7 +111,7 @@ async def infer_tensor( # past_key_values=past_key_values ) - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + is_finished = output_data.size == 1 if DEBUG >= 2: print(f"Output data: {output_data} finished: {is_finished}") From 65728857cc0165631b5195d8fdb1ab03f2ea6f4f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 14:59:13 -0800 Subject: [PATCH 144/491] checking normalization, adding last layer check to forward --- exo/inference/pytorch/inference.py | 18 ++++-------------- exo/inference/pytorch/model/hf.py | 9 ++++++++- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index d5a5d8ae..b06ac216 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -61,11 +61,6 @@ async def infer_prompt( # past_key_values=past_key_values ) - if shard.is_last_layer(): - # output_data = output_data.view(1, -1, 4096) - output_data = self.model.norm(output_data) - # output_data = output_data.flatten() - # Save the past key values to the inference state # self._save_kv_cache(past_key_values) @@ -102,20 +97,15 @@ async def infer_tensor( # Run the forward pass through the model layers # output_data, past_key_values - if shard.is_last_layer(): - output_data = self.model.norm(in_tensor) - # output_data = output_data.flatten() - else: - output_data = self.model.forward_layers( - in_tensor, - # past_key_values=past_key_values - ) + output_data = self.model.forward_layers( + in_tensor, + # past_key_values=past_key_values + ) is_finished = output_data.size == 1 if DEBUG >= 2: print(f"Output data: {output_data} finished: {is_finished}") - print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") with torch.no_grad(): diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 45ee29c5..a412cb02 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -110,7 +110,14 @@ def forward_layers( print(f"\n[layer {i}] layer_outputs: {layer_outputs[0]}") hidden_states = layer_outputs[0] - # new_past_key_values.append(layer_outputs[1]) + + if DEBUG >= 2: + print(f"is last layer? {self.shard.is_last_layer}") + print(f"layer count {self.shard.get_layer_count()}") + + if self.shard.is_last_layer(): + # output_data = output_data.view(1, -1, 4096) + output_data = self.model.norm(hidden_states) return hidden_states # if self.shard.is_last_layer(): From 35b79d2db031a967a62fa0a4d17d991d15bec24c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:02:58 -0800 Subject: [PATCH 145/491] fixing norm utilization, adding logging --- exo/inference/pytorch/model/hf.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index a412cb02..0c7b9c11 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -72,8 +72,10 @@ def forward_layers( """ if DEBUG >= 2: print(f"forward_layer call\ninput_data: {input_data}") + print(f"1 is last layer? {self.shard.is_last_layer}") + print(f"1 layer count {self.shard.get_layer_count()}") - # flatten to 1d and turn to long + # embed data input_data = self.embed_tokens(input_data) if DEBUG >= 2: print(f"embedded input_data {input_data}") @@ -112,12 +114,12 @@ def forward_layers( hidden_states = layer_outputs[0] if DEBUG >= 2: - print(f"is last layer? {self.shard.is_last_layer}") - print(f"layer count {self.shard.get_layer_count()}") + print(f"2 is last layer? {self.shard.is_last_layer}") + print(f"2 layer count {self.shard.get_layer_count()}") if self.shard.is_last_layer(): # output_data = output_data.view(1, -1, 4096) - output_data = self.model.norm(hidden_states) + return self.norm(hidden_states) return hidden_states # if self.shard.is_last_layer(): From ee8b76b6b846312d285f24380ae452d33cb41cfd Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:08:31 -0800 Subject: [PATCH 146/491] adding logging, checking embed error --- exo/inference/pytorch/inference.py | 2 ++ exo/inference/pytorch/model/hf.py | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index b06ac216..767dbace 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -67,6 +67,7 @@ async def infer_prompt( is_finished = output_data.size == 1 if DEBUG >= 2: + print("infer_prompt called") print(f"Output data: {output_data} finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") @@ -88,6 +89,7 @@ async def infer_tensor( in_tensor = torch.tensor(input_data) if DEBUG >= 2: + print("infer_tensor called") print(f"input_data: {input_data}\n") print(f"in_tensor: {in_tensor}\n") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 0c7b9c11..e2cbf43c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -72,8 +72,7 @@ def forward_layers( """ if DEBUG >= 2: print(f"forward_layer call\ninput_data: {input_data}") - print(f"1 is last layer? {self.shard.is_last_layer}") - print(f"1 layer count {self.shard.get_layer_count()}") + print(f"1 shard {self.shard.to_dict()}") # embed data input_data = self.embed_tokens(input_data) @@ -114,8 +113,8 @@ def forward_layers( hidden_states = layer_outputs[0] if DEBUG >= 2: - print(f"2 is last layer? {self.shard.is_last_layer}") - print(f"2 layer count {self.shard.get_layer_count()}") + print(f"2 is last layer? {self.shard.is_last_layer()}") + print(f"2 shard {self.shard.to_dict()}") if self.shard.is_last_layer(): # output_data = output_data.view(1, -1, 4096) From 310e607e67dd0329947cf8ede2e59d31524cf4a6 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:10:42 -0800 Subject: [PATCH 147/491] making it so embed only happens with first layer --- exo/inference/pytorch/model/hf.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index e2cbf43c..50c9dd26 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -71,13 +71,16 @@ def forward_layers( Note: past_key_values not working for model, might be a library bug """ if DEBUG >= 2: - print(f"forward_layer call\ninput_data: {input_data}") + print("forward_layer call") + print(f"input_data: {input_data}") print(f"1 shard {self.shard.to_dict()}") # embed data - input_data = self.embed_tokens(input_data) - if DEBUG >= 2: - print(f"embedded input_data {input_data}") + if self.shard.is_first_layer(): + input_data = self.embed_tokens(input_data) + + if DEBUG >= 2: + print(f"embedded input_data {input_data}") # Check past key values # if past_key_values is None: From d7a77e9f5e6632ab98027a7c9c380073af5e895e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:13:52 -0800 Subject: [PATCH 148/491] putting in layer loop --- exo/inference/pytorch/model/hf.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 50c9dd26..b2a34d6a 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -75,13 +75,6 @@ def forward_layers( print(f"input_data: {input_data}") print(f"1 shard {self.shard.to_dict()}") - # embed data - if self.shard.is_first_layer(): - input_data = self.embed_tokens(input_data) - - if DEBUG >= 2: - print(f"embedded input_data {input_data}") - # Check past key values # if past_key_values is None: # past_key_values = [None] * len(self.layers) @@ -98,6 +91,12 @@ def forward_layers( for i, layer in enumerate(self.layers): # Get past key value if available # past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None + + # embed only at first layer + if i == self.shard.start_layer: + input_data = self.embed_tokens(input_data) + if DEBUG >= 2: + print(f"embedded input_data {input_data}") # Forward pass through the layer if DEBUG >= 2: From ba4f9653494d6cdaf4afc4fc0e624d9520751f74 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:16:06 -0800 Subject: [PATCH 149/491] putting in layer loop --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index b2a34d6a..220c1db0 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -94,9 +94,9 @@ def forward_layers( # embed only at first layer if i == self.shard.start_layer: - input_data = self.embed_tokens(input_data) + hidden_states = self.embed_tokens(hidden_states) if DEBUG >= 2: - print(f"embedded input_data {input_data}") + print(f"embedded hidden_states {hidden_states}") # Forward pass through the layer if DEBUG >= 2: From 8fe37002d0a2c4518a3355785d3f361518359967 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:19:08 -0800 Subject: [PATCH 150/491] putting in layer loop --- exo/inference/pytorch/model/hf.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 220c1db0..c3c51994 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -89,19 +89,20 @@ def forward_layers( #new_past_key_values = [] hidden_states = input_data for i, layer in enumerate(self.layers): + # Forward pass through the layer + if DEBUG >= 2: + print(f"\n[layer {i}] {layer}") + print(f"hidden_states {hidden_states}") + # Get past key value if available # past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None # embed only at first layer - if i == self.shard.start_layer: + if i == 0: hidden_states = self.embed_tokens(hidden_states) if DEBUG >= 2: print(f"embedded hidden_states {hidden_states}") - # Forward pass through the layer - if DEBUG >= 2: - print(f"\n[layer {i}] hidden_states {hidden_states}") - layer_outputs = layer( hidden_states, position_ids=position_ids, From 5ecda20ef7dd27fef9d2c007d2b75f6c6fa2becb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:23:13 -0800 Subject: [PATCH 151/491] tensor bug --- exo/inference/pytorch/model/hf.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index c3c51994..8e8c2c58 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -98,10 +98,9 @@ def forward_layers( # past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None # embed only at first layer - if i == 0: - hidden_states = self.embed_tokens(hidden_states) - if DEBUG >= 2: - print(f"embedded hidden_states {hidden_states}") + hidden_states = self.embed_tokens(hidden_states) + if DEBUG >= 2: + print(f"embedded hidden_states {hidden_states}") layer_outputs = layer( hidden_states, From 57843f1220154e458855e035cc7ca127aa01ab51 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:26:48 -0800 Subject: [PATCH 152/491] adding param to forward layer to check where infer is coming from --- exo/inference/pytorch/inference.py | 2 ++ exo/inference/pytorch/model/hf.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 767dbace..50cce140 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -58,6 +58,7 @@ async def infer_prompt( output_data = self.model.forward_layers( tokens, + "prompt" # past_key_values=past_key_values ) @@ -101,6 +102,7 @@ async def infer_tensor( output_data = self.model.forward_layers( in_tensor, + "tensor" # past_key_values=past_key_values ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 8e8c2c58..e2925594 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -63,6 +63,7 @@ def __init__(self, shard: Shard): def forward_layers( self, input_data: torch.tensor, + infer_from: str #past_key_values: list ) -> torch.tensor: #-> Tuple[torch.tensor, list]: """ @@ -97,10 +98,14 @@ def forward_layers( # Get past key value if available # past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None - # embed only at first layer - hidden_states = self.embed_tokens(hidden_states) - if DEBUG >= 2: - print(f"embedded hidden_states {hidden_states}") + # embed only at first layer and infer prompt + if i == 0 and infer_from == "prompt": + if DEBUG >= 2: + print("first layer and infer_prompt") + + hidden_states = self.embed_tokens(hidden_states) + if DEBUG >= 2: + print(f"embedded hidden_states {hidden_states}") layer_outputs = layer( hidden_states, From 9d52ed7714ae61c1fd0fafb723f66a0bed1316cb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:33:10 -0800 Subject: [PATCH 153/491] infinity processing loop bug --- exo/inference/pytorch/model/hf.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index e2925594..5627f1cb 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -99,7 +99,7 @@ def forward_layers( # past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None # embed only at first layer and infer prompt - if i == 0 and infer_from == "prompt": + if self.shard.start_layer == i and infer_from == "prompt": if DEBUG >= 2: print("first layer and infer_prompt") @@ -120,11 +120,9 @@ def forward_layers( hidden_states = layer_outputs[0] if DEBUG >= 2: - print(f"2 is last layer? {self.shard.is_last_layer()}") print(f"2 shard {self.shard.to_dict()}") - if self.shard.is_last_layer(): - # output_data = output_data.view(1, -1, 4096) + if i == self.shard.end_layer: return self.norm(hidden_states) return hidden_states From 1eca0e63ca8332e8db517d7d92c85c9dae0038a7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:42:07 -0800 Subject: [PATCH 154/491] infinity processing loop bug --- exo/inference/pytorch/model/hf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 5627f1cb..02e6a7f9 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -123,7 +123,9 @@ def forward_layers( print(f"2 shard {self.shard.to_dict()}") if i == self.shard.end_layer: - return self.norm(hidden_states) + print(f"last layer, normalize hidden states") + hs_norm = self.norm(hidden_states) + return hs_norm return hidden_states # if self.shard.is_last_layer(): From a17bd55cad7faa0ceb51c2b6aa1bdca30e209cd0 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:51:06 -0800 Subject: [PATCH 155/491] trying to flatten --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 02e6a7f9..cb87d01b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -65,7 +65,7 @@ def forward_layers( input_data: torch.tensor, infer_from: str #past_key_values: list - ) -> torch.tensor: #-> Tuple[torch.tensor, list]: + ) -> any: #-> Tuple[torch.tensor, list]: """ Forward pass through the specified layers. @@ -125,7 +125,7 @@ def forward_layers( if i == self.shard.end_layer: print(f"last layer, normalize hidden states") hs_norm = self.norm(hidden_states) - return hs_norm + return hs_norm.flatten() return hidden_states # if self.shard.is_last_layer(): From 3085dbcb435422b4e82151ddf9d557f5a7507a17 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:54:45 -0800 Subject: [PATCH 156/491] trying to flatten --- exo/inference/pytorch/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 50cce140..1f3d8087 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -69,7 +69,8 @@ async def infer_prompt( if DEBUG >= 2: print("infer_prompt called") - print(f"Output data: {output_data} finished: {is_finished}") + print(f"Output data: {output_data} output_data.size: {output_data.size}") + print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") with torch.no_grad(): From ccea4d216d6fc18f4dfec35d5b3e79164856a0f1 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:56:16 -0800 Subject: [PATCH 157/491] testing size fix --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 1f3d8087..f046929e 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -65,11 +65,11 @@ async def infer_prompt( # Save the past key values to the inference state # self._save_kv_cache(past_key_values) - is_finished = output_data.size == 1 + is_finished = output_data.size() == 1 if DEBUG >= 2: print("infer_prompt called") - print(f"Output data: {output_data} output_data.size: {output_data.size}") + print(f"Output data: {output_data} output_data.size: {output_data.size()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") From 36ecb1ec3a15667101848465ea493732c203fdce Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 15:58:47 -0800 Subject: [PATCH 158/491] testing size fix --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f046929e..8083f346 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -65,7 +65,7 @@ async def infer_prompt( # Save the past key values to the inference state # self._save_kv_cache(past_key_values) - is_finished = output_data.size() == 1 + is_finished = len(output_data.size()) == 1 if DEBUG >= 2: print("infer_prompt called") From b928d3e648d9fc57425358542982cd03d3850d0b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:01:41 -0800 Subject: [PATCH 159/491] testing items fix --- exo/inference/pytorch/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 8083f346..57a4cf8c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -70,6 +70,7 @@ async def infer_prompt( if DEBUG >= 2: print("infer_prompt called") print(f"Output data: {output_data} output_data.size: {output_data.size()}") + print(f"output_data {output_data.items()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") From 06b35b957c684466f75aede2aca17321f22e4363 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:03:35 -0800 Subject: [PATCH 160/491] testing items fix --- exo/inference/pytorch/inference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 57a4cf8c..1dbdb9fa 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -62,6 +62,9 @@ async def infer_prompt( # past_key_values=past_key_values ) + with torch.no_grad(): + output_npa = np.array(output_data.cpu()) + # Save the past key values to the inference state # self._save_kv_cache(past_key_values) @@ -70,13 +73,10 @@ async def infer_prompt( if DEBUG >= 2: print("infer_prompt called") print(f"Output data: {output_data} output_data.size: {output_data.size()}") - print(f"output_data {output_data.items()}") + print(f"output_data {output_npa.items()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") - with torch.no_grad(): - output_npa = np.array(output_data.cpu()) - return ( output_npa, "", From 3a0bdba5dcb8381e55020c65853fb66f5b5cc8b2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:05:03 -0800 Subject: [PATCH 161/491] testing items fix --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 1dbdb9fa..396f62b8 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -73,7 +73,7 @@ async def infer_prompt( if DEBUG >= 2: print("infer_prompt called") print(f"Output data: {output_data} output_data.size: {output_data.size()}") - print(f"output_data {output_npa.items()}") + print(f"output_data {output_npa.item()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") From f58566e1a810ab5feca656b03ee1230bad72ec87 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:06:28 -0800 Subject: [PATCH 162/491] testing items fix --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 396f62b8..9d7cf7aa 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -73,7 +73,7 @@ async def infer_prompt( if DEBUG >= 2: print("infer_prompt called") print(f"Output data: {output_data} output_data.size: {output_data.size()}") - print(f"output_data {output_npa.item()}") + print(f"output_npa {output_npa}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") From ae576269763a491a0b9cbd6f1a31add30bf5c62b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:09:40 -0800 Subject: [PATCH 163/491] testing items fix --- exo/inference/pytorch/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 9d7cf7aa..eba51c8a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -73,6 +73,7 @@ async def infer_prompt( if DEBUG >= 2: print("infer_prompt called") print(f"Output data: {output_data} output_data.size: {output_data.size()}") + print(f"output_data {output_data.item()}") print(f"output_npa {output_npa}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") From a6cdd6bcf593ee523836f476bac4b9d3de46031e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:14:04 -0800 Subject: [PATCH 164/491] testing items fix --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index eba51c8a..1ca2131e 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -73,7 +73,7 @@ async def infer_prompt( if DEBUG >= 2: print("infer_prompt called") print(f"Output data: {output_data} output_data.size: {output_data.size()}") - print(f"output_data {output_data.item()}") + print(f"output_data {output_data.squeeze().item()}") print(f"output_npa {output_npa}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") From 749c9547a6f0efa0467b0b2a8da725761574284d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:20:04 -0800 Subject: [PATCH 165/491] testing items fix --- exo/inference/pytorch/model/hf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index cb87d01b..4d5a165a 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -75,6 +75,7 @@ def forward_layers( print("forward_layer call") print(f"input_data: {input_data}") print(f"1 shard {self.shard.to_dict()}") + print(f"1 is_last_layer {self.shard.is_first_layer()}") # Check past key values # if past_key_values is None: From c8e59a7a2b574ac8c0410d174529b916568b0363 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:22:47 -0800 Subject: [PATCH 166/491] testing items fix --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 4d5a165a..0be2e267 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -75,7 +75,6 @@ def forward_layers( print("forward_layer call") print(f"input_data: {input_data}") print(f"1 shard {self.shard.to_dict()}") - print(f"1 is_last_layer {self.shard.is_first_layer()}") # Check past key values # if past_key_values is None: @@ -128,6 +127,7 @@ def forward_layers( hs_norm = self.norm(hidden_states) return hs_norm.flatten() + print(f"1 is_last_layer {self.shard.is_last_layer()}") return hidden_states # if self.shard.is_last_layer(): # logits = self.full_model.model.norm(hidden_states) From 1d68267e7ca61566f083f395c479d4e2e216e767 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:24:44 -0800 Subject: [PATCH 167/491] testing items fix --- exo/inference/pytorch/model/hf.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 0be2e267..865da29b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -119,15 +119,12 @@ def forward_layers( hidden_states = layer_outputs[0] - if DEBUG >= 2: - print(f"2 shard {self.shard.to_dict()}") - - if i == self.shard.end_layer: - print(f"last layer, normalize hidden states") - hs_norm = self.norm(hidden_states) - return hs_norm.flatten() + # if i == self.shard.end_layer: + # print(f"last layer, normalize hidden states") + # hs_norm = self.norm(hidden_states) + # return hs_norm.flatten() - print(f"1 is_last_layer {self.shard.is_last_layer()}") + print(f"2 is_last_layer {self.shard.is_last_layer()}") return hidden_states # if self.shard.is_last_layer(): # logits = self.full_model.model.norm(hidden_states) From 026bbd2f97f15d3cb1323c8b54beacd0410026bc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:28:13 -0800 Subject: [PATCH 168/491] testing items fix --- exo/inference/pytorch/model/hf.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 865da29b..fc2ef894 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -89,6 +89,12 @@ def forward_layers( #new_past_key_values = [] hidden_states = input_data + + if self.shard.is_first_layer(): + hidden_states = self.embed_tokens(hidden_states) + if DEBUG >= 2: + print(f"embedded hidden_states {hidden_states}") + for i, layer in enumerate(self.layers): # Forward pass through the layer if DEBUG >= 2: @@ -99,13 +105,13 @@ def forward_layers( # past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None # embed only at first layer and infer prompt - if self.shard.start_layer == i and infer_from == "prompt": - if DEBUG >= 2: - print("first layer and infer_prompt") + # if self.shard.start_layer == i and infer_from == "prompt": + # if DEBUG >= 2: + # print("first layer and infer_prompt") - hidden_states = self.embed_tokens(hidden_states) - if DEBUG >= 2: - print(f"embedded hidden_states {hidden_states}") + # hidden_states = self.embed_tokens(hidden_states) + # if DEBUG >= 2: + # print(f"embedded hidden_states {hidden_states}") layer_outputs = layer( hidden_states, @@ -119,12 +125,11 @@ def forward_layers( hidden_states = layer_outputs[0] - # if i == self.shard.end_layer: - # print(f"last layer, normalize hidden states") - # hs_norm = self.norm(hidden_states) - # return hs_norm.flatten() - print(f"2 is_last_layer {self.shard.is_last_layer()}") + if self.shard.is_last_layer(): + hs_norm = self.norm(hidden_states) + return hs_norm.flatten() + return hidden_states # if self.shard.is_last_layer(): # logits = self.full_model.model.norm(hidden_states) From da79891623901646e8d37d547ea80e9c6343e116 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:31:56 -0800 Subject: [PATCH 169/491] testing items fix --- exo/inference/pytorch/inference.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 1ca2131e..5e1cad06 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -70,13 +70,13 @@ async def infer_prompt( is_finished = len(output_data.size()) == 1 - if DEBUG >= 2: - print("infer_prompt called") - print(f"Output data: {output_data} output_data.size: {output_data.size()}") - print(f"output_data {output_data.squeeze().item()}") - print(f"output_npa {output_npa}") - print(f"finished: {is_finished}") - print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") + # if DEBUG >= 2: + # print("infer_prompt called") + # print(f"Output data: {output_data} output_data.size: {output_data.size()}") + # print(f"output_data {output_data.squeeze().item()}") + # print(f"output_npa {output_npa}") + # print(f"finished: {is_finished}") + # print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") return ( output_npa, From 6370f65d4f30c21dc52af61c64dabe8be451d3ad Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:35:36 -0800 Subject: [PATCH 170/491] testing items fix --- exo/inference/pytorch/inference.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 5e1cad06..6e8536ef 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -39,6 +39,8 @@ async def infer_prompt( image_str: Optional[str] = None, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 2: + print("infer_prompt called") # Ensure the shard is loaded await self.ensure_shard(shard) @@ -65,6 +67,9 @@ async def infer_prompt( with torch.no_grad(): output_npa = np.array(output_data.cpu()) + if DEBUG >= 2: + print(f"output_npa.size: {output_npa.size}") + # Save the past key values to the inference state # self._save_kv_cache(past_key_values) From e8411aeafc4815e15bf8933c403c1dd69a9dbc3d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:39:59 -0800 Subject: [PATCH 171/491] sending norm to lm_head and output to float --- exo/inference/pytorch/model/hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index fc2ef894..3161bc69 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -128,7 +128,8 @@ def forward_layers( print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): hs_norm = self.norm(hidden_states) - return hs_norm.flatten() + hs_lm_head = self.full_model.lm_head(hs_norm) + return hs_lm_head.float() return hidden_states # if self.shard.is_last_layer(): From 254d7ac1ec319e8520f6004574c052678c2d8fd2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:43:21 -0800 Subject: [PATCH 172/491] sending norm to lm_head and output to float --- exo/inference/pytorch/model/hf.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 3161bc69..8c23a61b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -128,8 +128,13 @@ def forward_layers( print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): hs_norm = self.norm(hidden_states) - hs_lm_head = self.full_model.lm_head(hs_norm) - return hs_lm_head.float() + hs_lm_head = self.full_model.lm_head(hs_norm).float() + + if DEBUG >= 2: + print(f"hs_norm: {hs_norm}") + print(f"hs_lm_head: {hs_lm_head}") + + return hs_lm_head return hidden_states # if self.shard.is_last_layer(): From f20c7f81d831738c895b3f7bb0c5f6efa2466357 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:55:09 -0800 Subject: [PATCH 173/491] sending norm to lm_head and output to float --- exo/inference/pytorch/inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 6e8536ef..fc294d16 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -67,8 +67,9 @@ async def infer_prompt( with torch.no_grad(): output_npa = np.array(output_data.cpu()) - if DEBUG >= 2: - print(f"output_npa.size: {output_npa.size}") + if DEBUG >= 2: + print(f"output_data.size(): {output_data.size()}") + print(f"output_npa.size: {output_npa.size}") # Save the past key values to the inference state # self._save_kv_cache(past_key_values) From fb5dff2fbb99fd5570d66463d42d53067e2e32f2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 16:56:42 -0800 Subject: [PATCH 174/491] sending norm to lm_head and output to float --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 8c23a61b..c3335628 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -128,7 +128,7 @@ def forward_layers( print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): hs_norm = self.norm(hidden_states) - hs_lm_head = self.full_model.lm_head(hs_norm).float() + hs_lm_head = self.full_model.lm_head(hs_norm).float().flatten() if DEBUG >= 2: print(f"hs_norm: {hs_norm}") From a6a0c2ba83d5c4bdf3ed66f77f4f5866277c7d73 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 17:18:43 -0800 Subject: [PATCH 175/491] sending norm to lm_head and output to float --- exo/inference/pytorch/inference.py | 12 +++---- exo/inference/pytorch/model/hf.py | 55 ++---------------------------- 2 files changed, 8 insertions(+), 59 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index fc294d16..a28d62a9 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -56,11 +56,8 @@ async def infer_prompt( if DEBUG >= 2: print(f"tokens: {tokens}\n") - - output_data = self.model.forward_layers( - tokens, - "prompt" + tokens # past_key_values=past_key_values ) @@ -69,12 +66,14 @@ async def infer_prompt( if DEBUG >= 2: print(f"output_data.size(): {output_data.size()}") + + print(f"output_npa: {output_npa}") print(f"output_npa.size: {output_npa.size}") # Save the past key values to the inference state # self._save_kv_cache(past_key_values) - is_finished = len(output_data.size()) == 1 + is_finished = output_npa.size == 1 # if DEBUG >= 2: # print("infer_prompt called") @@ -110,8 +109,7 @@ async def infer_tensor( # output_data, past_key_values output_data = self.model.forward_layers( - in_tensor, - "tensor" + in_tensor # past_key_values=past_key_values ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index c3335628..92f11d2c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -38,34 +38,10 @@ def __init__(self, shard: Shard): self.embed_tokens = self.full_model.model.embed_tokens self.norm = self.full_model.model.norm - # def prefill(self, tokens: list[int], start_pos: int=0) -> int: - # print(f"\nprefill called") - # """ - # Process the initial input tokens and set up the initial hidden states. - # """ - # # Assuming tokens is a 1D tensor of token IDs - # for token in tokens: - # # Convert token to a tensor and get embeddings - # token_tensor = torch.tensor([[token]], device=self.device) - # token_tensor = self.embed_tokens(token_tensor) - - # if DEBUG >= 2: - # print(f"\ntoken_tensor shape: {token_tensor.shape}") - - # # Prefill with tokens - # self.forward_layers(start_pos, token_tensor, None) - - # # Increment start position - # start_pos += 1 - - # return start_pos - def forward_layers( self, - input_data: torch.tensor, - infer_from: str - #past_key_values: list - ) -> any: #-> Tuple[torch.tensor, list]: + input_data: torch.tensor + ) -> any: """ Forward pass through the specified layers. @@ -76,10 +52,6 @@ def forward_layers( print(f"input_data: {input_data}") print(f"1 shard {self.shard.to_dict()}") - # Check past key values - # if past_key_values is None: - # past_key_values = [None] * len(self.layers) - # Initialize position_ids position_ids = torch.arange( input_data.size(1), @@ -87,7 +59,6 @@ def forward_layers( device=self.device ).unsqueeze(0) - #new_past_key_values = [] hidden_states = input_data if self.shard.is_first_layer(): @@ -100,24 +71,10 @@ def forward_layers( if DEBUG >= 2: print(f"\n[layer {i}] {layer}") print(f"hidden_states {hidden_states}") - - # Get past key value if available - # past_key_value = past_key_values[i] if past_key_values and len(past_key_values) > 0 else None - - # embed only at first layer and infer prompt - # if self.shard.start_layer == i and infer_from == "prompt": - # if DEBUG >= 2: - # print("first layer and infer_prompt") - - # hidden_states = self.embed_tokens(hidden_states) - # if DEBUG >= 2: - # print(f"embedded hidden_states {hidden_states}") layer_outputs = layer( hidden_states, - position_ids=position_ids, - # past_key_value=past_key_value, - # use_cache=True + position_ids=position_ids ) if DEBUG >= 2: @@ -137,9 +94,3 @@ def forward_layers( return hs_lm_head return hidden_states - # if self.shard.is_last_layer(): - # logits = self.full_model.model.norm(hidden_states) - # return logits.flatten() #, new_past_key_values - # else: - # return hidden_states#, new_past_key_values - From b24046f0ac41ee4f00c8c67c293c071e17d70b6b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 17:20:28 -0800 Subject: [PATCH 176/491] finish issue --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 92f11d2c..08f9ca2c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -91,6 +91,6 @@ def forward_layers( print(f"hs_norm: {hs_norm}") print(f"hs_lm_head: {hs_lm_head}") - return hs_lm_head + return (hs_lm_head, hidden_states) return hidden_states From 8687fe62a435f576683bd33ba287d985bcaee510 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 17:25:16 -0800 Subject: [PATCH 177/491] finish issue --- exo/inference/pytorch/inference.py | 2 +- exo/inference/pytorch/model/hf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a28d62a9..707aa53b 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -62,7 +62,7 @@ async def infer_prompt( ) with torch.no_grad(): - output_npa = np.array(output_data.cpu()) + output_npa = np.array([output_data.tolist()]) if DEBUG >= 2: print(f"output_data.size(): {output_data.size()}") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 08f9ca2c..92f11d2c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -91,6 +91,6 @@ def forward_layers( print(f"hs_norm: {hs_norm}") print(f"hs_lm_head: {hs_lm_head}") - return (hs_lm_head, hidden_states) + return hs_lm_head return hidden_states From 454b205dc5230fdb56297336a4fb7d1037024910 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 17:40:28 -0800 Subject: [PATCH 178/491] finish issue --- exo/inference/pytorch/model/hf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 92f11d2c..62db005b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -85,12 +85,14 @@ def forward_layers( print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): hs_norm = self.norm(hidden_states) - hs_lm_head = self.full_model.lm_head(hs_norm).float().flatten() + hs_lm_head = self.full_model.lm_head(hs_norm).float() + output_token = torch.argmax(hs_lm_head, dim=-1).cpu().numpy().flatten() if DEBUG >= 2: print(f"hs_norm: {hs_norm}") print(f"hs_lm_head: {hs_lm_head}") + print(f"output_token: {output_token}") - return hs_lm_head + return output_token return hidden_states From fb6c43d615bf978c1906ea67f2ee9359bf128946 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 17:43:32 -0800 Subject: [PATCH 179/491] fixing output data to numpy, logits finish --- exo/inference/pytorch/inference.py | 35 +++--------------------------- 1 file changed, 3 insertions(+), 32 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 707aa53b..7be2ca3c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -10,16 +10,12 @@ from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel from exo.helpers import DEBUG -# Default settings -TEMPERATURE = 0.7 -TOP_K = 50 - class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. """ - def __init__(self, debug: bool = False): + def __init__(self): """ Initialize the inference engine. @@ -48,43 +44,18 @@ async def infer_prompt( # Tokenize the prompt tokens = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) - # Load the past key values from the inference state if available - # past_key_values = self._load_kv_cache(inference_state) - # Run the forward pass through the model layers - # output_data, past_key_values if DEBUG >= 2: print(f"tokens: {tokens}\n") output_data = self.model.forward_layers( tokens - # past_key_values=past_key_values ) - with torch.no_grad(): - output_npa = np.array([output_data.tolist()]) - - if DEBUG >= 2: - print(f"output_data.size(): {output_data.size()}") - - print(f"output_npa: {output_npa}") - print(f"output_npa.size: {output_npa.size}") - - # Save the past key values to the inference state - # self._save_kv_cache(past_key_values) - - is_finished = output_npa.size == 1 - - # if DEBUG >= 2: - # print("infer_prompt called") - # print(f"Output data: {output_data} output_data.size: {output_data.size()}") - # print(f"output_data {output_data.squeeze().item()}") - # print(f"output_npa {output_npa}") - # print(f"finished: {is_finished}") - # print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") + is_finished = output_data.size == 1 return ( - output_npa, + output_data, "", is_finished ) From b6bec5441f65ad9959dcc0f353a5ec04e394658b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 17:44:41 -0800 Subject: [PATCH 180/491] fixing debug flag issue --- exo/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/helpers.py b/exo/helpers.py index b811a0f9..64940f08 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -45,7 +45,7 @@ def get_inference_engine(inference_engine_name): elif inference_engine_name == "pytorch": # will change from debug being true after testing from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine - return PyTorchDynamicShardInferenceEngine(debug=os.getenv("PYTORCH_DEBUG", default=True)) + return PyTorchDynamicShardInferenceEngine() else: raise ValueError(f"Inference engine {inference_engine_name} not supported") From b85fdecee99065e98cc66c1e792da27d106579e6 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 17:54:48 -0800 Subject: [PATCH 181/491] fixing debug flag issue --- exo/inference/pytorch/model/hf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 62db005b..c3a72cbc 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -54,7 +54,7 @@ def forward_layers( # Initialize position_ids position_ids = torch.arange( - input_data.size(1), + input_data.size(1) if input_data.size > 1 else input_data, dtype=torch.long, device=self.device ).unsqueeze(0) @@ -85,14 +85,14 @@ def forward_layers( print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): hs_norm = self.norm(hidden_states) - hs_lm_head = self.full_model.lm_head(hs_norm).float() - output_token = torch.argmax(hs_lm_head, dim=-1).cpu().numpy().flatten() + hs_lm_head = self.full_model.lm_head(hs_norm).float()[:, -1, :] + # output_token = .cpu().numpy().flatten() if DEBUG >= 2: print(f"hs_norm: {hs_norm}") print(f"hs_lm_head: {hs_lm_head}") - print(f"output_token: {output_token}") + # print(f"output_token: {output_token}") - return output_token + return hs_lm_head.cpu().numpy() - return hidden_states + return hidden_states.cpu().numpy() From b075d0f4cdca94ac32109f3a72e3b84609e69783 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 18:33:26 -0800 Subject: [PATCH 182/491] adding position embeddings --- exo/inference/pytorch/model/hf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index c3a72cbc..a7e443f1 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -54,7 +54,7 @@ def forward_layers( # Initialize position_ids position_ids = torch.arange( - input_data.size(1) if input_data.size > 1 else input_data, + input_data, dtype=torch.long, device=self.device ).unsqueeze(0) @@ -63,8 +63,11 @@ def forward_layers( if self.shard.is_first_layer(): hidden_states = self.embed_tokens(hidden_states) + position_embeddings = self.rotary_emb(hidden_states) + if DEBUG >= 2: print(f"embedded hidden_states {hidden_states}") + print(f"position_ids: {self.position_embeddings}") for i, layer in enumerate(self.layers): # Forward pass through the layer @@ -74,7 +77,7 @@ def forward_layers( layer_outputs = layer( hidden_states, - position_ids=position_ids + position_embeddings=position_embeddings ) if DEBUG >= 2: From fe604f4d6c81652cbb675713e8ad761e2750397a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 18:34:48 -0800 Subject: [PATCH 183/491] adding position embeddings --- exo/inference/pytorch/model/hf.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index a7e443f1..8d757ec1 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -52,13 +52,6 @@ def forward_layers( print(f"input_data: {input_data}") print(f"1 shard {self.shard.to_dict()}") - # Initialize position_ids - position_ids = torch.arange( - input_data, - dtype=torch.long, - device=self.device - ).unsqueeze(0) - hidden_states = input_data if self.shard.is_first_layer(): From 9ff520030d745e991b0d215325aeff6cfcef7235 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 18:36:30 -0800 Subject: [PATCH 184/491] adding position embeddings --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 8d757ec1..fdf4e76a 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -56,7 +56,7 @@ def forward_layers( if self.shard.is_first_layer(): hidden_states = self.embed_tokens(hidden_states) - position_embeddings = self.rotary_emb(hidden_states) + position_embeddings = self.full_model.model.rotary_emb(hidden_states) if DEBUG >= 2: print(f"embedded hidden_states {hidden_states}") From 955503671207c4d90efc2d67afe0f7c06e4fcef2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 18:41:01 -0800 Subject: [PATCH 185/491] adding position embeddings --- exo/inference/pytorch/model/hf.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index fdf4e76a..b5e325cd 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -56,7 +56,14 @@ def forward_layers( if self.shard.is_first_layer(): hidden_states = self.embed_tokens(hidden_states) - position_embeddings = self.full_model.model.rotary_emb(hidden_states) + + batch_size, seq_len = input_data.size() + position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) + + position_embeddings = self.full_model.model.rotary_emb( + hidden_states, + position_ids + ) if DEBUG >= 2: print(f"embedded hidden_states {hidden_states}") From e76c17420b7855a1b54a098cb30b70bbe1603bd7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 18:43:54 -0800 Subject: [PATCH 186/491] adding position embeddings --- exo/inference/pytorch/model/hf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index b5e325cd..14c9684f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -53,6 +53,8 @@ def forward_layers( print(f"1 shard {self.shard.to_dict()}") hidden_states = input_data + position_ids = None + position_embeddings = None if self.shard.is_first_layer(): hidden_states = self.embed_tokens(hidden_states) From 2d94e60bb30c4d31d2ac1013a6329297e55ba0ce Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 18:45:43 -0800 Subject: [PATCH 187/491] adding position embeddings --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 14c9684f..e6de3436 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -69,7 +69,7 @@ def forward_layers( if DEBUG >= 2: print(f"embedded hidden_states {hidden_states}") - print(f"position_ids: {self.position_embeddings}") + print(f"position_ids: {position_embeddings}") for i, layer in enumerate(self.layers): # Forward pass through the layer From 764b8a727a521abbe4600233d6b9731dfe9cc2cd Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 18:48:10 -0800 Subject: [PATCH 188/491] adding position embeddings --- exo/inference/pytorch/model/hf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index e6de3436..6e268641 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -97,7 +97,9 @@ def forward_layers( print(f"hs_norm: {hs_norm}") print(f"hs_lm_head: {hs_lm_head}") # print(f"output_token: {output_token}") - - return hs_lm_head.cpu().numpy() + with torch.no_grad(): + last_state = hs_lm_head.cpu().numpy() + + return last_state return hidden_states.cpu().numpy() From 9733e8ab45d4acd5860074406073ffda5615c66c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 19:02:53 -0800 Subject: [PATCH 189/491] adding sampling --- exo/inference/pytorch/inference.py | 4 +++ exo/inference/pytorch/model/hf.py | 56 +++++++++++++++++++++++++----- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 7be2ca3c..1a560c0a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -52,6 +52,10 @@ async def infer_prompt( tokens ) + if DEBUG >= 2: + print(f"output_data: {output_data}\n") + print(f"output_data.size {output_data.size}\n") + is_finished = output_data.size == 1 return ( diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 6e268641..c139cf72 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -4,6 +4,46 @@ from exo.helpers import DEBUG from typing import Tuple +def sample_logits(logits, temp=0.85, top_k=25, top_p=0.9, alpha_f=0.1, alpha_p=0.0): + # Apply temperature scaling + if temp > 0: + logits = logits / temp + + # Top-k sampling + if top_k > 0: + top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) + logits = torch.full_like(logits, -float('inf')) + logits.scatter_(-1, top_k_indices, top_k_values) + + # Top-p (nucleus) sampling + if 0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, -float('inf')) + + # Alpha sampling (to discourage repetition) + if alpha_f or alpha_p: + if not hasattr(sample_logits, "alpha_counter"): + setattr(sample_logits, "alpha_counter", torch.zeros_like(logits, dtype=torch.int32).contiguous()) + logits = logits - (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) + + # Sample from the logits + probabilities = F.softmax(logits, dim=-1) + sampled_token = torch.multinomial(probabilities, 1) + + # Update alpha counter + if alpha_f or alpha_p: + sample_logits.alpha_counter = (torch.arange(probabilities.numel(), device=logits.device) == sampled_token).where(sample_logits.alpha_counter + 1, sample_logits.alpha_counter) + + return sampled_token + class ShardedHuggingFaceModel(torch.nn.Module): def __init__(self, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() @@ -90,16 +130,16 @@ def forward_layers( print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): hs_norm = self.norm(hidden_states) - hs_lm_head = self.full_model.lm_head(hs_norm).float()[:, -1, :] - # output_token = .cpu().numpy().flatten() - + hs_lm_head = self.full_model.lm_head(hs_norm).float() + + # Use the sampling function with default settings + output_token = sample_logits(hs_lm_head).cpu().numpy().flatten() + if DEBUG >= 2: print(f"hs_norm: {hs_norm}") print(f"hs_lm_head: {hs_lm_head}") - # print(f"output_token: {output_token}") - with torch.no_grad(): - last_state = hs_lm_head.cpu().numpy() - - return last_state + print(f"output_token: {output_token}") + + return output_token return hidden_states.cpu().numpy() From 304659604b3c20a10109ec10cfd59aace075ffab Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 19:04:29 -0800 Subject: [PATCH 190/491] fix import --- exo/inference/pytorch/model/hf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index c139cf72..88c78ea5 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,4 +1,5 @@ import torch +from torch.nn import functional as F from transformers import AutoModelForCausalLM from exo.inference.shard import Shard from exo.helpers import DEBUG From 2a21d8ce13fc12856094ea681ca4318260bf8cdb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 19:06:53 -0800 Subject: [PATCH 191/491] lm_head fix --- exo/inference/pytorch/model/hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 88c78ea5..8bec7803 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -134,7 +134,8 @@ def forward_layers( hs_lm_head = self.full_model.lm_head(hs_norm).float() # Use the sampling function with default settings - output_token = sample_logits(hs_lm_head).cpu().numpy().flatten() + output_token = sample_logits( + hs_lm_head[:, -1, :]).cpu().numpy().flatten() if DEBUG >= 2: print(f"hs_norm: {hs_norm}") From adf1cf38b83f06982590882b7107ab4f5b0264a5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 19:36:41 -0800 Subject: [PATCH 192/491] sample fix --- exo/inference/pytorch/model/hf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 8bec7803..668537c6 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -41,7 +41,9 @@ def sample_logits(logits, temp=0.85, top_k=25, top_p=0.9, alpha_f=0.1, alpha_p=0 # Update alpha counter if alpha_f or alpha_p: - sample_logits.alpha_counter = (torch.arange(probabilities.numel(), device=logits.device) == sampled_token).where(sample_logits.alpha_counter + 1, sample_logits.alpha_counter) + condition = (torch.arange(probabilities.numel(), device=logits.device) == sampled_token) + condition = condition.bool() # Convert condition to boolean tensor + sample_logits.alpha_counter = torch.where(condition, sample_logits.alpha_counter + 1, sample_logits.alpha_counter) return sampled_token From 40f88d45a5da001f8f2797af3cc5a0e3f6e25875 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 19:49:01 -0800 Subject: [PATCH 193/491] eos fix --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 1a560c0a..ff1a3e97 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -56,7 +56,7 @@ async def infer_prompt( print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") - is_finished = output_data.size == 1 + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] return ( output_data, @@ -88,7 +88,7 @@ async def infer_tensor( # past_key_values=past_key_values ) - is_finished = output_data.size == 1 + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 2: print(f"Output data: {output_data} finished: {is_finished}") From 930944efd5f8446973fb988fc900c2d1ad808a9c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 19:59:26 -0800 Subject: [PATCH 194/491] eos fix --- exo/inference/pytorch/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ff1a3e97..de789586 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -55,6 +55,7 @@ async def infer_prompt( if DEBUG >= 2: print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") + print(f"output_data.item() {output_data.item()}") is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] From d719d8e9359872768aa18f1c9841cd7048e26007 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:01:37 -0800 Subject: [PATCH 195/491] eos fix --- exo/inference/pytorch/inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index de789586..de73831a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -72,6 +72,9 @@ async def infer_tensor( input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + if input_data.dim() == 1: + input_data = input_data.unsqueeze(0) + in_tensor = torch.tensor(input_data) if DEBUG >= 2: print("infer_tensor called") From 500f85e72fc264260fc7868841a381e3d3759636 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:02:53 -0800 Subject: [PATCH 196/491] eos fix --- exo/inference/pytorch/inference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index de73831a..84ac3036 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -71,11 +71,11 @@ async def infer_tensor( shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: - - if input_data.dim() == 1: - input_data = input_data.unsqueeze(0) - + in_tensor = torch.tensor(input_data) + if in_tensor.dim() == 1: + in_tensor = in_tensor.unsqueeze(0) + if DEBUG >= 2: print("infer_tensor called") print(f"input_data: {input_data}\n") From 03d4e866e72659aa67ce6bc56170138bcb3f890d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:06:43 -0800 Subject: [PATCH 197/491] eos fix --- exo/inference/pytorch/inference.py | 10 ++-------- exo/inference/pytorch/model/hf.py | 3 ++- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 84ac3036..e4dc7fe7 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -73,9 +73,7 @@ async def infer_tensor( inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: in_tensor = torch.tensor(input_data) - if in_tensor.dim() == 1: - in_tensor = in_tensor.unsqueeze(0) - + if DEBUG >= 2: print("infer_tensor called") print(f"input_data: {input_data}\n") @@ -97,12 +95,8 @@ async def infer_tensor( if DEBUG >= 2: print(f"Output data: {output_data} finished: {is_finished}") - - with torch.no_grad(): - output_npa = np.array(output_data.cpu()) - return ( - output_npa, + output_data, "", is_finished ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 668537c6..0aeac4e3 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,4 +1,5 @@ import torch +import numpy as np from torch.nn import functional as F from transformers import AutoModelForCausalLM from exo.inference.shard import Shard @@ -84,7 +85,7 @@ def __init__(self, shard: Shard): def forward_layers( self, input_data: torch.tensor - ) -> any: + ) -> np.ndarray: """ Forward pass through the specified layers. From b2814b45fe8e9176026d45bb4242bf2017f136f5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:17:05 -0800 Subject: [PATCH 198/491] eos fix --- exo/inference/pytorch/inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index e4dc7fe7..ffde4472 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -73,7 +73,7 @@ async def infer_tensor( inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: in_tensor = torch.tensor(input_data) - + if DEBUG >= 2: print("infer_tensor called") print(f"input_data: {input_data}\n") @@ -93,7 +93,8 @@ async def infer_tensor( is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 2: - print(f"Output data: {output_data} finished: {is_finished}") + print(f"finished: {is_finished}") + print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") return ( output_data, From a55c9a39fb70240241b7f64dfae77d80602ebd35 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:18:01 -0800 Subject: [PATCH 199/491] eos fix --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 0aeac4e3..1419b30c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -103,7 +103,7 @@ def forward_layers( if self.shard.is_first_layer(): hidden_states = self.embed_tokens(hidden_states) - batch_size, seq_len = input_data.size() + batch_size, seq_len = hidden_states.size() position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.full_model.model.rotary_emb( From ca6734e7860c0fd94a0819aaf5e37aa7eb2e8952 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:21:11 -0800 Subject: [PATCH 200/491] is finished issue --- exo/inference/pytorch/inference.py | 4 ++-- exo/inference/pytorch/model/hf.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ffde4472..4869117f 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -57,7 +57,7 @@ async def infer_prompt( print(f"output_data.size {output_data.size}\n") print(f"output_data.item() {output_data.item()}") - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + is_finished = output_data[-1] == self.tokenizer.eos_token_id and output_data.size == 1 return ( output_data, @@ -90,7 +90,7 @@ async def infer_tensor( # past_key_values=past_key_values ) - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + is_finished = output_data[-1] == self.tokenizer.eos_token_id and output_data.size == 1 if DEBUG >= 2: print(f"finished: {is_finished}") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 1419b30c..0aeac4e3 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -103,7 +103,7 @@ def forward_layers( if self.shard.is_first_layer(): hidden_states = self.embed_tokens(hidden_states) - batch_size, seq_len = hidden_states.size() + batch_size, seq_len = input_data.size() position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.full_model.model.rotary_emb( From 609c9199bdeafe176c1166a1e3b5efb3d7728ea0 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:27:10 -0800 Subject: [PATCH 201/491] is finished issue --- exo/inference/pytorch/inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 4869117f..a083ef8c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -56,6 +56,10 @@ async def infer_prompt( print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"output_data.item() {output_data.item()}") + print(f"finished: {is_finished}") + print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") + print(f"output_data[-1] {output_data[-1]}") + print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") is_finished = output_data[-1] == self.tokenizer.eos_token_id and output_data.size == 1 From b47932cff02b817b8a46a1adb5af351eb2548dda Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:28:27 -0800 Subject: [PATCH 202/491] is finished issue --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a083ef8c..3030972a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -52,6 +52,8 @@ async def infer_prompt( tokens ) + is_finished = output_data[-1] == self.tokenizer.eos_token_id and output_data.size == 1 + if DEBUG >= 2: print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") @@ -61,8 +63,6 @@ async def infer_prompt( print(f"output_data[-1] {output_data[-1]}") print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") - is_finished = output_data[-1] == self.tokenizer.eos_token_id and output_data.size == 1 - return ( output_data, "", From c878b2415a754145c391c4840c2a95b8fa827e44 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:44:10 -0800 Subject: [PATCH 203/491] is finished issue --- exo/inference/pytorch/model/hf.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 0aeac4e3..56cc5fbe 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -84,8 +84,9 @@ def __init__(self, shard: Shard): def forward_layers( self, - input_data: torch.tensor - ) -> np.ndarray: + input_data: torch.tensor, + past_kvs = [] + ) -> Tuple[np.ndarray, list]: """ Forward pass through the specified layers. @@ -99,6 +100,7 @@ def forward_layers( hidden_states = input_data position_ids = None position_embeddings = None + present_kvs = [] if self.shard.is_first_layer(): hidden_states = self.embed_tokens(hidden_states) @@ -123,13 +125,16 @@ def forward_layers( layer_outputs = layer( hidden_states, - position_embeddings=position_embeddings + position_embeddings=position_embeddings, + past_key_values=past_kvs[i] if past_kvs else None, + use_cache=True ) if DEBUG >= 2: print(f"\n[layer {i}] layer_outputs: {layer_outputs[0]}") hidden_states = layer_outputs[0] + present_kvs = layer_outputs[1] print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): @@ -145,6 +150,6 @@ def forward_layers( print(f"hs_lm_head: {hs_lm_head}") print(f"output_token: {output_token}") - return output_token + return (output_token, present_kvs) - return hidden_states.cpu().numpy() + return (hidden_states.cpu().numpy(), present_kvs) From e006edd2aa37b0e27b43ca128d3e74c16ff4b634 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:47:21 -0800 Subject: [PATCH 204/491] working on caching --- exo/inference/pytorch/model/hf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 56cc5fbe..c7f9aa61 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -122,6 +122,7 @@ def forward_layers( if DEBUG >= 2: print(f"\n[layer {i}] {layer}") print(f"hidden_states {hidden_states}") + print(f"past_kvs {past_kvs}") layer_outputs = layer( hidden_states, @@ -136,6 +137,9 @@ def forward_layers( hidden_states = layer_outputs[0] present_kvs = layer_outputs[1] + if DEBUG >= 2: + print(f"present_kvs {present_kvs}") + print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): hs_norm = self.norm(hidden_states) From d3b1106cdd70092d69e3de3fd350db2a9082a902 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:53:19 -0800 Subject: [PATCH 205/491] working on caching --- exo/inference/pytorch/model/hf.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index c7f9aa61..80aba11c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,7 +1,7 @@ import torch import numpy as np from torch.nn import functional as F -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, LlamaConfig from exo.inference.shard import Shard from exo.helpers import DEBUG from typing import Tuple @@ -58,11 +58,16 @@ def __init__(self, shard: Shard): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard + # Load the model with the configuration for caching + self.config = LlamaConfig.from_pretrained(shard.model_id) + self.config.use_cache = True # Enable caching + # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( shard.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, - device_map="auto" + device_map="auto", + config=self.config ) # Extract only the layers for this shard From bb9e9b0be3352054ea97cc80807e7cc790d7498a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 20:58:01 -0800 Subject: [PATCH 206/491] working on caching --- exo/inference/pytorch/inference.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 3030972a..c177b31c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -3,6 +3,7 @@ import numpy as np import torch import numpy as np +import json from typing import Optional, Callable, Tuple from transformers import AutoTokenizer from exo.inference.shard import Shard @@ -48,7 +49,7 @@ async def infer_prompt( if DEBUG >= 2: print(f"tokens: {tokens}\n") - output_data = self.model.forward_layers( + output_data, inference_state = self.model.forward_layers( tokens ) @@ -65,7 +66,7 @@ async def infer_prompt( return ( output_data, - "", + json.loads(inference_state), is_finished ) @@ -89,7 +90,7 @@ async def infer_tensor( # Run the forward pass through the model layers # output_data, past_key_values - output_data = self.model.forward_layers( + output_data, inference_state = self.model.forward_layers( in_tensor # past_key_values=past_key_values ) @@ -102,7 +103,7 @@ async def infer_tensor( return ( output_data, - "", + json.loads(inference_state), is_finished ) From 063b11d5fcc46b3701a0f85aff295fc16596c88c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 21:02:39 -0800 Subject: [PATCH 207/491] working on caching --- exo/inference/pytorch/inference.py | 61 ++++++++---------------------- 1 file changed, 15 insertions(+), 46 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index c177b31c..6f296785 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -39,18 +39,17 @@ async def infer_prompt( if DEBUG >= 2: print("infer_prompt called") - # Ensure the shard is loaded await self.ensure_shard(shard) - # Tokenize the prompt + inference_state = json.loads(inference_state) tokens = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) - # Run the forward pass through the model layers if DEBUG >= 2: print(f"tokens: {tokens}\n") output_data, inference_state = self.model.forward_layers( - tokens + tokens, + inference_state ) is_finished = output_data[-1] == self.tokenizer.eos_token_id and output_data.size == 1 @@ -59,6 +58,7 @@ async def infer_prompt( print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"output_data.item() {output_data.item()}") + print(f"inference_state: {inference_state}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") @@ -66,7 +66,7 @@ async def infer_prompt( return ( output_data, - json.loads(inference_state), + json.dumps(inference_state), is_finished ) @@ -78,69 +78,38 @@ async def infer_tensor( inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: in_tensor = torch.tensor(input_data) + inference_state = json.loads(inference_state) if DEBUG >= 2: print("infer_tensor called") print(f"input_data: {input_data}\n") print(f"in_tensor: {in_tensor}\n") - # Ensure the shard is loaded await self.ensure_shard(shard) - # Run the forward pass through the model layers - # output_data, past_key_values - output_data, inference_state = self.model.forward_layers( - in_tensor - # past_key_values=past_key_values + in_tensor, + inference_state ) is_finished = output_data[-1] == self.tokenizer.eos_token_id and output_data.size == 1 if DEBUG >= 2: + print(f"output_data: {output_data}\n") + print(f"output_data.size {output_data.size}\n") + print(f"output_data.item() {output_data.item()}") + print(f"inference_state: {inference_state}") print(f"finished: {is_finished}") + print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") + print(f"output_data[-1] {output_data[-1]}") print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") return ( output_data, - json.loads(inference_state), + json.dumps(inference_state), is_finished ) - # def _load_kv_cache(self, past_key_values_list): - # """ - # Load key-value cache from the inference state. - - # Args: - # past_key_values_list (list): List of past key-value tensors. - - # Returns: - # list: List of loaded past key-value tensors. - # """ - # if past_key_values_list is None: - # return [] - # return [torch.tensor(kv, device=self.device) for kv in past_key_values_list] - - # def _save_kv_cache(self, past_key_values): - # """ - # Save key-value cache to the inference state. - - # Args: - # past_key_values (list): List of past key-value tensors. - - # Returns: - # list: List of key-value tensors in a format suitable for saving. - # """ - # if past_key_values is None: - # return [] - - # new_cache = [] - # for kv in past_key_values: - # if kv: - # new_cache.append(kv.cpu().tolist()) - - # return new_cache - async def ensure_shard(self, shard: Optional[Shard]): """ Ensure the model shard is loaded and ready for inference. From af8d9fe5dedacd33581ddf6a7b44d6cac1358e03 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 21:04:14 -0800 Subject: [PATCH 208/491] working on caching --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 6f296785..433bfb7d 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -41,7 +41,7 @@ async def infer_prompt( await self.ensure_shard(shard) - inference_state = json.loads(inference_state) + inference_state = json.loads(inference_state) if inference_state else "" tokens = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) if DEBUG >= 2: @@ -78,7 +78,7 @@ async def infer_tensor( inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: in_tensor = torch.tensor(input_data) - inference_state = json.loads(inference_state) + inference_state = json.loads(inference_state) if inference_state else "" if DEBUG >= 2: print("infer_tensor called") From f3e07eb08e7f2ffae6dadf92c14db610bedffea0 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 22:26:48 -0800 Subject: [PATCH 209/491] changing token to chat template --- exo/inference/pytorch/inference.py | 9 +++++- exo/inference/pytorch/model/hf.py | 44 ++-------------------------- exo/inference/pytorch/model/utils.py | 44 ++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 43 deletions(-) create mode 100644 exo/inference/pytorch/model/utils.py diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 433bfb7d..51ef3b12 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -42,7 +42,14 @@ async def infer_prompt( await self.ensure_shard(shard) inference_state = json.loads(inference_state) if inference_state else "" - tokens = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + tokens = self.tokenizer.apply_chat_template( + conversation=[{ + "role": "user", + "content": prompt + }], + tokenize=True, + add_generation_prompt=False, + ) if DEBUG >= 2: print(f"tokens: {tokens}\n") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 80aba11c..42f2357f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,52 +1,12 @@ import torch import numpy as np -from torch.nn import functional as F + from transformers import AutoModelForCausalLM, LlamaConfig from exo.inference.shard import Shard from exo.helpers import DEBUG from typing import Tuple -def sample_logits(logits, temp=0.85, top_k=25, top_p=0.9, alpha_f=0.1, alpha_p=0.0): - # Apply temperature scaling - if temp > 0: - logits = logits / temp - - # Top-k sampling - if top_k > 0: - top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) - logits = torch.full_like(logits, -float('inf')) - logits.scatter_(-1, top_k_indices, top_k_values) - - # Top-p (nucleus) sampling - if 0 < top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) - logits = logits.masked_fill(indices_to_remove, -float('inf')) - - # Alpha sampling (to discourage repetition) - if alpha_f or alpha_p: - if not hasattr(sample_logits, "alpha_counter"): - setattr(sample_logits, "alpha_counter", torch.zeros_like(logits, dtype=torch.int32).contiguous()) - logits = logits - (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) - - # Sample from the logits - probabilities = F.softmax(logits, dim=-1) - sampled_token = torch.multinomial(probabilities, 1) - - # Update alpha counter - if alpha_f or alpha_p: - condition = (torch.arange(probabilities.numel(), device=logits.device) == sampled_token) - condition = condition.bool() # Convert condition to boolean tensor - sample_logits.alpha_counter = torch.where(condition, sample_logits.alpha_counter + 1, sample_logits.alpha_counter) - - return sampled_token +from .utils import sample_logits class ShardedHuggingFaceModel(torch.nn.Module): def __init__(self, shard: Shard): diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py new file mode 100644 index 00000000..fe7ece45 --- /dev/null +++ b/exo/inference/pytorch/model/utils.py @@ -0,0 +1,44 @@ +import torch +from torch.nn import functional as F + +def sample_logits(logits, temp=0.85, top_k=25, top_p=0.9, alpha_f=0.1, alpha_p=0.0): + # Apply temperature scaling + if temp > 0: + logits = logits / temp + + # Top-k sampling + if top_k > 0: + top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) + logits = torch.full_like(logits, -float('inf')) + logits.scatter_(-1, top_k_indices, top_k_values) + + # Top-p (nucleus) sampling + if 0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, -float('inf')) + + # Alpha sampling (to discourage repetition) + if alpha_f or alpha_p: + if not hasattr(sample_logits, "alpha_counter"): + setattr(sample_logits, "alpha_counter", torch.zeros_like(logits, dtype=torch.int32).contiguous()) + logits = logits - (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) + + # Sample from the logits + probabilities = F.softmax(logits, dim=-1) + sampled_token = torch.multinomial(probabilities, 1) + + # Update alpha counter + if alpha_f or alpha_p: + condition = (torch.arange(probabilities.numel(), device=logits.device) == sampled_token) + condition = condition.bool() # Convert condition to boolean tensor + sample_logits.alpha_counter = torch.where(condition, sample_logits.alpha_counter + 1, sample_logits.alpha_counter) + + return sampled_token \ No newline at end of file From bb65e46a5b98197080d9b84f4d8d8c1100980f84 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 22:28:46 -0800 Subject: [PATCH 210/491] changing token to chat template --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 51ef3b12..213d52b0 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -42,14 +42,14 @@ async def infer_prompt( await self.ensure_shard(shard) inference_state = json.loads(inference_state) if inference_state else "" - tokens = self.tokenizer.apply_chat_template( + tokens = torch.tensor(self.tokenizer.apply_chat_template( conversation=[{ "role": "user", "content": prompt }], tokenize=True, add_generation_prompt=False, - ) + )) if DEBUG >= 2: print(f"tokens: {tokens}\n") From 8dea6ae1ffb3ba1ba41002ea48d33855b9f8743f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 22:32:05 -0800 Subject: [PATCH 211/491] changing token to chat template --- exo/inference/pytorch/model/hf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 42f2357f..a9a216be 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -70,6 +70,9 @@ def forward_layers( if self.shard.is_first_layer(): hidden_states = self.embed_tokens(hidden_states) + if DEBUG >= 2: + print(f"hidden_states: {hidden_states}") + batch_size, seq_len = input_data.size() position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) From a55f4560798d80fac049711e55e1688448cf9308 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 22:44:53 -0800 Subject: [PATCH 212/491] letting model general position_ids --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index a9a216be..0e15e5f1 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -73,8 +73,8 @@ def forward_layers( if DEBUG >= 2: print(f"hidden_states: {hidden_states}") - batch_size, seq_len = input_data.size() - position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) + # batch_size, seq_len = input_data.size() + # position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.full_model.model.rotary_emb( hidden_states, From de1d7331cfeb410609e0fda5f7a4720064145d30 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 22:54:19 -0800 Subject: [PATCH 213/491] working on kvs --- exo/inference/pytorch/model/hf.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 0e15e5f1..e961e2df 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,7 +1,7 @@ import torch import numpy as np -from transformers import AutoModelForCausalLM, LlamaConfig +from transformers import AutoModelForCausalLM, LlamaConfig, DynamicCache, Cache from exo.inference.shard import Shard from exo.helpers import DEBUG from typing import Tuple @@ -50,7 +50,7 @@ def __init__(self, shard: Shard): def forward_layers( self, input_data: torch.tensor, - past_kvs = [] + past_kvs: Cache = DynamicCache() ) -> Tuple[np.ndarray, list]: """ Forward pass through the specified layers. @@ -65,7 +65,7 @@ def forward_layers( hidden_states = input_data position_ids = None position_embeddings = None - present_kvs = [] + present_kvs = DynamicCache() if self.shard.is_first_layer(): hidden_states = self.embed_tokens(hidden_states) @@ -76,14 +76,14 @@ def forward_layers( # batch_size, seq_len = input_data.size() # position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.full_model.model.rotary_emb( - hidden_states, - position_ids - ) + # position_embeddings = self.full_model.model.rotary_emb( + # hidden_states, + # position_ids + # ) - if DEBUG >= 2: - print(f"embedded hidden_states {hidden_states}") - print(f"position_ids: {position_embeddings}") + # if DEBUG >= 2: + # print(f"embedded hidden_states {hidden_states}") + # print(f"position_ids: {position_embeddings}") for i, layer in enumerate(self.layers): # Forward pass through the layer @@ -94,13 +94,13 @@ def forward_layers( layer_outputs = layer( hidden_states, - position_embeddings=position_embeddings, - past_key_values=past_kvs[i] if past_kvs else None, + # position_embeddings=position_embeddings, + past_key_values=past_kvs, use_cache=True ) if DEBUG >= 2: - print(f"\n[layer {i}] layer_outputs: {layer_outputs[0]}") + print(f"\n[layer {i}] layer_outputs: {layer_outputs}") hidden_states = layer_outputs[0] present_kvs = layer_outputs[1] From 0d44195b83af7db686eee72d3bbafff444659e05 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 23:25:37 -0800 Subject: [PATCH 214/491] working on kvs --- exo/inference/pytorch/model/hf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index e961e2df..8bac47ac 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -72,6 +72,7 @@ def forward_layers( if DEBUG >= 2: print(f"hidden_states: {hidden_states}") + print(f"hidden_states.size(): {hidden_states.size()}") # batch_size, seq_len = input_data.size() # position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) From 683f547e554579f5c85753a42145142da7f63715 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 23:27:12 -0800 Subject: [PATCH 215/491] working on kvs --- exo/inference/pytorch/model/hf.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 8bac47ac..2e0ba375 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -74,13 +74,13 @@ def forward_layers( print(f"hidden_states: {hidden_states}") print(f"hidden_states.size(): {hidden_states.size()}") - # batch_size, seq_len = input_data.size() - # position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) + batch_size, seq_len = input_data.size() + position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) - # position_embeddings = self.full_model.model.rotary_emb( - # hidden_states, - # position_ids - # ) + position_embeddings = self.full_model.model.rotary_emb( + hidden_states, + position_ids + ) # if DEBUG >= 2: # print(f"embedded hidden_states {hidden_states}") @@ -95,7 +95,7 @@ def forward_layers( layer_outputs = layer( hidden_states, - # position_embeddings=position_embeddings, + position_embeddings=position_embeddings, past_key_values=past_kvs, use_cache=True ) From 3e858621d2b72de92071914a6365687a1ef3f358 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 23:28:27 -0800 Subject: [PATCH 216/491] working on kvs --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 2e0ba375..ed40e20f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -68,7 +68,7 @@ def forward_layers( present_kvs = DynamicCache() if self.shard.is_first_layer(): - hidden_states = self.embed_tokens(hidden_states) + if DEBUG >= 2: print(f"hidden_states: {hidden_states}") From 380feff6a720ec14789dd2ec6f50b841a29a3ae0 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 23:30:34 -0800 Subject: [PATCH 217/491] working on kvs --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index ed40e20f..b133a57a 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -68,13 +68,13 @@ def forward_layers( present_kvs = DynamicCache() if self.shard.is_first_layer(): - + hidden_states = self.embed_tokens(hidden_states) if DEBUG >= 2: print(f"hidden_states: {hidden_states}") print(f"hidden_states.size(): {hidden_states.size()}") - batch_size, seq_len = input_data.size() + batch_size, seq_len = hidden_states.size() position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.full_model.model.rotary_emb( From 5fe49f0f6f6deacb7d4129fc1bfbcfa334822d06 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 23:34:34 -0800 Subject: [PATCH 218/491] working on kvs --- exo/inference/pytorch/inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 213d52b0..f1151df4 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -42,14 +42,15 @@ async def infer_prompt( await self.ensure_shard(shard) inference_state = json.loads(inference_state) if inference_state else "" - tokens = torch.tensor(self.tokenizer.apply_chat_template( + tokens = self.tokenizer.apply_chat_template( conversation=[{ "role": "user", "content": prompt }], tokenize=True, add_generation_prompt=False, - )) + return_tensors="pt" + ) if DEBUG >= 2: print(f"tokens: {tokens}\n") From a2a76d9f697ca2acf49578e46d2114504d900e37 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 9 Aug 2024 23:35:49 -0800 Subject: [PATCH 219/491] working on kvs --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index b133a57a..2e0ba375 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -74,7 +74,7 @@ def forward_layers( print(f"hidden_states: {hidden_states}") print(f"hidden_states.size(): {hidden_states.size()}") - batch_size, seq_len = hidden_states.size() + batch_size, seq_len = input_data.size() position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.full_model.model.rotary_emb( From 48f2de6361b02633d84fc7429f125a8576a03812 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 09:33:13 -0800 Subject: [PATCH 220/491] fix main debug error --- main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index e66aa689..0e5401c2 100644 --- a/main.py +++ b/main.py @@ -51,7 +51,9 @@ chatgpt_api_endpoints=[f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip in get_all_ip_addresses()] web_chat_urls=[f"http://{ip}:{args.chatgpt_api_port}" for ip in get_all_ip_addresses()] if DEBUG >= 0: - print(f"Chat interface started:\n{'\n'.join([' - ' + terminal_link(web_chat_url) for web_chat_url in web_chat_urls])}") + links = '\n'.join([' - ' + terminal_link(web_chat_url) for web_chat_url in web_chat_urls]) + print(f"Chat interface started:\n{links}") + print(f"ChatGPT API endpoint served at:\n{'\n'.join([' - ' + terminal_link(chatgpt_api_endpoint) for chatgpt_api_endpoint in chatgpt_api_endpoints])}") node = StandardNode( args.node_id, From 68d9c46cc9aa31c99b3fc9af3153b558c2817718 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 09:35:01 -0800 Subject: [PATCH 221/491] fix main debug error --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 0e5401c2..92ea2dac 100644 --- a/main.py +++ b/main.py @@ -53,8 +53,8 @@ if DEBUG >= 0: links = '\n'.join([' - ' + terminal_link(web_chat_url) for web_chat_url in web_chat_urls]) print(f"Chat interface started:\n{links}") - - print(f"ChatGPT API endpoint served at:\n{'\n'.join([' - ' + terminal_link(chatgpt_api_endpoint) for chatgpt_api_endpoint in chatgpt_api_endpoints])}") + api_links = '\n'.join([' - ' + terminal_link(chatgpt_api_endpoint) for chatgpt_api_endpoint in chatgpt_api_endpoints]) + print(f"ChatGPT API endpoint served at:\n{api_links}") node = StandardNode( args.node_id, None, From 7306cfcf3a40f6f7af4b03c5529c416c8b294f48 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 09:58:01 -0800 Subject: [PATCH 222/491] finish error --- exo/inference/pytorch/inference.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f1151df4..063ae9de 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -48,10 +48,13 @@ async def infer_prompt( "content": prompt }], tokenize=True, - add_generation_prompt=False, + padding=True, + add_generation_prompt=True, return_tensors="pt" ) + # tokens = self.tokenizer.encode(prompt, return_tensors="pt") + if DEBUG >= 2: print(f"tokens: {tokens}\n") @@ -140,15 +143,4 @@ async def ensure_shard(self, shard: Optional[Shard]): self.shard = shard if DEBUG >= 2: - print(f"Shard loaded successfully: {shard}") - - def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]): - """ - Set a callback function to track download progress. - - Args: - on_download_progress (Callable[[int, int], None]): Callback function to track progress. - """ - # must have this function or inference engine breaks - # This method can be implemented if progress tracking is needed - pass + print(f"Shard loaded successfully: {shard}") \ No newline at end of file From 09eb58a3bb06020d5e88d4070cab6187e98f1025 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 10:39:33 -0800 Subject: [PATCH 223/491] finish error --- exo/inference/pytorch/inference.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 063ae9de..73b2df25 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -42,18 +42,7 @@ async def infer_prompt( await self.ensure_shard(shard) inference_state = json.loads(inference_state) if inference_state else "" - tokens = self.tokenizer.apply_chat_template( - conversation=[{ - "role": "user", - "content": prompt - }], - tokenize=True, - padding=True, - add_generation_prompt=True, - return_tensors="pt" - ) - - # tokens = self.tokenizer.encode(prompt, return_tensors="pt") + tokens = self.tokenizer.encode(prompt, return_tensors="pt") if DEBUG >= 2: print(f"tokens: {tokens}\n") From 9c1873787047aeb5ecc26ad1895f3e0afaf336d6 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 10:47:28 -0800 Subject: [PATCH 224/491] finish error --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 73b2df25..316f08d9 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -52,7 +52,7 @@ async def infer_prompt( inference_state ) - is_finished = output_data[-1] == self.tokenizer.eos_token_id and output_data.size == 1 + is_finished = output_data.size == 1 if DEBUG >= 2: print(f"output_data: {output_data}\n") From bf8ed2a3fc412fbe2af2fa85106876ebe6006a06 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 10:49:42 -0800 Subject: [PATCH 225/491] trying to manipulate sampling --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index fe7ece45..9bda8b83 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -1,7 +1,7 @@ import torch from torch.nn import functional as F -def sample_logits(logits, temp=0.85, top_k=25, top_p=0.9, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0.0): # Apply temperature scaling if temp > 0: logits = logits / temp From ae413d250d637e01ef472fb3647fa01a69df6ed7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 10:53:48 -0800 Subject: [PATCH 226/491] using resolve tokenizer from api --- exo/inference/pytorch/inference.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 316f08d9..4ee5d2bc 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -5,10 +5,10 @@ import numpy as np import json from typing import Optional, Callable, Tuple -from transformers import AutoTokenizer from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel +from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG class PyTorchDynamicShardInferenceEngine(InferenceEngine): @@ -124,11 +124,7 @@ async def ensure_shard(self, shard: Optional[Shard]): print(f"Loading new shard: {shard}") self.model = ShardedHuggingFaceModel(shard) - self.tokenizer = AutoTokenizer.from_pretrained( - shard.model_id, - add_eos_token=True, - use_fast=True - ) + self.tokenizer = resolve_tokenizer(shard.model_id) self.shard = shard if DEBUG >= 2: From 5e8663d14112eb1702cb7132eb623683078f9f7e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 10:57:11 -0800 Subject: [PATCH 227/491] using resolve tokenizer from api --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 4ee5d2bc..4e210d7c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -124,7 +124,7 @@ async def ensure_shard(self, shard: Optional[Shard]): print(f"Loading new shard: {shard}") self.model = ShardedHuggingFaceModel(shard) - self.tokenizer = resolve_tokenizer(shard.model_id) + self.tokenizer = await resolve_tokenizer(shard.model_id) self.shard = shard if DEBUG >= 2: From 9e3a4a7aab9cea6c764821d6d831d880a0103078 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 11:03:52 -0800 Subject: [PATCH 228/491] working on eot error --- exo/inference/pytorch/inference.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 4e210d7c..cde37d5d 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -52,7 +52,11 @@ async def infer_prompt( inference_state ) - is_finished = output_data.size == 1 + is_finished = output_data.size == 1 #and output_data.item() in [self.tokenizer.eos_token_id] + + if is_finished: + print(f"token from llm decode: {self.tokenizer.decode(output_data)}") + if DEBUG >= 2: print(f"output_data: {output_data}\n") @@ -92,7 +96,7 @@ async def infer_tensor( inference_state ) - is_finished = output_data[-1] == self.tokenizer.eos_token_id and output_data.size == 1 + is_finished = output_data.size == 1 #and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 2: print(f"output_data: {output_data}\n") From 6f90f43286da07672f65adce4835a06fd5ab7f56 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 12:15:48 -0800 Subject: [PATCH 229/491] working on eot error --- exo/inference/pytorch/model/hf.py | 6 +++--- exo/inference/pytorch/model/utils.py | 14 ++++++-------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 2e0ba375..e363f82b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -82,9 +82,9 @@ def forward_layers( position_ids ) - # if DEBUG >= 2: - # print(f"embedded hidden_states {hidden_states}") - # print(f"position_ids: {position_embeddings}") + if DEBUG >= 2: + print(f"embedded hidden_states {hidden_states}") + print(f"position_ids: {position_embeddings}") for i, layer in enumerate(self.layers): # Forward pass through the layer diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 9bda8b83..074d2bf7 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -8,8 +8,8 @@ def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0. # Top-k sampling if top_k > 0: - top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) - logits = torch.full_like(logits, -float('inf')) + top_k_values, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1) + logits = torch.full_like(logits, float('-inf')) logits.scatter_(-1, top_k_indices, top_k_values) # Top-p (nucleus) sampling @@ -22,13 +22,13 @@ def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0. sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) - logits = logits.masked_fill(indices_to_remove, -float('inf')) + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = float('-inf') # Alpha sampling (to discourage repetition) if alpha_f or alpha_p: if not hasattr(sample_logits, "alpha_counter"): - setattr(sample_logits, "alpha_counter", torch.zeros_like(logits, dtype=torch.int32).contiguous()) + sample_logits.alpha_counter = torch.zeros_like(logits, dtype=torch.int32) logits = logits - (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) # Sample from the logits @@ -37,8 +37,6 @@ def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0. # Update alpha counter if alpha_f or alpha_p: - condition = (torch.arange(probabilities.numel(), device=logits.device) == sampled_token) - condition = condition.bool() # Convert condition to boolean tensor - sample_logits.alpha_counter = torch.where(condition, sample_logits.alpha_counter + 1, sample_logits.alpha_counter) + sample_logits.alpha_counter.scatter_(-1, sampled_token, sample_logits.alpha_counter.gather(-1, sampled_token) + 1) return sampled_token \ No newline at end of file From 07c146c5bd3d3f4782fa4aa0c8343dad588c56b8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 12:19:10 -0800 Subject: [PATCH 230/491] working on eot error --- exo/inference/pytorch/model/utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 074d2bf7..411dc449 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -2,13 +2,18 @@ from torch.nn import functional as F def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0.0): + # Ensure logits is a 2D tensor + if logits.dim() == 1: + logits = logits.unsqueeze(0) + # Apply temperature scaling if temp > 0: logits = logits / temp # Top-k sampling if top_k > 0: - top_k_values, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1) + top_k = min(top_k, logits.size(-1)) + top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) logits = torch.full_like(logits, float('-inf')) logits.scatter_(-1, top_k_indices, top_k_values) @@ -23,7 +28,7 @@ def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0. sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[indices_to_remove] = float('-inf') + logits[torch.arange(logits.size(0)).unsqueeze(1), indices_to_remove] = float('-inf') # Alpha sampling (to discourage repetition) if alpha_f or alpha_p: @@ -39,4 +44,4 @@ def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0. if alpha_f or alpha_p: sample_logits.alpha_counter.scatter_(-1, sampled_token, sample_logits.alpha_counter.gather(-1, sampled_token) + 1) - return sampled_token \ No newline at end of file + return sampled_token.squeeze() \ No newline at end of file From da2a97bc424728b543f99257dc829ed303d0cc4f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 12:29:30 -0800 Subject: [PATCH 231/491] working on eot error --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index cde37d5d..7ab4e214 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -52,7 +52,7 @@ async def infer_prompt( inference_state ) - is_finished = output_data.size == 1 #and output_data.item() in [self.tokenizer.eos_token_id] + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if is_finished: print(f"token from llm decode: {self.tokenizer.decode(output_data)}") @@ -96,7 +96,7 @@ async def infer_tensor( inference_state ) - is_finished = output_data.size == 1 #and output_data.item() in [self.tokenizer.eos_token_id] + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 2: print(f"output_data: {output_data}\n") From f54990a121db9a52c8775cefad9dd1cb2c7aeab8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 12:33:19 -0800 Subject: [PATCH 232/491] working on eot error --- exo/inference/pytorch/model/hf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index e363f82b..0c8d308b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -74,7 +74,10 @@ def forward_layers( print(f"hidden_states: {hidden_states}") print(f"hidden_states.size(): {hidden_states.size()}") - batch_size, seq_len = input_data.size() + if hidden_states.size == 2: + batch_size, seq_len = hidden_states.size() + else: + batch_size, seq_len = input_data.size() position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.full_model.model.rotary_emb( From e96579f20385fc64e0cb894f9ed87d93e9974695 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 12:46:22 -0800 Subject: [PATCH 233/491] working on eot error --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 0c8d308b..7ca4448b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -74,7 +74,7 @@ def forward_layers( print(f"hidden_states: {hidden_states}") print(f"hidden_states.size(): {hidden_states.size()}") - if hidden_states.size == 2: + if hidden_states.size() == 2: batch_size, seq_len = hidden_states.size() else: batch_size, seq_len = input_data.size() From 25361ecba59b9d29bd9895d6e80ca6150bf7d337 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 12:54:37 -0800 Subject: [PATCH 234/491] working on eot error --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 7ca4448b..7127c890 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -74,7 +74,7 @@ def forward_layers( print(f"hidden_states: {hidden_states}") print(f"hidden_states.size(): {hidden_states.size()}") - if hidden_states.size() == 2: + if len(hidden_states.size()) == 2: batch_size, seq_len = hidden_states.size() else: batch_size, seq_len = input_data.size() From 84a0323718907d0e16d137da2b8d54d24dc12b7f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 13:36:16 -0800 Subject: [PATCH 235/491] eot issue, update sampling --- exo/inference/pytorch/model/hf.py | 5 +- exo/inference/pytorch/model/utils.py | 78 +++++++++++++++++++++------- 2 files changed, 59 insertions(+), 24 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 7127c890..e363f82b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -74,10 +74,7 @@ def forward_layers( print(f"hidden_states: {hidden_states}") print(f"hidden_states.size(): {hidden_states.size()}") - if len(hidden_states.size()) == 2: - batch_size, seq_len = hidden_states.size() - else: - batch_size, seq_len = input_data.size() + batch_size, seq_len = input_data.size() position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) position_embeddings = self.full_model.model.rotary_emb( diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 411dc449..5e61d818 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -1,36 +1,73 @@ import torch from torch.nn import functional as F -def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0.0): - # Ensure logits is a 2D tensor - if logits.dim() == 1: - logits = logits.unsqueeze(0) +def top_p_sampling(logits, top_p: float, temperature: float = 1.0): + """ + Perform top-p sampling (nucleus sampling) on logits. + + Args: + logits (torch.Tensor): The logits distribution to sample from. + top_p (float): The cumulative probability threshold for nucleus sampling. + temperature (float): Sampling temperature. + Returns: + torch.Tensor: The selected token indices. + """ # Apply temperature scaling - if temp > 0: - logits = logits / temp + logits = logits / temperature + + # Sort the logits in descending order + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + + # Calculate cumulative probabilities + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + # Create a mask to remove logits with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # Mask the logits + sorted_logits[sorted_indices_to_remove] = -float('Inf') + + # Sample from the filtered distribution + probabilities = torch.softmax(sorted_logits, dim=-1) + sampled_token = torch.multinomial(probabilities, 1) + + # Convert to original index order + return sorted_indices.gather(-1, sampled_token) + +def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0.0): + """ + Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. - # Top-k sampling + Args: + logits (torch.Tensor): The logits distribution to sample from. + temp (float): Temperature for scaling logits. + top_k (int): The number of top tokens to consider for sampling. + top_p (float): The cumulative probability threshold for nucleus sampling. + alpha_f (float): Penalty factor for repetition. + alpha_p (float): Penalty for selecting already selected tokens. + + Returns: + torch.Tensor: The selected token indices. + """ + # Return argmax for deterministic output at low temperature + if temp < 1e-6: + return logits.argmax(dim=-1) + + # Apply Top-k sampling if specified if top_k > 0: top_k = min(top_k, logits.size(-1)) top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) logits = torch.full_like(logits, float('-inf')) logits.scatter_(-1, top_k_indices, top_k_values) - # Top-p (nucleus) sampling + # Apply Top-p (nucleus) sampling if specified if 0 < top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 + logits = top_p_sampling(logits, top_p, temp) - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[torch.arange(logits.size(0)).unsqueeze(1), indices_to_remove] = float('-inf') - - # Alpha sampling (to discourage repetition) + # Apply alpha sampling to discourage repetition if alpha_f or alpha_p: if not hasattr(sample_logits, "alpha_counter"): sample_logits.alpha_counter = torch.zeros_like(logits, dtype=torch.int32) @@ -44,4 +81,5 @@ def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0. if alpha_f or alpha_p: sample_logits.alpha_counter.scatter_(-1, sampled_token, sample_logits.alpha_counter.gather(-1, sampled_token) + 1) - return sampled_token.squeeze() \ No newline at end of file + return sampled_token.squeeze() + From fffc0345ceecc96b315fb998ecdd3905946b939d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 13:40:18 -0800 Subject: [PATCH 236/491] ensure 2d --- exo/inference/pytorch/inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 7ab4e214..1eb59c7b 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -84,6 +84,10 @@ async def infer_tensor( in_tensor = torch.tensor(input_data) inference_state = json.loads(inference_state) if inference_state else "" + # Ensure input_data is 2D: [batch_size, seq_len] + if input_data.dim() == 1: + input_data = input_data.unsqueeze(0) # Add a batch dimension: [1, seq_len] + if DEBUG >= 2: print("infer_tensor called") print(f"input_data: {input_data}\n") From 3fbc7d238d3540405fcb0dcf88494e2530b3c6cb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 13:41:40 -0800 Subject: [PATCH 237/491] ensure 2d --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 1eb59c7b..9a52bafd 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -85,8 +85,8 @@ async def infer_tensor( inference_state = json.loads(inference_state) if inference_state else "" # Ensure input_data is 2D: [batch_size, seq_len] - if input_data.dim() == 1: - input_data = input_data.unsqueeze(0) # Add a batch dimension: [1, seq_len] + if in_tensor.dim() == 1: + in_tensor = in_tensor.unsqueeze(0) # Add a batch dimension: [1, seq_len] if DEBUG >= 2: print("infer_tensor called") From df2144e693d48af1218de7d97ed878d811a6c664 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 13:47:01 -0800 Subject: [PATCH 238/491] doing a fraction of temp for top p --- exo/inference/pytorch/model/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 5e61d818..1a8ae772 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -14,7 +14,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): torch.Tensor: The selected token indices. """ # Apply temperature scaling - logits = logits / temperature + logits = logits * (1/temperature) # Sort the logits in descending order sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.0, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.8, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 9b37ee12fce752a541f62bc888785708e7cbfc3c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 13:53:08 -0800 Subject: [PATCH 239/491] doing a fraction of temp for top p --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 1a8ae772..8379c025 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -14,7 +14,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): torch.Tensor: The selected token indices. """ # Apply temperature scaling - logits = logits * (1/temperature) + logits = logits/temperature # Sort the logits in descending order sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) From 49eead3d2809e017486816f8adb85e61ea9e4204 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 13:55:03 -0800 Subject: [PATCH 240/491] trying other sampling --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 8379c025..f0a403fa 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.8, top_k=15, top_p=0.9, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.8, top_k=0, top_p=1.0, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 1192971f117e5c3b44fedd2c1781026a96725a03 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 13:57:33 -0800 Subject: [PATCH 241/491] trying other sampling --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index f0a403fa..92b6d35c 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.8, top_k=0, top_p=1.0, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.8, top_k=32, top_p=1.0, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 8bc583b6c52c20bbb140439ac3595c855eb2aa25 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:03:09 -0800 Subject: [PATCH 242/491] trying other sampling --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 92b6d35c..59ffe6aa 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.8, top_k=32, top_p=1.0, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.9, top_k=50, top_p=0.9, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 9afe07fc2e019309ebff7981626f6bd0b6e91b15 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:07:58 -0800 Subject: [PATCH 243/491] trying other sampling --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 59ffe6aa..b217ea49 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.9, top_k=50, top_p=0.9, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=1.0, top_k=50, top_p=0.95, alpha_f=0.5, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 954682b0a45501dfd6881b82a5a7a50b0f8e83a8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:20:16 -0800 Subject: [PATCH 244/491] trying other sampling --- exo/inference/pytorch/model/hf.py | 63 +++++++++++++++++-------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index e363f82b..b40a3bcf 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -22,23 +22,29 @@ def __init__(self, shard: Shard): self.config = LlamaConfig.from_pretrained(shard.model_id) self.config.use_cache = True # Enable caching + # Extract only the layers for this shard + # get layers up to end layer + self.config.config.num_hidden_layers = shard.end_layer + 1 + # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( shard.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", - config=self.config + config={ + "use_cache" + } ) - # Extract only the layers for this shard - print(f"\nself.model: {self.full_model.model}\n") - print(f"\nlayer amount: {len(self.full_model.model.layers)}") - self.layers = [] - for i in range(shard.start_layer, shard.end_layer + 1): - # if DEBUG >= 2: - # print(f"loading layer[{i}]: {self.full_model.model.layers[i]}") + + + # self.layered_model = self.full_model.model() + # self.layers = [] + # for i in range(shard.start_layer, shard.end_layer + 1): + # # if DEBUG >= 2: + # # print(f"loading layer[{i}]: {self.full_model.model.layers[i]}") - self.layers.append(self.full_model.model.layers[i]) + # self.layers.append(self.full_model.model.layers[i]) # self.layers = torch.nn.ModuleList(layer_list) @@ -86,28 +92,27 @@ def forward_layers( print(f"embedded hidden_states {hidden_states}") print(f"position_ids: {position_embeddings}") - for i, layer in enumerate(self.layers): - # Forward pass through the layer - if DEBUG >= 2: - print(f"\n[layer {i}] {layer}") - print(f"hidden_states {hidden_states}") - print(f"past_kvs {past_kvs}") - - layer_outputs = layer( - hidden_states, - position_embeddings=position_embeddings, - past_key_values=past_kvs, - use_cache=True - ) + # Forward pass through the layer + if DEBUG >= 2: + print(f"\n[layer model] {self.full_model.model}") + print(f"hidden_states {hidden_states}") + print(f"past_kvs {past_kvs}") + + layer_outputs = self.full_model.model( + hidden_states, + position_embeddings=position_embeddings, + past_key_values=past_kvs, + use_cache=True + ) - if DEBUG >= 2: - print(f"\n[layer {i}] layer_outputs: {layer_outputs}") - - hidden_states = layer_outputs[0] - present_kvs = layer_outputs[1] + if DEBUG >= 2: + print(f"\nlayer_outputs: {layer_outputs}") + + hidden_states = layer_outputs[0] + present_kvs = layer_outputs[1] - if DEBUG >= 2: - print(f"present_kvs {present_kvs}") + if DEBUG >= 2: + print(f"present_kvs {present_kvs}") print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): From 0547fa37a65db0336a5b338ca10528225ee37e04 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:29:09 -0800 Subject: [PATCH 245/491] fix typo --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index b40a3bcf..436653e7 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -24,7 +24,7 @@ def __init__(self, shard: Shard): # Extract only the layers for this shard # get layers up to end layer - self.config.config.num_hidden_layers = shard.end_layer + 1 + self.config.num_hidden_layers = shard.end_layer + 1 # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( From 88c94a86d6e64cbb9bf297aa08257d2966bcdc30 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:31:30 -0800 Subject: [PATCH 246/491] use input embeds --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 436653e7..1136d172 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -100,7 +100,7 @@ def forward_layers( layer_outputs = self.full_model.model( hidden_states, - position_embeddings=position_embeddings, + inputs_embeds=position_embeddings, past_key_values=past_kvs, use_cache=True ) From 4160c40f10df1eac78065534a1c03de284220d4b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:34:32 -0800 Subject: [PATCH 247/491] use position ids --- exo/inference/pytorch/model/hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 1136d172..802fc438 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -100,7 +100,8 @@ def forward_layers( layer_outputs = self.full_model.model( hidden_states, - inputs_embeds=position_embeddings, + position_ids=position_ids + # inputs_embeds=position_embeddings, past_key_values=past_kvs, use_cache=True ) From a6e3c15d92cb0778c94e530d435e7a9d92b323f2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:34:58 -0800 Subject: [PATCH 248/491] use position ids --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 802fc438..04946d49 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -100,7 +100,7 @@ def forward_layers( layer_outputs = self.full_model.model( hidden_states, - position_ids=position_ids + position_ids=position_ids, # inputs_embeds=position_embeddings, past_key_values=past_kvs, use_cache=True From 6f3435e9c2dde1313a4645184a0cf7041ac34bbf Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:37:17 -0800 Subject: [PATCH 249/491] use position ids --- exo/inference/pytorch/model/hf.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 04946d49..5e90b206 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -70,27 +70,27 @@ def forward_layers( hidden_states = input_data position_ids = None - position_embeddings = None + # position_embeddings = None present_kvs = DynamicCache() - if self.shard.is_first_layer(): - hidden_states = self.embed_tokens(hidden_states) + # if self.shard.is_first_layer(): + # hidden_states = self.embed_tokens(hidden_states) - if DEBUG >= 2: - print(f"hidden_states: {hidden_states}") - print(f"hidden_states.size(): {hidden_states.size()}") + # if DEBUG >= 2: + # print(f"hidden_states: {hidden_states}") + # print(f"hidden_states.size(): {hidden_states.size()}") - batch_size, seq_len = input_data.size() - position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) + # batch_size, seq_len = input_data.size() + # position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) - position_embeddings = self.full_model.model.rotary_emb( - hidden_states, - position_ids - ) + # position_embeddings = self.full_model.model.rotary_emb( + # hidden_states, + # position_ids + # ) - if DEBUG >= 2: - print(f"embedded hidden_states {hidden_states}") - print(f"position_ids: {position_embeddings}") + # if DEBUG >= 2: + # print(f"embedded hidden_states {hidden_states}") + # print(f"position_ids: {position_embeddings}") # Forward pass through the layer if DEBUG >= 2: From 157d0ddb0492a5bb28d487d54b17a8b820063726 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:37:32 -0800 Subject: [PATCH 250/491] rmv embed fix --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 5e90b206..60a1b52f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -69,7 +69,7 @@ def forward_layers( print(f"1 shard {self.shard.to_dict()}") hidden_states = input_data - position_ids = None + # position_ids = None # position_embeddings = None present_kvs = DynamicCache() @@ -100,7 +100,7 @@ def forward_layers( layer_outputs = self.full_model.model( hidden_states, - position_ids=position_ids, + # position_ids=position_ids, # inputs_embeds=position_embeddings, past_key_values=past_kvs, use_cache=True From 0d12ef5175dc86195e9387f4b6f69e5b93e02010 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:42:09 -0800 Subject: [PATCH 251/491] layer test --- exo/inference/pytorch/model/hf.py | 42 +++---------------------------- 1 file changed, 3 insertions(+), 39 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 60a1b52f..7f88db04 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -24,7 +24,7 @@ def __init__(self, shard: Shard): # Extract only the layers for this shard # get layers up to end layer - self.config.num_hidden_layers = shard.end_layer + 1 + self.config.num_hidden_layers = 2 # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( @@ -35,18 +35,6 @@ def __init__(self, shard: Shard): "use_cache" } ) - - - - # self.layered_model = self.full_model.model() - # self.layers = [] - # for i in range(shard.start_layer, shard.end_layer + 1): - # # if DEBUG >= 2: - # # print(f"loading layer[{i}]: {self.full_model.model.layers[i]}") - - # self.layers.append(self.full_model.model.layers[i]) - - # self.layers = torch.nn.ModuleList(layer_list) # Embeddings and final layer norm # used for doing what forward LlamaModel does in transformers @@ -66,37 +54,16 @@ def forward_layers( if DEBUG >= 2: print("forward_layer call") print(f"input_data: {input_data}") - print(f"1 shard {self.shard.to_dict()}") + print(f"shard {self.shard.to_dict()}") hidden_states = input_data - # position_ids = None - # position_embeddings = None present_kvs = DynamicCache() - # if self.shard.is_first_layer(): - # hidden_states = self.embed_tokens(hidden_states) - - # if DEBUG >= 2: - # print(f"hidden_states: {hidden_states}") - # print(f"hidden_states.size(): {hidden_states.size()}") - - # batch_size, seq_len = input_data.size() - # position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) - - # position_embeddings = self.full_model.model.rotary_emb( - # hidden_states, - # position_ids - # ) - - # if DEBUG >= 2: - # print(f"embedded hidden_states {hidden_states}") - # print(f"position_ids: {position_embeddings}") - # Forward pass through the layer if DEBUG >= 2: print(f"\n[layer model] {self.full_model.model}") print(f"hidden_states {hidden_states}") - print(f"past_kvs {past_kvs}") + # print(f"past_kvs {past_kvs}") layer_outputs = self.full_model.model( hidden_states, @@ -112,9 +79,6 @@ def forward_layers( hidden_states = layer_outputs[0] present_kvs = layer_outputs[1] - if DEBUG >= 2: - print(f"present_kvs {present_kvs}") - print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): hs_norm = self.norm(hidden_states) From 73fc9bf4f083626ce243d45793e33bb77888bb26 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:46:38 -0800 Subject: [PATCH 252/491] layer test --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 7f88db04..74b0cf4f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -76,8 +76,8 @@ def forward_layers( if DEBUG >= 2: print(f"\nlayer_outputs: {layer_outputs}") - hidden_states = layer_outputs[0] - present_kvs = layer_outputs[1] + hidden_states = layer_outputs.last_hidden_state + present_kvs = layer_outputs.past_key_values print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): From 2cdf81305d709566d05aa7308473d804e0cc7daf Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 14:51:51 -0800 Subject: [PATCH 253/491] layer test --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 74b0cf4f..37148711 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -73,7 +73,7 @@ def forward_layers( use_cache=True ) - if DEBUG >= 2: + if DEBUG >= 4: print(f"\nlayer_outputs: {layer_outputs}") hidden_states = layer_outputs.last_hidden_state From 0b3ec2cc7d33a8dab4197862b7f1b8b831922192 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 15:01:17 -0800 Subject: [PATCH 254/491] layer test --- exo/inference/pytorch/inference.py | 4 ++-- exo/inference/pytorch/model/hf.py | 7 ++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 9a52bafd..98fe9cc1 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -106,7 +106,7 @@ async def infer_tensor( print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"output_data.item() {output_data.item()}") - print(f"inference_state: {inference_state}") + print(f"inference_state: {inference_state.size()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") @@ -114,7 +114,7 @@ async def infer_tensor( return ( output_data, - json.dumps(inference_state), + json.dumps(inference_state.cpu().numpy()), is_finished ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 37148711..d15c85bd 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -27,13 +27,10 @@ def __init__(self, shard: Shard): self.config.num_hidden_layers = 2 # Load the model - self.full_model = AutoModelForCausalLM.from_pretrained( + self.full_model = AutoModelForCausalLM(self.config).from_pretrained( shard.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, - device_map="auto", - config={ - "use_cache" - } + device_map="auto" ) # Embeddings and final layer norm From fb1e0a6bfb09b06adb2927e50d804e2ac1112884 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 15:04:36 -0800 Subject: [PATCH 255/491] layer test --- exo/inference/pytorch/model/hf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index d15c85bd..60340edf 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -27,12 +27,14 @@ def __init__(self, shard: Shard): self.config.num_hidden_layers = 2 # Load the model - self.full_model = AutoModelForCausalLM(self.config).from_pretrained( + self.full_model = AutoModelForCausalLM.from_pretrained( shard.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) + self.full_model.config = self.config + # Embeddings and final layer norm # used for doing what forward LlamaModel does in transformers self.embed_tokens = self.full_model.model.embed_tokens From f662ad39a30064bbd1e8e10e59d2c59c6c4b2c41 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 15:08:10 -0800 Subject: [PATCH 256/491] layer test --- exo/inference/pytorch/inference.py | 8 ++++---- exo/inference/pytorch/model/utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 98fe9cc1..7540cbaa 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -41,7 +41,7 @@ async def infer_prompt( await self.ensure_shard(shard) - inference_state = json.loads(inference_state) if inference_state else "" + inference_state = json.loads(torch.tensor(inference_state)) if inference_state else "" tokens = self.tokenizer.encode(prompt, return_tensors="pt") if DEBUG >= 2: @@ -62,7 +62,7 @@ async def infer_prompt( print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"output_data.item() {output_data.item()}") - print(f"inference_state: {inference_state}") + print(f"inference_state: {inference_state.size()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") @@ -70,7 +70,7 @@ async def infer_prompt( return ( output_data, - json.dumps(inference_state), + json.dumps(inference_state.cpu().numpy()), is_finished ) @@ -82,7 +82,7 @@ async def infer_tensor( inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: in_tensor = torch.tensor(input_data) - inference_state = json.loads(inference_state) if inference_state else "" + inference_state = json.loads(torch.tensor(inference_state)) if inference_state else "" # Ensure input_data is 2D: [batch_size, seq_len] if in_tensor.dim() == 1: diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index b217ea49..581a94b5 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=1.0, top_k=50, top_p=0.95, alpha_f=0.5, alpha_p=0.0): +def sample_logits(logits, temp=1.0, top_k=20, top_p=0.95, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 1b27532d5edda06aa24869e435a49bb33a60d822 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:00:36 -0800 Subject: [PATCH 257/491] layer test --- exo/inference/pytorch/inference.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 7540cbaa..8b3a73f9 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -10,6 +10,7 @@ from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG +from transformers import DynamicCache class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ @@ -41,7 +42,14 @@ async def infer_prompt( await self.ensure_shard(shard) - inference_state = json.loads(torch.tensor(inference_state)) if inference_state else "" + # need to make this so inference_state is not a string + if inference_state: + inference_state = DynamicCache.from_legacy_cache( + json.loads(torch.tensor(inference_state)) + ) + else: + inference_state = DynamicCache() + tokens = self.tokenizer.encode(prompt, return_tensors="pt") if DEBUG >= 2: From 0876c79be0823bf315da06ee60638f652313aab6 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:02:48 -0800 Subject: [PATCH 258/491] layer test --- exo/inference/pytorch/inference.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 8b3a73f9..9b881353 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -70,7 +70,7 @@ async def infer_prompt( print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"output_data.item() {output_data.item()}") - print(f"inference_state: {inference_state.size()}") + print(f"inference_state.get_max_length(): {inference_state.get_max_length()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") @@ -90,7 +90,13 @@ async def infer_tensor( inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: in_tensor = torch.tensor(input_data) - inference_state = json.loads(torch.tensor(inference_state)) if inference_state else "" + + if inference_state: + inference_state = DynamicCache.from_legacy_cache( + json.loads(torch.tensor(inference_state)) + ) + else: + inference_state = DynamicCache() # Ensure input_data is 2D: [batch_size, seq_len] if in_tensor.dim() == 1: @@ -114,7 +120,7 @@ async def infer_tensor( print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"output_data.item() {output_data.item()}") - print(f"inference_state: {inference_state.size()}") + print(f"inference_state.get_max_length(): {inference_state.get_max_length()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") From e5c56c4a4ee354c8ce06e2424cf9f1844054450c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:12:40 -0800 Subject: [PATCH 259/491] layer test --- exo/inference/pytorch/inference.py | 34 ++++++++---------------------- exo/inference/pytorch/model/hf.py | 14 ++++++------ 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 9b881353..a8c72f27 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -43,21 +43,15 @@ async def infer_prompt( await self.ensure_shard(shard) # need to make this so inference_state is not a string - if inference_state: - inference_state = DynamicCache.from_legacy_cache( - json.loads(torch.tensor(inference_state)) - ) - else: - inference_state = DynamicCache() - + # cant use it with dynamic cache + tokens = self.tokenizer.encode(prompt, return_tensors="pt") if DEBUG >= 2: print(f"tokens: {tokens}\n") - output_data, inference_state = self.model.forward_layers( - tokens, - inference_state + output_data = self.model.forward_layers( + tokens ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] @@ -70,7 +64,6 @@ async def infer_prompt( print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"output_data.item() {output_data.item()}") - print(f"inference_state.get_max_length(): {inference_state.get_max_length()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") @@ -78,7 +71,7 @@ async def infer_prompt( return ( output_data, - json.dumps(inference_state.cpu().numpy()), + "", is_finished ) @@ -90,14 +83,7 @@ async def infer_tensor( inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: in_tensor = torch.tensor(input_data) - - if inference_state: - inference_state = DynamicCache.from_legacy_cache( - json.loads(torch.tensor(inference_state)) - ) - else: - inference_state = DynamicCache() - + # Ensure input_data is 2D: [batch_size, seq_len] if in_tensor.dim() == 1: in_tensor = in_tensor.unsqueeze(0) # Add a batch dimension: [1, seq_len] @@ -109,9 +95,8 @@ async def infer_tensor( await self.ensure_shard(shard) - output_data, inference_state = self.model.forward_layers( - in_tensor, - inference_state + output_data = self.model.forward_layers( + in_tensor ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] @@ -120,7 +105,6 @@ async def infer_tensor( print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"output_data.item() {output_data.item()}") - print(f"inference_state.get_max_length(): {inference_state.get_max_length()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") @@ -128,7 +112,7 @@ async def infer_tensor( return ( output_data, - json.dumps(inference_state.cpu().numpy()), + "", is_finished ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 60340edf..ca1c6416 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -40,10 +40,11 @@ def __init__(self, shard: Shard): self.embed_tokens = self.full_model.model.embed_tokens self.norm = self.full_model.model.norm + self.past_key_values = DynamicCache() + def forward_layers( self, - input_data: torch.tensor, - past_kvs: Cache = DynamicCache() + input_data: torch.tensor ) -> Tuple[np.ndarray, list]: """ Forward pass through the specified layers. @@ -56,7 +57,6 @@ def forward_layers( print(f"shard {self.shard.to_dict()}") hidden_states = input_data - present_kvs = DynamicCache() # Forward pass through the layer if DEBUG >= 2: @@ -68,7 +68,7 @@ def forward_layers( hidden_states, # position_ids=position_ids, # inputs_embeds=position_embeddings, - past_key_values=past_kvs, + past_key_values=self.past_key_values, use_cache=True ) @@ -76,7 +76,7 @@ def forward_layers( print(f"\nlayer_outputs: {layer_outputs}") hidden_states = layer_outputs.last_hidden_state - present_kvs = layer_outputs.past_key_values + self.past_key_values = layer_outputs.past_key_values print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): @@ -92,6 +92,6 @@ def forward_layers( print(f"hs_lm_head: {hs_lm_head}") print(f"output_token: {output_token}") - return (output_token, present_kvs) + return output_token - return (hidden_states.cpu().numpy(), present_kvs) + return hidden_states.cpu().numpy() From 696c3bb55dc95f15ad0e34b235e7bfe7cf574414 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:16:11 -0800 Subject: [PATCH 260/491] layer test --- exo/inference/pytorch/model/hf.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index ca1c6416..d0108c5b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -40,7 +40,7 @@ def __init__(self, shard: Shard): self.embed_tokens = self.full_model.model.embed_tokens self.norm = self.full_model.model.norm - self.past_key_values = DynamicCache() + # self.past_key_values = DynamicCache() def forward_layers( self, @@ -68,15 +68,15 @@ def forward_layers( hidden_states, # position_ids=position_ids, # inputs_embeds=position_embeddings, - past_key_values=self.past_key_values, - use_cache=True + # past_key_values=self.past_key_values, + # use_cache=True # not enough vram for using cache ;_; ) - if DEBUG >= 4: + if DEBUG >= 2: print(f"\nlayer_outputs: {layer_outputs}") hidden_states = layer_outputs.last_hidden_state - self.past_key_values = layer_outputs.past_key_values + # self.past_key_values = layer_outputs.past_key_values print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): From e807a6376907b10c4de88c179a8c29c945808623 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:16:34 -0800 Subject: [PATCH 261/491] layer test --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 581a94b5..2a80260a 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=1.0, top_k=20, top_p=0.95, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.6, top_k=20, top_p=0.95, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 0b589ee2c4c3a8b42a5f576cc6de0f774a97137c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:17:48 -0800 Subject: [PATCH 262/491] layer test --- exo/inference/pytorch/model/hf.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index d0108c5b..cd81588e 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -18,13 +18,7 @@ def __init__(self, shard: Shard): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard - # Load the model with the configuration for caching - self.config = LlamaConfig.from_pretrained(shard.model_id) - self.config.use_cache = True # Enable caching - - # Extract only the layers for this shard - # get layers up to end layer - self.config.num_hidden_layers = 2 + # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( @@ -33,7 +27,15 @@ def __init__(self, shard: Shard): device_map="auto" ) - self.full_model.config = self.config + # set model config to restrict layers and enable caching + self.config = LlamaConfig.from_pretrained(shard.model_id) + # self.config.use_cache = True # Enable caching + + # Extract only the layers for this shard + # get layers up to end layer + self.config.num_hidden_layers = 2 + + self.full_model.model.config = self.config # Embeddings and final layer norm # used for doing what forward LlamaModel does in transformers From ced3879daf6eb6bd185ea72a46e9ee2ba441ca61 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:25:10 -0800 Subject: [PATCH 263/491] layer test --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 2a80260a..1073b82f 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.6, top_k=20, top_p=0.95, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.0, top_k=15, top_p=0.8, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From c6693a8673ea2aaa6449493c0503e6cd68c29912 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:25:18 -0800 Subject: [PATCH 264/491] layer test --- exo/inference/pytorch/model/hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index cd81588e..258daf5b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -34,8 +34,9 @@ def __init__(self, shard: Shard): # Extract only the layers for this shard # get layers up to end layer self.config.num_hidden_layers = 2 - self.full_model.model.config = self.config + if DEBUG >= 2: + print(f"full_model.model layer: {len(self.full_model.layers)}") # Embeddings and final layer norm # used for doing what forward LlamaModel does in transformers From 30e971dccf569a275a7f4f7472f9bec443fd7a61 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:26:10 -0800 Subject: [PATCH 265/491] layer test --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 258daf5b..4a04e132 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -36,7 +36,7 @@ def __init__(self, shard: Shard): self.config.num_hidden_layers = 2 self.full_model.model.config = self.config if DEBUG >= 2: - print(f"full_model.model layer: {len(self.full_model.layers)}") + print(f"full_model.model layer: {len(self.full_model.model.layers)}") # Embeddings and final layer norm # used for doing what forward LlamaModel does in transformers From 8b4e62492df516de2814446162b8f9b6ce5b2ffc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:33:21 -0800 Subject: [PATCH 266/491] fixing layer issue --- exo/inference/pytorch/model/hf.py | 20 ++++++++++++-------- exo/inference/pytorch/model/utils.py | 2 +- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 4a04e132..8a769c4e 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -27,14 +27,18 @@ def __init__(self, shard: Shard): device_map="auto" ) - # set model config to restrict layers and enable caching - self.config = LlamaConfig.from_pretrained(shard.model_id) - # self.config.use_cache = True # Enable caching - - # Extract only the layers for this shard - # get layers up to end layer - self.config.num_hidden_layers = 2 - self.full_model.model.config = self.config + # using llamaconfig not working setting layers manually + layers = [] + for i in range(shard.start_layer, shard.end_layer + 1): + layer = self.full_model.model.layers[i] + + if DEBUG >= 2: + print(f"Loading layers[{i}]") + + layers.append(layer) + + self.full_model.model.layer = layers + if DEBUG >= 2: print(f"full_model.model layer: {len(self.full_model.model.layers)}") diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 1073b82f..9f034ccc 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.0, top_k=15, top_p=0.8, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.1, top_k=15, top_p=0.8, alpha_f=0.2, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From bcb499e86962492cbf8e842fc4671456359fa549 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:35:24 -0800 Subject: [PATCH 267/491] temp and layer test --- exo/inference/pytorch/model/hf.py | 2 +- exo/inference/pytorch/model/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 8a769c4e..5bf3f19b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -29,7 +29,7 @@ def __init__(self, shard: Shard): # using llamaconfig not working setting layers manually layers = [] - for i in range(shard.start_layer, shard.end_layer + 1): + for i in range(shard.start_layer, 2): layer = self.full_model.model.layers[i] if DEBUG >= 2: diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 9f034ccc..ecc5c6a7 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.1, top_k=15, top_p=0.8, alpha_f=0.2, alpha_p=0.0): +def sample_logits(logits, temp=0.01, top_k=15, top_p=0.8, alpha_f=0.2, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 724c6c437d730c7c88cf71f648c1cc4a677dd06c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:40:42 -0800 Subject: [PATCH 268/491] temp and layer test --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 5bf3f19b..2c47ee61 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -37,7 +37,7 @@ def __init__(self, shard: Shard): layers.append(layer) - self.full_model.model.layer = layers + self.full_model.model.layers = layers if DEBUG >= 2: print(f"full_model.model layer: {len(self.full_model.model.layers)}") From e23f3f7df1aee472fc09819cb5b3055febd434a1 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:44:32 -0800 Subject: [PATCH 269/491] temp and layer test --- exo/inference/pytorch/model/hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 2c47ee61..0a1153a3 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn import numpy as np from transformers import AutoModelForCausalLM, LlamaConfig, DynamicCache, Cache @@ -37,7 +38,7 @@ def __init__(self, shard: Shard): layers.append(layer) - self.full_model.model.layers = layers + self.full_model.model.layers = nn.ModuleList(layers) if DEBUG >= 2: print(f"full_model.model layer: {len(self.full_model.model.layers)}") From 7f13a6d7c630f9a74a9ef854fb819c2ab60ffd08 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:49:00 -0800 Subject: [PATCH 270/491] temp and layer test --- exo/inference/pytorch/model/hf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 0a1153a3..f85ce70b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -30,7 +30,7 @@ def __init__(self, shard: Shard): # using llamaconfig not working setting layers manually layers = [] - for i in range(shard.start_layer, 2): + for i in range(5, 10): layer = self.full_model.model.layers[i] if DEBUG >= 2: @@ -69,7 +69,7 @@ def forward_layers( # Forward pass through the layer if DEBUG >= 2: print(f"\n[layer model] {self.full_model.model}") - print(f"hidden_states {hidden_states}") + print(f"IN hidden_states {hidden_states}") # print(f"past_kvs {past_kvs}") layer_outputs = self.full_model.model( @@ -81,7 +81,8 @@ def forward_layers( ) if DEBUG >= 2: - print(f"\nlayer_outputs: {layer_outputs}") + print(f"OUT hidden_states {hidden_states}") + # print(f"\nlayer_outputs: {layer_outputs}") hidden_states = layer_outputs.last_hidden_state # self.past_key_values = layer_outputs.past_key_values From ec92328d5e8c82bff75e2f1f0f4cda031fea5e4d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:52:46 -0800 Subject: [PATCH 271/491] temp and layer test --- exo/inference/pytorch/model/hf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index f85ce70b..9343c70c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -74,6 +74,7 @@ def forward_layers( layer_outputs = self.full_model.model( hidden_states, + past_key_values=None # position_ids=position_ids, # inputs_embeds=position_embeddings, # past_key_values=self.past_key_values, From fc3d2248be6401890b45a706c46e8ce1cb20c743 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 16:58:49 -0800 Subject: [PATCH 272/491] temp and layer test --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 9343c70c..28ecd3c0 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -74,7 +74,7 @@ def forward_layers( layer_outputs = self.full_model.model( hidden_states, - past_key_values=None + layer_idx=5 # position_ids=position_ids, # inputs_embeds=position_embeddings, # past_key_values=self.past_key_values, From f14a3397b3d2ba8a37d062808706aa6f1193c688 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:00:02 -0800 Subject: [PATCH 273/491] temp and layer test --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 28ecd3c0..755a2fe8 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -72,9 +72,9 @@ def forward_layers( print(f"IN hidden_states {hidden_states}") # print(f"past_kvs {past_kvs}") + self.full_model.model.layer_idx = 5 layer_outputs = self.full_model.model( - hidden_states, - layer_idx=5 + hidden_states # position_ids=position_ids, # inputs_embeds=position_embeddings, # past_key_values=self.past_key_values, From 4f4a9e11f80846cadfdcaf6666241329ec36162c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:07:39 -0800 Subject: [PATCH 274/491] temp and layer test --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 755a2fe8..f8964a70 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -74,11 +74,11 @@ def forward_layers( self.full_model.model.layer_idx = 5 layer_outputs = self.full_model.model( - hidden_states + hidden_states, # position_ids=position_ids, # inputs_embeds=position_embeddings, # past_key_values=self.past_key_values, - # use_cache=True # not enough vram for using cache ;_; + use_cache=False # not enough vram for using cache ;_; ) if DEBUG >= 2: From 3da44b3c70785131102c451eaf00330d0089e4a8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:09:28 -0800 Subject: [PATCH 275/491] change temp --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index ecc5c6a7..d25de200 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.01, top_k=15, top_p=0.8, alpha_f=0.2, alpha_p=0.0): +def sample_logits(logits, temp=0.001, top_k=15, top_p=0.8, alpha_f=0.2, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 0a4a0038f36b1abe52c41282888f1e812d3f4435 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:10:54 -0800 Subject: [PATCH 276/491] change temp and alpha --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index d25de200..e47c50df 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.001, top_k=15, top_p=0.8, alpha_f=0.2, alpha_p=0.0): +def sample_logits(logits, temp=0.001, top_k=15, top_p=0.8, alpha_f=0.0, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From e351501f33fd0a4ff474068da77ced77296dce79 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:15:12 -0800 Subject: [PATCH 277/491] change temp and alpha --- exo/inference/pytorch/model/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index e47c50df..510b951f 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.001, top_k=15, top_p=0.8, alpha_f=0.0, alpha_p=0.0): +def sample_logits(logits, temp=0.001, top_k=15, top_p=0.8, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. @@ -68,7 +68,7 @@ def sample_logits(logits, temp=0.001, top_k=15, top_p=0.8, alpha_f=0.0, alpha_p= logits = top_p_sampling(logits, top_p, temp) # Apply alpha sampling to discourage repetition - if alpha_f or alpha_p: + if alpha_f > 0.0 or alpha_p > 0.0: if not hasattr(sample_logits, "alpha_counter"): sample_logits.alpha_counter = torch.zeros_like(logits, dtype=torch.int32) logits = logits - (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) @@ -78,7 +78,7 @@ def sample_logits(logits, temp=0.001, top_k=15, top_p=0.8, alpha_f=0.0, alpha_p= sampled_token = torch.multinomial(probabilities, 1) # Update alpha counter - if alpha_f or alpha_p: + if alpha_f > 0.0 or alpha_p > 0.0: sample_logits.alpha_counter.scatter_(-1, sampled_token, sample_logits.alpha_counter.gather(-1, sampled_token) + 1) return sampled_token.squeeze() From 325156741e55848681655b0f19ff856d41353d5f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:16:25 -0800 Subject: [PATCH 278/491] change temp --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 510b951f..b043a45b 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.001, top_k=15, top_p=0.8, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.0, top_k=15, top_p=0.8, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 16e4f7ec5f33ddc7f27802e509c3fe4cce87ce1e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:17:35 -0800 Subject: [PATCH 279/491] change temp --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index b043a45b..b8426e51 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -37,7 +37,7 @@ def top_p_sampling(logits, top_p: float, temperature: float = 1.0): # Convert to original index order return sorted_indices.gather(-1, sampled_token) -def sample_logits(logits, temp=0.0, top_k=15, top_p=0.8, alpha_f=0.1, alpha_p=0.0): +def sample_logits(logits, temp=0.3, top_k=85, top_p=2.0, alpha_f=0.1, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. From 608392765854c5986b772b24a47286be384f5ac2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:19:55 -0800 Subject: [PATCH 280/491] change temp --- exo/inference/pytorch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index f8964a70..bbfa0d90 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -30,7 +30,7 @@ def __init__(self, shard: Shard): # using llamaconfig not working setting layers manually layers = [] - for i in range(5, 10): + for i in range(shard.start_layer, shard.end_layer + 1): layer = self.full_model.model.layers[i] if DEBUG >= 2: From 5b02fd1d2f621dce1bd59e95ef7fb4151d80cae4 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:30:31 -0800 Subject: [PATCH 281/491] change sampling --- exo/inference/pytorch/model/utils.py | 142 ++++++++++++++++++--------- 1 file changed, 94 insertions(+), 48 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index b8426e51..b2ea5847 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -1,60 +1,108 @@ import torch from torch.nn import functional as F -def top_p_sampling(logits, top_p: float, temperature: float = 1.0): - """ - Perform top-p sampling (nucleus sampling) on logits. +# def top_p_sampling(logits, top_p: float, temperature: float = 1.0): +# """ +# Perform top-p sampling (nucleus sampling) on logits. - Args: - logits (torch.Tensor): The logits distribution to sample from. - top_p (float): The cumulative probability threshold for nucleus sampling. - temperature (float): Sampling temperature. +# Args: +# logits (torch.Tensor): The logits distribution to sample from. +# top_p (float): The cumulative probability threshold for nucleus sampling. +# temperature (float): Sampling temperature. - Returns: - torch.Tensor: The selected token indices. - """ - # Apply temperature scaling - logits = logits/temperature +# Returns: +# torch.Tensor: The selected token indices. +# """ +# # Apply temperature scaling +# logits = logits/temperature - # Sort the logits in descending order - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) +# # Sort the logits in descending order +# sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - # Calculate cumulative probabilities - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) +# # Calculate cumulative probabilities +# cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - # Create a mask to remove logits with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 +# # Create a mask to remove logits with cumulative probability above the threshold +# sorted_indices_to_remove = cumulative_probs > top_p +# sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() +# sorted_indices_to_remove[..., 0] = 0 - # Mask the logits - sorted_logits[sorted_indices_to_remove] = -float('Inf') +# # Mask the logits +# sorted_logits[sorted_indices_to_remove] = -float('Inf') - # Sample from the filtered distribution - probabilities = torch.softmax(sorted_logits, dim=-1) - sampled_token = torch.multinomial(probabilities, 1) +# # Sample from the filtered distribution +# probabilities = torch.softmax(sorted_logits, dim=-1) +# sampled_token = torch.multinomial(probabilities, 1) - # Convert to original index order - return sorted_indices.gather(-1, sampled_token) +# # Convert to original index order +# return sorted_indices.gather(-1, sampled_token) + +# def sample_logits(logits, temp=0.3, top_k=85, top_p=2.0, alpha_f=0.1, alpha_p=0.0): +# """ +# Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. + +# Args: +# logits (torch.Tensor): The logits distribution to sample from. +# temp (float): Temperature for scaling logits. +# top_k (int): The number of top tokens to consider for sampling. +# top_p (float): The cumulative probability threshold for nucleus sampling. +# alpha_f (float): Penalty factor for repetition. +# alpha_p (float): Penalty for selecting already selected tokens. + +# Returns: +# torch.Tensor: The selected token indices. +# """ +# # Return argmax for deterministic output at low temperature +# if temp < 1e-6: +# return logits.argmax(dim=-1) + +# # Apply Top-k sampling if specified +# if top_k > 0: +# top_k = min(top_k, logits.size(-1)) +# top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) +# logits = torch.full_like(logits, float('-inf')) +# logits.scatter_(-1, top_k_indices, top_k_values) -def sample_logits(logits, temp=0.3, top_k=85, top_p=2.0, alpha_f=0.1, alpha_p=0.0): +# # Apply Top-p (nucleus) sampling if specified +# if 0 < top_p < 1.0: +# logits = top_p_sampling(logits, top_p, temp) + +# # Apply alpha sampling to discourage repetition +# if alpha_f > 0.0 or alpha_p > 0.0: +# if not hasattr(sample_logits, "alpha_counter"): +# sample_logits.alpha_counter = torch.zeros_like(logits, dtype=torch.int32) +# logits = logits - (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) + +# # Sample from the logits +# probabilities = F.softmax(logits, dim=-1) +# sampled_token = torch.multinomial(probabilities, 1) + +# # Update alpha counter +# if alpha_f > 0.0 or alpha_p > 0.0: +# sample_logits.alpha_counter.scatter_(-1, sampled_token, sample_logits.alpha_counter.gather(-1, sampled_token) + 1) + +# return sampled_token.squeeze() + +def sample_logits(logits, temperature=1.0, top_k=50, top_p=0.95): """ - Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. + Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. Args: logits (torch.Tensor): The logits distribution to sample from. - temp (float): Temperature for scaling logits. + temperature (float): Temperature for scaling logits. top_k (int): The number of top tokens to consider for sampling. top_p (float): The cumulative probability threshold for nucleus sampling. - alpha_f (float): Penalty factor for repetition. - alpha_p (float): Penalty for selecting already selected tokens. Returns: - torch.Tensor: The selected token indices. + torch.Tensor: The selected token index. """ - # Return argmax for deterministic output at low temperature - if temp < 1e-6: - return logits.argmax(dim=-1) + + # Ensure logits are in a floating-point format + logits = logits.float() + + # Apply temperature scaling + if temperature != 1.0: + logits = logits / temperature # Apply Top-k sampling if specified if top_k > 0: @@ -65,21 +113,19 @@ def sample_logits(logits, temp=0.3, top_k=85, top_p=2.0, alpha_f=0.1, alpha_p=0. # Apply Top-p (nucleus) sampling if specified if 0 < top_p < 1.0: - logits = top_p_sampling(logits, top_p, temp) + sorted_logits, _ = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 - # Apply alpha sampling to discourage repetition - if alpha_f > 0.0 or alpha_p > 0.0: - if not hasattr(sample_logits, "alpha_counter"): - sample_logits.alpha_counter = torch.zeros_like(logits, dtype=torch.int32) - logits = logits - (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) + sorted_logits[sorted_indices_to_remove] = -float('Inf') + logits = sorted_logits # Sample from the logits probabilities = F.softmax(logits, dim=-1) sampled_token = torch.multinomial(probabilities, 1) - # Update alpha counter - if alpha_f > 0.0 or alpha_p > 0.0: - sample_logits.alpha_counter.scatter_(-1, sampled_token, sample_logits.alpha_counter.gather(-1, sampled_token) + 1) - - return sampled_token.squeeze() - + return sampled_token.squeeze() \ No newline at end of file From 9805ac2ca229c49ac1b35b6c7ca2bd5d6a2a4e88 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:33:21 -0800 Subject: [PATCH 282/491] change sampling --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index b2ea5847..64f30bd2 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -83,7 +83,7 @@ # return sampled_token.squeeze() -def sample_logits(logits, temperature=1.0, top_k=50, top_p=0.95): +def sample_logits(logits, temperature=1.0, top_k=15, top_p=0.95): """ Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. From 8da3114b36911bc460fd62d00589918052fc2278 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:35:02 -0800 Subject: [PATCH 283/491] change sampling --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 64f30bd2..163e61fb 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -83,7 +83,7 @@ # return sampled_token.squeeze() -def sample_logits(logits, temperature=1.0, top_k=15, top_p=0.95): +def sample_logits(logits, temperature=0.6, top_k=15, top_p=0.55): """ Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. From c62dd2d6ce589e516b0defe76cb08a4110b19e37 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:37:50 -0800 Subject: [PATCH 284/491] change sampling --- exo/inference/pytorch/model/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 163e61fb..54fd6245 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -83,7 +83,7 @@ # return sampled_token.squeeze() -def sample_logits(logits, temperature=0.6, top_k=15, top_p=0.55): +def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): """ Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. @@ -102,7 +102,7 @@ def sample_logits(logits, temperature=0.6, top_k=15, top_p=0.55): # Apply temperature scaling if temperature != 1.0: - logits = logits / temperature + logits = logits * (1 / temperature) # Apply Top-k sampling if specified if top_k > 0: From 0f7f96dc848aa81edbef59641d0679ecfb7adb72 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:41:01 -0800 Subject: [PATCH 285/491] change sampling --- exo/inference/pytorch/model/utils.py | 46 +++++++++++++++------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 54fd6245..4cb704d2 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -101,28 +101,30 @@ def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): logits = logits.float() # Apply temperature scaling - if temperature != 1.0: - logits = logits * (1 / temperature) - - # Apply Top-k sampling if specified - if top_k > 0: - top_k = min(top_k, logits.size(-1)) - top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) - logits = torch.full_like(logits, float('-inf')) - logits.scatter_(-1, top_k_indices, top_k_values) - - # Apply Top-p (nucleus) sampling if specified - if 0 < top_p < 1.0: - sorted_logits, _ = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - sorted_logits[sorted_indices_to_remove] = -float('Inf') - logits = sorted_logits + if temperature == 0: + logits = logits.argmax(dim=-1) + else: + # Apply Top-k sampling if specified + if top_k > 0: + top_k = min(top_k, logits.size(-1)) + top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) + logits = torch.full_like(logits, float('-inf')) + logits.scatter_(-1, top_k_indices, top_k_values) + + # Apply Top-p (nucleus) sampling if specified + if 0 < top_p < 1.0: + sorted_logits, _ = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + sorted_logits[sorted_indices_to_remove] = -float('Inf') + logits = sorted_logits + else: + logits = logits * (1/temperature) # Sample from the logits probabilities = F.softmax(logits, dim=-1) From fc3661939359cd9641016c72bdb35c6b218585a9 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:43:22 -0800 Subject: [PATCH 286/491] change sampling --- exo/inference/pytorch/model/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 4cb704d2..ab305523 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -97,8 +97,8 @@ def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): torch.Tensor: The selected token index. """ - # Ensure logits are in a floating-point format - logits = logits.float() + # Ensure logits are long + logits = logits.long() # Apply temperature scaling if temperature == 0: From b5f98d58dc4ea09b71a29c84501ea845d540651f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:46:27 -0800 Subject: [PATCH 287/491] remove softmax --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index ab305523..89a4522b 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -127,7 +127,7 @@ def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): logits = logits * (1/temperature) # Sample from the logits - probabilities = F.softmax(logits, dim=-1) + # probabilities = F.softmax(logits, dim=-1) sampled_token = torch.multinomial(probabilities, 1) return sampled_token.squeeze() \ No newline at end of file From 52d608f372c899446690797d141597b67a98c4e5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:47:31 -0800 Subject: [PATCH 288/491] remove softmax --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 89a4522b..b681a40b 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -128,6 +128,6 @@ def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): # Sample from the logits # probabilities = F.softmax(logits, dim=-1) - sampled_token = torch.multinomial(probabilities, 1) + sampled_token = torch.multinomial(logits, 1) return sampled_token.squeeze() \ No newline at end of file From b17a9ab39d55f33eadb631c0022cd748f15db98a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:48:54 -0800 Subject: [PATCH 289/491] float long issue --- exo/inference/pytorch/model/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index b681a40b..7b1a0735 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -97,8 +97,8 @@ def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): torch.Tensor: The selected token index. """ - # Ensure logits are long - logits = logits.long() + # Ensure logits are float + logits = logits.float() # Apply temperature scaling if temperature == 0: From 69552e050897c152e6d3471713d21624bcf30449 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:50:13 -0800 Subject: [PATCH 290/491] float long issue --- exo/inference/pytorch/model/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 7b1a0735..db2d890e 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -101,8 +101,8 @@ def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): logits = logits.float() # Apply temperature scaling - if temperature == 0: - logits = logits.argmax(dim=-1) + if temperature == 0.0: + logits = logits.argmax(dim=-1).float() else: # Apply Top-k sampling if specified if top_k > 0: From 1ee8a10371b404af348396ff00d616a92f2817e3 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 17:52:37 -0800 Subject: [PATCH 291/491] float long issue --- exo/inference/pytorch/model/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index db2d890e..07b228c6 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -127,7 +127,7 @@ def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): logits = logits * (1/temperature) # Sample from the logits - # probabilities = F.softmax(logits, dim=-1) - sampled_token = torch.multinomial(logits, 1) + probabilities = F.softmax(logits, dim=-1) + sampled_token = torch.multinomial(probabilities, 1) return sampled_token.squeeze() \ No newline at end of file From 1d9f48286cf7db7e0019ae1da314da12bde5706c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 18:00:11 -0800 Subject: [PATCH 292/491] float long issue --- exo/inference/pytorch/model/utils.py | 68 ++++++++++++++++------------ 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 07b228c6..e17c0fb1 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -83,7 +83,7 @@ # return sampled_token.squeeze() -def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): +def sample_logits(ogits, temperature=1.0, top_k=0, top_p=1.0, alpha_f=0.0, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. @@ -92,42 +92,54 @@ def sample_logits(logits, temperature=0.0, top_k=0, top_p=1.0): temperature (float): Temperature for scaling logits. top_k (int): The number of top tokens to consider for sampling. top_p (float): The cumulative probability threshold for nucleus sampling. + alpha_f (float): Penalty factor for repetition frequency. + alpha_p (float): Penalty for repeated selection. Returns: torch.Tensor: The selected token index. """ - # Ensure logits are float + # Ensure logits are float logits = logits.float() + # If temperature is very low, just use argmax + if temperature < 1e-6: + return logits.argmax(dim=-1) + + # Alpha sampling (adjusting logits based on past selections) + if alpha_f > 0.0 or alpha_p > 0.0: + logits -= (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) + + # Replace NaNs with -inf to prevent softmax issues + logits = torch.where(torch.isnan(logits), torch.full_like(logits, -float('inf')), logits) + # Apply temperature scaling - if temperature == 0.0: - logits = logits.argmax(dim=-1).float() - else: - # Apply Top-k sampling if specified - if top_k > 0: - top_k = min(top_k, logits.size(-1)) - top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) - logits = torch.full_like(logits, float('-inf')) - logits.scatter_(-1, top_k_indices, top_k_values) - - # Apply Top-p (nucleus) sampling if specified - if 0 < top_p < 1.0: - sorted_logits, _ = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - sorted_logits[sorted_indices_to_remove] = -float('Inf') - logits = sorted_logits - else: - logits = logits * (1/temperature) - - # Sample from the logits + logits = logits / temperature + + # Top-k sampling + if top_k > 0: + top_k = min(top_k, logits.size(-1)) + top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) + logits = torch.full_like(logits, -float('inf')) + logits.scatter_(-1, top_k_indices, top_k_values) + + # Top-p sampling + if 0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + sorted_logits[sorted_indices_to_remove] = -float('inf') + logits = sorted_logits + + # Apply softmax to get probabilities probabilities = F.softmax(logits, dim=-1) + + # Sample from the probabilities sampled_token = torch.multinomial(probabilities, 1) return sampled_token.squeeze() \ No newline at end of file From 2ca9689b014297776b77c44505d05c49d5c0d83e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 18:01:07 -0800 Subject: [PATCH 293/491] float long issue --- exo/inference/pytorch/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index e17c0fb1..8ab499fc 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -83,7 +83,7 @@ # return sampled_token.squeeze() -def sample_logits(ogits, temperature=1.0, top_k=0, top_p=1.0, alpha_f=0.0, alpha_p=0.0): +def sample_logits(logits, temperature=1.0, top_k=0, top_p=1.0, alpha_f=0.0, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. From 0b8c9f2d6c01e51b5bf524dd4db10d5522c78d33 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 18:28:14 -0800 Subject: [PATCH 294/491] cleaning up utils.py --- exo/inference/pytorch/model/utils.py | 82 ---------------------------- 1 file changed, 82 deletions(-) diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index 8ab499fc..d56be5d8 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -1,88 +1,6 @@ import torch from torch.nn import functional as F -# def top_p_sampling(logits, top_p: float, temperature: float = 1.0): -# """ -# Perform top-p sampling (nucleus sampling) on logits. - -# Args: -# logits (torch.Tensor): The logits distribution to sample from. -# top_p (float): The cumulative probability threshold for nucleus sampling. -# temperature (float): Sampling temperature. - -# Returns: -# torch.Tensor: The selected token indices. -# """ -# # Apply temperature scaling -# logits = logits/temperature - -# # Sort the logits in descending order -# sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - -# # Calculate cumulative probabilities -# cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - -# # Create a mask to remove logits with cumulative probability above the threshold -# sorted_indices_to_remove = cumulative_probs > top_p -# sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() -# sorted_indices_to_remove[..., 0] = 0 - -# # Mask the logits -# sorted_logits[sorted_indices_to_remove] = -float('Inf') - -# # Sample from the filtered distribution -# probabilities = torch.softmax(sorted_logits, dim=-1) -# sampled_token = torch.multinomial(probabilities, 1) - -# # Convert to original index order -# return sorted_indices.gather(-1, sampled_token) - -# def sample_logits(logits, temp=0.3, top_k=85, top_p=2.0, alpha_f=0.1, alpha_p=0.0): -# """ -# Sample tokens from logits using temperature, top-k, top-p, and alpha sampling. - -# Args: -# logits (torch.Tensor): The logits distribution to sample from. -# temp (float): Temperature for scaling logits. -# top_k (int): The number of top tokens to consider for sampling. -# top_p (float): The cumulative probability threshold for nucleus sampling. -# alpha_f (float): Penalty factor for repetition. -# alpha_p (float): Penalty for selecting already selected tokens. - -# Returns: -# torch.Tensor: The selected token indices. -# """ -# # Return argmax for deterministic output at low temperature -# if temp < 1e-6: -# return logits.argmax(dim=-1) - -# # Apply Top-k sampling if specified -# if top_k > 0: -# top_k = min(top_k, logits.size(-1)) -# top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) -# logits = torch.full_like(logits, float('-inf')) -# logits.scatter_(-1, top_k_indices, top_k_values) - -# # Apply Top-p (nucleus) sampling if specified -# if 0 < top_p < 1.0: -# logits = top_p_sampling(logits, top_p, temp) - -# # Apply alpha sampling to discourage repetition -# if alpha_f > 0.0 or alpha_p > 0.0: -# if not hasattr(sample_logits, "alpha_counter"): -# sample_logits.alpha_counter = torch.zeros_like(logits, dtype=torch.int32) -# logits = logits - (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) - -# # Sample from the logits -# probabilities = F.softmax(logits, dim=-1) -# sampled_token = torch.multinomial(probabilities, 1) - -# # Update alpha counter -# if alpha_f > 0.0 or alpha_p > 0.0: -# sample_logits.alpha_counter.scatter_(-1, sampled_token, sample_logits.alpha_counter.gather(-1, sampled_token) + 1) - -# return sampled_token.squeeze() - def sample_logits(logits, temperature=1.0, top_k=0, top_p=1.0, alpha_f=0.0, alpha_p=0.0): """ Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. From 94de83f59111793875f2045796fef8652139a52b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 10 Aug 2024 18:29:55 -0800 Subject: [PATCH 295/491] removing broken llama.py --- exo/inference/pytorch/model/llama.py | 56 ---------------------------- 1 file changed, 56 deletions(-) delete mode 100644 exo/inference/pytorch/model/llama.py diff --git a/exo/inference/pytorch/model/llama.py b/exo/inference/pytorch/model/llama.py deleted file mode 100644 index f6427e02..00000000 --- a/exo/inference/pytorch/model/llama.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -import torch.nn as nn -from transformers.models.llama.modeling_llama import LlamaForCausalLM -from exo.inference.shard import Shard - -class ShardedLLAMAModel(nn.Module): - def __init__(self, model_path: str, shard: Shard): - super(ShardedLLAMAModel, self).__init__() - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.shard = shard - - # Load the full model - self.full_model = LlamaForCausalLM.from_pretrained(model_path) - self.full_model.to(self.device) - - # Extract only the layers for this shard - self.layers = nn.ModuleList([ - self.full_model.model.layers[i] for i in range(shard.start_layer, shard.end_layer + 1) - ]) - - # Embeddings and final layer norm - self.embed_tokens = self.full_model.model.embed_tokens - self.embed_positions = self.full_model.model.embed_positions - self.norm = self.full_model.model.norm - self.lm_head = self.full_model.lm_head - - def forward_layers(self, input_ids, past_key_values=None): - """ - Forward pass through the specified layers. - - Args: - input_ids (torch.Tensor): Input token IDs. - past_key_values (list, optional): Past key values for caching. - - Returns: - tuple: Hidden states and new past key values. - """ - if past_key_values is None: - past_key_values = [None] * len(self.layers) - - # Token and position embeddings - hidden_states = self.embed_tokens(input_ids) + self.embed_positions(input_ids) - - # Apply each layer in this shard - new_past_key_values = [] - for i, layer in enumerate(self.layers): - layer_past = past_key_values[i] - hidden_states, new_layer_past = layer(hidden_states, past_key_values=layer_past, use_cache=True) - new_past_key_values.append(new_layer_past) - - if self.shard.is_last_layer(): - hidden_states = self.norm(hidden_states) - logits = self.lm_head(hidden_states) - return logits, new_past_key_values - else: - return hidden_states, new_past_key_values From 226a0acd8149184d38e365fc7d7e49960a896f13 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 24 Aug 2024 18:23:39 -0800 Subject: [PATCH 296/491] removing unittest, update inference return type, fixing converting tensor to numpy --- exo/inference/pytorch/README.md | 18 ++++++++++++++++ exo/inference/pytorch/inference.py | 13 ++++++++---- exo/inference/pytorch/model/hf.py | 7 +++---- .../pytorch/test_build_transformer.py | 21 ------------------- .../pytorch/test_inference_engine.py | 6 ++---- 5 files changed, 32 insertions(+), 33 deletions(-) create mode 100644 exo/inference/pytorch/README.md delete mode 100644 exo/inference/pytorch/test_build_transformer.py diff --git a/exo/inference/pytorch/README.md b/exo/inference/pytorch/README.md new file mode 100644 index 00000000..8cb0ce07 --- /dev/null +++ b/exo/inference/pytorch/README.md @@ -0,0 +1,18 @@ +# PyTorch & HuggingFace inference engine +Experimental, still under development + + +## Install +Install needed py modules, make sure to be using CUDA 12.4 for the PyTorch install + +```console +$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 +$ pip install transformers accelerate +``` + +After installing accelerate you get hit with a dependency error, for now ignore until we can fix this as exo works fine with 1.26.4 + +```console +ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. +exo 0.0.1 requires numpy==2.0.0, but you have numpy 1.26.4 which is incompatible. +``` \ No newline at end of file diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a8c72f27..dd6434ca 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -63,11 +63,14 @@ async def infer_prompt( if DEBUG >= 2: print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") - print(f"output_data.item() {output_data.item()}") + print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") - print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + + if output_data.size == 1: + print(f"size 1 output_data.item() {output_data.item()}") + print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") return ( output_data, @@ -104,11 +107,13 @@ async def infer_tensor( if DEBUG >= 2: print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") - print(f"output_data.item() {output_data.item()}") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") - print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + + if output_data.size == 1: + print(f"size 1 output_data.item() {output_data.item()}") + print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") return ( output_data, diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index bbfa0d90..96f573b7 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -19,8 +19,6 @@ def __init__(self, shard: Shard): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard - - # Load the model self.full_model = AutoModelForCausalLM.from_pretrained( shard.model_id, @@ -53,9 +51,10 @@ def __init__(self, shard: Shard): def forward_layers( self, input_data: torch.tensor - ) -> Tuple[np.ndarray, list]: + ) -> np.ndarray: """ Forward pass through the specified layers. + This is without caching Note: past_key_values not working for model, might be a library bug """ @@ -104,4 +103,4 @@ def forward_layers( return output_token - return hidden_states.cpu().numpy() + return hidden_states.cpu().detach().numpy() \ No newline at end of file diff --git a/exo/inference/pytorch/test_build_transformer.py b/exo/inference/pytorch/test_build_transformer.py deleted file mode 100644 index cdbfa6fc..00000000 --- a/exo/inference/pytorch/test_build_transformer.py +++ /dev/null @@ -1,21 +0,0 @@ -import unittest -from unittest.mock import patch, MagicMock -from pathlib import Path -import torch -from exo.inference.shard import Shard -from exo.inference.pytorch.helpers import build_transformer - -class TestBuildTransformer(unittest.TestCase): - - def test_build_transformer(self): - # Call the build_transformer function - model = build_transformer( - "gpt2", - quantize=True, - device="cuda" - ) - - self.assertIsNotNone(model) - -if __name__ == '__main__': - unittest.main() diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index fbc314f0..f1eaf31e 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -8,12 +8,10 @@ def main(): model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, - n_layers=12 + n_layers=32 ) - engine = PyTorchDynamicShardInferenceEngine( - shard - ) + engine = PyTorchDynamicShardInferenceEngine() # Prepare the prompt From e11bebd7f534391f825a5df1aab3fe5bba0ec604 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 24 Aug 2024 20:53:30 -0800 Subject: [PATCH 297/491] adding nvidia quadro and t1000 support --- exo/topology/device_capabilities.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index 6b8de77f..ba81a08b 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -97,6 +97,9 @@ def to_dict(self): "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA T1000 8GB": DeviceFlops(fp32=2.5 * TFLOPS, fp16=5.0 * TFLOPS, int8=10.0 * TFLOPS), + "Quadro M2000": DeviceFlops(fp32=0.5 * TFLOPS, fp16=1.0 * TFLOPS, int8=2.0 * TFLOPS), + "Quadro P400": DeviceFlops(fp32=0.641 * TFLOPS, fp16=1.282 * TFLOPS, int8=2.564 * TFLOPS), # ... add more devices if needed ... ### AMD GPUs # RX 6000 series From 778cb6ef03bc3f2451820006c21dd9a4fa2688e7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 24 Aug 2024 21:06:14 -0800 Subject: [PATCH 298/491] updating test, updating model selection for smaller quant llama3 model --- exo/inference/pytorch/test_inference_engine.py | 2 +- exo/models.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index f1eaf31e..b1e5b56a 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -19,7 +19,7 @@ def main(): # Run inference loop = asyncio.get_event_loop() - output_data, new_inference_state, is_eos = loop.run_until_complete( + output_data, _, _ = loop.run_until_complete( engine.infer_prompt( request_id="test_request", shard=shard, prompt=prompt ) diff --git a/exo/models.py b/exo/models.py index d355e88d..1ad4df21 100644 --- a/exo/models.py +++ b/exo/models.py @@ -5,6 +5,7 @@ "llama-3.1-8b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=32), }, "llama-3.1-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), @@ -19,6 +20,10 @@ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80), }, + "llama-3-2B-Base": { + "TinygradDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=32), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=32), + }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),}, From 56aae50cac1ff6bfdb18d5d4b8d67b69ccb53db6 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 24 Aug 2024 21:41:21 -0800 Subject: [PATCH 299/491] added updating model options to update_deps.py --- tinychat/examples/tinychat/index.html | 24 +++++++++---------- tinychat/examples/tinychat/update_deps.py | 29 ++++++++++++++++++++++- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/tinychat/examples/tinychat/index.html b/tinychat/examples/tinychat/index.html index 6136864f..b437b098 100644 --- a/tinychat/examples/tinychat/index.html +++ b/tinychat/examples/tinychat/index.html @@ -18,8 +18,8 @@ - - + + @@ -27,16 +27,16 @@
+ + + + + + + + + +
Date: Sat, 24 Aug 2024 21:52:38 -0800 Subject: [PATCH 300/491] updating inference class init to take shard, updating pytorch test_inference_engine.py, adding in pytorch option for inference engine --- exo/inference/pytorch/inference.py | 4 +- .../pytorch/test_inference_engine.py | 52 +++++++++++++++++-- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index dd6434ca..cc33b6bf 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -17,14 +17,14 @@ class PyTorchDynamicShardInferenceEngine(InferenceEngine): PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. """ - def __init__(self): + def __init__(self, shard): """ Initialize the inference engine. Args: debug (bool): If True, enables debug logging. Defaults to False. """ - self.shard = None + self.shard = shard self.model = None self.tokenizer = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index b1e5b56a..ffb5a10f 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -2,8 +2,48 @@ import asyncio from exo.inference.shard import Shard from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.inference_engine import InferenceEngine +from exo.inference.shard import Shard +from exo.helpers import DEBUG +import os +import numpy as np + +async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str): + prompt = "In a single word only, what is the last name of the current president of the USA?" + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt) + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + "A", + shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), + input_data=resp_full, + inference_state=inference_state_full, + ) -def main(): + pp = 15 + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt) + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32), + input_data=resp1, + inference_state=inference_state_1, + ) + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), + input_data=resp2, + inference_state=inference_state_2, + ) + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32), + input_data=resp3, + inference_state=inference_state_3, + ) + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) + +def single_test(): shard = Shard( model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, @@ -11,7 +51,7 @@ def main(): n_layers=32 ) - engine = PyTorchDynamicShardInferenceEngine() + engine = PyTorchDynamicShardInferenceEngine(shard) # Prepare the prompt @@ -28,4 +68,10 @@ def main(): assert output_data is not None if __name__ == '__main__': - main() + # single_test() + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "andrijdavid/Llama3-2B-Base", + )) + From aa769cae4b9c1e71072ace3640b96e8470e6713b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 24 Aug 2024 21:55:07 -0800 Subject: [PATCH 301/491] adding updates for inference_engine.py --- exo/inference/inference_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/exo/inference/inference_engine.py b/exo/inference/inference_engine.py index b9465493..ad612c75 100644 --- a/exo/inference/inference_engine.py +++ b/exo/inference/inference_engine.py @@ -27,5 +27,8 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) return TinygradDynamicShardInferenceEngine(shard_downloader) + elif inference_engine_name == "pytorch": + from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine + return PyTorchDynamicShardInferenceEngine(shard_downloader) else: raise ValueError(f"Inference engine {inference_engine_name} not supported") From 08e8b41895255ba12ac873fc060d2a27b79bb747 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 24 Aug 2024 22:05:18 -0800 Subject: [PATCH 302/491] reducing layer amount for llama3-2b-base --- exo/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/models.py b/exo/models.py index 1ad4df21..0b7b48d6 100644 --- a/exo/models.py +++ b/exo/models.py @@ -21,8 +21,8 @@ "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80), }, "llama-3-2B-Base": { - "TinygradDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=32), - "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=32), + "TinygradDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=5), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=5), }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, From dd2812b81642cf909f8fd04c321dbaefb27b286d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 25 Aug 2024 14:20:06 -0800 Subject: [PATCH 303/491] fixing gpu tensor to numpy conversion issues, updating top_p_sampling along with adding in torch topk, added in better random distribution selection for when top_p is too low or high, started work on forward_layer_cached but infer functions need to be changed and take any and not a string --- exo/api/chatgpt_api.py | 21 +++- exo/inference/pytorch/inference.py | 27 +++-- exo/inference/pytorch/model/hf.py | 140 +++++++++++++++++++++----- exo/inference/pytorch/model/utils.py | 102 +++++++++++-------- exo/models.py | 11 +- tinychat/examples/tinychat/index.html | 15 +-- 6 files changed, 226 insertions(+), 90 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index d9af9458..017d71c7 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -113,8 +113,27 @@ def remap_messages(messages: List[Message]) -> List[Message]: def build_prompt(tokenizer, _messages: List[Message]): + if len(_messages) == 1: + user_msg = _messages[0] + + # get instruct sys message + sys_msg = Message(role="system", content="You are a helpful assistant.") + + # restructure for sys_msg to go first + _messages = [sys_msg, user_msg] + messages = remap_messages(_messages) - prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + if DEBUG >= 3: + print(f"prompt: {str(prompt)}") + for msg in messages: + print(f"chat role: {msg.role}\ncontent: {msg.content}") + image_str = None for message in messages: if not isinstance(message.content, list): diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index cc33b6bf..01a80d32 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -37,9 +37,6 @@ async def infer_prompt( image_str: Optional[str] = None, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 2: - print("infer_prompt called") - await self.ensure_shard(shard) # need to make this so inference_state is not a string @@ -47,9 +44,6 @@ async def infer_prompt( tokens = self.tokenizer.encode(prompt, return_tensors="pt") - if DEBUG >= 2: - print(f"tokens: {tokens}\n") - output_data = self.model.forward_layers( tokens ) @@ -60,7 +54,9 @@ async def infer_prompt( print(f"token from llm decode: {self.tokenizer.decode(output_data)}") - if DEBUG >= 2: + if DEBUG >= 4: + print("infer_prompt called") + print(f"tokens: {tokens}\n") print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") @@ -91,7 +87,7 @@ async def infer_tensor( if in_tensor.dim() == 1: in_tensor = in_tensor.unsqueeze(0) # Add a batch dimension: [1, seq_len] - if DEBUG >= 2: + if DEBUG >= 4: print("infer_tensor called") print(f"input_data: {input_data}\n") print(f"in_tensor: {in_tensor}\n") @@ -104,7 +100,7 @@ async def infer_tensor( is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] - if DEBUG >= 2: + if DEBUG >= 4: print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"finished: {is_finished}") @@ -131,12 +127,21 @@ async def ensure_shard(self, shard: Optional[Shard]): if self.shard == shard: return - if DEBUG >= 2: + if DEBUG >= 4: print(f"Loading new shard: {shard}") + # if self.model: + # if DEBUG >= 2: + # print(f"\nCLEARING MODEL {self.shard.model_id}\n") + + # # delete model and free up memory to reload + # self.model.cpu() + # del self.model + # torch.cuda.empty_cache() + self.model = ShardedHuggingFaceModel(shard) self.tokenizer = await resolve_tokenizer(shard.model_id) self.shard = shard - if DEBUG >= 2: + if DEBUG >= 4: print(f"Shard loaded successfully: {shard}") \ No newline at end of file diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 96f573b7..c2064921 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -2,29 +2,45 @@ import torch.nn as nn import numpy as np -from transformers import AutoModelForCausalLM, LlamaConfig, DynamicCache, Cache +from transformers import AutoModelForCausalLM, BitsAndBytesConfig, DynamicCache, Cache from exo.inference.shard import Shard from exo.helpers import DEBUG from typing import Tuple from .utils import sample_logits +TOP_P = 0.75 #0.95 +TOP_K = 20 +TEMP = 0.8 + class ShardedHuggingFaceModel(torch.nn.Module): def __init__(self, shard: Shard): super(ShardedHuggingFaceModel, self).__init__() - if DEBUG >= 2: - print(f"\nShardedHuggingFaceModel init with shard {shard}") + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.shard = shard # Load the model - self.full_model = AutoModelForCausalLM.from_pretrained( - shard.model_id, - torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, - device_map="auto" - ) + try: + self.full_model = AutoModelForCausalLM.from_pretrained( + shard.model_id, + torch_dtype="auto", + device_map="auto", + # offload_buffers=True + ) + # .to(self.device) + except Exception as err: + print(f"Error loading model: {err}") + raise + + if DEBUG >= 2: + print(f"\nShardedHuggingFaceModel init with shard {shard}") + print(f"self.full_model: {self.full_model}") + print(f"self.full_model.model: {self.full_model.model}") # using llamaconfig not working setting layers manually layers = [] @@ -37,6 +53,7 @@ def __init__(self, shard: Shard): layers.append(layer) self.full_model.model.layers = nn.ModuleList(layers) + # .to(self.device) if DEBUG >= 2: print(f"full_model.model layer: {len(self.full_model.model.layers)}") @@ -46,8 +63,6 @@ def __init__(self, shard: Shard): self.embed_tokens = self.full_model.model.embed_tokens self.norm = self.full_model.model.norm - # self.past_key_values = DynamicCache() - def forward_layers( self, input_data: torch.tensor @@ -69,23 +84,98 @@ def forward_layers( if DEBUG >= 2: print(f"\n[layer model] {self.full_model.model}") print(f"IN hidden_states {hidden_states}") - # print(f"past_kvs {past_kvs}") - self.full_model.model.layer_idx = 5 + layer_outputs = self.full_model.model( + hidden_states.to(self.device), + use_cache=False + ) + + if DEBUG >= 2: + print(f"OUT hidden_states {layer_outputs.last_hidden_state}") + + hidden_states = layer_outputs.last_hidden_state + + print(f"2 is_last_layer {self.shard.is_last_layer()}") + if self.shard.is_last_layer(): + hs_norm = self.norm(hidden_states) + hs_lm_head = self.full_model.lm_head(hs_norm).float() + + # Use the sampling function with default settings + with torch.no_grad(): + output_token = sample_logits( + hs_lm_head[:, -1, :], + TEMP, + TOP_P, + TOP_K + ).cpu().numpy().flatten() + + if DEBUG >= 2: + print(f"hs_norm: {hs_norm}") + print(f"hs_lm_head: {hs_lm_head}") + print(f"output_token: {output_token}") + + return output_token + + return hidden_states.cpu().numpy() + + def forward_layers_cached( + self, + input_data: torch.tensor, + past_kvs: Cache = DynamicCache() + ) -> Tuple[np.ndarray, list]: + """ + Forward pass through the specified layers. + With caching + + Note: past_key_values not working for model, might be a library bug + """ + if DEBUG >= 2: + print("forward_layer call") + print(f"input_data: {input_data}") + print(f"shard {self.shard.to_dict()}") + + hidden_states = input_data + position_ids = None + position_embeddings = None + + if self.shard.is_first_layer(): + hidden_states = self.embed_tokens(hidden_states) + + if DEBUG >= 2: + print(f"hidden_states: {hidden_states}") + print(f"hidden_states.size(): {hidden_states.size()}") + + batch_size, seq_len = input_data.size() + position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) + + position_embeddings = self.full_model.model.rotary_emb( + hidden_states, + position_ids + ) + + # if DEBUG >= 2: + # print(f"embedded hidden_states {hidden_states}") + # print(f"position_ids: {position_embeddings}") + + + # Forward pass through the layer + if DEBUG >= 2: + print(f"IN hidden_states {hidden_states}") + print(f"past_kvs {past_kvs}") + layer_outputs = self.full_model.model( hidden_states, - # position_ids=position_ids, - # inputs_embeds=position_embeddings, - # past_key_values=self.past_key_values, - use_cache=False # not enough vram for using cache ;_; + position_ids=position_ids, + inputs_embeds=position_embeddings, + past_key_values=past_kvs, + use_cache=True ) if DEBUG >= 2: - print(f"OUT hidden_states {hidden_states}") - # print(f"\nlayer_outputs: {layer_outputs}") + print(f"\nlayer_outputs: {layer_outputs}") hidden_states = layer_outputs.last_hidden_state - # self.past_key_values = layer_outputs.past_key_values + present_kvs = layer_outputs.past_key_values print(f"2 is_last_layer {self.shard.is_last_layer()}") if self.shard.is_last_layer(): @@ -94,13 +184,17 @@ def forward_layers( # Use the sampling function with default settings output_token = sample_logits( - hs_lm_head[:, -1, :]).cpu().numpy().flatten() + hs_lm_head[:, -1, :], + TEMP, + TOP_P, + TOP_K + ).numpy() if DEBUG >= 2: print(f"hs_norm: {hs_norm}") print(f"hs_lm_head: {hs_lm_head}") print(f"output_token: {output_token}") - return output_token + return (output_token, present_kvs) - return hidden_states.cpu().detach().numpy() \ No newline at end of file + return (hidden_states.numpy(), present_kvs) \ No newline at end of file diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/utils.py index d56be5d8..df84b397 100644 --- a/exo/inference/pytorch/model/utils.py +++ b/exo/inference/pytorch/model/utils.py @@ -1,17 +1,59 @@ import torch from torch.nn import functional as F -def sample_logits(logits, temperature=1.0, top_k=0, top_p=1.0, alpha_f=0.0, alpha_p=0.0): +def top_p_sampling(scaled_logits: torch.Tensor, top_p: float) -> torch.Tensor: + """ + Apply top-p (nucleus) sampling to logits. + + Args: + scaled_logits (torch.Tensor): The scaled logits from the model's output. + top_p (float): The cumulative probability threshold for top-p filtering. + temp (float): Temperature parameter for softmax distribution reshaping. + + Returns: + torch.Tensor: Token selected based on the top-p criterion. + + Ref: + https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py#L67C1-L97C17 + """ + scaled_logits = torch.where(torch.isnan(scaled_logits), torch.zeros_like(scaled_logits), scaled_logits) + scaled_logits = torch.where(torch.isinf(scaled_logits), torch.full_like(scaled_logits, 1e6), scaled_logits) + + probs = torch.softmax(scaled_logits, dim=-1) + + sorted_probs, sorted_indices = torch.sort( + probs, + descending=True, + dim=-1 + ) + + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + mask = cumulative_probs > top_p + + top_probs = torch.where(mask, torch.zeros_like(sorted_probs), sorted_probs) + sum_probs = top_probs.sum(dim=-1, keepdim=True) + top_probs = torch.where(sum_probs > 0, top_probs / sum_probs, torch.ones_like(top_probs) / top_probs.size(-1)) + + if torch.isnan(top_probs).any() or torch.isinf(top_probs).any(): + print("Warning: Top probabilities contain NaN or Inf values after normalization") + top_probs = torch.where(torch.isnan(top_probs) | torch.isinf(top_probs), + 1.0 / top_probs.size(-1), + top_probs) + + sorted_token = torch.multinomial(top_probs, num_samples=1) + + token = sorted_indices.gather(-1, sorted_token) + + return token.squeeze(-1) + +def sample_logits(logits, temp, top_p, top_k): """ Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. Args: logits (torch.Tensor): The logits distribution to sample from. - temperature (float): Temperature for scaling logits. - top_k (int): The number of top tokens to consider for sampling. + temp (float): temp for scaling logits. top_p (float): The cumulative probability threshold for nucleus sampling. - alpha_f (float): Penalty factor for repetition frequency. - alpha_p (float): Penalty for repeated selection. Returns: torch.Tensor: The selected token index. @@ -20,44 +62,22 @@ def sample_logits(logits, temperature=1.0, top_k=0, top_p=1.0, alpha_f=0.0, alph # Ensure logits are float logits = logits.float() - # If temperature is very low, just use argmax - if temperature < 1e-6: + # If temp is very low, just use argmax + if temp == 0: return logits.argmax(dim=-1) + + scaled_logits = logits/temp - # Alpha sampling (adjusting logits based on past selections) - if alpha_f > 0.0 or alpha_p > 0.0: - logits -= (sample_logits.alpha_counter * alpha_f + (sample_logits.alpha_counter > 0) * alpha_p) - - # Replace NaNs with -inf to prevent softmax issues - logits = torch.where(torch.isnan(logits), torch.full_like(logits, -float('inf')), logits) - - # Apply temperature scaling - logits = logits / temperature - - # Top-k sampling + # top k if top_k > 0: - top_k = min(top_k, logits.size(-1)) - top_k_values, top_k_indices = torch.topk(logits, top_k, dim=-1) - logits = torch.full_like(logits, -float('inf')) - logits.scatter_(-1, top_k_indices, top_k_values) - + top_values, top_indices = torch.topk(scaled_logits, top_k, dim=-1) + scaled_logits = torch.zeros_like(logits).scatter_(-1, top_indices, top_values) + # Top-p sampling if 0 < top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - sorted_logits[sorted_indices_to_remove] = -float('inf') - logits = sorted_logits - - # Apply softmax to get probabilities - probabilities = F.softmax(logits, dim=-1) - - # Sample from the probabilities - sampled_token = torch.multinomial(probabilities, 1) - - return sampled_token.squeeze() \ No newline at end of file + return top_p_sampling(scaled_logits, top_p) + else: + # random distribution selection + probs = torch.softmax(scaled_logits, dim=-1) + rand_sample = torch.distributions.Categorical(probs) + return rand_sample.sample().squeeze() \ No newline at end of file diff --git a/exo/models.py b/exo/models.py index 0b7b48d6..72a5b566 100644 --- a/exo/models.py +++ b/exo/models.py @@ -21,8 +21,10 @@ "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80), }, "llama-3-2B-Base": { - "TinygradDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=5), - "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=5), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=6), + }, + "llama-3-1B-Base": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, @@ -31,4 +33,9 @@ "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),}, ### llava "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),}, + ### qwen + "Qwen2-0.5B-Instruct": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), + }, + } diff --git a/tinychat/examples/tinychat/index.html b/tinychat/examples/tinychat/index.html index b437b098..8ff4c64c 100644 --- a/tinychat/examples/tinychat/index.html +++ b/tinychat/examples/tinychat/index.html @@ -19,24 +19,15 @@ - + - +
+
Tuple[np.ndarray, str, bool]: + await self.ensure_shard(shard) # need to make this so inference_state is not a string @@ -44,16 +52,27 @@ async def infer_prompt( tokens = self.tokenizer.encode(prompt, return_tensors="pt") - output_data = self.model.forward_layers( - tokens - ) + if self.use_cache: + # convert inference_state or cache from json to DynamicCache + past_kv = DynamicCache() + if inference_state != None: + cache_dict = json.loads(inference_state) + past_kv.key_cache = [torch.tensor(data) for data in cache_dict['key_cache']] + past_kv.value_cache = [torch.tensor(data) for data in cache_dict['value_cache']] + + output_data, current_kvs = self.model.forward( + tokens, + past_kv, + use_cache=True + ) + else: + output_data = self.model.forward( + tokens, + use_cache=False + ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] - if is_finished: - print(f"token from llm decode: {self.tokenizer.decode(output_data)}") - - if DEBUG >= 4: print("infer_prompt called") print(f"tokens: {tokens}\n") @@ -68,9 +87,17 @@ async def infer_prompt( print(f"size 1 output_data.item() {output_data.item()}") print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + if self.use_cache: + # legacy_cache = current_kvs.to_legacy_cache() + print(current_kvs.key_cache) + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] + } + return ( output_data, - "", + json.dumps(cache_dict) if self.use_cache else "", is_finished ) @@ -79,28 +106,38 @@ async def infer_tensor( request_id: str, shard: Shard, input_data: np.ndarray, - inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: - - in_tensor = torch.tensor(input_data) - - # Ensure input_data is 2D: [batch_size, seq_len] - if in_tensor.dim() == 1: - in_tensor = in_tensor.unsqueeze(0) # Add a batch dimension: [1, seq_len] - - if DEBUG >= 4: - print("infer_tensor called") - print(f"input_data: {input_data}\n") - print(f"in_tensor: {in_tensor}\n") + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: await self.ensure_shard(shard) - output_data = self.model.forward_layers( - in_tensor - ) + in_tensor = torch.tensor(input_data) + + if self.use_cache: + # convert inference_state or cache from json to DynamicCache + past_kv = DynamicCache() + if inference_state != None: + cache_dict = json.loads(inference_state) + past_kv.key_cache = [torch.tensor(data) for data in cache_dict['key_cache']] + past_kv.value_cache = [torch.tensor(data) for data in cache_dict['value_cache']] + + output_data, current_kvs = self.model.forward( + in_tensor, + past_kv, + use_cache=True + ) + else: + output_data = self.model.forward( + in_tensor, + use_cache=False + ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 4: + print("infer_tensor called") + print(f"input_data: {input_data}\n") + print(f"in_tensor: {in_tensor}\n") print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"finished: {is_finished}") @@ -111,9 +148,16 @@ async def infer_tensor( print(f"size 1 output_data.item() {output_data.item()}") print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + if self.use_cache: + legacy_cache = current_kvs.to_legacy_cache() + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in legacy_cache.key_cache], + 'value_cache': [tensor.tolist() for tensor in legacy_cache.value_cache] + } + return ( output_data, - "", + json.dumps(cache_dict) if self.use_cache else "", is_finished ) @@ -139,9 +183,9 @@ async def ensure_shard(self, shard: Optional[Shard]): # del self.model # torch.cuda.empty_cache() - self.model = ShardedHuggingFaceModel(shard) - self.tokenizer = await resolve_tokenizer(shard.model_id) self.shard = shard + self.tokenizer = await resolve_tokenizer(shard.model_id) + self.model = ShardedHuggingFaceModel(shard, self.tokenizer) if DEBUG >= 4: print(f"Shard loaded successfully: {shard}") \ No newline at end of file diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index c2064921..a8345466 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,11 +1,12 @@ import torch import torch.nn as nn import numpy as np +import re from transformers import AutoModelForCausalLM, BitsAndBytesConfig, DynamicCache, Cache from exo.inference.shard import Shard from exo.helpers import DEBUG -from typing import Tuple +from typing import Tuple, Optional, Union, List from .utils import sample_logits @@ -14,7 +15,7 @@ TEMP = 0.8 class ShardedHuggingFaceModel(torch.nn.Module): - def __init__(self, shard: Shard): + def __init__(self, shard: Shard, tokenizer: any): super(ShardedHuggingFaceModel, self).__init__() if torch.cuda.is_available(): @@ -23,178 +24,282 @@ def __init__(self, shard: Shard): self.device = torch.device("cpu") self.shard = shard + self.tokenizer = tokenizer # Load the model try: - self.full_model = AutoModelForCausalLM.from_pretrained( + self.llm_model = AutoModelForCausalLM.from_pretrained( shard.model_id, torch_dtype="auto", device_map="auto", # offload_buffers=True ) - # .to(self.device) + + self.base_model = self.llm_model.model except Exception as err: print(f"Error loading model: {err}") raise if DEBUG >= 2: print(f"\nShardedHuggingFaceModel init with shard {shard}") - print(f"self.full_model: {self.full_model}") - print(f"self.full_model.model: {self.full_model.model}") + print(f"self.llm_model: {self.llm_model}") + print(f"self.llm_model.model: {self.llm_model.model}") - # using llamaconfig not working setting layers manually + # load layers from base model to use layers = [] for i in range(shard.start_layer, shard.end_layer + 1): - layer = self.full_model.model.layers[i] + layer = self.llm_model.model.layers[i] if DEBUG >= 2: print(f"Loading layers[{i}]") layers.append(layer) - self.full_model.model.layers = nn.ModuleList(layers) - # .to(self.device) + self.layers = nn.ModuleList(layers).to(self.device) if DEBUG >= 2: - print(f"full_model.model layer: {len(self.full_model.model.layers)}") + print(f"full_model.model layer: {len(self.llm_model.model.layers)}") # Embeddings and final layer norm # used for doing what forward LlamaModel does in transformers - self.embed_tokens = self.full_model.model.embed_tokens - self.norm = self.full_model.model.norm - - def forward_layers( + self.norm = self.llm_model.model.norm + self.lm_head = self.llm_model.lm_head + + def forward( self, - input_data: torch.tensor - ) -> np.ndarray: + input_ids: torch.tensor, + past_kvs: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: bool = True + ) -> Tuple[np.ndarray, any]: """ - Forward pass through the specified layers. - This is without caching + Forward through layers using the base model - Note: past_key_values not working for model, might be a library bug - """ - if DEBUG >= 2: - print("forward_layer call") - print(f"input_data: {input_data}") - print(f"shard {self.shard.to_dict()}") + Args: + input_ids: tensor input + past_kvs: past key value stores for cache + use_cache: use cache + + Returns: + hidden_states: numpy of states between layers + or logits: numpy of normalization and linearization of last hidden state + past_kvs: DynamicCache of past key values if use_cache is true + + Ref: + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/qwen2/modeling_qwen2.py#L804 + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L887 + """ - hidden_states = input_data + if self.shard.is_first_layer(): + inputs_embeds = self.base_model.embed_tokens(input_ids.to(self.device)) - # Forward pass through the layer - if DEBUG >= 2: - print(f"\n[layer model] {self.full_model.model}") - print(f"IN hidden_states {hidden_states}") - - layer_outputs = self.full_model.model( - hidden_states.to(self.device), - use_cache=False - ) + if use_cache: + past_kvs = DynamicCache.from_legacy_cache(past_kvs) - if DEBUG >= 2: - print(f"OUT hidden_states {layer_outputs.last_hidden_state}") - - hidden_states = layer_outputs.last_hidden_state + past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) - print(f"2 is_last_layer {self.shard.is_last_layer()}") - if self.shard.is_last_layer(): - hs_norm = self.norm(hidden_states) - hs_lm_head = self.full_model.lm_head(hs_norm).float() + position_ids = cache_position.unsqueeze(0) + + hidden_states = inputs_embeds + + # progress through layers + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + position_ids=position_ids, + past_key_value=past_kvs, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + next_kvs = layer_outputs[1] + + if DEBUG >= 3: + print(f"hidden_state: {hidden_states}") + print(f"next_kvs: {next_kvs}") - # Use the sampling function with default settings + if self.shard.is_last_layer(): + norm = self.norm(hidden_states) + lm_head = self.lm_head(norm).float() + with torch.no_grad(): - output_token = sample_logits( - hs_lm_head[:, -1, :], + logits = sample_logits( + lm_head[:, -1, :], TEMP, TOP_P, TOP_K ).cpu().numpy().flatten() - if DEBUG >= 2: - print(f"hs_norm: {hs_norm}") - print(f"hs_lm_head: {hs_lm_head}") - print(f"output_token: {output_token}") + if DEBUG >= 3: + print( + self.tokenizer.batch_decode( + logits, + skip_special_tokens=True + )[0] + ) - return output_token - - return hidden_states.cpu().numpy() - - def forward_layers_cached( - self, - input_data: torch.tensor, - past_kvs: Cache = DynamicCache() - ) -> Tuple[np.ndarray, list]: - """ - Forward pass through the specified layers. - With caching - - Note: past_key_values not working for model, might be a library bug - """ - if DEBUG >= 2: - print("forward_layer call") - print(f"input_data: {input_data}") - print(f"shard {self.shard.to_dict()}") - - hidden_states = input_data - position_ids = None - position_embeddings = None + return (logits, next_kvs) - if self.shard.is_first_layer(): - hidden_states = self.embed_tokens(hidden_states) - - if DEBUG >= 2: - print(f"hidden_states: {hidden_states}") - print(f"hidden_states.size(): {hidden_states.size()}") - - batch_size, seq_len = input_data.size() - position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) - - position_embeddings = self.full_model.model.rotary_emb( - hidden_states, - position_ids - ) + return ( + hidden_states.cpu().numpy(), + next_kvs + ) - # if DEBUG >= 2: - # print(f"embedded hidden_states {hidden_states}") - # print(f"position_ids: {position_embeddings}") + # def forward_layers( + # self, + # input_data: torch.tensor + # ) -> np.ndarray: + # """ + # Forward pass through the specified layers. + # This is without caching + + # Note: past_key_values not working for model, might be a library bug + # """ + # if DEBUG >= 2: + # print("forward_layer call") + # print(f"input_data: {input_data}") + # print(f"shard {self.shard.to_dict()}") + + # hidden_states = input_data + + # # Forward pass through the layer + # if DEBUG >= 2: + # print(f"\n[layer model] {self.llm_model.model}") + # print(f"IN hidden_states {hidden_states}") + + # layer_outputs = self.llm_model.model( + # hidden_states.to(self.device), + # use_cache=False + # ) + # if DEBUG >= 2: + # print(f"OUT hidden_states {layer_outputs.last_hidden_state}") - # Forward pass through the layer - if DEBUG >= 2: - print(f"IN hidden_states {hidden_states}") - print(f"past_kvs {past_kvs}") + # hidden_states = layer_outputs.last_hidden_state + + # print(f"2 is_last_layer {self.shard.is_last_layer()}") + # if self.shard.is_last_layer(): + # hs_norm = self.norm(hidden_states) + # hs_lm_head = self.llm_model.lm_head(hs_norm).float() + + # # Use the sampling function with default settings + # with torch.no_grad(): + # output_token = sample_logits( + # hs_lm_head[:, -1, :], + # TEMP, + # TOP_P, + # TOP_K + # ).cpu().numpy().flatten() + + # if DEBUG >= 2: + # print(f"hs_norm: {hs_norm}") + # print(f"hs_lm_head: {hs_lm_head}") + # print(f"output_token: {output_token}") + + # return output_token - layer_outputs = self.full_model.model( - hidden_states, - position_ids=position_ids, - inputs_embeds=position_embeddings, - past_key_values=past_kvs, - use_cache=True - ) + # return hidden_states.cpu().numpy() + + # def forward_layers_cached( + # self, + # input_data: torch.tensor, + # past_kvs + # ) -> Tuple[np.ndarray, list]: + # """ + # Forward pass through the specified layers. + # With caching + + # Note: past_key_values not working for model, might be a library bug + # """ + + # if not past_kvs: + # past_kvs = DynamicCache() + # else: + # past_kvs = DynamicCache.from_legacy_cache(past_kvs) + + # if DEBUG >= 2: + # print("forward_layer call") + # print(f"input_data: {input_data}") + # print(f"shard {self.shard.to_dict()}") + # print(f"past_kvs: {past_kvs}") + + # input_ids = input_data.to(self.device) + # position_ids = None + # # position_embeddings = None + + # inputs_embeds = self.embed_tokens(input_ids) + + # if self.shard.is_first_layer(): + # hidden_states = self.embed_tokens(hidden_states) + + # if DEBUG >= 2: + # print(f"hidden_states: {hidden_states}") + # print(f"hidden_states.size(): {hidden_states.size()}") + + # batch_size, seq_len = input_data.size() + # position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) + + # # check if model does not have rotary emb + # # have to apply rotary per model + # # embedding seems very model specific and using position_ids + # # seems more universal, even though some give warning about it + # # if re.match(r"Qwen|qwen", self.shard.model_id): + # # import transformers.models.qwen2.modeling_qwen2 as qwen2 + # # position_embeddings = + # # q=hidden_states, + # # position_ids=position_ids + # # ) + # # else: + # # position_embeddings = self.llm_model.model.rotary_emb( + # # hidden_states, + # # position_ids + # # ) + + # # if DEBUG >= 2: + # # print(f"embedded hidden_states {hidden_states}") + # # print(f"position_ids: {position_embeddings}") - if DEBUG >= 2: - print(f"\nlayer_outputs: {layer_outputs}") - hidden_states = layer_outputs.last_hidden_state - present_kvs = layer_outputs.past_key_values - - print(f"2 is_last_layer {self.shard.is_last_layer()}") - if self.shard.is_last_layer(): - hs_norm = self.norm(hidden_states) - hs_lm_head = self.full_model.lm_head(hs_norm).float() - - # Use the sampling function with default settings - output_token = sample_logits( - hs_lm_head[:, -1, :], - TEMP, - TOP_P, - TOP_K - ).numpy() - - if DEBUG >= 2: - print(f"hs_norm: {hs_norm}") - print(f"hs_lm_head: {hs_lm_head}") - print(f"output_token: {output_token}") - - return (output_token, present_kvs) + # # Forward pass through the layer + # if DEBUG >= 2: + # print(f"IN hidden_states {hidden_states}") + # print(f"past_kvs {past_kvs}") + + # layer_outputs = self.llm_model.model( + # hidden_states, + # position_ids=position_ids, + # past_key_values=past_kvs, + # use_cache=True + # ) + + # if DEBUG >= 2: + # print(f"\nlayer_outputs: {layer_outputs}") + + # hidden_states = layer_outputs.last_hidden_state + # present_kvs = layer_outputs.past_key_values + + # print(f"2 is_last_layer {self.shard.is_last_layer()}") + # if self.shard.is_last_layer(): + # hs_norm = self.norm(hidden_states) + # hs_lm_head = self.llm_model.lm_head(hs_norm).float() + + # # Use the sampling function with default settings + # with torch.no_grad(): + # output_token = sample_logits( + # hs_lm_head[:, -1, :], + # TEMP, + # TOP_P, + # TOP_K + # ).cpu().numpy().flatten() + + # if DEBUG >= 2: + # print(f"hs_norm: {hs_norm}") + # print(f"hs_lm_head: {hs_lm_head}") + # print(f"output_token: {output_token}") + + # return (output_token, present_kvs) - return (hidden_states.numpy(), present_kvs) \ No newline at end of file + # return (hidden_states.cpu().numpy(), present_kvs) \ No newline at end of file From 3beea222d43bdce04b448abc43e7bd2ddcb61a6d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 03:08:47 -0800 Subject: [PATCH 305/491] updates to caching, stuck on issue with infer_prompt and infer_tensor where data from infer_prompt is not complete --- exo/inference/pytorch/inference.py | 56 ++++++++++++------ exo/inference/pytorch/model/hf.py | 91 ++++++++++++++++++------------ 2 files changed, 93 insertions(+), 54 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 30842174..95d71f33 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -12,8 +12,6 @@ from exo.helpers import DEBUG from transformers import DynamicCache -from exo.inference.pytorch.model.utils import sample_logits - class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. @@ -50,15 +48,24 @@ async def infer_prompt( # need to make this so inference_state is not a string # cant use it with dynamic cache - tokens = self.tokenizer.encode(prompt, return_tensors="pt") + tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + tokens = self.model.embed_tokens(tokens) + current_kvs = None + if DEBUG >= 4: + print("infer_prompt called") + print(f"tokens: {tokens}\n") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + if self.use_cache: # convert inference_state or cache from json to DynamicCache past_kv = DynamicCache() if inference_state != None: cache_dict = json.loads(inference_state) - past_kv.key_cache = [torch.tensor(data) for data in cache_dict['key_cache']] - past_kv.value_cache = [torch.tensor(data) for data in cache_dict['value_cache']] + past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] + past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] output_data, current_kvs = self.model.forward( tokens, @@ -74,8 +81,6 @@ async def infer_prompt( is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 4: - print("infer_prompt called") - print(f"tokens: {tokens}\n") print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") @@ -88,8 +93,6 @@ async def infer_prompt( print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") if self.use_cache: - # legacy_cache = current_kvs.to_legacy_cache() - print(current_kvs.key_cache) cache_dict = { 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] @@ -111,15 +114,35 @@ async def infer_tensor( await self.ensure_shard(shard) - in_tensor = torch.tensor(input_data) + current_kvs = None + + in_tensor = torch.tensor( + input_data, + device=self.device + ) + + if in_tensor.dim() == 1: + in_tensor = in_tensor.unsqueeze(1) + + in_tensor = self.model.embed_tokens(in_tensor) + if DEBUG >= 4: + print("infer_tensor called") + print(f"input_data: {input_data}") + print(f"input_data.size: {input_data.size}") + print(f"input_tensor: {in_tensor}\n") + print(f"shard: {self.shard}") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + if self.use_cache: # convert inference_state or cache from json to DynamicCache past_kv = DynamicCache() if inference_state != None: cache_dict = json.loads(inference_state) - past_kv.key_cache = [torch.tensor(data) for data in cache_dict['key_cache']] - past_kv.value_cache = [torch.tensor(data) for data in cache_dict['value_cache']] + past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] + past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] output_data, current_kvs = self.model.forward( in_tensor, @@ -135,8 +158,6 @@ async def infer_tensor( is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] if DEBUG >= 4: - print("infer_tensor called") - print(f"input_data: {input_data}\n") print(f"in_tensor: {in_tensor}\n") print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") @@ -148,11 +169,10 @@ async def infer_tensor( print(f"size 1 output_data.item() {output_data.item()}") print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") - if self.use_cache: - legacy_cache = current_kvs.to_legacy_cache() + if self.use_cache and current_kvs: cache_dict = { - 'key_cache': [tensor.tolist() for tensor in legacy_cache.key_cache], - 'value_cache': [tensor.tolist() for tensor in legacy_cache.value_cache] + 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] } return ( diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index a8345466..fd907b24 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -8,7 +8,7 @@ from exo.helpers import DEBUG from typing import Tuple, Optional, Union, List -from .utils import sample_logits +from exo.inference.pytorch.model.utils import sample_logits TOP_P = 0.75 #0.95 TOP_K = 20 @@ -30,7 +30,7 @@ def __init__(self, shard: Shard, tokenizer: any): try: self.llm_model = AutoModelForCausalLM.from_pretrained( shard.model_id, - torch_dtype="auto", + torch_dtype=torch.float32, device_map="auto", # offload_buffers=True ) @@ -64,6 +64,7 @@ def __init__(self, shard: Shard, tokenizer: any): # used for doing what forward LlamaModel does in transformers self.norm = self.llm_model.model.norm self.lm_head = self.llm_model.lm_head + self.embed_tokens = self.base_model.embed_tokens def forward( self, @@ -88,65 +89,83 @@ def forward( https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/qwen2/modeling_qwen2.py#L804 https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L887 """ - - if self.shard.is_first_layer(): - inputs_embeds = self.base_model.embed_tokens(input_ids.to(self.device)) - - if use_cache: - past_kvs = DynamicCache.from_legacy_cache(past_kvs) - + if DEBUG >= 4: + print("forward called") + print(f"input_ids: {input_ids}\n") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + + if use_cache: + past_kvs = DynamicCache.from_legacy_cache(past_kvs) past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + + # if self.shard.is_first_layer(): + # inputs_embeds = self.embed_tokens(input_ids) + + # cache_position = torch.arange( + # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + # ).to(self.device) - position_ids = cache_position.unsqueeze(0) + # position_ids = cache_position.unsqueeze(0).to(self.device) - hidden_states = inputs_embeds + # hidden_states = inputs_embeds # progress through layers for decoder_layer in self.layers: + if DEBUG >= 4: + print("Going through layer") + print(f"{decoder_layer}") + layer_outputs = decoder_layer( - hidden_states, - position_ids=position_ids, + input_ids, + # position_ids=position_ids, past_key_value=past_kvs, use_cache=use_cache, - cache_position=cache_position, + # cache_position=cache_position, ) hidden_states = layer_outputs[0] - next_kvs = layer_outputs[1] + if use_cache: + next_kvs = layer_outputs[1] if DEBUG >= 3: print(f"hidden_state: {hidden_states}") print(f"next_kvs: {next_kvs}") - + if self.shard.is_last_layer(): - norm = self.norm(hidden_states) - lm_head = self.lm_head(norm).float() - + hs_norm = self.norm(hidden_states) + hs_lm_head = self.llm_model.lm_head(hs_norm).float() + + # Use the sampling function with default settings with torch.no_grad(): - logits = sample_logits( - lm_head[:, -1, :], + output_token = sample_logits( + hs_lm_head[:, -1, :], TEMP, TOP_P, TOP_K ).cpu().numpy().flatten() - if DEBUG >= 3: - print( - self.tokenizer.batch_decode( - logits, - skip_special_tokens=True - )[0] - ) + if DEBUG >= 2: + print(f"hs_norm: {hs_norm}") + print(f"hs_lm_head: {hs_lm_head}") + print(f"output_token: {output_token}") - return (logits, next_kvs) + if use_cache: + return (output_token, next_kvs) + + return output_token + + with torch.no_grad(): + out_hidden_states = hidden_states.cpu().numpy() - return ( - hidden_states.cpu().numpy(), - next_kvs - ) + if use_cache: + return ( + out_hidden_states, + next_kvs + ) + + return out_hidden_states # def forward_layers( # self, From 87a14ca7be29768d498290b48c5de0fb76cc61a4 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 03:23:35 -0800 Subject: [PATCH 306/491] trying to fix infer problems --- exo/inference/pytorch/inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 95d71f33..ad6e8f3a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -121,10 +121,10 @@ async def infer_tensor( device=self.device ) - if in_tensor.dim() == 1: - in_tensor = in_tensor.unsqueeze(1) + # if in_tensor.dim() == 1: + # in_tensor = in_tensor.unsqueeze(1) - in_tensor = self.model.embed_tokens(in_tensor) + # in_tensor = self.model.embed_tokens(in_tensor) if DEBUG >= 4: print("infer_tensor called") From 356bf2f56dfc984406bb80c4dfc8fac11c644e64 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 13:53:38 -0800 Subject: [PATCH 307/491] switched everything to use caching, did more prep for encoding the token/logit coming from infer_tensor to infer_prompt, running into OOM issues trying on server --- exo/inference/pytorch/inference.py | 103 ++++++++---------- exo/inference/pytorch/model/hf.py | 48 +++----- .../pytorch/test_inference_engine.py | 86 +++++++++------ 3 files changed, 116 insertions(+), 121 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ad6e8f3a..200b1c4c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -12,6 +12,8 @@ from exo.helpers import DEBUG from transformers import DynamicCache + + class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. @@ -29,11 +31,6 @@ def __init__(self, shard): self.tokenizer = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if os.getenv("TORCH_CACHED") == "True": - self.use_cache = True - else: - self.use_cache = False - async def infer_prompt( self, request_id: str, @@ -59,24 +56,17 @@ async def infer_prompt( print(f"is_first_layer: {self.shard.is_first_layer()}") print(f"is_last_layer: {self.shard.is_last_layer()}") - if self.use_cache: - # convert inference_state or cache from json to DynamicCache - past_kv = DynamicCache() - if inference_state != None: - cache_dict = json.loads(inference_state) - past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] - past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] - - output_data, current_kvs = self.model.forward( - tokens, - past_kv, - use_cache=True - ) - else: - output_data = self.model.forward( - tokens, - use_cache=False - ) + # convert inference_state or cache from json to DynamicCache + past_kv = DynamicCache() + if inference_state != None: + cache_dict = json.loads(inference_state) + past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] + past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] + + output_data, current_kvs = self.model.forward( + tokens, + past_kv + ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] @@ -92,15 +82,14 @@ async def infer_prompt( print(f"size 1 output_data.item() {output_data.item()}") print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") - if self.use_cache: - cache_dict = { - 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] - } + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] + } return ( output_data, - json.dumps(cache_dict) if self.use_cache else "", + json.dumps(cache_dict), is_finished ) @@ -116,15 +105,13 @@ async def infer_tensor( current_kvs = None - in_tensor = torch.tensor( - input_data, - device=self.device - ) - - # if in_tensor.dim() == 1: - # in_tensor = in_tensor.unsqueeze(1) + if input_data.size == 1: + in_tensor = torch.tensor( + input_data, + device=self.device + ).unsqueeze(0).long() - # in_tensor = self.model.embed_tokens(in_tensor) + in_tensor = self.model.embed_tokens(in_tensor) if DEBUG >= 4: print("infer_tensor called") @@ -136,24 +123,26 @@ async def infer_tensor( print(f"is_first_layer: {self.shard.is_first_layer()}") print(f"is_last_layer: {self.shard.is_last_layer()}") - if self.use_cache: - # convert inference_state or cache from json to DynamicCache - past_kv = DynamicCache() - if inference_state != None: + # convert inference_state or cache from json to DynamicCache + past_kv = DynamicCache() + if inference_state != None: + try: cache_dict = json.loads(inference_state) past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] - output_data, current_kvs = self.model.forward( - in_tensor, - past_kv, - use_cache=True - ) - else: - output_data = self.model.forward( - in_tensor, - use_cache=False - ) + if DEBUG >= 4: + print("Loaded past_kv from JSON") + print(f"past_kv: {past_kv}") + print(f"past_kv.key_cache len: {len(past_kv.key_cache)}") + print(f"past_kv.value_cache len: {len(past_kv.value_cache)}") + except json.JSONDecodeError: + print(f"ERROR DECODING INFERENCE STATE") + + output_data, current_kvs = self.model.forward( + in_tensor, + past_kv + ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] @@ -169,15 +158,15 @@ async def infer_tensor( print(f"size 1 output_data.item() {output_data.item()}") print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") - if self.use_cache and current_kvs: - cache_dict = { - 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] - } + + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] + } return ( output_data, - json.dumps(cache_dict) if self.use_cache else "", + json.dumps(cache_dict), is_finished ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index fd907b24..cdb7d7c0 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -55,7 +55,7 @@ def __init__(self, shard: Shard, tokenizer: any): layers.append(layer) - self.layers = nn.ModuleList(layers).to(self.device) + self.layers = nn.ModuleList(layers) if DEBUG >= 2: print(f"full_model.model layer: {len(self.llm_model.model.layers)}") @@ -70,7 +70,6 @@ def forward( self, input_ids: torch.tensor, past_kvs: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_cache: bool = True ) -> Tuple[np.ndarray, any]: """ Forward through layers using the base model @@ -96,20 +95,16 @@ def forward( print(f"is_first_layer: {self.shard.is_first_layer()}") print(f"is_last_layer: {self.shard.is_last_layer()}") - if use_cache: - past_kvs = DynamicCache.from_legacy_cache(past_kvs) - past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 - - # if self.shard.is_first_layer(): - # inputs_embeds = self.embed_tokens(input_ids) - - # cache_position = torch.arange( - # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - # ).to(self.device) + past_kvs = DynamicCache.from_legacy_cache(past_kvs) + past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 - # position_ids = cache_position.unsqueeze(0).to(self.device) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + input_ids.shape[1], + device=input_ids.device + ).to(self.device) - # hidden_states = inputs_embeds + position_ids = cache_position.unsqueeze(0).to(self.device) # progress through layers for decoder_layer in self.layers: @@ -119,15 +114,14 @@ def forward( layer_outputs = decoder_layer( input_ids, - # position_ids=position_ids, + position_ids=position_ids, past_key_value=past_kvs, - use_cache=use_cache, - # cache_position=cache_position, + use_cache=True, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_kvs = layer_outputs[1] + next_kvs = layer_outputs[1] if DEBUG >= 3: print(f"hidden_state: {hidden_states}") @@ -151,21 +145,15 @@ def forward( print(f"hs_lm_head: {hs_lm_head}") print(f"output_token: {output_token}") - if use_cache: - return (output_token, next_kvs) - - return output_token + return (output_token, next_kvs) with torch.no_grad(): out_hidden_states = hidden_states.cpu().numpy() - if use_cache: - return ( - out_hidden_states, - next_kvs - ) - - return out_hidden_states + return ( + out_hidden_states, + next_kvs + ) # def forward_layers( # self, diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index ffb5a10f..725130e2 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -9,33 +9,68 @@ import os import numpy as np -async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str): +async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int): + # prompt = "Why is the sky blue?" prompt = "In a single word only, what is the last name of the current president of the USA?" - resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt) + + shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + "A", + shard=shard, + prompt=prompt + ) + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( "A", - shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), + shard=shard, input_data=resp_full, inference_state=inference_state_full, ) - pp = 15 - resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt) + pp = int(n_layers/2) + resp_shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=pp, + n_layers=n_layers + ) + + resp_shard2 = Shard( + model_id=model_id, + start_layer=pp + 1, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( + "B", + shard=resp_shard, + prompt=prompt + ) + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( "B", - shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32), + shard=resp_shard2, input_data=resp1, inference_state=inference_state_1, ) + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( "B", - shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), + shard=resp_shard, input_data=resp2, inference_state=inference_state_2, ) + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( "B", - shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32), + shard=resp_shard2, input_data=resp3, inference_state=inference_state_3, ) @@ -43,35 +78,18 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e assert np.array_equal(resp_full, resp2) assert np.array_equal(next_resp_full, resp4) -def single_test(): - shard = Shard( - model_id="meta-llama/Meta-Llama-3.1-8B", - start_layer=0, - end_layer=0, - n_layers=32 - ) - - engine = PyTorchDynamicShardInferenceEngine(shard) - - - # Prepare the prompt - prompt = "Why is the sky blue?" - - # Run inference - loop = asyncio.get_event_loop() - output_data, _, _ = loop.run_until_complete( - engine.infer_prompt( - request_id="test_request", shard=shard, prompt=prompt - ) - ) - - assert output_data is not None - if __name__ == '__main__': - # single_test() + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Qwen/Qwen2-0.5B-Instruct", + # 25 + # )) + asyncio.run(test_inference_engine( PyTorchDynamicShardInferenceEngine(HFShardDownloader()), PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "andrijdavid/Llama3-2B-Base", + "andrijdavid/Llama3-1B-Base", + 3 )) From aa8903285006c8c75c15b774373ccf38786dfc32 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 14:03:41 -0800 Subject: [PATCH 308/491] fixing test --- exo/inference/pytorch/test_inference_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 725130e2..7b839652 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -1,7 +1,7 @@ import asyncio from exo.inference.shard import Shard -from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine +from .inference import PyTorchDynamicShardInferenceEngine from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.inference_engine import InferenceEngine from exo.inference.shard import Shard From b9331d70319a0b928e8877b231798692538c4899 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 14:05:01 -0800 Subject: [PATCH 309/491] adding init py for old python versions --- exo/inference/pytorch/__init__.py | 0 exo/inference/pytorch/model/__init__.py | 0 exo/inference/pytorch/test_inference_engine.py | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 exo/inference/pytorch/__init__.py create mode 100644 exo/inference/pytorch/model/__init__.py diff --git a/exo/inference/pytorch/__init__.py b/exo/inference/pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/exo/inference/pytorch/model/__init__.py b/exo/inference/pytorch/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 7b839652..725130e2 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -1,7 +1,7 @@ import asyncio from exo.inference.shard import Shard -from .inference import PyTorchDynamicShardInferenceEngine +from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.inference_engine import InferenceEngine from exo.inference.shard import Shard From 2c7aa9c7b818b23cbbfaa30f6108eeb515904f90 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 14:59:59 -0800 Subject: [PATCH 310/491] update readme and add in init pys --- exo/download/__init__.py | 0 exo/download/hf/__init__.py | 0 exo/inference/pytorch/README.md | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 exo/download/__init__.py create mode 100644 exo/download/hf/__init__.py diff --git a/exo/download/__init__.py b/exo/download/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/exo/download/hf/__init__.py b/exo/download/hf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/exo/inference/pytorch/README.md b/exo/inference/pytorch/README.md index 8cb0ce07..f87ee389 100644 --- a/exo/inference/pytorch/README.md +++ b/exo/inference/pytorch/README.md @@ -6,7 +6,7 @@ Experimental, still under development Install needed py modules, make sure to be using CUDA 12.4 for the PyTorch install ```console -$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 +$ pip install torch --index-url https://download.pytorch.org/whl/cu124 $ pip install transformers accelerate ``` From 6da3e942173fc20778fae2ca9aaeb4bd97d567c0 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 15:04:24 -0800 Subject: [PATCH 311/491] adding more tests --- .../pytorch/test_inference_engine.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 725130e2..d12aaf01 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -79,13 +79,15 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Qwen/Qwen2-0.5B-Instruct", - # 25 - # )) + print(f"\n\n -------- TEST QWEN2 -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "Qwen/Qwen2-0.5B-Instruct", + 25 + )) + print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") asyncio.run(test_inference_engine( PyTorchDynamicShardInferenceEngine(HFShardDownloader()), PyTorchDynamicShardInferenceEngine(HFShardDownloader()), @@ -93,3 +95,11 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e 3 )) + print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "meta-llama/Meta-Llama-3.1-8B", + 32 + )) + From d0bc93c1471e28d509c9e9836953b65b9cc9c8e7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 15:07:28 -0800 Subject: [PATCH 312/491] adding more try catch to move through tests --- .../pytorch/test_inference_engine.py | 53 +++++++++++-------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index d12aaf01..e540516c 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -80,26 +80,35 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e if __name__ == '__main__': print(f"\n\n -------- TEST QWEN2 -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "Qwen/Qwen2-0.5B-Instruct", - 25 - )) - - print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "andrijdavid/Llama3-1B-Base", - 3 - )) - - print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "meta-llama/Meta-Llama-3.1-8B", - 32 - )) + try: + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "Qwen/Qwen2-0.5B-Instruct", + 24 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + + try: + print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "andrijdavid/Llama3-1B-Base", + 3 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") + + try: + print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "meta-llama/Meta-Llama-3.1-8B", + 32 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") From 0e221b27f8947074d9ccda641272f10d6c297543 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 15:12:52 -0800 Subject: [PATCH 313/491] tests --- .../pytorch/test_inference_engine.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index e540516c..c71d3070 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -80,26 +80,26 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e if __name__ == '__main__': print(f"\n\n -------- TEST QWEN2 -------- \n\n") - try: - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "Qwen/Qwen2-0.5B-Instruct", - 24 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") - - try: - print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "andrijdavid/Llama3-1B-Base", - 3 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") + # try: + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Qwen/Qwen2-0.5B-Instruct", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "andrijdavid/Llama3-1B-Base", + # 3 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") try: print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") From 9fc9fdb166dfe7cf6b4bd794a190b03ef3873cee Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 15:25:43 -0800 Subject: [PATCH 314/491] added position embeddings, update test --- exo/inference/pytorch/model/hf.py | 33 ++++++++++++------- .../pytorch/test_inference_engine.py | 6 +++- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index cdb7d7c0..484e6c4d 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -43,12 +43,12 @@ def __init__(self, shard: Shard, tokenizer: any): if DEBUG >= 2: print(f"\nShardedHuggingFaceModel init with shard {shard}") print(f"self.llm_model: {self.llm_model}") - print(f"self.llm_model.model: {self.llm_model.model}") + print(f"self.base_model: {self.base_model}") # load layers from base model to use layers = [] for i in range(shard.start_layer, shard.end_layer + 1): - layer = self.llm_model.model.layers[i] + layer = self.base_model.layers[i] if DEBUG >= 2: print(f"Loading layers[{i}]") @@ -58,11 +58,11 @@ def __init__(self, shard: Shard, tokenizer: any): self.layers = nn.ModuleList(layers) if DEBUG >= 2: - print(f"full_model.model layer: {len(self.llm_model.model.layers)}") + print(f"full_model.model layer: {len(self.base_model.layers)}") # Embeddings and final layer norm # used for doing what forward LlamaModel does in transformers - self.norm = self.llm_model.model.norm + self.norm = self.base_model.norm self.lm_head = self.llm_model.lm_head self.embed_tokens = self.base_model.embed_tokens @@ -106,6 +106,15 @@ def forward( position_ids = cache_position.unsqueeze(0).to(self.device) + try: + position_embeddings = self.base_model.rotary_emb( + input_ids, + position_ids + ) + except Exception as err: + print(f"rotary_emb not found in base_model") + position_embeddings = None + # progress through layers for decoder_layer in self.layers: if DEBUG >= 4: @@ -114,7 +123,8 @@ def forward( layer_outputs = decoder_layer( input_ids, - position_ids=position_ids, + position_ids=position_ids if not position_embeddings else None, + position_embeddings=position_embeddings, past_key_value=past_kvs, use_cache=True, cache_position=cache_position, @@ -124,8 +134,7 @@ def forward( next_kvs = layer_outputs[1] if DEBUG >= 3: - print(f"hidden_state: {hidden_states}") - print(f"next_kvs: {next_kvs}") + print(f"layer_outputs {layer_outputs}") if self.shard.is_last_layer(): hs_norm = self.norm(hidden_states) @@ -138,7 +147,7 @@ def forward( TEMP, TOP_P, TOP_K - ).cpu().numpy().flatten() + ).numpy(force=True).flatten() if DEBUG >= 2: print(f"hs_norm: {hs_norm}") @@ -174,10 +183,10 @@ def forward( # # Forward pass through the layer # if DEBUG >= 2: - # print(f"\n[layer model] {self.llm_model.model}") + # print(f"\n[layer model] {self.base_model}") # print(f"IN hidden_states {hidden_states}") - # layer_outputs = self.llm_model.model( + # layer_outputs = self.base_model( # hidden_states.to(self.device), # use_cache=False # ) @@ -260,7 +269,7 @@ def forward( # # position_ids=position_ids # # ) # # else: - # # position_embeddings = self.llm_model.model.rotary_emb( + # # position_embeddings = self.base_model.rotary_emb( # # hidden_states, # # position_ids # # ) @@ -275,7 +284,7 @@ def forward( # print(f"IN hidden_states {hidden_states}") # print(f"past_kvs {past_kvs}") - # layer_outputs = self.llm_model.model( + # layer_outputs = self.base_model( # hidden_states, # position_ids=position_ids, # past_key_values=past_kvs, diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index c71d3070..e4e0e078 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -26,6 +26,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e prompt=prompt ) + print(f"resp_full: {resp_full}") + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( "A", shard=shard, @@ -33,6 +35,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_full, ) + print(f"next_resp_full: {next_resp_full}") + pp = int(n_layers/2) resp_shard = Shard( model_id=model_id, @@ -79,8 +83,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - print(f"\n\n -------- TEST QWEN2 -------- \n\n") # try: + # print(f"\n\n -------- TEST QWEN2 -------- \n\n") # asyncio.run(test_inference_engine( # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), From 2635b4c7218c650e93289ab6dcc90d43af2e2d17 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 15:30:57 -0800 Subject: [PATCH 315/491] tests --- .../pytorch/test_inference_engine.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index e4e0e078..15337d53 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -94,25 +94,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e # except Exception as err: # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") - # try: - # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "andrijdavid/Llama3-1B-Base", - # 3 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") - try: - print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") asyncio.run(test_inference_engine( PyTorchDynamicShardInferenceEngine(HFShardDownloader()), PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "meta-llama/Meta-Llama-3.1-8B", - 32 + "andrijdavid/Llama3-1B-Base", + 3 )) except Exception as err: - print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "meta-llama/Meta-Llama-3.1-8B", + # 32 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") From 86e89eb8ddf2ac2933a317a4e688b2c59449ca1d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 26 Aug 2024 18:35:32 -0800 Subject: [PATCH 316/491] adding back tests --- .../pytorch/test_inference_engine.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 15337d53..b690f02e 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -83,16 +83,16 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - # try: - # print(f"\n\n -------- TEST QWEN2 -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Qwen/Qwen2-0.5B-Instruct", - # 24 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + try: + print(f"\n\n -------- TEST QWEN2 -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "Qwen/Qwen2-0.5B-Instruct", + 24 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") try: print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") @@ -105,14 +105,14 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e except Exception as err: print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") - # try: - # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "meta-llama/Meta-Llama-3.1-8B", - # 32 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + try: + print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "meta-llama/Meta-Llama-3.1-8B", + 32 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") From 64fbacd6af05434d434bc5362d19602d9e88fe0b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 27 Aug 2024 05:46:36 -0800 Subject: [PATCH 317/491] adding another test --- .../pytorch/test_inference_engine.py | 61 +++++++++++-------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index b690f02e..8d02e634 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -83,36 +83,47 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - try: - print(f"\n\n -------- TEST QWEN2 -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "Qwen/Qwen2-0.5B-Instruct", - 24 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + # try: + # print(f"\n\n -------- TEST QWEN2 -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Qwen/Qwen2-0.5B-Instruct", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "andrijdavid/Llama3-1B-Base", + # 3 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "meta-llama/Meta-Llama-3.1-8B", + # 32 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") try: - print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") + print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") asyncio.run(test_inference_engine( PyTorchDynamicShardInferenceEngine(HFShardDownloader()), PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "andrijdavid/Llama3-1B-Base", - 3 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") - - try: - print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "meta-llama/Meta-Llama-3.1-8B", - 32 + "Chickaboo/ChickaQ-Large", + 24 )) except Exception as err: - print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") From 0d9313016f2a3d7d5b3cdc3354888caf632ae116 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 27 Aug 2024 08:07:05 -0800 Subject: [PATCH 318/491] added gc collect to remove gpu, fixed tokenizers warning --- .gitignore | 3 + exo/inference/pytorch/README.md | 10 +- exo/inference/pytorch/inference.py | 30 +++-- exo/inference/pytorch/model/hf.py | 186 ++--------------------------- exo/inference/tokenizers.py | 6 +- 5 files changed, 45 insertions(+), 190 deletions(-) diff --git a/.gitignore b/.gitignore index 44892139..f5609f31 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,6 @@ cython_debug/ #.idea/ **/*.xcodeproj/* + +# PyTorch interface +.offload diff --git a/exo/inference/pytorch/README.md b/exo/inference/pytorch/README.md index f87ee389..670c8df6 100644 --- a/exo/inference/pytorch/README.md +++ b/exo/inference/pytorch/README.md @@ -15,4 +15,12 @@ After installing accelerate you get hit with a dependency error, for now ignore ```console ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. exo 0.0.1 requires numpy==2.0.0, but you have numpy 1.26.4 which is incompatible. -``` \ No newline at end of file +``` + +## Low VRAM Notes + +- When trying to do disk_offload getting the error "Cannot copy out of meta tensor; no data!", looking up the error it is tied to (low vram)[https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13087#issuecomment-2080272004] + +## Multiple GPU in 1 Notes +### Running multiple GPUs on 1 machine +- Getting error "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_CUDA_cat)" diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 200b1c4c..014b7169 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,18 +1,16 @@ # experimental, based off of tinygrad/inference.py -import os import numpy as np import torch import numpy as np import json -from typing import Optional, Callable, Tuple +from typing import Optional, Tuple from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG from transformers import DynamicCache - - +from accelerate import disk_offload class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ @@ -183,14 +181,24 @@ async def ensure_shard(self, shard: Optional[Shard]): if DEBUG >= 4: print(f"Loading new shard: {shard}") - # if self.model: - # if DEBUG >= 2: - # print(f"\nCLEARING MODEL {self.shard.model_id}\n") + if self.model: + if DEBUG >= 2: + print(f"\nCLEARING MODEL {shard.model_id}\n") + print(f"before allocated: {torch.cuda.memory_allocated()}") + print(f"before reserved: {torch.cuda.memory_reserved()}") - # # delete model and free up memory to reload - # self.model.cpu() - # del self.model - # torch.cuda.empty_cache() + # delete model and free up memory to reload + # self.model.cuda() + # disk_offload(model=self.model, offload_dir="./.offload") + import gc + + del self.model + gc.collect() + torch.cuda.empty_cache() + + if DEBUG >= 2: + print(f"after allocated: {torch.cuda.memory_allocated()}") + print(f"after reserved: {torch.cuda.memory_reserved()}") self.shard = shard self.tokenizer = await resolve_tokenizer(shard.model_id) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 484e6c4d..9d8990d7 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,13 +1,9 @@ import torch -import torch.nn as nn import numpy as np -import re - -from transformers import AutoModelForCausalLM, BitsAndBytesConfig, DynamicCache, Cache +from transformers import AutoModelForCausalLM, DynamicCache, Cache from exo.inference.shard import Shard from exo.helpers import DEBUG from typing import Tuple, Optional, Union, List - from exo.inference.pytorch.model.utils import sample_logits TOP_P = 0.75 #0.95 @@ -32,9 +28,11 @@ def __init__(self, shard: Shard, tokenizer: any): shard.model_id, torch_dtype=torch.float32, device_map="auto", - # offload_buffers=True + offload_buffers=True ) + # disk_offload(model=self.llm_model, offload_dir="./.offload") + self.base_model = self.llm_model.model except Exception as err: print(f"Error loading model: {err}") @@ -45,18 +43,6 @@ def __init__(self, shard: Shard, tokenizer: any): print(f"self.llm_model: {self.llm_model}") print(f"self.base_model: {self.base_model}") - # load layers from base model to use - layers = [] - for i in range(shard.start_layer, shard.end_layer + 1): - layer = self.base_model.layers[i] - - if DEBUG >= 2: - print(f"Loading layers[{i}]") - - layers.append(layer) - - self.layers = nn.ModuleList(layers) - if DEBUG >= 2: print(f"full_model.model layer: {len(self.base_model.layers)}") @@ -116,7 +102,9 @@ def forward( position_embeddings = None # progress through layers - for decoder_layer in self.layers: + for i in range(self.shard.start_layer, self.shard.end_layer + 1): + decoder_layer = self.base_model.layers[i] + if DEBUG >= 4: print("Going through layer") print(f"{decoder_layer}") @@ -157,165 +145,9 @@ def forward( return (output_token, next_kvs) with torch.no_grad(): - out_hidden_states = hidden_states.cpu().numpy() + out_hidden_states = hidden_states.numpy(force=True) return ( out_hidden_states, next_kvs - ) - - # def forward_layers( - # self, - # input_data: torch.tensor - # ) -> np.ndarray: - # """ - # Forward pass through the specified layers. - # This is without caching - - # Note: past_key_values not working for model, might be a library bug - # """ - # if DEBUG >= 2: - # print("forward_layer call") - # print(f"input_data: {input_data}") - # print(f"shard {self.shard.to_dict()}") - - # hidden_states = input_data - - # # Forward pass through the layer - # if DEBUG >= 2: - # print(f"\n[layer model] {self.base_model}") - # print(f"IN hidden_states {hidden_states}") - - # layer_outputs = self.base_model( - # hidden_states.to(self.device), - # use_cache=False - # ) - - # if DEBUG >= 2: - # print(f"OUT hidden_states {layer_outputs.last_hidden_state}") - - # hidden_states = layer_outputs.last_hidden_state - - # print(f"2 is_last_layer {self.shard.is_last_layer()}") - # if self.shard.is_last_layer(): - # hs_norm = self.norm(hidden_states) - # hs_lm_head = self.llm_model.lm_head(hs_norm).float() - - # # Use the sampling function with default settings - # with torch.no_grad(): - # output_token = sample_logits( - # hs_lm_head[:, -1, :], - # TEMP, - # TOP_P, - # TOP_K - # ).cpu().numpy().flatten() - - # if DEBUG >= 2: - # print(f"hs_norm: {hs_norm}") - # print(f"hs_lm_head: {hs_lm_head}") - # print(f"output_token: {output_token}") - - # return output_token - - # return hidden_states.cpu().numpy() - - # def forward_layers_cached( - # self, - # input_data: torch.tensor, - # past_kvs - # ) -> Tuple[np.ndarray, list]: - # """ - # Forward pass through the specified layers. - # With caching - - # Note: past_key_values not working for model, might be a library bug - # """ - - # if not past_kvs: - # past_kvs = DynamicCache() - # else: - # past_kvs = DynamicCache.from_legacy_cache(past_kvs) - - # if DEBUG >= 2: - # print("forward_layer call") - # print(f"input_data: {input_data}") - # print(f"shard {self.shard.to_dict()}") - # print(f"past_kvs: {past_kvs}") - - # input_ids = input_data.to(self.device) - # position_ids = None - # # position_embeddings = None - - # inputs_embeds = self.embed_tokens(input_ids) - - # if self.shard.is_first_layer(): - # hidden_states = self.embed_tokens(hidden_states) - - # if DEBUG >= 2: - # print(f"hidden_states: {hidden_states}") - # print(f"hidden_states.size(): {hidden_states.size()}") - - # batch_size, seq_len = input_data.size() - # position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device).unsqueeze(0).expand(batch_size, -1) - - # # check if model does not have rotary emb - # # have to apply rotary per model - # # embedding seems very model specific and using position_ids - # # seems more universal, even though some give warning about it - # # if re.match(r"Qwen|qwen", self.shard.model_id): - # # import transformers.models.qwen2.modeling_qwen2 as qwen2 - # # position_embeddings = - # # q=hidden_states, - # # position_ids=position_ids - # # ) - # # else: - # # position_embeddings = self.base_model.rotary_emb( - # # hidden_states, - # # position_ids - # # ) - - # # if DEBUG >= 2: - # # print(f"embedded hidden_states {hidden_states}") - # # print(f"position_ids: {position_embeddings}") - - - # # Forward pass through the layer - # if DEBUG >= 2: - # print(f"IN hidden_states {hidden_states}") - # print(f"past_kvs {past_kvs}") - - # layer_outputs = self.base_model( - # hidden_states, - # position_ids=position_ids, - # past_key_values=past_kvs, - # use_cache=True - # ) - - # if DEBUG >= 2: - # print(f"\nlayer_outputs: {layer_outputs}") - - # hidden_states = layer_outputs.last_hidden_state - # present_kvs = layer_outputs.past_key_values - - # print(f"2 is_last_layer {self.shard.is_last_layer()}") - # if self.shard.is_last_layer(): - # hs_norm = self.norm(hidden_states) - # hs_lm_head = self.llm_model.lm_head(hs_norm).float() - - # # Use the sampling function with default settings - # with torch.no_grad(): - # output_token = sample_logits( - # hs_lm_head[:, -1, :], - # TEMP, - # TOP_P, - # TOP_K - # ).cpu().numpy().flatten() - - # if DEBUG >= 2: - # print(f"hs_norm: {hs_norm}") - # print(f"hs_lm_head: {hs_lm_head}") - # print(f"output_token: {output_token}") - - # return (output_token, present_kvs) - - # return (hidden_states.cpu().numpy(), present_kvs) \ No newline at end of file + ) \ No newline at end of file diff --git a/exo/inference/tokenizers.py b/exo/inference/tokenizers.py index e0bc332d..9accd943 100644 --- a/exo/inference/tokenizers.py +++ b/exo/inference/tokenizers.py @@ -19,7 +19,11 @@ async def resolve_tokenizer(model_id: str): async def _resolve_tokenizer(model_id_or_local_path: str): try: if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}") - processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in model_id_or_local_path else False) + if "Mistral-Large" in str(model_id_or_local_path): + use_fast = True + else: + use_fast = False + processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=use_fast) if not hasattr(processor, 'eos_token_id'): processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id if not hasattr(processor, 'encode'): From 0ae716de1e5d977ae1248bfc2bab201f315fb8b8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 27 Aug 2024 08:23:53 -0800 Subject: [PATCH 319/491] fixing device --- exo/inference/pytorch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 9d8990d7..8c0f0d53 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -29,7 +29,7 @@ def __init__(self, shard: Shard, tokenizer: any): torch_dtype=torch.float32, device_map="auto", offload_buffers=True - ) + ).to(self.device) # disk_offload(model=self.llm_model, offload_dir="./.offload") @@ -116,7 +116,7 @@ def forward( past_key_value=past_kvs, use_cache=True, cache_position=cache_position, - ) + ).to(self.device) hidden_states = layer_outputs[0] next_kvs = layer_outputs[1] From 7705639ec91e1e2bcaf02eabe594b45314ef818f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 27 Aug 2024 12:07:34 -0800 Subject: [PATCH 320/491] adding smaller model test --- exo/inference/pytorch/inference.py | 9 +- exo/inference/pytorch/model/hf.py | 6 +- .../pytorch/test_inference_engine.py | 90 +++++++++++-------- 3 files changed, 61 insertions(+), 44 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 014b7169..878fb5fd 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -108,6 +108,11 @@ async def infer_tensor( input_data, device=self.device ).unsqueeze(0).long() + else: + in_tensor = torch.tensor( + input_data, + device=self.device + ).long() in_tensor = self.model.embed_tokens(in_tensor) @@ -175,8 +180,8 @@ async def ensure_shard(self, shard: Optional[Shard]): Args: shard (Optional[Shard]): Shard information for the model. """ - if self.shard == shard: - return + # if self.shard == shard: + # return if DEBUG >= 4: print(f"Loading new shard: {shard}") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 8c0f0d53..f3572dc5 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -28,8 +28,8 @@ def __init__(self, shard: Shard, tokenizer: any): shard.model_id, torch_dtype=torch.float32, device_map="auto", - offload_buffers=True - ).to(self.device) + # offload_buffers=True + ) # disk_offload(model=self.llm_model, offload_dir="./.offload") @@ -116,7 +116,7 @@ def forward( past_key_value=past_kvs, use_cache=True, cache_position=cache_position, - ).to(self.device) + ) hidden_states = layer_outputs[0] next_kvs = layer_outputs[1] diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 8d02e634..bacf53bc 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -13,31 +13,32 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e # prompt = "Why is the sky blue?" prompt = "In a single word only, what is the last name of the current president of the USA?" - shard = Shard( - model_id=model_id, - start_layer=0, - end_layer=n_layers-1, - n_layers=n_layers - ) - - resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( - "A", - shard=shard, - prompt=prompt - ) - - print(f"resp_full: {resp_full}") - - next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( - "A", - shard=shard, - input_data=resp_full, - inference_state=inference_state_full, - ) - - print(f"next_resp_full: {next_resp_full}") + # shard = Shard( + # model_id=model_id, + # start_layer=0, + # end_layer=n_layers-1, + # n_layers=n_layers + # ) + + # resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + # "A", + # shard=shard, + # prompt=prompt + # ) + + # print(f"resp_full: {resp_full}") + + # next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + # "A", + # shard=shard, + # input_data=resp_full, + # inference_state=inference_state_full, + # ) + + # print(f"next_resp_full: {next_resp_full}") pp = int(n_layers/2) + resp_shard = Shard( model_id=model_id, start_layer=0, @@ -65,19 +66,19 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_1, ) - resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( - "B", - shard=resp_shard, - input_data=resp2, - inference_state=inference_state_2, - ) + # resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + # "B", + # shard=resp_shard, + # input_data=resp2, + # inference_state=inference_state_2, + # ) - resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( - "B", - shard=resp_shard2, - input_data=resp3, - inference_state=inference_state_3, - ) + # resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + # "B", + # shard=resp_shard2, + # input_data=resp3, + # inference_state=inference_state_3, + # ) assert np.array_equal(resp_full, resp2) assert np.array_equal(next_resp_full, resp4) @@ -116,14 +117,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e # except Exception as err: # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + # try: + # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Chickaboo/ChickaQ-Large", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") + try: - print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") + print(f"\n\n --------- TEST ambrosfitz/TinyLlama-1.1B-Chat-yawp -------\n\n") asyncio.run(test_inference_engine( PyTorchDynamicShardInferenceEngine(HFShardDownloader()), PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "Chickaboo/ChickaQ-Large", - 24 + "ambrosfitz/TinyLlama-1.1B-Chat-yawp", + 22 )) except Exception as err: - print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") + print(f"\n\n !!!!!!!!!!! ambrosfitz/TinyLlama-1.1B-Chat-yawp TEST FAILED \n{err}\n") From 81d597db2b4075b4725308881ed2e577df9bcb5e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 27 Aug 2024 16:30:30 -0800 Subject: [PATCH 321/491] testing --- exo/inference/pytorch/inference.py | 18 ++++++++---------- exo/inference/pytorch/model/hf.py | 2 ++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 878fb5fd..c6ba8e52 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -103,16 +103,14 @@ async def infer_tensor( current_kvs = None - if input_data.size == 1: - in_tensor = torch.tensor( - input_data, - device=self.device - ).unsqueeze(0).long() - else: - in_tensor = torch.tensor( - input_data, - device=self.device - ).long() + # if input_data.size == 1: + # in_tensor = torch.from_numpy( + # input_data + # ).unsqueeze(0).long() + # else: + in_tensor = torch.from_numpy( + input_data + ).long() in_tensor = self.model.embed_tokens(in_tensor) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index f3572dc5..aa2873c5 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -108,6 +108,8 @@ def forward( if DEBUG >= 4: print("Going through layer") print(f"{decoder_layer}") + print("input_ids") + print(f"{input_ids}") layer_outputs = decoder_layer( input_ids, From f1d3e311790962f390e044840c238e293549588d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 27 Aug 2024 16:38:01 -0800 Subject: [PATCH 322/491] added tinyllama --- exo/inference/pytorch/inference.py | 16 ++++++++-------- exo/models.py | 3 +++ tinychat/examples/tinychat/index.html | 2 +- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index c6ba8e52..ba834eb6 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -103,14 +103,14 @@ async def infer_tensor( current_kvs = None - # if input_data.size == 1: - # in_tensor = torch.from_numpy( - # input_data - # ).unsqueeze(0).long() - # else: - in_tensor = torch.from_numpy( - input_data - ).long() + if input_data.size == 1: + in_tensor = torch.from_numpy( + input_data, + ).unsqueeze(0).long().to(self.device) + else: + in_tensor = torch.from_numpy( + input_data + ).long().to(self.device) in_tensor = self.model.embed_tokens(in_tensor) diff --git a/exo/models.py b/exo/models.py index 72a5b566..137b881c 100644 --- a/exo/models.py +++ b/exo/models.py @@ -26,6 +26,9 @@ "llama-3-1B-Base": { "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), }, + "TinyLlama-1.1B-Chat-yaw": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="ambrosfitz/TinyLlama-1.1B-Chat-yawp", start_layer=0, end_layer=0, n_layers=22), + }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),}, diff --git a/tinychat/examples/tinychat/index.html b/tinychat/examples/tinychat/index.html index 8ff4c64c..350cea17 100644 --- a/tinychat/examples/tinychat/index.html +++ b/tinychat/examples/tinychat/index.html @@ -27,7 +27,7 @@
+
Date: Tue, 27 Aug 2024 18:12:12 -0800 Subject: [PATCH 325/491] adding A10, adding test --- .../pytorch/test_inference_engine.py | 20 +++++++++---------- exo/topology/device_capabilities.py | 1 + 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 41ff337a..4bad37c2 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -106,16 +106,16 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e # except Exception as err: # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") - # try: - # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "meta-llama/Meta-Llama-3.1-8B", - # 32 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + try: + print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "meta-llama/Meta-Llama-3.1-8B", + 32 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") # try: # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index 51db53ef..bed1b510 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -108,6 +108,7 @@ def to_dict(self): "NVIDIA T1000 8GB": DeviceFlops(fp32=2.5 * TFLOPS, fp16=5.0 * TFLOPS, int8=10.0 * TFLOPS), "Quadro M2000": DeviceFlops(fp32=0.5 * TFLOPS, fp16=1.0 * TFLOPS, int8=2.0 * TFLOPS), "Quadro P400": DeviceFlops(fp32=0.641 * TFLOPS, fp16=1.282 * TFLOPS, int8=2.564 * TFLOPS), + "NVIDIA A10": DeviceFlops(fp32=31.2 * TFLOPS, fp16=62.5 * TFLOPS, int8=2.5 * TFLOPS), # ... add more devices if needed ... ### AMD GPUs # RX 6000 series From ed5bea79251af4661d8d89a8c5dd2ba91542dc7d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 27 Aug 2024 18:20:56 -0800 Subject: [PATCH 326/491] removing reloading of shard, changing temp and top_p --- exo/inference/pytorch/inference.py | 23 ++--------------------- exo/inference/pytorch/model/hf.py | 4 ++-- 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ba834eb6..063a9e4a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -178,31 +178,12 @@ async def ensure_shard(self, shard: Optional[Shard]): Args: shard (Optional[Shard]): Shard information for the model. """ - # if self.shard == shard: - # return + if self.shard == shard: + return if DEBUG >= 4: print(f"Loading new shard: {shard}") - if self.model: - if DEBUG >= 2: - print(f"\nCLEARING MODEL {shard.model_id}\n") - print(f"before allocated: {torch.cuda.memory_allocated()}") - print(f"before reserved: {torch.cuda.memory_reserved()}") - - # delete model and free up memory to reload - # self.model.cuda() - # disk_offload(model=self.model, offload_dir="./.offload") - import gc - - del self.model - gc.collect() - torch.cuda.empty_cache() - - if DEBUG >= 2: - print(f"after allocated: {torch.cuda.memory_allocated()}") - print(f"after reserved: {torch.cuda.memory_reserved()}") - self.shard = shard self.tokenizer = await resolve_tokenizer(shard.model_id) self.model = ShardedHuggingFaceModel(shard, self.tokenizer) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 074cc53c..ed9e6ae1 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -7,8 +7,8 @@ from exo.inference.pytorch.model.utils import sample_logits TOP_P = 0.9 #0.95 -TOP_K = 20 -TEMP = 0.8 +TOP_K = 25 +TEMP = 0.85 class ShardedHuggingFaceModel(torch.nn.Module): def __init__(self, shard: Shard, tokenizer: any): From 032c9b1db7ba3df80826703f08010d312e18174d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 31 Aug 2024 21:21:12 -0800 Subject: [PATCH 327/491] rewrite of sharded model using new split testing of huggingface models --- exo/api/chatgpt_api.py | 16 +- exo/inference/pytorch/inference.py | 96 +++++---- .../pytorch/model/archive/hf_manual.py | 203 ++++++++++++++++++ .../pytorch/model/{ => archive}/utils.py | 0 exo/inference/pytorch/model/hf.py | 149 +++---------- .../pytorch/test_inference_engine.py | 51 +++-- exo/inference/pytorch/test_inference_loop.py | 105 +++++++++ exo/inference/pytorch/test_split_model.py | 108 ++++++++++ 8 files changed, 533 insertions(+), 195 deletions(-) create mode 100644 exo/inference/pytorch/model/archive/hf_manual.py rename exo/inference/pytorch/model/{ => archive}/utils.py (100%) create mode 100644 exo/inference/pytorch/test_inference_loop.py create mode 100644 exo/inference/pytorch/test_split_model.py diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 1abda85f..2619d163 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -58,6 +58,9 @@ def generate_completion( "finish_reason": finish_reason, }], } + + if DEBUG >= 3: + print(f"completion: {completion}") if not stream: completion["usage"] = { @@ -113,16 +116,9 @@ def remap_messages(messages: List[Message]) -> List[Message]: def build_prompt(tokenizer, _messages: List[Message]): - if len(_messages) == 1: - user_msg = _messages[0] - - # get instruct sys message - sys_msg = Message(role="system", content="You are a helpful assistant.") - - # restructure for sys_msg to go first - _messages = [sys_msg, user_msg] - messages = remap_messages(_messages) + if DEBUG >= 3: + print(f"messages: {messages}") prompt = tokenizer.apply_chat_template( messages, tokenize=False, @@ -140,7 +136,7 @@ def build_prompt(tokenizer, _messages: List[Message]): continue for content in message.content: - # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 + # note: wae only support one image at time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 # follows the convention in https://platform.openai.com/docs/guides/vision if isinstance(content, dict) and content.get("type", None) == "image": image_str = content.get("image", None) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 063a9e4a..9334153c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -6,26 +6,28 @@ from typing import Optional, Tuple from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine -from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel +from exo.inference.pytorch.model.archive.hf_manual import ShardedHuggingFaceModel from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG from transformers import DynamicCache from accelerate import disk_offload +from exo.download.shard_download import ShardDownloader class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. """ - def __init__(self, shard): + def __init__(self, shard_downloader: ShardDownloader): """ Initialize the inference engine. Args: debug (bool): If True, enables debug logging. Defaults to False. """ - self.shard = shard - self.model = None + self.shard = None + self.shard_downloader = shard_downloader + self.stateful_sharded_model = None self.tokenizer = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -37,33 +39,33 @@ async def infer_prompt( image_str: Optional[str] = None, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_prompt called") + print(f"prompt: {prompt}") await self.ensure_shard(shard) # need to make this so inference_state is not a string # cant use it with dynamic cache - tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) - tokens = self.model.embed_tokens(tokens) + inputs = self.tokenizer(prompt, return_tensors="pt") + input_ids = inputs.input_ids.to(self.device) + + # add pad token if none + if self.tokenizer.pad_token == None: + self.tokenizer.add_special_tokens({"pad_token":""}) + self.stateful_sharded_model.base_model.resize_token_embeddings(len(self.tokenizer)) + current_kvs = None if DEBUG >= 4: - print("infer_prompt called") - print(f"tokens: {tokens}\n") + print(f"tokens: {input_ids}\n") print(f"layer_count: {self.shard.get_layer_count()}") print(f"is_first_layer: {self.shard.is_first_layer()}") print(f"is_last_layer: {self.shard.is_last_layer()}") - # convert inference_state or cache from json to DynamicCache - past_kv = DynamicCache() - if inference_state != None: - cache_dict = json.loads(inference_state) - past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] - past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] - - output_data, current_kvs = self.model.forward( - tokens, - past_kv + output_data = self.stateful_sharded_model.forward( + input_ids ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] @@ -98,31 +100,26 @@ async def infer_tensor( input_data: np.ndarray, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 3: + print("infer_tensor called") + print(f"input_data: {input_data}") + print(f"input_data.size: {input_data.size}") + print(f"input_data.shape: {input_data.shape}") + print(f"shard: {self.shard}") await self.ensure_shard(shard) current_kvs = None + if input_data.size == 1: - in_tensor = torch.from_numpy( - input_data, - ).unsqueeze(0).long().to(self.device) + in_tensor = torch.tensor([[input_data.item()]]).to(self.device) else: - in_tensor = torch.from_numpy( - input_data - ).long().to(self.device) + in_tensor = torch.tensor(input_data).to(self.device) - in_tensor = self.model.embed_tokens(in_tensor) - - if DEBUG >= 4: - print("infer_tensor called") - print(f"input_data: {input_data}") - print(f"input_data.size: {input_data.size}") - print(f"input_tensor: {in_tensor}\n") - print(f"shard: {self.shard}") - print(f"layer_count: {self.shard.get_layer_count()}") - print(f"is_first_layer: {self.shard.is_first_layer()}") - print(f"is_last_layer: {self.shard.is_last_layer()}") + # in_tensor = torch.tensor(input_data).to(self.device) + + # in_tensor = self.stateful_sharded_model.embed_tokens(in_tensor) # convert inference_state or cache from json to DynamicCache past_kv = DynamicCache() @@ -131,29 +128,33 @@ async def infer_tensor( cache_dict = json.loads(inference_state) past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] - - if DEBUG >= 4: - print("Loaded past_kv from JSON") - print(f"past_kv: {past_kv}") - print(f"past_kv.key_cache len: {len(past_kv.key_cache)}") - print(f"past_kv.value_cache len: {len(past_kv.value_cache)}") + past_kv_length = past_kv[0][0].shape[2] except json.JSONDecodeError: print(f"ERROR DECODING INFERENCE STATE") - output_data, current_kvs = self.model.forward( + if DEBUG >= 3: + # print(f"input_tensor: {in_tensor}") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + print(f"input_data.shape: {input_data.shape}") + + print(f"in_tensor: {in_tensor}") + output_data, current_kvs = self.stateful_sharded_model.forward( in_tensor, + None, past_kv ) is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] - if DEBUG >= 4: - print(f"in_tensor: {in_tensor}\n") + if DEBUG >= 3: print(f"output_data: {output_data}\n") print(f"output_data.size {output_data.size}\n") print(f"finished: {is_finished}") print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") print(f"output_data[-1] {output_data[-1]}") + print("====================================================") if output_data.size == 1: print(f"size 1 output_data.item() {output_data.item()}") @@ -184,9 +185,12 @@ async def ensure_shard(self, shard: Optional[Shard]): if DEBUG >= 4: print(f"Loading new shard: {shard}") - self.shard = shard + # need to build in shard downloader + # model_path = await self.shard_downloader.ensure_shard(shard) + self.tokenizer = await resolve_tokenizer(shard.model_id) - self.model = ShardedHuggingFaceModel(shard, self.tokenizer) + self.stateful_sharded_model = ShardedHuggingFaceModel(shard) + self.shard = shard if DEBUG >= 4: print(f"Shard loaded successfully: {shard}") \ No newline at end of file diff --git a/exo/inference/pytorch/model/archive/hf_manual.py b/exo/inference/pytorch/model/archive/hf_manual.py new file mode 100644 index 00000000..e5af2eaf --- /dev/null +++ b/exo/inference/pytorch/model/archive/hf_manual.py @@ -0,0 +1,203 @@ +# Attempted version to recreate manually using LlamaModel and others +# BROKEN +import torch +import numpy as np +from transformers import AutoModelForCausalLM, DynamicCache, Cache, AutoModel +from exo.inference.shard import Shard +from exo.helpers import DEBUG +from typing import Tuple, Optional, Union, List +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from exo.inference.pytorch.model.archive.utils import sample_logits + +TOP_P = 0.7 #0.95 +TOP_K = 50 +TEMP = 0.01 + + +class ShardedHuggingFaceModel(torch.nn.Module): + def __init__(self, shard: Shard): + super(ShardedHuggingFaceModel, self).__init__() + + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + self.shard = shard + + # Load the model + try: + self.base_model = AutoModel.from_pretrained( + shard.model_id, + torch_dtype=torch.float32, + device_map="auto", + # offload_buffers=True + ) + + # disk_offload(model=self.base_model, offload_dir="./.offload") + except Exception as err: + print(f"Error loading model: {err}") + raise + + if DEBUG >= 2: + print(f"\nShardedHuggingFaceModel init with shard {shard}") + print(f"self.base_model: {self.base_model}") + + # Embeddings and final layer norm + # used for doing what forward LlamaModel does in transformers + self.norm = self.base_model.norm + self.lm_head = torch.nn.Linear( + self.base_model.config.hidden_size, + self.base_model.config.vocab_size, + bias=False + ).to(self.device) + self.embed_tokens = self.base_model.embed_tokens + + def forward( + self, + input_ids: torch.tensor, + attention_mask: torch.tensor = None, + past_kvs: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + ) -> Tuple[np.ndarray, any]: + """ + Forward through layers using the base model + + Args: + input_ids: tensor input + attention_mask: attention mask from tokenizer + past_kvs: past key value stores for cache + + Returns: + hidden_states: numpy of states between layers + or logits: numpy of normalization and linearization of last hidden state + past_kvs: DynamicCache of past key values if use_cache is true + + Ref: + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/qwen2/modeling_qwen2.py#L804 + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L887 + """ + if DEBUG >= 4: + print("forward called") + print(f"input_ids: {input_ids}\n") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + + if self.shard.is_first_layer(): + if DEBUG >= 2: + print("first layer, embed") + print(f"input_ids: {input_ids}") + input_ids = self.embed_tokens(input_ids) + + if DEBUG >= 2: + print(f"embeded input_ids: {input_ids}") + + if attention_mask == None: + # get attention mask + past_kv_length = len(past_kvs) + batch_size, seq_length = input_ids.shape[:2] + attention_mask = _prepare_4d_causal_attention_mask( + None, (batch_size, seq_length), input_ids, past_kv_length + ) + + past_kvs = DynamicCache.from_legacy_cache(past_kvs) + past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + input_ids.shape[1], + device=self.device + ) + + position_ids = cache_position.unsqueeze(0).to(self.device) + + try: + position_embeddings = self.base_model.rotary_emb( + input_ids, + position_ids + ) + except Exception as err: + print(f"rotary_emb not found in base_model") + position_embeddings = None + + causal_mask = self.base_model._update_causal_mask( + attention_mask, + input_ids, + cache_position, + past_kvs, + self.base_model.config.output_attentions + ) + + # progress through layers + for i in range(self.shard.start_layer, self.shard.end_layer + 1): + decoder_layer = self.base_model.layers[i] + + if DEBUG >= 4: + print("Going through layer") + print(f"{decoder_layer}") + print("input_ids") + print(f"{input_ids}") + print("causal_mask") + print(f"{causal_mask}") + + try: + layer_outputs = decoder_layer( + input_ids, + attention_mask=causal_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_value=past_kvs, + use_cache=True, + cache_position=cache_position, + output_logits=True + ) + except Exception as err: + print(f"Going through layer failed: {err}") + print(err.__traceback__.tb_lineno) + raise + + hidden_states = layer_outputs[0] + next_kvs = layer_outputs[1] + + if DEBUG >= 3: + print(f"layer_outputs {layer_outputs}") + print(layer_outputs[1:]) + + if self.shard.is_last_layer(): + hs_norm = self.norm(hidden_states).to(self.device) + # hs_lm_head = self.base_model.lm_head(hs_norm).float() + + # Use the sampling function with default settings + with torch.no_grad(): + logits = self.lm_head( + hs_norm[:, -1:, :] + ).to(self.device).float() + + if DEBUG >= 2: + print(f"hs_norm: {hs_norm}") + # print(f"hs_lm_head: {hs_lm_head}") + print(f"logits: {logits}") + print(f"logits.shape: {logits.shape}") + + # output_token = sample_logits( + # logits, + # TEMP, + # TOP_P, + # TOP_K + # ).unsqueeze(0).unsqueeze(0).long() + + output_token = torch.distributions.Categorical( + logits=logits + ).sample(sample_shape=(1,)) + + if DEBUG >= 2: + print(f"output_token: {output_token}") + + return (output_token.numpy(force=True), next_kvs) + + with torch.no_grad(): + out_hidden_states = hidden_states.float().numpy(force=True) + + return ( + out_hidden_states, + next_kvs + ) \ No newline at end of file diff --git a/exo/inference/pytorch/model/utils.py b/exo/inference/pytorch/model/archive/utils.py similarity index 100% rename from exo/inference/pytorch/model/utils.py rename to exo/inference/pytorch/model/archive/utils.py diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index ed9e6ae1..0812af6e 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,155 +1,56 @@ import torch +import torch.nn as nn import numpy as np -from transformers import AutoModelForCausalLM, DynamicCache, Cache +from transformers import AutoModelForCausalLM from exo.inference.shard import Shard from exo.helpers import DEBUG +from exo.inference.inference_engine import InferenceEngine +from exo.download.shard_download import ShardDownloader from typing import Tuple, Optional, Union, List -from exo.inference.pytorch.model.utils import sample_logits -TOP_P = 0.9 #0.95 -TOP_K = 25 -TEMP = 0.85 - -class ShardedHuggingFaceModel(torch.nn.Module): - def __init__(self, shard: Shard, tokenizer: any): - super(ShardedHuggingFaceModel, self).__init__() +class ShardedHuggingFaceModel(InferenceEngine): + def __init__(self, shard: Shard): + self.shard = shard if torch.cuda.is_available(): self.device = torch.device("cuda") - else: + self.torch_dtype = torch.float32 + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + self.torch_dtype = torch.float32 + else: self.device = torch.device("cpu") + self.torch_dtype = torch.float16 - self.shard = shard - self.tokenizer = tokenizer - - # Load the model try: - self.llm_model = AutoModelForCausalLM.from_pretrained( + self.base_model = AutoModelForCausalLM.from_pretrained( shard.model_id, torch_dtype=torch.float32, - device_map="auto", - # offload_buffers=True + device_map="auto" ) - - # disk_offload(model=self.llm_model, offload_dir="./.offload") - - self.base_model = self.llm_model.model except Exception as err: - print(f"Error loading model: {err}") + print(f"error loading model: {err}") raise - if DEBUG >= 2: - print(f"\nShardedHuggingFaceModel init with shard {shard}") - print(f"self.llm_model: {self.llm_model}") - print(f"self.base_model: {self.base_model}") - - if DEBUG >= 2: - print(f"full_model.model layer: {len(self.base_model.layers)}") - - # Embeddings and final layer norm - # used for doing what forward LlamaModel does in transformers - self.norm = self.base_model.norm - self.lm_head = self.llm_model.lm_head - self.embed_tokens = self.base_model.embed_tokens - def forward( self, - input_ids: torch.tensor, - past_kvs: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + input_ids: torch.tensor ) -> Tuple[np.ndarray, any]: """ Forward through layers using the base model Args: input_ids: tensor input - past_kvs: past key value stores for cache - use_cache: use cache - - Returns: - hidden_states: numpy of states between layers - or logits: numpy of normalization and linearization of last hidden state - past_kvs: DynamicCache of past key values if use_cache is true - Ref: - https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/qwen2/modeling_qwen2.py#L804 - https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L887 + Returns: + generator_ids: token ids from generation """ - if DEBUG >= 4: - print("forward called") - print(f"input_ids: {input_ids}\n") - print(f"layer_count: {self.shard.get_layer_count()}") - print(f"is_first_layer: {self.shard.is_first_layer()}") - print(f"is_last_layer: {self.shard.is_last_layer()}") - - past_kvs = DynamicCache.from_legacy_cache(past_kvs) - past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 - - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + input_ids.shape[1], - device=input_ids.device - ).to(self.device) - - position_ids = cache_position.unsqueeze(0).to(self.device) - - try: - position_embeddings = self.base_model.rotary_emb( - input_ids, - position_ids - ) - except Exception as err: - print(f"rotary_emb not found in base_model") - position_embeddings = None - - # progress through layers - for i in range(self.shard.start_layer, self.shard.end_layer + 1): - decoder_layer = self.base_model.layers[i] - - if DEBUG >= 4: - print("Going through layer") - print(f"{decoder_layer}") - print("input_ids") - print(f"{input_ids}") - - layer_outputs = decoder_layer( - input_ids, - position_ids=position_ids if not position_embeddings else None, - position_embeddings=position_embeddings, - past_key_value=past_kvs, - use_cache=True, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - next_kvs = layer_outputs[1] - - if DEBUG >= 3: - print(f"layer_outputs {layer_outputs}") - - if self.shard.is_last_layer(): - hs_norm = self.norm(hidden_states) - hs_lm_head = self.llm_model.lm_head(hs_norm).float() - - # Use the sampling function with default settings - with torch.no_grad(): - output_token = sample_logits( - hs_lm_head[:, -1, :], - TEMP, - TOP_P, - TOP_K - ).numpy(force=True).flatten() - if DEBUG >= 2: - print(f"hs_norm: {hs_norm}") - print(f"hs_lm_head: {hs_lm_head}") - print(f"output_token: {output_token}") + torch_dtype = + self.model = AutoModelForCausalLM.from_pretrained( + self.shard.model_id, + torch_dtype=torch.float32, + device_map="auto", + ) - return (output_token, next_kvs) - - with torch.no_grad(): - out_hidden_states = hidden_states.numpy(force=True) - return ( - out_hidden_states, - next_kvs - ) \ No newline at end of file diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/test_inference_engine.py index 4bad37c2..9b8a19ef 100644 --- a/exo/inference/pytorch/test_inference_engine.py +++ b/exo/inference/pytorch/test_inference_engine.py @@ -26,7 +26,9 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e prompt=prompt ) - print(f"resp_full: {resp_full}") + print("\n------------resp_full---------------\n") + print(resp_full) + print("\n------------resp_full---------------\n") next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( "A", @@ -35,7 +37,9 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_full, ) - print(f"next_resp_full: {next_resp_full}") + print("\n------------next_resp_full---------------\n") + print(next_resp_full) + print("\n------------next_resp_full---------------\n") pp = int(n_layers/2) @@ -59,6 +63,11 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e prompt=prompt ) + print("\n------------resp1---------------\n") + print(resp1) + print("\n------------resp1---------------\n") + + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( "B", shard=resp_shard2, @@ -66,6 +75,10 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_1, ) + print("\n------------resp2---------------\n") + print(resp2) + print("\n------------resp2---------------\n") + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( "B", shard=resp_shard, @@ -73,6 +86,10 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_2, ) + print("\n------------resp3---------------\n") + print(resp3) + print("\n------------resp3---------------\n") + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( "B", shard=resp_shard2, @@ -80,6 +97,10 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e inference_state=inference_state_3, ) + print("\n------------resp4---------------\n") + print(resp4) + print("\n------------resp4---------------\n") + assert np.array_equal(resp_full, resp2) assert np.array_equal(next_resp_full, resp4) @@ -106,16 +127,16 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e # except Exception as err: # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") - try: - print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "meta-llama/Meta-Llama-3.1-8B", - 32 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + # try: + # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "meta-llama/Meta-Llama-3.1-8B", + # 32 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") # try: # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") @@ -129,13 +150,13 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") try: - print(f"\n\n --------- TEST ambrosfitz/TinyLlama-1.1B-Chat-yawp -------\n\n") + print(f"\n\n --------- TEST TinyLlama/TinyLlama_v1.1 -------\n\n") asyncio.run(test_inference_engine( PyTorchDynamicShardInferenceEngine(HFShardDownloader()), PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "ambrosfitz/TinyLlama-1.1B-Chat-yawp", + "TinyLlama/TinyLlama_v1.1", 22 )) except Exception as err: - print(f"\n\n !!!!!!!!!!! ambrosfitz/TinyLlama-1.1B-Chat-yawp TEST FAILED \n{err}\n") + print(f"\n\n !!!!!!!!!!! TinyLlama/TinyLlama_v1.1 TEST FAILED \n{err}\n") diff --git a/exo/inference/pytorch/test_inference_loop.py b/exo/inference/pytorch/test_inference_loop.py new file mode 100644 index 00000000..a61b4342 --- /dev/null +++ b/exo/inference/pytorch/test_inference_loop.py @@ -0,0 +1,105 @@ + +import asyncio +from exo.inference.shard import Shard +from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.inference_engine import InferenceEngine +from exo.inference.shard import Shard +from exo.helpers import DEBUG +import os +import numpy as np + +async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int): + # prompt = "Why is the sky blue?" + prompt = "In a single word only, what is the last name of the current president of the USA?" + + shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + "A", + shard=shard, + prompt=prompt + ) + + print("\n------------resp_full---------------\n") + print(resp_full) + print("\n------------resp_full---------------\n") + + next_resp_full = resp_full + is_finished = False + while not is_finished: + next_resp_full, _next_inference_state_full, is_finished = await inference_engine_1.infer_tensor( + "A", + shard=shard, + input_data=next_resp_full, + inference_state=inference_state_full, + ) + + print("\n------------next_resp_full---------------\n") + print(next_resp_full) + print("\n------------next_resp_full---------------\n") + + + + +if __name__ == '__main__': + # try: + # print(f"\n\n -------- TEST QWEN2 -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Qwen/Qwen2-0.5B-Instruct", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "andrijdavid/Llama3-1B-Base", + # 3 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "meta-llama/Meta-Llama-3.1-8B", + # 32 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Chickaboo/ChickaQ-Large", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") + + try: + print(f"\n\n --------- TEST ambrosfitz/TinyLlama-1.1B-Chat-yawp -------\n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "ambrosfitz/TinyLlama-1.1B-Chat-yawp", + 22 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! ambrosfitz/TinyLlama-1.1B-Chat-yawp TEST FAILED \n{err}\n") + diff --git a/exo/inference/pytorch/test_split_model.py b/exo/inference/pytorch/test_split_model.py new file mode 100644 index 00000000..35104fc1 --- /dev/null +++ b/exo/inference/pytorch/test_split_model.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import asyncio +from transformers import AutoModelForCausalLM, AutoConfig +from exo.api.chatgpt_api import resolve_tokenizer + +async def model_split_test(prompt: str, model_id: str, layers: int): + # inference + tokenizer = await resolve_tokenizer(model_id) + max_length = tokenizer.model_max_length + + # get full model + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float32, + device_map="auto", + ) + + half_layers = int(layers/2) + + # Create a copy of all the layers + model_layers = model.model.layers + copy_layers = [] + for i in range(half_layers): + print(f"Copying layer {i}") + layer_to_copy = model_layers[i] + print(layer_to_copy) + + copy_layers.append(layer_to_copy) + + # load half layers back into model + module_copy_list = nn.ModuleList(copy_layers).to("cuda") + model.model.layers.load_state_dict( + module_copy_list.state_dict(), + strict=False + ) + + # generate first half + inputs = tokenizer(prompt, return_tensors="pt") + fhalf_generate_ids = model.generate( + inputs.input_ids.to("cuda"), + max_new_tokens=max_length/2 + ).to("cuda") + + print("fhalf_generate_ids") + print(fhalf_generate_ids) + + # generate other half + copy_layers = [] + for i in range(half_layers, layers): + print(f"Copying layer {i}") + layer_to_copy = model_layers[i] + print(layer_to_copy) + + copy_layers.append(layer_to_copy) + + # load half layers back into model + module_copy_list = nn.ModuleList(copy_layers).to("cuda") + model.model.layers.load_state_dict( + module_copy_list.state_dict(), + strict=False + ) + + # generate second half with first half + shalf_generate_ids = model.generate( + fhalf_generate_ids + ).to("cuda") + + print("generate_ids") + print(shalf_generate_ids) + print(tokenizer.eos_token_id) + + # decode second half + decode = tokenizer.batch_decode( + shalf_generate_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False + )[0] + + print("decode") + print(decode) + +if __name__ == "__main__": + prompt = "In a single word only, what is the last name of the current president of the USA?" + + print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") + model_id = "TinyLlama/TinyLlama_v1.1" + model_layers = 22 + + asyncio.run( + model_split_test( + prompt=prompt, + model_id=model_id, + layers=model_layers + ) + ) + + print("\n-------- Test meta-llama/Meta-Llama-3.1-8B ----------\n") + model_id = "meta-llama/Meta-Llama-3.1-8B" + model_layers = 32 + + asyncio.run( + model_split_test( + prompt=prompt, + model_id=model_id, + layers=model_layers + ) + ) From 626b2235074d7f047c0de7ff6b087cab30b27bcb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 1 Sep 2024 11:17:09 -0800 Subject: [PATCH 328/491] building out new hf.py class, testing qwen and llama3 8b --- exo/inference/pytorch/model/hf.py | 28 +++-- exo/inference/pytorch/test_split_model.py | 121 ++++++++++++++++------ 2 files changed, 108 insertions(+), 41 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 0812af6e..2a5eefd3 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -9,7 +9,7 @@ from typing import Tuple, Optional, Union, List class ShardedHuggingFaceModel(InferenceEngine): - def __init__(self, shard: Shard): + def __init__(self, shard: Shard, ): self.shard = shard if torch.cuda.is_available(): @@ -25,11 +25,23 @@ def __init__(self, shard: Shard): try: self.base_model = AutoModelForCausalLM.from_pretrained( shard.model_id, - torch_dtype=torch.float32, + torch_dtype=self.torch_dtype, device_map="auto" ) + + # build layers from shard + layers = self.base_model.model.layers + copy_layers = nn.ModuleList( + [layers[i] for i in range(self.shard.start_layer, self.shard.end_layer + 1)] + ) + + # apply layers back to model + self.base_model.model.layers.load_state_dict( + copy_layers.state_dict(), + strict=False + ) except Exception as err: - print(f"error loading model: {err}") + print(f"error loading and splitting model: {err}") raise def forward( @@ -46,11 +58,7 @@ def forward( generator_ids: token ids from generation """ - torch_dtype = - self.model = AutoModelForCausalLM.from_pretrained( - self.shard.model_id, - torch_dtype=torch.float32, - device_map="auto", - ) - + generate_ids = self.base_model.generate( + input_ids, + ) \ No newline at end of file diff --git a/exo/inference/pytorch/test_split_model.py b/exo/inference/pytorch/test_split_model.py index 35104fc1..4046bb21 100644 --- a/exo/inference/pytorch/test_split_model.py +++ b/exo/inference/pytorch/test_split_model.py @@ -1,20 +1,36 @@ import torch import torch.nn as nn import asyncio -from transformers import AutoModelForCausalLM, AutoConfig +import gc +from transformers import AutoModelForCausalLM, AutoConfig, Qwen2ForCausalLM from exo.api.chatgpt_api import resolve_tokenizer +import re async def model_split_test(prompt: str, model_id: str, layers: int): # inference tokenizer = await resolve_tokenizer(model_id) - max_length = tokenizer.model_max_length + max_length = 512 #tokenizer.model_max_length # get full model - model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype=torch.float32, - device_map="auto", - ) + if re.match(r"^Qwen|qwen", model_id): + model = Qwen2ForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float32, + device_map="auto", + # attn_implementation="eager" + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float32, + device_map="auto", + ) + + # add pad token if none + # this is for llama based models, will add a check + if tokenizer.pad_token == None and re.match(r"Llama|llama", model_id): + tokenizer.add_special_tokens({"pad_token":""}) + model.resize_token_embeddings(len(tokenizer)) half_layers = int(layers/2) @@ -22,12 +38,14 @@ async def model_split_test(prompt: str, model_id: str, layers: int): model_layers = model.model.layers copy_layers = [] for i in range(half_layers): - print(f"Copying layer {i}") + # print(f"Copying layer {i}") layer_to_copy = model_layers[i] - print(layer_to_copy) + # print(layer_to_copy) copy_layers.append(layer_to_copy) + print(f"loading {len(copy_layers)} layers back to model") + # load half layers back into model module_copy_list = nn.ModuleList(copy_layers).to("cuda") model.model.layers.load_state_dict( @@ -36,24 +54,43 @@ async def model_split_test(prompt: str, model_id: str, layers: int): ) # generate first half - inputs = tokenizer(prompt, return_tensors="pt") + messages = [{"role": "user", "content": prompt}] + txt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + print(f"Generating from chat template\n{txt}") + + inputs = tokenizer([txt], return_tensors="pt") + input_ids = inputs.input_ids.to("cuda") + input_attention_mask = inputs.attention_mask.to("cuda") fhalf_generate_ids = model.generate( - inputs.input_ids.to("cuda"), - max_new_tokens=max_length/2 + input_ids, + # attention_mask=input_attention_mask, + max_length=int(max_length/2), + output_hidden_states=True + # output_attentions=True ).to("cuda") print("fhalf_generate_ids") print(fhalf_generate_ids) + # nptest = fhalf_generate_ids.numpy(force=True) + # print(f"nptest: {nptest}") + # generate other half copy_layers = [] for i in range(half_layers, layers): - print(f"Copying layer {i}") + # print(f"Copying layer {i}") layer_to_copy = model_layers[i] - print(layer_to_copy) + # print(layer_to_copy) copy_layers.append(layer_to_copy) + print(f"loading {len(copy_layers)} layers back to model") + # load half layers back into model module_copy_list = nn.ModuleList(copy_layers).to("cuda") model.model.layers.load_state_dict( @@ -62,13 +99,16 @@ async def model_split_test(prompt: str, model_id: str, layers: int): ) # generate second half with first half + print(f"Generating from hidden layers output fhalf_generate_ids") shalf_generate_ids = model.generate( - fhalf_generate_ids + fhalf_generate_ids, + # attention_mask=input_attention_mask, + max_length=max_length ).to("cuda") - print("generate_ids") + print("shalf_generate_ids") print(shalf_generate_ids) - print(tokenizer.eos_token_id) + # print(tokenizer.eos_token_id) # decode second half decode = tokenizer.batch_decode( @@ -80,12 +120,42 @@ async def model_split_test(prompt: str, model_id: str, layers: int): print("decode") print(decode) + # free model from memory + del model + gc.collect() + torch.cuda.empty_cache() + + if __name__ == "__main__": prompt = "In a single word only, what is the last name of the current president of the USA?" - print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") - model_id = "TinyLlama/TinyLlama_v1.1" - model_layers = 22 + # print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") + # model_id = "TinyLlama/TinyLlama_v1.1" + # model_layers = 22 + + # asyncio.run( + # model_split_test( + # prompt=prompt, + # model_id=model_id, + # layers=model_layers + # ) + # ) + + # print("\n-------- Test meta-llama/Meta-Llama-3.1-8B ----------\n") + # model_id = "meta-llama/Meta-Llama-3.1-8B" + # model_layers = 32 + + # asyncio.run( + # model_split_test( + # prompt=prompt, + # model_id=model_id, + # layers=model_layers + # ) + # ) + + print("\n-------- Test Qwen/Qwen2-0.5B-Instruct ----------\n") + model_id = "Qwen/Qwen2-0.5B-Instruct" + model_layers = 24 asyncio.run( model_split_test( @@ -95,14 +165,3 @@ async def model_split_test(prompt: str, model_id: str, layers: int): ) ) - print("\n-------- Test meta-llama/Meta-Llama-3.1-8B ----------\n") - model_id = "meta-llama/Meta-Llama-3.1-8B" - model_layers = 32 - - asyncio.run( - model_split_test( - prompt=prompt, - model_id=model_id, - layers=model_layers - ) - ) From f983e9347eac426ed357f011e8ee9e391d84c007 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 3 Sep 2024 19:26:36 -0800 Subject: [PATCH 329/491] trying to load in weights but transformers/pytorch doesnt allow that unless wanting to rebuild the whole model --- exo/inference/pytorch/.gitignore | 1 + exo/inference/pytorch/model/archive/utils.py | 83 ------ exo/inference/pytorch/test_simple_model.py | 40 +++ exo/inference/pytorch/test_split_model.py | 277 +++++++++++-------- exo/inference/pytorch/test_weight_load.py | 206 ++++++++++++++ exo/inference/pytorch/utils.py | 185 +++++++++++++ 6 files changed, 597 insertions(+), 195 deletions(-) create mode 100644 exo/inference/pytorch/.gitignore delete mode 100644 exo/inference/pytorch/model/archive/utils.py create mode 100644 exo/inference/pytorch/test_simple_model.py create mode 100644 exo/inference/pytorch/test_weight_load.py create mode 100644 exo/inference/pytorch/utils.py diff --git a/exo/inference/pytorch/.gitignore b/exo/inference/pytorch/.gitignore new file mode 100644 index 00000000..8fce6030 --- /dev/null +++ b/exo/inference/pytorch/.gitignore @@ -0,0 +1 @@ +data/ diff --git a/exo/inference/pytorch/model/archive/utils.py b/exo/inference/pytorch/model/archive/utils.py deleted file mode 100644 index df84b397..00000000 --- a/exo/inference/pytorch/model/archive/utils.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -from torch.nn import functional as F - -def top_p_sampling(scaled_logits: torch.Tensor, top_p: float) -> torch.Tensor: - """ - Apply top-p (nucleus) sampling to logits. - - Args: - scaled_logits (torch.Tensor): The scaled logits from the model's output. - top_p (float): The cumulative probability threshold for top-p filtering. - temp (float): Temperature parameter for softmax distribution reshaping. - - Returns: - torch.Tensor: Token selected based on the top-p criterion. - - Ref: - https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py#L67C1-L97C17 - """ - scaled_logits = torch.where(torch.isnan(scaled_logits), torch.zeros_like(scaled_logits), scaled_logits) - scaled_logits = torch.where(torch.isinf(scaled_logits), torch.full_like(scaled_logits, 1e6), scaled_logits) - - probs = torch.softmax(scaled_logits, dim=-1) - - sorted_probs, sorted_indices = torch.sort( - probs, - descending=True, - dim=-1 - ) - - cumulative_probs = torch.cumsum(sorted_probs, dim=-1) - mask = cumulative_probs > top_p - - top_probs = torch.where(mask, torch.zeros_like(sorted_probs), sorted_probs) - sum_probs = top_probs.sum(dim=-1, keepdim=True) - top_probs = torch.where(sum_probs > 0, top_probs / sum_probs, torch.ones_like(top_probs) / top_probs.size(-1)) - - if torch.isnan(top_probs).any() or torch.isinf(top_probs).any(): - print("Warning: Top probabilities contain NaN or Inf values after normalization") - top_probs = torch.where(torch.isnan(top_probs) | torch.isinf(top_probs), - 1.0 / top_probs.size(-1), - top_probs) - - sorted_token = torch.multinomial(top_probs, num_samples=1) - - token = sorted_indices.gather(-1, sorted_token) - - return token.squeeze(-1) - -def sample_logits(logits, temp, top_p, top_k): - """ - Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. - - Args: - logits (torch.Tensor): The logits distribution to sample from. - temp (float): temp for scaling logits. - top_p (float): The cumulative probability threshold for nucleus sampling. - - Returns: - torch.Tensor: The selected token index. - """ - - # Ensure logits are float - logits = logits.float() - - # If temp is very low, just use argmax - if temp == 0: - return logits.argmax(dim=-1) - - scaled_logits = logits/temp - - # top k - if top_k > 0: - top_values, top_indices = torch.topk(scaled_logits, top_k, dim=-1) - scaled_logits = torch.zeros_like(logits).scatter_(-1, top_indices, top_values) - - # Top-p sampling - if 0 < top_p < 1.0: - return top_p_sampling(scaled_logits, top_p) - else: - # random distribution selection - probs = torch.softmax(scaled_logits, dim=-1) - rand_sample = torch.distributions.Categorical(probs) - return rand_sample.sample().squeeze() \ No newline at end of file diff --git a/exo/inference/pytorch/test_simple_model.py b/exo/inference/pytorch/test_simple_model.py new file mode 100644 index 00000000..81009d08 --- /dev/null +++ b/exo/inference/pytorch/test_simple_model.py @@ -0,0 +1,40 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +device = "cuda" # the device to load the model onto + +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen2-0.5B-Instruct", + torch_dtype="auto", + device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + +prompt = "In a single word only, what is the last name of the current president of the USA?" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} +] +text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True +) +model_inputs = tokenizer([text], return_tensors="pt").to(device) + +generated_ids = model.generate( + model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + max_new_tokens=512, + do_sample=True, + top_k=20 + #num_beams=5, + #early_stopping=True +) +generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) +] + +response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + +print(f"Prompt: {prompt}\n") +print(f"Response: {response}\n") diff --git a/exo/inference/pytorch/test_split_model.py b/exo/inference/pytorch/test_split_model.py index 4046bb21..242e5f48 100644 --- a/exo/inference/pytorch/test_split_model.py +++ b/exo/inference/pytorch/test_split_model.py @@ -4,9 +4,73 @@ import gc from transformers import AutoModelForCausalLM, AutoConfig, Qwen2ForCausalLM from exo.api.chatgpt_api import resolve_tokenizer +from typing import Tuple, Optional import re +from exo.inference.pytorch.utils import sample_logits, top_k_sampling + +TEMP = 0.6 +TOP_K = 60 + +class OnionHuggingFaceLM(): + def __init__(self, layers, is_last=False): + self.layers = layers + self.is_last = is_last + + def forward( + self, + model, + input_ids: torch.tensor=None, + hidden_states: torch.tensor=None, + attention_mask: torch.tensor=None, + **kwargs + ) -> Tuple[Optional[torch.tensor], Optional[torch.tensor]]: + + # set base model + base_model = model.model + + if input_ids is not None and hidden_states is not None: + print("You must either pass a hidden_state or input_ids but not both") + assert ValueError + + if input_ids is not None: + # embed + hidden_states = base_model.embed_tokens(input_ids) + position_ids = torch.arange( + 0, + input_ids.size(1), + device=input_ids.device + ).unsqueeze(0) + + if hidden_states is not None: + hidden_states = hidden_states + position_ids = torch.arange( + 0, + hidden_states.size(1), + device=hidden_states.device + ).unsqueeze(0) + + for layer in self.layers: + print(f"Processing hidden state from layer\n{layer}\n") + hidden_states = layer( + hidden_states, + position_ids=position_ids + )[0] + + if self.is_last: + norm_states = base_model.norm(hidden_states).to("cuda") + logits = model.lm_head(norm_states).to("cuda") + + return (None, logits) + + return (hidden_states, None) + +async def model_half_split_test(prompt: str, model_id: str, layers: int): + """ + Test for splitting in half + """ + + half_layers = int(layers / 2) -async def model_split_test(prompt: str, model_id: str, layers: int): # inference tokenizer = await resolve_tokenizer(model_id) max_length = 512 #tokenizer.model_max_length @@ -15,43 +79,31 @@ async def model_split_test(prompt: str, model_id: str, layers: int): if re.match(r"^Qwen|qwen", model_id): model = Qwen2ForCausalLM.from_pretrained( model_id, - torch_dtype=torch.float32, + torch_dtype="auto", device_map="auto", # attn_implementation="eager" + # low_cpu_mem_usage=True ) else: model = AutoModelForCausalLM.from_pretrained( model_id, - torch_dtype=torch.float32, + torch_dtype="auto", device_map="auto", + # low_cpu_mem_usage=True ) - # add pad token if none - # this is for llama based models, will add a check - if tokenizer.pad_token == None and re.match(r"Llama|llama", model_id): - tokenizer.add_special_tokens({"pad_token":""}) - model.resize_token_embeddings(len(tokenizer)) + print(model.hf_device_map) - half_layers = int(layers/2) + # add pad token if none, depending on model + #if tokenizer.pad_token == None: + # if re.match(r"Llama|llama", model_id): + # tokenizer.add_special_tokens({"pad_token":""}) + # model.resize_token_embeddings(len(tokenizer)) - # Create a copy of all the layers - model_layers = model.model.layers - copy_layers = [] - for i in range(half_layers): - # print(f"Copying layer {i}") - layer_to_copy = model_layers[i] - # print(layer_to_copy) + shard_layers = nn.ModuleList(model.model.layers[:half_layers])#.to("cuda") + sharded_model = OnionHuggingFaceLM(layers=shard_layers) - copy_layers.append(layer_to_copy) - - print(f"loading {len(copy_layers)} layers back to model") - - # load half layers back into model - module_copy_list = nn.ModuleList(copy_layers).to("cuda") - model.model.layers.load_state_dict( - module_copy_list.state_dict(), - strict=False - ) + print(model) # generate first half messages = [{"role": "user", "content": prompt}] @@ -66,59 +118,60 @@ async def model_split_test(prompt: str, model_id: str, layers: int): inputs = tokenizer([txt], return_tensors="pt") input_ids = inputs.input_ids.to("cuda") input_attention_mask = inputs.attention_mask.to("cuda") - fhalf_generate_ids = model.generate( - input_ids, - # attention_mask=input_attention_mask, - max_length=int(max_length/2), - output_hidden_states=True - # output_attentions=True - ).to("cuda") - - print("fhalf_generate_ids") - print(fhalf_generate_ids) - - # nptest = fhalf_generate_ids.numpy(force=True) - # print(f"nptest: {nptest}") - - # generate other half - copy_layers = [] - for i in range(half_layers, layers): - # print(f"Copying layer {i}") - layer_to_copy = model_layers[i] - # print(layer_to_copy) - - copy_layers.append(layer_to_copy) - - print(f"loading {len(copy_layers)} layers back to model") - - # load half layers back into model - module_copy_list = nn.ModuleList(copy_layers).to("cuda") - model.model.layers.load_state_dict( - module_copy_list.state_dict(), - strict=False + + # add if first layer of model check + shard_hidden_states, shard_logits = sharded_model.forward( + model=model, + input_ids=input_ids ) - # generate second half with first half - print(f"Generating from hidden layers output fhalf_generate_ids") - shalf_generate_ids = model.generate( - fhalf_generate_ids, - # attention_mask=input_attention_mask, - max_length=max_length - ).to("cuda") - - print("shalf_generate_ids") - print(shalf_generate_ids) - # print(tokenizer.eos_token_id) + print(f"shard_hidden_states\n{shard_hidden_states}") + print(f"shard_logits\n{shard_logits}") + + + # second half + print("Using first half hidden state for last half of model") + shard_layers = nn.ModuleList(model.model.layers[half_layers:]).to("cuda") + sharded_model.layers = shard_layers + sharded_model.is_last = True - # decode second half - decode = tokenizer.batch_decode( - shalf_generate_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False - )[0] + if shard_hidden_states is not None: + # add if last layer of model or in the middle check + shard_hidden_states, shard_logits = sharded_model.forward( + model=model, + hidden_states=shard_hidden_states + ) - print("decode") - print(decode) + print(f"shard_hidden_states\n{shard_hidden_states}") + print(f"shard_logits\n{shard_logits}") + else: + print("Sharded hidden states not found, error") + raise ValueError + + + print("generate from logits") + if shard_logits is not None: + print(shard_logits.dim()) + #print(shard_logits[0]) + + generated_ids = sample_logits(shard_logits, 0.1, 0.95, 30) + #generated_ids = torch.argmax(shard_logits/0.7, dim=-1) + #generated_ids = model.generate(logits) + + print("generated_ids") + print(generated_ids) + + generated_text = tokenizer.batch_decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False + )[0] + + print("Generated text:") + print(generated_text) + else: + print("Sharded logits missing from last layer run, error") + raise ValueError # free model from memory del model @@ -127,41 +180,41 @@ async def model_split_test(prompt: str, model_id: str, layers: int): if __name__ == "__main__": - prompt = "In a single word only, what is the last name of the current president of the USA?" - - # print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") - # model_id = "TinyLlama/TinyLlama_v1.1" - # model_layers = 22 - - # asyncio.run( - # model_split_test( - # prompt=prompt, - # model_id=model_id, - # layers=model_layers - # ) - # ) - - # print("\n-------- Test meta-llama/Meta-Llama-3.1-8B ----------\n") - # model_id = "meta-llama/Meta-Llama-3.1-8B" - # model_layers = 32 - - # asyncio.run( - # model_split_test( - # prompt=prompt, - # model_id=model_id, - # layers=model_layers - # ) - # ) - - print("\n-------- Test Qwen/Qwen2-0.5B-Instruct ----------\n") - model_id = "Qwen/Qwen2-0.5B-Instruct" - model_layers = 24 - - asyncio.run( - model_split_test( - prompt=prompt, - model_id=model_id, - layers=model_layers - ) - ) + prompt = "In a single word only, what is the last name of the current president of the USA?" + + print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") + model_id = "TinyLlama/TinyLlama_v1.1" + model_layers = 22 + + asyncio.run( + model_half_split_test( + prompt=prompt, + model_id=model_id, + layers=model_layers + ) + ) + + #print("\n-------- Test meta-llama/Meta-Llama-3.1-8B ----------\n") + #model_id = "meta-llama/Meta-Llama-3.1-8B" + #model_layers = 32 + + #asyncio.run( + # model_half_split_test( + # prompt=prompt, + # model_id=model_id, + # layers=model_layers + # ) + #) + + #print("\n-------- Test Qwen/Qwen2-57B-A14B-Instruct ----------\n") + #model_id = "Qwen/Qwen2-57B-A14B-Instruct" + #model_layers = 28 + + #asyncio.run( + # model_half_split_test( + # prompt=prompt, + # model_id=model_id, + # layers=model_layers + # ) + #) diff --git a/exo/inference/pytorch/test_weight_load.py b/exo/inference/pytorch/test_weight_load.py new file mode 100644 index 00000000..7eb8142f --- /dev/null +++ b/exo/inference/pytorch/test_weight_load.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn +import asyncio +import gc +import json +from transformers import AutoConfig, AutoModel +from safetensors import safe_open +from typing import Tuple, Optional +import re +from exo.inference.pytorch.utils import sample_logits, top_k_sampling +from exo.api.chatgpt_api import resolve_tokenizer + +TEMP = 0.6 +TOP_K = 60 + +class OnionHuggingFaceLM(): + def __init__(self, layers, safetensor_index_file, safetensor_directory, is_last=False): + self.layers = layers + self.is_last = is_last + self.safetensor_index_file = safetensor_index_file + self.safetensor_directory = safetensor_directory + + # Load the safetensor index JSON + with open(safetensor_index_file, "r") as f: + self.index_data = json.load(f) + self.weight_map = self.index_data['weight_map'] + self.safetensors_metadata = self.index_data['safetensors_metadata'] + + def load_layer_weights(self, model, layer_index): + layer_tensors = {} + for param_name, file_name in self.weight_map.items(): + if param_name.startswith(f"model.layers.{layer_index}"): + file_path = f"{self.safetensor_directory}/{file_name}" + print(f"loading safetensor\n{file_path}\nfor layer\n{layer_index}") + offsets = self.safetensors_metadata[file_name]['offsets'][param_name] + dtype = self.safetensors_metadata[file_name]['dtype'] + shape = self.safetensors_metadata[file_name]['shape'] + + with safe_open(file_path, framework="pt", device="cuda") as f: + tensor = f.get_tensor_slice(offsets[0], offsets[1]) + tensor = tensor.view(shape) # Reshape to the correct shape + + layer_tensors[param_name] = tensor + + # Assign these tensors to the model's layer + for param_name, tensor in layer_tensors.items(): + param_pointer = model + param_parts = param_name.split('.') + for attr in param_parts[:-1]: + if attr.isdigit(): + attr = int(attr) + param_pointer = getattr(param_pointer, attr) + setattr(param_pointer, param_parts[-1], tensor) + + def forward( + self, + model, + input_ids: torch.tensor=None, + hidden_states: torch.tensor=None, + attention_mask: torch.tensor=None, + **kwargs + ) -> Tuple[Optional[torch.tensor], Optional[torch.tensor]]: + + base_model = model.model + + if input_ids is not None and hidden_states is not None: + print("You must either pass a hidden_state or input_ids but not both") + raise ValueError + + if input_ids is not None: + hidden_states = base_model.embed_tokens(input_ids) + position_ids = torch.arange( + 0, + input_ids.size(1), + device=input_ids.device + ).unsqueeze(0) + + if hidden_states is not None: + position_ids = torch.arange( + 0, + hidden_states.size(1), + device=hidden_states.device + ).unsqueeze(0) + + for idx, layer in enumerate(self.layers): + print(f"Loading weights for layer {idx}") + self.load_layer_weights(model, idx) # Load weights for the current layer + print(f"Processing hidden state from layer {idx}\n") + hidden_states = layer( + hidden_states, + position_ids=position_ids + )[0] + + if self.is_last: + norm_states = base_model.norm(hidden_states).to("cuda") + logits = model.lm_head(norm_states).to("cuda") + + return (None, logits) + + return (hidden_states, None) + +async def model_half_split_test( + prompt: str, + model_id: str, + layers: int, + safetensor_index_file: str, + safetensor_directory: str): + + half_layers = int(layers / 2) + + print("loading tokenizer") + tokenizer = await resolve_tokenizer(model_id) + max_length = 512 + + print("loading config and model") + config = AutoConfig.from_pretrained(model_id, local_files_only=True) + model = AutoModel.from_config(config).to("cuda") + + print(model.hf_device_map) + + shard_layers = nn.ModuleList(model.model.layers[:half_layers]) + sharded_model = OnionHuggingFaceLM( + layers=shard_layers, + safetensor_index_file=safetensor_index_file, + safetensor_directory=safetensor_directory + ) + + print(model) + + messages = [{"role": "user", "content": prompt}] + txt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + print(f"Generating from chat template\n{txt}") + + inputs = tokenizer([txt], return_tensors="pt") + input_ids = inputs.input_ids.to("cuda") + input_attention_mask = inputs.attention_mask.to("cuda") + + shard_hidden_states, shard_logits = sharded_model.forward( + model=model, + input_ids=input_ids + ) + + print(f"shard_hidden_states\n{shard_hidden_states}") + print(f"shard_logits\n{shard_logits}") + + print("Using first half hidden state for last half of model") + shard_layers = nn.ModuleList(model.model.layers[half_layers:]).to("cuda") + sharded_model.layers = shard_layers + sharded_model.is_last = True + + if shard_hidden_states is not None: + shard_hidden_states, shard_logits = sharded_model.forward( + model=model, + hidden_states=shard_hidden_states + ) + + print(f"shard_hidden_states\n{shard_hidden_states}") + print(f"shard_logits\n{shard_logits}") + else: + print("Sharded hidden states not found, error") + raise ValueError + + print("generate from logits") + if shard_logits is not None: + generated_ids = sample_logits(shard_logits, TEMP, 0.95, TOP_K) + print("generated_ids") + print(generated_ids) + + generated_text = tokenizer.batch_decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False + )[0] + + print("Generated text:") + print(generated_text) + else: + print("Sharded logits missing from last layer run, error") + raise ValueError + + del model + gc.collect() + torch.cuda.empty_cache() + +if __name__ == "__main__": + prompt = "In a single word only, what is the last name of the current president of the USA?" + + print("\n-------- Test Qwen/Qwen2-7B-Instruct ----------\n") + model_id = "Qwen/Qwen2-7B-Instruct" + model_layers = 22 + + asyncio.run( + model_half_split_test( + prompt=prompt, + model_id=model_id, + layers=model_layers, + safetensor_index_file="./data/qwen2_7B_Instruct/model.safetensors.index.json", + safetensor_directory="./data/qwen2_7B_Instruct/" + ) + ) + diff --git a/exo/inference/pytorch/utils.py b/exo/inference/pytorch/utils.py new file mode 100644 index 00000000..e4062da9 --- /dev/null +++ b/exo/inference/pytorch/utils.py @@ -0,0 +1,185 @@ +import torch +from torch.nn import functional as F + +def top_k_sampling(logits, thres): + num_logits = logits.shape[-1] + val, ind = torch.topk(logits, thres, dim=-1, largest=True, sorted=True) + mask = torch.zeros_like(logits) + mask.scatter_(-1, ind, 1) + logits = logits * mask + + return logits + +def top_p_sampling(logits, thres): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + print(f"top_p_sampling sorted_logits\n{sorted_logits}\nsorted_indices {sorted_indices}") + softmax_logits = F.softmax(sorted_logits, dim=-1) + print(f"top_p_sampling\nsoftmax_logits {softmax_logits}") + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + print(f"top_p_sampling\n{cumulative_probs}") + + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > thres + + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) + print(f"top_p_sampling\nindicies_to_remove: {indices_to_remove}") + logits[indices_to_remove] = float('-inf') + return logits + +def sample_logits(logits, temp, top_p, top_k): + """ + Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. + + Args: + logits (torch.Tensor): The logits distribution to sample from. + temp (float): temp for scaling logits. + top_p (float): The cumulative probability threshold for nucleus sampling. + + Returns: + torch.Tensor: The selected token index. + """ + # If temp is very low, just use argmax + if temp == 0: + return logits.argmax(dim=-1) + + print(f"logits {logits}") + + scaled_logits = logits/temp + + print(f"scaled_logits: {scaled_logits}") + + if 0 < top_p < 1.0: + top_p_logits = top_p_sampling(scaled_logits, top_p) + print(f"top_p logits {top_p_logits}") + if top_k > 0: + top_k_logits = top_k_sampling(top_p_logits, top_k) + return top_k_logits.argmax(dim=-1) + elif top_k > 0: + top_k_logits = top_k_sampling(logits, top_k) + print(f"top_k logits {top_k_logits}") + return top_k_logits.argmax(dim=-1) + + return scaled_logits.argmax(dim=-1) + + +# from tinygrad llama model sample +def sample(logits: torch.Tensor, temp: float, k: int, p: float, af: float, ap: float): + assert logits.ndim == 1, "only works on 1D tensors" + assert 0 <= p <= 1, "p must be between 0 and 1" + assert 0 <= k <= logits.numel(), "k must be between 0 and numel" + + # If temperature is very low, just use argmax + if temp < 1e-6: + return logits.argmax().reshape(1) + + # Alpha sampling + if af or ap: + if not hasattr(sample, "alpha_counter"): + sample.alpha_counter = torch.zeros_like(logits, dtype=torch.int32).contiguous() + logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0).float() * ap) + + # Replace NaNs with -inf + logits = torch.where(logits != logits, torch.tensor(-float("inf"), device=logits.device), logits) + + # Apply softmax after temperature scaling + t = F.softmax(logits / temp, dim=-1) + + counter = torch.arange(t.numel(), device=logits.device).contiguous() + counter2 = torch.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous() + + # Top-k sampling + if k: + output = torch.zeros(k, device=logits.device).contiguous() + output_indices = torch.zeros(k, device=logits.device, dtype=torch.int32).contiguous() + + for i in range(k): + t_max = t.max() + t_argmax = (t.numel() - ((t == t_max) * counter2).max() - 1).to(torch.int) + output[i] = t_max + output_indices[i] = t_argmax + t = torch.where(counter == t_argmax, torch.tensor(0.0, device=logits.device), t) + + # Approximate top-p sampling + output_cumsum = output.flip(dims=(0,)).cumsum(dim=0).flip(dims=(0,)) + t.sum() + mask = output_cumsum >= (1 - p) + output = output * mask.float() + output_indices = output_indices * mask.int() + + # Sample from the distribution + output_idx = output.multinomial(num_samples=1) + output_token = output_indices[output_idx] + else: + output_token = t.multinomial(num_samples=1) + + # Increase alpha counter + if af or ap: + sample.alpha_counter = torch.where(counter == output_token, sample.alpha_counter + 1, sample.alpha_counter) + + return output_token + + +def sample_3d(logits: torch.Tensor, temp: float, k: int, p: float, af: float, ap: float): + assert logits.ndim == 3, "only works on 3D tensors" + assert 0 <= p <= 1, "p must be between 0 and 1" + assert 0 <= k <= logits.shape[-1], "k must be between 0 and the last dimension size" + + batch_size, seq_len, vocab_size = logits.shape + + # If temperature is very low, just use argmax + if temp < 1e-6: + return logits.argmax(dim=-1) + + # Alpha sampling + if af or ap: + if not hasattr(sample, "alpha_counter"): + sample.alpha_counter = torch.zeros_like(logits, dtype=torch.int32).contiguous() + logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0).float() * ap) + + # Replace NaNs with -inf + logits = torch.where(logits != logits, torch.tensor(-float("inf"), device=logits.device), logits) + + # Apply softmax after temperature scaling + t = F.softmax(logits / temp, dim=-1) + + counter = torch.arange(vocab_size, device=logits.device).unsqueeze(0).unsqueeze(0).expand_as(t).contiguous() + counter2 = torch.arange(vocab_size - 1, -1, -1, device=logits.device).unsqueeze(0).unsqueeze(0).expand_as(t).contiguous() + + # Top-k sampling + if k: + output = torch.zeros((batch_size, seq_len, k), device=logits.device).contiguous() + output_indices = torch.zeros((batch_size, seq_len, k), device=logits.device, dtype=torch.int32).contiguous() + + for i in range(k): + t_max, _ = t.max(dim=-1, keepdim=True) + t_argmax = (vocab_size - ((t == t_max) * counter2).max(dim=-1, keepdim=True)[0] - 1).to(torch.int) + output[:, :, i] = t_max.squeeze(-1) + output_indices[:, :, i] = t_argmax.squeeze(-1) + t = torch.where(counter == t_argmax, torch.tensor(0.0, device=logits.device), t) + + # Approximate top-p sampling + output_cumsum = output.flip(dims=(-1,)).cumsum(dim=-1).flip(dims=(-1,)) + t.sum(dim=-1, keepdim=True) + mask = output_cumsum >= (1 - p) + output = output * mask.float() + output_indices = output_indices * mask.int() + + # Sample from the distribution + output_flat = output.view(batch_size * seq_len, -1) + output_idx = output_flat.multinomial(num_samples=1).squeeze(-1) + output_indices_flat = output_indices.view(batch_size * seq_len, -1) + output_token = output_indices_flat.gather(dim=-1, index=output_idx.unsqueeze(-1)).view(batch_size, seq_len) + else: + output_flat = t.view(batch_size * seq_len, -1) + output_token = output_flat.multinomial(num_samples=1).view(batch_size, seq_len) + + # Increase alpha counter + if af or ap: + sample.alpha_counter = torch.where(counter == output_token.unsqueeze(-1), sample.alpha_counter + 1, sample.alpha_counter) + + return output_token + From d142be047ef3f47b516f4badfc0e03b07673f8dc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 13 Sep 2024 11:10:56 -0800 Subject: [PATCH 330/491] adding more testing, refining logit selection --- exo/inference/pytorch/model/hf.py | 13 +- exo/inference/pytorch/test_simple_model.py | 6 +- exo/inference/pytorch/test_split_model.py | 339 +++++++++++++-------- 3 files changed, 224 insertions(+), 134 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 2a5eefd3..7ef80665 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -44,21 +44,18 @@ def __init__(self, shard: Shard, ): print(f"error loading and splitting model: {err}") raise - def forward( + def run( self, input_ids: torch.tensor ) -> Tuple[np.ndarray, any]: """ - Forward through layers using the base model + Run through a set of model layers Args: input_ids: tensor input + this could be tokens or hidden states from other layers Returns: - generator_ids: token ids from generation + layer_outputs: dict + layer output including hidden states, key values or logits """ - - generate_ids = self.base_model.generate( - input_ids, - - ) \ No newline at end of file diff --git a/exo/inference/pytorch/test_simple_model.py b/exo/inference/pytorch/test_simple_model.py index 81009d08..1b08a180 100644 --- a/exo/inference/pytorch/test_simple_model.py +++ b/exo/inference/pytorch/test_simple_model.py @@ -21,12 +21,16 @@ ) model_inputs = tokenizer([text], return_tensors="pt").to(device) +print(f"model_inputs:\n{model_inputs}") + +print(f"generation_config:\n{model.generation_config}") + generated_ids = model.generate( model_inputs.input_ids, attention_mask=model_inputs.attention_mask, max_new_tokens=512, do_sample=True, - top_k=20 + #top_k=20, #num_beams=5, #early_stopping=True ) diff --git a/exo/inference/pytorch/test_split_model.py b/exo/inference/pytorch/test_split_model.py index 242e5f48..d5ceb755 100644 --- a/exo/inference/pytorch/test_split_model.py +++ b/exo/inference/pytorch/test_split_model.py @@ -2,7 +2,28 @@ import torch.nn as nn import asyncio import gc -from transformers import AutoModelForCausalLM, AutoConfig, Qwen2ForCausalLM +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + DynamicCache, + Cache, + LogitsProcessorList, + #MinLengthLogitsProcessor, + LogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TemperatureLogitsWarper, + StoppingCriteriaList, + MaxLengthCriteria, + MaxTimeCriteria +) + +from transformers.generation.configuration_utils import ( + GenerationConfig, + GenerationMode +) + from exo.api.chatgpt_api import resolve_tokenizer from typing import Tuple, Optional import re @@ -19,50 +40,94 @@ def __init__(self, layers, is_last=False): def forward( self, model, - input_ids: torch.tensor=None, - hidden_states: torch.tensor=None, + llm_model, + input_ids: Optional[torch.tensor], + hidden_states: Optional[torch.tensor], attention_mask: torch.tensor=None, + past_key_values: Cache=DynamicCache(), **kwargs - ) -> Tuple[Optional[torch.tensor], Optional[torch.tensor]]: + ) -> Tuple[Optional[torch.tensor], Optional[torch.tensor], Optional[Cache]]: + + """ + Generate hidden states or logits via passing through set amount of layers of a model + To be passed only input_ids OR hidden_state and not both. This is for connecting the model + layer to generate a complete output + + Args: + input_ids: tensor Optional + hidden_states: tensor Optional - # set base model - base_model = model.model + Returns: + Tuple of + - hidden_states: tensor Optional + - logits: tensor Optional + + """ + is_first = False if input_ids is not None and hidden_states is not None: - print("You must either pass a hidden_state or input_ids but not both") - assert ValueError + raise ValueError if input_ids is not None: - # embed - hidden_states = base_model.embed_tokens(input_ids) - position_ids = torch.arange( - 0, - input_ids.size(1), - device=input_ids.device - ).unsqueeze(0) + # embed input_ids + input_ids = model.embed_tokens(input_ids) + # calculate position_ids + batch_size, seq_length = input_ids.shape[:2] + + is_first = True if hidden_states is not None: - hidden_states = hidden_states - position_ids = torch.arange( - 0, - hidden_states.size(1), - device=hidden_states.device - ).unsqueeze(0) + batch_size, seq_length = hidden_states.shape[:2] + + # cache + past_key_values_length = len(past_key_values) + cache_position = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=input_ids.device if input_ids is not None else hidden_states.device + ) + + position_ids = cache_position.unsqueeze(0) + + if is_first: + model_inputs = llm_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + position_ids=position_ids, + cache_position=cache_position, + attention_mask=attention_mask + ) + + print(f"model_inputs\n{model_inputs}") + for layer in self.layers: - print(f"Processing hidden state from layer\n{layer}\n") - hidden_states = layer( - hidden_states, - position_ids=position_ids - )[0] + layer_input = input_ids if input_ids is not None else hidden_states + #print(f"INPUT: \n{layer_input}\n") + #print(f"POSITION_IDS: \n{position_ids}\n") + #print(f"LAYER: \n{layer}\n") + layer_outputs = layer( + model_inputs["input_ids"], + position_ids=model_inputs["position_ids"], + #attention_mask=model_inputs["attention_mask"], + past_key_values=model_inputs["past_key_values"], + return_dict=True, + use_cache=True + ) + + hidden_states = layer_outputs[0] + past_key_values = layer_outputs[1] if self.is_last: - norm_states = base_model.norm(hidden_states).to("cuda") - logits = model.lm_head(norm_states).to("cuda") + norm_states = model.norm(hidden_states) + + # lm_head + logits = llm_model.lm_head(norm_states).to("cuda") - return (None, logits) + return (None, logits, past_key_values) - return (hidden_states, None) + return (hidden_states, None, past_key_values) async def model_half_split_test(prompt: str, model_id: str, layers: int): """ @@ -72,40 +137,28 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): half_layers = int(layers / 2) # inference - tokenizer = await resolve_tokenizer(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) max_length = 512 #tokenizer.model_max_length - # get full model - if re.match(r"^Qwen|qwen", model_id): - model = Qwen2ForCausalLM.from_pretrained( - model_id, - torch_dtype="auto", - device_map="auto", - # attn_implementation="eager" - # low_cpu_mem_usage=True - ) - else: - model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype="auto", - device_map="auto", - # low_cpu_mem_usage=True - ) - - print(model.hf_device_map) + # get llm model + llm_model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype="auto", + device_map="auto", + use_cache=True + ) + + # get base model + model = llm_model.model # add pad token if none, depending on model - #if tokenizer.pad_token == None: - # if re.match(r"Llama|llama", model_id): - # tokenizer.add_special_tokens({"pad_token":""}) - # model.resize_token_embeddings(len(tokenizer)) + if tokenizer.pad_token == None: + if re.match(r"Llama|llama", model_id): + tokenizer.add_special_tokens({"pad_token":""}) + model.resize_token_embeddings(len(tokenizer)) - shard_layers = nn.ModuleList(model.model.layers[:half_layers])#.to("cuda") - sharded_model = OnionHuggingFaceLM(layers=shard_layers) - - print(model) - - # generate first half + + # generate input_ids messages = [{"role": "user", "content": prompt}] txt = tokenizer.apply_chat_template( messages, @@ -113,65 +166,100 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): add_generation_prompt=True ) - print(f"Generating from chat template\n{txt}") - inputs = tokenizer([txt], return_tensors="pt") input_ids = inputs.input_ids.to("cuda") - input_attention_mask = inputs.attention_mask.to("cuda") + input_attention_mask = inputs.attention_mask.to("cuda") + batch_size, seq_length = input_ids.shape[:2] + + is_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + logit_runs = 1 - # add if first layer of model check - shard_hidden_states, shard_logits = sharded_model.forward( - model=model, - input_ids=input_ids - ) + raw_logits = None - print(f"shard_hidden_states\n{shard_hidden_states}") - print(f"shard_logits\n{shard_logits}") + while not is_finished: + print(f"\n\nLOGIT RUN {logit_runs}\n\n") + print(f"input_ids:\n{input_ids}\n") + print(input_ids.shape) - # second half - print("Using first half hidden state for last half of model") - shard_layers = nn.ModuleList(model.model.layers[half_layers:]).to("cuda") - sharded_model.layers = shard_layers - sharded_model.is_last = True + #shard_layers = nn.ModuleList(model.layers[:half_layers])#.to("cuda") + shard_layers = nn.ModuleList(model.layers) + sharded_model = OnionHuggingFaceLM(layers=shard_layers) + sharded_model.is_last = True - if shard_hidden_states is not None: - # add if last layer of model or in the middle check - shard_hidden_states, shard_logits = sharded_model.forward( + # generate first half + # add if first layer of model check + shard_hidden_states, shard_logits, shard_past_kvs = sharded_model.forward( model=model, - hidden_states=shard_hidden_states + llm_model=llm_model, + attention_mask=input_attention_mask, + input_ids=input_ids, + hidden_states=None ) - print(f"shard_hidden_states\n{shard_hidden_states}") - print(f"shard_logits\n{shard_logits}") - else: - print("Sharded hidden states not found, error") - raise ValueError - + # second half + #sharded_model.layers = nn.ModuleList(model.layers[half_layers:]) + #sharded_model.is_last = True - print("generate from logits") - if shard_logits is not None: - print(shard_logits.dim()) - #print(shard_logits[0]) + #shard_hidden_states, shard_logits, shard_past_kvs = sharded_model.forward( + # model=model, + # llm_model=llm_model, + # input_ids=None, + # hidden_states=shard_hidden_states, + # past_key_values=shard_past_kvs + #) - generated_ids = sample_logits(shard_logits, 0.1, 0.95, 30) - #generated_ids = torch.argmax(shard_logits/0.7, dim=-1) - #generated_ids = model.generate(logits) + # this part of the generation and _sample functions for transformers GenerationMixin + # ref: https://github.com/huggingface/transformers/blob/0a55d9f7376f72ad3ff296d4249840021b03bcc4/src/transformers/generation/utils.py#L1301 - print("generated_ids") - print(generated_ids) + # clone logit sample + logits = shard_logits[:, -1, :].clone().float() + + raw_logits = logits + + # distribute + logits_processor = LogitsProcessorList([ + TopKLogitsWarper(35), + TemperatureLogitsWarper(0.6), + TopPLogitsWarper(0.8) + ]) + + stopping_critera = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=50), + MaxTimeCriteria(max_time=10.0), + ] + ) + + next_token_scores = logits_processor(input_ids, logits) + + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + #next_tokens = torch.argmax(next_token_scores, dim=-1) + + # get inputs ready incase not finished + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + unfinished_sequences = unfinished_sequences & ~stopping_critera(input_ids, None) + is_finished = unfinished_sequences.max() == 0 + + print(f"is_finished?:\n{is_finished}\n") + + logit_runs += 1 + + del logits + del shard_logits - generated_text = tokenizer.batch_decode( - generated_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False - )[0] + print(f"model.generation_config\n{llm_model.generation_config}") - print("Generated text:") - print(generated_text) - else: - print("Sharded logits missing from last layer run, error") - raise ValueError + generated_text = tokenizer.batch_decode( + input_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False + )[0] + + print(f"generated_text:\n{generated_text}\n") # free model from memory del model @@ -180,19 +268,20 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): if __name__ == "__main__": - prompt = "In a single word only, what is the last name of the current president of the USA?" + #prompt = "In a single word only, what is the last name of the current president of the USA?" + prompt = "In a single word only, what is the color of an apple?" - print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") - model_id = "TinyLlama/TinyLlama_v1.1" - model_layers = 22 + #print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") + #model_id = "TinyLlama/TinyLlama_v1.1" + #model_layers = 22 - asyncio.run( - model_half_split_test( - prompt=prompt, - model_id=model_id, - layers=model_layers - ) - ) + #asyncio.run( + # model_half_split_test( + # prompt=prompt, + # model_id=model_id, + # layers=model_layers + # ) + #) #print("\n-------- Test meta-llama/Meta-Llama-3.1-8B ----------\n") #model_id = "meta-llama/Meta-Llama-3.1-8B" @@ -206,15 +295,15 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): # ) #) - #print("\n-------- Test Qwen/Qwen2-57B-A14B-Instruct ----------\n") - #model_id = "Qwen/Qwen2-57B-A14B-Instruct" - #model_layers = 28 - - #asyncio.run( - # model_half_split_test( - # prompt=prompt, - # model_id=model_id, - # layers=model_layers - # ) - #) + print("\n-------- Test Qwen/Qwen2-0.5B-Instruct ----------\n") + model_id = "Qwen/Qwen2-0.5B-Instruct" + model_layers = 24 + + asyncio.run( + model_half_split_test( + prompt=prompt, + model_id=model_id, + layers=model_layers + ) + ) From be8d7fbaf6f2a3cb200c15a22e4ac06701ed6506 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 15 Sep 2024 00:19:16 -0800 Subject: [PATCH 331/491] working split model test, updating class --- exo/inference/pytorch/test_split_model.py | 182 ++++++++++++++-------- 1 file changed, 121 insertions(+), 61 deletions(-) diff --git a/exo/inference/pytorch/test_split_model.py b/exo/inference/pytorch/test_split_model.py index d5ceb755..42e1642a 100644 --- a/exo/inference/pytorch/test_split_model.py +++ b/exo/inference/pytorch/test_split_model.py @@ -24,8 +24,14 @@ GenerationMode ) +# llama +from transformers.models.llama.modeling_llama import LlamaModel + +# qwen2 +from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + from exo.api.chatgpt_api import resolve_tokenizer -from typing import Tuple, Optional +from typing import Tuple, Optional, Union, List import re from exo.inference.pytorch.utils import sample_logits, top_k_sampling @@ -36,17 +42,27 @@ class OnionHuggingFaceLM(): def __init__(self, layers, is_last=False): self.layers = layers self.is_last = is_last + self.past_key_values = None + self.cache_position = None + self.position_ids = None + self.input_embed = None + self.causal_mask = None + self.position_embeddings = None + self.attention_mask = None + self.input_ids = None + self.hidden_states = None + self.next_decoder_cache = None def forward( self, model, llm_model, - input_ids: Optional[torch.tensor], - hidden_states: Optional[torch.tensor], - attention_mask: torch.tensor=None, - past_key_values: Cache=DynamicCache(), + input_ids: Optional[torch.tensor] = None, + hidden_states: Optional[torch.tensor] = None, + attention_mask: Optional[torch.tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, **kwargs - ) -> Tuple[Optional[torch.tensor], Optional[torch.tensor], Optional[Cache]]: + ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: """ Generate hidden states or logits via passing through set amount of layers of a model @@ -54,80 +70,123 @@ def forward( layer to generate a complete output Args: + model: base llm model tramsformers class + llm_model: llm chat model class input_ids: tensor Optional hidden_states: tensor Optional Returns: Tuple of - hidden_states: tensor Optional + - past_key_values - logits: tensor Optional """ - is_first = False + output_attentions = False # outputting attention not needed + use_legacy_cache = False # some models still use legacy kv store if input_ids is not None and hidden_states is not None: raise ValueError - if input_ids is not None: - # embed input_ids - input_ids = model.embed_tokens(input_ids) - # calculate position_ids - batch_size, seq_length = input_ids.shape[:2] + if hidden_states is not None: + self.hidden_states = hidden_states - is_first = True + if input_ids is not None: + self.input_ids = input_ids - if hidden_states is not None: - batch_size, seq_length = hidden_states.shape[:2] + # embed input_ids + self.inputs_embeds = model.embed_tokens(self.input_ids) - # cache - past_key_values_length = len(past_key_values) - cache_position = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=input_ids.device if input_ids is not None else hidden_states.device - ) + # cache + if past_key_values and not isinstance(past_key_values, Cache): + print("Using legacy cache") + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + position_ids = cache_position.unsqueeze(0) + + # causal mask + self.attention_mask = attention_mask + self.causal_mask = model._update_causal_mask( + None, + self.inputs_embeds, + cache_position, + past_key_values, + output_attentions + ) + + #print(f"causal_mask.dim(): {self.causal_mask.dim()}") - position_ids = cache_position.unsqueeze(0) + print(f"\ncausal_mask:{self.causal_mask}\n\n") - if is_first: + # embed positions, some models require and some dont + if isinstance(model, LlamaModel): + self.position_embeddings = model.rotary_emb( + self.inputs_embeds, + position_ids + ) + model_inputs = llm_model.prepare_inputs_for_generation( - input_ids, + self.input_ids, past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, position_ids=position_ids, - cache_position=cache_position, - attention_mask=attention_mask + cache_position=cache_position ) print(f"model_inputs\n{model_inputs}") + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] + + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + self.hidden_states, + attention_mask=self.causal_mask, + position_ids=self.position_ids, + past_key_values=self.past_key_values, + use_cache=True, + cache_position=self.cache_position - for layer in self.layers: - layer_input = input_ids if input_ids is not None else hidden_states - #print(f"INPUT: \n{layer_input}\n") - #print(f"POSITION_IDS: \n{position_ids}\n") - #print(f"LAYER: \n{layer}\n") - layer_outputs = layer( - model_inputs["input_ids"], - position_ids=model_inputs["position_ids"], - #attention_mask=model_inputs["attention_mask"], - past_key_values=model_inputs["past_key_values"], - return_dict=True, - use_cache=True ) - hidden_states = layer_outputs[0] - past_key_values = layer_outputs[1] + self.hidden_states = layer_outputs[0] + self.next_decoder_cache = layer_outputs[1] if self.is_last: - norm_states = model.norm(hidden_states) + self.hidden_states = model.norm(self.hidden_states) + + if use_legacy_cache: + self.past_key_values = self.next_decoder_cache.to_legacy_cache() + else: + self.past_key_values = self.next_decoder_cache # lm_head - logits = llm_model.lm_head(norm_states).to("cuda") + logits = llm_model.lm_head(self.hidden_states).to("cuda") - return (None, logits, past_key_values) + return ( + None, + None, + logits + ) - return (hidden_states, None, past_key_values) + return ( + self.hidden_states, + self.past_key_values, + None + ) async def model_half_split_test(prompt: str, model_id: str, layers: int): """ @@ -183,14 +242,15 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): print(f"input_ids:\n{input_ids}\n") print(input_ids.shape) - #shard_layers = nn.ModuleList(model.layers[:half_layers])#.to("cuda") - shard_layers = nn.ModuleList(model.layers) + print("\n first half of layers") + shard_layers = nn.ModuleList(model.layers[:half_layers])#.to("cuda") + #shard_layers = nn.ModuleList(model.layers) sharded_model = OnionHuggingFaceLM(layers=shard_layers) - sharded_model.is_last = True + #sharded_model.is_last = True # generate first half # add if first layer of model check - shard_hidden_states, shard_logits, shard_past_kvs = sharded_model.forward( + shard_hidden_states, shard_past_kvs, shard_logits = sharded_model.forward( model=model, llm_model=llm_model, attention_mask=input_attention_mask, @@ -199,16 +259,16 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): ) # second half - #sharded_model.layers = nn.ModuleList(model.layers[half_layers:]) - #sharded_model.is_last = True + print(f"\n second half of layers") + sharded_model.layers = nn.ModuleList(model.layers[half_layers:]) + sharded_model.is_last = True - #shard_hidden_states, shard_logits, shard_past_kvs = sharded_model.forward( - # model=model, - # llm_model=llm_model, - # input_ids=None, - # hidden_states=shard_hidden_states, - # past_key_values=shard_past_kvs - #) + shard_hidden_states, shard_past_kvs, shard_logits = sharded_model.forward( + model=model, + llm_model=llm_model, + hidden_states=shard_hidden_states, + past_key_values=shard_past_kvs + ) # this part of the generation and _sample functions for transformers GenerationMixin # ref: https://github.com/huggingface/transformers/blob/0a55d9f7376f72ad3ff296d4249840021b03bcc4/src/transformers/generation/utils.py#L1301 @@ -268,8 +328,8 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): if __name__ == "__main__": - #prompt = "In a single word only, what is the last name of the current president of the USA?" - prompt = "In a single word only, what is the color of an apple?" + prompt = "In a single word only, what is the last name of the current president of the USA?" + #prompt = "In a single word only, what is the color of an apple?" #print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") #model_id = "TinyLlama/TinyLlama_v1.1" From 9d1ecdd60f90bd37f5be426199b7a1b7e24dd4fc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 15 Sep 2024 05:07:39 -0800 Subject: [PATCH 332/491] working on class and inference engine updates --- exo/inference/pytorch/inference.py | 49 ++++--- exo/inference/pytorch/model/hf.py | 206 ++++++++++++++++++++++++----- 2 files changed, 206 insertions(+), 49 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 9334153c..a1df7966 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -13,6 +13,11 @@ from accelerate import disk_offload from exo.download.shard_download import ShardDownloader +# model value options +TOP_K = 35 +TEMP = 0.6 +TOP_P = 0.8 + class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. @@ -29,7 +34,20 @@ def __init__(self, shard_downloader: ShardDownloader): self.shard_downloader = shard_downloader self.stateful_sharded_model = None self.tokenizer = None - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # setup cuda device + if torch.cuda.is_available(): + self.device = torch.device("cuda") + self.torch_dtype = torch.float32 + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + self.torch_dtype = torch.float32 + else: + self.device = torch.device("cpu") + self.torch_dtype = torch.float16 + + # setup unfinished sequence + self.unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=self.device) async def infer_prompt( self, @@ -41,25 +59,24 @@ async def infer_prompt( ) -> Tuple[np.ndarray, str, bool]: if DEBUG >= 4: print("infer_prompt called") - print(f"prompt: {prompt}") await self.ensure_shard(shard) - # need to make this so inference_state is not a string - # cant use it with dynamic cache - - inputs = self.tokenizer(prompt, return_tensors="pt") - input_ids = inputs.input_ids.to(self.device) - - # add pad token if none - if self.tokenizer.pad_token == None: - self.tokenizer.add_special_tokens({"pad_token":""}) - self.stateful_sharded_model.base_model.resize_token_embeddings(len(self.tokenizer)) - - current_kvs = None + # setup prompt input + messages = [{"role": "user", "content": prompt}] + txt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + inputs = self.tokenizer([txt], return_tensors="pt") + input_ids = inputs.input_ids.to("cuda") + input_attention_mask = inputs.attention_mask.to("cuda") + batch_size, seq_length = input_ids.shape[:2] + if DEBUG >= 4: - print(f"tokens: {input_ids}\n") + print(f"input_ids: {input_ids}\n") print(f"layer_count: {self.shard.get_layer_count()}") print(f"is_first_layer: {self.shard.is_first_layer()}") print(f"is_last_layer: {self.shard.is_last_layer()}") @@ -193,4 +210,4 @@ async def ensure_shard(self, shard: Optional[Shard]): self.shard = shard if DEBUG >= 4: - print(f"Shard loaded successfully: {shard}") \ No newline at end of file + print(f"Shard loaded successfully: {shard}") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 7ef80665..2805ae67 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,61 +1,201 @@ import torch import torch.nn as nn import numpy as np -from transformers import AutoModelForCausalLM +import gc +from typing import Tuple, Optional, Union, List + from exo.inference.shard import Shard from exo.helpers import DEBUG from exo.inference.inference_engine import InferenceEngine from exo.download.shard_download import ShardDownloader -from typing import Tuple, Optional, Union, List + +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + DynamicCache, + Cache, + LogitsProcessorList, + #MinLengthLogitsProcessor, + LogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TemperatureLogitsWarper, + StoppingCriteriaList, + MaxLengthCriteria, + MaxTimeCriteria +) + +from transformers.generation.configuration_utils import ( + GenerationConfig, + GenerationMode +) + +# llama +from transformers.models.llama.modeling_llama import LlamaModel + +# qwen2 +from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + class ShardedHuggingFaceModel(InferenceEngine): def __init__(self, shard: Shard, ): + # class vars self.shard = shard + self.hidden_states = None + self.input_ids = None + self.inputs_embeds = None + self.attention_mask = None + self.position_embeddings = None + self.past_key_values = None + self.cache_position = None + self.position_ids = None + self.causal_mask = None - if torch.cuda.is_available(): - self.device = torch.device("cuda") - self.torch_dtype = torch.float32 - elif torch.backends.mps.is_available(): - self.device = torch.device("mps") - self.torch_dtype = torch.float32 - else: - self.device = torch.device("cpu") - self.torch_dtype = torch.float16 - + # setup pytorch and transformer llm try: self.base_model = AutoModelForCausalLM.from_pretrained( shard.model_id, torch_dtype=self.torch_dtype, device_map="auto" - ) - - # build layers from shard - layers = self.base_model.model.layers - copy_layers = nn.ModuleList( - [layers[i] for i in range(self.shard.start_layer, self.shard.end_layer + 1)] - ) - - # apply layers back to model - self.base_model.model.layers.load_state_dict( - copy_layers.state_dict(), - strict=False - ) + ) except Exception as err: print(f"error loading and splitting model: {err}") raise - def run( + + def forward( self, - input_ids: torch.tensor - ) -> Tuple[np.ndarray, any]: + shard: Optional[Shard] = None, + model, + llm_model, + input_ids: Optional[torch.tensor] = None, + hidden_states: Optional[torch.tensor] = None, + attention_mask: Optional[torch.tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_legacy_cache: Optional[bool] = False + ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: + """ - Run through a set of model layers + Generate hidden states or logits via passing through set amount of layers of a model + To be passed only input_ids OR hidden_state and not both. This is for connecting the model + layer to generate a complete output Args: - input_ids: tensor input - this could be tokens or hidden states from other layers + model: base llm model tramsformers class + llm_model: llm chat model class + input_ids: tensor optional + hidden_states: tensor optional + attention_mask: tensor optional + past_key_values: Cache or list[tensor] optional + use_legacy_cache: bool optional Returns: - layer_outputs: dict - layer output including hidden states, key values or logits + Tuple of + - hidden_states: tensor optional + - past_key_values: Cache or list[tensor] optional + - logits: tensor Optional + """ + + if input_ids is not None and hidden_states is not None: + raise ValueError + + if hidden_states is not None: + self.hidden_states = hidden_states + + if input_ids is not None: + self.input_ids = input_ids + + # embed input_ids + self.inputs_embeds = model.embed_tokens(self.input_ids) + + # cache + if past_key_values and not isinstance(past_key_values, Cache): + print("Using legacy cache") + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + position_ids = cache_position.unsqueeze(0) + + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = model._update_causal_mask( + None, + self.inputs_embeds, + cache_position, + past_key_values, + output_attentions + ) + + # embed positions, some models require and some dont + if isinstance(model, LlamaModel): + self.position_embeddings = model.rotary_emb( + self.inputs_embeds, + position_ids + ) + + # prepare inputs for decoder layers + model_inputs = llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=position_ids, + cache_position=cache_position + ) + + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] + + # run through decoder layers + layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) + for i in layer_amt: + decoder_layer = self.model.layers[i] + layer_outputs = decoder_layer( + self.hidden_states, + attention_mask=self.causal_mask, + position_ids=self.position_ids, + past_key_values=self.past_key_values, + use_cache=True, + cache_position=self.cache_position + ) + + self.hidden_states = layer_outputs[0] + self.next_decoder_cache = layer_outputs[1] + + + # handle last layer to get logits + if self.is_last: + self.hidden_states = model.norm(self.hidden_states) + + if use_legacy_cache: + self.past_key_values = self.next_decoder_cache.to_legacy_cache() + else: + self.past_key_values = self.next_decoder_cache + + # lm_head + logits = llm_model.lm_head(self.hidden_states).to(self.device) + + return ( + None, + None, + logits + ) + + return ( + self.hidden_states, + self.past_key_values, + None + ) + From 4b0df06dabaa345e238006c8af29483ac64244a5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 15 Sep 2024 07:23:16 -0800 Subject: [PATCH 333/491] building out inference engine test --- exo/inference/pytorch/inference.py | 131 +++++++----------- exo/inference/pytorch/model/hf.py | 90 +++++++++--- exo/inference/pytorch/tests/__init__.py | 0 .../{ => tests}/test_inference_engine.py | 0 .../{ => tests}/test_inference_loop.py | 40 +++--- .../pytorch/{ => tests}/test_simple_model.py | 0 .../pytorch/{ => tests}/test_split_model.py | 1 - .../pytorch/{ => tests}/test_weight_load.py | 0 exo/inference/pytorch/{ => tests}/utils.py | 0 9 files changed, 139 insertions(+), 123 deletions(-) create mode 100644 exo/inference/pytorch/tests/__init__.py rename exo/inference/pytorch/{ => tests}/test_inference_engine.py (100%) rename exo/inference/pytorch/{ => tests}/test_inference_loop.py (77%) rename exo/inference/pytorch/{ => tests}/test_simple_model.py (100%) rename exo/inference/pytorch/{ => tests}/test_split_model.py (99%) rename exo/inference/pytorch/{ => tests}/test_weight_load.py (100%) rename exo/inference/pytorch/{ => tests}/utils.py (100%) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a1df7966..3bb7afd7 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,12 +1,11 @@ # experimental, based off of tinygrad/inference.py import numpy as np import torch -import numpy as np import json from typing import Optional, Tuple from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine -from exo.inference.pytorch.model.archive.hf_manual import ShardedHuggingFaceModel +from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG from transformers import DynamicCache @@ -47,7 +46,7 @@ def __init__(self, shard_downloader: ShardDownloader): self.torch_dtype = torch.float16 # setup unfinished sequence - self.unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=self.device) + self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device) async def infer_prompt( self, @@ -71,41 +70,42 @@ async def infer_prompt( ) inputs = self.tokenizer([txt], return_tensors="pt") - input_ids = inputs.input_ids.to("cuda") - input_attention_mask = inputs.attention_mask.to("cuda") + input_ids = inputs.input_ids.to(self.device) + input_attention_mask = inputs.attention_mask.to(self.device) batch_size, seq_length = input_ids.shape[:2] - + + if DEBUG >= 4: print(f"input_ids: {input_ids}\n") - print(f"layer_count: {self.shard.get_layer_count()}") - print(f"is_first_layer: {self.shard.is_first_layer()}") - print(f"is_last_layer: {self.shard.is_last_layer()}") - output_data = self.stateful_sharded_model.forward( - input_ids + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + input_ids=input_ids, + attention_mask=input_attention_mask ) - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] - if DEBUG >= 4: - print(f"output_data: {output_data}\n") - print(f"output_data.size {output_data.size}\n") - - print(f"finished: {is_finished}") - print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") - print(f"output_data[-1] {output_data[-1]}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") - if output_data.size == 1: - print(f"size 1 output_data.item() {output_data.item()}") - print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + if shard_logits is not None: + input_ids = self.stateful_sharded_model.logits_sample(input_ids, shard_logits) + print(input_ids) + + if shard_past_kvs is not None: + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in shard_past_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in shard_past_kvs_kvs.value_cache] + } + else: + cache_dict = None - cache_dict = { - 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] - } + stopping_critera = self.stateful_sharded_model.stopping_critera + self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) + is_finished = self.unfinished_sequences.max() == 0 return ( - output_data, + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states, json.dumps(cache_dict), is_finished ) @@ -117,7 +117,7 @@ async def infer_tensor( input_data: np.ndarray, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 3: + if DEBUG >= 4: print("infer_tensor called") print(f"input_data: {input_data}") print(f"input_data.size: {input_data.size}") @@ -126,69 +126,34 @@ async def infer_tensor( await self.ensure_shard(shard) - current_kvs = None - - - if input_data.size == 1: - in_tensor = torch.tensor([[input_data.item()]]).to(self.device) - else: - in_tensor = torch.tensor(input_data).to(self.device) - - # in_tensor = torch.tensor(input_data).to(self.device) - - # in_tensor = self.stateful_sharded_model.embed_tokens(in_tensor) - - # convert inference_state or cache from json to DynamicCache - past_kv = DynamicCache() - if inference_state != None: - try: - cache_dict = json.loads(inference_state) - past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] - past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] - past_kv_length = past_kv[0][0].shape[2] - except json.JSONDecodeError: - print(f"ERROR DECODING INFERENCE STATE") - - if DEBUG >= 3: - # print(f"input_tensor: {in_tensor}") - print(f"layer_count: {self.shard.get_layer_count()}") - print(f"is_first_layer: {self.shard.is_first_layer()}") - print(f"is_last_layer: {self.shard.is_last_layer()}") - print(f"input_data.shape: {input_data.shape}") + hidden_states = torch.tensor(input_data) - print(f"in_tensor: {in_tensor}") - output_data, current_kvs = self.stateful_sharded_model.forward( - in_tensor, - None, - past_kv + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + hidden_states=hidden_states ) - is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] - - if DEBUG >= 3: - print(f"output_data: {output_data}\n") - print(f"output_data.size {output_data.size}\n") - print(f"finished: {is_finished}") - print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") - print(f"output_data[-1] {output_data[-1]}") - print("====================================================") - - if output_data.size == 1: - print(f"size 1 output_data.item() {output_data.item()}") - print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + if shard_logits is not None: + input_ids = self.stateful_sharded_model.logits_sample(shard_logits) + + if shard_past_kvs is not None: + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in shard_past_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in shard_past_kvs_kvs.value_cache] + } + else: + cache_dict = None - - cache_dict = { - 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] - } + stopping_critera = self.stateful_sharded_model.stopping_critera + self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) + is_finished = self.unfinished_sequences.max() == 0 return ( - output_data, + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states, json.dumps(cache_dict), is_finished ) - + + async def ensure_shard(self, shard: Optional[Shard]): """ Ensure the model shard is loaded and ready for inference. @@ -206,7 +171,7 @@ async def ensure_shard(self, shard: Optional[Shard]): # model_path = await self.shard_downloader.ensure_shard(shard) self.tokenizer = await resolve_tokenizer(shard.model_id) - self.stateful_sharded_model = ShardedHuggingFaceModel(shard) + self.stateful_sharded_model = ShardedHuggingFaceModel(shard, self.device, self.torch_dtype) self.shard = shard if DEBUG >= 4: diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 2805ae67..be928aa3 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -38,8 +38,8 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model -class ShardedHuggingFaceModel(InferenceEngine): - def __init__(self, shard: Shard, ): +class ShardedHuggingFaceModel: + def __init__(self, shard: Shard, device, dtype): # class vars self.shard = shard self.hidden_states = None @@ -50,15 +50,35 @@ def __init__(self, shard: Shard, ): self.past_key_values = None self.cache_position = None self.position_ids = None - self.causal_mask = None + self.causal_mask = None + + # setup logit processors + self.logits_processor = LogitsProcessorList([ + TopKLogitsWarper(35), + TemperatureLogitsWarper(0.6), + TopPLogitsWarper(0.8) + ]) + + # setup stopping critera for generation + self.stopping_critera = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=50), + MaxTimeCriteria(max_time=10.0), + ] + ) + + self.device = device + self.torch_dtype = dtype # setup pytorch and transformer llm try: - self.base_model = AutoModelForCausalLM.from_pretrained( + self.llm_model = AutoModelForCausalLM.from_pretrained( shard.model_id, torch_dtype=self.torch_dtype, device_map="auto" - ) + ) + + self.model = self.llm_model.model except Exception as err: print(f"error loading and splitting model: {err}") raise @@ -67,8 +87,6 @@ def __init__(self, shard: Shard, ): def forward( self, shard: Optional[Shard] = None, - model, - llm_model, input_ids: Optional[torch.tensor] = None, hidden_states: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, @@ -108,7 +126,7 @@ def forward( self.input_ids = input_ids # embed input_ids - self.inputs_embeds = model.embed_tokens(self.input_ids) + self.inputs_embeds = self.model.embed_tokens(self.input_ids) # cache if past_key_values and not isinstance(past_key_values, Cache): @@ -128,23 +146,23 @@ def forward( # casual mask and attention_mask self.attention_mask = attention_mask - self.causal_mask = model._update_causal_mask( + self.causal_mask = self.model._update_causal_mask( None, self.inputs_embeds, cache_position, past_key_values, - output_attentions + False # dont out attentions ) # embed positions, some models require and some dont - if isinstance(model, LlamaModel): - self.position_embeddings = model.rotary_emb( + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( self.inputs_embeds, position_ids ) # prepare inputs for decoder layers - model_inputs = llm_model.prepare_inputs_for_generation( + model_inputs = self.llm_model.prepare_inputs_for_generation( self.input_ids, past_key_values=past_key_values, attention_mask=self.attention_mask, @@ -175,27 +193,61 @@ def forward( self.next_decoder_cache = layer_outputs[1] - # handle last layer to get logits - if self.is_last: - self.hidden_states = model.norm(self.hidden_states) - + # handle last layer to get logits + if self.shard.is_last_layer(): + self.hidden_states = self.model.norm(self.hidden_states) if use_legacy_cache: self.past_key_values = self.next_decoder_cache.to_legacy_cache() else: self.past_key_values = self.next_decoder_cache # lm_head - logits = llm_model.lm_head(self.hidden_states).to(self.device) + logits = self.llm_model.lm_head(self.hidden_states).to(self.device) return ( None, None, logits ) - + print("199") return ( self.hidden_states, self.past_key_values, None ) + + def logits_sample( + self, + input_ids: torch.tensor, + logits: torch.tensor, + use_max: Optional[bool] = False + ) -> torch.tensor: + """ + Get a sample of the logits from end of model run + + Args: + logits: tensor + use_max: bool, if function should sample with argmax + + Returns: + input_ids: tensor + """ + + # get a single cloned logit + logits = logits[:, 1, :].clone().float() + + + next_token_scores = self.logits_processor(input_ids, logits) + + if not use_max: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # get inputs_ids from token sample + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + return input_ids + diff --git a/exo/inference/pytorch/tests/__init__.py b/exo/inference/pytorch/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/exo/inference/pytorch/test_inference_engine.py b/exo/inference/pytorch/tests/test_inference_engine.py similarity index 100% rename from exo/inference/pytorch/test_inference_engine.py rename to exo/inference/pytorch/tests/test_inference_engine.py diff --git a/exo/inference/pytorch/test_inference_loop.py b/exo/inference/pytorch/tests/test_inference_loop.py similarity index 77% rename from exo/inference/pytorch/test_inference_loop.py rename to exo/inference/pytorch/tests/test_inference_loop.py index a61b4342..b9cdd005 100644 --- a/exo/inference/pytorch/test_inference_loop.py +++ b/exo/inference/pytorch/tests/test_inference_loop.py @@ -48,16 +48,16 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e if __name__ == '__main__': - # try: - # print(f"\n\n -------- TEST QWEN2 -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Qwen/Qwen2-0.5B-Instruct", - # 24 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + try: + print(f"\n\n -------- TEST Qwen/Qwen2-0.5B-Instruct -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "Qwen/Qwen2-0.5B-Instruct", + 24 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") # try: # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") @@ -92,14 +92,14 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e # except Exception as err: # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") - try: - print(f"\n\n --------- TEST ambrosfitz/TinyLlama-1.1B-Chat-yawp -------\n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "ambrosfitz/TinyLlama-1.1B-Chat-yawp", - 22 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! ambrosfitz/TinyLlama-1.1B-Chat-yawp TEST FAILED \n{err}\n") + #try: + # print(f"\n\n --------- TEST ambrosfitz/TinyLlama-1.1B-Chat-yawp -------\n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "ambrosfitz/TinyLlama-1.1B-Chat-yawp", + # 22 + # )) + #except Exception as err: + # print(f"\n\n !!!!!!!!!!! ambrosfitz/TinyLlama-1.1B-Chat-yawp TEST FAILED \n{err}\n") diff --git a/exo/inference/pytorch/test_simple_model.py b/exo/inference/pytorch/tests/test_simple_model.py similarity index 100% rename from exo/inference/pytorch/test_simple_model.py rename to exo/inference/pytorch/tests/test_simple_model.py diff --git a/exo/inference/pytorch/test_split_model.py b/exo/inference/pytorch/tests/test_split_model.py similarity index 99% rename from exo/inference/pytorch/test_split_model.py rename to exo/inference/pytorch/tests/test_split_model.py index 42e1642a..7830b53e 100644 --- a/exo/inference/pytorch/test_split_model.py +++ b/exo/inference/pytorch/tests/test_split_model.py @@ -33,7 +33,6 @@ from exo.api.chatgpt_api import resolve_tokenizer from typing import Tuple, Optional, Union, List import re -from exo.inference.pytorch.utils import sample_logits, top_k_sampling TEMP = 0.6 TOP_K = 60 diff --git a/exo/inference/pytorch/test_weight_load.py b/exo/inference/pytorch/tests/test_weight_load.py similarity index 100% rename from exo/inference/pytorch/test_weight_load.py rename to exo/inference/pytorch/tests/test_weight_load.py diff --git a/exo/inference/pytorch/utils.py b/exo/inference/pytorch/tests/utils.py similarity index 100% rename from exo/inference/pytorch/utils.py rename to exo/inference/pytorch/tests/utils.py From 623468caee02c98395956fcfb87141e6cf705bbb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 15 Sep 2024 16:52:28 -0800 Subject: [PATCH 334/491] adding working tests, update to forward function to just use input_ids, update infer to pass logits and hidden_states --- exo/inference/pytorch/inference.py | 41 ++++- exo/inference/pytorch/model/hf.py | 144 ++++++++++-------- .../pytorch/tests/test_inference_engine.py | 56 ++++--- .../pytorch/tests/test_inference_loop.py | 28 ++-- 4 files changed, 161 insertions(+), 108 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 3bb7afd7..5ce6f2a2 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -2,6 +2,7 @@ import numpy as np import torch import json +import gc from typing import Optional, Tuple from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine @@ -75,12 +76,19 @@ async def infer_prompt( batch_size, seq_length = input_ids.shape[:2] + if inference_state is not None: + past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) + else: + past_kvs = None + + if DEBUG >= 4: print(f"input_ids: {input_ids}\n") shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=input_ids, - attention_mask=input_attention_mask + attention_mask=input_attention_mask, + past_key_values=past_kvs ) if DEBUG >= 4: @@ -105,7 +113,7 @@ async def infer_prompt( is_finished = self.unfinished_sequences.max() == 0 return ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states, + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), json.dumps(cache_dict), is_finished ) @@ -126,14 +134,29 @@ async def infer_tensor( await self.ensure_shard(shard) - hidden_states = torch.tensor(input_data) + if input_data.size == 1: + hidden_states = torch.tensor(input_data).to(self.device) + hidden_states = hidden_states.unsqueeze(0) + else: + hidden_states = torch.tensor(input_data).long().to(self.device) + + if inference_state is not None: + past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) + else: + past_kvs = None + + if DEBUG >= 4: + print(f"hidden_states: {hidden_states}") + print(f"inference_state: {inference_state}") shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( - hidden_states=hidden_states + input_ids=hidden_states, + past_key_values=past_kvs, + infer_tensor=True ) if shard_logits is not None: - input_ids = self.stateful_sharded_model.logits_sample(shard_logits) + input_ids = self.stateful_sharded_model.logits_sample(hidden_states, shard_logits) if shard_past_kvs is not None: cache_dict = { @@ -148,7 +171,7 @@ async def infer_tensor( is_finished = self.unfinished_sequences.max() == 0 return ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states, + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), json.dumps(cache_dict), is_finished ) @@ -170,6 +193,12 @@ async def ensure_shard(self, shard: Optional[Shard]): # need to build in shard downloader # model_path = await self.shard_downloader.ensure_shard(shard) + if self.stateful_sharded_model: + print("Deleting model") + del self.stateful_sharded_model + # gc.collect() + # torch.cuda.empty_cache() + self.tokenizer = await resolve_tokenizer(shard.model_id) self.stateful_sharded_model = ShardedHuggingFaceModel(shard, self.device, self.torch_dtype) self.shard = shard diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index be928aa3..a59eb1a7 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn import numpy as np -import gc from typing import Tuple, Optional, Union, List from exo.inference.shard import Shard @@ -75,7 +74,8 @@ def __init__(self, shard: Shard, device, dtype): self.llm_model = AutoModelForCausalLM.from_pretrained( shard.model_id, torch_dtype=self.torch_dtype, - device_map="auto" + device_map="auto", + offload_buffers=True ) self.model = self.llm_model.model @@ -88,10 +88,10 @@ def forward( self, shard: Optional[Shard] = None, input_ids: Optional[torch.tensor] = None, - hidden_states: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_legacy_cache: Optional[bool] = False + use_legacy_cache: Optional[bool] = False, + infer_tensor: Optional[bool] = False ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: """ @@ -116,70 +116,79 @@ def forward( """ - if input_ids is not None and hidden_states is not None: - raise ValueError - - if hidden_states is not None: - self.hidden_states = hidden_states - - if input_ids is not None: - self.input_ids = input_ids - - # embed input_ids - self.inputs_embeds = self.model.embed_tokens(self.input_ids) - - # cache - if past_key_values and not isinstance(past_key_values, Cache): - print("Using legacy cache") - use_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + self.inputs_embeds.shape[1], - device=self.inputs_embeds.device - ) + self.input_ids = input_ids + + # embed input_ids + self.inputs_embeds = self.model.embed_tokens(self.input_ids) + + #if DEBUG >= 4: + # print("forward called") + # print(f"input_ids: {self.input_ids}") + # print(f"inputs_embeds: {self.inputs_embeds}") + + # cache + if past_key_values and not isinstance(past_key_values, Cache): + print("Using legacy cache") + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) - # position id - position_ids = cache_position.unsqueeze(0) + # position id + position_ids = cache_position.unsqueeze(0) + + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, + self.inputs_embeds, + cache_position, + past_key_values, + False # dont out attentions + ) - # casual mask and attention_mask - self.attention_mask = attention_mask - self.causal_mask = self.model._update_causal_mask( - None, + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( self.inputs_embeds, - cache_position, - past_key_values, - False # dont out attentions - ) - - # embed positions, some models require and some dont - if isinstance(self.model, LlamaModel): - self.position_embeddings = self.model.rotary_emb( - self.inputs_embeds, - position_ids - ) - - # prepare inputs for decoder layers - model_inputs = self.llm_model.prepare_inputs_for_generation( - self.input_ids, - past_key_values=past_key_values, - attention_mask=self.attention_mask, - inputs_embeds=self.inputs_embeds, - position_ids=position_ids, - cache_position=cache_position + position_ids ) + + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=position_ids, + cache_position=cache_position + ) - self.hidden_states = self.inputs_embeds - self.position_ids = model_inputs["position_ids"] - self.cache_position = model_inputs["cache_position"] - self.past_key_values = model_inputs["past_key_values"] + self.hidden_states = self.inputs_embeds if not infer_tensor else self.input_ids + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] # run through decoder layers - layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) + layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) + + if DEBUG >= 4: + print(f"hidden_states: {self.hidden_states}") + print(f"model_inputs: {model_inputs}") + print(f"layer_amt: {layer_amt}") + for i in layer_amt: decoder_layer = self.model.layers[i] + if DEBUG >= 5: + print("decoder_layer before") + print(f"decoder_layer: {decoder_layer}") + print(f"hidden_states: {self.hidden_states}") + layer_outputs = decoder_layer( self.hidden_states, attention_mask=self.causal_mask, @@ -192,8 +201,14 @@ def forward( self.hidden_states = layer_outputs[0] self.next_decoder_cache = layer_outputs[1] + if DEBUG >= 5: + print("decoder_layer after") + print(f"hidden_states: {self.hidden_states}") + print(f"next_decoder_cache: {self.next_decoder_cache}") + # handle last layer to get logits + # shard is last layer says true at the start and not detecting last layer correctly if self.shard.is_last_layer(): self.hidden_states = self.model.norm(self.hidden_states) if use_legacy_cache: @@ -209,6 +224,7 @@ def forward( None, logits ) + print("199") return ( self.hidden_states, @@ -223,7 +239,7 @@ def logits_sample( use_max: Optional[bool] = False ) -> torch.tensor: """ - Get a sample of the logits from end of model run + Get a sample of the logits from end of model run for next token Args: logits: tensor @@ -234,7 +250,7 @@ def logits_sample( """ # get a single cloned logit - logits = logits[:, 1, :].clone().float() + logits = logits[:, -1, :].clone().float() next_token_scores = self.logits_processor(input_ids, logits) @@ -245,9 +261,11 @@ def logits_sample( else: next_tokens = torch.argmax(next_token_scores, dim=-1) + print(f"next_tokens: {next_tokens[:, None]}") + # get inputs_ids from token sample - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + # input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - return input_ids + return next_tokens[:, None] diff --git a/exo/inference/pytorch/tests/test_inference_engine.py b/exo/inference/pytorch/tests/test_inference_engine.py index 9b8a19ef..7e64c137 100644 --- a/exo/inference/pytorch/tests/test_inference_engine.py +++ b/exo/inference/pytorch/tests/test_inference_engine.py @@ -8,8 +8,14 @@ from exo.helpers import DEBUG import os import numpy as np +import time + +async def test_inference_engine( + inference_engine_1: InferenceEngine, + inference_engine_2: InferenceEngine, + model_id: str, + n_layers: int): -async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int): # prompt = "Why is the sky blue?" prompt = "In a single word only, what is the last name of the current president of the USA?" @@ -30,6 +36,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e print(resp_full) print("\n------------resp_full---------------\n") + time.sleep(5) + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( "A", shard=shard, @@ -41,8 +49,10 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e print(next_resp_full) print("\n------------next_resp_full---------------\n") + time.sleep(5) + pp = int(n_layers/2) - + resp_shard = Shard( model_id=model_id, start_layer=0, @@ -67,6 +77,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e print(resp1) print("\n------------resp1---------------\n") + time.sleep(5) + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( "B", @@ -105,16 +117,16 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - # try: - # print(f"\n\n -------- TEST QWEN2 -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Qwen/Qwen2-0.5B-Instruct", - # 24 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + try: + print(f"\n\n -------- TEST QWEN2 -------- \n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "Qwen/Qwen2-0.5B-Instruct", + 24 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") # try: # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") @@ -149,14 +161,14 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e # except Exception as err: # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") - try: - print(f"\n\n --------- TEST TinyLlama/TinyLlama_v1.1 -------\n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "TinyLlama/TinyLlama_v1.1", - 22 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! TinyLlama/TinyLlama_v1.1 TEST FAILED \n{err}\n") + #try: + # print(f"\n\n --------- TEST TinyLlama/TinyLlama_v1.1 -------\n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "TinyLlama/TinyLlama_v1.1", + # 22 + # )) + #except Exception as err: + # print(f"\n\n !!!!!!!!!!! TinyLlama/TinyLlama_v1.1 TEST FAILED \n{err}\n") diff --git a/exo/inference/pytorch/tests/test_inference_loop.py b/exo/inference/pytorch/tests/test_inference_loop.py index b9cdd005..d9b038d8 100644 --- a/exo/inference/pytorch/tests/test_inference_loop.py +++ b/exo/inference/pytorch/tests/test_inference_loop.py @@ -21,8 +21,8 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e ) resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( - "A", - shard=shard, + "A", + shard=shard, prompt=prompt ) @@ -30,22 +30,16 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e print(resp_full) print("\n------------resp_full---------------\n") - next_resp_full = resp_full - is_finished = False - while not is_finished: - next_resp_full, _next_inference_state_full, is_finished = await inference_engine_1.infer_tensor( - "A", - shard=shard, - input_data=next_resp_full, - inference_state=inference_state_full, - ) - - print("\n------------next_resp_full---------------\n") - print(next_resp_full) - print("\n------------next_resp_full---------------\n") - + next_resp_full, _next_inference_state_full, is_finished = await inference_engine_1.infer_tensor( + "A", + shard=shard, + input_data=resp_full, + inference_state=inference_state_full, + ) - + print("\n------------next_resp_full---------------\n") + print(next_resp_full) + print("\n------------next_resp_full---------------\n") if __name__ == '__main__': try: From 19b322dee8d94ff7969a182be65245f59f4e1eb8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 13:10:58 -0800 Subject: [PATCH 335/491] cleaning up code and tests, debugging and adding in cleaned up logging, added model value options --- exo/inference/pytorch/.gitignore | 1 + exo/inference/pytorch/inference.py | 33 ++- exo/inference/pytorch/model/hf.py | 48 ++-- .../pytorch/tests/test_inference_loop.py | 99 --------- .../pytorch/tests/test_split_model.py | 3 +- .../pytorch/tests/test_weight_load.py | 206 ------------------ 6 files changed, 52 insertions(+), 338 deletions(-) delete mode 100644 exo/inference/pytorch/tests/test_inference_loop.py delete mode 100644 exo/inference/pytorch/tests/test_weight_load.py diff --git a/exo/inference/pytorch/.gitignore b/exo/inference/pytorch/.gitignore index 8fce6030..6d76c24d 100644 --- a/exo/inference/pytorch/.gitignore +++ b/exo/inference/pytorch/.gitignore @@ -1 +1,2 @@ data/ +model/archive/ diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 5ce6f2a2..3e6dc266 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -14,13 +14,15 @@ from exo.download.shard_download import ShardDownloader # model value options -TOP_K = 35 -TEMP = 0.6 -TOP_P = 0.8 +TOP_K = 25 +TEMP = 0.7 +TOP_P = 0.9 +MAX_LENGTH = 125 +MAX_TIME = 10.0 class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ - PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. + PyTorch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. """ def __init__(self, shard_downloader: ShardDownloader): @@ -190,17 +192,24 @@ async def ensure_shard(self, shard: Optional[Shard]): if DEBUG >= 4: print(f"Loading new shard: {shard}") - # need to build in shard downloader + # -- TO DO -- + # Build in shard downloader but requires pulling + # apart how TrainedModel loads weight in its __init__ + # function in the transformer library # model_path = await self.shard_downloader.ensure_shard(shard) - if self.stateful_sharded_model: - print("Deleting model") - del self.stateful_sharded_model - # gc.collect() - # torch.cuda.empty_cache() - self.tokenizer = await resolve_tokenizer(shard.model_id) - self.stateful_sharded_model = ShardedHuggingFaceModel(shard, self.device, self.torch_dtype) + self.stateful_sharded_model = ShardedHuggingFaceModel( + shard=shard, + device=self.device, + dtype=self.torch_dtype, + top_k=TOP_K, + temp=TEMP, + top_p=TOP_P, + max_length=MAX_LENGTH, + max_time=MAX_TIME + ) + self.shard = shard if DEBUG >= 4: diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index a59eb1a7..93fc7ad3 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -38,7 +38,17 @@ class ShardedHuggingFaceModel: - def __init__(self, shard: Shard, device, dtype): + def __init__( + self, + shard: Shard, + device, + dtype, + top_k: int = 25, + temp: float = 0.7, + top_p: float = 0.9, + max_length: int = 50, + max_time: float = 10.0 + ): # class vars self.shard = shard self.hidden_states = None @@ -53,16 +63,16 @@ def __init__(self, shard: Shard, device, dtype): # setup logit processors self.logits_processor = LogitsProcessorList([ - TopKLogitsWarper(35), - TemperatureLogitsWarper(0.6), - TopPLogitsWarper(0.8) + TopKLogitsWarper(top_k), + TemperatureLogitsWarper(temp), + TopPLogitsWarper(top_p) ]) # setup stopping critera for generation self.stopping_critera = StoppingCriteriaList( [ - MaxLengthCriteria(max_length=50), - MaxTimeCriteria(max_time=10.0), + MaxLengthCriteria(max_length=max_length), + MaxTimeCriteria(max_time=max_time), ] ) @@ -103,10 +113,10 @@ def forward( model: base llm model tramsformers class llm_model: llm chat model class input_ids: tensor optional - hidden_states: tensor optional attention_mask: tensor optional past_key_values: Cache or list[tensor] optional - use_legacy_cache: bool optional + use_legacy_cache: bool optional + infer_tensor: bool optional, lets forward know to handle tensors Returns: Tuple of @@ -120,15 +130,9 @@ def forward( # embed input_ids self.inputs_embeds = self.model.embed_tokens(self.input_ids) - - #if DEBUG >= 4: - # print("forward called") - # print(f"input_ids: {self.input_ids}") - # print(f"inputs_embeds: {self.inputs_embeds}") - + # cache if past_key_values and not isinstance(past_key_values, Cache): - print("Using legacy cache") use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) @@ -219,13 +223,19 @@ def forward( # lm_head logits = self.llm_model.lm_head(self.hidden_states).to(self.device) + if DEBUG >= 4: + print(f"logits: {logits}") + return ( None, None, logits ) - print("199") + if DEBUG >= 4: + print(f"hidden_states: {self.hidden_states}") + print(f"past_key_values: {self.past_key_values}") + return ( self.hidden_states, self.past_key_values, @@ -261,10 +271,8 @@ def logits_sample( else: next_tokens = torch.argmax(next_token_scores, dim=-1) - print(f"next_tokens: {next_tokens[:, None]}") - - # get inputs_ids from token sample - # input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if DEBUG >= 4: + print(f"next_tokens: {next_tokens[:, None]}") return next_tokens[:, None] diff --git a/exo/inference/pytorch/tests/test_inference_loop.py b/exo/inference/pytorch/tests/test_inference_loop.py deleted file mode 100644 index d9b038d8..00000000 --- a/exo/inference/pytorch/tests/test_inference_loop.py +++ /dev/null @@ -1,99 +0,0 @@ - -import asyncio -from exo.inference.shard import Shard -from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine -from exo.download.hf.hf_shard_download import HFShardDownloader -from exo.inference.inference_engine import InferenceEngine -from exo.inference.shard import Shard -from exo.helpers import DEBUG -import os -import numpy as np - -async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int): - # prompt = "Why is the sky blue?" - prompt = "In a single word only, what is the last name of the current president of the USA?" - - shard = Shard( - model_id=model_id, - start_layer=0, - end_layer=n_layers-1, - n_layers=n_layers - ) - - resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( - "A", - shard=shard, - prompt=prompt - ) - - print("\n------------resp_full---------------\n") - print(resp_full) - print("\n------------resp_full---------------\n") - - next_resp_full, _next_inference_state_full, is_finished = await inference_engine_1.infer_tensor( - "A", - shard=shard, - input_data=resp_full, - inference_state=inference_state_full, - ) - - print("\n------------next_resp_full---------------\n") - print(next_resp_full) - print("\n------------next_resp_full---------------\n") - -if __name__ == '__main__': - try: - print(f"\n\n -------- TEST Qwen/Qwen2-0.5B-Instruct -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "Qwen/Qwen2-0.5B-Instruct", - 24 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") - - # try: - # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "andrijdavid/Llama3-1B-Base", - # 3 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") - - # try: - # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "meta-llama/Meta-Llama-3.1-8B", - # 32 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") - - # try: - # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Chickaboo/ChickaQ-Large", - # 24 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") - - #try: - # print(f"\n\n --------- TEST ambrosfitz/TinyLlama-1.1B-Chat-yawp -------\n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "ambrosfitz/TinyLlama-1.1B-Chat-yawp", - # 22 - # )) - #except Exception as err: - # print(f"\n\n !!!!!!!!!!! ambrosfitz/TinyLlama-1.1B-Chat-yawp TEST FAILED \n{err}\n") - diff --git a/exo/inference/pytorch/tests/test_split_model.py b/exo/inference/pytorch/tests/test_split_model.py index 7830b53e..827bdec2 100644 --- a/exo/inference/pytorch/tests/test_split_model.py +++ b/exo/inference/pytorch/tests/test_split_model.py @@ -327,7 +327,8 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): if __name__ == "__main__": - prompt = "In a single word only, what is the last name of the current president of the USA?" + #prompt = "In a single word only, what is the last name of the current president of the USA?" + prompt = "What color is the sky? Explain why" #prompt = "In a single word only, what is the color of an apple?" #print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") diff --git a/exo/inference/pytorch/tests/test_weight_load.py b/exo/inference/pytorch/tests/test_weight_load.py deleted file mode 100644 index 7eb8142f..00000000 --- a/exo/inference/pytorch/tests/test_weight_load.py +++ /dev/null @@ -1,206 +0,0 @@ -import torch -import torch.nn as nn -import asyncio -import gc -import json -from transformers import AutoConfig, AutoModel -from safetensors import safe_open -from typing import Tuple, Optional -import re -from exo.inference.pytorch.utils import sample_logits, top_k_sampling -from exo.api.chatgpt_api import resolve_tokenizer - -TEMP = 0.6 -TOP_K = 60 - -class OnionHuggingFaceLM(): - def __init__(self, layers, safetensor_index_file, safetensor_directory, is_last=False): - self.layers = layers - self.is_last = is_last - self.safetensor_index_file = safetensor_index_file - self.safetensor_directory = safetensor_directory - - # Load the safetensor index JSON - with open(safetensor_index_file, "r") as f: - self.index_data = json.load(f) - self.weight_map = self.index_data['weight_map'] - self.safetensors_metadata = self.index_data['safetensors_metadata'] - - def load_layer_weights(self, model, layer_index): - layer_tensors = {} - for param_name, file_name in self.weight_map.items(): - if param_name.startswith(f"model.layers.{layer_index}"): - file_path = f"{self.safetensor_directory}/{file_name}" - print(f"loading safetensor\n{file_path}\nfor layer\n{layer_index}") - offsets = self.safetensors_metadata[file_name]['offsets'][param_name] - dtype = self.safetensors_metadata[file_name]['dtype'] - shape = self.safetensors_metadata[file_name]['shape'] - - with safe_open(file_path, framework="pt", device="cuda") as f: - tensor = f.get_tensor_slice(offsets[0], offsets[1]) - tensor = tensor.view(shape) # Reshape to the correct shape - - layer_tensors[param_name] = tensor - - # Assign these tensors to the model's layer - for param_name, tensor in layer_tensors.items(): - param_pointer = model - param_parts = param_name.split('.') - for attr in param_parts[:-1]: - if attr.isdigit(): - attr = int(attr) - param_pointer = getattr(param_pointer, attr) - setattr(param_pointer, param_parts[-1], tensor) - - def forward( - self, - model, - input_ids: torch.tensor=None, - hidden_states: torch.tensor=None, - attention_mask: torch.tensor=None, - **kwargs - ) -> Tuple[Optional[torch.tensor], Optional[torch.tensor]]: - - base_model = model.model - - if input_ids is not None and hidden_states is not None: - print("You must either pass a hidden_state or input_ids but not both") - raise ValueError - - if input_ids is not None: - hidden_states = base_model.embed_tokens(input_ids) - position_ids = torch.arange( - 0, - input_ids.size(1), - device=input_ids.device - ).unsqueeze(0) - - if hidden_states is not None: - position_ids = torch.arange( - 0, - hidden_states.size(1), - device=hidden_states.device - ).unsqueeze(0) - - for idx, layer in enumerate(self.layers): - print(f"Loading weights for layer {idx}") - self.load_layer_weights(model, idx) # Load weights for the current layer - print(f"Processing hidden state from layer {idx}\n") - hidden_states = layer( - hidden_states, - position_ids=position_ids - )[0] - - if self.is_last: - norm_states = base_model.norm(hidden_states).to("cuda") - logits = model.lm_head(norm_states).to("cuda") - - return (None, logits) - - return (hidden_states, None) - -async def model_half_split_test( - prompt: str, - model_id: str, - layers: int, - safetensor_index_file: str, - safetensor_directory: str): - - half_layers = int(layers / 2) - - print("loading tokenizer") - tokenizer = await resolve_tokenizer(model_id) - max_length = 512 - - print("loading config and model") - config = AutoConfig.from_pretrained(model_id, local_files_only=True) - model = AutoModel.from_config(config).to("cuda") - - print(model.hf_device_map) - - shard_layers = nn.ModuleList(model.model.layers[:half_layers]) - sharded_model = OnionHuggingFaceLM( - layers=shard_layers, - safetensor_index_file=safetensor_index_file, - safetensor_directory=safetensor_directory - ) - - print(model) - - messages = [{"role": "user", "content": prompt}] - txt = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - - print(f"Generating from chat template\n{txt}") - - inputs = tokenizer([txt], return_tensors="pt") - input_ids = inputs.input_ids.to("cuda") - input_attention_mask = inputs.attention_mask.to("cuda") - - shard_hidden_states, shard_logits = sharded_model.forward( - model=model, - input_ids=input_ids - ) - - print(f"shard_hidden_states\n{shard_hidden_states}") - print(f"shard_logits\n{shard_logits}") - - print("Using first half hidden state for last half of model") - shard_layers = nn.ModuleList(model.model.layers[half_layers:]).to("cuda") - sharded_model.layers = shard_layers - sharded_model.is_last = True - - if shard_hidden_states is not None: - shard_hidden_states, shard_logits = sharded_model.forward( - model=model, - hidden_states=shard_hidden_states - ) - - print(f"shard_hidden_states\n{shard_hidden_states}") - print(f"shard_logits\n{shard_logits}") - else: - print("Sharded hidden states not found, error") - raise ValueError - - print("generate from logits") - if shard_logits is not None: - generated_ids = sample_logits(shard_logits, TEMP, 0.95, TOP_K) - print("generated_ids") - print(generated_ids) - - generated_text = tokenizer.batch_decode( - generated_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False - )[0] - - print("Generated text:") - print(generated_text) - else: - print("Sharded logits missing from last layer run, error") - raise ValueError - - del model - gc.collect() - torch.cuda.empty_cache() - -if __name__ == "__main__": - prompt = "In a single word only, what is the last name of the current president of the USA?" - - print("\n-------- Test Qwen/Qwen2-7B-Instruct ----------\n") - model_id = "Qwen/Qwen2-7B-Instruct" - model_layers = 22 - - asyncio.run( - model_half_split_test( - prompt=prompt, - model_id=model_id, - layers=model_layers, - safetensor_index_file="./data/qwen2_7B_Instruct/model.safetensors.index.json", - safetensor_directory="./data/qwen2_7B_Instruct/" - ) - ) - From cc2c14cf87b83c3cde2a3c8a166fa5126bd54e92 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 15:26:26 -0800 Subject: [PATCH 336/491] getting infer and stop token issues --- exo/inference/pytorch/inference.py | 7 +++---- exo/inference/pytorch/model/hf.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 3e6dc266..d14132e0 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -18,7 +18,7 @@ TEMP = 0.7 TOP_P = 0.9 MAX_LENGTH = 125 -MAX_TIME = 10.0 +MAX_TIME = 60.0 class PyTorchDynamicShardInferenceEngine(InferenceEngine): """ @@ -99,8 +99,7 @@ async def infer_prompt( print(f"\nshard_logits: {shard_logits}") if shard_logits is not None: - input_ids = self.stateful_sharded_model.logits_sample(input_ids, shard_logits) - print(input_ids) + input_ids = self.stateful_sharded_model.logits_sample(shard_logits) if shard_past_kvs is not None: cache_dict = { @@ -158,7 +157,7 @@ async def infer_tensor( ) if shard_logits is not None: - input_ids = self.stateful_sharded_model.logits_sample(hidden_states, shard_logits) + input_ids = self.stateful_sharded_model.logits_sample(shard_logits) if shard_past_kvs is not None: cache_dict = { diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 93fc7ad3..bf40f919 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -71,7 +71,7 @@ def __init__( # setup stopping critera for generation self.stopping_critera = StoppingCriteriaList( [ - MaxLengthCriteria(max_length=max_length), + #MaxLengthCriteria(max_length=max_length), MaxTimeCriteria(max_time=max_time), ] ) @@ -173,7 +173,7 @@ def forward( cache_position=cache_position ) - self.hidden_states = self.inputs_embeds if not infer_tensor else self.input_ids + self.hidden_states = self.inputs_embeds self.position_ids = model_inputs["position_ids"] self.cache_position = model_inputs["cache_position"] self.past_key_values = model_inputs["past_key_values"] @@ -244,7 +244,6 @@ def forward( def logits_sample( self, - input_ids: torch.tensor, logits: torch.tensor, use_max: Optional[bool] = False ) -> torch.tensor: @@ -263,17 +262,21 @@ def logits_sample( logits = logits[:, -1, :].clone().float() - next_token_scores = self.logits_processor(input_ids, logits) + next_token_scores = self.logits_processor(self.input_ids, logits) if not use_max: probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + next_tokens = torch.multinomial(probs, num_samples=1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) if DEBUG >= 4: + print(f"input_ids: {self.input_ids}") print(f"next_tokens: {next_tokens[:, None]}") - return next_tokens[:, None] + input_ids = torch.cat([self.input_ids, next_tokens[:, None].squeeze(-1)], dim=-1) + + return input_ids + #return next_tokens[:, None].squeeze(-1) From 583629c0c3349b9fa6daebd5b9ceb7319290e387 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 16:35:29 -0800 Subject: [PATCH 337/491] add tracking of next token and other logits into the full input_ids so it will stop, context was also dropping due to not having all the logits --- exo/api/chatgpt_api.py | 9 +++++++- exo/inference/pytorch/inference.py | 35 ++++++++++++++++++------------ exo/inference/pytorch/model/hf.py | 14 +++++------- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 2619d163..320377fb 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -70,9 +70,16 @@ def generate_completion( } choice = completion["choices"][0] + print(f"\nchoice {choice}") if object_type.startswith("chat.completion"): key_name = "delta" if stream else "message" - choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)} + + token_decode = tokenizer.batch_decode( + tokens, + skip_special_tokens=True, + clean_up_tokenization_spaces=False + ) + choice[key_name] = {"role": "assistant", "content": token_decode} elif object_type == "text_completion": choice["text"] = tokenizer.decode(tokens) else: diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index d14132e0..f98ac4d0 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -14,8 +14,8 @@ from exo.download.shard_download import ShardDownloader # model value options -TOP_K = 25 -TEMP = 0.7 +TOP_K = 20 +TEMP = 0.6 TOP_P = 0.9 MAX_LENGTH = 125 MAX_TIME = 60.0 @@ -37,6 +37,11 @@ def __init__(self, shard_downloader: ShardDownloader): self.stateful_sharded_model = None self.tokenizer = None + # the whole history with new logits need to + # be passed to the model to reach the end token + # even with caching + self.past_input_ids = None + # setup cuda device if torch.cuda.is_available(): self.device = torch.device("cuda") @@ -99,8 +104,10 @@ async def infer_prompt( print(f"\nshard_logits: {shard_logits}") if shard_logits is not None: - input_ids = self.stateful_sharded_model.logits_sample(shard_logits) - + next_token = self.stateful_sharded_model.logits_sample(shard_logits) + self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) + input_ids = next_token + if shard_past_kvs is not None: cache_dict = { 'key_cache': [tensor.tolist() for tensor in shard_past_kvs.key_cache], @@ -111,7 +118,10 @@ async def infer_prompt( stopping_critera = self.stateful_sharded_model.stopping_critera self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) - is_finished = self.unfinished_sequences.max() == 0 + is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id + + if is_finished: + self.past_input_ids = None return ( input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), @@ -134,12 +144,9 @@ async def infer_tensor( print(f"shard: {self.shard}") await self.ensure_shard(shard) - - if input_data.size == 1: - hidden_states = torch.tensor(input_data).to(self.device) - hidden_states = hidden_states.unsqueeze(0) - else: - hidden_states = torch.tensor(input_data).long().to(self.device) + + input_ids = torch.tensor(input_data).long().to(self.device) + self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) if inference_state is not None: past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) @@ -147,11 +154,11 @@ async def infer_tensor( past_kvs = None if DEBUG >= 4: - print(f"hidden_states: {hidden_states}") + print(f"input_ids: {input_ids}") print(f"inference_state: {inference_state}") shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( - input_ids=hidden_states, + input_ids=self.past_input_ids, past_key_values=past_kvs, infer_tensor=True ) @@ -169,7 +176,7 @@ async def infer_tensor( stopping_critera = self.stateful_sharded_model.stopping_critera self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) - is_finished = self.unfinished_sequences.max() == 0 + is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id return ( input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index bf40f919..066d643c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -255,28 +255,24 @@ def logits_sample( use_max: bool, if function should sample with argmax Returns: - input_ids: tensor + next_token: tensor """ # get a single cloned logit logits = logits[:, -1, :].clone().float() - next_token_scores = self.logits_processor(self.input_ids, logits) if not use_max: probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1) + next_token = torch.multinomial(probs, num_samples=1) else: - next_tokens = torch.argmax(next_token_scores, dim=-1) + next_token = torch.argmax(next_token_scores, dim=-1) if DEBUG >= 4: print(f"input_ids: {self.input_ids}") - print(f"next_tokens: {next_tokens[:, None]}") + print(f"next_token: {next_token}") - input_ids = torch.cat([self.input_ids, next_tokens[:, None].squeeze(-1)], dim=-1) - - return input_ids - #return next_tokens[:, None].squeeze(-1) + return next_token[:, None].squeeze(-1) From 7ec5bb8409db41f0a411538e78ff5a0226b4bc33 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 17:19:03 -0800 Subject: [PATCH 338/491] grpc testing --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f98ac4d0..5d3fa506 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -124,7 +124,7 @@ async def infer_prompt( self.past_input_ids = None return ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + input_ids.numpy(force=True), json.dumps(cache_dict), is_finished ) @@ -179,7 +179,7 @@ async def infer_tensor( is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id return ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + input_ids.numpy(force=True), json.dumps(cache_dict), is_finished ) From 5903e6342d241fdfa0fd9853873038d176fb0ad3 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 17:37:15 -0800 Subject: [PATCH 339/491] grpc testing --- exo/inference/pytorch/inference.py | 31 +++++++++++++++++++------ exo/inference/pytorch/model/hf.py | 3 +-- exo/networking/grpc/grpc_peer_handle.py | 3 +++ 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 5d3fa506..8a9f32a7 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -123,12 +123,17 @@ async def infer_prompt( if is_finished: self.past_input_ids = None - return ( - input_ids.numpy(force=True), + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), json.dumps(cache_dict), is_finished ) + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + async def infer_tensor( self, request_id: str, @@ -146,7 +151,11 @@ async def infer_tensor( await self.ensure_shard(shard) input_ids = torch.tensor(input_data).long().to(self.device) - self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) + + if self.past_input_ids is not None: + self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) + else: + self.past_input_ids = input_ids if inference_state is not None: past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) @@ -159,8 +168,7 @@ async def infer_tensor( shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=self.past_input_ids, - past_key_values=past_kvs, - infer_tensor=True + past_key_values=past_kvs ) if shard_logits is not None: @@ -178,11 +186,20 @@ async def infer_tensor( self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id - return ( - input_ids.numpy(force=True), + if DEBUG >= 4: + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), json.dumps(cache_dict), is_finished ) + + print(f"return_values: {return_values}") + + return return_values async def ensure_shard(self, shard: Optional[Shard]): diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 066d643c..d040418f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -100,8 +100,7 @@ def forward( input_ids: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_legacy_cache: Optional[bool] = False, - infer_tensor: Optional[bool] = False + use_legacy_cache: Optional[bool] = False ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: """ diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index 0629dc77..757de9fa 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -11,6 +11,7 @@ from exo.topology.topology import Topology from exo.topology.device_capabilities import DeviceCapabilities +from exo.helpers import DEBUG class GRPCPeerHandle(PeerHandle): def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities): @@ -52,6 +53,8 @@ async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] request_id=request_id, inference_state=inference_state, ) + + print(f"request: {request}") response = await self.stub.SendPrompt(request) if not response.tensor_data or not response.shape or not response.dtype: From e7a3fd0da3740ca832936c4691bf235e37108892 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 17:40:41 -0800 Subject: [PATCH 340/491] grpc testing --- exo/inference/pytorch/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 8a9f32a7..fce410fc 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -197,7 +197,8 @@ async def infer_tensor( is_finished ) - print(f"return_values: {return_values}") + if DEBUG >= 4: + print(f"return_values: {return_values}") return return_values From f6eec5ab5e6671bb928bc968cdba784d3cfdf22d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 18:00:06 -0800 Subject: [PATCH 341/491] grpc testing --- exo/inference/pytorch/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index fce410fc..cc6a4854 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -117,6 +117,7 @@ async def infer_prompt( cache_dict = None stopping_critera = self.stateful_sharded_model.stopping_critera + print("set stopping critera") self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id From d441a51d9ab317dc04bf544b6090ce59b7dea4a5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 18:03:29 -0800 Subject: [PATCH 342/491] grpc testing --- exo/inference/pytorch/inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index cc6a4854..27cf634e 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -106,6 +106,10 @@ async def infer_prompt( if shard_logits is not None: next_token = self.stateful_sharded_model.logits_sample(shard_logits) self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) + + stopping_critera = self.stateful_sharded_model.stopping_critera + print("set stopping critera") + self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) input_ids = next_token if shard_past_kvs is not None: @@ -116,9 +120,6 @@ async def infer_prompt( else: cache_dict = None - stopping_critera = self.stateful_sharded_model.stopping_critera - print("set stopping critera") - self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id if is_finished: From e7f6dcb68227d7892c6b4d2889b725849bdd6bcb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 18:08:01 -0800 Subject: [PATCH 343/491] grpc testing --- exo/inference/pytorch/inference.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 27cf634e..bb50dc0c 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -106,10 +106,6 @@ async def infer_prompt( if shard_logits is not None: next_token = self.stateful_sharded_model.logits_sample(shard_logits) self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) - - stopping_critera = self.stateful_sharded_model.stopping_critera - print("set stopping critera") - self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) input_ids = next_token if shard_past_kvs is not None: @@ -120,11 +116,16 @@ async def infer_prompt( else: cache_dict = None + stopping_critera = self.stateful_sharded_model.stopping_critera + print("set stopping critera") + self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id if is_finished: self.past_input_ids = None + print(f"shard as numpy: {shard_hidden_states.numpy(force=True)}") + return_values = ( input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), json.dumps(cache_dict), From ba5b00566b8f2652cd9188dc9b26f00636a449f2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 18:37:35 -0800 Subject: [PATCH 344/491] grpc testing --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index bb50dc0c..49991410 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -124,7 +124,7 @@ async def infer_prompt( if is_finished: self.past_input_ids = None - print(f"shard as numpy: {shard_hidden_states.numpy(force=True)}") + print(f"shard as numpy: {shard_hidden_states.detach().cpu().numpy()}") return_values = ( input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), From 6242d762c113e65414c7e6a901316707944b9338 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 18:42:40 -0800 Subject: [PATCH 345/491] grpc testing --- exo/inference/pytorch/model/hf.py | 36 +++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index d040418f..61928fe3 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -212,24 +212,24 @@ def forward( # handle last layer to get logits # shard is last layer says true at the start and not detecting last layer correctly - if self.shard.is_last_layer(): - self.hidden_states = self.model.norm(self.hidden_states) - if use_legacy_cache: - self.past_key_values = self.next_decoder_cache.to_legacy_cache() - else: - self.past_key_values = self.next_decoder_cache - - # lm_head - logits = self.llm_model.lm_head(self.hidden_states).to(self.device) - - if DEBUG >= 4: - print(f"logits: {logits}") - - return ( - None, - None, - logits - ) + #if self.shard.is_last_layer(): + self.hidden_states = self.model.norm(self.hidden_states) + if use_legacy_cache: + self.past_key_values = self.next_decoder_cache.to_legacy_cache() + else: + self.past_key_values = self.next_decoder_cache + + # lm_head + logits = self.llm_model.lm_head(self.hidden_states).to(self.device) + + if DEBUG >= 4: + print(f"logits: {logits}") + + return ( + None, + None, + logits + ) if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") From 563073104a7376b9163cbae86278113498bf7275 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 16 Sep 2024 18:45:17 -0800 Subject: [PATCH 346/491] grpc testing --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 49991410..f25b980d 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -124,7 +124,7 @@ async def infer_prompt( if is_finished: self.past_input_ids = None - print(f"shard as numpy: {shard_hidden_states.detach().cpu().numpy()}") + #print(f"shard as numpy: {shard_hidden_states.detach().cpu().numpy()}") return_values = ( input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), From 4a292680814039cd07d551ea78d56095cc179f49 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 00:46:21 -0800 Subject: [PATCH 347/491] testing passing hidden states in inference_state --- exo/inference/pytorch/inference.py | 55 +++++++----- exo/inference/pytorch/model/hf.py | 133 +++++++++++++++-------------- 2 files changed, 104 insertions(+), 84 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f25b980d..ff9f5c8d 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -66,6 +66,8 @@ async def infer_prompt( ) -> Tuple[np.ndarray, str, bool]: if DEBUG >= 4: print("infer_prompt called") + print(f"prompt: {prompt}") + print(f"shard: {shard}") await self.ensure_shard(shard) @@ -103,6 +105,10 @@ async def infer_prompt( print(f"\nshard_past_kvs {shard_past_kvs}\n") print(f"\nshard_logits: {shard_logits}") + hidden_dict = None + if shard_hidden_states is not None: + hidden_dict = {"hidden_states": shard_hidden_states.tolist()} + if shard_logits is not None: next_token = self.stateful_sharded_model.logits_sample(shard_logits) self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) @@ -117,18 +123,15 @@ async def infer_prompt( cache_dict = None stopping_critera = self.stateful_sharded_model.stopping_critera - print("set stopping critera") self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id if is_finished: self.past_input_ids = None - #print(f"shard as numpy: {shard_hidden_states.detach().cpu().numpy()}") - return_values = ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps(cache_dict), + input_ids.numpy(force=True), #if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps([cache_dict, hidden_dict]), is_finished ) @@ -147,33 +150,47 @@ async def infer_tensor( if DEBUG >= 4: print("infer_tensor called") print(f"input_data: {input_data}") - print(f"input_data.size: {input_data.size}") - print(f"input_data.shape: {input_data.shape}") - print(f"shard: {self.shard}") + print(f"shard: {shard}") await self.ensure_shard(shard) - - input_ids = torch.tensor(input_data).long().to(self.device) - if self.past_input_ids is not None: - self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) - else: - self.past_input_ids = input_ids + infer_state = json.loads(inference_state) if inference_state else None - if inference_state is not None: - past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) + # if in the middle of generation, pass an empty (1,0) array + # while using hidden_states passed via inference_state + hidden_states = None + if input_data.shape == (1,0) and infer_state is not None: + # set hidden_states to input_ids + hidden_states = torch.tensor(infer_state[1]["hidden_states"]) + input_ids = torch.tensor([[]]).to(self.device) # empty tensor else: - past_kvs = None + input_ids = torch.tensor(input_data).long().to(self.device) + + if self.past_input_ids is not None: + self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) + else: + self.past_input_ids = input_ids + + if inference_state is not None: + past_kvs = DynamicCache.from_legacy_cache(infer_state[0]) + else: + past_kvs = None if DEBUG >= 4: print(f"input_ids: {input_ids}") print(f"inference_state: {inference_state}") + print(f"infer_state: {infer_state}") shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=self.past_input_ids, + hidden_states=hidden_states, past_key_values=past_kvs ) + hidden_dict = None + if shard_hidden_states is not None: + hidden_dict = {"hidden_states": shard_hidden_states.tolist()} + if shard_logits is not None: input_ids = self.stateful_sharded_model.logits_sample(shard_logits) @@ -195,8 +212,8 @@ async def infer_tensor( print(f"\nshard_logits: {shard_logits}") return_values = ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps(cache_dict), + input_ids.numpy(force=True), #if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps([cache_dict, hidden_dict]), is_finished ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 61928fe3..7898a5de 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -98,6 +98,7 @@ def forward( self, shard: Optional[Shard] = None, input_ids: Optional[torch.tensor] = None, + hidden_states: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, use_legacy_cache: Optional[bool] = False @@ -124,58 +125,60 @@ def forward( - logits: tensor Optional """ + if hidden_states is not None: + self.hidden_states = hidden_states + else: + self.input_ids = input_ids - self.input_ids = input_ids - - # embed input_ids - self.inputs_embeds = self.model.embed_tokens(self.input_ids) + # embed input_ids + self.inputs_embeds = self.model.embed_tokens(self.input_ids) - # cache - if past_key_values and not isinstance(past_key_values, Cache): - use_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + self.inputs_embeds.shape[1], - device=self.inputs_embeds.device - ) - - # position id - position_ids = cache_position.unsqueeze(0) - - # casual mask and attention_mask - self.attention_mask = attention_mask - self.causal_mask = self.model._update_causal_mask( - None, - self.inputs_embeds, - cache_position, - past_key_values, - False # dont out attentions - ) - - # embed positions, some models require and some dont - if isinstance(self.model, LlamaModel): - self.position_embeddings = self.model.rotary_emb( + # cache + if past_key_values and not isinstance(past_key_values, Cache): + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + position_ids = cache_position.unsqueeze(0) + + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, self.inputs_embeds, - position_ids + cache_position, + past_key_values, + False # dont out attentions ) - - # prepare inputs for decoder layers - model_inputs = self.llm_model.prepare_inputs_for_generation( - self.input_ids, - past_key_values=past_key_values, - attention_mask=self.attention_mask, - inputs_embeds=self.inputs_embeds, - position_ids=position_ids, - cache_position=cache_position - ) - self.hidden_states = self.inputs_embeds - self.position_ids = model_inputs["position_ids"] - self.cache_position = model_inputs["cache_position"] - self.past_key_values = model_inputs["past_key_values"] + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( + self.inputs_embeds, + position_ids + ) + + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=position_ids, + cache_position=cache_position + ) + + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] # run through decoder layers layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) @@ -212,24 +215,24 @@ def forward( # handle last layer to get logits # shard is last layer says true at the start and not detecting last layer correctly - #if self.shard.is_last_layer(): - self.hidden_states = self.model.norm(self.hidden_states) - if use_legacy_cache: - self.past_key_values = self.next_decoder_cache.to_legacy_cache() - else: - self.past_key_values = self.next_decoder_cache - - # lm_head - logits = self.llm_model.lm_head(self.hidden_states).to(self.device) - - if DEBUG >= 4: - print(f"logits: {logits}") - - return ( - None, - None, - logits - ) + if self.shard.is_last_layer(): + self.hidden_states = self.model.norm(self.hidden_states) + if use_legacy_cache: + self.past_key_values = self.next_decoder_cache.to_legacy_cache() + else: + self.past_key_values = self.next_decoder_cache + + # lm_head + logits = self.llm_model.lm_head(self.hidden_states).to(self.device) + + if DEBUG >= 4: + print(f"logits: {logits}") + + return ( + None, + None, + logits + ) if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") From 2daf65f78f15df57accc7b54582d4ba67c6afa93 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 00:52:42 -0800 Subject: [PATCH 348/491] testing passing hidden states in inference_state --- exo/inference/pytorch/inference.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ff9f5c8d..75017208 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -126,8 +126,12 @@ async def infer_prompt( self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id - if is_finished: - self.past_input_ids = None + out_infer_state = json.dumps([cache_dict, hidden_dict]) + if DEBUG >= 4: + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + print(f"\nout_infer_state: {out_infer_state}") return_values = ( input_ids.numpy(force=True), #if shard_logits is not None else shard_hidden_states.numpy(force=True), @@ -206,14 +210,16 @@ async def infer_tensor( self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id + out_infer_state = json.dumps([cache_dict, hidden_dict]) if DEBUG >= 4: print(f"\nshard_hidden_states: {shard_hidden_states}\n") print(f"\nshard_past_kvs {shard_past_kvs}\n") print(f"\nshard_logits: {shard_logits}") + print(f"\nout_infer_state: {out_infer_state}") return_values = ( input_ids.numpy(force=True), #if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps([cache_dict, hidden_dict]), + out_infer_state, is_finished ) From 36d5cde3805e5faccee5d029dc4543271e84b59d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 01:17:15 -0800 Subject: [PATCH 349/491] fixing scalar issue, reversing passing hidden_states --- exo/inference/pytorch/inference.py | 64 ++++++++++++++---------------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 75017208..f4d3e323 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -70,7 +70,7 @@ async def infer_prompt( print(f"shard: {shard}") await self.ensure_shard(shard) - + # setup prompt input messages = [{"role": "user", "content": prompt}] txt = self.tokenizer.apply_chat_template( @@ -90,7 +90,6 @@ async def infer_prompt( else: past_kvs = None - if DEBUG >= 4: print(f"input_ids: {input_ids}\n") @@ -105,10 +104,7 @@ async def infer_prompt( print(f"\nshard_past_kvs {shard_past_kvs}\n") print(f"\nshard_logits: {shard_logits}") - hidden_dict = None - if shard_hidden_states is not None: - hidden_dict = {"hidden_states": shard_hidden_states.tolist()} - + next_token = None if shard_logits is not None: next_token = self.stateful_sharded_model.logits_sample(shard_logits) self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) @@ -124,18 +120,22 @@ async def infer_prompt( stopping_critera = self.stateful_sharded_model.stopping_critera self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) - is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id - out_infer_state = json.dumps([cache_dict, hidden_dict]) + hit_eos = False + if next_token is not None: + hit_eos = next_token.item() == self.tokenizer.eos_token_id + + is_finished = self.unfinished_sequences.max() == 0 or hit_eos + if DEBUG >= 4: + print(f"\ninput_ids: {input_ids}") print(f"\nshard_hidden_states: {shard_hidden_states}\n") print(f"\nshard_past_kvs {shard_past_kvs}\n") print(f"\nshard_logits: {shard_logits}") - print(f"\nout_infer_state: {out_infer_state}") return_values = ( - input_ids.numpy(force=True), #if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps([cache_dict, hidden_dict]), + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps(cache_dict), is_finished ) @@ -158,27 +158,17 @@ async def infer_tensor( await self.ensure_shard(shard) - infer_state = json.loads(inference_state) if inference_state else None + input_ids = torch.tensor(input_data).long().to(self.device) - # if in the middle of generation, pass an empty (1,0) array - # while using hidden_states passed via inference_state - hidden_states = None - if input_data.shape == (1,0) and infer_state is not None: - # set hidden_states to input_ids - hidden_states = torch.tensor(infer_state[1]["hidden_states"]) - input_ids = torch.tensor([[]]).to(self.device) # empty tensor + if self.past_input_ids is not None: + self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) else: - input_ids = torch.tensor(input_data).long().to(self.device) + self.past_input_ids = input_ids - if self.past_input_ids is not None: - self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) - else: - self.past_input_ids = input_ids - - if inference_state is not None: - past_kvs = DynamicCache.from_legacy_cache(infer_state[0]) - else: - past_kvs = None + if inference_state is not None: + past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) + else: + past_kvs = None if DEBUG >= 4: print(f"input_ids: {input_ids}") @@ -195,8 +185,10 @@ async def infer_tensor( if shard_hidden_states is not None: hidden_dict = {"hidden_states": shard_hidden_states.tolist()} + next_token is not None if shard_logits is not None: - input_ids = self.stateful_sharded_model.logits_sample(shard_logits) + next_token = self.stateful_sharded_model.logits_sample(shard_logits) + input_ids = next_token if shard_past_kvs is not None: cache_dict = { @@ -208,18 +200,22 @@ async def infer_tensor( stopping_critera = self.stateful_sharded_model.stopping_critera self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) - is_finished = self.unfinished_sequences.max() == 0 or input_ids.item() == self.tokenizer.eos_token_id + + hit_eos = False + if next_token is not None: + hit_eos = next_token.item() == self.tokenizer.eos_token_id + + is_finished = self.unfinished_sequences.max() == 0 or hit_eos - out_infer_state = json.dumps([cache_dict, hidden_dict]) if DEBUG >= 4: + print(f"\ninput_ids: {input_ids}") print(f"\nshard_hidden_states: {shard_hidden_states}\n") print(f"\nshard_past_kvs {shard_past_kvs}\n") print(f"\nshard_logits: {shard_logits}") - print(f"\nout_infer_state: {out_infer_state}") return_values = ( input_ids.numpy(force=True), #if shard_logits is not None else shard_hidden_states.numpy(force=True), - out_infer_state, + json.dumps(cache_dict), is_finished ) From 6917f303b2b549e8ca4ec027231d199ea05b8b1a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 01:19:40 -0800 Subject: [PATCH 350/491] inference bug fix, grpc testing --- exo/inference/pytorch/inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f4d3e323..0d148bc8 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -173,7 +173,6 @@ async def infer_tensor( if DEBUG >= 4: print(f"input_ids: {input_ids}") print(f"inference_state: {inference_state}") - print(f"infer_state: {infer_state}") shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=self.past_input_ids, From adab336431fb64483634e0c330f4a8acfae90f36 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 01:25:59 -0800 Subject: [PATCH 351/491] inference bug fix, grpc testing --- exo/inference/pytorch/inference.py | 1 - exo/inference/pytorch/model/hf.py | 95 +++++++++++++++--------------- 2 files changed, 46 insertions(+), 50 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 0d148bc8..e8920e08 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -176,7 +176,6 @@ async def infer_tensor( shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=self.past_input_ids, - hidden_states=hidden_states, past_key_values=past_kvs ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 7898a5de..d040418f 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -98,7 +98,6 @@ def forward( self, shard: Optional[Shard] = None, input_ids: Optional[torch.tensor] = None, - hidden_states: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, use_legacy_cache: Optional[bool] = False @@ -125,60 +124,58 @@ def forward( - logits: tensor Optional """ - if hidden_states is not None: - self.hidden_states = hidden_states - else: - self.input_ids = input_ids - # embed input_ids - self.inputs_embeds = self.model.embed_tokens(self.input_ids) + self.input_ids = input_ids + + # embed input_ids + self.inputs_embeds = self.model.embed_tokens(self.input_ids) - # cache - if past_key_values and not isinstance(past_key_values, Cache): - use_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + self.inputs_embeds.shape[1], - device=self.inputs_embeds.device - ) - - # position id - position_ids = cache_position.unsqueeze(0) + # cache + if past_key_values and not isinstance(past_key_values, Cache): + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + position_ids = cache_position.unsqueeze(0) + + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, + self.inputs_embeds, + cache_position, + past_key_values, + False # dont out attentions + ) - # casual mask and attention_mask - self.attention_mask = attention_mask - self.causal_mask = self.model._update_causal_mask( - None, + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( self.inputs_embeds, - cache_position, - past_key_values, - False # dont out attentions - ) - - # embed positions, some models require and some dont - if isinstance(self.model, LlamaModel): - self.position_embeddings = self.model.rotary_emb( - self.inputs_embeds, - position_ids - ) - - # prepare inputs for decoder layers - model_inputs = self.llm_model.prepare_inputs_for_generation( - self.input_ids, - past_key_values=past_key_values, - attention_mask=self.attention_mask, - inputs_embeds=self.inputs_embeds, - position_ids=position_ids, - cache_position=cache_position + position_ids ) + + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=position_ids, + cache_position=cache_position + ) - self.hidden_states = self.inputs_embeds - self.position_ids = model_inputs["position_ids"] - self.cache_position = model_inputs["cache_position"] - self.past_key_values = model_inputs["past_key_values"] + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] # run through decoder layers layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) From 73146dd0c862f18bfda73c5068482fb079b70022 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 01:47:12 -0800 Subject: [PATCH 352/491] fixing hf model for hidden_states --- exo/inference/pytorch/model/hf.py | 98 ++++++++++++++++--------------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index d040418f..c734f5da 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -98,6 +98,7 @@ def forward( self, shard: Optional[Shard] = None, input_ids: Optional[torch.tensor] = None, + hidden_states: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, use_legacy_cache: Optional[bool] = False @@ -124,58 +125,61 @@ def forward( - logits: tensor Optional """ - - self.input_ids = input_ids - - # embed input_ids - self.inputs_embeds = self.model.embed_tokens(self.input_ids) - - # cache - if past_key_values and not isinstance(past_key_values, Cache): - use_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + self.inputs_embeds.shape[1], - device=self.inputs_embeds.device - ) - # position id - position_ids = cache_position.unsqueeze(0) - - # casual mask and attention_mask - self.attention_mask = attention_mask - self.causal_mask = self.model._update_causal_mask( - None, - self.inputs_embeds, - cache_position, - past_key_values, - False # dont out attentions - ) + if hidden_states is not None: + self.hidden_states = hidden_states + else: + self.input_ids = input_ids + + # embed input_ids + self.inputs_embeds = self.model.embed_tokens(self.input_ids) + + # cache + if past_key_values and not isinstance(past_key_values, Cache): + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + position_ids = cache_position.unsqueeze(0) - # embed positions, some models require and some dont - if isinstance(self.model, LlamaModel): - self.position_embeddings = self.model.rotary_emb( + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, self.inputs_embeds, - position_ids + cache_position, + past_key_values, + False # dont out attentions + ) + + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( + self.inputs_embeds, + position_ids + ) + + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=position_ids, + cache_position=cache_position ) - - # prepare inputs for decoder layers - model_inputs = self.llm_model.prepare_inputs_for_generation( - self.input_ids, - past_key_values=past_key_values, - attention_mask=self.attention_mask, - inputs_embeds=self.inputs_embeds, - position_ids=position_ids, - cache_position=cache_position - ) - self.hidden_states = self.inputs_embeds - self.position_ids = model_inputs["position_ids"] - self.cache_position = model_inputs["cache_position"] - self.past_key_values = model_inputs["past_key_values"] + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] # run through decoder layers layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) From 929386d0049baf1c348c1e1dad7cb194f49c3632 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 01:49:19 -0800 Subject: [PATCH 353/491] fixing hf model for hidden_states --- exo/inference/pytorch/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index e8920e08..0d148bc8 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -176,6 +176,7 @@ async def infer_tensor( shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=self.past_input_ids, + hidden_states=hidden_states, past_key_values=past_kvs ) From 32b8f67af5d426011d536b54589b2e8888548135 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 01:58:13 -0800 Subject: [PATCH 354/491] fixing hf model for hidden_states --- exo/inference/pytorch/inference.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 0d148bc8..683a8123 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -160,10 +160,15 @@ async def infer_tensor( input_ids = torch.tensor(input_data).long().to(self.device) - if self.past_input_ids is not None: - self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) + # detect if hidden_states or not + hidden_states = None + if input_ids.size()[-1] > 1: + hidden_states = input_ids else: - self.past_input_ids = input_ids + if self.past_input_ids is not None: + self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) + else: + self.past_input_ids = input_ids if inference_state is not None: past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) From c86facb91529d2f883ed0b06a0b97bd9dc278d87 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 02:01:33 -0800 Subject: [PATCH 355/491] fixing hf model for hidden_states --- exo/inference/pytorch/model/hf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index c734f5da..edd105e5 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -126,6 +126,8 @@ def forward( """ + model_inputs = None + if hidden_states is not None: self.hidden_states = hidden_states else: From d15b20d551df00821516f9bd3bc1945ace10feaf Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 02:03:17 -0800 Subject: [PATCH 356/491] fixing hf model for hidden_states --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 683a8123..cefd0170 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -189,7 +189,7 @@ async def infer_tensor( if shard_hidden_states is not None: hidden_dict = {"hidden_states": shard_hidden_states.tolist()} - next_token is not None + next_token = None if shard_logits is not None: next_token = self.stateful_sharded_model.logits_sample(shard_logits) input_ids = next_token From 5e41bc4d5315e2664cb9f0a1bfe3565fa7c66df3 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 02:06:33 -0800 Subject: [PATCH 357/491] fixing hf model for hidden_states --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index cefd0170..b2414677 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -113,7 +113,7 @@ async def infer_prompt( if shard_past_kvs is not None: cache_dict = { 'key_cache': [tensor.tolist() for tensor in shard_past_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in shard_past_kvs_kvs.value_cache] + 'value_cache': [tensor.tolist() for tensor in shard_past_kvs.value_cache] } else: cache_dict = None @@ -197,7 +197,7 @@ async def infer_tensor( if shard_past_kvs is not None: cache_dict = { 'key_cache': [tensor.tolist() for tensor in shard_past_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in shard_past_kvs_kvs.value_cache] + 'value_cache': [tensor.tolist() for tensor in shard_past_kvs.value_cache] } else: cache_dict = None From b29c5f807aa0b5c37db0cae3046720f3a6137710 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 02:21:30 -0800 Subject: [PATCH 358/491] fixing hf model for hidden_states --- exo/inference/pytorch/inference.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index b2414677..033e02f1 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -68,6 +68,7 @@ async def infer_prompt( print("infer_prompt called") print(f"prompt: {prompt}") print(f"shard: {shard}") + print(f"inference_state: {inference_state}") await self.ensure_shard(shard) @@ -118,14 +119,9 @@ async def infer_prompt( else: cache_dict = None - stopping_critera = self.stateful_sharded_model.stopping_critera - self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) - - hit_eos = False + is_finished = False if next_token is not None: - hit_eos = next_token.item() == self.tokenizer.eos_token_id - - is_finished = self.unfinished_sequences.max() == 0 or hit_eos + is_finished = next_token.item() == self.tokenizer.eos_token_id if DEBUG >= 4: print(f"\ninput_ids: {input_ids}") @@ -155,6 +151,7 @@ async def infer_tensor( print("infer_tensor called") print(f"input_data: {input_data}") print(f"shard: {shard}") + print(f"inference_state: {inference_state}") await self.ensure_shard(shard) @@ -202,14 +199,14 @@ async def infer_tensor( else: cache_dict = None - stopping_critera = self.stateful_sharded_model.stopping_critera - self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) + #stopping_critera = self.stateful_sharded_model.stopping_critera + #self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) - hit_eos = False + is_finished = False if next_token is not None: - hit_eos = next_token.item() == self.tokenizer.eos_token_id + is_finished = next_token.item() == self.tokenizer.eos_token_id - is_finished = self.unfinished_sequences.max() == 0 or hit_eos + #is_finished = self.unfinished_sequences.max() == 0 or hit_eos if DEBUG >= 4: print(f"\ninput_ids: {input_ids}") From ddaa79c5354cc05595756cf97645be033efc5b06 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 02:32:24 -0800 Subject: [PATCH 359/491] fixing kvcache issue --- exo/inference/pytorch/inference.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 033e02f1..4c334f3a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -87,7 +87,11 @@ async def infer_prompt( if inference_state is not None: - past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) + infer_state = json.loads(inference_state) + if len(infer_state["key_cache"]) == 0: + past_kvs = DynamicCache() + else: + past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) else: past_kvs = None @@ -168,7 +172,11 @@ async def infer_tensor( self.past_input_ids = input_ids if inference_state is not None: - past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) + infer_state = json.loads(inference_state) + if len(infer_state["key_cache"]) == 0: + past_kvs = DynamicCache() + else: + past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) else: past_kvs = None From 3164d3859aab1f8f20674495b6eb8b2a14afc8fb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 02:36:53 -0800 Subject: [PATCH 360/491] fixing kvcache issue --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 4c334f3a..a2f7efb4 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -88,7 +88,7 @@ async def infer_prompt( if inference_state is not None: infer_state = json.loads(inference_state) - if len(infer_state["key_cache"]) == 0: + if not infer_state: past_kvs = DynamicCache() else: past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) @@ -173,7 +173,7 @@ async def infer_tensor( if inference_state is not None: infer_state = json.loads(inference_state) - if len(infer_state["key_cache"]) == 0: + if not infer_state: past_kvs = DynamicCache() else: past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) From e8532bc5934f30d41c500b6c1d16549799c921bf Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 02:43:03 -0800 Subject: [PATCH 361/491] fixing kvcache issue --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a2f7efb4..d98717b9 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -88,7 +88,7 @@ async def infer_prompt( if inference_state is not None: infer_state = json.loads(inference_state) - if not infer_state: + if not infer_state or (infer_state and len(infer_state["key_cache"] == 0)): past_kvs = DynamicCache() else: past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) @@ -173,7 +173,7 @@ async def infer_tensor( if inference_state is not None: infer_state = json.loads(inference_state) - if not infer_state: + if not infer_state or (infer_state and len(infer_state["key_cache"] == 0)): past_kvs = DynamicCache() else: past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) From 6a5b8db2d657d344e3fa4b56139f4442807888c4 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 02:45:16 -0800 Subject: [PATCH 362/491] fixing kvcache issue --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index d98717b9..ecb02798 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -88,7 +88,7 @@ async def infer_prompt( if inference_state is not None: infer_state = json.loads(inference_state) - if not infer_state or (infer_state and len(infer_state["key_cache"] == 0)): + if not infer_state or (infer_state and len(infer_state["key_cache"]) == 0): past_kvs = DynamicCache() else: past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) @@ -173,7 +173,7 @@ async def infer_tensor( if inference_state is not None: infer_state = json.loads(inference_state) - if not infer_state or (infer_state and len(infer_state["key_cache"] == 0)): + if not infer_state or (infer_state and len(infer_state["key_cache"]) == 0): past_kvs = DynamicCache() else: past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) From 515687d76c7eb951441f1ca0da06f64994358d36 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 17 Sep 2024 23:38:39 -0800 Subject: [PATCH 363/491] working on passing past input_ids between infers and nodes --- exo/inference/pytorch/inference.py | 54 +++++++++++++++++++++++------- exo/inference/pytorch/model/hf.py | 9 ++++- 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index ecb02798..7fc09a75 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -161,25 +161,48 @@ async def infer_tensor( input_ids = torch.tensor(input_data).long().to(self.device) + # setup cache and cached input_ids + past_kvs = None + past_iids = None + cached_iids = None + if inference_state is not None: + try: + infer_state = json.loads(inference_state) + except ValueError: + infer_state = None + + if infer_state is not None: + # setup cache + cached_kvs = infer_state[0] + if not cached_kvs or (cached_kvs and len(cached_kvs["key_cache"]) == 0): + past_kvs = DynamicCache() + else: + past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) + + # setup cached input_ids with one coming in, if any cached + cached_iids = infer_state[1] + if cached_iids is not None: + past_iids = None + if len(cached_iids) > 0: + cached_iids_tensor = torch.tensor(cached_iids["input_ids"]).to(self.device) + past_iids = torch.cat([cached_iids_tensor, input_ids], dim=-1).to(self.device) + cached_iids = {"input_ids": past_iids.tolist()} + + if DEBUG >= 4: + print(f"past_kvs: {past_kvs}") + print(f"cached_iids: {cached_iids}") + # detect if hidden_states or not hidden_states = None if input_ids.size()[-1] > 1: hidden_states = input_ids + self.past_input_ids = None else: - if self.past_input_ids is not None: - self.past_input_ids = torch.cat([self.past_input_ids, input_ids], dim=-1) + if past_iids is not None: + self.past_input_ids = past_iids else: self.past_input_ids = input_ids - - if inference_state is not None: - infer_state = json.loads(inference_state) - if not infer_state or (infer_state and len(infer_state["key_cache"]) == 0): - past_kvs = DynamicCache() - else: - past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) - else: - past_kvs = None - + if DEBUG >= 4: print(f"input_ids: {input_ids}") print(f"inference_state: {inference_state}") @@ -199,6 +222,7 @@ async def infer_tensor( next_token = self.stateful_sharded_model.logits_sample(shard_logits) input_ids = next_token + #cache if shard_past_kvs is not None: cache_dict = { 'key_cache': [tensor.tolist() for tensor in shard_past_kvs.key_cache], @@ -207,6 +231,10 @@ async def infer_tensor( else: cache_dict = None + if self.past_input_ids is not None: + next_cached_logits = torch.cat([self.past_input_ids, input_ids], dim=-1).to(self.device) + cached_iids = {"input_ids": next_cached_logits.tolist()} + #stopping_critera = self.stateful_sharded_model.stopping_critera #self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) @@ -224,7 +252,7 @@ async def infer_tensor( return_values = ( input_ids.numpy(force=True), #if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps(cache_dict), + json.dumps([cache_dict, cached_iids]), is_finished ) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index edd105e5..62e30a08 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -101,7 +101,7 @@ def forward( hidden_states: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_legacy_cache: Optional[bool] = False + use_legacy_cache: bool = False ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: """ @@ -198,6 +198,11 @@ def forward( print(f"decoder_layer: {decoder_layer}") print(f"hidden_states: {self.hidden_states}") + # TODO: fix caching as decoder layer is not returning + # present_key_value from attention layer on models + # might have some other generation functions needed to do it + # see https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2917 + # for qwen2 exhttps://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py#L291 layer_outputs = decoder_layer( self.hidden_states, attention_mask=self.causal_mask, @@ -212,6 +217,8 @@ def forward( if DEBUG >= 5: print("decoder_layer after") + print(f"layer_outputs: {layer_outputs}\n") + print(f"self.next_decoder_cache: {self.next_decoder_cache}") print(f"hidden_states: {self.hidden_states}") print(f"next_decoder_cache: {self.next_decoder_cache}") From 92ebdd5f0d1bf464dfe53cc671c18fa33cbc309b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 19 Sep 2024 03:34:03 -0800 Subject: [PATCH 364/491] implemented infer caching and passing cache information via inference_state --- exo/inference/pytorch/inference.py | 110 +++++++++++++++++------------ exo/inference/pytorch/model/hf.py | 5 +- 2 files changed, 67 insertions(+), 48 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 7fc09a75..302597a3 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -9,7 +9,7 @@ from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG -from transformers import DynamicCache +from transformers import DynamicCache, Cache from accelerate import disk_offload from exo.download.shard_download import ShardDownloader @@ -56,6 +56,47 @@ def __init__(self, shard_downloader: ShardDownloader): # setup unfinished sequence self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device) + def infer_caching( + self, + inference_state: Optional[str] = None + ) -> Tuple[Optional[Cache], Optional[torch.tensor], Optional[dict]]: + """ + inference caching for past_kvs and cached input_ids + user json inference_state + """ + # setup cache and cached input_ids + past_kvs = None + past_iids = None + cached_iids = None + if inference_state is not None: + try: + infer_state = json.loads(inference_state) + except ValueError: + infer_state = None + + if infer_state is not None: + # setup cache + cached_kvs = infer_state[0] + if not cached_kvs or (cached_kvs and len(cached_kvs["key_cache"]) == 0): + past_kvs = DynamicCache() + else: + past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) + + # setup cached input_ids with one coming in, if any cached + cached_iids = infer_state[1] + if cached_iids is not None: + past_iids = None + if len(cached_iids) > 0: + past_iids = torch.tensor(cached_iids["input_ids"]).to(self.device) + cached_iids = {"input_ids": past_iids.tolist()} + + if DEBUG >= 4: + print(f"past_kvs: {past_kvs}") + print(f"cached_iids: {cached_iids}") + + return (past_kvs, past_iids, cached_iids) + + async def infer_prompt( self, request_id: str, @@ -86,20 +127,19 @@ async def infer_prompt( batch_size, seq_length = input_ids.shape[:2] - if inference_state is not None: - infer_state = json.loads(inference_state) - if not infer_state or (infer_state and len(infer_state["key_cache"]) == 0): - past_kvs = DynamicCache() - else: - past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) + # get cache from inference_state + past_kvs, past_iids, cached_iids = self.infer_caching(inference_state) + + if past_iids is not None: + self.past_input_ids = past_iids, else: - past_kvs = None + self.past_input_ids = input_ids if DEBUG >= 4: - print(f"input_ids: {input_ids}\n") + print(f"past_input_ids: {self.past_input_ids}\n") shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( - input_ids=input_ids, + input_ids=self.past_input_ids, attention_mask=input_attention_mask, past_key_values=past_kvs ) @@ -115,6 +155,7 @@ async def infer_prompt( self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) input_ids = next_token + # cache if shard_past_kvs is not None: cache_dict = { 'key_cache': [tensor.tolist() for tensor in shard_past_kvs.key_cache], @@ -123,6 +164,9 @@ async def infer_prompt( else: cache_dict = None + if self.past_input_ids is not None: + cached_iids = {"input_ids": self.past_input_ids.tolist()} + is_finished = False if next_token is not None: is_finished = next_token.item() == self.tokenizer.eos_token_id @@ -135,7 +179,7 @@ async def infer_prompt( return_values = ( input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps(cache_dict), + json.dumps([cache_dict, cached_iids]), is_finished ) @@ -161,47 +205,19 @@ async def infer_tensor( input_ids = torch.tensor(input_data).long().to(self.device) - # setup cache and cached input_ids - past_kvs = None - past_iids = None - cached_iids = None - if inference_state is not None: - try: - infer_state = json.loads(inference_state) - except ValueError: - infer_state = None - - if infer_state is not None: - # setup cache - cached_kvs = infer_state[0] - if not cached_kvs or (cached_kvs and len(cached_kvs["key_cache"]) == 0): - past_kvs = DynamicCache() - else: - past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) - - # setup cached input_ids with one coming in, if any cached - cached_iids = infer_state[1] - if cached_iids is not None: - past_iids = None - if len(cached_iids) > 0: - cached_iids_tensor = torch.tensor(cached_iids["input_ids"]).to(self.device) - past_iids = torch.cat([cached_iids_tensor, input_ids], dim=-1).to(self.device) - cached_iids = {"input_ids": past_iids.tolist()} - - if DEBUG >= 4: - print(f"past_kvs: {past_kvs}") - print(f"cached_iids: {cached_iids}") - + # get cache from inference_state + past_kvs, past_iids, cached_iids = self.infer_caching(inference_state) + # detect if hidden_states or not hidden_states = None if input_ids.size()[-1] > 1: hidden_states = input_ids - self.past_input_ids = None + #self.past_input_ids = None + #else: + if past_iids is not None: + self.past_input_ids = past_iids else: - if past_iids is not None: - self.past_input_ids = past_iids - else: - self.past_input_ids = input_ids + self.past_input_ids = input_ids if DEBUG >= 4: print(f"input_ids: {input_ids}") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 62e30a08..1481c40c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -127,6 +127,7 @@ def forward( """ model_inputs = None + self.hidden_states = None if hidden_states is not None: self.hidden_states = hidden_states @@ -183,12 +184,14 @@ def forward( self.cache_position = model_inputs["cache_position"] self.past_key_values = model_inputs["past_key_values"] + if DEBUG >= 4: + print(f"model_inputs: {model_inputs}") + # run through decoder layers layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") - print(f"model_inputs: {model_inputs}") print(f"layer_amt: {layer_amt}") for i in layer_amt: From f0795bd17cfa1cbf1e491c1dfc2da7cb3a6d0824 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 19 Sep 2024 03:59:15 -0800 Subject: [PATCH 365/491] removing dynamic cache passing in inference_state as model does its own, added cleaning out cached_iids when process is finished --- exo/inference/pytorch/inference.py | 76 ++++++++++-------------------- 1 file changed, 24 insertions(+), 52 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 302597a3..fa16ccb5 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -59,13 +59,11 @@ def __init__(self, shard_downloader: ShardDownloader): def infer_caching( self, inference_state: Optional[str] = None - ) -> Tuple[Optional[Cache], Optional[torch.tensor], Optional[dict]]: + ) -> Tuple[Optional[torch.tensor], Optional[dict]]: """ - inference caching for past_kvs and cached input_ids - user json inference_state + inference caching from inference_state json """ # setup cache and cached input_ids - past_kvs = None past_iids = None cached_iids = None if inference_state is not None: @@ -75,15 +73,7 @@ def infer_caching( infer_state = None if infer_state is not None: - # setup cache - cached_kvs = infer_state[0] - if not cached_kvs or (cached_kvs and len(cached_kvs["key_cache"]) == 0): - past_kvs = DynamicCache() - else: - past_kvs = DynamicCache.from_legacy_cache(json.loads(inference_state)) - - # setup cached input_ids with one coming in, if any cached - cached_iids = infer_state[1] + cached_iids = infer_state["cached_iids"] if cached_iids is not None: past_iids = None if len(cached_iids) > 0: @@ -91,10 +81,9 @@ def infer_caching( cached_iids = {"input_ids": past_iids.tolist()} if DEBUG >= 4: - print(f"past_kvs: {past_kvs}") print(f"cached_iids: {cached_iids}") - return (past_kvs, past_iids, cached_iids) + return (past_iids, cached_iids) async def infer_prompt( @@ -126,9 +115,8 @@ async def infer_prompt( input_attention_mask = inputs.attention_mask.to(self.device) batch_size, seq_length = input_ids.shape[:2] - # get cache from inference_state - past_kvs, past_iids, cached_iids = self.infer_caching(inference_state) + past_iids, cached_iids = self.infer_caching(inference_state) if past_iids is not None: self.past_input_ids = past_iids, @@ -140,8 +128,7 @@ async def infer_prompt( shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=self.past_input_ids, - attention_mask=input_attention_mask, - past_key_values=past_kvs + attention_mask=input_attention_mask ) if DEBUG >= 4: @@ -155,15 +142,6 @@ async def infer_prompt( self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) input_ids = next_token - # cache - if shard_past_kvs is not None: - cache_dict = { - 'key_cache': [tensor.tolist() for tensor in shard_past_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in shard_past_kvs.value_cache] - } - else: - cache_dict = None - if self.past_input_ids is not None: cached_iids = {"input_ids": self.past_input_ids.tolist()} @@ -171,6 +149,10 @@ async def infer_prompt( if next_token is not None: is_finished = next_token.item() == self.tokenizer.eos_token_id + if is_finished: + # clear cache + cached_iids = {"input_ids": []} + if DEBUG >= 4: print(f"\ninput_ids: {input_ids}") print(f"\nshard_hidden_states: {shard_hidden_states}\n") @@ -179,7 +161,7 @@ async def infer_prompt( return_values = ( input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps([cache_dict, cached_iids]), + json.dumps({"cached_iids": cached_iids}), is_finished ) @@ -206,18 +188,18 @@ async def infer_tensor( input_ids = torch.tensor(input_data).long().to(self.device) # get cache from inference_state - past_kvs, past_iids, cached_iids = self.infer_caching(inference_state) + past_iids, cached_iids = self.infer_caching(inference_state) # detect if hidden_states or not hidden_states = None + self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids - #self.past_input_ids = None - #else: - if past_iids is not None: - self.past_input_ids = past_iids else: - self.past_input_ids = input_ids + if past_iids is not None: + self.past_input_ids = past_iids + else: + self.past_input_ids = input_ids if DEBUG >= 4: print(f"input_ids: {input_ids}") @@ -225,8 +207,7 @@ async def infer_tensor( shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=self.past_input_ids, - hidden_states=hidden_states, - past_key_values=past_kvs + hidden_states=hidden_states ) hidden_dict = None @@ -238,27 +219,18 @@ async def infer_tensor( next_token = self.stateful_sharded_model.logits_sample(shard_logits) input_ids = next_token - #cache - if shard_past_kvs is not None: - cache_dict = { - 'key_cache': [tensor.tolist() for tensor in shard_past_kvs.key_cache], - 'value_cache': [tensor.tolist() for tensor in shard_past_kvs.value_cache] - } - else: - cache_dict = None - + #cache if self.past_input_ids is not None: next_cached_logits = torch.cat([self.past_input_ids, input_ids], dim=-1).to(self.device) cached_iids = {"input_ids": next_cached_logits.tolist()} - #stopping_critera = self.stateful_sharded_model.stopping_critera - #self.unfinished_sequences = self.unfinished_sequences & ~stopping_critera(input_ids, None) - is_finished = False if next_token is not None: is_finished = next_token.item() == self.tokenizer.eos_token_id - #is_finished = self.unfinished_sequences.max() == 0 or hit_eos + if is_finished: + # clear cache + cached_iids = {"input_ids": []} if DEBUG >= 4: print(f"\ninput_ids: {input_ids}") @@ -267,8 +239,8 @@ async def infer_tensor( print(f"\nshard_logits: {shard_logits}") return_values = ( - input_ids.numpy(force=True), #if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps([cache_dict, cached_iids]), + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), is_finished ) From b8f15a0e66c6a7a820b116f50b945cdafb8ea953 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 19 Sep 2024 04:00:14 -0800 Subject: [PATCH 366/491] removed clearning cache on infer prompt and only on finished infer tensor --- exo/inference/pytorch/inference.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index fa16ccb5..441fc3ed 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -149,10 +149,6 @@ async def infer_prompt( if next_token is not None: is_finished = next_token.item() == self.tokenizer.eos_token_id - if is_finished: - # clear cache - cached_iids = {"input_ids": []} - if DEBUG >= 4: print(f"\ninput_ids: {input_ids}") print(f"\nshard_hidden_states: {shard_hidden_states}\n") From d0f3cb77022dc024b1bfc9af58e729c0e5b2166b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 19 Sep 2024 04:08:36 -0800 Subject: [PATCH 367/491] hidden state dropping between nodes issue --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 441fc3ed..835ee54a 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -181,7 +181,7 @@ async def infer_tensor( await self.ensure_shard(shard) - input_ids = torch.tensor(input_data).long().to(self.device) + input_ids = torch.tensor(input_data).to(self.device) # get cache from inference_state past_iids, cached_iids = self.infer_caching(inference_state) From fa6f26350a7fcf524f49f9008b7f08e87d876d19 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 19 Sep 2024 04:22:04 -0800 Subject: [PATCH 368/491] hidden state dropping between nodes issue --- exo/inference/pytorch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 835ee54a..0e23e972 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -216,8 +216,8 @@ async def infer_tensor( input_ids = next_token #cache - if self.past_input_ids is not None: - next_cached_logits = torch.cat([self.past_input_ids, input_ids], dim=-1).to(self.device) + if next_token is not None: + next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) cached_iids = {"input_ids": next_cached_logits.tolist()} is_finished = False From 2b0e7b56f72b64e99e0ee32d65ed72bff6baf6e2 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 19 Sep 2024 04:26:45 -0800 Subject: [PATCH 369/491] hidden state dropping between nodes issue --- exo/inference/pytorch/inference.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 0e23e972..f3036e78 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -217,7 +217,11 @@ async def infer_tensor( #cache if next_token is not None: - next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) + if self.past_input_ids is not None: + next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) + elif past_iids is not None: + next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) + cached_iids = {"input_ids": next_cached_logits.tolist()} is_finished = False From cee3e311809390c8d36c0360f3b5a763acb58b99 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 2 Oct 2024 10:37:17 -0800 Subject: [PATCH 370/491] cleaning up code, removing helpers.py --- exo/inference/pytorch/helpers.py | 24 --- exo/inference/pytorch/inference.py | 4 +- .../pytorch/model/archive/hf_manual.py | 203 ------------------ exo/inference/pytorch/model/hf.py | 28 +-- 4 files changed, 2 insertions(+), 257 deletions(-) delete mode 100644 exo/inference/pytorch/helpers.py delete mode 100644 exo/inference/pytorch/model/archive/hf_manual.py diff --git a/exo/inference/pytorch/helpers.py b/exo/inference/pytorch/helpers.py deleted file mode 100644 index addea2db..00000000 --- a/exo/inference/pytorch/helpers.py +++ /dev/null @@ -1,24 +0,0 @@ -# Helper functions for pytorch inference -# Some code coming from tinygrad but written towards pytorch - -import asyncio -import aiohttp -from tqdm import tqdm -from pathlib import Path -from typing import List - -async def fetch_file_async(session, url: str, output_path: Path): - async with session.get(url) as response: - response.raise_for_status() - with open(output_path, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - f.write(chunk) - -async def download_files(urls: List[str], output_paths: List[Path]): - async with aiohttp.ClientSession() as session: - tasks = [] - for url, output_path in zip(urls, output_paths): - tasks.append(fetch_file_async(session, url, output_path)) - - for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Downloading files"): - await f diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index f3036e78..94cea100 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -2,15 +2,13 @@ import numpy as np import torch import json -import gc + from typing import Optional, Tuple from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG -from transformers import DynamicCache, Cache -from accelerate import disk_offload from exo.download.shard_download import ShardDownloader # model value options diff --git a/exo/inference/pytorch/model/archive/hf_manual.py b/exo/inference/pytorch/model/archive/hf_manual.py deleted file mode 100644 index e5af2eaf..00000000 --- a/exo/inference/pytorch/model/archive/hf_manual.py +++ /dev/null @@ -1,203 +0,0 @@ -# Attempted version to recreate manually using LlamaModel and others -# BROKEN -import torch -import numpy as np -from transformers import AutoModelForCausalLM, DynamicCache, Cache, AutoModel -from exo.inference.shard import Shard -from exo.helpers import DEBUG -from typing import Tuple, Optional, Union, List -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from exo.inference.pytorch.model.archive.utils import sample_logits - -TOP_P = 0.7 #0.95 -TOP_K = 50 -TEMP = 0.01 - - -class ShardedHuggingFaceModel(torch.nn.Module): - def __init__(self, shard: Shard): - super(ShardedHuggingFaceModel, self).__init__() - - if torch.cuda.is_available(): - self.device = torch.device("cuda") - else: - self.device = torch.device("cpu") - - self.shard = shard - - # Load the model - try: - self.base_model = AutoModel.from_pretrained( - shard.model_id, - torch_dtype=torch.float32, - device_map="auto", - # offload_buffers=True - ) - - # disk_offload(model=self.base_model, offload_dir="./.offload") - except Exception as err: - print(f"Error loading model: {err}") - raise - - if DEBUG >= 2: - print(f"\nShardedHuggingFaceModel init with shard {shard}") - print(f"self.base_model: {self.base_model}") - - # Embeddings and final layer norm - # used for doing what forward LlamaModel does in transformers - self.norm = self.base_model.norm - self.lm_head = torch.nn.Linear( - self.base_model.config.hidden_size, - self.base_model.config.vocab_size, - bias=False - ).to(self.device) - self.embed_tokens = self.base_model.embed_tokens - - def forward( - self, - input_ids: torch.tensor, - attention_mask: torch.tensor = None, - past_kvs: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - ) -> Tuple[np.ndarray, any]: - """ - Forward through layers using the base model - - Args: - input_ids: tensor input - attention_mask: attention mask from tokenizer - past_kvs: past key value stores for cache - - Returns: - hidden_states: numpy of states between layers - or logits: numpy of normalization and linearization of last hidden state - past_kvs: DynamicCache of past key values if use_cache is true - - Ref: - https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/qwen2/modeling_qwen2.py#L804 - https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L887 - """ - if DEBUG >= 4: - print("forward called") - print(f"input_ids: {input_ids}\n") - print(f"layer_count: {self.shard.get_layer_count()}") - print(f"is_first_layer: {self.shard.is_first_layer()}") - print(f"is_last_layer: {self.shard.is_last_layer()}") - - if self.shard.is_first_layer(): - if DEBUG >= 2: - print("first layer, embed") - print(f"input_ids: {input_ids}") - input_ids = self.embed_tokens(input_ids) - - if DEBUG >= 2: - print(f"embeded input_ids: {input_ids}") - - if attention_mask == None: - # get attention mask - past_kv_length = len(past_kvs) - batch_size, seq_length = input_ids.shape[:2] - attention_mask = _prepare_4d_causal_attention_mask( - None, (batch_size, seq_length), input_ids, past_kv_length - ) - - past_kvs = DynamicCache.from_legacy_cache(past_kvs) - past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + input_ids.shape[1], - device=self.device - ) - - position_ids = cache_position.unsqueeze(0).to(self.device) - - try: - position_embeddings = self.base_model.rotary_emb( - input_ids, - position_ids - ) - except Exception as err: - print(f"rotary_emb not found in base_model") - position_embeddings = None - - causal_mask = self.base_model._update_causal_mask( - attention_mask, - input_ids, - cache_position, - past_kvs, - self.base_model.config.output_attentions - ) - - # progress through layers - for i in range(self.shard.start_layer, self.shard.end_layer + 1): - decoder_layer = self.base_model.layers[i] - - if DEBUG >= 4: - print("Going through layer") - print(f"{decoder_layer}") - print("input_ids") - print(f"{input_ids}") - print("causal_mask") - print(f"{causal_mask}") - - try: - layer_outputs = decoder_layer( - input_ids, - attention_mask=causal_mask, - position_ids=position_ids, - position_embeddings=position_embeddings, - past_key_value=past_kvs, - use_cache=True, - cache_position=cache_position, - output_logits=True - ) - except Exception as err: - print(f"Going through layer failed: {err}") - print(err.__traceback__.tb_lineno) - raise - - hidden_states = layer_outputs[0] - next_kvs = layer_outputs[1] - - if DEBUG >= 3: - print(f"layer_outputs {layer_outputs}") - print(layer_outputs[1:]) - - if self.shard.is_last_layer(): - hs_norm = self.norm(hidden_states).to(self.device) - # hs_lm_head = self.base_model.lm_head(hs_norm).float() - - # Use the sampling function with default settings - with torch.no_grad(): - logits = self.lm_head( - hs_norm[:, -1:, :] - ).to(self.device).float() - - if DEBUG >= 2: - print(f"hs_norm: {hs_norm}") - # print(f"hs_lm_head: {hs_lm_head}") - print(f"logits: {logits}") - print(f"logits.shape: {logits.shape}") - - # output_token = sample_logits( - # logits, - # TEMP, - # TOP_P, - # TOP_K - # ).unsqueeze(0).unsqueeze(0).long() - - output_token = torch.distributions.Categorical( - logits=logits - ).sample(sample_shape=(1,)) - - if DEBUG >= 2: - print(f"output_token: {output_token}") - - return (output_token.numpy(force=True), next_kvs) - - with torch.no_grad(): - out_hidden_states = hidden_states.float().numpy(force=True) - - return ( - out_hidden_states, - next_kvs - ) \ No newline at end of file diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 1481c40c..1b617d7c 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -5,38 +5,20 @@ from exo.inference.shard import Shard from exo.helpers import DEBUG -from exo.inference.inference_engine import InferenceEngine -from exo.download.shard_download import ShardDownloader from transformers import ( - AutoModel, AutoModelForCausalLM, - AutoTokenizer, DynamicCache, Cache, LogitsProcessorList, - #MinLengthLogitsProcessor, - LogitsWarper, TopKLogitsWarper, TopPLogitsWarper, - TemperatureLogitsWarper, - StoppingCriteriaList, - MaxLengthCriteria, - MaxTimeCriteria -) - -from transformers.generation.configuration_utils import ( - GenerationConfig, - GenerationMode + TemperatureLogitsWarper ) # llama from transformers.models.llama.modeling_llama import LlamaModel -# qwen2 -from transformers.models.qwen2.modeling_qwen2 import Qwen2Model - - class ShardedHuggingFaceModel: def __init__( self, @@ -68,14 +50,6 @@ def __init__( TopPLogitsWarper(top_p) ]) - # setup stopping critera for generation - self.stopping_critera = StoppingCriteriaList( - [ - #MaxLengthCriteria(max_length=max_length), - MaxTimeCriteria(max_time=max_time), - ] - ) - self.device = device self.torch_dtype = dtype From 57e14e8cbf6e3d2dbae632a70b8b53c4ec14efcb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 11:08:19 -0800 Subject: [PATCH 371/491] adding needed libs to setup.py, fixing 4 space to 2 space issue, adding in hf downloader to inference engine, testing --- .gitignore | 3 + exo/inference/pytorch/inference.py | 533 +++++++++--------- exo/inference/pytorch/model/hf.py | 30 +- .../pytorch/tests/test_inference_engine.py | 283 ++++------ .../pytorch/tests/test_split_model.py | 25 +- exo/models.py | 12 +- exo/tinychat/index.html | 1 + setup.py | 2 + 8 files changed, 438 insertions(+), 451 deletions(-) diff --git a/.gitignore b/.gitignore index f5609f31..33907f70 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,6 @@ cython_debug/ # PyTorch interface .offload + +# neovim/vim settings +.vimrc diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 94cea100..2f87c1b1 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,4 +1,5 @@ # experimental, based off of tinygrad/inference.py +import os import numpy as np import torch import json @@ -9,9 +10,9 @@ from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG -from exo.download.shard_download import ShardDownloader +from exo.download.hf.hf_shard_download import HFShardDownloader -# model value options +# model value options TOP_K = 20 TEMP = 0.6 TOP_P = 0.9 @@ -19,267 +20,273 @@ MAX_TIME = 60.0 class PyTorchDynamicShardInferenceEngine(InferenceEngine): + """ + PyTorch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. + """ + + def __init__(self, shard_downloader: HFShardDownloader): + """ + Initialize the inference engine. + + Args: + debug (bool): If True, enables debug logging. Defaults to False. """ - PyTorch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. + self.shard = None + self.shard_downloader = shard_downloader + self.stateful_sharded_model = None + self.tokenizer = None + + # the whole history with new logits need to + # be passed to the model to reach the end token + # even with caching + self.past_input_ids = None + + # setup cuda device + if os.environ.get("PYTORCH_DEVICE"): + pytorch_device = os.environ["PYTOCH_DEVICE"] + if pytorch_device not in ["cuda", "mps", "cpu"]: + pytorch_device = "cpu" + + self.device = pytorch_device + self.torch_dtype = torch.float32 if pytorch_device != "cpu" else torch.float16 + + if torch.cuda.is_available(): + self.device = torch.device("cuda") + self.torch_dtype = torch.float32 + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + self.torch_dtype = torch.float32 + else: + self.device = torch.device("cpu") + self.torch_dtype = torch.float16 + + # setup unfinished sequence + self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device) + + def infer_caching( + self, + inference_state: Optional[str] = None + ) -> Tuple[Optional[torch.tensor], Optional[dict]]: """ + inference caching from inference_state json + """ + # setup cache and cached input_ids + past_iids = None + cached_iids = None + if inference_state is not None: + try: + infer_state = json.loads(inference_state) + except ValueError: + infer_state = None + + if infer_state is not None: + cached_iids = infer_state["cached_iids"] + if cached_iids is not None: + past_iids = None + if len(cached_iids) > 0: + past_iids = torch.tensor(cached_iids["input_ids"]).to(self.device) + cached_iids = {"input_ids": past_iids.tolist()} + + if DEBUG >= 4: + print(f"cached_iids: {cached_iids}") + + return (past_iids, cached_iids) + + async def infer_prompt( + self, + request_id: Optional[str] = None, + shard: Optional[Shard] = None, + prompt: Optional[str] = "", + image_str: Optional[str] = None, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_prompt called") + print(f"prompt: {prompt}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + + await self.ensure_shard(shard) + + # setup prompt input + messages = [{"role": "user", "content": prompt}] + txt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + inputs = self.tokenizer([txt], return_tensors="pt") + input_ids = inputs.input_ids.to(self.device) + input_attention_mask = inputs.attention_mask.to(self.device) + batch_size, seq_length = input_ids.shape[:2] + + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) + + if past_iids is not None: + self.past_input_ids = past_iids, + else: + self.past_input_ids = input_ids + + if DEBUG >= 4: + print(f"past_input_ids: {self.past_input_ids}\n") + + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + input_ids=self.past_input_ids, + attention_mask=input_attention_mask + ) + + if DEBUG >= 4: + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + next_token = None + if shard_logits is not None: + next_token = self.stateful_sharded_model.logits_sample(shard_logits) + self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) + input_ids = next_token + + if self.past_input_ids is not None: + cached_iids = {"input_ids": self.past_input_ids.tolist()} + + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id + + if DEBUG >= 4: + print(f"\ninput_ids: {input_ids}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), + is_finished + ) + + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_tensor called") + print(f"input_data: {input_data}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + + await self.ensure_shard(shard) + + input_ids = torch.tensor(input_data).to(self.device) + + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) + + # detect if hidden_states or not + hidden_states = None + self.past_input_ids = None + if input_ids.size()[-1] > 1: + hidden_states = input_ids + else: + if past_iids is not None: + self.past_input_ids = past_iids + else: + self.past_input_ids = input_ids + + if DEBUG >= 4: + print(f"input_ids: {input_ids}") + print(f"inference_state: {inference_state}") + + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + input_ids=self.past_input_ids, + hidden_states=hidden_states + ) + + hidden_dict = None + if shard_hidden_states is not None: + hidden_dict = {"hidden_states": shard_hidden_states.tolist()} + + next_token = None + if shard_logits is not None: + next_token = self.stateful_sharded_model.logits_sample(shard_logits) + input_ids = next_token + + #cache + if next_token is not None: + if self.past_input_ids is not None: + next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) + elif past_iids is not None: + next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) + + cached_iids = {"input_ids": next_cached_logits.tolist()} + + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id + + if is_finished: + # clear cache + cached_iids = {"input_ids": []} + + if DEBUG >= 4: + print(f"\ninput_ids: {input_ids}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), + is_finished + ) + + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + + + async def ensure_shard(self, shard: Optional[Shard]): + """ + Ensure the model shard is loaded and ready for inference. - def __init__(self, shard_downloader: ShardDownloader): - """ - Initialize the inference engine. - - Args: - debug (bool): If True, enables debug logging. Defaults to False. - """ - self.shard = None - self.shard_downloader = shard_downloader - self.stateful_sharded_model = None - self.tokenizer = None - - # the whole history with new logits need to - # be passed to the model to reach the end token - # even with caching - self.past_input_ids = None - - # setup cuda device - if torch.cuda.is_available(): - self.device = torch.device("cuda") - self.torch_dtype = torch.float32 - elif torch.backends.mps.is_available(): - self.device = torch.device("mps") - self.torch_dtype = torch.float32 - else: - self.device = torch.device("cpu") - self.torch_dtype = torch.float16 - - # setup unfinished sequence - self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device) - - def infer_caching( - self, - inference_state: Optional[str] = None - ) -> Tuple[Optional[torch.tensor], Optional[dict]]: - """ - inference caching from inference_state json - """ - # setup cache and cached input_ids - past_iids = None - cached_iids = None - if inference_state is not None: - try: - infer_state = json.loads(inference_state) - except ValueError: - infer_state = None - - if infer_state is not None: - cached_iids = infer_state["cached_iids"] - if cached_iids is not None: - past_iids = None - if len(cached_iids) > 0: - past_iids = torch.tensor(cached_iids["input_ids"]).to(self.device) - cached_iids = {"input_ids": past_iids.tolist()} - - if DEBUG >= 4: - print(f"cached_iids: {cached_iids}") - - return (past_iids, cached_iids) - - - async def infer_prompt( - self, - request_id: str, - shard: Optional[Shard] = None, - prompt: str = "", - image_str: Optional[str] = None, - inference_state: Optional[str] = None - ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 4: - print("infer_prompt called") - print(f"prompt: {prompt}") - print(f"shard: {shard}") - print(f"inference_state: {inference_state}") - - await self.ensure_shard(shard) - - # setup prompt input - messages = [{"role": "user", "content": prompt}] - txt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - - inputs = self.tokenizer([txt], return_tensors="pt") - input_ids = inputs.input_ids.to(self.device) - input_attention_mask = inputs.attention_mask.to(self.device) - batch_size, seq_length = input_ids.shape[:2] - - # get cache from inference_state - past_iids, cached_iids = self.infer_caching(inference_state) - - if past_iids is not None: - self.past_input_ids = past_iids, - else: - self.past_input_ids = input_ids - - if DEBUG >= 4: - print(f"past_input_ids: {self.past_input_ids}\n") - - shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( - input_ids=self.past_input_ids, - attention_mask=input_attention_mask - ) - - if DEBUG >= 4: - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - - next_token = None - if shard_logits is not None: - next_token = self.stateful_sharded_model.logits_sample(shard_logits) - self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) - input_ids = next_token - - if self.past_input_ids is not None: - cached_iids = {"input_ids": self.past_input_ids.tolist()} - - is_finished = False - if next_token is not None: - is_finished = next_token.item() == self.tokenizer.eos_token_id - - if DEBUG >= 4: - print(f"\ninput_ids: {input_ids}") - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - - return_values = ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps({"cached_iids": cached_iids}), - is_finished - ) - - if DEBUG >= 4: - print(f"return_values: {return_values}") - - return return_values - - async def infer_tensor( - self, - request_id: str, - shard: Shard, - input_data: np.ndarray, - inference_state: Optional[str] = None - ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 4: - print("infer_tensor called") - print(f"input_data: {input_data}") - print(f"shard: {shard}") - print(f"inference_state: {inference_state}") - - await self.ensure_shard(shard) - - input_ids = torch.tensor(input_data).to(self.device) - - # get cache from inference_state - past_iids, cached_iids = self.infer_caching(inference_state) - - # detect if hidden_states or not - hidden_states = None - self.past_input_ids = None - if input_ids.size()[-1] > 1: - hidden_states = input_ids - else: - if past_iids is not None: - self.past_input_ids = past_iids - else: - self.past_input_ids = input_ids - - if DEBUG >= 4: - print(f"input_ids: {input_ids}") - print(f"inference_state: {inference_state}") - - shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( - input_ids=self.past_input_ids, - hidden_states=hidden_states - ) - - hidden_dict = None - if shard_hidden_states is not None: - hidden_dict = {"hidden_states": shard_hidden_states.tolist()} - - next_token = None - if shard_logits is not None: - next_token = self.stateful_sharded_model.logits_sample(shard_logits) - input_ids = next_token - - #cache - if next_token is not None: - if self.past_input_ids is not None: - next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) - elif past_iids is not None: - next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) - - cached_iids = {"input_ids": next_cached_logits.tolist()} - - is_finished = False - if next_token is not None: - is_finished = next_token.item() == self.tokenizer.eos_token_id - - if is_finished: - # clear cache - cached_iids = {"input_ids": []} - - if DEBUG >= 4: - print(f"\ninput_ids: {input_ids}") - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - - return_values = ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps({"cached_iids": cached_iids}), - is_finished - ) - - if DEBUG >= 4: - print(f"return_values: {return_values}") - - return return_values - - - async def ensure_shard(self, shard: Optional[Shard]): - """ - Ensure the model shard is loaded and ready for inference. - - Args: - shard (Optional[Shard]): Shard information for the model. - """ - if self.shard == shard: - return - - if DEBUG >= 4: - print(f"Loading new shard: {shard}") - - # -- TO DO -- - # Build in shard downloader but requires pulling - # apart how TrainedModel loads weight in its __init__ - # function in the transformer library - # model_path = await self.shard_downloader.ensure_shard(shard) - - self.tokenizer = await resolve_tokenizer(shard.model_id) - self.stateful_sharded_model = ShardedHuggingFaceModel( - shard=shard, - device=self.device, - dtype=self.torch_dtype, - top_k=TOP_K, - temp=TEMP, - top_p=TOP_P, - max_length=MAX_LENGTH, - max_time=MAX_TIME - ) - - self.shard = shard - - if DEBUG >= 4: - print(f"Shard loaded successfully: {shard}") + Args: + shard (Optional[Shard]): Shard information for the model. + """ + if self.shard == shard: + return + + if DEBUG >= 4: + print(f"Loading new shard: {shard}") + + model_path = await self.shard_downloader.ensure_shard(shard) + if DEBUG >= 4: + print(f"model_path: {model_path}") + + self.tokenizer = await resolve_tokenizer(shard.model_id) + self.stateful_sharded_model = ShardedHuggingFaceModel( + shard=shard, + local_model_path=model_path, + device=self.device, + dtype=self.torch_dtype, + top_k=TOP_K, + temp=TEMP, + top_p=TOP_P, + max_length=MAX_LENGTH, + max_time=MAX_TIME + ) + + self.shard = shard + + if DEBUG >= 4: + print(f"Shard loaded successfully: {shard}") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 1b617d7c..38cd85c2 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -23,7 +23,8 @@ class ShardedHuggingFaceModel: def __init__( self, shard: Shard, - device, + local_model_path, + device, dtype, top_k: int = 25, temp: float = 0.7, @@ -31,19 +32,20 @@ def __init__( max_length: int = 50, max_time: float = 10.0 ): - # class vars + # class vars self.shard = shard - self.hidden_states = None + self.hidden_states = None self.input_ids = None self.inputs_embeds = None self.attention_mask = None - self.position_embeddings = None - self.past_key_values = None - self.cache_position = None - self.position_ids = None + self.position_embeddings = None + self.past_key_values = None + self.cache_position = None + self.position_ids = None self.causal_mask = None + self.local_model_path = local_model_path - # setup logit processors + # setup logit processors self.logits_processor = LogitsProcessorList([ TopKLogitsWarper(top_k), TemperatureLogitsWarper(temp), @@ -56,13 +58,13 @@ def __init__( # setup pytorch and transformer llm try: self.llm_model = AutoModelForCausalLM.from_pretrained( - shard.model_id, + pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.torch_dtype, device_map="auto", offload_buffers=True ) - self.model = self.llm_model.model + self.model = self.llm_model.model except Exception as err: print(f"error loading and splitting model: {err}") raise @@ -70,7 +72,6 @@ def __init__( def forward( self, - shard: Optional[Shard] = None, input_ids: Optional[torch.tensor] = None, hidden_states: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, @@ -93,7 +94,7 @@ def forward( infer_tensor: bool optional, lets forward know to handle tensors Returns: - Tuple of + Tuple of - hidden_states: tensor optional - past_key_values: Cache or list[tensor] optional - logits: tensor Optional @@ -199,9 +200,8 @@ def forward( print(f"hidden_states: {self.hidden_states}") print(f"next_decoder_cache: {self.next_decoder_cache}") - # handle last layer to get logits - # shard is last layer says true at the start and not detecting last layer correctly + # shard is last layer says true at the start and not detecting last layer correctly if self.shard.is_last_layer(): self.hidden_states = self.model.norm(self.hidden_states) if use_legacy_cache: @@ -209,7 +209,7 @@ def forward( else: self.past_key_values = self.next_decoder_cache - # lm_head + # lm_head logits = self.llm_model.lm_head(self.hidden_states).to(self.device) if DEBUG >= 4: diff --git a/exo/inference/pytorch/tests/test_inference_engine.py b/exo/inference/pytorch/tests/test_inference_engine.py index 7e64c137..854d9b9c 100644 --- a/exo/inference/pytorch/tests/test_inference_engine.py +++ b/exo/inference/pytorch/tests/test_inference_engine.py @@ -11,164 +11,131 @@ import time async def test_inference_engine( - inference_engine_1: InferenceEngine, - inference_engine_2: InferenceEngine, - model_id: str, - n_layers: int): - - # prompt = "Why is the sky blue?" - prompt = "In a single word only, what is the last name of the current president of the USA?" - - shard = Shard( - model_id=model_id, - start_layer=0, - end_layer=n_layers-1, - n_layers=n_layers - ) - - resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( - "A", - shard=shard, - prompt=prompt - ) - - print("\n------------resp_full---------------\n") - print(resp_full) - print("\n------------resp_full---------------\n") - - time.sleep(5) - - next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( - "A", - shard=shard, - input_data=resp_full, - inference_state=inference_state_full, - ) - - print("\n------------next_resp_full---------------\n") - print(next_resp_full) - print("\n------------next_resp_full---------------\n") - - time.sleep(5) - - pp = int(n_layers/2) - - resp_shard = Shard( - model_id=model_id, - start_layer=0, - end_layer=pp, - n_layers=n_layers - ) - - resp_shard2 = Shard( - model_id=model_id, - start_layer=pp + 1, - end_layer=n_layers-1, - n_layers=n_layers - ) - - resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( - "B", - shard=resp_shard, - prompt=prompt - ) - - print("\n------------resp1---------------\n") - print(resp1) - print("\n------------resp1---------------\n") - - time.sleep(5) - - - resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( - "B", - shard=resp_shard2, - input_data=resp1, - inference_state=inference_state_1, - ) - - print("\n------------resp2---------------\n") - print(resp2) - print("\n------------resp2---------------\n") - - resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( - "B", - shard=resp_shard, - input_data=resp2, - inference_state=inference_state_2, - ) - - print("\n------------resp3---------------\n") - print(resp3) - print("\n------------resp3---------------\n") - - resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( - "B", - shard=resp_shard2, - input_data=resp3, - inference_state=inference_state_3, - ) - - print("\n------------resp4---------------\n") - print(resp4) - print("\n------------resp4---------------\n") - - assert np.array_equal(resp_full, resp2) - assert np.array_equal(next_resp_full, resp4) + inference_engine_1: InferenceEngine, + inference_engine_2: InferenceEngine, + model_id: str, + n_layers: int): + + # prompt = "Why is the sky blue?" + prompt = "In a single word only, what is the last name of the current president of the USA?" + + shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + "A", + shard=shard, + prompt=prompt + ) + + print("\n------------resp_full---------------\n") + print(resp_full) + print("\n------------resp_full---------------\n") + + time.sleep(5) + + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + "A", + shard=shard, + input_data=resp_full, + inference_state=inference_state_full, + ) + + print("\n------------next_resp_full---------------\n") + print(next_resp_full) + print("\n------------next_resp_full---------------\n") + + time.sleep(5) + + pp = int(n_layers/2) + + resp_shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=pp, + n_layers=n_layers + ) + + resp_shard2 = Shard( + model_id=model_id, + start_layer=pp + 1, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( + "B", + shard=resp_shard, + prompt=prompt + ) + + print("\n------------resp1---------------\n") + print(resp1) + print("\n------------resp1---------------\n") + + time.sleep(5) + + + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp1, + inference_state=inference_state_1, + ) + + print("\n------------resp2---------------\n") + print(resp2) + print("\n------------resp2---------------\n") + + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + "B", + shard=resp_shard, + input_data=resp2, + inference_state=inference_state_2, + ) + + print("\n------------resp3---------------\n") + print(resp3) + print("\n------------resp3---------------\n") + + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp3, + inference_state=inference_state_3, + ) + + print("\n------------resp4---------------\n") + print(resp4) + print("\n------------resp4---------------\n") + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - try: - print(f"\n\n -------- TEST QWEN2 -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "Qwen/Qwen2-0.5B-Instruct", - 24 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") - - # try: - # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "andrijdavid/Llama3-1B-Base", - # 3 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") - - # try: - # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "meta-llama/Meta-Llama-3.1-8B", - # 32 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") - - # try: - # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Chickaboo/ChickaQ-Large", - # 24 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") - - #try: - # print(f"\n\n --------- TEST TinyLlama/TinyLlama_v1.1 -------\n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "TinyLlama/TinyLlama_v1.1", - # 22 - # )) - #except Exception as err: - # print(f"\n\n !!!!!!!!!!! TinyLlama/TinyLlama_v1.1 TEST FAILED \n{err}\n") + # try: + # print("\n\n -------- TEST QWEN2 -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Qwen/Qwen2-0.5B-Instruct", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + + try: + print("\n-------- Test meta-llama/Llama-3.2-1B-Instruct ----------\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "meta-llama/Llama-3.2-1B-Instruct", + 24 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") diff --git a/exo/inference/pytorch/tests/test_split_model.py b/exo/inference/pytorch/tests/test_split_model.py index 827bdec2..157a215d 100644 --- a/exo/inference/pytorch/tests/test_split_model.py +++ b/exo/inference/pytorch/tests/test_split_model.py @@ -3,14 +3,11 @@ import asyncio import gc from transformers import ( - AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache, Cache, LogitsProcessorList, - #MinLengthLogitsProcessor, - LogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, @@ -286,8 +283,8 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): stopping_critera = StoppingCriteriaList( [ - MaxLengthCriteria(max_length=50), - MaxTimeCriteria(max_time=10.0), + MaxLengthCriteria(max_length=255), + MaxTimeCriteria(max_time=100.0), ] ) @@ -355,9 +352,21 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): # ) #) - print("\n-------- Test Qwen/Qwen2-0.5B-Instruct ----------\n") - model_id = "Qwen/Qwen2-0.5B-Instruct" - model_layers = 24 + #print("\n-------- Test Qwen/Qwen2-0.5B-Instruct ----------\n") + #model_id = "Qwen/Qwen2-0.5B-Instruct" + #model_layers = 24 + + #asyncio.run( + # model_half_split_test( + # prompt=prompt, + # model_id=model_id, + # layers=model_layers + # ) + #) + + print("\n-------- Test meta-llama/Llama-3.2-1B-Instruct ----------\n") + model_id = "meta-llama/Llama-3.2-1B-Instruct" + model_layers = 32 asyncio.run( model_half_split_test( diff --git a/exo/models.py b/exo/models.py index 67ea81c4..6f69960e 100644 --- a/exo/models.py +++ b/exo/models.py @@ -36,8 +36,8 @@ "llama-3-1B-Base": { "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), }, - "TinyLlama-1.1B-Chat-yaw": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="ambrosfitz/TinyLlama-1.1B-Chat-yawp", start_layer=0, end_layer=0, n_layers=22), + "meta-llama/Llama-3.2-1B-Instruct": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=24), }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, @@ -47,11 +47,6 @@ "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),}, ### llava "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),}, - ### qwen - "Qwen2-0.5B-Instruct": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), - }, - ### qwen "qwen-2.5-coder-1.5b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), @@ -74,4 +69,7 @@ "qwen-2.5-math-72b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), }, + "Qwen2-0.5B-Instruct": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), + }, } diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index 9cad69d5..c00d2b0a 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -38,6 +38,7 @@ + diff --git a/setup.py b/setup.py index 75d570e9..8401167b 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,8 @@ "transformers==4.43.3", "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad", + "torch==2.4.0+cu124", + "accelerate=0.33.0" ] # Add macOS-specific packages if on Darwin (macOS) From 9fe3ec63dd26b78d9c27e3bcb17f72a79c7ee977 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 11:51:33 -0800 Subject: [PATCH 372/491] cleaning up code, added pytorch engine to llama 3.2 1b model shard in models.py, removed old 3.2 1b model shard, moving to test server for more vram --- exo/inference/pytorch/inference.py | 21 ++++++++++++--------- exo/inference/pytorch/model/hf.py | 2 +- exo/models.py | 4 +--- exo/tinychat/index.html | 1 - setup.py | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 2f87c1b1..8264aae8 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -12,6 +12,9 @@ from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader +# llama +from transformers.models.llama.modeling_llama import LlamaModel + # model value options TOP_K = 20 TEMP = 0.6 @@ -52,7 +55,7 @@ def __init__(self, shard_downloader: HFShardDownloader): if torch.cuda.is_available(): self.device = torch.device("cuda") - self.torch_dtype = torch.float32 + self.torch_dtype = torch.float16 elif torch.backends.mps.is_available(): self.device = torch.device("mps") self.torch_dtype = torch.float32 @@ -105,10 +108,10 @@ async def infer_prompt( print(f"prompt: {prompt}") print(f"shard: {shard}") print(f"inference_state: {inference_state}") - + await self.ensure_shard(shard) - - # setup prompt input + + # setup prompt input messages = [{"role": "user", "content": prompt}] txt = self.tokenizer.apply_chat_template( messages, @@ -174,9 +177,9 @@ async def infer_prompt( async def infer_tensor( self, - request_id: str, - shard: Shard, - input_data: np.ndarray, + request_id: str, + shard: Shard, + input_data: np.ndarray, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: if DEBUG >= 4: @@ -192,13 +195,13 @@ async def infer_tensor( # get cache from inference_state past_iids, cached_iids = self.infer_caching(inference_state) - # detect if hidden_states or not + # detect if hidden_states or not hidden_states = None self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids else: - if past_iids is not None: + if past_iids is not None: self.past_input_ids = past_iids else: self.past_input_ids = input_ids diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 38cd85c2..57a1590b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -16,7 +16,7 @@ TemperatureLogitsWarper ) -# llama +# llama from transformers.models.llama.modeling_llama import LlamaModel class ShardedHuggingFaceModel: diff --git a/exo/models.py b/exo/models.py index 6f69960e..2f1e7d10 100644 --- a/exo/models.py +++ b/exo/models.py @@ -4,6 +4,7 @@ ### llama "llama-3.2-1b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), }, "llama-3.2-3b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), @@ -36,9 +37,6 @@ "llama-3-1B-Base": { "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), }, - "meta-llama/Llama-3.2-1B-Instruct": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=24), - }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),}, diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index c00d2b0a..9cad69d5 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -38,7 +38,6 @@ - diff --git a/setup.py b/setup.py index 8401167b..b23485a7 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad", "torch==2.4.0+cu124", - "accelerate=0.33.0" + "accelerate" ] # Add macOS-specific packages if on Darwin (macOS) From b44f6e975f4b77642fba46a5fde601499c0bb52b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 13:08:14 -0800 Subject: [PATCH 373/491] updating pytorch requirement --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b23485a7..f8ae17ee 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ "transformers==4.43.3", "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad", - "torch==2.4.0+cu124", + "torch==2.4.0", "accelerate" ] From 936e60a42dfbb1b2e6a94168f3237e3dec42635a Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 15:41:26 -0800 Subject: [PATCH 374/491] trying tokenizer fixes for llama3.1 --- exo/inference/pytorch/inference.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 8264aae8..1791e0c9 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -8,10 +8,11 @@ from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel -from exo.api.chatgpt_api import resolve_tokenizer +from exo.inference.tokenizers import resolve_tokenizer from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader +from transformers import AutoTokenizer # llama from transformers.models.llama.modeling_llama import LlamaModel @@ -176,7 +177,7 @@ async def infer_prompt( return return_values async def infer_tensor( - self, + self, request_id: str, shard: Shard, input_data: np.ndarray, @@ -259,7 +260,7 @@ async def infer_tensor( return return_values - async def ensure_shard(self, shard: Optional[Shard]): + async def ensure_shard(self, shard: Shard): """ Ensure the model shard is loaded and ready for inference. @@ -276,7 +277,6 @@ async def ensure_shard(self, shard: Optional[Shard]): if DEBUG >= 4: print(f"model_path: {model_path}") - self.tokenizer = await resolve_tokenizer(shard.model_id) self.stateful_sharded_model = ShardedHuggingFaceModel( shard=shard, local_model_path=model_path, @@ -288,8 +288,15 @@ async def ensure_shard(self, shard: Optional[Shard]): max_length=MAX_LENGTH, max_time=MAX_TIME ) - self.shard = shard + if isinstance(self.stateful_sharded_model.model, LlamaModel): + self.tokenizer = AutoTokenizer.from_pretrained( + model_path if model_path is not None else shard.model_id, + trust_remote_code=True + ) + else: + self.tokenizer = await resolve_tokenizer(shard.model_id) + if DEBUG >= 4: print(f"Shard loaded successfully: {shard}") From 43c3c627b506dd188ba755cba59eedd3a4b9e9da Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 15:51:16 -0800 Subject: [PATCH 375/491] detecting 3.1 for adding padding token and using autotokenizer for llama models --- exo/inference/pytorch/inference.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 1791e0c9..020bd865 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,5 +1,6 @@ # experimental, based off of tinygrad/inference.py -import os +import os +import re import numpy as np import torch import json @@ -295,6 +296,10 @@ async def ensure_shard(self, shard: Shard): model_path if model_path is not None else shard.model_id, trust_remote_code=True ) + + if len(re.findall(r"3\.1", shard.model_id)) > 0: + self.tokenizer.add_special_tokens({"pad_token":""}) + else: self.tokenizer = await resolve_tokenizer(shard.model_id) From 75a29f464f9e8b1cd3d8c085cb7da97ab085c1ba Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 16:26:55 -0800 Subject: [PATCH 376/491] updating models.py to use instruct version --- exo/inference/pytorch/inference.py | 20 ++++++++++---------- exo/models.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 020bd865..01841ba4 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -13,7 +13,7 @@ from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader -from transformers import AutoTokenizer +from tokenizers import Tokenizer # llama from transformers.models.llama.modeling_llama import LlamaModel @@ -114,14 +114,14 @@ async def infer_prompt( await self.ensure_shard(shard) # setup prompt input - messages = [{"role": "user", "content": prompt}] - txt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - - inputs = self.tokenizer([txt], return_tensors="pt") + #messages = [{"role": "user", "content": prompt}] + #txt = self.tokenizer.apply_chat_template( + # messages, + # tokenize=False, + # add_generation_prompt=True + #) + + inputs = self.tokenizer([prompt], return_tensors="pt") input_ids = inputs.input_ids.to(self.device) input_attention_mask = inputs.attention_mask.to(self.device) batch_size, seq_length = input_ids.shape[:2] @@ -292,7 +292,7 @@ async def ensure_shard(self, shard: Shard): self.shard = shard if isinstance(self.stateful_sharded_model.model, LlamaModel): - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer = Tokenizer.from_pretrained( model_path if model_path is not None else shard.model_id, trust_remote_code=True ) diff --git a/exo/models.py b/exo/models.py index bb9ccf5c..29e9a7d6 100644 --- a/exo/models.py +++ b/exo/models.py @@ -12,7 +12,7 @@ "llama-3.1-8b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32), - "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=32), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), }, "llama-3.1-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), From e407404fa4c0af11cc5a01a5ce6fd1696e1a7a57 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 16:41:40 -0800 Subject: [PATCH 377/491] fixing autotokenizer --- exo/inference/pytorch/inference.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 01841ba4..8bbc2579 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -13,7 +13,7 @@ from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader -from tokenizers import Tokenizer +from transformers import AutoTokenizer # llama from transformers.models.llama.modeling_llama import LlamaModel @@ -71,7 +71,7 @@ def __init__(self, shard_downloader: HFShardDownloader): def infer_caching( self, inference_state: Optional[str] = None - ) -> Tuple[Optional[torch.tensor], Optional[dict]]: + ) -> Tuple[Optional[torch.Tensor], Optional[dict]]: """ inference caching from inference_state json """ @@ -99,9 +99,9 @@ def infer_caching( async def infer_prompt( self, - request_id: Optional[str] = None, - shard: Optional[Shard] = None, - prompt: Optional[str] = "", + request_id: str, + shard: Shard, + prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: @@ -113,14 +113,6 @@ async def infer_prompt( await self.ensure_shard(shard) - # setup prompt input - #messages = [{"role": "user", "content": prompt}] - #txt = self.tokenizer.apply_chat_template( - # messages, - # tokenize=False, - # add_generation_prompt=True - #) - inputs = self.tokenizer([prompt], return_tensors="pt") input_ids = inputs.input_ids.to(self.device) input_attention_mask = inputs.attention_mask.to(self.device) @@ -196,7 +188,7 @@ async def infer_tensor( # get cache from inference_state past_iids, cached_iids = self.infer_caching(inference_state) - + # detect if hidden_states or not hidden_states = None self.past_input_ids = None @@ -292,7 +284,7 @@ async def ensure_shard(self, shard: Shard): self.shard = shard if isinstance(self.stateful_sharded_model.model, LlamaModel): - self.tokenizer = Tokenizer.from_pretrained( + self.tokenizer = AutoTokenizer.from_pretrained( model_path if model_path is not None else shard.model_id, trust_remote_code=True ) From 668668f5439dc59cfa632dca0427b88e26e3a382 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 18:17:02 -0800 Subject: [PATCH 378/491] making it so position and cache is computed every forward on hf model --- exo/inference/pytorch/model/hf.py | 499 +++++++++++++++--------------- 1 file changed, 247 insertions(+), 252 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 57a1590b..547ed040 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -7,261 +7,256 @@ from exo.helpers import DEBUG from transformers import ( - AutoModelForCausalLM, - DynamicCache, - Cache, - LogitsProcessorList, - TopKLogitsWarper, - TopPLogitsWarper, - TemperatureLogitsWarper + AutoModelForCausalLM, + DynamicCache, + Cache, + LogitsProcessorList, + TopKLogitsWarper, + TopPLogitsWarper, + TemperatureLogitsWarper ) # llama from transformers.models.llama.modeling_llama import LlamaModel class ShardedHuggingFaceModel: - def __init__( - self, - shard: Shard, - local_model_path, - device, - dtype, - top_k: int = 25, - temp: float = 0.7, - top_p: float = 0.9, - max_length: int = 50, - max_time: float = 10.0 - ): - # class vars - self.shard = shard - self.hidden_states = None - self.input_ids = None - self.inputs_embeds = None - self.attention_mask = None - self.position_embeddings = None - self.past_key_values = None - self.cache_position = None - self.position_ids = None - self.causal_mask = None - self.local_model_path = local_model_path - - # setup logit processors - self.logits_processor = LogitsProcessorList([ - TopKLogitsWarper(top_k), - TemperatureLogitsWarper(temp), - TopPLogitsWarper(top_p) - ]) - - self.device = device - self.torch_dtype = dtype - - # setup pytorch and transformer llm - try: - self.llm_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=self.local_model_path, - torch_dtype=self.torch_dtype, - device_map="auto", - offload_buffers=True - ) - - self.model = self.llm_model.model - except Exception as err: - print(f"error loading and splitting model: {err}") - raise - - - def forward( - self, - input_ids: Optional[torch.tensor] = None, - hidden_states: Optional[torch.tensor] = None, - attention_mask: Optional[torch.tensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - use_legacy_cache: bool = False - ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: - - """ - Generate hidden states or logits via passing through set amount of layers of a model - To be passed only input_ids OR hidden_state and not both. This is for connecting the model - layer to generate a complete output - - Args: - model: base llm model tramsformers class - llm_model: llm chat model class - input_ids: tensor optional - attention_mask: tensor optional - past_key_values: Cache or list[tensor] optional - use_legacy_cache: bool optional - infer_tensor: bool optional, lets forward know to handle tensors - - Returns: - Tuple of - - hidden_states: tensor optional - - past_key_values: Cache or list[tensor] optional - - logits: tensor Optional - - """ - - model_inputs = None - self.hidden_states = None - - if hidden_states is not None: - self.hidden_states = hidden_states - else: - self.input_ids = input_ids - - # embed input_ids - self.inputs_embeds = self.model.embed_tokens(self.input_ids) - - # cache - if past_key_values and not isinstance(past_key_values, Cache): - use_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + self.inputs_embeds.shape[1], - device=self.inputs_embeds.device - ) - - # position id - position_ids = cache_position.unsqueeze(0) - - # casual mask and attention_mask - self.attention_mask = attention_mask - self.causal_mask = self.model._update_causal_mask( - None, - self.inputs_embeds, - cache_position, - past_key_values, - False # dont out attentions - ) - - # embed positions, some models require and some dont - if isinstance(self.model, LlamaModel): - self.position_embeddings = self.model.rotary_emb( - self.inputs_embeds, - position_ids - ) - - # prepare inputs for decoder layers - model_inputs = self.llm_model.prepare_inputs_for_generation( - self.input_ids, - past_key_values=past_key_values, - attention_mask=self.attention_mask, - inputs_embeds=self.inputs_embeds, - position_ids=position_ids, - cache_position=cache_position - ) - - self.hidden_states = self.inputs_embeds - self.position_ids = model_inputs["position_ids"] - self.cache_position = model_inputs["cache_position"] - self.past_key_values = model_inputs["past_key_values"] - - if DEBUG >= 4: - print(f"model_inputs: {model_inputs}") - - # run through decoder layers - layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) - - if DEBUG >= 4: - print(f"hidden_states: {self.hidden_states}") - print(f"layer_amt: {layer_amt}") - - for i in layer_amt: - decoder_layer = self.model.layers[i] - if DEBUG >= 5: - print("decoder_layer before") - print(f"decoder_layer: {decoder_layer}") - print(f"hidden_states: {self.hidden_states}") - - # TODO: fix caching as decoder layer is not returning - # present_key_value from attention layer on models - # might have some other generation functions needed to do it - # see https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2917 - # for qwen2 exhttps://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py#L291 - layer_outputs = decoder_layer( - self.hidden_states, - attention_mask=self.causal_mask, - position_ids=self.position_ids, - past_key_values=self.past_key_values, - use_cache=True, - cache_position=self.cache_position - ) - - self.hidden_states = layer_outputs[0] - self.next_decoder_cache = layer_outputs[1] - - if DEBUG >= 5: - print("decoder_layer after") - print(f"layer_outputs: {layer_outputs}\n") - print(f"self.next_decoder_cache: {self.next_decoder_cache}") - print(f"hidden_states: {self.hidden_states}") - print(f"next_decoder_cache: {self.next_decoder_cache}") - - # handle last layer to get logits - # shard is last layer says true at the start and not detecting last layer correctly - if self.shard.is_last_layer(): - self.hidden_states = self.model.norm(self.hidden_states) - if use_legacy_cache: - self.past_key_values = self.next_decoder_cache.to_legacy_cache() - else: - self.past_key_values = self.next_decoder_cache - - # lm_head - logits = self.llm_model.lm_head(self.hidden_states).to(self.device) - - if DEBUG >= 4: - print(f"logits: {logits}") - - return ( - None, - None, - logits - ) - - if DEBUG >= 4: - print(f"hidden_states: {self.hidden_states}") - print(f"past_key_values: {self.past_key_values}") - - return ( - self.hidden_states, - self.past_key_values, - None - ) - - def logits_sample( - self, - logits: torch.tensor, - use_max: Optional[bool] = False - ) -> torch.tensor: - """ - Get a sample of the logits from end of model run for next token - - Args: - logits: tensor - use_max: bool, if function should sample with argmax - - Returns: - next_token: tensor - """ - - # get a single cloned logit - logits = logits[:, -1, :].clone().float() - - next_token_scores = self.logits_processor(self.input_ids, logits) - - if not use_max: - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - else: - next_token = torch.argmax(next_token_scores, dim=-1) - - if DEBUG >= 4: - print(f"input_ids: {self.input_ids}") - print(f"next_token: {next_token}") - - return next_token[:, None].squeeze(-1) - - + def __init__( + self, + shard: Shard, + local_model_path, + device, + dtype, + top_k: int = 25, + temp: float = 0.7, + top_p: float = 0.9, + max_length: int = 50, + max_time: float = 10.0 + ): + # class vars + self.shard = shard + self.hidden_states = None + self.input_ids = None + self.inputs_embeds = None + self.attention_mask = None + self.position_embeddings = None + self.past_key_values = None + self.cache_position = None + self.position_ids = None + self.causal_mask = None + self.local_model_path = local_model_path + + # setup logit processors + self.logits_processor = LogitsProcessorList([ + TopKLogitsWarper(top_k), + TemperatureLogitsWarper(temp), + TopPLogitsWarper(top_p) + ]) + + self.device = device + self.torch_dtype = dtype + + # setup pytorch and transformer llm + try: + self.llm_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=self.local_model_path, + torch_dtype=self.torch_dtype, + device_map="auto", + offload_buffers=True + ) + + self.model = self.llm_model.model + except Exception as err: + print(f"error loading and splitting model: {err}") + raise + + + def forward( + self, + input_ids: Optional[torch.tensor] = None, + hidden_states: Optional[torch.tensor] = None, + attention_mask: Optional[torch.tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_legacy_cache: bool = False + ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: + + """ + Generate hidden states or logits via passing through set amount of layers of a model + To be passed only input_ids OR hidden_state and not both. This is for connecting the model + layer to generate a complete output + + Args: + model: base llm model tramsformers class + llm_model: llm chat model class + input_ids: tensor optional + attention_mask: tensor optional + past_key_values: Cache or list[tensor] optional + use_legacy_cache: bool optional + infer_tensor: bool optional, lets forward know to handle tensors + + Returns: + Tuple of + - hidden_states: tensor optional + - past_key_values: Cache or list[tensor] optional + - logits: tensor Optional + + """ + model_inputs = None + self.hidden_states = hidden_states + self.input_ids = input_ids + + # embed input_ids + self.inputs_embeds = self.model.embed_tokens(self.input_ids) + + # cache + if past_key_values and not isinstance(past_key_values, Cache): + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + position_ids = cache_position.unsqueeze(0) + + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, + self.inputs_embeds, + cache_position, + past_key_values, + False # dont out attentions + ) + + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( + self.inputs_embeds, + position_ids + ) + + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=position_ids, + cache_position=cache_position + ) + + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] + + if DEBUG >= 4: + print(f"model_inputs: {model_inputs}") + + # run through decoder layers + layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) + + if DEBUG >= 4: + print(f"hidden_states: {self.hidden_states}") + print(f"layer_amt: {layer_amt}") + + for i in layer_amt: + decoder_layer = self.model.layers[i] + if DEBUG >= 5: + print("decoder_layer before") + print(f"decoder_layer: {decoder_layer}") + print(f"hidden_states: {self.hidden_states}") + print(f"position_ids: {self.position_ids}") + print(f"position_embeddings: {self.position_embeddings}") + + # TODO: fix caching as decoder layer is not returning + # present_key_value from attention layer on models + # might have some other generation functions needed to do it + # see https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2917 + # for qwen2 exhttps://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py#L291 + layer_outputs = decoder_layer( + self.hidden_states, + attention_mask=self.causal_mask, + position_ids=self.position_ids, + past_key_values=self.past_key_values, + use_cache=True, + cache_position=self.cache_position + ) + + self.hidden_states = layer_outputs[0] + self.next_decoder_cache = layer_outputs[1] + + if DEBUG >= 5: + print("decoder_layer after") + print(f"layer_outputs: {layer_outputs}\n") + print(f"self.next_decoder_cache: {self.next_decoder_cache}") + print(f"hidden_states: {self.hidden_states}") + print(f"next_decoder_cache: {self.next_decoder_cache}") + + # handle last layer to get logits + # shard is last layer says true at the start and not detecting last layer correctly + if self.shard.is_last_layer(): + self.hidden_states = self.model.norm(self.hidden_states) + if use_legacy_cache: + self.past_key_values = self.next_decoder_cache.to_legacy_cache() + else: + self.past_key_values = self.next_decoder_cache + + # lm_head + logits = self.llm_model.lm_head(self.hidden_states).to(self.device) + + if DEBUG >= 4: + print(f"logits: {logits}") + + return ( + None, + None, + logits + ) + + if DEBUG >= 4: + print(f"hidden_states: {self.hidden_states}") + print(f"past_key_values: {self.past_key_values}") + + return ( + self.hidden_states, + self.past_key_values, + None + ) + + def logits_sample( + self, + logits: torch.tensor, + use_max: Optional[bool] = False + ) -> torch.tensor: + """ + Get a sample of the logits from end of model run for next token + + Args: + logits: tensor + use_max: bool, if function should sample with argmax + + Returns: + next_token: tensor + """ + + # get a single cloned logit + logits = logits[:, -1, :].clone().float() + + next_token_scores = self.logits_processor(self.input_ids, logits) + + if not use_max: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_token_scores, dim=-1) + + if DEBUG >= 4: + print(f"input_ids: {self.input_ids}") + print(f"next_token: {next_token}") + + return next_token[:, None].squeeze(-1) From 4e356f8dac26f7378e276807fabd2fa773b540b8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 18:20:23 -0800 Subject: [PATCH 379/491] loading cached input_ids when passing hidden states --- exo/inference/pytorch/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 8bbc2579..15760994 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -194,6 +194,7 @@ async def infer_tensor( self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids + self.past_input_ids = past_iids else: if past_iids is not None: self.past_input_ids = past_iids From a5ef04a9ecdb3f9185f665654c19118240af85b5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 18:33:43 -0800 Subject: [PATCH 380/491] loading cached iids from infer state fix --- exo/inference/pytorch/inference.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 15760994..6a14f1f7 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -105,7 +105,7 @@ async def infer_prompt( image_str: Optional[str] = None, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 4: + if DEBUG >= 2: print("infer_prompt called") print(f"prompt: {prompt}") print(f"shard: {shard}") @@ -126,7 +126,7 @@ async def infer_prompt( else: self.past_input_ids = input_ids - if DEBUG >= 4: + if DEBUG >= 2: print(f"past_input_ids: {self.past_input_ids}\n") shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( @@ -134,7 +134,7 @@ async def infer_prompt( attention_mask=input_attention_mask ) - if DEBUG >= 4: + if DEBUG >= 2: print(f"\nshard_hidden_states: {shard_hidden_states}\n") print(f"\nshard_past_kvs {shard_past_kvs}\n") print(f"\nshard_logits: {shard_logits}") @@ -152,7 +152,7 @@ async def infer_prompt( if next_token is not None: is_finished = next_token.item() == self.tokenizer.eos_token_id - if DEBUG >= 4: + if DEBUG >= 2: print(f"\ninput_ids: {input_ids}") print(f"\nshard_hidden_states: {shard_hidden_states}\n") print(f"\nshard_past_kvs {shard_past_kvs}\n") @@ -164,7 +164,7 @@ async def infer_prompt( is_finished ) - if DEBUG >= 4: + if DEBUG >= 2: print(f"return_values: {return_values}") return return_values @@ -176,7 +176,7 @@ async def infer_tensor( input_data: np.ndarray, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 4: + if DEBUG >= 2: print("infer_tensor called") print(f"input_data: {input_data}") print(f"shard: {shard}") @@ -194,15 +194,16 @@ async def infer_tensor( self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids - self.past_input_ids = past_iids + self.past_input_ids = torch.tensor(cached_iids["input_ids"]) else: if past_iids is not None: self.past_input_ids = past_iids else: self.past_input_ids = input_ids - if DEBUG >= 4: + if DEBUG >= 2: print(f"input_ids: {input_ids}") + print(f"hidden_state: {hidden_states}") print(f"inference_state: {inference_state}") shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( @@ -236,7 +237,7 @@ async def infer_tensor( # clear cache cached_iids = {"input_ids": []} - if DEBUG >= 4: + if DEBUG >= 2: print(f"\ninput_ids: {input_ids}") print(f"\nshard_hidden_states: {shard_hidden_states}\n") print(f"\nshard_past_kvs {shard_past_kvs}\n") @@ -248,7 +249,7 @@ async def infer_tensor( is_finished ) - if DEBUG >= 4: + if DEBUG >= 2: print(f"return_values: {return_values}") return return_values @@ -264,11 +265,11 @@ async def ensure_shard(self, shard: Shard): if self.shard == shard: return - if DEBUG >= 4: + if DEBUG >= 2: print(f"Loading new shard: {shard}") model_path = await self.shard_downloader.ensure_shard(shard) - if DEBUG >= 4: + if DEBUG >= 2: print(f"model_path: {model_path}") self.stateful_sharded_model = ShardedHuggingFaceModel( From e888baa1537c98cd450ee28db29c83bc4e7320be Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 18:35:12 -0800 Subject: [PATCH 381/491] device fix --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 6a14f1f7..fbf99a6e 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -194,7 +194,7 @@ async def infer_tensor( self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids - self.past_input_ids = torch.tensor(cached_iids["input_ids"]) + self.past_input_ids = torch.tensor(cached_iids["input_ids"]).to(self.device) else: if past_iids is not None: self.past_input_ids = past_iids From 7d9eb17d22a832dccb359ba48bd08cd650dc33ec Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 18:42:45 -0800 Subject: [PATCH 382/491] position id fix --- exo/inference/pytorch/model/hf.py | 91 ++++++++++++++++--------------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 547ed040..3f88afb2 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -104,58 +104,59 @@ def forward( self.hidden_states = hidden_states self.input_ids = input_ids - # embed input_ids - self.inputs_embeds = self.model.embed_tokens(self.input_ids) - - # cache - if past_key_values and not isinstance(past_key_values, Cache): - use_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + self.inputs_embeds.shape[1], - device=self.inputs_embeds.device - ) + if self.hidden_states is None or self.position_ids is None: + # embed input_ids + self.inputs_embeds = self.model.embed_tokens(self.input_ids) + + # cache + if past_key_values and not isinstance(past_key_values, Cache): + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) - # position id - position_ids = cache_position.unsqueeze(0) - - # casual mask and attention_mask - self.attention_mask = attention_mask - self.causal_mask = self.model._update_causal_mask( - None, - self.inputs_embeds, - cache_position, - past_key_values, - False # dont out attentions - ) + # position id + position_ids = cache_position.unsqueeze(0) - # embed positions, some models require and some dont - if isinstance(self.model, LlamaModel): - self.position_embeddings = self.model.rotary_emb( + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, self.inputs_embeds, - position_ids + cache_position, + past_key_values, + False # dont out attentions ) - # prepare inputs for decoder layers - model_inputs = self.llm_model.prepare_inputs_for_generation( - self.input_ids, - past_key_values=past_key_values, - attention_mask=self.attention_mask, - inputs_embeds=self.inputs_embeds, - position_ids=position_ids, - cache_position=cache_position - ) + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( + self.inputs_embeds, + position_ids + ) + + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=position_ids, + cache_position=cache_position + ) - self.hidden_states = self.inputs_embeds - self.position_ids = model_inputs["position_ids"] - self.cache_position = model_inputs["cache_position"] - self.past_key_values = model_inputs["past_key_values"] + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] - if DEBUG >= 4: - print(f"model_inputs: {model_inputs}") + if DEBUG >= 4: + print(f"model_inputs: {model_inputs}") # run through decoder layers layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) From 11986289e7095d4229e4ddd7d0330817b3aecbcd Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 19:00:11 -0800 Subject: [PATCH 383/491] fixing inference instance state issues between nodes --- exo/inference/pytorch/inference.py | 4 ++-- exo/inference/pytorch/model/hf.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index fbf99a6e..03995975 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -194,7 +194,7 @@ async def infer_tensor( self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids - self.past_input_ids = torch.tensor(cached_iids["input_ids"]).to(self.device) + self.past_input_ids = past_iids else: if past_iids is not None: self.past_input_ids = past_iids @@ -202,7 +202,7 @@ async def infer_tensor( self.past_input_ids = input_ids if DEBUG >= 2: - print(f"input_ids: {input_ids}") + print(f"past_input_ids: {self.past_input_ids}") print(f"hidden_state: {hidden_states}") print(f"inference_state: {inference_state}") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 3f88afb2..786899e2 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -104,6 +104,12 @@ def forward( self.hidden_states = hidden_states self.input_ids = input_ids + if DEBUG >= 2: + print("hf forward called") + print(f"hidden_states: {self.hidden_states}") + print(f"input_ids: {self.input_ids}") + print(f"self.position_ids: {self.position_ids}") + if self.hidden_states is None or self.position_ids is None: # embed input_ids self.inputs_embeds = self.model.embed_tokens(self.input_ids) From d25b7ac0398b44b5f0aed51d9922b70513af96fd Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 19:38:31 -0800 Subject: [PATCH 384/491] node testing --- exo/models.py | 2 +- exo/tinychat/index.html | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/models.py b/exo/models.py index 29e9a7d6..7d6bc30a 100644 --- a/exo/models.py +++ b/exo/models.py @@ -67,7 +67,7 @@ "qwen-2.5-math-72b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), }, - "Qwen2-0.5B-Instruct": { + "qwen2-0.5b-instruct": { "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), }, } diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index 706255bc..e9be9218 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -50,6 +50,7 @@ +
Date: Sun, 6 Oct 2024 20:26:38 -0800 Subject: [PATCH 386/491] node inference fix --- exo/inference/pytorch/model/hf.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 80ec8f78..d5f8f68b 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -109,8 +109,15 @@ def forward( print(f"hidden_states: {self.hidden_states}") print(f"input_ids: {self.input_ids}") print(f"self.position_ids: {self.position_ids}") + print(f"past_key_values: {past_key_values}") + + # skip if there is a hidden state with position_ids already calculated + # if there is hidden states and no position_ids, will need to be calculated + # this is not needed for Qwen model but Llama requires it + if (self.hidden_states is None or + (self.hidden_states is not None and self.position_ids is None) + ): - if self.hidden_states is None: # embed input_ids self.inputs_embeds = self.model.embed_tokens(self.input_ids) @@ -228,6 +235,8 @@ def forward( if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") print(f"past_key_values: {self.past_key_values}") + print(f"position_ids: {self.position_ids}") + print(f"input_ids: {self.input_ids}") return ( self.hidden_states, From 77a52a57eefd789e094a550a7a6d9640a67c844d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 20:33:04 -0800 Subject: [PATCH 387/491] node inference fix --- exo/inference/pytorch/model/hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index d5f8f68b..5ba8615a 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -232,7 +232,8 @@ def forward( logits ) - if DEBUG >= 4: + if DEBUG >= 2: + print("hf out [no logit]") print(f"hidden_states: {self.hidden_states}") print(f"past_key_values: {self.past_key_values}") print(f"position_ids: {self.position_ids}") From 2b3397f459e22bf91423bba36324ed74ce6d4236 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 21:07:21 -0800 Subject: [PATCH 388/491] node inference fix --- exo/inference/pytorch/inference.py | 2 +- exo/inference/pytorch/model/hf.py | 38 ++++++++++++++---------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 03995975..971eaa7d 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -194,7 +194,7 @@ async def infer_tensor( self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids - self.past_input_ids = past_iids + #self.past_input_ids = past_iids else: if past_iids is not None: self.past_input_ids = past_iids diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 5ba8615a..f73a9712 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -114,28 +114,26 @@ def forward( # skip if there is a hidden state with position_ids already calculated # if there is hidden states and no position_ids, will need to be calculated # this is not needed for Qwen model but Llama requires it - if (self.hidden_states is None or - (self.hidden_states is not None and self.position_ids is None) - ): - - # embed input_ids - self.inputs_embeds = self.model.embed_tokens(self.input_ids) - - # cache - if past_key_values and not isinstance(past_key_values, Cache): - use_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + self.inputs_embeds.shape[1], - device=self.inputs_embeds.device - ) - # position id - position_ids = cache_position.unsqueeze(0) + # embed input_ids + self.inputs_embeds = self.model.embed_tokens(self.input_ids) + + # cache + if past_key_values and not isinstance(past_key_values, Cache): + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + self.inputs_embeds.shape[1], + device=self.inputs_embeds.device + ) + + # position id + position_ids = cache_position.unsqueeze(0) + if self.hidden_states is None: # casual mask and attention_mask self.attention_mask = attention_mask self.causal_mask = self.model._update_causal_mask( From 2e588afaa4493adc6fc3c2ccb4eb6be77a282939 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 21:10:18 -0800 Subject: [PATCH 389/491] node inference fix --- exo/inference/pytorch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 971eaa7d..03995975 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -194,7 +194,7 @@ async def infer_tensor( self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids - #self.past_input_ids = past_iids + self.past_input_ids = past_iids else: if past_iids is not None: self.past_input_ids = past_iids From e2eba0592aecdc205dd087c571a1384987617fef Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 21:13:11 -0800 Subject: [PATCH 390/491] node inference fix --- exo/inference/pytorch/model/hf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index f73a9712..872ea7e2 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -104,14 +104,6 @@ def forward( self.hidden_states = hidden_states self.input_ids = input_ids - if DEBUG >= 2: - print("hf forward called") - print(f"hidden_states: {self.hidden_states}") - print(f"input_ids: {self.input_ids}") - print(f"self.position_ids: {self.position_ids}") - print(f"past_key_values: {past_key_values}") - - # skip if there is a hidden state with position_ids already calculated # if there is hidden states and no position_ids, will need to be calculated # this is not needed for Qwen model but Llama requires it @@ -133,6 +125,14 @@ def forward( # position id position_ids = cache_position.unsqueeze(0) + if DEBUG >= 2: + print("hf forward called") + print(f"hidden_states: {self.hidden_states}") + print(f"input_ids: {self.input_ids}") + print(f"self.position_ids: {self.position_ids}") + print(f"past_key_values: {past_key_values}") + + if self.hidden_states is None: # casual mask and attention_mask self.attention_mask = attention_mask From d7699ebaf939da9f660dc39de983d399ef155d84 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 21:16:23 -0800 Subject: [PATCH 391/491] node inference fix --- exo/inference/pytorch/model/hf.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 872ea7e2..eb00e933 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -123,13 +123,14 @@ def forward( ) # position id - position_ids = cache_position.unsqueeze(0) + self.position_ids = cache_position.unsqueeze(0) if DEBUG >= 2: print("hf forward called") print(f"hidden_states: {self.hidden_states}") print(f"input_ids: {self.input_ids}") - print(f"self.position_ids: {self.position_ids}") + print(f"input_embeds: {self.inputs_embeds}") + print(f"position_ids: {self.position_ids}") print(f"past_key_values: {past_key_values}") @@ -148,7 +149,7 @@ def forward( if isinstance(self.model, LlamaModel): self.position_embeddings = self.model.rotary_emb( self.inputs_embeds, - position_ids + self.position_ids ) # prepare inputs for decoder layers @@ -157,7 +158,7 @@ def forward( past_key_values=past_key_values, attention_mask=self.attention_mask, inputs_embeds=self.inputs_embeds, - position_ids=position_ids, + position_ids=self.position_ids, cache_position=cache_position ) From bd9bf4f1afd3046730fc534ccadcc43bc066d67e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 21:28:38 -0800 Subject: [PATCH 392/491] inference between nodes fixed by always calculating position id and input embed from input_ids cache, working on vram mem management --- exo/inference/pytorch/model/hf.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index eb00e933..9c30d8ae 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -72,13 +72,12 @@ def __init__( def forward( self, - input_ids: Optional[torch.tensor] = None, - hidden_states: Optional[torch.tensor] = None, - attention_mask: Optional[torch.tensor] = None, + input_ids: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, use_legacy_cache: bool = False - ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: - + ) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]: """ Generate hidden states or logits via passing through set amount of layers of a model To be passed only input_ids OR hidden_state and not both. This is for connecting the model @@ -125,7 +124,7 @@ def forward( # position id self.position_ids = cache_position.unsqueeze(0) - if DEBUG >= 2: + if DEBUG >= 4: print("hf forward called") print(f"hidden_states: {self.hidden_states}") print(f"input_ids: {self.input_ids}") @@ -231,7 +230,7 @@ def forward( logits ) - if DEBUG >= 2: + if DEBUG >= 4: print("hf out [no logit]") print(f"hidden_states: {self.hidden_states}") print(f"past_key_values: {self.past_key_values}") @@ -246,18 +245,18 @@ def forward( def logits_sample( self, - logits: torch.tensor, + logits: torch.Tensor, use_max: Optional[bool] = False - ) -> torch.tensor: + ) -> torch.Tensor: """ Get a sample of the logits from end of model run for next token - + Args: - logits: tensor + logits: tensor use_max: bool, if function should sample with argmax Returns: - next_token: tensor + next_token: tensor """ # get a single cloned logit @@ -273,6 +272,6 @@ def logits_sample( if DEBUG >= 4: print(f"input_ids: {self.input_ids}") - print(f"next_token: {next_token}") + print(f"next_token: {next_token}") return next_token[:, None].squeeze(-1) From 913a00859e284aae186d8e1867f2afb3b0b26bfc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 7 Oct 2024 08:10:01 -0800 Subject: [PATCH 393/491] cleaning up code --- exo/inference/pytorch/inference.py | 56 +++++++++++++----------------- exo/inference/pytorch/model/hf.py | 19 +++++----- 2 files changed, 32 insertions(+), 43 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 03995975..11f8eddb 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -34,7 +34,7 @@ def __init__(self, shard_downloader: HFShardDownloader): Initialize the inference engine. Args: - debug (bool): If True, enables debug logging. Defaults to False. + shard_downloader: Model and weights sharding download """ self.shard = None self.shard_downloader = shard_downloader @@ -49,15 +49,15 @@ def __init__(self, shard_downloader: HFShardDownloader): # setup cuda device if os.environ.get("PYTORCH_DEVICE"): pytorch_device = os.environ["PYTOCH_DEVICE"] - if pytorch_device not in ["cuda", "mps", "cpu"]: + if pytorch_device not in ["cuda", "mps"]: pytorch_device = "cpu" self.device = pytorch_device - self.torch_dtype = torch.float32 if pytorch_device != "cpu" else torch.float16 + self.torch_dtype = torch.float16 if pytorch_device != "cpu" else torch.float32 if torch.cuda.is_available(): self.device = torch.device("cuda") - self.torch_dtype = torch.float16 + self.torch_dtype = torch.float32 elif torch.backends.mps.is_available(): self.device = torch.device("mps") self.torch_dtype = torch.float32 @@ -101,11 +101,11 @@ async def infer_prompt( self, request_id: str, shard: Shard, - prompt: str, - image_str: Optional[str] = None, + prompt: str, + image_str: Optional[str] = None, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 2: + if DEBUG >= 4: print("infer_prompt called") print(f"prompt: {prompt}") print(f"shard: {shard}") @@ -115,26 +115,25 @@ async def infer_prompt( inputs = self.tokenizer([prompt], return_tensors="pt") input_ids = inputs.input_ids.to(self.device) - input_attention_mask = inputs.attention_mask.to(self.device) - batch_size, seq_length = input_ids.shape[:2] + input_attention_mask = inputs.attention_mask.to(self.device) # get cache from inference_state past_iids, cached_iids = self.infer_caching(inference_state) if past_iids is not None: - self.past_input_ids = past_iids, + self.past_input_ids = past_iids else: self.past_input_ids = input_ids - if DEBUG >= 2: + if DEBUG >= 4: print(f"past_input_ids: {self.past_input_ids}\n") - + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=self.past_input_ids, attention_mask=input_attention_mask ) - if DEBUG >= 2: + if DEBUG >= 4: print(f"\nshard_hidden_states: {shard_hidden_states}\n") print(f"\nshard_past_kvs {shard_past_kvs}\n") print(f"\nshard_logits: {shard_logits}") @@ -152,7 +151,7 @@ async def infer_prompt( if next_token is not None: is_finished = next_token.item() == self.tokenizer.eos_token_id - if DEBUG >= 2: + if DEBUG >= 4: print(f"\ninput_ids: {input_ids}") print(f"\nshard_hidden_states: {shard_hidden_states}\n") print(f"\nshard_past_kvs {shard_past_kvs}\n") @@ -164,7 +163,7 @@ async def infer_prompt( is_finished ) - if DEBUG >= 2: + if DEBUG >= 4: print(f"return_values: {return_values}") return return_values @@ -176,7 +175,7 @@ async def infer_tensor( input_data: np.ndarray, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 2: + if DEBUG >= 4: print("infer_tensor called") print(f"input_data: {input_data}") print(f"shard: {shard}") @@ -201,7 +200,7 @@ async def infer_tensor( else: self.past_input_ids = input_ids - if DEBUG >= 2: + if DEBUG >= 4: print(f"past_input_ids: {self.past_input_ids}") print(f"hidden_state: {hidden_states}") print(f"inference_state: {inference_state}") @@ -211,22 +210,18 @@ async def infer_tensor( hidden_states=hidden_states ) - hidden_dict = None - if shard_hidden_states is not None: - hidden_dict = {"hidden_states": shard_hidden_states.tolist()} - - next_token = None + next_token = None if shard_logits is not None: next_token = self.stateful_sharded_model.logits_sample(shard_logits) input_ids = next_token - + #cache if next_token is not None: if self.past_input_ids is not None: next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) elif past_iids is not None: next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) - + cached_iids = {"input_ids": next_cached_logits.tolist()} is_finished = False @@ -234,10 +229,10 @@ async def infer_tensor( is_finished = next_token.item() == self.tokenizer.eos_token_id if is_finished: - # clear cache + # clear cache cached_iids = {"input_ids": []} - if DEBUG >= 2: + if DEBUG >= 4: print(f"\ninput_ids: {input_ids}") print(f"\nshard_hidden_states: {shard_hidden_states}\n") print(f"\nshard_past_kvs {shard_past_kvs}\n") @@ -245,16 +240,15 @@ async def infer_tensor( return_values = ( input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps({"cached_iids": cached_iids}), + json.dumps({"cached_iids": cached_iids}), is_finished ) - if DEBUG >= 2: + if DEBUG >= 4: print(f"return_values: {return_values}") return return_values - async def ensure_shard(self, shard: Shard): """ Ensure the model shard is loaded and ready for inference. @@ -265,12 +259,10 @@ async def ensure_shard(self, shard: Shard): if self.shard == shard: return - if DEBUG >= 2: + if DEBUG >= 4: print(f"Loading new shard: {shard}") model_path = await self.shard_downloader.ensure_shard(shard) - if DEBUG >= 2: - print(f"model_path: {model_path}") self.stateful_sharded_model = ShardedHuggingFaceModel( shard=shard, diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 9c30d8ae..d6038f04 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import numpy as np from typing import Tuple, Optional, Union, List from exo.inference.shard import Shard @@ -69,7 +68,6 @@ def __init__( print(f"error loading and splitting model: {err}") raise - def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -84,8 +82,8 @@ def forward( layer to generate a complete output Args: - model: base llm model tramsformers class - llm_model: llm chat model class + model: base llm model tramsformers class + llm_model: llm chat model class input_ids: tensor optional attention_mask: tensor optional past_key_values: Cache or list[tensor] optional @@ -95,7 +93,7 @@ def forward( Returns: Tuple of - hidden_states: tensor optional - - past_key_values: Cache or list[tensor] optional + - past_key_values: Cache or list[tensor] optional - logits: tensor Optional """ @@ -132,7 +130,6 @@ def forward( print(f"position_ids: {self.position_ids}") print(f"past_key_values: {past_key_values}") - if self.hidden_states is None: # casual mask and attention_mask self.attention_mask = attention_mask @@ -169,7 +166,7 @@ def forward( if DEBUG >= 4: print(f"model_inputs: {model_inputs}") - # run through decoder layers + # run through decoder layers layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) if DEBUG >= 4: @@ -186,8 +183,8 @@ def forward( print(f"position_embeddings: {self.position_embeddings}") # TODO: fix caching as decoder layer is not returning - # present_key_value from attention layer on models - # might have some other generation functions needed to do it + # present_key_value from attention layer on models + # might have some other generation functions needed to do it # see https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2917 # for qwen2 exhttps://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py#L291 layer_outputs = decoder_layer( @@ -217,8 +214,8 @@ def forward( self.past_key_values = self.next_decoder_cache.to_legacy_cache() else: self.past_key_values = self.next_decoder_cache - - # lm_head + + # lm_head logits = self.llm_model.lm_head(self.hidden_states).to(self.device) if DEBUG >= 4: From b518f73fcf4e803431b619a50012226aa7b92e78 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 7 Oct 2024 12:06:29 -0800 Subject: [PATCH 394/491] comma and other text issue fix --- exo/api/chatgpt_api.py | 9 +-------- exo/inference/pytorch/inference.py | 1 + 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index fe8cc590..9a65deae 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -71,16 +71,9 @@ def generate_completion( } choice = completion["choices"][0] - print(f"\nchoice {choice}") if object_type.startswith("chat.completion"): key_name = "delta" if stream else "message" - - token_decode = tokenizer.batch_decode( - tokens, - skip_special_tokens=True, - clean_up_tokenization_spaces=False - ) - choice[key_name] = {"role": "assistant", "content": token_decode} + choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)} elif object_type == "text_completion": choice["text"] = tokenizer.decode(tokens) else: diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 11f8eddb..676e3162 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -14,6 +14,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader from transformers import AutoTokenizer + # llama from transformers.models.llama.modeling_llama import LlamaModel From 9d2477952769a3415c60350e6291d12567ba8aef Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 8 Oct 2024 23:17:26 -0800 Subject: [PATCH 395/491] adding threadpooling to forward and logit sampling --- exo/inference/pytorch/inference.py | 66 ++++++++++++++++++++++++------ 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 676e3162..a613015e 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,11 +1,14 @@ # experimental, based off of tinygrad/inference.py +import asyncio import os import re import numpy as np import torch import json +import functools +from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Tuple +from typing import Optional, Tuple, Union, List from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel @@ -13,7 +16,7 @@ from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader -from transformers import AutoTokenizer +from transformers import AutoTokenizer, Cache # llama from transformers.models.llama.modeling_llama import LlamaModel @@ -39,8 +42,6 @@ def __init__(self, shard_downloader: HFShardDownloader): """ self.shard = None self.shard_downloader = shard_downloader - self.stateful_sharded_model = None - self.tokenizer = None # the whole history with new logits need to # be passed to the model to reach the end token @@ -59,15 +60,15 @@ def __init__(self, shard_downloader: HFShardDownloader): if torch.cuda.is_available(): self.device = torch.device("cuda") self.torch_dtype = torch.float32 - elif torch.backends.mps.is_available(): + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): self.device = torch.device("mps") self.torch_dtype = torch.float32 else: self.device = torch.device("cpu") self.torch_dtype = torch.float16 - # setup unfinished sequence - self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device) + # setup threadding + torch.set_num_threads(torch.get_num_threads()) def infer_caching( self, @@ -98,6 +99,44 @@ def infer_caching( return (past_iids, cached_iids) + async def async_forward( + self, + input_ids: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None + ) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]: + + loop = asyncio.get_running_loop() + + forward_partial = functools.partial( + self.stateful_sharded_model.forward, + input_ids=input_ids, + hidden_states=hidden_states, + attention_mask=attention_mask + ) + + with ThreadPoolExecutor() as pool: + result = await loop.run_in_executor(pool, forward_partial) + + return result + + async def async_logit_sample( + self, + logits: torch.Tensor + ) -> torch.Tensor: + + loop = asyncio.get_running_loop() + + sample_partial = functools.partial( + self.stateful_sharded_model.logits_sample, + logits=logits + ) + + with ThreadPoolExecutor() as pool: + result = await loop.run_in_executor(pool, sample_partial) + + return result + async def infer_prompt( self, request_id: str, @@ -129,7 +168,7 @@ async def infer_prompt( if DEBUG >= 4: print(f"past_input_ids: {self.past_input_ids}\n") - shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( input_ids=self.past_input_ids, attention_mask=input_attention_mask ) @@ -141,7 +180,7 @@ async def infer_prompt( next_token = None if shard_logits is not None: - next_token = self.stateful_sharded_model.logits_sample(shard_logits) + next_token = await self.async_logit_sample(shard_logits) self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) input_ids = next_token @@ -206,24 +245,27 @@ async def infer_tensor( print(f"hidden_state: {hidden_states}") print(f"inference_state: {inference_state}") - shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( input_ids=self.past_input_ids, hidden_states=hidden_states ) next_token = None if shard_logits is not None: - next_token = self.stateful_sharded_model.logits_sample(shard_logits) + next_token = await self.async_logit_sample(shard_logits) input_ids = next_token #cache + next_cached_logits = None if next_token is not None: if self.past_input_ids is not None: next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) elif past_iids is not None: next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) - cached_iids = {"input_ids": next_cached_logits.tolist()} + cached_iids = { + "input_ids": next_cached_logits.tolist() if next_cached_logits is not None else [] + } is_finished = False if next_token is not None: From d4fb74fa7e8a09348f56f0b237cd492179e30cac Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 10 Oct 2024 13:09:58 -0700 Subject: [PATCH 396/491] rename (PyTorch, pytorch) -> (Torch, torch) --- exo/inference/inference_engine.py | 6 +++--- exo/inference/{pytorch => torch}/.gitignore | 0 exo/inference/{pytorch => torch}/README.md | 0 exo/inference/{pytorch => torch}/__init__.py | 0 exo/inference/{pytorch => torch}/inference.py | 18 +++++++++--------- .../{pytorch => torch}/model/__init__.py | 0 exo/inference/{pytorch => torch}/model/hf.py | 0 .../{pytorch => torch}/tests/__init__.py | 0 .../tests/test_inference_engine.py | 10 +++++----- .../tests/test_simple_model.py | 0 .../tests/test_split_model.py | 0 .../{pytorch => torch}/tests/utils.py | 0 exo/models.py | 10 +++++----- 13 files changed, 22 insertions(+), 22 deletions(-) rename exo/inference/{pytorch => torch}/.gitignore (100%) rename exo/inference/{pytorch => torch}/README.md (100%) rename exo/inference/{pytorch => torch}/__init__.py (100%) rename exo/inference/{pytorch => torch}/inference.py (94%) rename exo/inference/{pytorch => torch}/model/__init__.py (100%) rename exo/inference/{pytorch => torch}/model/hf.py (100%) rename exo/inference/{pytorch => torch}/tests/__init__.py (100%) rename exo/inference/{pytorch => torch}/tests/test_inference_engine.py (90%) rename exo/inference/{pytorch => torch}/tests/test_simple_model.py (100%) rename exo/inference/{pytorch => torch}/tests/test_split_model.py (100%) rename exo/inference/{pytorch => torch}/tests/utils.py (100%) diff --git a/exo/inference/inference_engine.py b/exo/inference/inference_engine.py index 2b98adbe..7fd7528b 100644 --- a/exo/inference/inference_engine.py +++ b/exo/inference/inference_engine.py @@ -27,8 +27,8 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) return TinygradDynamicShardInferenceEngine(shard_downloader) - elif inference_engine_name == "pytorch": - from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine - return PyTorchDynamicShardInferenceEngine(shard_downloader) + elif inference_engine_name == "torch": + from exo.inference.torch.inference import TorchDynamicShardInferenceEngine + return TorchDynamicShardInferenceEngine(shard_downloader) else: raise ValueError(f"Inference engine {inference_engine_name} not supported") diff --git a/exo/inference/pytorch/.gitignore b/exo/inference/torch/.gitignore similarity index 100% rename from exo/inference/pytorch/.gitignore rename to exo/inference/torch/.gitignore diff --git a/exo/inference/pytorch/README.md b/exo/inference/torch/README.md similarity index 100% rename from exo/inference/pytorch/README.md rename to exo/inference/torch/README.md diff --git a/exo/inference/pytorch/__init__.py b/exo/inference/torch/__init__.py similarity index 100% rename from exo/inference/pytorch/__init__.py rename to exo/inference/torch/__init__.py diff --git a/exo/inference/pytorch/inference.py b/exo/inference/torch/inference.py similarity index 94% rename from exo/inference/pytorch/inference.py rename to exo/inference/torch/inference.py index a613015e..093724c1 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/torch/inference.py @@ -11,7 +11,7 @@ from typing import Optional, Tuple, Union, List from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine -from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel +from exo.inference.torch.model.hf import ShardedHuggingFaceModel from exo.inference.tokenizers import resolve_tokenizer from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader @@ -28,9 +28,9 @@ MAX_LENGTH = 125 MAX_TIME = 60.0 -class PyTorchDynamicShardInferenceEngine(InferenceEngine): +class TorchDynamicShardInferenceEngine(InferenceEngine): """ - PyTorch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. + Torch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. """ def __init__(self, shard_downloader: HFShardDownloader): @@ -49,13 +49,13 @@ def __init__(self, shard_downloader: HFShardDownloader): self.past_input_ids = None # setup cuda device - if os.environ.get("PYTORCH_DEVICE"): - pytorch_device = os.environ["PYTOCH_DEVICE"] - if pytorch_device not in ["cuda", "mps"]: - pytorch_device = "cpu" + if os.environ.get("TORCH_DEVICE"): + torch_device = os.environ["PYTOCH_DEVICE"] + if torch_device not in ["cuda", "mps"]: + torch_device = "cpu" - self.device = pytorch_device - self.torch_dtype = torch.float16 if pytorch_device != "cpu" else torch.float32 + self.device = torch_device + self.torch_dtype = torch.float16 if torch_device != "cpu" else torch.float32 if torch.cuda.is_available(): self.device = torch.device("cuda") diff --git a/exo/inference/pytorch/model/__init__.py b/exo/inference/torch/model/__init__.py similarity index 100% rename from exo/inference/pytorch/model/__init__.py rename to exo/inference/torch/model/__init__.py diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/torch/model/hf.py similarity index 100% rename from exo/inference/pytorch/model/hf.py rename to exo/inference/torch/model/hf.py diff --git a/exo/inference/pytorch/tests/__init__.py b/exo/inference/torch/tests/__init__.py similarity index 100% rename from exo/inference/pytorch/tests/__init__.py rename to exo/inference/torch/tests/__init__.py diff --git a/exo/inference/pytorch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py similarity index 90% rename from exo/inference/pytorch/tests/test_inference_engine.py rename to exo/inference/torch/tests/test_inference_engine.py index 854d9b9c..e8b0b14c 100644 --- a/exo/inference/pytorch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -1,7 +1,7 @@ import asyncio from exo.inference.shard import Shard -from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine +from exo.inference.torch.inference import TorchDynamicShardInferenceEngine from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.inference_engine import InferenceEngine from exo.inference.shard import Shard @@ -120,8 +120,8 @@ async def test_inference_engine( # try: # print("\n\n -------- TEST QWEN2 -------- \n\n") # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # TorchDynamicShardInferenceEngine(HFShardDownloader()), + # TorchDynamicShardInferenceEngine(HFShardDownloader()), # "Qwen/Qwen2-0.5B-Instruct", # 24 # )) @@ -131,8 +131,8 @@ async def test_inference_engine( try: print("\n-------- Test meta-llama/Llama-3.2-1B-Instruct ----------\n") asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + TorchDynamicShardInferenceEngine(HFShardDownloader()), + TorchDynamicShardInferenceEngine(HFShardDownloader()), "meta-llama/Llama-3.2-1B-Instruct", 24 )) diff --git a/exo/inference/pytorch/tests/test_simple_model.py b/exo/inference/torch/tests/test_simple_model.py similarity index 100% rename from exo/inference/pytorch/tests/test_simple_model.py rename to exo/inference/torch/tests/test_simple_model.py diff --git a/exo/inference/pytorch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py similarity index 100% rename from exo/inference/pytorch/tests/test_split_model.py rename to exo/inference/torch/tests/test_split_model.py diff --git a/exo/inference/pytorch/tests/utils.py b/exo/inference/torch/tests/utils.py similarity index 100% rename from exo/inference/pytorch/tests/utils.py rename to exo/inference/torch/tests/utils.py diff --git a/exo/models.py b/exo/models.py index 7d6bc30a..b6a7092b 100644 --- a/exo/models.py +++ b/exo/models.py @@ -4,7 +4,7 @@ ### llama "llama-3.2-1b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16), - "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), + "TorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), }, "llama-3.2-3b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), @@ -12,7 +12,7 @@ "llama-3.1-8b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32), - "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), + "TorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), }, "llama-3.1-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), @@ -32,10 +32,10 @@ "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80), }, "llama-3-2B-Base": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=6), + "TorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=6), }, "llama-3-1B-Base": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), + "TorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, @@ -68,6 +68,6 @@ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), }, "qwen2-0.5b-instruct": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), + "TorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), }, } From edf1c3d0003b488ab27443fadf217fe554c5faa2 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 10 Oct 2024 13:13:20 -0700 Subject: [PATCH 397/491] add ci jobs for chatgpt_api_integration_test_torch_linux_cpu and chatgpt_api_integration_test_torch_mac --- .circleci/config.yml | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index ba5f5968..92b4d6e1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -178,6 +178,50 @@ jobs: inference_engine: mlx model_id: llama-3.1-8b + chatgpt_api_integration_test_torch_linux_cpu: + machine: + image: ubuntu-2404:2024.08.1 + resource_class: large + steps: + - checkout + - run: + name: Set up Python + command: | + brew install python@3.12 + python3.12 -m venv env + source env/bin/activate + - run: + name: Install dependencies + command: | + source env/bin/activate + pip install --upgrade pip + pip install . + - run_chatgpt_api_test: + inference_engine: torch + model_id: llama-3.2-1b + + chatgpt_api_integration_test_torch_mac: + macos: + xcode: "15.4.0" + resource_class: macos.m1.large.gen1 + steps: + - checkout + - run: + name: Set up Python + command: | + brew install python@3.12 + python3.12 -m venv env + source env/bin/activate + - run: + name: Install dependencies + command: | + source env/bin/activate + pip install --upgrade pip + pip install . + - run_chatgpt_api_test: + inference_engine: torch + model_id: llama-3.2-1b + test_macos_m1: macos: xcode: "15.4.0" From 0fd6711723debb03524b7c0878655e16fb2942b2 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 10 Oct 2024 13:17:49 -0700 Subject: [PATCH 398/491] add ci jobs for chatgpt_api_integration_test_torch_linux_cpu and chatgpt_api_integration_test_torch_mac --- .circleci/config.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 92b4d6e1..c9adbd2c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -260,4 +260,6 @@ workflows: - discovery_integration_test - chatgpt_api_integration_test_mlx - test_macos_m1 + - chatgpt_api_integration_test_torch_linux_cpu + - chatgpt_api_integration_test_torch_mac # - chatgpt_api_integration_test_tinygrad \ No newline at end of file From a4feeab9bd955c178fb75ec9c8a560e9d39a1bae Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 10 Oct 2024 13:26:18 -0700 Subject: [PATCH 399/491] ci filters --- .circleci/config.yml | 66 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c9adbd2c..e9b23b6f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -256,10 +256,62 @@ workflows: version: 2 build_and_test: jobs: - - unit_test - - discovery_integration_test - - chatgpt_api_integration_test_mlx - - test_macos_m1 - - chatgpt_api_integration_test_torch_linux_cpu - - chatgpt_api_integration_test_torch_mac - # - chatgpt_api_integration_test_tinygrad \ No newline at end of file + - approve_run: + type: approval + requires: [] + filters: + branches: + ignore: main + - unit_test: + requires: + - approve_run + - discovery_integration_test: + requires: + - approve_run + - chatgpt_api_integration_test_mlx: + requires: + - approve_run + - test_macos_m1: + requires: + - approve_run + - chatgpt_api_integration_test_torch_linux_cpu: + requires: + - approve_run + - chatgpt_api_integration_test_torch_mac: + requires: + - approve_run + # - chatgpt_api_integration_test_tinygrad: + # requires: + # - approve_run + + # Run jobs without approval on the main branch + main_branch_workflow: + jobs: + - unit_test: + filters: + branches: + only: main + - discovery_integration_test: + filters: + branches: + only: main + - chatgpt_api_integration_test_mlx: + filters: + branches: + only: main + - test_macos_m1: + filters: + branches: + only: main + - chatgpt_api_integration_test_torch_linux_cpu: + filters: + branches: + only: main + - chatgpt_api_integration_test_torch_mac: + filters: + branches: + only: main + # - chatgpt_api_integration_test_tinygrad: + # filters: + # branches: + # only: main \ No newline at end of file From 55fd48247d985a9b18ef3bda992a2d6e293b6330 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 10 Oct 2024 13:34:34 -0700 Subject: [PATCH 400/491] rm comments --- .circleci/config.yml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index e9b23b6f..71fada3d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -280,9 +280,6 @@ workflows: - chatgpt_api_integration_test_torch_mac: requires: - approve_run - # - chatgpt_api_integration_test_tinygrad: - # requires: - # - approve_run # Run jobs without approval on the main branch main_branch_workflow: @@ -311,7 +308,3 @@ workflows: filters: branches: only: main - # - chatgpt_api_integration_test_tinygrad: - # filters: - # branches: - # only: main \ No newline at end of file From da39519fe8735e20286d0dcb72b621811767e3da Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 10 Oct 2024 13:38:32 -0700 Subject: [PATCH 401/491] ci --- .circleci/config.yml | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 71fada3d..b1438335 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -281,7 +281,23 @@ workflows: requires: - approve_run - # Run jobs without approval on the main branch + # Workflow for forked PRs without approval + forked_pr_workflow: + jobs: + - unit_test + - discovery_integration_test + - chatgpt_api_integration_test_mlx + - test_macos_m1 + - chatgpt_api_integration_test_torch_linux_cpu + - chatgpt_api_integration_test_torch_mac + # The trigger condition ensures this workflow runs for forked PRs + triggers: + - type: pull_request + filters: + branches: + ignore: main + + # Existing workflow for main branch main_branch_workflow: jobs: - unit_test: @@ -307,4 +323,4 @@ workflows: - chatgpt_api_integration_test_torch_mac: filters: branches: - only: main + only: main \ No newline at end of file From 5eb6c34fb2e6a4e82ba680b468a662a97f5d9509 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 10 Oct 2024 17:32:51 -0800 Subject: [PATCH 402/491] fixed torch device selection --- exo/inference/pytorch/README.md | 31 +++++++----------------------- exo/inference/pytorch/inference.py | 26 ++++++++++++------------- 2 files changed, 19 insertions(+), 38 deletions(-) diff --git a/exo/inference/pytorch/README.md b/exo/inference/pytorch/README.md index 670c8df6..5cbeeef6 100644 --- a/exo/inference/pytorch/README.md +++ b/exo/inference/pytorch/README.md @@ -1,26 +1,9 @@ # PyTorch & HuggingFace inference engine -Experimental, still under development - -## Install -Install needed py modules, make sure to be using CUDA 12.4 for the PyTorch install - -```console -$ pip install torch --index-url https://download.pytorch.org/whl/cu124 -$ pip install transformers accelerate -``` - -After installing accelerate you get hit with a dependency error, for now ignore until we can fix this as exo works fine with 1.26.4 - -```console -ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. -exo 0.0.1 requires numpy==2.0.0, but you have numpy 1.26.4 which is incompatible. -``` - -## Low VRAM Notes - -- When trying to do disk_offload getting the error "Cannot copy out of meta tensor; no data!", looking up the error it is tied to (low vram)[https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13087#issuecomment-2080272004] - -## Multiple GPU in 1 Notes -### Running multiple GPUs on 1 machine -- Getting error "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_CUDA_cat)" +## Notes/Issues +### 10/10/2024 +- To select a pytorch device via environment variables, set the variable TORCH_DEVICE +- - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM +- - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373) +- - Looking into adding mobile device support properly +- If device is not CPU the data type defaults to float32 else float16. diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index a613015e..04bda300 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -2,12 +2,14 @@ import asyncio import os import re -import numpy as np -import torch import json import functools from concurrent.futures import ThreadPoolExecutor +import numpy as np + +import torch + from typing import Optional, Tuple, Union, List from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine @@ -49,23 +51,19 @@ def __init__(self, shard_downloader: HFShardDownloader): self.past_input_ids = None # setup cuda device - if os.environ.get("PYTORCH_DEVICE"): - pytorch_device = os.environ["PYTOCH_DEVICE"] - if pytorch_device not in ["cuda", "mps"]: - pytorch_device = "cpu" - - self.device = pytorch_device - self.torch_dtype = torch.float16 if pytorch_device != "cpu" else torch.float32 - - if torch.cuda.is_available(): + if os.environ.get("TORCH_DEVICE"): + self.device = torch.device(os.environ["TORCH_DEVICE"]) + elif torch.cuda.is_available(): self.device = torch.device("cuda") - self.torch_dtype = torch.float32 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): self.device = torch.device("mps") - self.torch_dtype = torch.float32 else: self.device = torch.device("cpu") - self.torch_dtype = torch.float16 + + torch.set_default_device(self.device) + + # setup cude dtype + self.torch_dtype = torch.float32 if self.device != torch.device('cpu') else torch.float16 # setup threadding torch.set_num_threads(torch.get_num_threads()) From 18d41ebf79ebd288a4fe3e1d571d420d54ff6a80 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 10 Oct 2024 17:46:12 -0800 Subject: [PATCH 403/491] fixing imports --- exo/inference/pytorch/inference.py | 332 +++++++++++++++++++++++++++++ 1 file changed, 332 insertions(+) create mode 100644 exo/inference/pytorch/inference.py diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py new file mode 100644 index 00000000..0684aa13 --- /dev/null +++ b/exo/inference/pytorch/inference.py @@ -0,0 +1,332 @@ +# experimental, based off of tinygrad/inference.py +import asyncio +import os +import re +import json +import functools +from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Tuple, Union, List + +import numpy as np +import torch + +from exo.inference.shard import Shard +from exo.inference.inference_engine import InferenceEngine +from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel +from exo.inference.tokenizers import resolve_tokenizer +from exo.helpers import DEBUG +from exo.download.hf.hf_shard_download import HFShardDownloader + +from transformers import AutoTokenizer, Cache +# llama +from transformers.models.llama.modeling_llama import LlamaModel + +# model value options +TOP_K = 20 +TEMP = 0.6 +TOP_P = 0.9 +MAX_LENGTH = 125 +MAX_TIME = 60.0 + +class PyTorchDynamicShardInferenceEngine(InferenceEngine): + """ + PyTorch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. + """ + + def __init__(self, shard_downloader: HFShardDownloader): + """ + Initialize the inference engine. + + Args: + shard_downloader: Model and weights sharding download + """ + self.shard = None + self.shard_downloader = shard_downloader + + # the whole history with new logits need to + # be passed to the model to reach the end token + # even with caching + self.past_input_ids = None + + # setup cuda device + if os.environ.get("TORCH_DEVICE"): + self.device = torch.device(os.environ["TORCH_DEVICE"]) + elif torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + + torch.set_default_device(self.device) + + # setup cude dtype + self.torch_dtype = torch.float32 if self.device != torch.device('cpu') else torch.float16 + + # setup threadding + torch.set_num_threads(torch.get_num_threads()) + + def infer_caching( + self, + inference_state: Optional[str] = None + ) -> Tuple[Optional[torch.Tensor], Optional[dict]]: + """ + inference caching from inference_state json + """ + # setup cache and cached input_ids + past_iids = None + cached_iids = None + if inference_state is not None: + try: + infer_state = json.loads(inference_state) + except ValueError: + infer_state = None + + if infer_state is not None: + cached_iids = infer_state["cached_iids"] + if cached_iids is not None: + past_iids = None + if len(cached_iids) > 0: + past_iids = torch.tensor(cached_iids["input_ids"]).to(self.device) + cached_iids = {"input_ids": past_iids.tolist()} + + if DEBUG >= 4: + print(f"cached_iids: {cached_iids}") + + return (past_iids, cached_iids) + + async def async_forward( + self, + input_ids: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None + ) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]: + + loop = asyncio.get_running_loop() + + forward_partial = functools.partial( + self.stateful_sharded_model.forward, + input_ids=input_ids, + hidden_states=hidden_states, + attention_mask=attention_mask + ) + + with ThreadPoolExecutor() as pool: + result = await loop.run_in_executor(pool, forward_partial) + + return result + + async def async_logit_sample( + self, + logits: torch.Tensor + ) -> torch.Tensor: + + loop = asyncio.get_running_loop() + + sample_partial = functools.partial( + self.stateful_sharded_model.logits_sample, + logits=logits + ) + + with ThreadPoolExecutor() as pool: + result = await loop.run_in_executor(pool, sample_partial) + + return result + + async def infer_prompt( + self, + request_id: str, + shard: Shard, + prompt: str, + image_str: Optional[str] = None, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_prompt called") + print(f"prompt: {prompt}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + + await self.ensure_shard(shard) + + inputs = self.tokenizer([prompt], return_tensors="pt") + input_ids = inputs.input_ids.to(self.device) + input_attention_mask = inputs.attention_mask.to(self.device) + + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) + + if past_iids is not None: + self.past_input_ids = past_iids + else: + self.past_input_ids = input_ids + + if DEBUG >= 4: + print(f"past_input_ids: {self.past_input_ids}\n") + + shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( + input_ids=self.past_input_ids, + attention_mask=input_attention_mask + ) + + if DEBUG >= 4: + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + next_token = None + if shard_logits is not None: + next_token = await self.async_logit_sample(shard_logits) + self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) + input_ids = next_token + + if self.past_input_ids is not None: + cached_iids = {"input_ids": self.past_input_ids.tolist()} + + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id + + if DEBUG >= 4: + print(f"\ninput_ids: {input_ids}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), + is_finished + ) + + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_tensor called") + print(f"input_data: {input_data}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + + await self.ensure_shard(shard) + + input_ids = torch.tensor(input_data).to(self.device) + + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) + + # detect if hidden_states or not + hidden_states = None + self.past_input_ids = None + if input_ids.size()[-1] > 1: + hidden_states = input_ids + self.past_input_ids = past_iids + else: + if past_iids is not None: + self.past_input_ids = past_iids + else: + self.past_input_ids = input_ids + + if DEBUG >= 4: + print(f"past_input_ids: {self.past_input_ids}") + print(f"hidden_state: {hidden_states}") + print(f"inference_state: {inference_state}") + + shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( + input_ids=self.past_input_ids, + hidden_states=hidden_states + ) + + next_token = None + if shard_logits is not None: + next_token = await self.async_logit_sample(shard_logits) + input_ids = next_token + + #cache + next_cached_logits = None + if next_token is not None: + if self.past_input_ids is not None: + next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) + elif past_iids is not None: + next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) + + cached_iids = { + "input_ids": next_cached_logits.tolist() if next_cached_logits is not None else [] + } + + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id + + if is_finished: + # clear cache + cached_iids = {"input_ids": []} + + if DEBUG >= 4: + print(f"\ninput_ids: {input_ids}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), + is_finished + ) + + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + + async def ensure_shard(self, shard: Shard): + """ + Ensure the model shard is loaded and ready for inference. + + Args: + shard (Optional[Shard]): Shard information for the model. + """ + if self.shard == shard: + return + + if DEBUG >= 4: + print(f"Loading new shard: {shard}") + + model_path = await self.shard_downloader.ensure_shard(shard) + + self.stateful_sharded_model = ShardedHuggingFaceModel( + shard=shard, + local_model_path=model_path, + device=self.device, + dtype=self.torch_dtype, + top_k=TOP_K, + temp=TEMP, + top_p=TOP_P, + max_length=MAX_LENGTH, + max_time=MAX_TIME + ) + self.shard = shard + + if isinstance(self.stateful_sharded_model.model, LlamaModel): + self.tokenizer = AutoTokenizer.from_pretrained( + model_path if model_path is not None else shard.model_id, + trust_remote_code=True + ) + + if len(re.findall(r"3\.1", shard.model_id)) > 0: + self.tokenizer.add_special_tokens({"pad_token":""}) + + else: + self.tokenizer = await resolve_tokenizer(shard.model_id) + + if DEBUG >= 4: + print(f"Shard loaded successfully: {shard}") From 9ecbf0c0f138950b5adb0839cb8a5e948dc10897 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 10 Oct 2024 17:48:40 -0800 Subject: [PATCH 404/491] fixing chatgpt_api mistake --- exo/api/chatgpt_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index befc1b43..7b7be502 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -125,7 +125,7 @@ def build_prompt(tokenizer, _messages: List[Message]): continue for content in message.content: - # note: wae only support one image at time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 + # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 # follows the convention in https://platform.openai.com/docs/guides/vision if isinstance(content, dict) and content.get("type", None) == "image": image_str = content.get("image", None) From dae2cbe6b6b1683deac126b58f66813a5b162cd1 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 10 Oct 2024 17:54:07 -0800 Subject: [PATCH 405/491] removing old pytorch folder --- exo/inference/pytorch/inference.py | 332 ----------------------------- 1 file changed, 332 deletions(-) delete mode 100644 exo/inference/pytorch/inference.py diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py deleted file mode 100644 index 0684aa13..00000000 --- a/exo/inference/pytorch/inference.py +++ /dev/null @@ -1,332 +0,0 @@ -# experimental, based off of tinygrad/inference.py -import asyncio -import os -import re -import json -import functools -from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Tuple, Union, List - -import numpy as np -import torch - -from exo.inference.shard import Shard -from exo.inference.inference_engine import InferenceEngine -from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel -from exo.inference.tokenizers import resolve_tokenizer -from exo.helpers import DEBUG -from exo.download.hf.hf_shard_download import HFShardDownloader - -from transformers import AutoTokenizer, Cache -# llama -from transformers.models.llama.modeling_llama import LlamaModel - -# model value options -TOP_K = 20 -TEMP = 0.6 -TOP_P = 0.9 -MAX_LENGTH = 125 -MAX_TIME = 60.0 - -class PyTorchDynamicShardInferenceEngine(InferenceEngine): - """ - PyTorch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. - """ - - def __init__(self, shard_downloader: HFShardDownloader): - """ - Initialize the inference engine. - - Args: - shard_downloader: Model and weights sharding download - """ - self.shard = None - self.shard_downloader = shard_downloader - - # the whole history with new logits need to - # be passed to the model to reach the end token - # even with caching - self.past_input_ids = None - - # setup cuda device - if os.environ.get("TORCH_DEVICE"): - self.device = torch.device(os.environ["TORCH_DEVICE"]) - elif torch.cuda.is_available(): - self.device = torch.device("cuda") - elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): - self.device = torch.device("mps") - else: - self.device = torch.device("cpu") - - torch.set_default_device(self.device) - - # setup cude dtype - self.torch_dtype = torch.float32 if self.device != torch.device('cpu') else torch.float16 - - # setup threadding - torch.set_num_threads(torch.get_num_threads()) - - def infer_caching( - self, - inference_state: Optional[str] = None - ) -> Tuple[Optional[torch.Tensor], Optional[dict]]: - """ - inference caching from inference_state json - """ - # setup cache and cached input_ids - past_iids = None - cached_iids = None - if inference_state is not None: - try: - infer_state = json.loads(inference_state) - except ValueError: - infer_state = None - - if infer_state is not None: - cached_iids = infer_state["cached_iids"] - if cached_iids is not None: - past_iids = None - if len(cached_iids) > 0: - past_iids = torch.tensor(cached_iids["input_ids"]).to(self.device) - cached_iids = {"input_ids": past_iids.tolist()} - - if DEBUG >= 4: - print(f"cached_iids: {cached_iids}") - - return (past_iids, cached_iids) - - async def async_forward( - self, - input_ids: Optional[torch.Tensor] = None, - hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None - ) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]: - - loop = asyncio.get_running_loop() - - forward_partial = functools.partial( - self.stateful_sharded_model.forward, - input_ids=input_ids, - hidden_states=hidden_states, - attention_mask=attention_mask - ) - - with ThreadPoolExecutor() as pool: - result = await loop.run_in_executor(pool, forward_partial) - - return result - - async def async_logit_sample( - self, - logits: torch.Tensor - ) -> torch.Tensor: - - loop = asyncio.get_running_loop() - - sample_partial = functools.partial( - self.stateful_sharded_model.logits_sample, - logits=logits - ) - - with ThreadPoolExecutor() as pool: - result = await loop.run_in_executor(pool, sample_partial) - - return result - - async def infer_prompt( - self, - request_id: str, - shard: Shard, - prompt: str, - image_str: Optional[str] = None, - inference_state: Optional[str] = None - ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 4: - print("infer_prompt called") - print(f"prompt: {prompt}") - print(f"shard: {shard}") - print(f"inference_state: {inference_state}") - - await self.ensure_shard(shard) - - inputs = self.tokenizer([prompt], return_tensors="pt") - input_ids = inputs.input_ids.to(self.device) - input_attention_mask = inputs.attention_mask.to(self.device) - - # get cache from inference_state - past_iids, cached_iids = self.infer_caching(inference_state) - - if past_iids is not None: - self.past_input_ids = past_iids - else: - self.past_input_ids = input_ids - - if DEBUG >= 4: - print(f"past_input_ids: {self.past_input_ids}\n") - - shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( - input_ids=self.past_input_ids, - attention_mask=input_attention_mask - ) - - if DEBUG >= 4: - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - - next_token = None - if shard_logits is not None: - next_token = await self.async_logit_sample(shard_logits) - self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) - input_ids = next_token - - if self.past_input_ids is not None: - cached_iids = {"input_ids": self.past_input_ids.tolist()} - - is_finished = False - if next_token is not None: - is_finished = next_token.item() == self.tokenizer.eos_token_id - - if DEBUG >= 4: - print(f"\ninput_ids: {input_ids}") - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - - return_values = ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps({"cached_iids": cached_iids}), - is_finished - ) - - if DEBUG >= 4: - print(f"return_values: {return_values}") - - return return_values - - async def infer_tensor( - self, - request_id: str, - shard: Shard, - input_data: np.ndarray, - inference_state: Optional[str] = None - ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 4: - print("infer_tensor called") - print(f"input_data: {input_data}") - print(f"shard: {shard}") - print(f"inference_state: {inference_state}") - - await self.ensure_shard(shard) - - input_ids = torch.tensor(input_data).to(self.device) - - # get cache from inference_state - past_iids, cached_iids = self.infer_caching(inference_state) - - # detect if hidden_states or not - hidden_states = None - self.past_input_ids = None - if input_ids.size()[-1] > 1: - hidden_states = input_ids - self.past_input_ids = past_iids - else: - if past_iids is not None: - self.past_input_ids = past_iids - else: - self.past_input_ids = input_ids - - if DEBUG >= 4: - print(f"past_input_ids: {self.past_input_ids}") - print(f"hidden_state: {hidden_states}") - print(f"inference_state: {inference_state}") - - shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( - input_ids=self.past_input_ids, - hidden_states=hidden_states - ) - - next_token = None - if shard_logits is not None: - next_token = await self.async_logit_sample(shard_logits) - input_ids = next_token - - #cache - next_cached_logits = None - if next_token is not None: - if self.past_input_ids is not None: - next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) - elif past_iids is not None: - next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) - - cached_iids = { - "input_ids": next_cached_logits.tolist() if next_cached_logits is not None else [] - } - - is_finished = False - if next_token is not None: - is_finished = next_token.item() == self.tokenizer.eos_token_id - - if is_finished: - # clear cache - cached_iids = {"input_ids": []} - - if DEBUG >= 4: - print(f"\ninput_ids: {input_ids}") - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - - return_values = ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps({"cached_iids": cached_iids}), - is_finished - ) - - if DEBUG >= 4: - print(f"return_values: {return_values}") - - return return_values - - async def ensure_shard(self, shard: Shard): - """ - Ensure the model shard is loaded and ready for inference. - - Args: - shard (Optional[Shard]): Shard information for the model. - """ - if self.shard == shard: - return - - if DEBUG >= 4: - print(f"Loading new shard: {shard}") - - model_path = await self.shard_downloader.ensure_shard(shard) - - self.stateful_sharded_model = ShardedHuggingFaceModel( - shard=shard, - local_model_path=model_path, - device=self.device, - dtype=self.torch_dtype, - top_k=TOP_K, - temp=TEMP, - top_p=TOP_P, - max_length=MAX_LENGTH, - max_time=MAX_TIME - ) - self.shard = shard - - if isinstance(self.stateful_sharded_model.model, LlamaModel): - self.tokenizer = AutoTokenizer.from_pretrained( - model_path if model_path is not None else shard.model_id, - trust_remote_code=True - ) - - if len(re.findall(r"3\.1", shard.model_id)) > 0: - self.tokenizer.add_special_tokens({"pad_token":""}) - - else: - self.tokenizer = await resolve_tokenizer(shard.model_id) - - if DEBUG >= 4: - print(f"Shard loaded successfully: {shard}") From 55ae0271d995e073cc7df10b47075289fd607ac0 Mon Sep 17 00:00:00 2001 From: Vincent C Date: Thu, 10 Oct 2024 17:56:33 -0800 Subject: [PATCH 406/491] Update README.md cleaning up readme --- exo/inference/torch/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/inference/torch/README.md b/exo/inference/torch/README.md index 5cbeeef6..59e73f7c 100644 --- a/exo/inference/torch/README.md +++ b/exo/inference/torch/README.md @@ -3,7 +3,7 @@ ## Notes/Issues ### 10/10/2024 - To select a pytorch device via environment variables, set the variable TORCH_DEVICE -- - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM -- - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373) -- - Looking into adding mobile device support properly + - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM + - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373) + - Looking into adding mobile device support properly - If device is not CPU the data type defaults to float32 else float16. From 4b6a86d8f9bcce40df4177bf188e6c195375af6e Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Fri, 11 Oct 2024 16:38:45 -0700 Subject: [PATCH 407/491] set all torch models in models.py --- exo/models.py | 44 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/exo/models.py b/exo/models.py index b6a7092b..3cefe889 100644 --- a/exo/models.py +++ b/exo/models.py @@ -4,32 +4,39 @@ ### llama "llama-3.2-1b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16), - "TorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), }, "llama-3.2-3b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-3B-Instruct", start_layer=0, end_layer=0, n_layers=28), }, "llama-3.1-8b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32), - "TorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Meta-Llama-3.1-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), }, "llama-3.1-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80), }, "llama-3.1-70b-bf16": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80), "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80), + }, + "llama-3.1-405b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126), }, - "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),}, "llama-3-8b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32), + "TorchDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), }, "llama-3-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80), + "TorchDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3-70B-Instruct", start_layer=0, end_layer=0, n_layers=80), }, "llama-3-2B-Base": { "TorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=6), @@ -38,34 +45,55 @@ "TorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), }, ### mistral - "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, - "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),}, + "mistral-nemo": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Mistral-Nemo-Instruct-2407", start_layer=0, end_layer=0, n_layers=40), + }, + "mistral-large": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88), + "TorchDynamicShardInferenceEngine": Shard(model_id="mistralai/Mistral-Large-Instruct-2407", start_layer=0, end_layer=0, n_layers=88), + }, ### deepseek - "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),}, - "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),}, + "deepseek-coder-v2-lite": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27), + "TorchDynamicShardInferenceEngine": Shard(model_id="deepseek-ai/DeepSeek-V2-Lite", start_layer=0, end_layer=0, n_layers=27), + }, + "deepseek-coder-v2.5": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60), + "TorchDynamicShardInferenceEngine": Shard(model_id="deepseek-ai/DeepSeek-V2.5", start_layer=0, end_layer=0, n_layers=60), + }, ### llava - "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),}, + "llava-1.5-7b-hf": { + "MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32), + }, ### qwen "qwen-2.5-coder-1.5b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-Coder-1.5B-Instruct", start_layer=0, end_layer=0, n_layers=28), }, "qwen-2.5-coder-7b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-Coder-7B-Instruct", start_layer=0, end_layer=0, n_layers=28), }, "qwen-2.5-7b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-7B-Instruct", start_layer=0, end_layer=0, n_layers=28), }, "qwen-2.5-math-7b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-Math-7B-Instruct", start_layer=0, end_layer=0, n_layers=28), }, "qwen-2.5-14b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-14B-Instruct", start_layer=0, end_layer=0, n_layers=48), }, "qwen-2.5-72b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-72B-Instruct", start_layer=0, end_layer=0, n_layers=80), }, "qwen-2.5-math-72b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-Math-72B-Instruct", start_layer=0, end_layer=0, n_layers=80), }, "qwen2-0.5b-instruct": { "TorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), From 830d33d5e17ebdcd41167dc40d8d1c02b104bcca Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Fri, 11 Oct 2024 16:39:12 -0700 Subject: [PATCH 408/491] in torch, explicitly set the device when initilaizing the model --- exo/inference/torch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index d6038f04..7e154aeb 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -59,7 +59,7 @@ def __init__( self.llm_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.torch_dtype, - device_map="auto", + device_map={"": self.device}, offload_buffers=True ) From 074dfe3dc965db3838f89ca1d05cb486eddf02ed Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Fri, 11 Oct 2024 16:39:16 -0700 Subject: [PATCH 409/491] spacing --- exo/inference/torch/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index c5eddabd..bdb71f64 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -114,7 +114,7 @@ async def async_forward( ) with ThreadPoolExecutor() as pool: - result = await loop.run_in_executor(pool, forward_partial) + result = await loop.run_in_executor(pool, forward_partial) return result From d9cfcc4c20555981ae88188a42270604240777eb Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Fri, 11 Oct 2024 16:53:13 -0700 Subject: [PATCH 410/491] add model mlx-community/Qwen2-0.5B-Instruct-4bit --- exo/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/models.py b/exo/models.py index 3cefe889..fe608f05 100644 --- a/exo/models.py +++ b/exo/models.py @@ -96,6 +96,7 @@ "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-Math-72B-Instruct", start_layer=0, end_layer=0, n_layers=80), }, "qwen2-0.5b-instruct": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2-0.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=24), "TorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), }, } From 2c056b4fc71e7b57ec0c615364bd35e81f406a66 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 12 Oct 2024 02:56:16 -0800 Subject: [PATCH 411/491] code changes from PR feedback, working on splitting of weights --- exo/inference/torch/inference.py | 123 +++--- exo/inference/torch/model/hf.py | 56 +-- .../torch/tests/test_inference_engine.py | 13 +- .../torch/tests/test_simple_model.py | 13 +- exo/inference/torch/tests/test_split_model.py | 378 ------------------ exo/models.py | 4 +- exo/networking/grpc/grpc_peer_handle.py | 4 +- setup.py | 2 +- 8 files changed, 125 insertions(+), 468 deletions(-) delete mode 100644 exo/inference/torch/tests/test_split_model.py diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index c5eddabd..d3f4e853 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -1,7 +1,6 @@ # experimental, based off of tinygrad/inference.py import asyncio import os -import re import json import functools from concurrent.futures import ThreadPoolExecutor @@ -17,18 +16,14 @@ from exo.inference.tokenizers import resolve_tokenizer from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.hf.hf_helpers import get_weight_map -from transformers import AutoTokenizer, Cache - -# llama -from transformers.models.llama.modeling_llama import LlamaModel +from transformers import Cache # model value options TOP_K = 20 TEMP = 0.6 TOP_P = 0.9 -MAX_LENGTH = 125 -MAX_TIME = 60.0 class TorchDynamicShardInferenceEngine(InferenceEngine): """ @@ -63,7 +58,7 @@ def __init__(self, shard_downloader: HFShardDownloader): torch.set_default_device(self.device) # setup cude dtype - self.torch_dtype = torch.float32 if self.device != torch.device('cpu') else torch.float16 + self.dtype = torch.get_default_dtype() # setup threadding torch.set_num_threads(torch.get_num_threads()) @@ -103,18 +98,30 @@ async def async_forward( hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None ) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]: + """ + Asynchronously performs the forward pass using a stateful sharded model. - loop = asyncio.get_running_loop() + Args: + input_ids (torch.Tensor, optional): Input token IDs for the model. If not provided, `hidden_states` must be used. + hidden_states (torch.Tensor, optional): Precomputed hidden states to be used instead of `input_ids`. + attention_mask (torch.Tensor, optional): Mask to prevent attention on padding token indices. - forward_partial = functools.partial( - self.stateful_sharded_model.forward, - input_ids=input_ids, - hidden_states=hidden_states, - attention_mask=attention_mask - ) + Returns: + A tuple containing: + + - shard_hidden_states (torch.Tensor, optional): Hidden states resulting from the forward pass. + - shard_past_kvs (list(torch.FloatTensor), optional): List of past key-value tensors (cache) used in the model. + - shard_logits (torch.Tensor, optional): The logits computed during the forward pass. + """ + loop = asyncio.get_running_loop() with ThreadPoolExecutor() as pool: - result = await loop.run_in_executor(pool, forward_partial) + result = await loop.run_in_executor(pool, functools.partial( + self.stateful_sharded_model.forward, + input_ids=input_ids, + hidden_states=hidden_states, + attention_mask=attention_mask + )) return result @@ -122,16 +129,22 @@ async def async_logit_sample( self, logits: torch.Tensor ) -> torch.Tensor: + """ + Asynchronously samples logits using the model's logit sampling method. - loop = asyncio.get_running_loop() + Args: + logits (torch.Tensor): The logits produced by the model for sampling. - sample_partial = functools.partial( - self.stateful_sharded_model.logits_sample, - logits=logits - ) + Returns: + next_logit (torch.Tensor): The next logit samples from given logis + """ + loop = asyncio.get_running_loop() with ThreadPoolExecutor() as pool: - result = await loop.run_in_executor(pool, sample_partial) + result = await loop.run_in_executor(pool, functools.partial( + self.stateful_sharded_model.logits_sample, + logits=logits + )) return result @@ -143,6 +156,23 @@ async def infer_prompt( image_str: Optional[str] = None, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: + """ + Asynchronously processes a prompt using the specified shard and returns the inference result. + + Args: + request_id (str): The unique identifier for the request. + shard (Shard): The model shard used for inference. + prompt (str): The text prompt to be processed by the model. + image_str (str, optional): A base64 encoded image string to be optionally used in the inference. Defaults to None. + inference_state (str, optional): The cached inference state for resuming or continuing inference. Defaults to None. + + Returns: + A tuple containing: + + - input_ids (np.ndarray): The processed token IDs as a NumPy array if logits were generated. Otherwise, it returns hidden states. + - cache_json (str): A JSON string containing the cached input IDs for further inference steps. + - is_finished (bool): A boolean indicating whether the model has reached the end-of-sequence (EOS) token. + """ if DEBUG >= 4: print("infer_prompt called") print(f"prompt: {prompt}") @@ -182,6 +212,9 @@ async def infer_prompt( self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) input_ids = next_token + if DEBUG >= 4: + print(f"\nnext_token: {next_token}") + if self.past_input_ids is not None: cached_iids = {"input_ids": self.past_input_ids.tolist()} @@ -189,12 +222,6 @@ async def infer_prompt( if next_token is not None: is_finished = next_token.item() == self.tokenizer.eos_token_id - if DEBUG >= 4: - print(f"\ninput_ids: {input_ids}") - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - return_values = ( input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), json.dumps({"cached_iids": cached_iids}), @@ -213,6 +240,22 @@ async def infer_tensor( input_data: np.ndarray, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: + """ + Asynchronously processes input tensor data using the specified shard and returns the inference result. + + Args: + request_id (str): The unique identifier for the request. + shard (Shard): The model shard used for inference. + input_data (np.ndarray): The input data in NumPy array format to be processed by the model. + inference_state (str, optional): The cached inference state for resuming or continuing inference. Defaults to None. + + Returns: + A tuple containing: + + - input_ids (np.ndarray): The processed token IDs as a NumPy array if logits were generated. Otherwise, it returns hidden states. + - cache_json (str): A JSON string containing the cached input IDs for further inference steps. + - is_finished (bool): A boolean indicating whether the model has reached the end-of-sequence (EOS) token. + """ if DEBUG >= 4: print("infer_tensor called") print(f"input_data: {input_data}") @@ -239,9 +282,9 @@ async def infer_tensor( self.past_input_ids = input_ids if DEBUG >= 4: - print(f"past_input_ids: {self.past_input_ids}") - print(f"hidden_state: {hidden_states}") - print(f"inference_state: {inference_state}") + print(f"\npast_input_ids: {self.past_input_ids}") + print(f"\nhidden_state: {hidden_states}") + print(f"\ninference_state: {inference_state}") shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( input_ids=self.past_input_ids, @@ -309,26 +352,14 @@ async def ensure_shard(self, shard: Shard): shard=shard, local_model_path=model_path, device=self.device, - dtype=self.torch_dtype, + dtype=self.dtype, top_k=TOP_K, temp=TEMP, - top_p=TOP_P, - max_length=MAX_LENGTH, - max_time=MAX_TIME + top_p=TOP_P ) self.shard = shard - if isinstance(self.stateful_sharded_model.model, LlamaModel): - self.tokenizer = AutoTokenizer.from_pretrained( - model_path if model_path is not None else shard.model_id, - trust_remote_code=True - ) - - if len(re.findall(r"3\.1", shard.model_id)) > 0: - self.tokenizer.add_special_tokens({"pad_token":""}) - - else: - self.tokenizer = await resolve_tokenizer(shard.model_id) + self.tokenizer = await resolve_tokenizer(shard.model_id) if DEBUG >= 4: print(f"Shard loaded successfully: {shard}") diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index d6038f04..254c1dd3 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -27,10 +27,21 @@ def __init__( dtype, top_k: int = 25, temp: float = 0.7, - top_p: float = 0.9, - max_length: int = 50, - max_time: float = 10.0 + top_p: float = 0.9 ): + """ + Initializes the ShardedHuggingFaceModel with a specified shard, model path, and device. + + Args: + shard (Shard): The model shard containing the start and end layers. + local_model_path (str): The local path to the model. + device (str): The device on which to run the model, e.g., "cuda" or "cpu". + dtype (torch.dtype): The data type (precision) to be used for model computations. + top_k (int, optional): The number of top tokens to consider for sampling. Defaults to 25. + temp (float, optional): The temperature for softmax sampling. Defaults to 0.7. + top_p (float, optional): The cumulative probability threshold for nucleus sampling. Defaults to 0.9. + """ + # class vars self.shard = shard self.hidden_states = None @@ -52,14 +63,14 @@ def __init__( ]) self.device = device - self.torch_dtype = dtype + self.dtype = dtype # setup pytorch and transformer llm try: self.llm_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, - torch_dtype=self.torch_dtype, - device_map="auto", + torch_dtype=self.dtype, + device_map={"", self.device}, offload_buffers=True ) @@ -77,25 +88,20 @@ def forward( use_legacy_cache: bool = False ) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]: """ - Generate hidden states or logits via passing through set amount of layers of a model - To be passed only input_ids OR hidden_state and not both. This is for connecting the model - layer to generate a complete output + Performs a forward pass through the model shard, computing hidden states, past key values, and logits. Args: - model: base llm model tramsformers class - llm_model: llm chat model class - input_ids: tensor optional - attention_mask: tensor optional - past_key_values: Cache or list[tensor] optional - use_legacy_cache: bool optional - infer_tensor: bool optional, lets forward know to handle tensors + input_ids (torch.Tensor, optional): The input token IDs for the model. Either input_ids or hidden_states must be provided. + hidden_states (torch.Tensor, optional): The hidden states of the model at the current layer. + attention_mask (torch.Tensor, optional): The attention mask to prevent attending to padding tokens. + past_key_values (Union[Cache, List[torch.FloatTensor]], optional): Cached past key values for fast autoregressive generation. + use_legacy_cache (bool, optional): Whether to use the legacy cache format for past key values. Defaults to False. Returns: - Tuple of - - hidden_states: tensor optional - - past_key_values: Cache or list[tensor] optional - - logits: tensor Optional - + Tuple: + - hidden_states (torch.Tensor, optional): The hidden states after the forward pass. + - past_key_values (Union[Cache, List[torch.FloatTensor]], optional): The updated past key values. + - logits (torch.Tensor, optional): The logits produced by the model if the last layer is processed. """ model_inputs = None self.hidden_states = hidden_states @@ -246,14 +252,14 @@ def logits_sample( use_max: Optional[bool] = False ) -> torch.Tensor: """ - Get a sample of the logits from end of model run for next token + Samples the next token from the model's output logits, either by using argmax or probabilistic sampling. Args: - logits: tensor - use_max: bool, if function should sample with argmax + logits (torch.Tensor): The logits output from the model's final layer. + use_max (bool, optional): If True, uses torch.argmax to select the next token from logits. Defaults to False. Returns: - next_token: tensor + torch.Tensor: The next predicted token. """ # get a single cloned logit diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py index e8b0b14c..b326a0b6 100644 --- a/exo/inference/torch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -1,12 +1,12 @@ +""" +Test inference engine and model sharding +""" import asyncio from exo.inference.shard import Shard from exo.inference.torch.inference import TorchDynamicShardInferenceEngine from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.inference_engine import InferenceEngine -from exo.inference.shard import Shard -from exo.helpers import DEBUG -import os import numpy as np import time @@ -15,8 +15,7 @@ async def test_inference_engine( inference_engine_2: InferenceEngine, model_id: str, n_layers: int): - - # prompt = "Why is the sky blue?" + prompt = "In a single word only, what is the last name of the current president of the USA?" shard = Shard( @@ -129,11 +128,11 @@ async def test_inference_engine( # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") try: - print("\n-------- Test meta-llama/Llama-3.2-1B-Instruct ----------\n") + print("\n-------- Test unsloth/Llama-3.2-1B-Instruct ----------\n") asyncio.run(test_inference_engine( TorchDynamicShardInferenceEngine(HFShardDownloader()), TorchDynamicShardInferenceEngine(HFShardDownloader()), - "meta-llama/Llama-3.2-1B-Instruct", + "unsloth/Llama-3.2-1B-Instruct", 24 )) except Exception as err: diff --git a/exo/inference/torch/tests/test_simple_model.py b/exo/inference/torch/tests/test_simple_model.py index 1b08a180..2a36717f 100644 --- a/exo/inference/torch/tests/test_simple_model.py +++ b/exo/inference/torch/tests/test_simple_model.py @@ -1,5 +1,8 @@ +""" +Simple model test using basic pytorch/huggingface LLM model loading, inference and generation +with logit sampling +""" from transformers import AutoModelForCausalLM, AutoTokenizer -device = "cuda" # the device to load the model onto model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2-0.5B-Instruct", @@ -19,7 +22,7 @@ tokenize=False, add_generation_prompt=True ) -model_inputs = tokenizer([text], return_tensors="pt").to(device) +model_inputs = tokenizer([text], return_tensors="pt") print(f"model_inputs:\n{model_inputs}") @@ -29,11 +32,9 @@ model_inputs.input_ids, attention_mask=model_inputs.attention_mask, max_new_tokens=512, - do_sample=True, - #top_k=20, - #num_beams=5, - #early_stopping=True + do_sample=True ) + generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py deleted file mode 100644 index 157a215d..00000000 --- a/exo/inference/torch/tests/test_split_model.py +++ /dev/null @@ -1,378 +0,0 @@ -import torch -import torch.nn as nn -import asyncio -import gc -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - DynamicCache, - Cache, - LogitsProcessorList, - TopKLogitsWarper, - TopPLogitsWarper, - TemperatureLogitsWarper, - StoppingCriteriaList, - MaxLengthCriteria, - MaxTimeCriteria -) - -from transformers.generation.configuration_utils import ( - GenerationConfig, - GenerationMode -) - -# llama -from transformers.models.llama.modeling_llama import LlamaModel - -# qwen2 -from transformers.models.qwen2.modeling_qwen2 import Qwen2Model - -from exo.api.chatgpt_api import resolve_tokenizer -from typing import Tuple, Optional, Union, List -import re - -TEMP = 0.6 -TOP_K = 60 - -class OnionHuggingFaceLM(): - def __init__(self, layers, is_last=False): - self.layers = layers - self.is_last = is_last - self.past_key_values = None - self.cache_position = None - self.position_ids = None - self.input_embed = None - self.causal_mask = None - self.position_embeddings = None - self.attention_mask = None - self.input_ids = None - self.hidden_states = None - self.next_decoder_cache = None - - def forward( - self, - model, - llm_model, - input_ids: Optional[torch.tensor] = None, - hidden_states: Optional[torch.tensor] = None, - attention_mask: Optional[torch.tensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - **kwargs - ) -> Tuple[Optional[torch.tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.tensor]]: - - """ - Generate hidden states or logits via passing through set amount of layers of a model - To be passed only input_ids OR hidden_state and not both. This is for connecting the model - layer to generate a complete output - - Args: - model: base llm model tramsformers class - llm_model: llm chat model class - input_ids: tensor Optional - hidden_states: tensor Optional - - Returns: - Tuple of - - hidden_states: tensor Optional - - past_key_values - - logits: tensor Optional - - """ - output_attentions = False # outputting attention not needed - use_legacy_cache = False # some models still use legacy kv store - - if input_ids is not None and hidden_states is not None: - raise ValueError - - if hidden_states is not None: - self.hidden_states = hidden_states - - if input_ids is not None: - self.input_ids = input_ids - - # embed input_ids - self.inputs_embeds = model.embed_tokens(self.input_ids) - - # cache - if past_key_values and not isinstance(past_key_values, Cache): - print("Using legacy cache") - use_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + self.inputs_embeds.shape[1], - device=self.inputs_embeds.device - ) - - # position id - position_ids = cache_position.unsqueeze(0) - - # causal mask - self.attention_mask = attention_mask - self.causal_mask = model._update_causal_mask( - None, - self.inputs_embeds, - cache_position, - past_key_values, - output_attentions - ) - - #print(f"causal_mask.dim(): {self.causal_mask.dim()}") - - print(f"\ncausal_mask:{self.causal_mask}\n\n") - - # embed positions, some models require and some dont - if isinstance(model, LlamaModel): - self.position_embeddings = model.rotary_emb( - self.inputs_embeds, - position_ids - ) - - model_inputs = llm_model.prepare_inputs_for_generation( - self.input_ids, - past_key_values=past_key_values, - attention_mask=self.attention_mask, - inputs_embeds=self.inputs_embeds, - position_ids=position_ids, - cache_position=cache_position - ) - - print(f"model_inputs\n{model_inputs}") - - self.hidden_states = self.inputs_embeds - self.position_ids = model_inputs["position_ids"] - self.cache_position = model_inputs["cache_position"] - self.past_key_values = model_inputs["past_key_values"] - - - for decoder_layer in self.layers: - layer_outputs = decoder_layer( - self.hidden_states, - attention_mask=self.causal_mask, - position_ids=self.position_ids, - past_key_values=self.past_key_values, - use_cache=True, - cache_position=self.cache_position - - ) - - self.hidden_states = layer_outputs[0] - self.next_decoder_cache = layer_outputs[1] - - if self.is_last: - self.hidden_states = model.norm(self.hidden_states) - - if use_legacy_cache: - self.past_key_values = self.next_decoder_cache.to_legacy_cache() - else: - self.past_key_values = self.next_decoder_cache - - # lm_head - logits = llm_model.lm_head(self.hidden_states).to("cuda") - - return ( - None, - None, - logits - ) - - return ( - self.hidden_states, - self.past_key_values, - None - ) - -async def model_half_split_test(prompt: str, model_id: str, layers: int): - """ - Test for splitting in half - """ - - half_layers = int(layers / 2) - - # inference - tokenizer = AutoTokenizer.from_pretrained(model_id) - max_length = 512 #tokenizer.model_max_length - - # get llm model - llm_model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype="auto", - device_map="auto", - use_cache=True - ) - - # get base model - model = llm_model.model - - # add pad token if none, depending on model - if tokenizer.pad_token == None: - if re.match(r"Llama|llama", model_id): - tokenizer.add_special_tokens({"pad_token":""}) - model.resize_token_embeddings(len(tokenizer)) - - - # generate input_ids - messages = [{"role": "user", "content": prompt}] - txt = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - - inputs = tokenizer([txt], return_tensors="pt") - input_ids = inputs.input_ids.to("cuda") - input_attention_mask = inputs.attention_mask.to("cuda") - batch_size, seq_length = input_ids.shape[:2] - - is_finished = False - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - logit_runs = 1 - - raw_logits = None - - while not is_finished: - print(f"\n\nLOGIT RUN {logit_runs}\n\n") - - print(f"input_ids:\n{input_ids}\n") - print(input_ids.shape) - - print("\n first half of layers") - shard_layers = nn.ModuleList(model.layers[:half_layers])#.to("cuda") - #shard_layers = nn.ModuleList(model.layers) - sharded_model = OnionHuggingFaceLM(layers=shard_layers) - #sharded_model.is_last = True - - # generate first half - # add if first layer of model check - shard_hidden_states, shard_past_kvs, shard_logits = sharded_model.forward( - model=model, - llm_model=llm_model, - attention_mask=input_attention_mask, - input_ids=input_ids, - hidden_states=None - ) - - # second half - print(f"\n second half of layers") - sharded_model.layers = nn.ModuleList(model.layers[half_layers:]) - sharded_model.is_last = True - - shard_hidden_states, shard_past_kvs, shard_logits = sharded_model.forward( - model=model, - llm_model=llm_model, - hidden_states=shard_hidden_states, - past_key_values=shard_past_kvs - ) - - # this part of the generation and _sample functions for transformers GenerationMixin - # ref: https://github.com/huggingface/transformers/blob/0a55d9f7376f72ad3ff296d4249840021b03bcc4/src/transformers/generation/utils.py#L1301 - - # clone logit sample - logits = shard_logits[:, -1, :].clone().float() - - raw_logits = logits - - # distribute - logits_processor = LogitsProcessorList([ - TopKLogitsWarper(35), - TemperatureLogitsWarper(0.6), - TopPLogitsWarper(0.8) - ]) - - stopping_critera = StoppingCriteriaList( - [ - MaxLengthCriteria(max_length=255), - MaxTimeCriteria(max_time=100.0), - ] - ) - - next_token_scores = logits_processor(input_ids, logits) - - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - #next_tokens = torch.argmax(next_token_scores, dim=-1) - - # get inputs ready incase not finished - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - - unfinished_sequences = unfinished_sequences & ~stopping_critera(input_ids, None) - is_finished = unfinished_sequences.max() == 0 - - print(f"is_finished?:\n{is_finished}\n") - - logit_runs += 1 - - del logits - del shard_logits - - print(f"model.generation_config\n{llm_model.generation_config}") - - generated_text = tokenizer.batch_decode( - input_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False - )[0] - - print(f"generated_text:\n{generated_text}\n") - - # free model from memory - del model - gc.collect() - torch.cuda.empty_cache() - - -if __name__ == "__main__": - #prompt = "In a single word only, what is the last name of the current president of the USA?" - prompt = "What color is the sky? Explain why" - #prompt = "In a single word only, what is the color of an apple?" - - #print("\n-------- Test TinyLlama/TinyLlama_v1.1 ----------\n") - #model_id = "TinyLlama/TinyLlama_v1.1" - #model_layers = 22 - - #asyncio.run( - # model_half_split_test( - # prompt=prompt, - # model_id=model_id, - # layers=model_layers - # ) - #) - - #print("\n-------- Test meta-llama/Meta-Llama-3.1-8B ----------\n") - #model_id = "meta-llama/Meta-Llama-3.1-8B" - #model_layers = 32 - - #asyncio.run( - # model_half_split_test( - # prompt=prompt, - # model_id=model_id, - # layers=model_layers - # ) - #) - - #print("\n-------- Test Qwen/Qwen2-0.5B-Instruct ----------\n") - #model_id = "Qwen/Qwen2-0.5B-Instruct" - #model_layers = 24 - - #asyncio.run( - # model_half_split_test( - # prompt=prompt, - # model_id=model_id, - # layers=model_layers - # ) - #) - - print("\n-------- Test meta-llama/Llama-3.2-1B-Instruct ----------\n") - model_id = "meta-llama/Llama-3.2-1B-Instruct" - model_layers = 32 - - asyncio.run( - model_half_split_test( - prompt=prompt, - model_id=model_id, - layers=model_layers - ) - ) - diff --git a/exo/models.py b/exo/models.py index b6a7092b..d5d69164 100644 --- a/exo/models.py +++ b/exo/models.py @@ -4,7 +4,7 @@ ### llama "llama-3.2-1b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16), - "TorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), }, "llama-3.2-3b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), @@ -12,7 +12,7 @@ "llama-3.1-8b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32), - "TorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Meta-Llama-3.1-8B-Instruct", start_layer=0, end_layer=0, n_layers=32), }, "llama-3.1-70b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index 6e8e586e..14a01f7d 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -12,8 +12,6 @@ from exo.topology.device_capabilities import DeviceCapabilities from exo.helpers import DEBUG -from exo.helpers import DEBUG - class GRPCPeerHandle(PeerHandle): def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities): self._id = _id @@ -78,7 +76,6 @@ async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] inference_state=inference_state, ) - print(f"request: {request}") response = await self.stub.SendPrompt(request) if not response.tensor_data or not response.shape or not response.dtype: @@ -98,6 +95,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Option request_id=request_id, inference_state=inference_state, ) + response = await self.stub.SendTensor(request) if not response.tensor_data or not response.shape or not response.dtype: diff --git a/setup.py b/setup.py index 432a1c3d..5a4e04bc 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad", "torch==2.4.0", - "accelerate" + "accelerate==0.34.2" ] # Add macOS-specific packages if on Darwin (macOS) From 83a723b9e6b156d54a9906f655ffe8d041062108 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 12 Oct 2024 21:09:06 -0800 Subject: [PATCH 412/491] doing more work toward individual safetensor loading, adding back device mapping auto --- exo/inference/torch/inference.py | 5 ++ exo/inference/torch/model/hf.py | 2 +- .../torch/tests/test_inference_engine.py | 18 ++++- exo/inference/torch/tests/test_split_model.py | 71 +++++++++++++++++++ 4 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 exo/inference/torch/tests/test_split_model.py diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index d3f4e853..52db0c0a 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -348,6 +348,11 @@ async def ensure_shard(self, shard: Shard): model_path = await self.shard_downloader.ensure_shard(shard) + # get model weight map + model_wm = await get_weight_map(repo_id=shard.model_id) + + print(f"model_wm: {model_wm}") + self.stateful_sharded_model = ShardedHuggingFaceModel( shard=shard, local_model_path=model_path, diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 254c1dd3..9d524de0 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -70,7 +70,7 @@ def __init__( self.llm_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.dtype, - device_map={"", self.device}, + device_map="auto", offload_buffers=True ) diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py index b326a0b6..a03c5c9a 100644 --- a/exo/inference/torch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -127,14 +127,26 @@ async def test_inference_engine( # except Exception as err: # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + #try: + # print("\n-------- Test unsloth/Llama-3.2-1B-Instruct ----------\n") + # asyncio.run(test_inference_engine( + # TorchDynamicShardInferenceEngine(HFShardDownloader()), + # TorchDynamicShardInferenceEngine(HFShardDownloader()), + # "unsloth/Llama-3.2-1B-Instruct", + # 24 + # )) + #except Exception as err: + # print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") + try: - print("\n-------- Test unsloth/Llama-3.2-1B-Instruct ----------\n") + print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") asyncio.run(test_inference_engine( TorchDynamicShardInferenceEngine(HFShardDownloader()), TorchDynamicShardInferenceEngine(HFShardDownloader()), - "unsloth/Llama-3.2-1B-Instruct", - 24 + "unsloth/Meta-Llama-3.1-8B-Instruct", + 32 )) except Exception as err: print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") + diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py new file mode 100644 index 00000000..95f0694d --- /dev/null +++ b/exo/inference/torch/tests/test_split_model.py @@ -0,0 +1,71 @@ +""" +Testing of loading model by layer +""" +import asyncio +import re + +from exo.download.hf.hf_helpers import get_weight_map +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.shard import Shard + +from typing import Optional, Union, Tuple + +from transformers import AutoModel + +async def load_model( + repo_id: str, + shard: Shard +) -> Optional[AutoModel]: + """ + load model by layer and safetensors + """ + + shard_downloader = HFShardDownloader() + model_path = await shard_downloader.ensure_shard(shard) + weight_map = await get_weight_map(repo_id) + + if weight_map: + for wname, wtensor in weight_map.items(): + # get layer number + layer_rgx = r'^model\.layers\.(\d+)\.(\w+)\.(\w+)$' + layer_found = re.findall(layer_rgx, wname) + if layer_found: + try: + layer_idx = int(layer_found[0][0]) + print(f"layer_idx: {layer_idx}") + if shard.start_layer <= layer_idx <= shard.end_layer: + print(f"wtensor: {wtensor}") + + # move to local .tmp folder that can be removed later + # check if files not already there, if there, reuse + # create automodel with rest of layers + # lm_head needed at end + except Exception as err: + print(f"err: {err}") + +async def test_split_model(model_id: str, n_layers: int): + """ + Test to load split models + """ + + shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=n_layers-1, + n_layers=n_layers + ) + + await load_model( + model_id, + shard + ) + +if __name__ == "__main__": + try: + print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") + asyncio.run(test_split_model( + "unsloth/Meta-Llama-3.1-8B-Instruct", + 32 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") From 47be250dbe743d3624f3ae2195ed50e9a2704373 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 13 Oct 2024 04:58:25 -0800 Subject: [PATCH 413/491] working on split model, moving to server for more vram --- exo/inference/torch/tests/test_split_model.py | 145 ++++++++++++++---- 1 file changed, 113 insertions(+), 32 deletions(-) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index 95f0694d..df0d5497 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -3,69 +3,150 @@ """ import asyncio import re +import json +import os +from pathlib import Path +from typing import Optional -from exo.download.hf.hf_helpers import get_weight_map +from exo.download.hf.hf_helpers import ( + get_weight_map, + download_repo_files +) from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.shard import Shard -from typing import Optional, Union, Tuple - -from transformers import AutoModel +from transformers import AutoModelForCausalLM async def load_model( repo_id: str, - shard: Shard -) -> Optional[AutoModel]: + shard: Shard, + model_path: Path, + weight_map: Optional[dict] +) -> Optional[AutoModelForCausalLM]: """ load model by layer and safetensors + return causal llm automodel with only requested layers, if weight maps + if no weight map, return and load the whole model """ - - shard_downloader = HFShardDownloader() - model_path = await shard_downloader.ensure_shard(shard) - weight_map = await get_weight_map(repo_id) - + print("load_model called") if weight_map: + layer_weight_map = {} + skip_layers = [] + for wname, wtensor in weight_map.items(): # get layer number - layer_rgx = r'^model\.layers\.(\d+)\.(\w+)\.(\w+)$' + layer_rgx = r'^model\.layers\.(\d+)\.*' layer_found = re.findall(layer_rgx, wname) + print(f"wname: {wname}") if layer_found: - try: - layer_idx = int(layer_found[0][0]) - print(f"layer_idx: {layer_idx}") - if shard.start_layer <= layer_idx <= shard.end_layer: - print(f"wtensor: {wtensor}") - - # move to local .tmp folder that can be removed later - # check if files not already there, if there, reuse - # create automodel with rest of layers - # lm_head needed at end - except Exception as err: - print(f"err: {err}") - -async def test_split_model(model_id: str, n_layers: int): + print(f"layer_found: {layer_found}") + # slice up layer map to start and end layers + # from shard + layer_idx = int(layer_found[0]) + if shard.start_layer <= layer_idx <= shard.end_layer: + layer_weight_map[wname] = wtensor + else: + skip_layers.append(wname) + print(f"SKIPPING LAYER {layer_idx}") + + if wname not in skip_layers: + print(f"adding non-layer: {wname}") + layer_weight_map[wname] = wtensor + + # will manipulate current model.safetensors.index.json + # but set back at end of inference + print(layer_weight_map) + + # rewrite model.safetensors.index.json + try: + model_st_snapshot = model_path/"model.safetensors.index.json" + # call download repo files again to reload original safetensors json + os.remove(model_st_snapshot) + + await download_repo_files( + repo_id=shard.model_id, + revision="main", + allow_patterns="model.safetensors.index.json") + + mst_json = {} + with open(model_st_snapshot, "r") as mst_file: + mst_json = json.load(mst_file) + + mst_json["weight_map"] = layer_weight_map + + print(f"mst_json: {json.dumps(mst_json, indent=4)}") + + with open(model_st_snapshot, "w") as mst_file: + json.dump(mst_json, mst_file, indent=4) + print(f"{model_st_snapshot} rewritten with {shard.n_layers} weights") + except Exception as err: + print(f"err: {err}") + raise + + else: + print("weight_map not found, loading whole model") + + # load model with layer edits + # or whole model if no weight_map + shard_model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map="auto", + offload_buffers=True + ) + + return shard_model + + +async def test_split_model( + model_id: str, + start_layer: int, + end_layer: int, + n_layers: int +): """ Test to load split models """ shard = Shard( model_id=model_id, - start_layer=0, - end_layer=n_layers-1, + start_layer=start_layer, + end_layer=end_layer-1, n_layers=n_layers ) + print(f"loading shard: {shard}") + shard_downloader = HFShardDownloader() + model_path = await shard_downloader.ensure_shard(shard) + weight_map = await get_weight_map(model_id) + await load_model( model_id, - shard + shard, + model_path, + weight_map ) if __name__ == "__main__": + #Qwen/Qwen2.5-3B try: - print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") + print("\n-------- Test Qwen/Qwen2.5-3B ----------\n") asyncio.run(test_split_model( - "unsloth/Meta-Llama-3.1-8B-Instruct", - 32 + "Qwen/Qwen2.5-3B", + 0, + 1, + 36 )) except Exception as err: print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") + + # unsloth/Meta-Llama-3.1-8B-Instruct + #try: + # print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") + # asyncio.run(test_split_model( + # "unsloth/Meta-Llama-3.1-8B-Instruct", + # 0, + # 1, + # 32 + # )) + #except Exception as err: + # print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") From ea0d4b154e827aedd05c97afbeb595451d88a5b8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 13 Oct 2024 05:59:26 -0800 Subject: [PATCH 414/491] change to hf downloader as was not getting all safetensor files --- exo/download/hf/hf_helpers.py | 34 ++++++++++++------- exo/inference/torch/tests/test_split_model.py | 33 +++++++++++------- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index a548df2e..4c22763d 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -394,19 +394,27 @@ def extract_layer_num(tensor_name: str) -> Optional[int]: def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: - default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"]) + default_patterns = set([ + "*.json", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + "*.safetensors" + ]) + shard_specific_patterns = set() - if weight_map: - for tensor_name, filename in weight_map.items(): - layer_num = extract_layer_num(tensor_name) - if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer: - shard_specific_patterns.add(filename) - sorted_file_names = sorted(weight_map.values()) - if shard.is_first_layer(): - shard_specific_patterns.add(sorted_file_names[0]) - elif shard.is_last_layer(): - shard_specific_patterns.add(sorted_file_names[-1]) - else: - shard_specific_patterns = set("*.safetensors") + #if weight_map: + # for tensor_name, filename in weight_map.items(): + # layer_num = extract_layer_num(tensor_name) + # if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer: + # shard_specific_patterns.add(filename) + # sorted_file_names = sorted(weight_map.values()) + # if shard.is_first_layer(): + # shard_specific_patterns.add(sorted_file_names[0]) + # elif shard.is_last_layer(): + # shard_specific_patterns.add(sorted_file_names[-1]) + #else: + #shard_specific_patterns = set("*.safetensors") if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}") return list(default_patterns | shard_specific_patterns) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index df0d5497..46a3b6a4 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -29,6 +29,8 @@ async def load_model( if no weight map, return and load the whole model """ print("load_model called") + model_st_snapshot = model_path/"model.safetensors.index.json" + if weight_map: layer_weight_map = {} skip_layers = [] @@ -59,14 +61,13 @@ async def load_model( # rewrite model.safetensors.index.json try: - model_st_snapshot = model_path/"model.safetensors.index.json" # call download repo files again to reload original safetensors json - os.remove(model_st_snapshot) + #os.remove(model_st_snapshot) - await download_repo_files( - repo_id=shard.model_id, - revision="main", - allow_patterns="model.safetensors.index.json") + #await download_repo_files( + # repo_id=shard.model_id, + # revision="main", + # allow_patterns="model.safetensors.index.json") mst_json = {} with open(model_st_snapshot, "r") as mst_file: @@ -76,9 +77,11 @@ async def load_model( print(f"mst_json: {json.dumps(mst_json, indent=4)}") - with open(model_st_snapshot, "w") as mst_file: - json.dump(mst_json, mst_file, indent=4) - print(f"{model_st_snapshot} rewritten with {shard.n_layers} weights") + os.remove(model_st_snapshot) + + with open(model_st_snapshot, "w") as mst_file: + json.dump(mst_json, mst_file, indent=4) + print(f"{model_st_snapshot} rewritten with {shard.n_layers} weights") except Exception as err: print(f"err: {err}") raise @@ -94,6 +97,9 @@ async def load_model( offload_buffers=True ) + # have to clear out edited model safetensors mst_json + os.remove(model_st_snapshot) + return shard_model @@ -114,6 +120,9 @@ async def test_split_model( n_layers=n_layers ) + # remove old weight json if present + + print(f"loading shard: {shard}") shard_downloader = HFShardDownloader() model_path = await shard_downloader.ensure_shard(shard) @@ -129,11 +138,11 @@ async def test_split_model( if __name__ == "__main__": #Qwen/Qwen2.5-3B try: - print("\n-------- Test Qwen/Qwen2.5-3B ----------\n") + print("\n-------- Test Qwen/Qwen2.5-3B-Instruct ----------\n") asyncio.run(test_split_model( - "Qwen/Qwen2.5-3B", + "Qwen/Qwen2.5-3B-Instruct", 0, - 1, + 18, 36 )) except Exception as err: From 30b799174ae30721ee22637c48eb723129d6d67c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 13 Oct 2024 07:17:16 -0800 Subject: [PATCH 415/491] splitting model still work in progress as transformers still seems to try to load more than needed even witha modified safetensor json file, finished up PR main updates but will continue on this one --- exo/download/hf/hf_helpers.py | 34 ++++++--------- exo/inference/torch/tests/test_split_model.py | 42 +++++++------------ 2 files changed, 28 insertions(+), 48 deletions(-) diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index 4c22763d..a548df2e 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -394,27 +394,19 @@ def extract_layer_num(tensor_name: str) -> Optional[int]: def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: - default_patterns = set([ - "*.json", - "*.py", - "tokenizer.model", - "*.tiktoken", - "*.txt", - "*.safetensors" - ]) - + default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"]) shard_specific_patterns = set() - #if weight_map: - # for tensor_name, filename in weight_map.items(): - # layer_num = extract_layer_num(tensor_name) - # if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer: - # shard_specific_patterns.add(filename) - # sorted_file_names = sorted(weight_map.values()) - # if shard.is_first_layer(): - # shard_specific_patterns.add(sorted_file_names[0]) - # elif shard.is_last_layer(): - # shard_specific_patterns.add(sorted_file_names[-1]) - #else: - #shard_specific_patterns = set("*.safetensors") + if weight_map: + for tensor_name, filename in weight_map.items(): + layer_num = extract_layer_num(tensor_name) + if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer: + shard_specific_patterns.add(filename) + sorted_file_names = sorted(weight_map.values()) + if shard.is_first_layer(): + shard_specific_patterns.add(sorted_file_names[0]) + elif shard.is_last_layer(): + shard_specific_patterns.add(sorted_file_names[-1]) + else: + shard_specific_patterns = set("*.safetensors") if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}") return list(default_patterns | shard_specific_patterns) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index 46a3b6a4..cab9f221 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -8,10 +8,7 @@ from pathlib import Path from typing import Optional -from exo.download.hf.hf_helpers import ( - get_weight_map, - download_repo_files -) +from exo.download.hf.hf_helpers import get_weight_map from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.shard import Shard @@ -33,7 +30,7 @@ async def load_model( if weight_map: layer_weight_map = {} - skip_layers = [] + non_layer_weights = [] for wname, wtensor in weight_map.items(): # get layer number @@ -47,28 +44,20 @@ async def load_model( layer_idx = int(layer_found[0]) if shard.start_layer <= layer_idx <= shard.end_layer: layer_weight_map[wname] = wtensor - else: - skip_layers.append(wname) - print(f"SKIPPING LAYER {layer_idx}") - - if wname not in skip_layers: - print(f"adding non-layer: {wname}") - layer_weight_map[wname] = wtensor - - # will manipulate current model.safetensors.index.json - # but set back at end of inference - print(layer_weight_map) + else: + non_layer_weights.append((wname, wtensor)) + + if shard.is_first_layer(): + # this assumes at max only one first weight non-layer for model + first_weight = non_layer_weights[0] + layer_weight_map[first_weight[0]] = first_weight[1] + elif shard.is_last_layer(): + last_weights = non_layer_weights[1:] + for last_weight in last_weights: + layer_weight_map[last_weight[0]] = last_weight[1] # rewrite model.safetensors.index.json try: - # call download repo files again to reload original safetensors json - #os.remove(model_st_snapshot) - - #await download_repo_files( - # repo_id=shard.model_id, - # revision="main", - # allow_patterns="model.safetensors.index.json") - mst_json = {} with open(model_st_snapshot, "r") as mst_file: mst_json = json.load(mst_file) @@ -81,7 +70,6 @@ async def load_model( with open(model_st_snapshot, "w") as mst_file: json.dump(mst_json, mst_file, indent=4) - print(f"{model_st_snapshot} rewritten with {shard.n_layers} weights") except Exception as err: print(f"err: {err}") raise @@ -109,7 +97,7 @@ async def test_split_model( end_layer: int, n_layers: int ): - """ + """ Test to load split models """ @@ -142,7 +130,7 @@ async def test_split_model( asyncio.run(test_split_model( "Qwen/Qwen2.5-3B-Instruct", 0, - 18, + 3, 36 )) except Exception as err: From 3a2c431c3102b2a9e2a4bf934bca4e15f001b1e5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 13 Oct 2024 07:35:42 -0800 Subject: [PATCH 416/491] updating readme --- exo/inference/torch/README.md | 41 ++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/exo/inference/torch/README.md b/exo/inference/torch/README.md index 5cbeeef6..c80d35c2 100644 --- a/exo/inference/torch/README.md +++ b/exo/inference/torch/README.md @@ -1,9 +1,44 @@ # PyTorch & HuggingFace inference engine +## Tech + +Tested on + +```bash +# Laptop/PC +Distributor ID: Pop +Description: Pop!_OS 22.04 LTS +Release: 22.04 +Codename: jammy +CUDA Version: 12.4 +Nvidia Driver Version: 550.107.02 + +GPU 1: Nvidia GeForce RTX 3060 6GB Laptop +``` +```bash +# Server +Distributor ID: Pop +Description: Pop!_OS 22.04 LTS +Release: 22.04 +Codename: jammy +CUDA Version: 12.4 +Nvidia Driver Version: 550.90.07 + +GPU 1: NVIDIA T1000 8GB +GPU 2: NVIDIA Quadro M2000 4GB +GPU 3: NVIDIA Quadro M2000 4GB +GPU 4: NVIDIA Quadro P400 2GB +GPU 5: NVIDIA Quadro P400 2GB +``` + + ## Notes/Issues ### 10/10/2024 - To select a pytorch device via environment variables, set the variable TORCH_DEVICE -- - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM -- - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373) -- - Looking into adding mobile device support properly + - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM + - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373) +- Looking into adding mobile device support properly - If device is not CPU the data type defaults to float32 else float16. + +### 10/13/2024 +Still working on split model development (see test_split_model.py). Right now, it seems to do it but still transformers is loading more in the RAM and GPU as it loads up a larger models (causing an OOM). Will research and add to next update. Right now, tests are added and are in development. From 6c6e7b2c3bffa0f8bd9c59b01339f1ccb92d9140 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 13 Oct 2024 21:41:04 -0800 Subject: [PATCH 417/491] successful splitting model test with only loading needed weights, implementing it in main inference code --- exo/inference/torch/tests/test_split_model.py | 85 ++++++++++++++++--- 1 file changed, 71 insertions(+), 14 deletions(-) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index cab9f221..c8c4f3a7 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -8,17 +8,35 @@ from pathlib import Path from typing import Optional +import torch + +from transformers.modeling_utils import offload_weight + from exo.download.hf.hf_helpers import get_weight_map from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.shard import Shard -from transformers import AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +def print_ram_stats(): + if torch.cuda.is_available(): + allocated_memory = torch.cuda.memory_allocated() + max_memory = torch.cuda.max_memory_allocated() + cached_memory = torch.cuda.memory_reserved() + + print("Cuda stats") + print(f'Allocated memory: {allocated_memory / 1024**2} MB') + print(f'Max allocated memory: {max_memory / 1024**2} MB') + print(f'Cached memory: {cached_memory / 1024**2} MB') + + async def load_model( repo_id: str, shard: Shard, model_path: Path, - weight_map: Optional[dict] + weight_map: Optional[dict], + device: Optional[str] = "cuda" ) -> Optional[AutoModelForCausalLM]: """ load model by layer and safetensors @@ -61,7 +79,6 @@ async def load_model( mst_json = {} with open(model_st_snapshot, "r") as mst_file: mst_json = json.load(mst_file) - mst_json["weight_map"] = layer_weight_map print(f"mst_json: {json.dumps(mst_json, indent=4)}") @@ -77,27 +94,70 @@ async def load_model( else: print("weight_map not found, loading whole model") + # setup the weight range for init_weights + shard_num_hidden_layers = shard.end_layer - shard.start_layer + print(f"Setting up LLM config with {shard_num_hidden_layers} hidden layers") + llm_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path=model_path, + device_map="cuda", + offload_buffers=True, + local_files_only=True, + num_hidden_layers=shard_num_hidden_layers + ) + # load model with layer edits # or whole model if no weight_map - shard_model = AutoModelForCausalLM.from_pretrained( - model_path, - device_map="auto", - offload_buffers=True + print(f"Loading sharded AutoModelForCausalLM from {model_path}") + shard_model = AutoModelForCausalLM.from_config(llm_config).to(device) + + print("Loading tokenizer") + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=model_path, + local_files_only=True, ) + print_ram_stats() + + prompt = "In a single word only, what color is a red apple?" + + model_inputs = tokenizer( + [prompt], + return_tensors="pt" + ) + + generated_ids = shard_model.generate( + model_inputs.input_ids.to(device), + attention_mask=model_inputs.attention_mask.to(device), + max_new_tokens=512, + do_sample=True + ) + + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip( + model_inputs.input_ids, + generated_ids + ) + ] + + response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + print(f"Prompt: {prompt}\n") + print(f"Response: {response}\n") + + print_ram_stats() + # have to clear out edited model safetensors mst_json os.remove(model_st_snapshot) return shard_model - async def test_split_model( model_id: str, start_layer: int, end_layer: int, n_layers: int ): - """ + """ Test to load split models """ @@ -108,9 +168,6 @@ async def test_split_model( n_layers=n_layers ) - # remove old weight json if present - - print(f"loading shard: {shard}") shard_downloader = HFShardDownloader() model_path = await shard_downloader.ensure_shard(shard) @@ -130,11 +187,11 @@ async def test_split_model( asyncio.run(test_split_model( "Qwen/Qwen2.5-3B-Instruct", 0, - 3, + 6, 36 )) except Exception as err: - print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") + print(f"\n\n !!!!!!!!!!! Qwen/Qwen2.5-3B-Instruct TEST FAILED \n{err}\n") # unsloth/Meta-Llama-3.1-8B-Instruct #try: From aacdeb595e8a47e9d8a0d45e8e3ac84c92efa4e5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 00:44:19 -0800 Subject: [PATCH 418/491] adding model sharding to inference engine, doing testing with inference engine and sharding --- exo/inference/torch/inference.py | 11 +- exo/inference/torch/model/hf.py | 110 ++++++++++-- .../torch/tests/test_inference_engine.py | 157 +++++++++--------- exo/inference/torch/tests/test_split_model.py | 42 ++--- exo/inference/torch/utils.py | 49 ++++++ 5 files changed, 252 insertions(+), 117 deletions(-) create mode 100644 exo/inference/torch/utils.py diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index 52db0c0a..63f29284 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -60,8 +60,11 @@ def __init__(self, shard_downloader: HFShardDownloader): # setup cude dtype self.dtype = torch.get_default_dtype() - # setup threadding - torch.set_num_threads(torch.get_num_threads()) + # setup device_map + if os.environ.get("TORCH_DEVICE_MAP"): + self.device_map = os.environ["TORCH_DEVICE_MAP"] + else: + self.device_map = str(self.device) def infer_caching( self, @@ -351,13 +354,13 @@ async def ensure_shard(self, shard: Shard): # get model weight map model_wm = await get_weight_map(repo_id=shard.model_id) - print(f"model_wm: {model_wm}") - self.stateful_sharded_model = ShardedHuggingFaceModel( shard=shard, local_model_path=model_path, + weight_map=model_wm, device=self.device, dtype=self.dtype, + device_map=self.device_map, top_k=TOP_K, temp=TEMP, top_p=TOP_P diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 9d524de0..c188a97a 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -1,12 +1,18 @@ +import os +import json +from typing import Tuple, Optional, Union, List +from pathlib import Path + import torch import torch.nn as nn -from typing import Tuple, Optional, Union, List from exo.inference.shard import Shard from exo.helpers import DEBUG +from exo.inference.torch.utils import extract_layers from transformers import ( - AutoModelForCausalLM, + AutoConfig, + AutoModelForCausalLM, DynamicCache, Cache, LogitsProcessorList, @@ -22,12 +28,15 @@ class ShardedHuggingFaceModel: def __init__( self, shard: Shard, - local_model_path, - device, - dtype, + local_model_path: Path, + weight_map: Optional[dict], + device: torch.device, + dtype: torch.dtype, + device_map: str, top_k: int = 25, temp: float = 0.7, - top_p: float = 0.9 + top_p: float = 0.9, + offload_buffers: bool = True ): """ Initializes the ShardedHuggingFaceModel with a specified shard, model path, and device. @@ -64,21 +73,96 @@ def __init__( self.device = device self.dtype = dtype + self.device_map = device_map + + self.offload_buffers = offload_buffers + + self.model_safetensors_path = self.local_model_path/"model.safetensors.index.json" # setup pytorch and transformer llm try: - self.llm_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=self.local_model_path, - torch_dtype=self.dtype, - device_map="auto", - offload_buffers=True - ) + if weight_map: + self.llm_model_config = self.load_sharded_model( + shard, + weight_map, + offload_buffers=self.offload_buffers + ) - self.model = self.llm_model.model + # clear out edited safetensor json + # this is needed because shard downloader just + # appends and not redownloads the file + os.remove(self.model_safetensors_path) + else: + self.llm_model_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path=self.local_model_path, + torch_dtype=self.dtype, + device_map=self.device_map, + offload_buffers=self.offload_buffers + ) + + self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) + + self.model = self.llm_model.model.to(self.device) except Exception as err: print(f"error loading and splitting model: {err}") raise + def load_sharded_model( + self, + shard: Shard, + weight_map: dict, + offload_buffers: bool + ) -> AutoConfig: + """ + Loads sharded version of model where only needed + weights are loaded for necessary layers + + Args: + + Returns: + """ + if DEBUG >= 4: + print("load_sharded_model called") + print(f"shard: {shard}") + + # break out layers per shard range + layer_weight_map = extract_layers( + weight_map, + shard + ) + + # rewrite model.safetensors.index.json for only needed layers + try: + mst_json = {} + with open(self.model_safetensors_path, "r") as mst_file: + mst_json = json.load(mst_file) + mst_json["weight_map"] = layer_weight_map + + if DEBUG >= 4: + print(f"rewritten safetensor index \n{json.dumps(mst_json, indent=4)}") + + os.remove(self.model_safetensors_path) + + with open(self.model_safetensors_path, "w") as mst_file: + json.dump(mst_json, mst_file, indent=4) + except Exception as err: + print(f"err: {err}") + raise + + # load model + try: + shard_num_hidden_layers = shard.end_layer - shard.start_layer + return AutoConfig.from_pretrained( + pretrained_model_name_or_path=self.local_model_path, + device_map=self.device_map, + offload_buffers=offload_buffers, + local_files_only=True, + num_hidden_layers=shard_num_hidden_layers + ) + except Exception as err: + print(f"err: {err}") + raise + def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py index a03c5c9a..e102af69 100644 --- a/exo/inference/torch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -15,59 +15,57 @@ async def test_inference_engine( inference_engine_2: InferenceEngine, model_id: str, n_layers: int): - + prompt = "In a single word only, what is the last name of the current president of the USA?" - shard = Shard( +# shard = Shard( +# model_id=model_id, +# start_layer=0, +# end_layer=n_layers-1, +# n_layers=n_layers +# ) +# +# resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( +# "A", +# shard=shard, +# prompt=prompt +# ) +# +# print("\n------------resp_full---------------\n") +# print(resp_full) +# print("\n------------resp_full---------------\n") +# +# time.sleep(5) +# +# next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( +# "A", +# shard=shard, +# input_data=resp_full, +# inference_state=inference_state_full, +# ) +# +# print("\n------------next_resp_full---------------\n") +# print(next_resp_full) +# print("\n------------next_resp_full---------------\n") +# +# time.sleep(5) + + resp_shard = Shard( model_id=model_id, start_layer=0, - end_layer=n_layers-1, + end_layer=1, n_layers=n_layers ) - resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( - "A", - shard=shard, - prompt=prompt - ) - - print("\n------------resp_full---------------\n") - print(resp_full) - print("\n------------resp_full---------------\n") - - time.sleep(5) - - next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( - "A", - shard=shard, - input_data=resp_full, - inference_state=inference_state_full, - ) - - print("\n------------next_resp_full---------------\n") - print(next_resp_full) - print("\n------------next_resp_full---------------\n") - - time.sleep(5) - - pp = int(n_layers/2) - - resp_shard = Shard( - model_id=model_id, - start_layer=0, - end_layer=pp, - n_layers=n_layers - ) - - resp_shard2 = Shard( - model_id=model_id, - start_layer=pp + 1, - end_layer=n_layers-1, - n_layers=n_layers - ) + #resp_shard2 = Shard( + # model_id=model_id, + # start_layer=3, + # end_layer=5, + # n_layers=n_layers + #) resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( - "B", + "B", shard=resp_shard, prompt=prompt ) @@ -78,42 +76,41 @@ async def test_inference_engine( time.sleep(5) - - resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( - "B", - shard=resp_shard2, - input_data=resp1, - inference_state=inference_state_1, - ) - - print("\n------------resp2---------------\n") - print(resp2) - print("\n------------resp2---------------\n") - - resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( - "B", - shard=resp_shard, - input_data=resp2, - inference_state=inference_state_2, - ) - - print("\n------------resp3---------------\n") - print(resp3) - print("\n------------resp3---------------\n") - - resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( - "B", - shard=resp_shard2, - input_data=resp3, - inference_state=inference_state_3, - ) - - print("\n------------resp4---------------\n") - print(resp4) - print("\n------------resp4---------------\n") - - assert np.array_equal(resp_full, resp2) - assert np.array_equal(next_resp_full, resp4) + #resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + # "B", + # shard=resp_shard2, + # input_data=resp1, + # inference_state=inference_state_1, + #) + + #print("\n------------resp2---------------\n") + #print(resp2) + #print("\n------------resp2---------------\n") + + #resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + # "B", + # shard=resp_shard, + # input_data=resp2, + # inference_state=inference_state_2, + #) + + #print("\n------------resp3---------------\n") + #print(resp3) + #print("\n------------resp3---------------\n") + + #resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + # "B", + # shard=resp_shard2, + # input_data=resp3, + # inference_state=inference_state_3, + #) + + #print("\n------------resp4---------------\n") + #print(resp4) + #print("\n------------resp4---------------\n") + + #assert np.array_equal(resp_full, resp2) + #assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': # try: diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index c8c4f3a7..0afc3ed3 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -29,9 +29,7 @@ def print_ram_stats(): print(f'Max allocated memory: {max_memory / 1024**2} MB') print(f'Cached memory: {cached_memory / 1024**2} MB') - - -async def load_model( +def load_model( repo_id: str, shard: Shard, model_path: Path, @@ -65,6 +63,10 @@ async def load_model( else: non_layer_weights.append((wname, wtensor)) + non_layer_weights = sorted(non_layer_weights, key=lambda x: x[1]) + + print(f"sorted non_layer_weights: {non_layer_weights}") + if shard.is_first_layer(): # this assumes at max only one first weight non-layer for model first_weight = non_layer_weights[0] @@ -173,7 +175,7 @@ async def test_split_model( model_path = await shard_downloader.ensure_shard(shard) weight_map = await get_weight_map(model_id) - await load_model( + load_model( model_id, shard, model_path, @@ -182,25 +184,25 @@ async def test_split_model( if __name__ == "__main__": #Qwen/Qwen2.5-3B + #try: + # print("\n-------- Test Qwen/Qwen2.5-3B-Instruct ----------\n") + # asyncio.run(test_split_model( + # "Qwen/Qwen2.5-3B-Instruct", + # 0, + # 6, + # 36 + # )) + #except Exception as err: + # print(f"\n\n !!!!!!!!!!! Qwen/Qwen2.5-3B-Instruct TEST FAILED \n{err}\n") + + # unsloth/Meta-Llama-3.1-8B-Instruct try: - print("\n-------- Test Qwen/Qwen2.5-3B-Instruct ----------\n") + print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") asyncio.run(test_split_model( - "Qwen/Qwen2.5-3B-Instruct", + "unsloth/Meta-Llama-3.1-8B-Instruct", 0, 6, - 36 + 32 )) except Exception as err: - print(f"\n\n !!!!!!!!!!! Qwen/Qwen2.5-3B-Instruct TEST FAILED \n{err}\n") - - # unsloth/Meta-Llama-3.1-8B-Instruct - #try: - # print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") - # asyncio.run(test_split_model( - # "unsloth/Meta-Llama-3.1-8B-Instruct", - # 0, - # 1, - # 32 - # )) - #except Exception as err: - # print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") + print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") diff --git a/exo/inference/torch/utils.py b/exo/inference/torch/utils.py new file mode 100644 index 00000000..994daeeb --- /dev/null +++ b/exo/inference/torch/utils.py @@ -0,0 +1,49 @@ +""" +Utility functions to be used by inference engine +and model +""" +import re + +from exo.inference.shard import Shard + +def extract_layers( + weight_map: dict, + shard: Shard +) -> dict: + """ + Extract layers from weight map in range + + Args: + + Returns: + """ + + layer_rgx = r'^model\.layers\.(\d+)\.*' + layer_weight_map = {} + non_layer_weights = [] + + for wname, wtensor in weight_map.items(): + layer_found = re.findall(layer_rgx, wname) + if layer_found: + layer_idx = int(layer_found[0]) + if shard.start_layer <= layer_idx <= shard.end_layer: + layer_weight_map[wname] = wtensor + else: + non_layer_weights.append((wname, wtensor)) + + non_layer_weights = sorted(non_layer_weights, key=lambda x: x[1]) + + print(non_layer_weights) + print(f"first: {shard.is_first_layer()}") + print(f"last: {shard.is_last_layer()}") + + if shard.is_first_layer(): + # this assumes at max only one first weight non-layer for model + first_weight = non_layer_weights[0] + layer_weight_map[first_weight[0]] = first_weight[1] + elif shard.is_last_layer(): + last_weights = non_layer_weights[1:] + for last_weight in last_weights: + layer_weight_map[last_weight[0]] = last_weight[1] + + return layer_weight_map From ce702d1301a5bd59a1b53a65c7c8e21f4bd768b7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 00:54:02 -0800 Subject: [PATCH 419/491] fixing layer range issue --- exo/inference/torch/model/hf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index c188a97a..a8def187 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -261,11 +261,13 @@ def forward( if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") + print(f"model layer amt: {len(self.model.layers)}") print(f"layer_amt: {layer_amt}") for i in layer_amt: decoder_layer = self.model.layers[i] if DEBUG >= 5: + print(f"layer #{i}") print("decoder_layer before") print(f"decoder_layer: {decoder_layer}") print(f"hidden_states: {self.hidden_states}") From e387a797fef33b40def3faef6ceb9239806b7308 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 01:00:56 -0800 Subject: [PATCH 420/491] fixing layer range issue --- exo/inference/torch/model/hf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index a8def187..ebe12d7d 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -152,6 +152,8 @@ def load_sharded_model( # load model try: shard_num_hidden_layers = shard.end_layer - shard.start_layer + if DEBUG >= 4: + print(f"config with {shard_num_hidden_layers} layers") return AutoConfig.from_pretrained( pretrained_model_name_or_path=self.local_model_path, device_map=self.device_map, From e0ba2bb4182b4c2afe54a538d5dc7642c0d1ad4b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 01:03:53 -0800 Subject: [PATCH 421/491] fixing layer range issue --- exo/inference/torch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index ebe12d7d..ed3ca2a9 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -259,7 +259,7 @@ def forward( print(f"model_inputs: {model_inputs}") # run through decoder layers - layer_amt = range(self.shard.start_layer, self.shard.end_layer + 1) + layer_amt = range(self.shard.start_layer, self.shard.end_layer) if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") From 5b9638f249a4137c0d1e93fc3a81edfc71731aec Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 01:07:14 -0800 Subject: [PATCH 422/491] checking if ram over usaage even if reducing layers on large models --- exo/inference/torch/tests/test_inference_engine.py | 8 ++++++-- exo/inference/torch/utils.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py index e102af69..93960afe 100644 --- a/exo/inference/torch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -1,14 +1,16 @@ """ Test inference engine and model sharding """ - +import time import asyncio + from exo.inference.shard import Shard from exo.inference.torch.inference import TorchDynamicShardInferenceEngine from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.inference_engine import InferenceEngine +from exo.inference.torch.utils import print_ram_stats + import numpy as np -import time async def test_inference_engine( inference_engine_1: InferenceEngine, @@ -64,6 +66,7 @@ async def test_inference_engine( # n_layers=n_layers #) + print_ram_stats() resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( "B", shard=resp_shard, @@ -74,6 +77,7 @@ async def test_inference_engine( print(resp1) print("\n------------resp1---------------\n") + print_ram_stats() time.sleep(5) #resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( diff --git a/exo/inference/torch/utils.py b/exo/inference/torch/utils.py index 994daeeb..4ad92253 100644 --- a/exo/inference/torch/utils.py +++ b/exo/inference/torch/utils.py @@ -6,6 +6,8 @@ from exo.inference.shard import Shard +import torch + def extract_layers( weight_map: dict, shard: Shard @@ -47,3 +49,14 @@ def extract_layers( layer_weight_map[last_weight[0]] = last_weight[1] return layer_weight_map + +def print_ram_stats(): + if torch.cuda.is_available(): + allocated_memory = torch.cuda.memory_allocated() + max_memory = torch.cuda.max_memory_allocated() + cached_memory = torch.cuda.memory_reserved() + + print("Cuda stats") + print(f'Allocated memory: {allocated_memory / 1024**2} MB') + print(f'Max allocated memory: {max_memory / 1024**2} MB') + print(f'Cached memory: {cached_memory / 1024**2} MB') From 664f29f3e720f89744ef05e4b881d5662675e528 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 01:23:18 -0800 Subject: [PATCH 423/491] half layer inference engine testing --- .../torch/tests/test_inference_engine.py | 74 ++++++++++--------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py index 93960afe..2d68001f 100644 --- a/exo/inference/torch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -52,19 +52,21 @@ async def test_inference_engine( # # time.sleep(5) + half_layer = int(n_layers/2) + resp_shard = Shard( model_id=model_id, start_layer=0, - end_layer=1, + end_layer=half_layer, n_layers=n_layers ) - #resp_shard2 = Shard( - # model_id=model_id, - # start_layer=3, - # end_layer=5, - # n_layers=n_layers - #) + resp_shard2 = Shard( + model_id=model_id, + start_layer=half_layer+1, + end_layer=n_layers-1, + n_layers=n_layers + ) print_ram_stats() resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( @@ -80,16 +82,16 @@ async def test_inference_engine( print_ram_stats() time.sleep(5) - #resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( - # "B", - # shard=resp_shard2, - # input_data=resp1, - # inference_state=inference_state_1, - #) + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp1, + inference_state=inference_state_1, + ) - #print("\n------------resp2---------------\n") - #print(resp2) - #print("\n------------resp2---------------\n") + print("\n------------resp2---------------\n") + print(resp2) + print("\n------------resp2---------------\n") #resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( # "B", @@ -117,16 +119,16 @@ async def test_inference_engine( #assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - # try: - # print("\n\n -------- TEST QWEN2 -------- \n\n") - # asyncio.run(test_inference_engine( - # TorchDynamicShardInferenceEngine(HFShardDownloader()), - # TorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Qwen/Qwen2-0.5B-Instruct", - # 24 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + try: + print("\n\n -------- TEST Qwen/Qwen2.5-3B-Instruct -------- \n\n") + asyncio.run(test_inference_engine( + TorchDynamicShardInferenceEngine(HFShardDownloader()), + TorchDynamicShardInferenceEngine(HFShardDownloader()), + "Qwen/Qwen2.5-3B-Instruct", + 36 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") #try: # print("\n-------- Test unsloth/Llama-3.2-1B-Instruct ----------\n") @@ -139,15 +141,15 @@ async def test_inference_engine( #except Exception as err: # print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") - try: - print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") - asyncio.run(test_inference_engine( - TorchDynamicShardInferenceEngine(HFShardDownloader()), - TorchDynamicShardInferenceEngine(HFShardDownloader()), - "unsloth/Meta-Llama-3.1-8B-Instruct", - 32 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") + #try: + # print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") + # asyncio.run(test_inference_engine( + # TorchDynamicShardInferenceEngine(HFShardDownloader()), + # TorchDynamicShardInferenceEngine(HFShardDownloader()), + # "unsloth/Meta-Llama-3.1-8B-Instruct", + # 32 + # )) + #except Exception as err: + # print(f"\n\n !!!!!!!!!!! unsloth/Llama-3.1-8B-Instruct TEST FAILED \n{err}\n") From 2591fab98103666910626d3b89513122a8fbef87 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 01:28:43 -0800 Subject: [PATCH 424/491] fixing layer amount with sharded modeling --- exo/inference/torch/model/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index ed3ca2a9..3b6fefc2 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -259,7 +259,7 @@ def forward( print(f"model_inputs: {model_inputs}") # run through decoder layers - layer_amt = range(self.shard.start_layer, self.shard.end_layer) + layer_amt = range(self.shard.end_layer - self.shard.start_layer) if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") From 99dac57b17df542bfb670a98d90bbb8a7410ea01 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 01:36:55 -0800 Subject: [PATCH 425/491] adding qwen2.5 3B for testing --- exo/models.py | 4 ++++ exo/tinychat/index.html | 1 + 2 files changed, 5 insertions(+) diff --git a/exo/models.py b/exo/models.py index fe608f05..a8a482df 100644 --- a/exo/models.py +++ b/exo/models.py @@ -71,6 +71,10 @@ "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-Coder-1.5B-Instruct", start_layer=0, end_layer=0, n_layers=28), }, + "qwen-2.5-3B-Instruct": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=36), + "TorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2.5-3B-Instruct", start_layer=0, end_layer=0, n_layers=36), + }, "qwen-2.5-coder-7b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Qwen2.5-Coder-7B-Instruct", start_layer=0, end_layer=0, n_layers=28), diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index e9be9218..8d4d4ee8 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -44,6 +44,7 @@ + From 493cd3e38bcd4eedc35d6f0f8e2956014f8d22f8 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 02:00:02 -0800 Subject: [PATCH 426/491] updating inference engine test --- .../torch/tests/test_inference_engine.py | 144 ++++++++---------- 1 file changed, 65 insertions(+), 79 deletions(-) diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py index 2d68001f..2d24c8b2 100644 --- a/exo/inference/torch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -8,7 +8,6 @@ from exo.inference.torch.inference import TorchDynamicShardInferenceEngine from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.inference_engine import InferenceEngine -from exo.inference.torch.utils import print_ram_stats import numpy as np @@ -20,37 +19,37 @@ async def test_inference_engine( prompt = "In a single word only, what is the last name of the current president of the USA?" -# shard = Shard( -# model_id=model_id, -# start_layer=0, -# end_layer=n_layers-1, -# n_layers=n_layers -# ) -# -# resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( -# "A", -# shard=shard, -# prompt=prompt -# ) -# -# print("\n------------resp_full---------------\n") -# print(resp_full) -# print("\n------------resp_full---------------\n") -# -# time.sleep(5) -# -# next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( -# "A", -# shard=shard, -# input_data=resp_full, -# inference_state=inference_state_full, -# ) -# -# print("\n------------next_resp_full---------------\n") -# print(next_resp_full) -# print("\n------------next_resp_full---------------\n") -# -# time.sleep(5) + shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + "A", + shard=shard, + prompt=prompt + ) + + print("\n------------resp_full---------------\n") + print(resp_full) + print("\n------------resp_full---------------\n") + + time.sleep(5) + + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + "A", + shard=shard, + input_data=resp_full, + inference_state=inference_state_full, + ) + + print("\n------------next_resp_full---------------\n") + print(next_resp_full) + print("\n------------next_resp_full---------------\n") + + time.sleep(5) half_layer = int(n_layers/2) @@ -68,7 +67,6 @@ async def test_inference_engine( n_layers=n_layers ) - print_ram_stats() resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( "B", shard=resp_shard, @@ -79,7 +77,6 @@ async def test_inference_engine( print(resp1) print("\n------------resp1---------------\n") - print_ram_stats() time.sleep(5) resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( @@ -93,63 +90,52 @@ async def test_inference_engine( print(resp2) print("\n------------resp2---------------\n") - #resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( - # "B", - # shard=resp_shard, - # input_data=resp2, - # inference_state=inference_state_2, - #) + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + "B", + shard=resp_shard, + input_data=resp2, + inference_state=inference_state_2, + ) - #print("\n------------resp3---------------\n") - #print(resp3) - #print("\n------------resp3---------------\n") + print("\n------------resp3---------------\n") + print(resp3) + print("\n------------resp3---------------\n") - #resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( - # "B", - # shard=resp_shard2, - # input_data=resp3, - # inference_state=inference_state_3, - #) + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp3, + inference_state=inference_state_3, + ) - #print("\n------------resp4---------------\n") - #print(resp4) - #print("\n------------resp4---------------\n") + print("\n------------resp4---------------\n") + print(resp4) + print("\n------------resp4---------------\n") - #assert np.array_equal(resp_full, resp2) - #assert np.array_equal(next_resp_full, resp4) + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - try: - print("\n\n -------- TEST Qwen/Qwen2.5-3B-Instruct -------- \n\n") - asyncio.run(test_inference_engine( - TorchDynamicShardInferenceEngine(HFShardDownloader()), - TorchDynamicShardInferenceEngine(HFShardDownloader()), - "Qwen/Qwen2.5-3B-Instruct", - 36 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") - #try: - # print("\n-------- Test unsloth/Llama-3.2-1B-Instruct ----------\n") + # print("\n\n -------- TEST Qwen/Qwen2.5-3B-Instruct -------- \n\n") # asyncio.run(test_inference_engine( # TorchDynamicShardInferenceEngine(HFShardDownloader()), # TorchDynamicShardInferenceEngine(HFShardDownloader()), - # "unsloth/Llama-3.2-1B-Instruct", - # 24 + # "Qwen/Qwen2.5-3B-Instruct", + # 36 # )) #except Exception as err: - # print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") + # print(f"\n!!!! QWEN2 TEST FAILED \n{err}\n") - #try: - # print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") - # asyncio.run(test_inference_engine( - # TorchDynamicShardInferenceEngine(HFShardDownloader()), - # TorchDynamicShardInferenceEngine(HFShardDownloader()), - # "unsloth/Meta-Llama-3.1-8B-Instruct", - # 32 - # )) - #except Exception as err: - # print(f"\n\n !!!!!!!!!!! unsloth/Llama-3.1-8B-Instruct TEST FAILED \n{err}\n") + try: + print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") + asyncio.run(test_inference_engine( + TorchDynamicShardInferenceEngine(HFShardDownloader()), + TorchDynamicShardInferenceEngine(HFShardDownloader()), + "unsloth/Meta-Llama-3.1-8B-Instruct", + 32 + )) + except Exception as err: + print(f"\n!!!! unsloth/Meta-Llama-3.1-8B-Instruct TEST FAILED \n{err}\n") From de232946eaa5a023710372a6e0cc6a351185af99 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 14 Oct 2024 02:05:09 -0800 Subject: [PATCH 427/491] cleaning up utils and split model --- exo/inference/torch/tests/test_split_model.py | 15 +++---------- exo/inference/torch/utils.py | 22 ++++++++++--------- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index 0afc3ed3..935f74df 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -15,20 +15,10 @@ from exo.download.hf.hf_helpers import get_weight_map from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.shard import Shard +from exo.inference.torch.utils import print_cuda_vram_stats from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -def print_ram_stats(): - if torch.cuda.is_available(): - allocated_memory = torch.cuda.memory_allocated() - max_memory = torch.cuda.max_memory_allocated() - cached_memory = torch.cuda.memory_reserved() - - print("Cuda stats") - print(f'Allocated memory: {allocated_memory / 1024**2} MB') - print(f'Max allocated memory: {max_memory / 1024**2} MB') - print(f'Cached memory: {cached_memory / 1024**2} MB') - def load_model( repo_id: str, shard: Shard, @@ -118,7 +108,8 @@ def load_model( local_files_only=True, ) - print_ram_stats() + if torch.cuda.is_available() and device == "cuda": + print_cuda_vram_stats() prompt = "In a single word only, what color is a red apple?" diff --git a/exo/inference/torch/utils.py b/exo/inference/torch/utils.py index 4ad92253..e5fc80e5 100644 --- a/exo/inference/torch/utils.py +++ b/exo/inference/torch/utils.py @@ -50,13 +50,15 @@ def extract_layers( return layer_weight_map -def print_ram_stats(): - if torch.cuda.is_available(): - allocated_memory = torch.cuda.memory_allocated() - max_memory = torch.cuda.max_memory_allocated() - cached_memory = torch.cuda.memory_reserved() - - print("Cuda stats") - print(f'Allocated memory: {allocated_memory / 1024**2} MB') - print(f'Max allocated memory: {max_memory / 1024**2} MB') - print(f'Cached memory: {cached_memory / 1024**2} MB') +def print_cuda_vram_stats(): + """ + Prints CUDA VRAM stats being used by pytorch + """ + allocated_memory = torch.cuda.memory_allocated() + max_memory = torch.cuda.max_memory_allocated() + cached_memory = torch.cuda.memory_reserved() + + print("CUDA stats") + print(f'Allocated memory: {allocated_memory / 1024**2} MB') + print(f'Max allocated memory: {max_memory / 1024**2} MB') + print(f'Cached memory: {cached_memory / 1024**2} MB') From e7470b1a9ae643a6e41d7eabfc44c63f077c6afa Mon Sep 17 00:00:00 2001 From: Daniel Newman Date: Tue, 15 Oct 2024 13:55:33 -0400 Subject: [PATCH 428/491] bugfix in llm setup --- exo/inference/torch/model/hf.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 3b6fefc2..20f17ee1 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -92,17 +92,18 @@ def __init__( # this is needed because shard downloader just # appends and not redownloads the file os.remove(self.model_safetensors_path) + + self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) + self.model = self.llm_model.model.to(self.device) else: - self.llm_model_config = AutoConfig.from_pretrained( + self.llm_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.dtype, device_map=self.device_map, offload_buffers=self.offload_buffers ) - - self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) - - self.model = self.llm_model.model.to(self.device) + self.model = self.llm_model.model + except Exception as err: print(f"error loading and splitting model: {err}") raise From f5a1cef3ab509f8b13e0ae49365ec61c7ef4a306 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Wed, 16 Oct 2024 15:02:27 -0700 Subject: [PATCH 429/491] handle range not satisfiable edge case --- exo/download/hf/hf_helpers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index a548df2e..3d5349e5 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -173,6 +173,8 @@ async def download_file( if progress_callback: await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) return + if DEBUG >= 2: print(f"Range not satisfiable {file_path=} {total_size=} {downloaded_size=}") + return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) except ValueError: if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) From 751bd1c3bbee3579e562538da85a3383d6332043 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 15:16:20 -0800 Subject: [PATCH 430/491] updating to use automodelforcausallm instead of autoconfig --- exo/inference/torch/model/hf.py | 16 ++++++++++------ exo/inference/torch/utils.py | 4 ---- exo/models.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 3b6fefc2..78e457c8 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -82,6 +82,7 @@ def __init__( # setup pytorch and transformer llm try: if weight_map: + print("loading shard model") self.llm_model_config = self.load_sharded_model( shard, weight_map, @@ -92,15 +93,18 @@ def __init__( # this is needed because shard downloader just # appends and not redownloads the file os.remove(self.model_safetensors_path) + + self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) else: - self.llm_model_config = AutoConfig.from_pretrained( + print("loading full model") + self.llm_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.dtype, device_map=self.device_map, - offload_buffers=self.offload_buffers - ) + offload_buffers=True + ).to(self.device) - self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) + self.model = self.llm_model.model.to(self.device) except Exception as err: @@ -112,7 +116,7 @@ def load_sharded_model( shard: Shard, weight_map: dict, offload_buffers: bool - ) -> AutoConfig: + ) -> AutoModelForCausalLM: """ Loads sharded version of model where only needed weights are loaded for necessary layers @@ -154,7 +158,7 @@ def load_sharded_model( shard_num_hidden_layers = shard.end_layer - shard.start_layer if DEBUG >= 4: print(f"config with {shard_num_hidden_layers} layers") - return AutoConfig.from_pretrained( + return AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, device_map=self.device_map, offload_buffers=offload_buffers, diff --git a/exo/inference/torch/utils.py b/exo/inference/torch/utils.py index e5fc80e5..b9c4f148 100644 --- a/exo/inference/torch/utils.py +++ b/exo/inference/torch/utils.py @@ -35,10 +35,6 @@ def extract_layers( non_layer_weights = sorted(non_layer_weights, key=lambda x: x[1]) - print(non_layer_weights) - print(f"first: {shard.is_first_layer()}") - print(f"last: {shard.is_last_layer()}") - if shard.is_first_layer(): # this assumes at max only one first weight non-layer for model first_weight = non_layer_weights[0] diff --git a/exo/models.py b/exo/models.py index a8a482df..ab797706 100644 --- a/exo/models.py +++ b/exo/models.py @@ -4,7 +4,7 @@ ### llama "llama-3.2-1b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16), - "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), + "TorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), }, "llama-3.2-3b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), From 7d866d81d768752a73ac68c47744c2623d5c58ae Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 15:18:05 -0800 Subject: [PATCH 431/491] removing meta model --- exo/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/models.py b/exo/models.py index ab797706..a8a482df 100644 --- a/exo/models.py +++ b/exo/models.py @@ -4,7 +4,7 @@ ### llama "llama-3.2-1b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16), - "TorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), + "TorchDynamicShardInferenceEngine": Shard(model_id="unsloth/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), }, "llama-3.2-3b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), From 253237b1adbeec0b00aa8df79ff8749e1b18bc63 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 15:24:28 -0800 Subject: [PATCH 432/491] updating split model test --- exo/inference/torch/tests/test_split_model.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index 935f74df..2783e7f0 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -89,18 +89,17 @@ def load_model( # setup the weight range for init_weights shard_num_hidden_layers = shard.end_layer - shard.start_layer print(f"Setting up LLM config with {shard_num_hidden_layers} hidden layers") - llm_config = AutoConfig.from_pretrained( + + # load model with layer edits + # or whole model if no weight_map + print(f"Loading sharded AutoModelForCausalLM from {model_path}") + shard_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=model_path, device_map="cuda", offload_buffers=True, local_files_only=True, num_hidden_layers=shard_num_hidden_layers - ) - - # load model with layer edits - # or whole model if no weight_map - print(f"Loading sharded AutoModelForCausalLM from {model_path}") - shard_model = AutoModelForCausalLM.from_config(llm_config).to(device) + ).to(device) print("Loading tokenizer") tokenizer = AutoTokenizer.from_pretrained( From e46ffa4a213d6e386fa5d692813abcc821cf0b7f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 15:42:20 -0800 Subject: [PATCH 433/491] updating split model test --- exo/inference/torch/tests/test_split_model.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index 2783e7f0..4717311c 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -34,6 +34,28 @@ def load_model( print("load_model called") model_st_snapshot = model_path/"model.safetensors.index.json" + if device: + device = device + elif os.environ.get("TORCH_DEVICE"): + device = torch.device(os.environ["TORCH_DEVICE"]) + elif torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + torch.set_default_device(device) + + # setup cude dtype + dtype = torch.get_default_dtype() + + # setup device_map + if os.environ.get("TORCH_DEVICE_MAP"): + device_map = os.environ["TORCH_DEVICE_MAP"] + else: + device_map = str(device) + if weight_map: layer_weight_map = {} non_layer_weights = [] @@ -95,7 +117,8 @@ def load_model( print(f"Loading sharded AutoModelForCausalLM from {model_path}") shard_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=model_path, - device_map="cuda", + device_map=device_map, + dtype=dtype, offload_buffers=True, local_files_only=True, num_hidden_layers=shard_num_hidden_layers From 476b6babbad113c6b41de8e3cf0469b4e7f71e21 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 15:49:02 -0800 Subject: [PATCH 434/491] automodel fix --- exo/inference/torch/model/hf.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 78e457c8..038fb6d1 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -83,7 +83,7 @@ def __init__( try: if weight_map: print("loading shard model") - self.llm_model_config = self.load_sharded_model( + self.llm_model = self.load_sharded_model( shard, weight_map, offload_buffers=self.offload_buffers @@ -93,8 +93,6 @@ def __init__( # this is needed because shard downloader just # appends and not redownloads the file os.remove(self.model_safetensors_path) - - self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) else: print("loading full model") self.llm_model = AutoModelForCausalLM.from_pretrained( @@ -104,8 +102,6 @@ def __init__( offload_buffers=True ).to(self.device) - - self.model = self.llm_model.model.to(self.device) except Exception as err: print(f"error loading and splitting model: {err}") @@ -161,10 +157,11 @@ def load_sharded_model( return AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, device_map=self.device_map, + torch_dtype=self.dtype, offload_buffers=offload_buffers, local_files_only=True, num_hidden_layers=shard_num_hidden_layers - ) + ).to(self.device) except Exception as err: print(f"err: {err}") raise From f7e02e9edbf912044b74d632d347808c2a43ca1b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 15:53:42 -0800 Subject: [PATCH 435/491] fixing split model test --- exo/inference/torch/tests/test_split_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index 4717311c..68e9e95e 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -118,7 +118,7 @@ def load_model( shard_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=model_path, device_map=device_map, - dtype=dtype, + torch_dtype=dtype, offload_buffers=True, local_files_only=True, num_hidden_layers=shard_num_hidden_layers From bd6322f870fbf9303844c635d4c54e049d745c5b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 16:00:34 -0800 Subject: [PATCH 436/491] pytorch offload buffers error --- exo/inference/torch/tests/test_split_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index 68e9e95e..bfdbfb49 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -119,7 +119,7 @@ def load_model( pretrained_model_name_or_path=model_path, device_map=device_map, torch_dtype=dtype, - offload_buffers=True, + offload_buffers=False, local_files_only=True, num_hidden_layers=shard_num_hidden_layers ).to(device) From c51bd916716ca708cde421e6c0a2bbe9022adacb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 16:08:37 -0800 Subject: [PATCH 437/491] device_map any issue with split model --- exo/inference/torch/tests/test_split_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index bfdbfb49..25a49538 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -119,10 +119,10 @@ def load_model( pretrained_model_name_or_path=model_path, device_map=device_map, torch_dtype=dtype, - offload_buffers=False, + offload_buffers=True, local_files_only=True, num_hidden_layers=shard_num_hidden_layers - ).to(device) + ) print("Loading tokenizer") tokenizer = AutoTokenizer.from_pretrained( @@ -218,4 +218,4 @@ async def test_split_model( 32 )) except Exception as err: - print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") + print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.1-8B-Instruct TEST FAILED \n{err}\n") From 4a2aef40e375b920622bb9198ce640a698121383 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 16:40:31 -0800 Subject: [PATCH 438/491] updating split model test --- exo/inference/torch/model/hf.py | 12 ++++++--- exo/inference/torch/tests/test_split_model.py | 25 ++++++++----------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 038fb6d1..1850469b 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -11,7 +11,6 @@ from exo.inference.torch.utils import extract_layers from transformers import ( - AutoConfig, AutoModelForCausalLM, DynamicCache, Cache, @@ -154,14 +153,21 @@ def load_sharded_model( shard_num_hidden_layers = shard.end_layer - shard.start_layer if DEBUG >= 4: print(f"config with {shard_num_hidden_layers} layers") - return AutoModelForCausalLM.from_pretrained( + + llm_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, device_map=self.device_map, torch_dtype=self.dtype, offload_buffers=offload_buffers, local_files_only=True, num_hidden_layers=shard_num_hidden_layers - ).to(self.device) + ) + + if self.device_map == "auto": + return llm_model + else: + return llm_model.to(self.device) + except Exception as err: print(f"err: {err}") raise diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index 25a49538..183406a5 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -17,14 +17,13 @@ from exo.inference.shard import Shard from exo.inference.torch.utils import print_cuda_vram_stats -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer def load_model( - repo_id: str, shard: Shard, model_path: Path, weight_map: Optional[dict], - device: Optional[str] = "cuda" + device: Optional[torch.device] = torch.device("cpu") ) -> Optional[AutoModelForCausalLM]: """ load model by layer and safetensors @@ -34,16 +33,12 @@ def load_model( print("load_model called") model_st_snapshot = model_path/"model.safetensors.index.json" - if device: - device = device - elif os.environ.get("TORCH_DEVICE"): + if os.environ.get("TORCH_DEVICE"): device = torch.device(os.environ["TORCH_DEVICE"]) elif torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): device = torch.device("mps") - else: - device = torch.device("cpu") torch.set_default_device(device) @@ -122,7 +117,7 @@ def load_model( offload_buffers=True, local_files_only=True, num_hidden_layers=shard_num_hidden_layers - ) + ).to(device) print("Loading tokenizer") tokenizer = AutoTokenizer.from_pretrained( @@ -159,8 +154,6 @@ def load_model( print(f"Prompt: {prompt}\n") print(f"Response: {response}\n") - print_ram_stats() - # have to clear out edited model safetensors mst_json os.remove(model_st_snapshot) @@ -189,13 +182,15 @@ async def test_split_model( weight_map = await get_weight_map(model_id) load_model( - model_id, shard, model_path, weight_map ) if __name__ == "__main__": + n_layers = int(os.environ["N_LAYERS"]) if os.environ.get("N_LAYERS") else 32 + start_layer = int(os.environ["START_LAYER"]) if os.environ.get("START_LAYER") else 0 + end_layer = int(os.environ["END_LAYER"]) if os.environ.get("END_LAYER") else int(n_layers/2) #Qwen/Qwen2.5-3B #try: # print("\n-------- Test Qwen/Qwen2.5-3B-Instruct ----------\n") @@ -213,9 +208,9 @@ async def test_split_model( print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") asyncio.run(test_split_model( "unsloth/Meta-Llama-3.1-8B-Instruct", - 0, - 6, - 32 + start_layer, + end_layer, + n_layers )) except Exception as err: print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.1-8B-Instruct TEST FAILED \n{err}\n") From 79f0763d59fcddb20ce024e85f8a8f151621eb2e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 17:10:58 -0800 Subject: [PATCH 439/491] fixing split model issue --- exo/inference/torch/model/hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 1850469b..eb6957c3 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -86,7 +86,7 @@ def __init__( shard, weight_map, offload_buffers=self.offload_buffers - ) + ).to(self.device) # clear out edited safetensor json # this is needed because shard downloader just @@ -311,7 +311,7 @@ def forward( # shard is last layer says true at the start and not detecting last layer correctly if self.shard.is_last_layer(): self.hidden_states = self.model.norm(self.hidden_states) - if use_legacy_cache: + if use_legacy_cache and self.next_decoder_cache is not None: self.past_key_values = self.next_decoder_cache.to_legacy_cache() else: self.past_key_values = self.next_decoder_cache From cbbc9cf1aeeb9fdee496ce4ea4da2d67d44144dd Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 17:16:31 -0800 Subject: [PATCH 440/491] fixing node issues --- exo/inference/torch/inference.py | 1 + exo/inference/torch/model/hf.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index 63f29284..5459664c 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -91,6 +91,7 @@ def infer_caching( cached_iids = {"input_ids": past_iids.tolist()} if DEBUG >= 4: + print(f"cached_iids len: {len(cached_iids)}") print(f"cached_iids: {cached_iids}") return (past_iids, cached_iids) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index eb6957c3..c7153b5d 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -98,7 +98,7 @@ def __init__( pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.dtype, device_map=self.device_map, - offload_buffers=True + offload_buffers=offload_buffers ).to(self.device) self.model = self.llm_model.model.to(self.device) From 58cebabd85c8354e570d486a4b31f56b1ce8b1d3 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 18:06:45 -0800 Subject: [PATCH 441/491] fixing node issues --- exo/inference/torch/inference.py | 4 ++++ exo/inference/torch/model/hf.py | 7 ++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index 5459664c..47ccc599 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -127,6 +127,10 @@ async def async_forward( attention_mask=attention_mask )) + if DEBUG >=4 : + print("async_forward") + print(f"result: {result}") + return result async def async_logit_sample( diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index c7153b5d..2dab660c 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -86,7 +86,7 @@ def __init__( shard, weight_map, offload_buffers=self.offload_buffers - ).to(self.device) + ) # clear out edited safetensor json # this is needed because shard downloader just @@ -163,10 +163,7 @@ def load_sharded_model( num_hidden_layers=shard_num_hidden_layers ) - if self.device_map == "auto": - return llm_model - else: - return llm_model.to(self.device) + return llm_model.to(self.device) except Exception as err: print(f"err: {err}") From 7f9b1bb1833a8d05780ec3a2ce316bbe86662957 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 18:11:10 -0800 Subject: [PATCH 442/491] fixing node issues --- exo/inference/torch/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index 47ccc599..f89f2367 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -127,11 +127,11 @@ async def async_forward( attention_mask=attention_mask )) - if DEBUG >=4 : + if DEBUG >=4: print("async_forward") print(f"result: {result}") - return result + return result[0], result[1], result[2] async def async_logit_sample( self, From c3adec5bc2719ffe90e9a1984cb2436a2ce62b65 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 18:24:51 -0800 Subject: [PATCH 443/491] fixing node issues --- exo/inference/torch/inference.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index f89f2367..8c543850 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -204,10 +204,14 @@ async def infer_prompt( if DEBUG >= 4: print(f"past_input_ids: {self.past_input_ids}\n") - shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( input_ids=self.past_input_ids, attention_mask=input_attention_mask ) + #shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( + # input_ids=self.past_input_ids, + # attention_mask=input_attention_mask + #) if DEBUG >= 4: print(f"\nshard_hidden_states: {shard_hidden_states}\n") From c8e6acc49de461b7a3fa9c3699f1ae80e0cb840f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 18:35:39 -0800 Subject: [PATCH 444/491] fixing node issues --- exo/inference/torch/inference.py | 6 +--- exo/inference/torch/model/hf.py | 60 ++++++++++++++++---------------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index 8c543850..f89f2367 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -204,14 +204,10 @@ async def infer_prompt( if DEBUG >= 4: print(f"past_input_ids: {self.past_input_ids}\n") - shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( input_ids=self.past_input_ids, attention_mask=input_attention_mask ) - #shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward( - # input_ids=self.past_input_ids, - # attention_mask=input_attention_mask - #) if DEBUG >= 4: print(f"\nshard_hidden_states: {shard_hidden_states}\n") diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 2dab660c..21d5b345 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -226,41 +226,41 @@ def forward( print(f"position_ids: {self.position_ids}") print(f"past_key_values: {past_key_values}") - if self.hidden_states is None: - # casual mask and attention_mask - self.attention_mask = attention_mask - self.causal_mask = self.model._update_causal_mask( - None, + #if self.hidden_states is None: + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, + self.inputs_embeds, + cache_position, + past_key_values, + False # dont out attentions + ) + + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( self.inputs_embeds, - cache_position, - past_key_values, - False # dont out attentions + self.position_ids ) - # embed positions, some models require and some dont - if isinstance(self.model, LlamaModel): - self.position_embeddings = self.model.rotary_emb( - self.inputs_embeds, - self.position_ids - ) - - # prepare inputs for decoder layers - model_inputs = self.llm_model.prepare_inputs_for_generation( - self.input_ids, - past_key_values=past_key_values, - attention_mask=self.attention_mask, - inputs_embeds=self.inputs_embeds, - position_ids=self.position_ids, - cache_position=cache_position - ) + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=self.position_ids, + cache_position=cache_position + ) - self.hidden_states = self.inputs_embeds - self.position_ids = model_inputs["position_ids"] - self.cache_position = model_inputs["cache_position"] - self.past_key_values = model_inputs["past_key_values"] + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] - if DEBUG >= 4: - print(f"model_inputs: {model_inputs}") + if DEBUG >= 4: + print(f"model_inputs: {model_inputs}") # run through decoder layers layer_amt = range(self.shard.end_layer - self.shard.start_layer) From df028e2219b7dcda0cae7b983026a65788d790cc Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 18:51:19 -0800 Subject: [PATCH 445/491] fixing node issues, range issue --- exo/inference/torch/model/hf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 21d5b345..63614730 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -62,6 +62,7 @@ def __init__( self.position_ids = None self.causal_mask = None self.local_model_path = local_model_path + self.is_sharded_model = False # setup logit processors self.logits_processor = LogitsProcessorList([ @@ -88,6 +89,8 @@ def __init__( offload_buffers=self.offload_buffers ) + self.is_sharded_model = True + # clear out edited safetensor json # this is needed because shard downloader just # appends and not redownloads the file @@ -263,7 +266,10 @@ def forward( print(f"model_inputs: {model_inputs}") # run through decoder layers - layer_amt = range(self.shard.end_layer - self.shard.start_layer) + if self.is_sharded_model: + layer_amt = range(self.shard.end_layer - self.shard.start_layer) + else: + layer_amt = range(self.shard.start_layer, self.shard.end_layer) if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") From e5a1939694a332004e30b57f0cfc29a3a3f8e7f3 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 16 Oct 2024 18:55:27 -0800 Subject: [PATCH 446/491] fixing node issues, range issue --- exo/inference/torch/model/hf.py | 56 ++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 63614730..2d543ce5 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -229,38 +229,38 @@ def forward( print(f"position_ids: {self.position_ids}") print(f"past_key_values: {past_key_values}") - #if self.hidden_states is None: - # casual mask and attention_mask - self.attention_mask = attention_mask - self.causal_mask = self.model._update_causal_mask( - None, - self.inputs_embeds, - cache_position, - past_key_values, - False # dont out attentions - ) - - # embed positions, some models require and some dont - if isinstance(self.model, LlamaModel): - self.position_embeddings = self.model.rotary_emb( + if self.hidden_states is None: + # casual mask and attention_mask + self.attention_mask = attention_mask + self.causal_mask = self.model._update_causal_mask( + None, self.inputs_embeds, - self.position_ids + cache_position, + past_key_values, + False # dont out attentions ) - # prepare inputs for decoder layers - model_inputs = self.llm_model.prepare_inputs_for_generation( - self.input_ids, - past_key_values=past_key_values, - attention_mask=self.attention_mask, - inputs_embeds=self.inputs_embeds, - position_ids=self.position_ids, - cache_position=cache_position - ) + # embed positions, some models require and some dont + if isinstance(self.model, LlamaModel): + self.position_embeddings = self.model.rotary_emb( + self.inputs_embeds, + self.position_ids + ) + + # prepare inputs for decoder layers + model_inputs = self.llm_model.prepare_inputs_for_generation( + self.input_ids, + past_key_values=past_key_values, + attention_mask=self.attention_mask, + inputs_embeds=self.inputs_embeds, + position_ids=self.position_ids, + cache_position=cache_position + ) - self.hidden_states = self.inputs_embeds - self.position_ids = model_inputs["position_ids"] - self.cache_position = model_inputs["cache_position"] - self.past_key_values = model_inputs["past_key_values"] + self.hidden_states = self.inputs_embeds + self.position_ids = model_inputs["position_ids"] + self.cache_position = model_inputs["cache_position"] + self.past_key_values = model_inputs["past_key_values"] if DEBUG >= 4: print(f"model_inputs: {model_inputs}") From d07b825cb8e953c339671ba92c3920a5bdcdb2a9 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 18 Oct 2024 12:06:29 -0800 Subject: [PATCH 447/491] adding num hidden layers manipulation for all models --- exo/inference/torch/model/hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 70fee8e5..e0d459f8 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -104,7 +104,8 @@ def __init__( pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.dtype, device_map=self.device_map, - offload_buffers=offload_buffers + offload_buffers=offload_buffers, + num_hidden_layers=int(shard.end_layer - shard.start_layer) ).to(self.device) self.model = self.llm_model.model.to(self.device) From a840e7fc3563677d55dac021fed5cae68d246a5f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 18 Oct 2024 12:08:43 -0800 Subject: [PATCH 448/491] updating to use shard_num_hidden_layers --- exo/inference/torch/model/hf.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index e0d459f8..2b925214 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -99,13 +99,14 @@ def __init__( self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) self.model = self.llm_model.model.to(self.device) else: - print("loading full model") + shard_num_hidden_layers = shard.end_layer - shard.start_layer + print(f"loading safetensor in {shard_num_hidden_layers} layer model") self.llm_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.dtype, device_map=self.device_map, offload_buffers=offload_buffers, - num_hidden_layers=int(shard.end_layer - shard.start_layer) + num_hidden_layers=shard_num_hidden_layers ).to(self.device) self.model = self.llm_model.model.to(self.device) From 52fa3f877fcb35a7d3dcc18aba3f374c583dab85 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 18 Oct 2024 12:31:30 -0800 Subject: [PATCH 449/491] adding in better layer manipulation --- exo/inference/torch/model/hf.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 2b925214..a8b6e50f 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -89,15 +89,12 @@ def __init__( offload_buffers=self.offload_buffers ) - self.is_sharded_model = True + # self.is_sharded_model = True # clear out edited safetensor json # this is needed because shard downloader just # appends and not redownloads the file os.remove(self.model_safetensors_path) - - self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device) - self.model = self.llm_model.model.to(self.device) else: shard_num_hidden_layers = shard.end_layer - shard.start_layer print(f"loading safetensor in {shard_num_hidden_layers} layer model") @@ -271,10 +268,10 @@ def forward( print(f"model_inputs: {model_inputs}") # run through decoder layers - if self.is_sharded_model: - layer_amt = range(self.shard.end_layer - self.shard.start_layer) - else: - layer_amt = range(self.shard.start_layer, self.shard.end_layer) + # if self.is_sharded_model: + layer_amt = range(self.shard.end_layer - self.shard.start_layer) + # else: + # layer_amt = range(self.shard.start_layer, self.shard.end_layer) if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") From ec49e316c28c272ef71a0f7cfd83d4e4fc576e54 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 18 Oct 2024 16:53:45 -0800 Subject: [PATCH 450/491] adding in safe tensor sharding, generate model.safetensors.index.json from single safetensor, starting safetensor sharding test --- .../torch/model/hf_safe_tensor_shard.py | 89 +++++++++++++ .../torch/tests/test_safetensor_json.py | 120 ++++++++++++++++++ .../torch/tests/test_safetensor_shard.py | 3 + exo/inference/torch/tests/test_split_model.py | 2 - 4 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 exo/inference/torch/model/hf_safe_tensor_shard.py create mode 100644 exo/inference/torch/tests/test_safetensor_json.py create mode 100644 exo/inference/torch/tests/test_safetensor_shard.py diff --git a/exo/inference/torch/model/hf_safe_tensor_shard.py b/exo/inference/torch/model/hf_safe_tensor_shard.py new file mode 100644 index 00000000..caa23e4a --- /dev/null +++ b/exo/inference/torch/model/hf_safe_tensor_shard.py @@ -0,0 +1,89 @@ +""" +HuggingFace Safetensor Shard +Sharding of safetensors to only use weights of models needed +""" +import os +import shutil +from safetensors import safe_open +from safetensors.torch import save_file + +class HFSafeTensorShard: + def __init__(self, model_folder, start_layer, end_layer): + self.model_folder = model_folder + self.start_layer = start_layer + self.end_layer = end_layer + self.safetensor_path = self.get_safetensor_path() + self.backup_path = self.safetensor_path + ".backup" + + def get_safetensor_path(self): + try: + for file_name in os.listdir(self.model_folder): + if file_name.endswith(".safetensors"): + return os.path.join(self.model_folder, file_name) + raise FileNotFoundError("No safetensors file found in the provided model folder.") + except Exception as err: + print(f"Error in get_safetensor_path: {err}") + raise + + def backup_safetensor(self): + try: + if not os.path.exists(self.backup_path): + shutil.copy(self.safetensor_path, self.backup_path) + print(f"Backup created at {self.backup_path}") + else: + print("Backup already exists. Skipping backup.") + except Exception as err: + print(f"Error in backup_safetensor: {err}") + raise + + def modify_safetensor(self): + # Ensure the safetensor is backed up before modifying + self.backup_safetensor() + + try: + with safe_open(self.safetensor_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + new_tensors = {} + + # Iterate over tensors, including only those within the specified layer range + for key in f.keys(): + layer_number = self.extract_layer_number(key) + if self.start_layer <= layer_number <= self.end_layer: + new_tensors[key] = f.get_tensor(key) + else: + print(f"Excluding layer {layer_number}: {key}") + + # Save the modified safetensor + save_file(new_tensors, self.safetensor_path, metadata) + print(f"Safetensor modified and saved to {self.safetensor_path}") + except Exception as err: + print(f"Error modifying safetensor: {err}") + raise + + def extract_layer_number(self, key): + """ + Extract the layer number from a tensor key. + This function assumes keys follow the format 'transformer.h..'. + """ + try: + parts = key.split(".") + layer_idx = next(i for i, part in enumerate(parts) if part.startswith("h")) + return int(parts[layer_idx + 1]) + except (IndexError, ValueError) as err: + print(f"Error extracting layer number from key '{key}': {err}") + return -1 + + def restore_backup(self): + """ + Restore the original safetensor from the backup file. + This is useful when you want to reset to the original before making new modifications. + """ + try: + if os.path.exists(self.backup_path): + shutil.copy(self.backup_path, self.safetensor_path) + print(f"Safetensor restored from backup at {self.backup_path}") + else: + print("No backup found. Cannot restore.") + except Exception as err: + print(f"Error in restore_backup: {err}") + raise diff --git a/exo/inference/torch/tests/test_safetensor_json.py b/exo/inference/torch/tests/test_safetensor_json.py new file mode 100644 index 00000000..3ec02c71 --- /dev/null +++ b/exo/inference/torch/tests/test_safetensor_json.py @@ -0,0 +1,120 @@ +""" +Create a model.safetensors.index.json from safetensors +""" +import json +import os + +import asyncio + +from safetensors import safe_open + +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.shard import Shard + +import torch + +def create_safetensor_index(safetensor_files: list, index_file: str): + """ + Creates a model.safetensors.index.json file from a list of safetensor files. + + Args: + safetensor_files (list): List of paths to the safetensor files. + index_file (str): Path where the index JSON file should be saved. + + Raises: + ValueError: If an unsupported data type is encountered. + """ + if safetensor_files: + # Initialize the metadata and weight_map + metadata = { + "metadata": { + "total_size": 0 + }, + "weight_map": {} + } + + for safetensor_file in safetensor_files: + # Use the safetensor file name as the shard_name + shard_name = os.path.basename(safetensor_file) + + # Open the safetensor file to read the metadata + with safe_open(safetensor_file, framework="pt") as f: + # Get tensor names + tensor_names = f.keys() + + # Collect metadata for each tensor + for name in tensor_names: + tensor_data = f.get_tensor(name) + print(f"tensor_data: {tensor_data}") + shape = tensor_data.shape + dtype = tensor_data.dtype + print(f"shape: {shape}") + print(f"dtype: {str(dtype) == "torch.bfloat16"}") + + # Calculate the tensor size in bytes based on dtype + total_elements = 1 + for dim in shape: + total_elements *= dim + + if dtype == torch.float32: + element_size = 4 + elif dtype == torch.float16 or dtype == torch.bfloat16: + element_size = 2 + # Extend this to support more data types if needed + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + tensor_size = total_elements * element_size + metadata["metadata"]["total_size"] += tensor_size + + # Add to weight_map, mapping the tensor to the shard (file) name + metadata["weight_map"][name] = shard_name + + # Write the metadata and weight map to the index file + with open(index_file, "w") as f: + json.dump(metadata, f, indent=4) + + print(f"Index file created: {index_file}") + else: + print("No safetensor files provided.") + + +async def main(): + """ + Main asynchronous function to download the model shard and create an index file for safetensors. + + This function downloads a model shard from Hugging Face, identifies safetensor files, and + generates a corresponding index file using the `create_safetensor_index` function. + """ + start_layer = 3 + end_layer = 5 + + # Create a Shard object + shard = Shard( + model_id="meta-llama/Llama-3.2-1B-Instruct", + start_layer=start_layer, + end_layer=end_layer-1, + n_layers=32 + ) + + print(f"Loading shard: {shard}") + shard_downloader = HFShardDownloader() + + # Ensure shard is downloaded + model_path = await shard_downloader.ensure_shard(shard) + + # Collect all safetensor files from the model path + safetensor_files = [ + os.path.join(model_path, file_name) + for file_name in os.listdir(model_path) if file_name.endswith(".safetensors") + ] + + # Create the index file + if safetensor_files: + create_safetensor_index(safetensor_files, os.path.join(model_path, "model.safetensors.index.json")) + else: + print("No safetensor files found in the model path.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/exo/inference/torch/tests/test_safetensor_shard.py b/exo/inference/torch/tests/test_safetensor_shard.py new file mode 100644 index 00000000..d18e3a95 --- /dev/null +++ b/exo/inference/torch/tests/test_safetensor_shard.py @@ -0,0 +1,3 @@ +""" +Sharding safetensor +""" \ No newline at end of file diff --git a/exo/inference/torch/tests/test_split_model.py b/exo/inference/torch/tests/test_split_model.py index 183406a5..197a7c07 100644 --- a/exo/inference/torch/tests/test_split_model.py +++ b/exo/inference/torch/tests/test_split_model.py @@ -10,8 +10,6 @@ import torch -from transformers.modeling_utils import offload_weight - from exo.download.hf.hf_helpers import get_weight_map from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.shard import Shard From f45b51444bb9bf6c1e7bf2fd853dc4895c8c9e79 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 19 Oct 2024 04:51:29 -0800 Subject: [PATCH 451/491] implementing sharding tests, fixing bugs with safetensor recompile --- exo/inference/torch/model/hf.py | 23 +-- .../torch/model/hf_safe_tensor_shard.py | 190 ++++++++++++++---- .../torch/tests/test_safetensor_shard.py | 42 +++- 3 files changed, 199 insertions(+), 56 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index a8b6e50f..1850469b 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -62,7 +62,6 @@ def __init__( self.position_ids = None self.causal_mask = None self.local_model_path = local_model_path - self.is_sharded_model = False # setup logit processors self.logits_processor = LogitsProcessorList([ @@ -89,21 +88,17 @@ def __init__( offload_buffers=self.offload_buffers ) - # self.is_sharded_model = True - # clear out edited safetensor json # this is needed because shard downloader just # appends and not redownloads the file os.remove(self.model_safetensors_path) else: - shard_num_hidden_layers = shard.end_layer - shard.start_layer - print(f"loading safetensor in {shard_num_hidden_layers} layer model") + print("loading full model") self.llm_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.dtype, device_map=self.device_map, - offload_buffers=offload_buffers, - num_hidden_layers=shard_num_hidden_layers + offload_buffers=True ).to(self.device) self.model = self.llm_model.model.to(self.device) @@ -168,7 +163,10 @@ def load_sharded_model( num_hidden_layers=shard_num_hidden_layers ) - return llm_model.to(self.device) + if self.device_map == "auto": + return llm_model + else: + return llm_model.to(self.device) except Exception as err: print(f"err: {err}") @@ -264,14 +262,11 @@ def forward( self.cache_position = model_inputs["cache_position"] self.past_key_values = model_inputs["past_key_values"] - if DEBUG >= 4: - print(f"model_inputs: {model_inputs}") + if DEBUG >= 4: + print(f"model_inputs: {model_inputs}") # run through decoder layers - # if self.is_sharded_model: layer_amt = range(self.shard.end_layer - self.shard.start_layer) - # else: - # layer_amt = range(self.shard.start_layer, self.shard.end_layer) if DEBUG >= 4: print(f"hidden_states: {self.hidden_states}") @@ -316,7 +311,7 @@ def forward( # shard is last layer says true at the start and not detecting last layer correctly if self.shard.is_last_layer(): self.hidden_states = self.model.norm(self.hidden_states) - if use_legacy_cache and self.next_decoder_cache is not None: + if use_legacy_cache: self.past_key_values = self.next_decoder_cache.to_legacy_cache() else: self.past_key_values = self.next_decoder_cache diff --git a/exo/inference/torch/model/hf_safe_tensor_shard.py b/exo/inference/torch/model/hf_safe_tensor_shard.py index caa23e4a..e0737292 100644 --- a/exo/inference/torch/model/hf_safe_tensor_shard.py +++ b/exo/inference/torch/model/hf_safe_tensor_shard.py @@ -4,58 +4,72 @@ """ import os import shutil +import json + +from pathlib import Path + from safetensors import safe_open from safetensors.torch import save_file +import torch + +from exo.inference.shard import Shard +from exo.helpers import DEBUG +from exo.inference.torch.utils import extract_layers + class HFSafeTensorShard: - def __init__(self, model_folder, start_layer, end_layer): - self.model_folder = model_folder - self.start_layer = start_layer - self.end_layer = end_layer - self.safetensor_path = self.get_safetensor_path() - self.backup_path = self.safetensor_path + ".backup" - - def get_safetensor_path(self): + def __init__(self, model_path: Path, shard: Shard): + self.model_path = model_path + self.shard = shard + self.safetensors_path = self.get_safetensors() + + def get_safetensors(self) -> list: + safetensors_path = [] try: - for file_name in os.listdir(self.model_folder): + for file_name in os.listdir(self.model_path): if file_name.endswith(".safetensors"): - return os.path.join(self.model_folder, file_name) - raise FileNotFoundError("No safetensors file found in the provided model folder.") + safetensors_path.append(os.path.join(self.model_path, file_name)) except Exception as err: print(f"Error in get_safetensor_path: {err}") raise + return safetensors_path + def backup_safetensor(self): try: - if not os.path.exists(self.backup_path): - shutil.copy(self.safetensor_path, self.backup_path) - print(f"Backup created at {self.backup_path}") - else: - print("Backup already exists. Skipping backup.") + for safetensor_path in self.safetensors_path: + backup_path = safetensor_path+".backup" + if not os.path.exists(backup_path): + shutil.copy(safetensor_path, backup_path) + print(f"Backup created at {backup_path}") + else: + print("Backup already exists. Skipping backup.") except Exception as err: print(f"Error in backup_safetensor: {err}") raise def modify_safetensor(self): # Ensure the safetensor is backed up before modifying - self.backup_safetensor() - try: - with safe_open(self.safetensor_path, framework="pt", device="cpu") as f: - metadata = f.metadata() - new_tensors = {} - - # Iterate over tensors, including only those within the specified layer range - for key in f.keys(): - layer_number = self.extract_layer_number(key) - if self.start_layer <= layer_number <= self.end_layer: - new_tensors[key] = f.get_tensor(key) - else: - print(f"Excluding layer {layer_number}: {key}") - - # Save the modified safetensor - save_file(new_tensors, self.safetensor_path, metadata) - print(f"Safetensor modified and saved to {self.safetensor_path}") + self.backup_safetensor() + + for safetensor_path in self.safetensors_path: + with safe_open(safetensor_path, framework="pt") as f: + metadata = f.metadata() + new_tensors = {} + + # Iterate over tensors, including only those within the specified layer range + print(f"\n{f.keys()}\n") + for key in f.keys(): + layer_number = self.extract_layer_number(key) + if self.shard.start_layer <= layer_number <= self.shard.end_layer: + if DEBUG >= 4: + print(f"modify_safetensor [{layer_number}] extracting {key}") + new_tensors[key] = f.get_tensor(key) + + # Save the modified safetensor + save_file(new_tensors, safetensor_path, metadata) + print(f"Safetensor modified and saved to {safetensor_path}") except Exception as err: print(f"Error modifying safetensor: {err}") raise @@ -63,12 +77,16 @@ def modify_safetensor(self): def extract_layer_number(self, key): """ Extract the layer number from a tensor key. - This function assumes keys follow the format 'transformer.h..'. + This function assumes keys follow the format 'model.layers..'. """ try: parts = key.split(".") - layer_idx = next(i for i, part in enumerate(parts) if part.startswith("h")) - return int(parts[layer_idx + 1]) + layer_idx = 0 + if parts[0] == "model" and parts[1] == "layers": + layer_idx = int(parts[2]) + return layer_idx + #layer_idx = next(i for i, part in enumerate(parts) if part.startswith("h")) + #return int(parts[layer_idx + 1]) except (IndexError, ValueError) as err: print(f"Error extracting layer number from key '{key}': {err}") return -1 @@ -79,11 +97,101 @@ def restore_backup(self): This is useful when you want to reset to the original before making new modifications. """ try: - if os.path.exists(self.backup_path): - shutil.copy(self.backup_path, self.safetensor_path) - print(f"Safetensor restored from backup at {self.backup_path}") - else: - print("No backup found. Cannot restore.") + for safetensor_path in self.safetensors_path: + backup_path = safetensor_path+".backup" + if os.path.exists(backup_path): + shutil.copy(backup_path, safetensor_path) + print(f"Safetensor restored from backup at {backup_path}") + else: + print("No backup found. Cannot restore.") except Exception as err: print(f"Error in restore_backup: {err}") raise + + def create_safetensor_index(self): + """ + Creates a model.safetensors.index.json file from a list of safetensor files. + + Args: + + Raises: + """ + if self.safetensors_path: + # initialize the metadata and weight_map + metadata = { + "metadata": { + "total_size": 0 + }, + "weight_map": {} + } + + for safetensor_file in self.safetensors_path: + # use the safetensor file name as the shard_name + shard_name = os.path.basename(safetensor_file) + + # open the safetensor file to read the metadata + with safe_open(safetensor_file, framework="pt") as f: + # get tensor names + tensor_names = f.keys() + + # collect metadata for each tensor + for name in tensor_names: + tensor_data = f.get_tensor(name) + print(f"tensor_data: {tensor_data}") + shape = tensor_data.shape + dtype = tensor_data.dtype + print(f"shape: {shape}") + print(f"dtype: {str(dtype) == "torch.bfloat16"}") + + # calculate the tensor size in bytes based on dtype + total_elements = 1 + for dim in shape: + total_elements *= dim + + if dtype == torch.float32: + element_size = 4 + elif dtype == torch.float16 or dtype == torch.bfloat16: + element_size = 2 + # extend this to support more data types if needed + else: + raise ValueError(f"unsupported dtype: {dtype}") + + tensor_size = total_elements * element_size + metadata["metadata"]["total_size"] += tensor_size + + # add to weight_map, mapping the tensor to the shard (file) name + metadata["weight_map"][name] = shard_name + + # write the metadata and weight map to the index file + with open(f"{self.model_path}/model.safetensors.index.json", "w") as f: + json.dump(metadata, f, indent=4) + + print("model.safetensors.index.json created") + else: + print("No safetensor files provided.") + + def shard_safetensor_index(self, weight_map): + layer_weight_map = extract_layers( + weight_map, + self.shard + ) + + # rewrite model.safetensors.index.json for only needed layers + try: + mst_json = {} + for safetensor_path in self.safetensors_path: + with open(safetensor_path, "r") as mst_file: + mst_json = json.load(mst_file) + mst_json["weight_map"] = layer_weight_map + + if DEBUG >= 4: + print(f"rewritten safetensor index \n{json.dumps(mst_json, indent=4)}") + + os.remove(safetensor_path) + + with open(safetensor_path, "w") as mst_file: + json.dump(mst_json, mst_file, indent=4) + except Exception as err: + print(f"err: {err}") + raise + diff --git a/exo/inference/torch/tests/test_safetensor_shard.py b/exo/inference/torch/tests/test_safetensor_shard.py index d18e3a95..14821248 100644 --- a/exo/inference/torch/tests/test_safetensor_shard.py +++ b/exo/inference/torch/tests/test_safetensor_shard.py @@ -1,3 +1,43 @@ """ Sharding safetensor -""" \ No newline at end of file +""" + +import asyncio + +from exo.inference.shard import Shard +from exo.inference.torch.model.hf_safe_tensor_shard import HFSafeTensorShard +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.hf.hf_helpers import get_weight_map + + +async def main(): + start_layer = 3 + end_layer = 5 + + # Create a Shard object + shard = Shard( + model_id="meta-llama/Llama-3.2-1B-Instruct", + start_layer=start_layer, + end_layer=end_layer-1, + n_layers=32 + ) + + print(f"Loading shard: {shard}") + shard_downloader = HFShardDownloader() + + # Ensure shard is downloaded + model_path = await shard_downloader.ensure_shard(shard) + + # weight map, if any + model_wm = await get_weight_map( + repo_id=shard.model_id + ) + + tensor_shard = HFSafeTensorShard(model_path, shard) + tensor_shard.modify_safetensor() + tensor_shard.create_safetensor_index() + tensor_shard.restore_backup() + + +if __name__ == "__main__": + asyncio.run(main()) From f90c24a2b67c4aa68493e2be8470c7be304af716 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 20 Oct 2024 04:03:33 -0800 Subject: [PATCH 452/491] adding safetensor sharding, implementing it into model inference engine --- exo/inference/torch/model/hf.py | 118 +++++++----------- .../torch/model/hf_safe_tensor_shard.py | 102 +++++++++------ .../torch/tests/test_inference_engine.py | 2 +- .../torch/tests/test_safetensor_shard.py | 30 ++++- .../torch/tests/test_simple_model.py | 83 ++++++------ 5 files changed, 176 insertions(+), 159 deletions(-) diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/model/hf.py index 1850469b..f15d5d19 100644 --- a/exo/inference/torch/model/hf.py +++ b/exo/inference/torch/model/hf.py @@ -9,6 +9,7 @@ from exo.inference.shard import Shard from exo.helpers import DEBUG from exo.inference.torch.utils import extract_layers +from exo.inference.torch.model.hf_safe_tensor_shard import HFSafeTensorShard from transformers import ( AutoModelForCausalLM, @@ -52,17 +53,17 @@ def __init__( # class vars self.shard = shard - self.hidden_states = None - self.input_ids = None - self.inputs_embeds = None - self.attention_mask = None - self.position_embeddings = None - self.past_key_values = None - self.cache_position = None - self.position_ids = None - self.causal_mask = None self.local_model_path = local_model_path - + self.weight_map = weight_map + self.device = device + self.dtype = dtype + self.device_map = device_map + self.offload_buffers = offload_buffers + self.model_safetensors_path = self.local_model_path/"model.safetensors.index.json" + self.safetensor_sharder = HFSafeTensorShard( + self.local_model_path, + self.shard + ) # setup logit processors self.logits_processor = LogitsProcessorList([ TopKLogitsWarper(top_k), @@ -70,87 +71,47 @@ def __init__( TopPLogitsWarper(top_p) ]) - self.device = device - self.dtype = dtype - self.device_map = device_map - - self.offload_buffers = offload_buffers - - self.model_safetensors_path = self.local_model_path/"model.safetensors.index.json" - - # setup pytorch and transformer llm + # setup sharded llm try: - if weight_map: - print("loading shard model") - self.llm_model = self.load_sharded_model( - shard, - weight_map, - offload_buffers=self.offload_buffers - ) - - # clear out edited safetensor json - # this is needed because shard downloader just - # appends and not redownloads the file - os.remove(self.model_safetensors_path) - else: - print("loading full model") - self.llm_model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=self.local_model_path, - torch_dtype=self.dtype, - device_map=self.device_map, - offload_buffers=True - ).to(self.device) - + self.llm_model = self.load_sharded_model() self.model = self.llm_model.model.to(self.device) + + # restore originals for next run, if one + self.safetensor_sharder.restore_backups() except Exception as err: - print(f"error loading and splitting model: {err}") + print(f"error loading and sharding model: {err}") raise - def load_sharded_model( - self, - shard: Shard, - weight_map: dict, - offload_buffers: bool - ) -> AutoModelForCausalLM: + # forward variables + self.hidden_states = None + self.input_ids = None + self.inputs_embeds = None + self.attention_mask = None + self.position_embeddings = None + self.past_key_values = None + self.cache_position = None + self.position_ids = None + self.causal_mask = None + + def load_sharded_model(self) -> AutoModelForCausalLM: """ Loads sharded version of model where only needed weights are loaded for necessary layers - Args: - Returns: + llm_model (AutoModelForCausalLM) - sharded llm model with only needed layers loaded """ if DEBUG >= 4: print("load_sharded_model called") - print(f"shard: {shard}") - - # break out layers per shard range - layer_weight_map = extract_layers( - weight_map, - shard - ) - # rewrite model.safetensors.index.json for only needed layers - try: - mst_json = {} - with open(self.model_safetensors_path, "r") as mst_file: - mst_json = json.load(mst_file) - mst_json["weight_map"] = layer_weight_map - - if DEBUG >= 4: - print(f"rewritten safetensor index \n{json.dumps(mst_json, indent=4)}") - - os.remove(self.model_safetensors_path) - - with open(self.model_safetensors_path, "w") as mst_file: - json.dump(mst_json, mst_file, indent=4) - except Exception as err: - print(f"err: {err}") - raise + # modify safetensor + self.safetensor_sharder.modify_safetensor() + self.safetensor_sharder.create_safetensor_index() + self.safetensor_sharder.shard_safetensor_index(self.weight_map) # load model try: - shard_num_hidden_layers = shard.end_layer - shard.start_layer + shard_num_hidden_layers = (self.shard.end_layer - self.shard.start_layer) + 1 if DEBUG >= 4: print(f"config with {shard_num_hidden_layers} layers") @@ -158,11 +119,16 @@ def load_sharded_model( pretrained_model_name_or_path=self.local_model_path, device_map=self.device_map, torch_dtype=self.dtype, - offload_buffers=offload_buffers, + offload_buffers=self.offload_buffers, local_files_only=True, - num_hidden_layers=shard_num_hidden_layers + num_hidden_layers=shard_num_hidden_layers, + use_safetensors=True, + low_cpu_mem_usage=True ) + # restore backup for next run + self.safetensor_sharder.restore_backups() + if self.device_map == "auto": return llm_model else: diff --git a/exo/inference/torch/model/hf_safe_tensor_shard.py b/exo/inference/torch/model/hf_safe_tensor_shard.py index e0737292..95162d3f 100644 --- a/exo/inference/torch/model/hf_safe_tensor_shard.py +++ b/exo/inference/torch/model/hf_safe_tensor_shard.py @@ -19,9 +19,10 @@ class HFSafeTensorShard: def __init__(self, model_path: Path, shard: Shard): - self.model_path = model_path + self.model_path = model_path self.shard = shard self.safetensors_path = self.get_safetensors() + self.safetensor_index_path = f"{self.model_path}/model.safetensors.index.json" def get_safetensors(self) -> list: safetensors_path = [] @@ -41,7 +42,7 @@ def backup_safetensor(self): backup_path = safetensor_path+".backup" if not os.path.exists(backup_path): shutil.copy(safetensor_path, backup_path) - print(f"Backup created at {backup_path}") + print(f"Backup created at {backup_path}") else: print("Backup already exists. Skipping backup.") except Exception as err: @@ -49,27 +50,40 @@ def backup_safetensor(self): raise def modify_safetensor(self): - # Ensure the safetensor is backed up before modifying + """ + Extract needed weights for layers from safetensor files + and create a new safetensor with same names + """ try: self.backup_safetensor() - + safetensor_is_used = False for safetensor_path in self.safetensors_path: + initial_size = os.path.getsize(safetensor_path) with safe_open(safetensor_path, framework="pt") as f: metadata = f.metadata() new_tensors = {} # Iterate over tensors, including only those within the specified layer range - print(f"\n{f.keys()}\n") for key in f.keys(): layer_number = self.extract_layer_number(key) if self.shard.start_layer <= layer_number <= self.shard.end_layer: if DEBUG >= 4: print(f"modify_safetensor [{layer_number}] extracting {key}") new_tensors[key] = f.get_tensor(key) - + safetensor_is_used = True + # Save the modified safetensor - save_file(new_tensors, safetensor_path, metadata) - print(f"Safetensor modified and saved to {safetensor_path}") + if safetensor_is_used: + save_file(new_tensors, safetensor_path, metadata) + modified_size = os.path.getsize(safetensor_path) + + print(f"Safetensor modified and saved to {safetensor_path}") + print(f"Initial size: {initial_size / (1024**3):.2f} GB") + print(f"Modified size: {modified_size / (1024**3):.2f} GB") + else: + # remove unused safetensors + os.remove(safetensor_path) + print(f"Removed safetensor: {safetensor_path}") except Exception as err: print(f"Error modifying safetensor: {err}") raise @@ -91,23 +105,6 @@ def extract_layer_number(self, key): print(f"Error extracting layer number from key '{key}': {err}") return -1 - def restore_backup(self): - """ - Restore the original safetensor from the backup file. - This is useful when you want to reset to the original before making new modifications. - """ - try: - for safetensor_path in self.safetensors_path: - backup_path = safetensor_path+".backup" - if os.path.exists(backup_path): - shutil.copy(backup_path, safetensor_path) - print(f"Safetensor restored from backup at {backup_path}") - else: - print("No backup found. Cannot restore.") - except Exception as err: - print(f"Error in restore_backup: {err}") - raise - def create_safetensor_index(self): """ Creates a model.safetensors.index.json file from a list of safetensor files. @@ -116,6 +113,12 @@ def create_safetensor_index(self): Raises: """ + if os.path.exists(self.safetensor_index_path): + backup_index_path = f"{self.model_path}/model.safetensors.index.json.backup" + if not os.path.exists(backup_index_path): + shutil.copy(self.safetensor_index_path, backup_index_path) + print(f"backed up index json {self.safetensor_index_path}") + if self.safetensors_path: # initialize the metadata and weight_map metadata = { @@ -130,18 +133,15 @@ def create_safetensor_index(self): shard_name = os.path.basename(safetensor_file) # open the safetensor file to read the metadata - with safe_open(safetensor_file, framework="pt") as f: + with safe_open(safetensor_file, framework="pt", device="cpu") as f: # get tensor names tensor_names = f.keys() # collect metadata for each tensor for name in tensor_names: tensor_data = f.get_tensor(name) - print(f"tensor_data: {tensor_data}") shape = tensor_data.shape dtype = tensor_data.dtype - print(f"shape: {shape}") - print(f"dtype: {str(dtype) == "torch.bfloat16"}") # calculate the tensor size in bytes based on dtype total_elements = 1 @@ -163,7 +163,7 @@ def create_safetensor_index(self): metadata["weight_map"][name] = shard_name # write the metadata and weight map to the index file - with open(f"{self.model_path}/model.safetensors.index.json", "w") as f: + with open(self.safetensor_index_path, "w") as f: json.dump(metadata, f, indent=4) print("model.safetensors.index.json created") @@ -179,19 +179,41 @@ def shard_safetensor_index(self, weight_map): # rewrite model.safetensors.index.json for only needed layers try: mst_json = {} - for safetensor_path in self.safetensors_path: - with open(safetensor_path, "r") as mst_file: - mst_json = json.load(mst_file) - mst_json["weight_map"] = layer_weight_map + with open(self.safetensor_index_path, "r") as mst_file: + mst_json = json.load(mst_file) + mst_json["weight_map"] = layer_weight_map - if DEBUG >= 4: - print(f"rewritten safetensor index \n{json.dumps(mst_json, indent=4)}") + if DEBUG >= 4: + print(f"new safetensor index\n{json.dumps(mst_json, indent=4)}\n") - os.remove(safetensor_path) + os.remove(self.safetensor_index_path) - with open(safetensor_path, "w") as mst_file: - json.dump(mst_json, mst_file, indent=4) + with open(self.safetensor_index_path, "w") as mst_file: + json.dump(mst_json, mst_file, indent=4) except Exception as err: print(f"err: {err}") raise - + + def restore_backups(self): + """ + Restore the original safetensor and index json, if any, from the backup file. + """ + try: + for safetensor_path in self.safetensors_path: + backup_path = safetensor_path+".backup" + if os.path.exists(backup_path): + shutil.copy(backup_path, safetensor_path) + print(f"Safetensor restored from backup at {backup_path}") + else: + print("No backup found. Cannot restore.") + + backup_index_path = self.safetensor_index_path+".backup" + if os.path.exists(backup_index_path): + shutil.copy(backup_index_path, self.safetensor_index_path) + print(f"Safetensor index JSON restored from backup at {backup_index_path}") + else: + print("No backup found. Cannot restore") + except Exception as err: + print(f"Error in restore_backup: {err}") + raise + diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py index 2d24c8b2..1594551c 100644 --- a/exo/inference/torch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -22,7 +22,7 @@ async def test_inference_engine( shard = Shard( model_id=model_id, start_layer=0, - end_layer=n_layers-1, + end_layer=0, n_layers=n_layers ) diff --git a/exo/inference/torch/tests/test_safetensor_shard.py b/exo/inference/torch/tests/test_safetensor_shard.py index 14821248..88c12ec0 100644 --- a/exo/inference/torch/tests/test_safetensor_shard.py +++ b/exo/inference/torch/tests/test_safetensor_shard.py @@ -9,10 +9,11 @@ from exo.download.hf.hf_shard_download import HFShardDownloader from exo.download.hf.hf_helpers import get_weight_map +from transformers import AutoModelForCausalLM, AutoTokenizer async def main(): - start_layer = 3 - end_layer = 5 + start_layer = 0 + end_layer = 1 # Create a Shard object shard = Shard( @@ -36,8 +37,31 @@ async def main(): tensor_shard = HFSafeTensorShard(model_path, shard) tensor_shard.modify_safetensor() tensor_shard.create_safetensor_index() - tensor_shard.restore_backup() + # load model and test + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=shard.model_id, + local_files_only=True, + num_hidden_layers=shard.end_layer - shard.start_layer + ).to("cuda") + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "In one simple word, what is the color of a red apple?"} + ] + + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = tokenizer([text], return_tensors="pt") + + print(f"model_inputs:\n{model_inputs}") + + tensor_shard.restore_backup() if __name__ == "__main__": asyncio.run(main()) diff --git a/exo/inference/torch/tests/test_simple_model.py b/exo/inference/torch/tests/test_simple_model.py index 2a36717f..5ffd30ef 100644 --- a/exo/inference/torch/tests/test_simple_model.py +++ b/exo/inference/torch/tests/test_simple_model.py @@ -4,42 +4,47 @@ """ from transformers import AutoModelForCausalLM, AutoTokenizer -model = AutoModelForCausalLM.from_pretrained( - "Qwen/Qwen2-0.5B-Instruct", - torch_dtype="auto", - device_map="auto" -) -tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - -prompt = "In a single word only, what is the last name of the current president of the USA?" - -messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt} -] -text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True -) -model_inputs = tokenizer([text], return_tensors="pt") - -print(f"model_inputs:\n{model_inputs}") - -print(f"generation_config:\n{model.generation_config}") - -generated_ids = model.generate( - model_inputs.input_ids, - attention_mask=model_inputs.attention_mask, - max_new_tokens=512, - do_sample=True -) - -generated_ids = [ - output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) -] - -response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] - -print(f"Prompt: {prompt}\n") -print(f"Response: {response}\n") +def run_simple(prompt: str): + model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen2-0.5B-Instruct", + torch_dtype="auto", + device_map="auto" + ) + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = tokenizer([text], return_tensors="pt") + + print(f"model_inputs:\n{model_inputs}") + + print(f"generation_config:\n{model.generation_config}") + + generated_ids = model.generate( + model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + max_new_tokens=512, + do_sample=True + ) + + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + print(f"Prompt: {prompt}\n") + print(f"Response: {response}\n") + +if __name__ == "__main__": + run_simple( + "In a single word only, what is the last name of the current president of the USA?" + ) From 696c264d45e36b832309bc9e61e08e2fcaf94ae1 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 20 Oct 2024 04:05:12 -0800 Subject: [PATCH 453/491] updating backup and backup restore --- exo/inference/torch/model/hf_safe_tensor_shard.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exo/inference/torch/model/hf_safe_tensor_shard.py b/exo/inference/torch/model/hf_safe_tensor_shard.py index 95162d3f..052fafb0 100644 --- a/exo/inference/torch/model/hf_safe_tensor_shard.py +++ b/exo/inference/torch/model/hf_safe_tensor_shard.py @@ -202,6 +202,7 @@ def restore_backups(self): for safetensor_path in self.safetensors_path: backup_path = safetensor_path+".backup" if os.path.exists(backup_path): + os.remove(safetensor_path) shutil.copy(backup_path, safetensor_path) print(f"Safetensor restored from backup at {backup_path}") else: @@ -209,6 +210,7 @@ def restore_backups(self): backup_index_path = self.safetensor_index_path+".backup" if os.path.exists(backup_index_path): + os.remove(self.safetensor_index_path) shutil.copy(backup_index_path, self.safetensor_index_path) print(f"Safetensor index JSON restored from backup at {backup_index_path}") else: From 9514e922f73250c5df37ccd29f61ee88d37e9d8c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 20 Oct 2024 04:12:05 -0800 Subject: [PATCH 454/491] added removing backup when restoring --- exo/inference/torch/model/hf_safe_tensor_shard.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/exo/inference/torch/model/hf_safe_tensor_shard.py b/exo/inference/torch/model/hf_safe_tensor_shard.py index 052fafb0..2dc6c7ff 100644 --- a/exo/inference/torch/model/hf_safe_tensor_shard.py +++ b/exo/inference/torch/model/hf_safe_tensor_shard.py @@ -204,6 +204,7 @@ def restore_backups(self): if os.path.exists(backup_path): os.remove(safetensor_path) shutil.copy(backup_path, safetensor_path) + os.remove(backup_path) print(f"Safetensor restored from backup at {backup_path}") else: print("No backup found. Cannot restore.") @@ -212,6 +213,7 @@ def restore_backups(self): if os.path.exists(backup_index_path): os.remove(self.safetensor_index_path) shutil.copy(backup_index_path, self.safetensor_index_path) + os.remove(backup_index_path) print(f"Safetensor index JSON restored from backup at {backup_index_path}") else: print("No backup found. Cannot restore") From d65505ee8b59c7d5b14ee301c9655b7cfcdefb97 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 20 Oct 2024 04:45:50 -0800 Subject: [PATCH 455/491] added generating weight map if none, did updates to backup and restore process for sharding HF safetensors --- .../torch/model/hf_safe_tensor_shard.py | 38 ++++++++++--------- .../torch/tests/test_inference_engine.py | 30 +++++++-------- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/exo/inference/torch/model/hf_safe_tensor_shard.py b/exo/inference/torch/model/hf_safe_tensor_shard.py index 2dc6c7ff..537e73d5 100644 --- a/exo/inference/torch/model/hf_safe_tensor_shard.py +++ b/exo/inference/torch/model/hf_safe_tensor_shard.py @@ -6,6 +6,7 @@ import shutil import json +from typing import Optional from pathlib import Path from safetensors import safe_open @@ -23,13 +24,21 @@ def __init__(self, model_path: Path, shard: Shard): self.shard = shard self.safetensors_path = self.get_safetensors() self.safetensor_index_path = f"{self.model_path}/model.safetensors.index.json" + self.metadata = { + "metadata": { + "total_size": 0 + }, + "weight_map": {} + } def get_safetensors(self) -> list: safetensors_path = [] try: for file_name in os.listdir(self.model_path): if file_name.endswith(".safetensors"): - safetensors_path.append(os.path.join(self.model_path, file_name)) + safetensor_path = os.path.join(self.model_path, file_name) + if safetensor_path not in safetensors_path: + safetensors_path.append(safetensor_path) except Exception as err: print(f"Error in get_safetensor_path: {err}") raise @@ -42,7 +51,7 @@ def backup_safetensor(self): backup_path = safetensor_path+".backup" if not os.path.exists(backup_path): shutil.copy(safetensor_path, backup_path) - print(f"Backup created at {backup_path}") + print(f"Backup created at {backup_path}") else: print("Backup already exists. Skipping backup.") except Exception as err: @@ -120,14 +129,7 @@ def create_safetensor_index(self): print(f"backed up index json {self.safetensor_index_path}") if self.safetensors_path: - # initialize the metadata and weight_map - metadata = { - "metadata": { - "total_size": 0 - }, - "weight_map": {} - } - + # initialize the metadata and weight_map for safetensor_file in self.safetensors_path: # use the safetensor file name as the shard_name shard_name = os.path.basename(safetensor_file) @@ -157,20 +159,24 @@ def create_safetensor_index(self): raise ValueError(f"unsupported dtype: {dtype}") tensor_size = total_elements * element_size - metadata["metadata"]["total_size"] += tensor_size + self.metadata["metadata"]["total_size"] += tensor_size # add to weight_map, mapping the tensor to the shard (file) name - metadata["weight_map"][name] = shard_name + self.metadata["weight_map"][name] = shard_name # write the metadata and weight map to the index file with open(self.safetensor_index_path, "w") as f: - json.dump(metadata, f, indent=4) + json.dump(self.metadata, f, indent=4) print("model.safetensors.index.json created") else: print("No safetensor files provided.") - def shard_safetensor_index(self, weight_map): + def shard_safetensor_index(self, weight_map: Optional[dict] = None): + if weight_map is None: + weight_map = self.metadata["weight_map"] + + print(f"shard\n{weight_map}") layer_weight_map = extract_layers( weight_map, self.shard @@ -206,8 +212,6 @@ def restore_backups(self): shutil.copy(backup_path, safetensor_path) os.remove(backup_path) print(f"Safetensor restored from backup at {backup_path}") - else: - print("No backup found. Cannot restore.") backup_index_path = self.safetensor_index_path+".backup" if os.path.exists(backup_index_path): @@ -215,8 +219,6 @@ def restore_backups(self): shutil.copy(backup_index_path, self.safetensor_index_path) os.remove(backup_index_path) print(f"Safetensor index JSON restored from backup at {backup_index_path}") - else: - print("No backup found. Cannot restore") except Exception as err: print(f"Error in restore_backup: {err}") raise diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py index 1594551c..2b72b859 100644 --- a/exo/inference/torch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -116,26 +116,26 @@ async def test_inference_engine( assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - #try: - # print("\n\n -------- TEST Qwen/Qwen2.5-3B-Instruct -------- \n\n") - # asyncio.run(test_inference_engine( - # TorchDynamicShardInferenceEngine(HFShardDownloader()), - # TorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Qwen/Qwen2.5-3B-Instruct", - # 36 - # )) - #except Exception as err: - # print(f"\n!!!! QWEN2 TEST FAILED \n{err}\n") - try: - print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") + print("\n\n -------- TEST Qwen/Qwen2-0.5B-Instruct -------- \n\n") asyncio.run(test_inference_engine( TorchDynamicShardInferenceEngine(HFShardDownloader()), TorchDynamicShardInferenceEngine(HFShardDownloader()), - "unsloth/Meta-Llama-3.1-8B-Instruct", - 32 + "Qwen/Qwen2-0.5B-Instruct", + 36 )) except Exception as err: - print(f"\n!!!! unsloth/Meta-Llama-3.1-8B-Instruct TEST FAILED \n{err}\n") + print(f"\n!!!! QWEN2 TEST FAILED \n{err}\n") + + #try: + # print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n") + # asyncio.run(test_inference_engine( + # TorchDynamicShardInferenceEngine(HFShardDownloader()), + # TorchDynamicShardInferenceEngine(HFShardDownloader()), + # "unsloth/Meta-Llama-3.1-8B-Instruct", + # 32 + # )) + #except Exception as err: + # print(f"\n!!!! unsloth/Meta-Llama-3.1-8B-Instruct TEST FAILED \n{err}\n") From d5b61131c79b2bab01685913e4838c1743b0b959 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 20 Oct 2024 04:56:50 -0800 Subject: [PATCH 456/491] cleaning up logging --- .../torch/model/hf_safe_tensor_shard.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/exo/inference/torch/model/hf_safe_tensor_shard.py b/exo/inference/torch/model/hf_safe_tensor_shard.py index 537e73d5..34c9c411 100644 --- a/exo/inference/torch/model/hf_safe_tensor_shard.py +++ b/exo/inference/torch/model/hf_safe_tensor_shard.py @@ -51,9 +51,9 @@ def backup_safetensor(self): backup_path = safetensor_path+".backup" if not os.path.exists(backup_path): shutil.copy(safetensor_path, backup_path) - print(f"Backup created at {backup_path}") - else: - print("Backup already exists. Skipping backup.") + + if DEBUG >= 4: + print(f"Backup created at {backup_path}") except Exception as err: print(f"Error in backup_safetensor: {err}") raise @@ -80,19 +80,22 @@ def modify_safetensor(self): print(f"modify_safetensor [{layer_number}] extracting {key}") new_tensors[key] = f.get_tensor(key) safetensor_is_used = True - + # Save the modified safetensor if safetensor_is_used: save_file(new_tensors, safetensor_path, metadata) modified_size = os.path.getsize(safetensor_path) - print(f"Safetensor modified and saved to {safetensor_path}") - print(f"Initial size: {initial_size / (1024**3):.2f} GB") - print(f"Modified size: {modified_size / (1024**3):.2f} GB") + if DEBUG >= 4: + print(f"Safetensor modified and saved to {safetensor_path}") + print(f"Initial size: {initial_size / (1024**3):.2f} GB") + print(f"Modified size: {modified_size / (1024**3):.2f} GB") else: # remove unused safetensors os.remove(safetensor_path) - print(f"Removed safetensor: {safetensor_path}") + + if DEBUG >= 4: + print(f"Removed safetensor: {safetensor_path}") except Exception as err: print(f"Error modifying safetensor: {err}") raise @@ -126,7 +129,9 @@ def create_safetensor_index(self): backup_index_path = f"{self.model_path}/model.safetensors.index.json.backup" if not os.path.exists(backup_index_path): shutil.copy(self.safetensor_index_path, backup_index_path) - print(f"backed up index json {self.safetensor_index_path}") + + if DEBUG >= 4: + print(f"backed up index json {self.safetensor_index_path}") if self.safetensors_path: # initialize the metadata and weight_map @@ -168,7 +173,8 @@ def create_safetensor_index(self): with open(self.safetensor_index_path, "w") as f: json.dump(self.metadata, f, indent=4) - print("model.safetensors.index.json created") + if DEBUG >= 4: + print(f"created new {self.safetensor_index_path}") else: print("No safetensor files provided.") @@ -176,7 +182,6 @@ def shard_safetensor_index(self, weight_map: Optional[dict] = None): if weight_map is None: weight_map = self.metadata["weight_map"] - print(f"shard\n{weight_map}") layer_weight_map = extract_layers( weight_map, self.shard @@ -211,14 +216,18 @@ def restore_backups(self): os.remove(safetensor_path) shutil.copy(backup_path, safetensor_path) os.remove(backup_path) - print(f"Safetensor restored from backup at {backup_path}") + + if DEBUG >= 4: + print(f"Safetensor restored from backup at {backup_path}") backup_index_path = self.safetensor_index_path+".backup" if os.path.exists(backup_index_path): os.remove(self.safetensor_index_path) shutil.copy(backup_index_path, self.safetensor_index_path) os.remove(backup_index_path) - print(f"Safetensor index JSON restored from backup at {backup_index_path}") + + if DEBUG >= 4: + print(f"Safetensor index JSON restored from backup at {backup_index_path}") except Exception as err: print(f"Error in restore_backup: {err}") raise From d2302ccd424e072d84cab0ad9be88ff91b6f2a58 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 20 Oct 2024 05:04:51 -0800 Subject: [PATCH 457/491] updating docstring in newest class file --- .../torch/model/hf_safe_tensor_shard.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/exo/inference/torch/model/hf_safe_tensor_shard.py b/exo/inference/torch/model/hf_safe_tensor_shard.py index 34c9c411..c3afdea5 100644 --- a/exo/inference/torch/model/hf_safe_tensor_shard.py +++ b/exo/inference/torch/model/hf_safe_tensor_shard.py @@ -32,6 +32,12 @@ def __init__(self, model_path: Path, shard: Shard): } def get_safetensors(self) -> list: + """ + Gets a list of all files that have the extension .safetensors + + Return: + list: A list of all the safetensors file paths + """ safetensors_path = [] try: for file_name in os.listdir(self.model_path): @@ -120,10 +126,6 @@ def extract_layer_number(self, key): def create_safetensor_index(self): """ Creates a model.safetensors.index.json file from a list of safetensor files. - - Args: - - Raises: """ if os.path.exists(self.safetensor_index_path): backup_index_path = f"{self.model_path}/model.safetensors.index.json.backup" @@ -179,6 +181,13 @@ def create_safetensor_index(self): print("No safetensor files provided.") def shard_safetensor_index(self, weight_map: Optional[dict] = None): + """ + Modify the weight_map of the safetensors index json to only + get weights for the working layers + + Args: + weight_map(dict, Optional): holds which weight maps to which layer + """ if weight_map is None: weight_map = self.metadata["weight_map"] From 72fcf9bb7ebe7b78972a369a9312113638cf1a38 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 21 Oct 2024 13:15:32 -0800 Subject: [PATCH 458/491] starting write of llama3 model outside of transformers and using pytorch --- exo/inference/torch/model/llama3_tokenizer.py | 219 ++++++++++++++++++ exo/inference/torch/model/llm_utils.py | 23 ++ 2 files changed, 242 insertions(+) create mode 100644 exo/inference/torch/model/llama3_tokenizer.py create mode 100644 exo/inference/torch/model/llm_utils.py diff --git a/exo/inference/torch/model/llama3_tokenizer.py b/exo/inference/torch/model/llama3_tokenizer.py new file mode 100644 index 00000000..e595d4b3 --- /dev/null +++ b/exo/inference/torch/model/llama3_tokenizer.py @@ -0,0 +1,219 @@ +""" +Llama3 tokenizer from https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/tokenizer.py +""" +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import os +from logging import getLogger +from pathlib import Path +from typing import ( + AbstractSet, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Union, +) + +import tiktoken + +from tiktoken.load import load_tiktoken_bpe + +logger = getLogger(__name__) + + +# The tiktoken tokenizer can handle <=400k chars without +# pyo3_runtime.PanicException. +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +# https://github.com/openai/tiktoken/issues/195 +# Here we iterate over subsequences and split if we exceed the limit +# of max consecutive non-whitespace or whitespace characters. +MAX_NO_WHITESPACES_CHARS = 25_000 + + +_INSTANCE = None + + +class Tokenizer: + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + @classmethod + def get_instance(cls): + global _INSTANCE + + if _INSTANCE is None: + _INSTANCE = Tokenizer( + os.path.join(os.path.dirname(__file__), "tokenizer.model") + ) + return _INSTANCE + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + "<|image|>", + ] + reserved_tokens = [ + f"<|reserved_special_token_{2 + i}|>" + for i in range(self.num_reserved_special_tokens - len(special_tokens)) + ] + special_tokens = special_tokens + reserved_tokens + + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self.n_words: int = num_base_tokens + len(special_tokens) + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.eot_id: int = self.special_tokens["<|eot_id|>"] + self.eom_id: int = self.special_tokens["<|eom_id|>"] + self.python_tag_id = self.special_tokens["<|python_tag|>"] + self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] + self.stop_tokens = [ + self.eos_id, + self.special_tokens["<|eom_id|>"], + self.special_tokens["<|eot_id|>"], + ] + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_special ("all"|set[str]): allowed special tokens in string + disallowed_special ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + if allowed_special is None: + allowed_special = set() + assert type(s) is str + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] diff --git a/exo/inference/torch/model/llm_utils.py b/exo/inference/torch/model/llm_utils.py new file mode 100644 index 00000000..0868207d --- /dev/null +++ b/exo/inference/torch/model/llm_utils.py @@ -0,0 +1,23 @@ +""" +Utility methods used by LLMs +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchtune.modules import RotaryPositionalEmbeddings + +def rope_embed( + self, + input_embeddings: torch.Tensor, + position_ids: torch.Tensor, +): + """ + Wrapper of rotary embeddings using pytorch module + + Args: + input_embeddings (torch.Tensor): token embeddings from input + position_ids (torch.Tensor): position ids of tokens + """ + rotary_emb = RotaryPositionalEmbeddings() + + From 9cac5ab706af117de2255a0bf515a1cf84a573f7 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 21 Oct 2024 13:38:04 -0800 Subject: [PATCH 459/491] moving llama3 modeling source code, updating readme file --- exo/inference/torch/README.md | 30 +++++++++++-------- exo/inference/torch/model/llama3/model.py | 5 ++++ .../tokenizer.py} | 0 exo/inference/torch/model/llm_utils.py | 24 +++++++++++---- 4 files changed, 42 insertions(+), 17 deletions(-) create mode 100644 exo/inference/torch/model/llama3/model.py rename exo/inference/torch/model/{llama3_tokenizer.py => llama3/tokenizer.py} (100%) diff --git a/exo/inference/torch/README.md b/exo/inference/torch/README.md index 1beca27c..da67faa2 100644 --- a/exo/inference/torch/README.md +++ b/exo/inference/torch/README.md @@ -1,4 +1,21 @@ -# PyTorch & HuggingFace inference engine +# PyTorch inference engine + +## Devs +- [Vincent Castro](https://github.com/risingsunomi) + +## Notes/Issues +### 10/10/2024 +- To select a pytorch device via environment variables, set the variable TORCH_DEVICE + - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM + - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373) + - Looking into adding mobile device support properly +- If device is not CPU the data type defaults to float32 else float16. + +### 10/13/2024 +Still working on split model development (see test_split_model.py). Right now, it seems to do it but still transformers is loading more in the RAM and GPU as it loads up a larger models (causing an OOM). Will research and add to next update. Right now, tests are added and are in development. + +### 10/21/2024 +Working on removing transformers due to inference and VRAM usage [issues](https://github.com/exo-explore/exo/pull/139#issuecomment-2424953962). Creating a pure pytorch implementation of llama3 as using transformers wont work for exo. Using some code from meta but also implementing the use of torchtune. ## Tech @@ -31,14 +48,3 @@ GPU 4: NVIDIA Quadro P400 2GB GPU 5: NVIDIA Quadro P400 2GB ``` - -## Notes/Issues -### 10/10/2024 -- To select a pytorch device via environment variables, set the variable TORCH_DEVICE - - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM - - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373) - - Looking into adding mobile device support properly -- If device is not CPU the data type defaults to float32 else float16. - -### 10/13/2024 -Still working on split model development (see test_split_model.py). Right now, it seems to do it but still transformers is loading more in the RAM and GPU as it loads up a larger models (causing an OOM). Will research and add to next update. Right now, tests are added and are in development. diff --git a/exo/inference/torch/model/llama3/model.py b/exo/inference/torch/model/llama3/model.py new file mode 100644 index 00000000..d26c0c02 --- /dev/null +++ b/exo/inference/torch/model/llama3/model.py @@ -0,0 +1,5 @@ +""" +llama3 model + +Written with pytorch using torchtune and other methods +""" diff --git a/exo/inference/torch/model/llama3_tokenizer.py b/exo/inference/torch/model/llama3/tokenizer.py similarity index 100% rename from exo/inference/torch/model/llama3_tokenizer.py rename to exo/inference/torch/model/llama3/tokenizer.py diff --git a/exo/inference/torch/model/llm_utils.py b/exo/inference/torch/model/llm_utils.py index 0868207d..2b95e547 100644 --- a/exo/inference/torch/model/llm_utils.py +++ b/exo/inference/torch/model/llm_utils.py @@ -6,18 +6,32 @@ import torch.nn.functional as F from torchtune.modules import RotaryPositionalEmbeddings +from typing import Optional + def rope_embed( - self, + head_dim: int, input_embeddings: torch.Tensor, - position_ids: torch.Tensor, -): + position_ids: Optional[torch.Tensor], +) -> torch.Tensor: """ Wrapper of rotary embeddings using pytorch module Args: - input_embeddings (torch.Tensor): token embeddings from input + input_embeddings (torch.Tensor): token embeddings from input position_ids (torch.Tensor): position ids of tokens + + Returns: + torch.Tensor: output with RoPE applied """ - rotary_emb = RotaryPositionalEmbeddings() + try: + rotary_emb = RotaryPositionalEmbeddings(head_dim) + output = rotary_emb.forward( + input_embeddings, + input_pos=position_ids + ) + except Exception: + raise + + return output From 80120084722384ec1fe8318c73dea5d7a13fe0bb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 22 Oct 2024 17:32:17 -0800 Subject: [PATCH 460/491] adding pytorch based llama model, added testing and working through bugs --- exo/inference/torch/inference.py | 2 +- exo/inference/torch/model/llama3/model.py | 5 - exo/inference/torch/model/llama3/tokenizer.py | 219 -------------- exo/inference/torch/model/llm_utils.py | 37 --- .../torch/{model => models}/__init__.py | 0 exo/inference/torch/{model => models}/hf.py | 0 .../{model => models}/hf_safe_tensor_shard.py | 0 exo/inference/torch/models/llama3.py | 275 ++++++++++++++++++ exo/inference/torch/models/llm_utils.py | 92 ++++++ .../torch/tests/test_llama3_model.py | 129 ++++++++ .../torch/tests/test_safetensor_shard.py | 2 +- 11 files changed, 498 insertions(+), 263 deletions(-) delete mode 100644 exo/inference/torch/model/llama3/model.py delete mode 100644 exo/inference/torch/model/llama3/tokenizer.py delete mode 100644 exo/inference/torch/model/llm_utils.py rename exo/inference/torch/{model => models}/__init__.py (100%) rename exo/inference/torch/{model => models}/hf.py (100%) rename exo/inference/torch/{model => models}/hf_safe_tensor_shard.py (100%) create mode 100644 exo/inference/torch/models/llama3.py create mode 100644 exo/inference/torch/models/llm_utils.py create mode 100644 exo/inference/torch/tests/test_llama3_model.py diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/inference.py index f89f2367..23bbe814 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/inference.py @@ -12,7 +12,7 @@ from typing import Optional, Tuple, Union, List from exo.inference.shard import Shard from exo.inference.inference_engine import InferenceEngine -from exo.inference.torch.model.hf import ShardedHuggingFaceModel +from exo.inference.torch.models.hf import ShardedHuggingFaceModel from exo.inference.tokenizers import resolve_tokenizer from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader diff --git a/exo/inference/torch/model/llama3/model.py b/exo/inference/torch/model/llama3/model.py deleted file mode 100644 index d26c0c02..00000000 --- a/exo/inference/torch/model/llama3/model.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -llama3 model - -Written with pytorch using torchtune and other methods -""" diff --git a/exo/inference/torch/model/llama3/tokenizer.py b/exo/inference/torch/model/llama3/tokenizer.py deleted file mode 100644 index e595d4b3..00000000 --- a/exo/inference/torch/model/llama3/tokenizer.py +++ /dev/null @@ -1,219 +0,0 @@ -""" -Llama3 tokenizer from https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/tokenizer.py -""" -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - -import os -from logging import getLogger -from pathlib import Path -from typing import ( - AbstractSet, - cast, - Collection, - Dict, - Iterator, - List, - Literal, - Optional, - Sequence, - Union, -) - -import tiktoken - -from tiktoken.load import load_tiktoken_bpe - -logger = getLogger(__name__) - - -# The tiktoken tokenizer can handle <=400k chars without -# pyo3_runtime.PanicException. -TIKTOKEN_MAX_ENCODE_CHARS = 400_000 - -# https://github.com/openai/tiktoken/issues/195 -# Here we iterate over subsequences and split if we exceed the limit -# of max consecutive non-whitespace or whitespace characters. -MAX_NO_WHITESPACES_CHARS = 25_000 - - -_INSTANCE = None - - -class Tokenizer: - """ - Tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - @classmethod - def get_instance(cls): - global _INSTANCE - - if _INSTANCE is None: - _INSTANCE = Tokenizer( - os.path.join(os.path.dirname(__file__), "tokenizer.model") - ) - return _INSTANCE - - def __init__(self, model_path: str): - """ - Initializes the Tokenizer with a Tiktoken model. - - Args: - model_path (str): The path to the Tiktoken model file. - """ - assert os.path.isfile(model_path), model_path - - mergeable_ranks = load_tiktoken_bpe(model_path) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|step_id|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - "<|image|>", - ] - reserved_tokens = [ - f"<|reserved_special_token_{2 + i}|>" - for i in range(self.num_reserved_special_tokens - len(special_tokens)) - ] - special_tokens = special_tokens + reserved_tokens - - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - - self.n_words: int = num_base_tokens + len(special_tokens) - # BOS / EOS token IDs - self.bos_id: int = self.special_tokens["<|begin_of_text|>"] - self.eos_id: int = self.special_tokens["<|end_of_text|>"] - self.eot_id: int = self.special_tokens["<|eot_id|>"] - self.eom_id: int = self.special_tokens["<|eom_id|>"] - self.python_tag_id = self.special_tokens["<|python_tag|>"] - self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] - self.stop_tokens = [ - self.eos_id, - self.special_tokens["<|eom_id|>"], - self.special_tokens["<|eot_id|>"], - ] - - def encode( - self, - s: str, - *, - bos: bool, - eos: bool, - allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, - disallowed_special: Union[Literal["all"], Collection[str]] = (), - ) -> List[int]: - """ - Encodes a string into a list of token IDs. - - Args: - s (str): The input string to be encoded. - bos (bool): Whether to prepend the beginning-of-sequence token. - eos (bool): Whether to append the end-of-sequence token. - allowed_special ("all"|set[str]): allowed special tokens in string - disallowed_special ("all"|set[str]): special tokens that raise an error when in string - - Returns: - list[int]: A list of token IDs. - - By default, setting disallowed_special=() encodes a string by ignoring - special tokens. Specifically: - - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (insteading of raising - an error). - - Setting `allowed_special` to "all" will treat all text corresponding - to special tokens to be encoded as special tokens. - """ - if allowed_special is None: - allowed_special = set() - assert type(s) is str - - substrs = ( - substr - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) - for substr in self._split_whitespaces_or_nonwhitespaces( - s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS - ) - ) - t: List[int] = [] - for substr in substrs: - t.extend( - self.model.encode( - substr, - allowed_special=allowed_special, - disallowed_special=disallowed_special, - ) - ) - if bos: - t.insert(0, self.bos_id) - if eos: - t.append(self.eos_id) - return t - - def decode(self, t: Sequence[int]) -> str: - """ - Decodes a list of token IDs into a string. - - Args: - t (List[int]): The list of token IDs to be decoded. - - Returns: - str: The decoded string. - """ - # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) - - @staticmethod - def _split_whitespaces_or_nonwhitespaces( - s: str, max_consecutive_slice_len: int - ) -> Iterator[str]: - """ - Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` - consecutive whitespaces or consecutive non-whitespaces. - """ - current_slice_len = 0 - current_slice_is_space = s[0].isspace() if len(s) > 0 else False - slice_start = 0 - - for i in range(len(s)): - is_now_space = s[i].isspace() - - if current_slice_is_space ^ is_now_space: - current_slice_len = 1 - current_slice_is_space = is_now_space - else: - current_slice_len += 1 - if current_slice_len > max_consecutive_slice_len: - yield s[slice_start:i] - slice_start = i - current_slice_len = 1 - yield s[slice_start:] diff --git a/exo/inference/torch/model/llm_utils.py b/exo/inference/torch/model/llm_utils.py deleted file mode 100644 index 2b95e547..00000000 --- a/exo/inference/torch/model/llm_utils.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Utility methods used by LLMs -""" -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchtune.modules import RotaryPositionalEmbeddings - -from typing import Optional - -def rope_embed( - head_dim: int, - input_embeddings: torch.Tensor, - position_ids: Optional[torch.Tensor], -) -> torch.Tensor: - """ - Wrapper of rotary embeddings using pytorch module - - Args: - input_embeddings (torch.Tensor): token embeddings from input - position_ids (torch.Tensor): position ids of tokens - - Returns: - torch.Tensor: output with RoPE applied - """ - try: - rotary_emb = RotaryPositionalEmbeddings(head_dim) - output = rotary_emb.forward( - input_embeddings, - input_pos=position_ids - ) - except Exception: - raise - - return output - - diff --git a/exo/inference/torch/model/__init__.py b/exo/inference/torch/models/__init__.py similarity index 100% rename from exo/inference/torch/model/__init__.py rename to exo/inference/torch/models/__init__.py diff --git a/exo/inference/torch/model/hf.py b/exo/inference/torch/models/hf.py similarity index 100% rename from exo/inference/torch/model/hf.py rename to exo/inference/torch/models/hf.py diff --git a/exo/inference/torch/model/hf_safe_tensor_shard.py b/exo/inference/torch/models/hf_safe_tensor_shard.py similarity index 100% rename from exo/inference/torch/model/hf_safe_tensor_shard.py rename to exo/inference/torch/models/hf_safe_tensor_shard.py diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py new file mode 100644 index 00000000..c0af9ea3 --- /dev/null +++ b/exo/inference/torch/models/llama3.py @@ -0,0 +1,275 @@ +""" +llama3 model + +Written with pytorch using torchtune and other methods +""" +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torchtune.modules import MultiHeadAttention, RotaryPositionalEmbeddings, KVCache + +class LlamaBlock(nn.Module): + """ + Encoder block class for the LLaMA model without residual connections. + """ + def __init__( + self, + dim, + heads, + num_kv_heads, + head_dim, + ff_dim, + rms_norm_eps, + attention_dropout=0.0, + use_bias=False, + max_seq_len=4096, + pos_embeddings=None + ): + super(LlamaBlock, self).__init__() + + # Define linear projections for Q, K, V, and Output + self.q_proj = nn.Linear(dim, heads * head_dim, bias=use_bias) + self.k_proj = nn.Linear(dim, num_kv_heads * head_dim, bias=use_bias) + self.v_proj = nn.Linear(dim, num_kv_heads * head_dim, bias=use_bias) + self.output_proj = nn.Linear(heads * head_dim, dim, bias=use_bias) + + # Define optional query normalization + self.q_norm = nn.LayerNorm(head_dim, eps=rms_norm_eps) + + # MultiHeadAttention from torchtune + self.attn = MultiHeadAttention( + embed_dim=dim, + num_heads=heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=self.q_proj, + k_proj=self.k_proj, + v_proj=self.v_proj, + output_proj=self.output_proj, + pos_embeddings=pos_embeddings, + q_norm=self.q_norm, + k_norm=self.q_norm, + kv_cache=None, + max_seq_len=max_seq_len, + is_causal=True, + attn_dropout=attention_dropout + ) + + # RMSNorm layers before and after attention and feed-forward layers + self.norm1 = nn.LayerNorm(dim, eps=rms_norm_eps) + self.norm2 = nn.LayerNorm(dim, eps=rms_norm_eps) + + # Feed-forward layer with SwiGLU activation + self.feed_forward = nn.Sequential( + nn.Linear(dim, ff_dim), + nn.GLU(), # SwiGLU approximation + nn.Linear(ff_dim // 2, dim) + ) + + def forward( + self, + x, + kv_cache: Optional[KVCache] = None, + attention_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, KVCache]: + """ + Forward pass with integrated attention and key-value caching. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + kv_cache (Optional[KVCache]): KVCache object for managing past key-value states. + attention_mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, 1, 1, seq_len). + input_pos (Optional[torch.Tensor]): Position IDs tensor of shape (batch_size, seq_len). + + Returns: + Tuple[torch.Tensor, KVCache]: + - x (torch.Tensor): Output tensor of shape (batch_size, seq_len, dim). + - kv_cache (KVCache): Updated KVCache object. + """ + # Apply normalization before attention + residual = x + x = self.norm1(x) + + # Compute Q, K, V projections + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + # Initialize or update KVCache + if kv_cache is None: + kv_cache = KVCache( + batch_size=x.size(0), + max_seq_len=x.size(1), + num_heads=self.attn.num_heads, + head_dim=self.attn.head_dim, + dtype=x.dtype + ) + + # Update KVCache with new key-value pairs + k_val, v_val = kv_cache.update(k, v) + + # Apply MultiHeadAttention with key-value caching + x = self.attn(q, k_val, v_val, mask=attention_mask, input_pos=input_pos) + + # Residual connection + x = x + residual + + # Apply feed-forward network with residual connection + residual = x + x = self.norm2(x) + x = self.feed_forward(x) + x = x + residual + + return x, kv_cache + +class LlamaModel(nn.Module): + """ + LlamaModel is a pure PyTorch implementation of the LLaMA architecture + """ + + def __init__(self, config, tokenizer): + """ + Initialize the LlamaModel. + + Args: + config (dict): Configuration dictionary containing model parameters. + - hidden_size (int): Size of the hidden layers. + - num_hidden_layers (int): Number of transformer layers. + - num_attention_heads (int): Number of attention heads. + - intermediate_size (int): Size of the intermediate (feed-forward) layers. + - vocab_size (int): Vocabulary size for the embedding layer. + - max_position_embeddings (int): Maximum number of positional embeddings. + - rms_norm_eps (float): Epsilon for RMS normalization. + - head_dim (int): Dimension of each attention head. + - attention_dropout (float): Dropout rate for attention layers. + tokenizer: Tokenizer used for input preprocessing. + """ + super(LlamaModel, self).__init__() + + # Load configurations from config + self.config = config + self.hidden_size = config['hidden_size'] + self.num_layers = config['num_hidden_layers'] + self.num_heads = config['num_attention_heads'] + self.num_kv_heads = config['num_key_value_heads'] + self.intermediate_size = config['intermediate_size'] + self.vocab_size = config['vocab_size'] + self.max_position_embeddings = config['max_position_embeddings'] + self.rms_norm_eps = config['rms_norm_eps'] + self.head_dim = config['head_dim'] + self.attention_dropout = config.get('attention_dropout', 0.0) + + # Model layers + self.embed = nn.Embedding(self.vocab_size, self.hidden_size) + self.rotary_pos_emb = RotaryPositionalEmbeddings( + self.hidden_size // self.num_heads, + config['rope_scaling']['original_max_position_embeddings'], + config['rope_theta'] + ) + self.layers = nn.ModuleList([ + LlamaBlock( + dim=self.hidden_size, + heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + ff_dim=self.intermediate_size, + rms_norm_eps=self.rms_norm_eps, + attention_dropout=self.attention_dropout, + use_bias=config.get('attention_bias', False) + ) for _ in range(self.num_layers) + ]) + self.norm = nn.LayerNorm(self.hidden_size, eps=self.rms_norm_eps) + self.to_logits = nn.Linear(self.hidden_size, self.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + pos_ids: Optional[torch.Tensor] = None, + past_kv_cache: Optional[KVCache] = None, + return_hidden_states: bool = False + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], KVCache]: + """ + Forward pass with integrated position ID handling, attention mask, and optional KVCache. + + Args: + input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len). + attention_mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, seq_len). + pos_ids (Optional[torch.Tensor]): Position IDs. If None, they are calculated automatically. + past_kv_cache (Optional[KVCache]): Optional KVCache for efficient generation. + If provided, it stores past key-value states for faster autoregressive inference. + return_hidden_states (bool): Whether to return hidden states from each layer. + + Returns: + Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], KVCache]: + - logits (torch.Tensor): Output logits of shape (batch_size, seq_len, vocab_size). + - hidden_states (Optional[Tuple[torch.Tensor]]): Hidden states from each layer, if return_hidden_states is True. + - past_kv_cache (KVCache): Updated KVCache object. + """ + batch_size, seq_len = input_ids.shape + + # Create initial embeddings + x = self.embed(input_ids) + + # Initialize position IDs if not provided + if pos_ids is None: + past_seen_tokens = past_kv_cache.size if past_kv_cache is not None else 0 + pos_ids = torch.arange( + past_seen_tokens, + past_seen_tokens + seq_len, + device=input_ids.device + ).unsqueeze(0).expand(batch_size, -1) + + # Reshape x to prepare for rotary embeddings: (batch_size, seq_len, num_heads, head_dim) + x = x.view(batch_size, seq_len, self.num_heads, self.head_dim) + + # Apply rotary positional embeddings + x = self.rotary_pos_emb( + x=x, + input_pos=pos_ids + ) + + # Reshape x back to original shape: (batch_size, seq_len, hidden_size) + x = x.view(batch_size, seq_len, self.hidden_size) + + # Initialize or use the provided KVCache + if past_kv_cache is None: + past_kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=self.max_position_embeddings, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=x.dtype + ) + + # Apply attention mask if provided (convert to appropriate format) + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # Shape: (batch_size, 1, 1, seq_len) + attention_mask = (1.0 - attention_mask) * -1e4 # Convert to large negative values + + # Track hidden states if required + hidden_states = [] + + # Forward pass through layers with KVCache + for layer_idx, layer in enumerate(self.layers): + x, k_val, v_val = layer(x, past_kv_cache, layer_idx, attention_mask) + + # Update KVCache + past_kv_cache.update(k_val, v_val) + + if return_hidden_states: + hidden_states.append(x) + + # Apply final layer normalization + x = self.norm(x) + + # Compute logits + logits = self.to_logits(x) + + # Prepare the return values + if return_hidden_states: + return logits, tuple(hidden_states), past_kv_cache + else: + return logits, None, past_kv_cache diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py new file mode 100644 index 00000000..89057021 --- /dev/null +++ b/exo/inference/torch/models/llm_utils.py @@ -0,0 +1,92 @@ +""" +Utility methods used by LLMs +""" +import json +from pathlib import Path + +import torch +import torch.nn.functional as F + +from exo.helpers import DEBUG + +def load_model_config(model_config_path: Path) -> dict: + """ + Loads the config.json of the model + + Args: + model_path (Path): local path to model config json + + Returns: + dict: The config as a dictionary + """ + model_config = {} + with open(model_config_path, "r") as f: + model_config = json.load(f) + return model_config + +def select_next_token( + logits, + top_k=0, + top_p=0.0, + temperature=1.0, + use_max=False, +): + """ + Selects the next token from logits using top-k, top-p, and temperature scaling. + + Args: + logits (torch.Tensor): Logits tensor of shape (batch_size, vocab_size). + top_k (int): Number of top logits to consider for sampling. + top_p (float): Cumulative probability threshold for nucleus sampling. + temperature (float): Scaling factor for temperature. + use_max (bool): Whether to use argmax for next token selection. + debug (bool): If True, prints debugging information. + + Returns: + next_token (torch.Tensor): The next token selected (batch_size,). + """ + # Get logits for the last token in the sequence + logits = logits[:, -1, :].clone().float() + + # Apply temperature scaling + if temperature != 1.0: + logits = logits / temperature + + # Apply top-k filtering + if top_k > 0: + # Get the top-k logits and set the rest to -inf + top_k_values, _ = torch.topk(logits, top_k, dim=-1) + min_top_k_value = top_k_values[:, -1, None] + logits = torch.where(logits < min_top_k_value, torch.tensor(float('-inf'), device=logits.device), logits) + + # Apply top-p (nucleus) filtering + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Mask tokens exceeding the top-p threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() # Shift right + sorted_indices_to_remove[:, 0] = 0 # Ensure at least one token is selected + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, float('-inf')) + + # Calculate probabilities + probs = F.softmax(logits, dim=-1) + + # Select next token + if not use_max: + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) + + # Debugging output + if DEBUG >= 4: + print(f"Logits: {logits}") + print(f"Probabilities: {probs}") + print(f"Next token: {next_token}") + + return next_token.squeeze(-1) + + diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py new file mode 100644 index 00000000..25485b11 --- /dev/null +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -0,0 +1,129 @@ +""" +Test of pytorch based llama3 model +""" +from pathlib import Path + +import torch +from transformers import AutoTokenizer +from huggingface_hub import snapshot_download +from safetensors.torch import load_file as load_safetensors +from exo.inference.torch.models.llm_utils import load_model_config, select_next_token +from exo.inference.torch.models.llama3 import LlamaModel, KVCache + +# Constants +MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct" + +# Get the path to the model files from the Hugging Face cache +cache_dir = Path(snapshot_download(MODEL_NAME)) +print(f"Cache directory: {cache_dir}") + +# Load model configuration +config = load_model_config(cache_dir / "config.json") + +# Initialize tokenizer +tokenizer = AutoTokenizer.from_pretrained(cache_dir) + +# Initialize LlamaModel with config and tokenizer +model = LlamaModel(config, tokenizer) + +# Load weights from safetensors files in the cache directory +safetensors_files = list(cache_dir.glob("*.safetensors")) +if not safetensors_files: + raise FileNotFoundError("No safetensors files found in the cache directory.") + +# Load weights from each found safetensors file +for safetensor_file in safetensors_files: + print(f"Loading weights from: {safetensor_file}") + state_dict = load_safetensors(safetensor_file) + model.load_state_dict(state_dict, strict=False) + +model.eval() # Set the model to evaluation mode + +# Sample text for testing +test_text = "Once upon a time," + +def test_forward_pass(model, tokenizer, text): + """ + Test the forward pass of the LlamaModel with given input text. + """ + # Tokenize input text + inputs = tokenizer(text, return_tensors="pt") + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask") + + # Initialize KVCache + past_kv_cache = KVCache( + batch_size=input_ids.size(0), + max_seq_len=model.max_position_embeddings, + num_heads=model.num_heads, + head_dim=model.head_dim, + dtype=input_ids.dtype + ) + + # Forward pass with KVCache + with torch.no_grad(): + logits, hidden_states, _ = model( + input_ids, + attention_mask=attention_mask, + pos_ids=None, + past_kv_cache=past_kv_cache, + return_hidden_states=True + ) + + # Print logits shape and hidden state information + print(f"Logits shape: {logits.shape}") + if hidden_states: + print(f"Number of hidden states: {len(hidden_states)}") + print(f"Shape of last hidden state: {hidden_states[-1].shape}") + +def test_generation(model, tokenizer, text, max_length=50): + """ + Test the generation capabilities of the LlamaModel with sample text. + """ + # Tokenize input text + inputs = tokenizer(text, return_tensors="pt") + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask") + + # Initialize KVCache for caching + past_kv_cache = KVCache( + batch_size=input_ids.size(0), + max_seq_len=model.max_position_embeddings, + num_heads=model.num_heads, + head_dim=model.head_dim, + dtype=input_ids.dtype + ) + + # Start with initial input_ids + generated_ids = input_ids.clone() + + # Generate tokens step-by-step + for _ in range(max_length): + with torch.no_grad(): + logits, _, past_kv_cache = model( + generated_ids, + attention_mask=attention_mask, + past_kv_cache=past_kv_cache + ) + + # Select next token using logits + next_token = select_next_token(logits, top_k=50, top_p=0.9, temperature=0.7, use_max=False) + + # Update generated_ids + generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) + + # Check for EOS token + if next_token.item() == tokenizer.eos_token_id: + break + + # Decode generated text + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + print(f"Generated text: {generated_text}") + +if __name__ == "__main__": + print("Testing forward pass:") + test_forward_pass(model, tokenizer, test_text) + + print("\nTesting generation:") + test_generation(model, tokenizer, test_text) + diff --git a/exo/inference/torch/tests/test_safetensor_shard.py b/exo/inference/torch/tests/test_safetensor_shard.py index 88c12ec0..ef72ef58 100644 --- a/exo/inference/torch/tests/test_safetensor_shard.py +++ b/exo/inference/torch/tests/test_safetensor_shard.py @@ -5,7 +5,7 @@ import asyncio from exo.inference.shard import Shard -from exo.inference.torch.model.hf_safe_tensor_shard import HFSafeTensorShard +from exo.inference.torch.models.hf_safe_tensor_shard import HFSafeTensorShard from exo.download.hf.hf_shard_download import HFShardDownloader from exo.download.hf.hf_helpers import get_weight_map From 76323d727dcc3fecbc18cded7c1b1d3cf339f79b Mon Sep 17 00:00:00 2001 From: Vincent C Date: Tue, 22 Oct 2024 17:35:35 -0800 Subject: [PATCH 461/491] Update llama3.py removing some of docstring --- exo/inference/torch/models/llama3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index c0af9ea3..8156b110 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -11,7 +11,7 @@ class LlamaBlock(nn.Module): """ - Encoder block class for the LLaMA model without residual connections. + Encoder block class for the LLaMA model """ def __init__( self, From 0d66acdfa369a0d2986bb21cd8c35bf86532a587 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 23 Oct 2024 12:14:16 -0800 Subject: [PATCH 462/491] updating pytorch llama model still, currently broken but backing up as continuing the rewrite/refactor --- exo/inference/torch/models/llama3.py | 191 ++++++++++-------- .../torch/tests/test_safetensor_shard.py | 8 +- 2 files changed, 116 insertions(+), 83 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 8156b110..23e04958 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -3,11 +3,19 @@ Written with pytorch using torchtune and other methods """ +import math from typing import Optional, Tuple import torch import torch.nn as nn -from torchtune.modules import MultiHeadAttention, RotaryPositionalEmbeddings, KVCache +from torchtune.modules import ( + MultiHeadAttention, + RotaryPositionalEmbeddings, + KVCache, + RMSNorm +) + +from exo.inference.shard import Shard class LlamaBlock(nn.Module): """ @@ -21,12 +29,25 @@ def __init__( head_dim, ff_dim, rms_norm_eps, + rotary_pos_emb, attention_dropout=0.0, use_bias=False, max_seq_len=4096, pos_embeddings=None ): super(LlamaBlock, self).__init__() + # Class vars + self.dim = dim + self.heads = heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.ff_dim = ff_dim + self.rms_norm_eps = rms_norm_eps + self.attention_dropout = attention_dropout + self.use_bias = use_bias + self.max_seq_len = max_seq_len + self.pos_embeddings = pos_embeddings + self.rotary_pos_emb = rotary_pos_emb # Define linear projections for Q, K, V, and Output self.q_proj = nn.Linear(dim, heads * head_dim, bias=use_bias) @@ -68,12 +89,13 @@ def __init__( ) def forward( - self, - x, - kv_cache: Optional[KVCache] = None, - attention_mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, KVCache]: + self, + hidden_states: torch.Tensor, + kv_cache: Optional[KVCache] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.FloatTensor] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Forward pass with integrated attention and key-value caching. @@ -81,55 +103,60 @@ def forward( x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). kv_cache (Optional[KVCache]): KVCache object for managing past key-value states. attention_mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, 1, 1, seq_len). - input_pos (Optional[torch.Tensor]): Position IDs tensor of shape (batch_size, seq_len). + position_ids (Optional[torch.Tensor]): Position IDs tensor of shape (batch_size, seq_len). Returns: - Tuple[torch.Tensor, KVCache]: + Tuple[torch.Tensor, KVCache]: - x (torch.Tensor): Output tensor of shape (batch_size, seq_len, dim). - kv_cache (KVCache): Updated KVCache object. """ - # Apply normalization before attention - residual = x - x = self.norm1(x) + batch_size, seq_len, _ = hidden_states.shape + attn_output, attn_weights = None - # Compute Q, K, V projections - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) + # Do kvq projection + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Reshape + query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # Initialize or update KVCache if kv_cache is None: kv_cache = KVCache( - batch_size=x.size(0), - max_seq_len=x.size(1), - num_heads=self.attn.num_heads, + batch_size=batch_size, + max_seq_len=self.attn.max_seq_len, + num_heads=self.heads, head_dim=self.attn.head_dim, - dtype=x.dtype + dtype=hidden_states.dtype ) - # Update KVCache with new key-value pairs - k_val, v_val = kv_cache.update(k, v) + # cache + value_states = kv_cache.update(key_states, value_states) - # Apply MultiHeadAttention with key-value caching - x = self.attn(q, k_val, v_val, mask=attention_mask, input_pos=input_pos) + # Attention weights and causal mask + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - # Residual connection - x = x + residual + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) - # Apply feed-forward network with residual connection - residual = x - x = self.norm2(x) - x = self.feed_forward(x) - x = x + residual + - return x, kv_cache + return attn_output, attn_weights, kv_cache class LlamaModel(nn.Module): """ LlamaModel is a pure PyTorch implementation of the LLaMA architecture """ - def __init__(self, config, tokenizer): + def __init__(self, config: dict, shard: Shard): """ Initialize the LlamaModel. @@ -144,10 +171,11 @@ def __init__(self, config, tokenizer): - rms_norm_eps (float): Epsilon for RMS normalization. - head_dim (int): Dimension of each attention head. - attention_dropout (float): Dropout rate for attention layers. - tokenizer: Tokenizer used for input preprocessing. """ super(LlamaModel, self).__init__() + self.shard = shard + # Load configurations from config self.config = config self.hidden_size = config['hidden_size'] @@ -160,9 +188,10 @@ def __init__(self, config, tokenizer): self.rms_norm_eps = config['rms_norm_eps'] self.head_dim = config['head_dim'] self.attention_dropout = config.get('attention_dropout', 0.0) + self.padding_idx = config["pad_token_id"] # Model layers - self.embed = nn.Embedding(self.vocab_size, self.hidden_size) + self.embed = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) self.rotary_pos_emb = RotaryPositionalEmbeddings( self.hidden_size // self.num_heads, config['rope_scaling']['original_max_position_embeddings'], @@ -171,105 +200,107 @@ def __init__(self, config, tokenizer): self.layers = nn.ModuleList([ LlamaBlock( dim=self.hidden_size, - heads=self.num_heads, + heads=self.hidden_size // self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, ff_dim=self.intermediate_size, rms_norm_eps=self.rms_norm_eps, attention_dropout=self.attention_dropout, - use_bias=config.get('attention_bias', False) + use_bias=config.get('attention_bias', False), + rotary_pos_emb=self.rotary_pos_emb ) for _ in range(self.num_layers) ]) - self.norm = nn.LayerNorm(self.hidden_size, eps=self.rms_norm_eps) + self.norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self.to_logits = nn.Linear(self.hidden_size, self.vocab_size) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - pos_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, past_kv_cache: Optional[KVCache] = None, - return_hidden_states: bool = False - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], KVCache]: + ) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor]], KVCache]: """ Forward pass with integrated position ID handling, attention mask, and optional KVCache. Args: input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len). attention_mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, seq_len). - pos_ids (Optional[torch.Tensor]): Position IDs. If None, they are calculated automatically. + position_ids (Optional[torch.Tensor]): Position IDs. If None, they are calculated automatically. + cache_position (Optional[torch.LongTensor]): the positions of inputs in the sequence past_kv_cache (Optional[KVCache]): Optional KVCache for efficient generation. If provided, it stores past key-value states for faster autoregressive inference. - return_hidden_states (bool): Whether to return hidden states from each layer. Returns: Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], KVCache]: - - logits (torch.Tensor): Output logits of shape (batch_size, seq_len, vocab_size). - - hidden_states (Optional[Tuple[torch.Tensor]]): Hidden states from each layer, if return_hidden_states is True. + - logits (Optional[torch.Tensor]): Output logits of shape (batch_size, seq_len, vocab_size). + - hidden_states (Optional[torch.Tensor]): Hidden states from each layer - past_kv_cache (KVCache): Updated KVCache object. """ batch_size, seq_len = input_ids.shape # Create initial embeddings - x = self.embed(input_ids) + input_embeds = self.embed(input_ids) + + # Initialize or use the provided KVCache + if past_kv_cache is None: + past_kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=self.max_position_embeddings, + num_heads=self.num_heads, + head_dim=self.head_dim, + dtype=x.dtype + ) # Initialize position IDs if not provided - if pos_ids is None: + if cache_position is None: past_seen_tokens = past_kv_cache.size if past_kv_cache is not None else 0 - pos_ids = torch.arange( + cache_position = torch.arange( past_seen_tokens, past_seen_tokens + seq_len, device=input_ids.device ).unsqueeze(0).expand(batch_size, -1) - # Reshape x to prepare for rotary embeddings: (batch_size, seq_len, num_heads, head_dim) - x = x.view(batch_size, seq_len, self.num_heads, self.head_dim) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + hidden_states = input_embeds # Apply rotary positional embeddings - x = self.rotary_pos_emb( - x=x, - input_pos=pos_ids + position_embeddings = self.rotary_pos_emb( + hidden_states, + input_pos=position_ids ) - # Reshape x back to original shape: (batch_size, seq_len, hidden_size) - x = x.view(batch_size, seq_len, self.hidden_size) - - # Initialize or use the provided KVCache - if past_kv_cache is None: - past_kv_cache = KVCache( - batch_size=batch_size, - max_seq_len=self.max_position_embeddings, - num_heads=self.num_heads, - head_dim=self.head_dim, - dtype=x.dtype - ) - # Apply attention mask if provided (convert to appropriate format) if attention_mask is not None: attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # Shape: (batch_size, 1, 1, seq_len) attention_mask = (1.0 - attention_mask) * -1e4 # Convert to large negative values - # Track hidden states if required - hidden_states = [] - # Forward pass through layers with KVCache - for layer_idx, layer in enumerate(self.layers): - x, k_val, v_val = layer(x, past_kv_cache, layer_idx, attention_mask) - - # Update KVCache - past_kv_cache.update(k_val, v_val) + for layer_idx in range(self.shard.end_layer, self.shard.start_layer): + layer_hidden_state, layer_kv_cache = layer( + hidden_states=hidden_states, + kv_cache=past_kv_cache, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings + ) - if return_hidden_states: - hidden_states.append(x) + hidden_states = layer_hidden_state # Apply final layer normalization - x = self.norm(x) + hidden_states = self.norm(hidden_states) - # Compute logits - logits = self.to_logits(x) + # Compute logits if at end layer + if self.shard.is_last_layer(): + logits = self.to_logits(hidden_states) + else: + logits = None # Prepare the return values if return_hidden_states: - return logits, tuple(hidden_states), past_kv_cache + return logits, hidden_states, past_kv_cache else: return logits, None, past_kv_cache diff --git a/exo/inference/torch/tests/test_safetensor_shard.py b/exo/inference/torch/tests/test_safetensor_shard.py index ef72ef58..dd84ff18 100644 --- a/exo/inference/torch/tests/test_safetensor_shard.py +++ b/exo/inference/torch/tests/test_safetensor_shard.py @@ -17,7 +17,7 @@ async def main(): # Create a Shard object shard = Shard( - model_id="meta-llama/Llama-3.2-1B-Instruct", + model_id="unsloth/Meta-Llama-3.1-8B-Instruct", start_layer=start_layer, end_layer=end_layer-1, n_layers=32 @@ -42,7 +42,9 @@ async def main(): model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=shard.model_id, local_files_only=True, - num_hidden_layers=shard.end_layer - shard.start_layer + num_hidden_layers=shard.end_layer - shard.start_layer, + #device_map="auto", + torch_dtype="float16" ).to("cuda") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") @@ -61,7 +63,7 @@ async def main(): print(f"model_inputs:\n{model_inputs}") - tensor_shard.restore_backup() + tensor_shard.restore_backups() if __name__ == "__main__": asyncio.run(main()) From 1512d13b52e87636c741c07fe79383cd46cdc891 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 25 Oct 2024 02:18:42 -0800 Subject: [PATCH 463/491] updated llamablock and llamamodel, created a MLP helper class to use in other transformers, added in torch based 4d linear transformation causal mask, need to add in device map or just using device as running out of ram, think other issues at play as should not run out of ram before getting to hidden layers --- exo/inference/torch/models/hf.py | 5 +- exo/inference/torch/models/llama3.py | 183 +++++++++--------- exo/inference/torch/models/llm_utils.py | 129 ++++++++++++ .../torch/tests/test_llama3_model.py | 32 ++- 4 files changed, 246 insertions(+), 103 deletions(-) diff --git a/exo/inference/torch/models/hf.py b/exo/inference/torch/models/hf.py index f15d5d19..5d5b03e4 100644 --- a/exo/inference/torch/models/hf.py +++ b/exo/inference/torch/models/hf.py @@ -1,5 +1,3 @@ -import os -import json from typing import Tuple, Optional, Union, List from pathlib import Path @@ -8,8 +6,7 @@ from exo.inference.shard import Shard from exo.helpers import DEBUG -from exo.inference.torch.utils import extract_layers -from exo.inference.torch.model.hf_safe_tensor_shard import HFSafeTensorShard +from exo.inference.torch.models.hf_safe_tensor_shard import HFSafeTensorShard from transformers import ( AutoModelForCausalLM, diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 23e04958..94607388 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -3,7 +3,6 @@ Written with pytorch using torchtune and other methods """ -import math from typing import Optional, Tuple import torch @@ -16,6 +15,7 @@ ) from exo.inference.shard import Shard +from exo.inference.torch.models.llm_utils import MLP, create_4d_causal_attention_mask class LlamaBlock(nn.Module): """ @@ -36,7 +36,6 @@ def __init__( pos_embeddings=None ): super(LlamaBlock, self).__init__() - # Class vars self.dim = dim self.heads = heads self.num_kv_heads = num_kv_heads @@ -48,45 +47,19 @@ def __init__( self.max_seq_len = max_seq_len self.pos_embeddings = pos_embeddings self.rotary_pos_emb = rotary_pos_emb - - # Define linear projections for Q, K, V, and Output self.q_proj = nn.Linear(dim, heads * head_dim, bias=use_bias) self.k_proj = nn.Linear(dim, num_kv_heads * head_dim, bias=use_bias) self.v_proj = nn.Linear(dim, num_kv_heads * head_dim, bias=use_bias) self.output_proj = nn.Linear(heads * head_dim, dim, bias=use_bias) - - # Define optional query normalization - self.q_norm = nn.LayerNorm(head_dim, eps=rms_norm_eps) - - # MultiHeadAttention from torchtune - self.attn = MultiHeadAttention( - embed_dim=dim, - num_heads=heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - q_proj=self.q_proj, - k_proj=self.k_proj, - v_proj=self.v_proj, - output_proj=self.output_proj, - pos_embeddings=pos_embeddings, - q_norm=self.q_norm, - k_norm=self.q_norm, - kv_cache=None, - max_seq_len=max_seq_len, - is_causal=True, - attn_dropout=attention_dropout - ) - - # RMSNorm layers before and after attention and feed-forward layers - self.norm1 = nn.LayerNorm(dim, eps=rms_norm_eps) - self.norm2 = nn.LayerNorm(dim, eps=rms_norm_eps) - - # Feed-forward layer with SwiGLU activation - self.feed_forward = nn.Sequential( - nn.Linear(dim, ff_dim), - nn.GLU(), # SwiGLU approximation - nn.Linear(ff_dim // 2, dim) + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.mlp = MLP( + input_dim=dim, + hidden_dims=[ff_dim], # Single hidden layer with ff_dim as the hidden size + output_dim=dim, + activation='gelu', + dropout=attention_dropout ) + self.post_norm = RMSNorm(dim, eps=rms_norm_eps) def forward( self, @@ -94,62 +67,60 @@ def forward( kv_cache: Optional[KVCache] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - position_embeddings: Optional[torch.FloatTensor] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[KVCache]]: """ - Forward pass with integrated attention and key-value caching. + Forward pass with integrated attention, resnet and key-value caching. Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + hidden_states (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). kv_cache (Optional[KVCache]): KVCache object for managing past key-value states. attention_mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, 1, 1, seq_len). position_ids (Optional[torch.Tensor]): Position IDs tensor of shape (batch_size, seq_len). Returns: Tuple[torch.Tensor, KVCache]: - - x (torch.Tensor): Output tensor of shape (batch_size, seq_len, dim). - - kv_cache (KVCache): Updated KVCache object. + - Output tensor of shape (batch_size, seq_len, dim). + - Updated KVCache object. """ - batch_size, seq_len, _ = hidden_states.shape - attn_output, attn_weights = None - - # Do kvq projection - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Reshape - query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - # Initialize or update KVCache - if kv_cache is None: - kv_cache = KVCache( - batch_size=batch_size, - max_seq_len=self.attn.max_seq_len, - num_heads=self.heads, - head_dim=self.attn.head_dim, - dtype=hidden_states.dtype - ) - - # cache - value_states = kv_cache.update(key_states, value_states) - - # Attention weights and causal mask - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # setting up resnet + residual = hidden_states + + # Apply RMSNorm to input + hidden_states = self.input_norm(hidden_states) + + # Apply MultiHeadAttention with KVCache + hidden_states = MultiHeadAttention( + embed_dim=self.dim, + num_heads=self.heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + q_proj=self.q_proj, + k_proj=self.k_proj, + v_proj=self.v_proj, + output_proj=self.output_proj, + pos_embeddings=self.rotary_pos_emb, + q_norm=self.q_norm, + k_norm=self.q_norm, + kv_cache=kv_cache, # Passed during the forward call + max_seq_len=self.max_seq_len, + is_causal=True, + attn_dropout=self.attention_dropout + )( + x=hidden_states, + mask=attention_mask, + input_pos=position_ids + ) - + # Residual connection + hidden_states = residual + hidden_states + residual = hidden_states + # Post attention normalization + hidden_states = self.post_norm(hidden_states) + # Feed-forward network with MLP and residual connection + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states + residual - return attn_output, attn_weights, kv_cache + return hidden_states, kv_cache class LlamaModel(nn.Module): """ @@ -188,7 +159,8 @@ def __init__(self, config: dict, shard: Shard): self.rms_norm_eps = config['rms_norm_eps'] self.head_dim = config['head_dim'] self.attention_dropout = config.get('attention_dropout', 0.0) - self.padding_idx = config["pad_token_id"] + self.padding_idx = config.get("pad_token_id") + self.device_map="any" # Model layers self.embed = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) @@ -220,7 +192,7 @@ def forward( position_ids: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, past_kv_cache: Optional[KVCache] = None, - ) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor]], KVCache]: + ) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor]], Optional[KVCache]]: """ Forward pass with integrated position ID handling, attention mask, and optional KVCache. @@ -250,7 +222,7 @@ def forward( max_seq_len=self.max_position_embeddings, num_heads=self.num_heads, head_dim=self.head_dim, - dtype=x.dtype + dtype=input_embeds.dtype ) # Initialize position IDs if not provided @@ -267,28 +239,60 @@ def forward( hidden_states = input_embeds + # Reshape hidden_states to (batch_size, seq_len, num_heads, head_dim) + batch_size, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(batch_size, seq_len, self.num_heads, self.head_dim) + + # Reshape position_ids to match (batch_size, seq_len) + if position_ids.dim() != 2: + position_ids = position_ids.squeeze(0) + + print(f"hidden_states: {hidden_states.shape}") + print(f"position_ids: {position_ids.shape}") + # Apply rotary positional embeddings position_embeddings = self.rotary_pos_emb( hidden_states, input_pos=position_ids ) - # Apply attention mask if provided (convert to appropriate format) + print(f"position_embeddings: {position_embeddings.shape}") + + # Reshape back to (batch_size, seq_len, hidden_size) + hidden_states = hidden_states.view(batch_size, seq_len, self.hidden_size) + print(f"hidden_states: {hidden_states.shape}") + + # create 4d causal mask + causal_mask = None if attention_mask is not None: - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # Shape: (batch_size, 1, 1, seq_len) - attention_mask = (1.0 - attention_mask) * -1e4 # Convert to large negative values + causal_mask = create_4d_causal_attention_mask( + attention_mask=attention_mask, + seq_len=hidden_states.size(1), + target_len=self.max_position_embeddings, + dtype=hidden_states.dtype, + device=hidden_states.device, + cache_pos=torch.arange(self.max_position_embeddings, device=hidden_states.device), + batch_size=hidden_states.size(0) + ) + + print(f"attention_mask: {attention_mask.shape}") + print(f"causal_mask: {causal_mask.shape}") # Forward pass through layers with KVCache for layer_idx in range(self.shard.end_layer, self.shard.start_layer): - layer_hidden_state, layer_kv_cache = layer( + print(f"forward layer #{layer_idx}") + encoder_layer = self.layers[layer_idx] + print(f"encoder_layer\n{encoder_layer}") + layer_hidden_state, layer_kv_cache = self.layers[layer_idx]( hidden_states=hidden_states, kv_cache=past_kv_cache, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, position_embeddings=position_embeddings ) hidden_states = layer_hidden_state + past_kv_cache = layer_kv_cache # Apply final layer normalization hidden_states = self.norm(hidden_states) @@ -299,8 +303,7 @@ def forward( else: logits = None - # Prepare the return values - if return_hidden_states: + if logits is None: return logits, hidden_states, past_kv_cache else: return logits, None, past_kv_cache diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 89057021..22f22485 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -5,6 +5,7 @@ from pathlib import Path import torch +import torch.nn as nn import torch.nn.functional as F from exo.helpers import DEBUG @@ -89,4 +90,132 @@ def select_next_token( return next_token.squeeze(-1) +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dims, output_dim, activation='gelu', dropout=0.0, use_batchnorm=False): + """ + General MLP (Multi-Layer Perceptron) module. + + Args: + input_dim (int): Dimensionality of the input. + hidden_dims (list of int): List of hidden layer dimensions. + output_dim (int): Dimensionality of the output. + activation (str): Activation function ('relu', 'gelu', 'tanh', 'sigmoid', etc.). + dropout (float): Dropout probability. + use_batchnorm (bool): Whether to use batch normalization. + """ + super(MLP, self).__init__() + + self.layers = nn.ModuleList() + self.use_batchnorm = use_batchnorm + + # Activation function mapping + activations = { + 'relu': nn.ReLU(), + 'gelu': nn.GELU(), + 'tanh': nn.Tanh(), + 'sigmoid': nn.Sigmoid(), + 'leaky_relu': nn.LeakyReLU(0.2) + } + + # Ensure valid activation + if activation not in activations: + raise ValueError(f"Invalid activation: {activation}. Choose from {list(activations.keys())}") + + self.activation = activations[activation] + + # Construct MLP layers + prev_dim = input_dim + for h_dim in hidden_dims: + self.layers.append(nn.Linear(prev_dim, h_dim)) + if use_batchnorm: + self.layers.append(nn.BatchNorm1d(h_dim)) + self.layers.append(self.activation) + if dropout > 0: + self.layers.append(nn.Dropout(dropout)) + prev_dim = h_dim + + # Output layer + self.output_layer = nn.Linear(prev_dim, output_dim) + + def forward(self, x): + """ + Forward pass for the MLP module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after the MLP transformations. + """ + for layer in self.layers: + x = layer(x) + return self.output_layer(x) + +def create_4d_causal_attention_mask( + attention_mask: torch.Tensor, + seq_len: int, + target_len: int, + dtype: torch.dtype, + device: torch.device, + cache_pos: torch.Tensor, + batch_size: int, +) -> torch.Tensor: + """ + Creates a 4D causal attention mask from a 2D mask, with adjustments for static caching. + + Args: + attention_mask (torch.Tensor): + A 2D tensor of shape (batch_size, key_value_length) or a 4D tensor of shape + (batch_size, 1, query_length, key_value_length). + seq_len (int): + Sequence length of the input being processed. + target_len (int): + Target length to generate the causal mask. + dtype (torch.dtype): + Data type for the causal mask. + device (torch.device): + Device to place the causal mask on. + cache_pos (torch.Tensor): + Cache position indices indicating the position of the input tokens in the sequence. + batch_size (int): + Number of samples in the batch. + + Returns: + torch.Tensor: + A 4D causal mask of shape (batch_size, 1, query_length, key_value_length). + """ + if attention_mask is not None and attention_mask.dim() == 4: + # If the mask is already 4D, return it directly + return attention_mask + + min_value = torch.finfo(dtype).min + + # Create a 2D causal mask of shape (seq_len, target_len) + causal_mask = torch.full( + (seq_len, target_len), fill_value=min_value, dtype=dtype, device=device + ) + + if seq_len != 1: + # Mask positions after the current position + causal_mask = torch.triu(causal_mask, diagonal=1) + + # Adjust causal mask for cache position + causal_mask *= (torch.arange(target_len, device=device) > cache_pos.view(-1, 1)) + + # Expand to 4D and batch size + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + + if attention_mask is not None: + # Create a padding mask based on the input attention_mask + mask_len = attention_mask.shape[-1] + causal_mask = causal_mask.clone() # Ensure contiguous memory for in-place operations + padding_mask = causal_mask[:, :, :, :mask_len] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + + # Apply padding to the causal mask + causal_mask[:, :, :, :mask_len] = causal_mask[:, :, :, :mask_len].masked_fill( + padding_mask, min_value + ) + + return causal_mask diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 25485b11..13d50f8d 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -9,8 +9,8 @@ from safetensors.torch import load_file as load_safetensors from exo.inference.torch.models.llm_utils import load_model_config, select_next_token from exo.inference.torch.models.llama3 import LlamaModel, KVCache +from exo.inference.shard import Shard -# Constants MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct" # Get the path to the model files from the Hugging Face cache @@ -20,11 +20,21 @@ # Load model configuration config = load_model_config(cache_dir / "config.json") +print(f"current config\n{config}") + +# Setup shard +shard = Shard( + model_id=MODEL_NAME, + start_layer=0, + end_layer=int(config["num_hidden_layers"]) - 1, + n_layers=int(config["num_hidden_layers"]) +) + # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(cache_dir) # Initialize LlamaModel with config and tokenizer -model = LlamaModel(config, tokenizer) +model = LlamaModel(config, shard) # Load weights from safetensors files in the cache directory safetensors_files = list(cache_dir.glob("*.safetensors")) @@ -48,9 +58,12 @@ def test_forward_pass(model, tokenizer, text): """ # Tokenize input text inputs = tokenizer(text, return_tensors="pt") - input_ids = inputs["input_ids"] + input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask") + print(f"input_ids: {input_ids}") + print(f"attention_mask: {attention_mask}") + # Initialize KVCache past_kv_cache = KVCache( batch_size=input_ids.size(0), @@ -62,17 +75,18 @@ def test_forward_pass(model, tokenizer, text): # Forward pass with KVCache with torch.no_grad(): - logits, hidden_states, _ = model( + logits, hidden_states, past_kv_cache = model( input_ids, attention_mask=attention_mask, - pos_ids=None, - past_kv_cache=past_kv_cache, - return_hidden_states=True + position_ids=None, + past_kv_cache=past_kv_cache ) # Print logits shape and hidden state information - print(f"Logits shape: {logits.shape}") - if hidden_states: + if logits is not None: + print(f"Logits shape: {logits.shape}") + + if hidden_states is not None: print(f"Number of hidden states: {len(hidden_states)}") print(f"Shape of last hidden state: {hidden_states[-1].shape}") From 0eb80448fd1714f261f64d2282f05bbd45654b3e Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 25 Oct 2024 03:07:34 -0800 Subject: [PATCH 464/491] fixing causual mask loading error, updated testing, working on logit selection issues producing gibberish --- exo/inference/torch/models/llama3.py | 6 +- .../torch/tests/test_llama3_model.py | 130 +++++++----------- 2 files changed, 51 insertions(+), 85 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 94607388..954df4a1 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -267,11 +267,11 @@ def forward( if attention_mask is not None: causal_mask = create_4d_causal_attention_mask( attention_mask=attention_mask, - seq_len=hidden_states.size(1), - target_len=self.max_position_embeddings, + seq_len=hidden_states.shape[1], + target_len=attention_mask.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device, - cache_pos=torch.arange(self.max_position_embeddings, device=hidden_states.device), + cache_pos=cache_position, batch_size=hidden_states.size(0) ) diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 13d50f8d..6308335b 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -4,6 +4,7 @@ from pathlib import Path import torch +import torchtune.generation as ttg from transformers import AutoTokenizer from huggingface_hub import snapshot_download from safetensors.torch import load_file as load_safetensors @@ -12,83 +13,9 @@ from exo.inference.shard import Shard MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct" +TEMP=0.7 +TOP_K=25 -# Get the path to the model files from the Hugging Face cache -cache_dir = Path(snapshot_download(MODEL_NAME)) -print(f"Cache directory: {cache_dir}") - -# Load model configuration -config = load_model_config(cache_dir / "config.json") - -print(f"current config\n{config}") - -# Setup shard -shard = Shard( - model_id=MODEL_NAME, - start_layer=0, - end_layer=int(config["num_hidden_layers"]) - 1, - n_layers=int(config["num_hidden_layers"]) -) - -# Initialize tokenizer -tokenizer = AutoTokenizer.from_pretrained(cache_dir) - -# Initialize LlamaModel with config and tokenizer -model = LlamaModel(config, shard) - -# Load weights from safetensors files in the cache directory -safetensors_files = list(cache_dir.glob("*.safetensors")) -if not safetensors_files: - raise FileNotFoundError("No safetensors files found in the cache directory.") - -# Load weights from each found safetensors file -for safetensor_file in safetensors_files: - print(f"Loading weights from: {safetensor_file}") - state_dict = load_safetensors(safetensor_file) - model.load_state_dict(state_dict, strict=False) - -model.eval() # Set the model to evaluation mode - -# Sample text for testing -test_text = "Once upon a time," - -def test_forward_pass(model, tokenizer, text): - """ - Test the forward pass of the LlamaModel with given input text. - """ - # Tokenize input text - inputs = tokenizer(text, return_tensors="pt") - input_ids = inputs.get("input_ids") - attention_mask = inputs.get("attention_mask") - - print(f"input_ids: {input_ids}") - print(f"attention_mask: {attention_mask}") - - # Initialize KVCache - past_kv_cache = KVCache( - batch_size=input_ids.size(0), - max_seq_len=model.max_position_embeddings, - num_heads=model.num_heads, - head_dim=model.head_dim, - dtype=input_ids.dtype - ) - - # Forward pass with KVCache - with torch.no_grad(): - logits, hidden_states, past_kv_cache = model( - input_ids, - attention_mask=attention_mask, - position_ids=None, - past_kv_cache=past_kv_cache - ) - - # Print logits shape and hidden state information - if logits is not None: - print(f"Logits shape: {logits.shape}") - - if hidden_states is not None: - print(f"Number of hidden states: {len(hidden_states)}") - print(f"Shape of last hidden state: {hidden_states[-1].shape}") def test_generation(model, tokenizer, text, max_length=50): """ @@ -96,7 +23,7 @@ def test_generation(model, tokenizer, text, max_length=50): """ # Tokenize input text inputs = tokenizer(text, return_tensors="pt") - input_ids = inputs["input_ids"] + input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask") # Initialize KVCache for caching @@ -121,7 +48,9 @@ def test_generation(model, tokenizer, text, max_length=50): ) # Select next token using logits - next_token = select_next_token(logits, top_k=50, top_p=0.9, temperature=0.7, use_max=False) + #next_token = select_next_token(logits, top_k=50, top_p=0.9, temperature=0.7, use_max=False) + next_token = ttg.sample(logits[:, -1, :].clone().float(), temperature=TEMP, top_k=TOP_K).squeeze(-1) + print(f"next_token: {next_token}") # Update generated_ids generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) @@ -132,12 +61,49 @@ def test_generation(model, tokenizer, text, max_length=50): # Decode generated text generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - print(f"Generated text: {generated_text}") + print(f"\nPrompt: {text}") + print(f"\nGenerated Response: {generated_text}") if __name__ == "__main__": - print("Testing forward pass:") - test_forward_pass(model, tokenizer, test_text) - print("\nTesting generation:") + # Get the path to the model files from the Hugging Face cache + cache_dir = Path(snapshot_download(MODEL_NAME)) + print(f"Cache directory: {cache_dir}") + + # Load model configuration + config = load_model_config(cache_dir / "config.json") + + print(f"current config\n{config}") + + # Setup shard + shard = Shard( + model_id=MODEL_NAME, + start_layer=0, + end_layer=int(config["num_hidden_layers"]) - 1, + n_layers=int(config["num_hidden_layers"]) + ) + + # Initialize tokenizer + tokenizer = AutoTokenizer.from_pretrained(cache_dir) + + # Initialize LlamaModel with config and tokenizer + model = LlamaModel(config, shard) + + # Load weights from safetensors files in the cache directory + safetensors_files = list(cache_dir.glob("*.safetensors")) + if not safetensors_files: + raise FileNotFoundError("No safetensors files found in the cache directory.") + + # Load weights from each found safetensors file + for safetensor_file in safetensors_files: + print(f"Loading weights from: {safetensor_file}") + state_dict = load_safetensors(safetensor_file) + model.load_state_dict(state_dict, strict=False) + + model.eval() # Set the model to evaluation mode + + # Sample text for testing + test_text = "What color is a red apple?" + test_generation(model, tokenizer, test_text) From a6768b4717fa095268ce9e7b0f95abf7a6f1e7e1 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 25 Oct 2024 03:43:11 -0800 Subject: [PATCH 465/491] adding a chat temple from tokenizer to test, looking at padding ids to see if they is causing more gibberish, looking into better logit sampling and looking over other generation setup --- exo/inference/torch/models/llama3.py | 5 ++--- .../torch/tests/test_llama3_model.py | 22 +++++++++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 954df4a1..60d08426 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -160,7 +160,6 @@ def __init__(self, config: dict, shard: Shard): self.head_dim = config['head_dim'] self.attention_dropout = config.get('attention_dropout', 0.0) self.padding_idx = config.get("pad_token_id") - self.device_map="any" # Model layers self.embed = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) @@ -183,7 +182,7 @@ def __init__(self, config: dict, shard: Shard): ) for _ in range(self.num_layers) ]) self.norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) - self.to_logits = nn.Linear(self.hidden_size, self.vocab_size) + self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) def forward( self, @@ -299,7 +298,7 @@ def forward( # Compute logits if at end layer if self.shard.is_last_layer(): - logits = self.to_logits(hidden_states) + logits = self.lm_head(hidden_states[:, -1:, :]) else: logits = None diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 6308335b..6537115b 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -14,18 +14,31 @@ MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct" TEMP=0.7 -TOP_K=25 +TOP_K=35 +TOP_P=0.9 -def test_generation(model, tokenizer, text, max_length=50): +def test_generation(model, tokenizer, text, max_length=10): """ Test the generation capabilities of the LlamaModel with sample text. """ # Tokenize input text - inputs = tokenizer(text, return_tensors="pt") + prompt = tokenizer.apply_chat_template([ + { + "role": "user", + "content": text + } + ], tokenize=False, add_generation_prompt=True) + + print(f"prompt: {prompt}") + + inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask") + print(f"input_ids: {input_ids}") + print(f"attention_mask: {attention_mask}") + # Initialize KVCache for caching past_kv_cache = KVCache( batch_size=input_ids.size(0), @@ -48,12 +61,13 @@ def test_generation(model, tokenizer, text, max_length=50): ) # Select next token using logits - #next_token = select_next_token(logits, top_k=50, top_p=0.9, temperature=0.7, use_max=False) + #next_token = select_next_token(logits, top_k=TOP_K, top_p=TOP_P, temperature=TEMP, use_max=False) next_token = ttg.sample(logits[:, -1, :].clone().float(), temperature=TEMP, top_k=TOP_K).squeeze(-1) print(f"next_token: {next_token}") # Update generated_ids generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) + print(f"generated_ids: {generated_ids}") # Check for EOS token if next_token.item() == tokenizer.eos_token_id: From 8ba24e2bc8fd5830237b4b0387628ff0d809a776 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 26 Oct 2024 03:45:29 -0800 Subject: [PATCH 466/491] fixing parameter defintion on 4d mask method, commiting before trying to upgrade to main fork --- exo/inference/torch/models/llm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 22f22485..10eabc74 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -157,7 +157,7 @@ def create_4d_causal_attention_mask( target_len: int, dtype: torch.dtype, device: torch.device, - cache_pos: torch.Tensor, + cache_pos: torch.LongTensor, batch_size: int, ) -> torch.Tensor: """ From cfb10ba39e196a515e6be87bffce9d628e44a0df Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 28 Oct 2024 01:16:40 -0800 Subject: [PATCH 467/491] added in more base llm functions like multiheadattention and rotate embed, working on llama model and getting shapes/reshaping right, running into kv cache issue --- exo/inference/torch/README.md | 3 + exo/inference/torch/models/llama3.py | 139 +++++---- exo/inference/torch/models/llm_utils.py | 285 ++++++++++++++---- .../torch/tests/test_llama3_model.py | 17 +- 4 files changed, 316 insertions(+), 128 deletions(-) diff --git a/exo/inference/torch/README.md b/exo/inference/torch/README.md index da67faa2..43b3782a 100644 --- a/exo/inference/torch/README.md +++ b/exo/inference/torch/README.md @@ -17,6 +17,9 @@ Still working on split model development (see test_split_model.py). Right now, i ### 10/21/2024 Working on removing transformers due to inference and VRAM usage [issues](https://github.com/exo-explore/exo/pull/139#issuecomment-2424953962). Creating a pure pytorch implementation of llama3 as using transformers wont work for exo. Using some code from meta but also implementing the use of torchtune. +### 10/27/2024 +Still working on llama3 model but wanted to note that a better KVCache needs to be investigated. + ## Tech Tested on diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 60d08426..bf2e6f16 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -8,14 +8,17 @@ import torch import torch.nn as nn from torchtune.modules import ( - MultiHeadAttention, - RotaryPositionalEmbeddings, KVCache, - RMSNorm + RMSNorm, ) from exo.inference.shard import Shard -from exo.inference.torch.models.llm_utils import MLP, create_4d_causal_attention_mask +from exo.inference.torch.models.llm_utils import ( + MultiLayerPreceptron, + MultiHeadAttention, + RotaryEmbedding, + create_4d_causal_attention_mask +) class LlamaBlock(nn.Module): """ @@ -24,12 +27,13 @@ class LlamaBlock(nn.Module): def __init__( self, dim, - heads, - num_kv_heads, head_dim, + num_heads, + num_kv_heads, ff_dim, - rms_norm_eps, rotary_pos_emb, + mlp, + rms_norm_eps=1e-6, attention_dropout=0.0, use_bias=False, max_seq_len=4096, @@ -37,9 +41,9 @@ def __init__( ): super(LlamaBlock, self).__init__() self.dim = dim - self.heads = heads - self.num_kv_heads = num_kv_heads self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads self.ff_dim = ff_dim self.rms_norm_eps = rms_norm_eps self.attention_dropout = attention_dropout @@ -47,26 +51,21 @@ def __init__( self.max_seq_len = max_seq_len self.pos_embeddings = pos_embeddings self.rotary_pos_emb = rotary_pos_emb - self.q_proj = nn.Linear(dim, heads * head_dim, bias=use_bias) + self.q_proj = nn.Linear(dim, num_heads * head_dim, bias=use_bias) self.k_proj = nn.Linear(dim, num_kv_heads * head_dim, bias=use_bias) self.v_proj = nn.Linear(dim, num_kv_heads * head_dim, bias=use_bias) - self.output_proj = nn.Linear(heads * head_dim, dim, bias=use_bias) - self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) - self.mlp = MLP( - input_dim=dim, - hidden_dims=[ff_dim], # Single hidden layer with ff_dim as the hidden size - output_dim=dim, - activation='gelu', - dropout=attention_dropout - ) - self.post_norm = RMSNorm(dim, eps=rms_norm_eps) + self.output_proj = nn.Linear(num_heads * head_dim, dim, bias=use_bias) + self.input_layer_norm = RMSNorm(dim, eps=rms_norm_eps) + self.mlp = mlp + self.post_attention_norm = RMSNorm(dim, eps=rms_norm_eps) def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], kv_cache: Optional[KVCache] = None, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[KVCache]]: """ Forward pass with integrated attention, resnet and key-value caching. @@ -86,36 +85,39 @@ def forward( residual = hidden_states # Apply RMSNorm to input - hidden_states = self.input_norm(hidden_states) + hidden_states = self.input_layer_norm(hidden_states) + print(f"self.input_layer_norm(hidden_states) {hidden_states.shape}") + + batch_size, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(batch_size, seq_len, self.num_heads, self.head_dim).squeeze() + print(f"hidden_states: {hidden_states.shape}") # Apply MultiHeadAttention with KVCache - hidden_states = MultiHeadAttention( - embed_dim=self.dim, - num_heads=self.heads, + mh_attn = MultiHeadAttention( + hidden_size=self.head_dim, + num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, - q_proj=self.q_proj, - k_proj=self.k_proj, - v_proj=self.v_proj, - output_proj=self.output_proj, - pos_embeddings=self.rotary_pos_emb, - q_norm=self.q_norm, - k_norm=self.q_norm, - kv_cache=kv_cache, # Passed during the forward call - max_seq_len=self.max_seq_len, + kv_cache=kv_cache, is_causal=True, - attn_dropout=self.attention_dropout - )( - x=hidden_states, - mask=attention_mask, - input_pos=position_ids + attention_dropout=self.attention_dropout, + rotary_emb=self.rotary_pos_emb + ) + + hidden_states = mh_attn( + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + position_embeddings=position_embeddings ) # Residual connection hidden_states = residual + hidden_states residual = hidden_states + print(f"hidden_states: {hidden_states}") + print(f"residual: {residual}") # Post attention normalization - hidden_states = self.post_norm(hidden_states) + hidden_states = self.post_attention_norm(hidden_states) # Feed-forward network with MLP and residual connection hidden_states = self.mlp(hidden_states) hidden_states = hidden_states + residual @@ -161,24 +163,32 @@ def __init__(self, config: dict, shard: Shard): self.attention_dropout = config.get('attention_dropout', 0.0) self.padding_idx = config.get("pad_token_id") - # Model layers + # Model layers and methods self.embed = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) - self.rotary_pos_emb = RotaryPositionalEmbeddings( - self.hidden_size // self.num_heads, + self.rotary_pos_emb = RotaryEmbedding( + self.head_dim, config['rope_scaling']['original_max_position_embeddings'], config['rope_theta'] ) + self.mlp = MultiLayerPreceptron( + input_dim=self.hidden_size, + hidden_dims=[self.intermediate_size], # Single hidden layer with ff_dim as the hidden size + output_dim=self.hidden_size, + activation='gelu', + dropout=self.attention_dropout + ) self.layers = nn.ModuleList([ LlamaBlock( dim=self.hidden_size, - heads=self.hidden_size // self.num_heads, + head_dim=self.hidden_size // self.num_heads, + num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, ff_dim=self.intermediate_size, rms_norm_eps=self.rms_norm_eps, attention_dropout=self.attention_dropout, use_bias=config.get('attention_bias', False), - rotary_pos_emb=self.rotary_pos_emb + rotary_pos_emb=self.rotary_pos_emb, + mlp=self.mlp ) for _ in range(self.num_layers) ]) self.norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) @@ -189,7 +199,7 @@ def forward( input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, past_kv_cache: Optional[KVCache] = None, ) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor]], Optional[KVCache]]: """ @@ -214,15 +224,15 @@ def forward( # Create initial embeddings input_embeds = self.embed(input_ids) - # Initialize or use the provided KVCache - if past_kv_cache is None: - past_kv_cache = KVCache( - batch_size=batch_size, - max_seq_len=self.max_position_embeddings, - num_heads=self.num_heads, - head_dim=self.head_dim, - dtype=input_embeds.dtype - ) + ## Initialize or use the provided KVCache + #if past_kv_cache is None: + # past_kv_cache = KVCache( + # batch_size=batch_size, + # max_seq_len=self.max_position_embeddings, + # num_heads=self.num_heads, + # head_dim=self.head_dim, + # dtype=input_embeds.dtype + # ) # Initialize position IDs if not provided if cache_position is None: @@ -231,11 +241,13 @@ def forward( past_seen_tokens, past_seen_tokens + seq_len, device=input_ids.device - ).unsqueeze(0).expand(batch_size, -1) + ) + #.unsqueeze(0).expand(batch_size, -1) if position_ids is None: position_ids = cache_position.unsqueeze(0) + print(f"input_embeds: {input_embeds.shape}") hidden_states = input_embeds # Reshape hidden_states to (batch_size, seq_len, num_heads, head_dim) @@ -250,14 +262,12 @@ def forward( print(f"position_ids: {position_ids.shape}") # Apply rotary positional embeddings - position_embeddings = self.rotary_pos_emb( - hidden_states, - input_pos=position_ids - ) + position_embeddings = self.rotary_pos_emb(hidden_states, position_ids) - print(f"position_embeddings: {position_embeddings.shape}") + print(f"position_embeddings: {position_embeddings}") # Reshape back to (batch_size, seq_len, hidden_size) + print(f"hidden_size: {self.hidden_size}") hidden_states = hidden_states.view(batch_size, seq_len, self.hidden_size) print(f"hidden_states: {hidden_states.shape}") @@ -278,13 +288,12 @@ def forward( print(f"causal_mask: {causal_mask.shape}") # Forward pass through layers with KVCache - for layer_idx in range(self.shard.end_layer, self.shard.start_layer): + for layer_idx in range(self.shard.start_layer, self.shard.end_layer): print(f"forward layer #{layer_idx}") encoder_layer = self.layers[layer_idx] print(f"encoder_layer\n{encoder_layer}") layer_hidden_state, layer_kv_cache = self.layers[layer_idx]( hidden_states=hidden_states, - kv_cache=past_kv_cache, attention_mask=causal_mask, position_ids=position_ids, position_embeddings=position_embeddings @@ -293,6 +302,8 @@ def forward( hidden_states = layer_hidden_state past_kv_cache = layer_kv_cache + print(f"layer_kv_cache: {layer_kv_cache.size}") + # Apply final layer normalization hidden_states = self.norm(hidden_states) diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 10eabc74..1f5abbe5 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -3,10 +3,13 @@ """ import json from pathlib import Path +from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +import torchtune.modules as ttm +import math from exo.helpers import DEBUG @@ -90,7 +93,7 @@ def select_next_token( return next_token.squeeze(-1) -class MLP(nn.Module): +class MultiLayerPreceptron(nn.Module): def __init__(self, input_dim, hidden_dims, output_dim, activation='gelu', dropout=0.0, use_batchnorm=False): """ General MLP (Multi-Layer Perceptron) module. @@ -103,7 +106,7 @@ def __init__(self, input_dim, hidden_dims, output_dim, activation='gelu', dropou dropout (float): Dropout probability. use_batchnorm (bool): Whether to use batch normalization. """ - super(MLP, self).__init__() + super(MultiLayerPreceptron, self).__init__() self.layers = nn.ModuleList() self.use_batchnorm = use_batchnorm @@ -152,70 +155,238 @@ def forward(self, x): return self.output_layer(x) def create_4d_causal_attention_mask( - attention_mask: torch.Tensor, - seq_len: int, - target_len: int, - dtype: torch.dtype, - device: torch.device, - cache_pos: torch.LongTensor, - batch_size: int, + attention_mask: torch.Tensor, + seq_len: int, + target_len: int, + dtype: torch.dtype, + device: torch.device, + cache_pos: torch.Tensor, + batch_size: int, ) -> torch.Tensor: + """ + Creates a 4D causal attention mask from a 2D mask, with adjustments for static caching. + + Args: + attention_mask (torch.Tensor): + A 2D tensor of shape (batch_size, key_value_length) or a 4D tensor of shape + (batch_size, 1, query_length, key_value_length). + seq_len (int): + Sequence length of the input being processed. + target_len (int): + Target length to generate the causal mask. + dtype (torch.dtype): + Data type for the causal mask. + device (torch.device): + Device to place the causal mask on. + cache_pos (torch.Tensor): + Cache position indices indicating the position of the input tokens in the sequence. + batch_size (int): + Number of samples in the batch. + + Returns: + torch.Tensor: + A 4D causal mask of shape (batch_size, 1, query_length, key_value_length). + """ + if attention_mask is not None and attention_mask.dim() == 4: + # If the mask is already 4D, return it directly + return attention_mask + + min_value = torch.finfo(dtype).min + + # Create a 2D causal mask of shape (seq_len, target_len) + causal_mask = torch.full( + (seq_len, target_len), fill_value=min_value, dtype=dtype, device=device + ) + + if seq_len != 1: + # Mask positions after the current position + causal_mask = torch.triu(causal_mask, diagonal=1) + + # Adjust causal mask for cache position + causal_mask *= (torch.arange(target_len, device=device) > cache_pos.view(-1, 1)) + + # Expand to 4D and batch size + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + + if attention_mask is not None: + # Create a padding mask based on the input attention_mask + mask_len = attention_mask.shape[-1] + causal_mask = causal_mask.clone() # Ensure contiguous memory for in-place operations + padding_mask = causal_mask[:, :, :, :mask_len] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + + # Apply padding to the causal mask + causal_mask[:, :, :, :mask_len] = causal_mask[:, :, :, :mask_len].masked_fill( + padding_mask, min_value + ) + + return causal_mask + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +class MultiHeadAttention(nn.Module): + """Multi-headed attention mechanism.""" + + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + rotary_emb, + kv_cache: Optional[ttm.KVCache] = None, + attention_dropout=0.0, + is_causal=True + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.attention_dropout = attention_dropout + self.is_causal = is_causal + self.rotary_emb = rotary_emb + + self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) + self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False) + + self.kv_cache = kv_cache + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + ) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.size() + + if self.kv_cache is None or self.kv_cache.batch_size != batch_size: + self.kv_cache = ttm.KVCache( + batch_size=batch_size, + max_seq_len=seq_len, + num_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=hidden_states.dtype + ) + + # Project to queries, keys, and values + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Reshape to [batch_size, num_heads, seq_len, head_dim] + query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + print(f"query_states: {query_states.shape}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") + + # Apply rotary positional embeddings if position_ids are provided + # or use position_embeddings + if position_embeddings is not None: + cos, sin = position_embeddings + else: + cos, sin = self.rotary_emb(query_states, position_ids) + + # Expand cos and sin to match the shape of query_states + cos = cos[:, :, None, :self.head_dim].expand_as(query_states) + sin = sin[:, :, None, :self.head_dim].expand_as(query_states) + print(f"cos: {cos.shape} | sin: {sin.shape}") + + # Apply rotary embeddings to queries and keys + query_states = (query_states * cos) + (rotate_half(query_states) * sin) + key_states = (key_states * cos) + (rotate_half(key_states) * sin) + + # Repeat keys and values if needed + if self.num_heads > self.num_kv_heads: + n_rep = self.num_heads // self.num_kv_heads + key_states = torch.repeat_interleave(key_states, n_rep, dim=1) + value_states = torch.repeat_interleave(value_states, n_rep, dim=1) + + print(f"query_states: {query_states.shape}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") + + # Forcing caching always enabled + key_states, value_states = self.kv_cache.update(key_states, value_states) + + # Compute attention scores + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + + # Apply causal mask, if applicable + if self.is_causal: + causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=hidden_states.device)) + attn_weights = attn_weights.masked_fill(causal_mask == 0, float('-inf')) + + # Apply attention mask, if provided + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # Softmax normalization + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + # Compute attention output + attn_output = torch.matmul(attn_weights, value_states) + + # Reshape to [batch_size, seq_len, hidden_size] + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) + + # Project back to hidden size + attn_output = self.o_proj(attn_output) + + return attn_output + +class RotaryEmbedding(nn.Module): + """Rotary Position Embedding.""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0, rope_type="default", device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.scaling_factor = scaling_factor + self.rope_type = rope_type + + # Initialize the inverse frequency for RoPE + inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, x, position_ids) -> Tuple[torch.Tensor, torch.Tensor]: """ - Creates a 4D causal attention mask from a 2D mask, with adjustments for static caching. + Compute the rotary position embeddings (cos, sin) for the given input tensor. Args: - attention_mask (torch.Tensor): - A 2D tensor of shape (batch_size, key_value_length) or a 4D tensor of shape - (batch_size, 1, query_length, key_value_length). - seq_len (int): - Sequence length of the input being processed. - target_len (int): - Target length to generate the causal mask. - dtype (torch.dtype): - Data type for the causal mask. - device (torch.device): - Device to place the causal mask on. - cache_pos (torch.Tensor): - Cache position indices indicating the position of the input tokens in the sequence. - batch_size (int): - Number of samples in the batch. + x (torch.Tensor): The input tensor of shape (batch_size, seq_len, num_heads, head_dim). + position_ids (torch.Tensor): The position indices for the sequence. Returns: - torch.Tensor: - A 4D causal mask of shape (batch_size, 1, query_length, key_value_length). + Tuple[torch.Tensor, torch.Tensor]: The cos and sin embeddings. """ - if attention_mask is not None and attention_mask.dim() == 4: - # If the mask is already 4D, return it directly - return attention_mask + # Expand inv_freq to match the batch size and sequence length + batch_size, seq_len = position_ids.size(0), position_ids.size(1) + inv_freq_expanded = self.inv_freq[None, :, None].expand(batch_size, -1, seq_len) - min_value = torch.finfo(dtype).min + # Expand position_ids to match the frequency tensor + position_ids_expanded = position_ids[:, None, :].float() - # Create a 2D causal mask of shape (seq_len, target_len) - causal_mask = torch.full( - (seq_len, target_len), fill_value=min_value, dtype=dtype, device=device - ) + # Compute cos and sin embeddings + freqs = torch.einsum("bnd,bnl->bnd", inv_freq_expanded, position_ids_expanded) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() - if seq_len != 1: - # Mask positions after the current position - causal_mask = torch.triu(causal_mask, diagonal=1) + # Apply the scaling factor to cos and sin embeddings + cos = cos * self.scaling_factor + sin = sin * self.scaling_factor - # Adjust causal mask for cache position - causal_mask *= (torch.arange(target_len, device=device) > cache_pos.view(-1, 1)) - - # Expand to 4D and batch size - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - - if attention_mask is not None: - # Create a padding mask based on the input attention_mask - mask_len = attention_mask.shape[-1] - causal_mask = causal_mask.clone() # Ensure contiguous memory for in-place operations - padding_mask = causal_mask[:, :, :, :mask_len] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - - # Apply padding to the causal mask - causal_mask[:, :, :, :mask_len] = causal_mask[:, :, :, :mask_len].masked_fill( - padding_mask, min_value - ) - - return causal_mask + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 6537115b..973b66da 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -40,13 +40,16 @@ def test_generation(model, tokenizer, text, max_length=10): print(f"attention_mask: {attention_mask}") # Initialize KVCache for caching - past_kv_cache = KVCache( - batch_size=input_ids.size(0), - max_seq_len=model.max_position_embeddings, - num_heads=model.num_heads, - head_dim=model.head_dim, - dtype=input_ids.dtype - ) + past_kv_cache = None + #past_kv_cache = KVCache( + # batch_size=input_ids.size(0), + # max_seq_len=model.max_position_embeddings, + # num_heads=model.num_heads, + # head_dim=model.head_dim, + # dtype=input_ids.dtype + #) + + #print(f"past_kv_cache: {past_kv_cache}") # Start with initial input_ids generated_ids = input_ids.clone() From ea868c6b7b4155bc88fc589449cefab4085f1a8f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 30 Oct 2024 03:20:09 -0800 Subject: [PATCH 468/491] updating attentions, changed model struct, fixing kv cache --- exo/inference/torch/models/llama3.py | 157 +++---- exo/inference/torch/models/llm_utils.py | 413 +++++++++++++----- .../torch/tests/test_llama3_model.py | 63 +-- 3 files changed, 393 insertions(+), 240 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index bf2e6f16..09c162e6 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -9,13 +9,14 @@ import torch.nn as nn from torchtune.modules import ( KVCache, - RMSNorm, + RMSNorm ) from exo.inference.shard import Shard from exo.inference.torch.models.llm_utils import ( MultiLayerPreceptron, - MultiHeadAttention, + #MultiHeadAttention, + SDPAttention, RotaryEmbedding, create_4d_causal_attention_mask ) @@ -27,36 +28,14 @@ class LlamaBlock(nn.Module): def __init__( self, dim, - head_dim, - num_heads, - num_kv_heads, - ff_dim, - rotary_pos_emb, mlp, - rms_norm_eps=1e-6, - attention_dropout=0.0, - use_bias=False, - max_seq_len=4096, - pos_embeddings=None + self_attn, + rms_norm_eps=1e-6 ): super(LlamaBlock, self).__init__() - self.dim = dim - self.head_dim = head_dim - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.ff_dim = ff_dim - self.rms_norm_eps = rms_norm_eps - self.attention_dropout = attention_dropout - self.use_bias = use_bias - self.max_seq_len = max_seq_len - self.pos_embeddings = pos_embeddings - self.rotary_pos_emb = rotary_pos_emb - self.q_proj = nn.Linear(dim, num_heads * head_dim, bias=use_bias) - self.k_proj = nn.Linear(dim, num_kv_heads * head_dim, bias=use_bias) - self.v_proj = nn.Linear(dim, num_kv_heads * head_dim, bias=use_bias) - self.output_proj = nn.Linear(num_heads * head_dim, dim, bias=use_bias) - self.input_layer_norm = RMSNorm(dim, eps=rms_norm_eps) + self.self_attn = self_attn self.mlp = mlp + self.input_layer_norm = RMSNorm(dim, eps=rms_norm_eps) self.post_attention_norm = RMSNorm(dim, eps=rms_norm_eps) def forward( @@ -88,23 +67,12 @@ def forward( hidden_states = self.input_layer_norm(hidden_states) print(f"self.input_layer_norm(hidden_states) {hidden_states.shape}") - batch_size, seq_len, _ = hidden_states.shape - hidden_states = hidden_states.view(batch_size, seq_len, self.num_heads, self.head_dim).squeeze() - print(f"hidden_states: {hidden_states.shape}") - - # Apply MultiHeadAttention with KVCache - mh_attn = MultiHeadAttention( - hidden_size=self.head_dim, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - kv_cache=kv_cache, - is_causal=True, - attention_dropout=self.attention_dropout, - rotary_emb=self.rotary_pos_emb - ) + #batch_size, seq_len, _ = hidden_states.shape + #hidden_states = hidden_states.view(batch_size, seq_len, self.num_heads, self.head_dim).squeeze() + #print(f"hidden_states: {hidden_states.shape}") - hidden_states = mh_attn( + # Apply MultiHeadAttention with KVCache + hidden_states, kv_cache = self.self_attn( hidden_states=hidden_states, position_ids=position_ids, attention_mask=attention_mask, @@ -114,8 +82,8 @@ def forward( # Residual connection hidden_states = residual + hidden_states residual = hidden_states - print(f"hidden_states: {hidden_states}") - print(f"residual: {residual}") + print(f"hidden_states: {hidden_states.shape}") + print(f"residual: {residual.shape}") # Post attention normalization hidden_states = self.post_attention_norm(hidden_states) # Feed-forward network with MLP and residual connection @@ -163,41 +131,42 @@ def __init__(self, config: dict, shard: Shard): self.attention_dropout = config.get('attention_dropout', 0.0) self.padding_idx = config.get("pad_token_id") - # Model layers and methods + # Model layers and methods, order matters self.embed = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) - self.rotary_pos_emb = RotaryEmbedding( - self.head_dim, - config['rope_scaling']['original_max_position_embeddings'], - config['rope_theta'] - ) - self.mlp = MultiLayerPreceptron( - input_dim=self.hidden_size, - hidden_dims=[self.intermediate_size], # Single hidden layer with ff_dim as the hidden size - output_dim=self.hidden_size, - activation='gelu', - dropout=self.attention_dropout - ) self.layers = nn.ModuleList([ LlamaBlock( dim=self.hidden_size, - head_dim=self.hidden_size // self.num_heads, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ff_dim=self.intermediate_size, rms_norm_eps=self.rms_norm_eps, - attention_dropout=self.attention_dropout, - use_bias=config.get('attention_bias', False), - rotary_pos_emb=self.rotary_pos_emb, - mlp=self.mlp + self_attn=SDPAttention( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.hidden_size // self.num_heads, + is_causal=True, + attention_dropout=self.attention_dropout, + rotary_emb=RotaryEmbedding( + self.head_dim + ), + attention_bias=config.get('attention_bias', False) + ), + mlp=MultiLayerPreceptron( + input_dim=self.hidden_size, + hidden_dim=self.intermediate_size, + activation=self.config.get("hidden_act", "silu"), + use_bias=self.config.get("mlp_bias", False) + ), ) for _ in range(self.num_layers) ]) self.norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) + self.rotary_pos_emb = RotaryEmbedding( + self.head_dim + ) self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) def forward( self, input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor, position_ids: Optional[torch.Tensor] = None, cache_position: Optional[torch.Tensor] = None, past_kv_cache: Optional[KVCache] = None, @@ -215,7 +184,7 @@ def forward( Returns: Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], KVCache]: - - logits (Optional[torch.Tensor]): Output logits of shape (batch_size, seq_len, vocab_size). + - pred_score (Optional[torch.Tensor]): Prediction scores from lm_head of model. - hidden_states (Optional[torch.Tensor]): Hidden states from each layer - past_kv_cache (KVCache): Updated KVCache object. """ @@ -239,11 +208,13 @@ def forward( past_seen_tokens = past_kv_cache.size if past_kv_cache is not None else 0 cache_position = torch.arange( past_seen_tokens, - past_seen_tokens + seq_len, + past_seen_tokens + input_embeds.shape[1], device=input_ids.device ) #.unsqueeze(0).expand(batch_size, -1) + print(f"cache_position: {cache_position.shape}") + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -264,28 +235,30 @@ def forward( # Apply rotary positional embeddings position_embeddings = self.rotary_pos_emb(hidden_states, position_ids) - print(f"position_embeddings: {position_embeddings}") - # Reshape back to (batch_size, seq_len, hidden_size) print(f"hidden_size: {self.hidden_size}") hidden_states = hidden_states.view(batch_size, seq_len, self.hidden_size) print(f"hidden_states: {hidden_states.shape}") - # create 4d causal mask - causal_mask = None - if attention_mask is not None: - causal_mask = create_4d_causal_attention_mask( - attention_mask=attention_mask, - seq_len=hidden_states.shape[1], - target_len=attention_mask.shape[-1], - dtype=hidden_states.dtype, - device=hidden_states.device, - cache_pos=cache_position, - batch_size=hidden_states.size(0) - ) + # create/update 4d causal mask + seq_len = input_embeds.shape[1] + + if past_kv_cache is not None: + target_len = past_kv_cache.size + seq_len + 1 + else: + target_len = seq_len + 1 + causal_mask = create_4d_causal_attention_mask( + attention_mask=attention_mask, + seq_len=seq_len, + target_len=target_len, + dtype=input_embeds.dtype, + device=input_embeds.device, + cache_pos=cache_position, + batch_size=input_embeds.size(0) + ) - print(f"attention_mask: {attention_mask.shape}") - print(f"causal_mask: {causal_mask.shape}") + print(f"attention_mask: {attention_mask.shape}") + print(f"causal_mask: {causal_mask.shape}") # Forward pass through layers with KVCache for layer_idx in range(self.shard.start_layer, self.shard.end_layer): @@ -307,13 +280,13 @@ def forward( # Apply final layer normalization hidden_states = self.norm(hidden_states) - # Compute logits if at end layer + # Compute prediction score from lm head if at end layer if self.shard.is_last_layer(): - logits = self.lm_head(hidden_states[:, -1:, :]) + pred_score = self.lm_head(hidden_states) else: - logits = None + pred_score = None - if logits is None: - return logits, hidden_states, past_kv_cache + if pred_score is None: + return pred_score, hidden_states, past_kv_cache else: - return logits, None, past_kv_cache + return pred_score, None, past_kv_cache diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 1f5abbe5..f43c228b 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -11,6 +11,8 @@ import torchtune.modules as ttm import math +from transformers.models.mamba.modeling_mamba import causal_conv1d_update + from exo.helpers import DEBUG def load_model_config(model_config_path: Path) -> dict: @@ -39,7 +41,7 @@ def select_next_token( Selects the next token from logits using top-k, top-p, and temperature scaling. Args: - logits (torch.Tensor): Logits tensor of shape (batch_size, vocab_size). + logits (torch.Tensor): Logits or prediction scores tensor of shape (batch_size, vocab_size). top_k (int): Number of top logits to consider for sampling. top_p (float): Cumulative probability threshold for nucleus sampling. temperature (float): Scaling factor for temperature. @@ -58,49 +60,50 @@ def select_next_token( # Apply top-k filtering if top_k > 0: - # Get the top-k logits and set the rest to -inf - top_k_values, _ = torch.topk(logits, top_k, dim=-1) - min_top_k_value = top_k_values[:, -1, None] - logits = torch.where(logits < min_top_k_value, torch.tensor(float('-inf'), device=logits.device), logits) + top_k = min(top_k, logits.size(-1)) + min_topk = torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(logits < min_topk, float("-inf")) # Apply top-p (nucleus) filtering if top_p > 0.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Mask tokens exceeding the top-p threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() # Shift right - sorted_indices_to_remove[:, 0] = 0 # Ensure at least one token is selected + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + sorted_indices_to_remove[..., -1:] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(indices_to_remove, float('-inf')) - # Calculate probabilities - probs = F.softmax(logits, dim=-1) - # Select next token if not use_max: + probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) + next_token = next_token[:, None].squeeze(-1) + # Debugging output if DEBUG >= 4: print(f"Logits: {logits}") - print(f"Probabilities: {probs}") print(f"Next token: {next_token}") - return next_token.squeeze(-1) + return next_token class MultiLayerPreceptron(nn.Module): - def __init__(self, input_dim, hidden_dims, output_dim, activation='gelu', dropout=0.0, use_batchnorm=False): + def __init__( + self, + input_dim, + hidden_dim, + activation='gelu', + use_bias=False + ): """ General MLP (Multi-Layer Perceptron) module. Args: input_dim (int): Dimensionality of the input. - hidden_dims (list of int): List of hidden layer dimensions. + hidden_dims (int): Hidden layer/intermediate dimensions. output_dim (int): Dimensionality of the output. activation (str): Activation function ('relu', 'gelu', 'tanh', 'sigmoid', etc.). dropout (float): Dropout probability. @@ -108,39 +111,27 @@ def __init__(self, input_dim, hidden_dims, output_dim, activation='gelu', dropou """ super(MultiLayerPreceptron, self).__init__() - self.layers = nn.ModuleList() - self.use_batchnorm = use_batchnorm - # Activation function mapping activations = { 'relu': nn.ReLU(), 'gelu': nn.GELU(), 'tanh': nn.Tanh(), 'sigmoid': nn.Sigmoid(), - 'leaky_relu': nn.LeakyReLU(0.2) + 'leaky_relu': nn.LeakyReLU(0.2), + 'silu': nn.SiLU() } # Ensure valid activation if activation not in activations: raise ValueError(f"Invalid activation: {activation}. Choose from {list(activations.keys())}") - self.activation = activations[activation] - # Construct MLP layers - prev_dim = input_dim - for h_dim in hidden_dims: - self.layers.append(nn.Linear(prev_dim, h_dim)) - if use_batchnorm: - self.layers.append(nn.BatchNorm1d(h_dim)) - self.layers.append(self.activation) - if dropout > 0: - self.layers.append(nn.Dropout(dropout)) - prev_dim = h_dim - - # Output layer - self.output_layer = nn.Linear(prev_dim, output_dim) - - def forward(self, x): + self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) + self.up_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) + self.down_proj = nn.Linear(hidden_dim, input_dim, bias=use_bias) + self.act_fn = activations[activation] + + def forward(self, x) -> torch.Tensor: """ Forward pass for the MLP module. @@ -150,9 +141,8 @@ def forward(self, x): Returns: torch.Tensor: Output tensor after the MLP transformations. """ - for layer in self.layers: - x = layer(x) - return self.output_layer(x) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj def create_4d_causal_attention_mask( attention_mask: torch.Tensor, @@ -228,8 +218,69 @@ def rotate_half(x): x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) +class RotaryEmbedding(nn.Module): + """ + Rotary Position Embedding. + + This computes the inverse frequencies according to the original RoPE implementation. + There are other implementations that will be added. + Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.scaling_factor = scaling_factor + + # Initialize the inverse frequency for RoPE + inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the rotary position embeddings (cos, sin) for the given input tensor. + + Args: + x (torch.Tensor): The input tensor of shape (batch_size, seq_len, num_heads, head_dim). + position_ids (torch.Tensor): The position indices for the sequence. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The cos and sin embeddings. + """ + # Expand inv_freq to match the batch size + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.size(0), -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Compute cos and sin embeddings + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Apply the scaling factor to cos and sin embeddings + cos = cos * self.scaling_factor + sin = sin * self.scaling_factor + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + +# ------------------ +# Attention Methods +# ------------------ + class MultiHeadAttention(nn.Module): - """Multi-headed attention mechanism.""" + """ + Multi-headed attention mechanism. + + Using the "attention is all you need" implementation. Other implementations will follow. + Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L277 + Ref: https://pytorch.org/torchtune/0.3/_modules/torchtune/modules/attention.html + """ def __init__( self, @@ -240,7 +291,8 @@ def __init__( rotary_emb, kv_cache: Optional[ttm.KVCache] = None, attention_dropout=0.0, - is_causal=True + is_causal=True, + attention_bias=False ): super().__init__() self.hidden_size = hidden_size @@ -249,42 +301,37 @@ def __init__( self.head_dim = head_dim self.attention_dropout = attention_dropout self.is_causal = is_causal - self.rotary_emb = rotary_emb - - self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) - self.k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) - self.v_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) - self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False) - self.kv_cache = kv_cache + # nn layers + self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=attention_bias) + self.rotary_emb = rotary_emb + def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - ) -> torch.Tensor: + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cos_sin_unsqueeze: int=1 + ) -> Tuple[torch.Tensor, ttm.KVCache]: batch_size, seq_len, _ = hidden_states.size() - if self.kv_cache is None or self.kv_cache.batch_size != batch_size: - self.kv_cache = ttm.KVCache( - batch_size=batch_size, - max_seq_len=seq_len, - num_heads=self.num_kv_heads, - head_dim=self.head_dim, - dtype=hidden_states.dtype - ) - # Project to queries, keys, and values query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) + print(f"query_states: {query_states.shape}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") # Reshape to [batch_size, num_heads, seq_len, head_dim] query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) print(f"query_states: {query_states.shape}") print(f"key_states: {key_states.shape}") print(f"value_states: {value_states.shape}") @@ -294,99 +341,227 @@ def forward( if position_embeddings is not None: cos, sin = position_embeddings else: - cos, sin = self.rotary_emb(query_states, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) - # Expand cos and sin to match the shape of query_states - cos = cos[:, :, None, :self.head_dim].expand_as(query_states) - sin = sin[:, :, None, :self.head_dim].expand_as(query_states) + print(f"cos: {cos.shape} | sin: {sin.shape}") + # Expand cos and sin to match hidden_states' shape + cos = cos.unsqueeze(cos_sin_unsqueeze) + sin = sin.unsqueeze(cos_sin_unsqueeze) print(f"cos: {cos.shape} | sin: {sin.shape}") # Apply rotary embeddings to queries and keys query_states = (query_states * cos) + (rotate_half(query_states) * sin) key_states = (key_states * cos) + (rotate_half(key_states) * sin) - - # Repeat keys and values if needed - if self.num_heads > self.num_kv_heads: - n_rep = self.num_heads // self.num_kv_heads - key_states = torch.repeat_interleave(key_states, n_rep, dim=1) - value_states = torch.repeat_interleave(value_states, n_rep, dim=1) - print(f"query_states: {query_states.shape}") print(f"key_states: {key_states.shape}") print(f"value_states: {value_states.shape}") # Forcing caching always enabled + if self.kv_cache is not None: + print(f"self.kv_cache.size {self.kv_cache.size}") + print(f"key_states.size(0) {key_states.size(2)}") + if self.kv_cache is None or self.kv_cache.batch_size != key_states.size(0): + print(f"\n MAKE NEW KVCACHE batch_size={key_states.size(0)} max_seq_len={key_states.size(2)}") + self.kv_cache = ttm.KVCache( + batch_size=key_states.size(0), + max_seq_len=key_states.size(2), + num_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=hidden_states.dtype + ) key_states, value_states = self.kv_cache.update(key_states, value_states) + print(f"kv_cache: {self.kv_cache.size}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") - # Compute attention scores - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + # Repeat keys and values if needed + #if self.num_heads > self.num_kv_heads: + n_rep = self.num_heads // self.num_kv_heads + key_states = torch.repeat_interleave(key_states, n_rep, dim=1) + value_states = torch.repeat_interleave(value_states, n_rep, dim=1) + + print(f"query_states: {query_states.shape}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") - # Apply causal mask, if applicable - if self.is_causal: - causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=hidden_states.device)) - attn_weights = attn_weights.masked_fill(causal_mask == 0, float('-inf')) + # Compute attention scores + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + print(f"attn_weights: {attn_weights.shape}") # Apply attention mask, if provided if attention_mask is not None: - attn_weights = attn_weights + attention_mask + print(f"attention_mask: {attention_mask.shape}") + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + print(f"causal_mask: {causal_mask.shape}") + attn_weights = attn_weights + causal_mask + print(f"attn_weights: {attn_weights.shape}") # Softmax normalization attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + print(f"attn_weights: {attn_weights.shape}") # Compute attention output attn_output = torch.matmul(attn_weights, value_states) + print(f"attn_output: {attn_output.shape}") - # Reshape to [batch_size, seq_len, hidden_size] - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) + # Transpose attention output + attn_output = attn_output.transpose(1,2).contiguous() + print(f"attn_output: {attn_output.shape}") + + # Reshape [batch_size, seq_len, -1] + attn_output = attn_output.reshape(batch_size, seq_len, -1) + print(f"attn_output after transpose: {attn_output.shape}") # Project back to hidden size attn_output = self.o_proj(attn_output) + print(f"attn_output: {attn_output.shape}") - return attn_output + return attn_output, self.kv_cache -class RotaryEmbedding(nn.Module): - """Rotary Position Embedding.""" +class SDPAttention(nn.Module): + """ + Scaled dot product attention mechanism. - def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0, rope_type="default", device=None): + Using the scaled dot product attention method from pytorch + Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L524 + """ + + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + rotary_emb, + kv_cache: Optional[ttm.KVCache] = None, + attention_dropout=0.0, + is_causal=True, + attention_bias=False + ): super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.scaling_factor = scaling_factor - self.rope_type = rope_type + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.attention_dropout = attention_dropout + self.is_causal = is_causal + self.kv_cache = kv_cache - # Initialize the inverse frequency for RoPE - inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) + # nn layers + self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=attention_bias) + self.rotary_emb = rotary_emb - def forward(self, x, position_ids) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Compute the rotary position embeddings (cos, sin) for the given input tensor. + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cos_sin_unsqueeze: int=1 + ) -> Tuple[torch.Tensor, ttm.KVCache]: + batch_size, seq_len, _ = hidden_states.size() - Args: - x (torch.Tensor): The input tensor of shape (batch_size, seq_len, num_heads, head_dim). - position_ids (torch.Tensor): The position indices for the sequence. + # Project to queries, keys, and values + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + print(f"query_states: {query_states.shape}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") - Returns: - Tuple[torch.Tensor, torch.Tensor]: The cos and sin embeddings. - """ - # Expand inv_freq to match the batch size and sequence length - batch_size, seq_len = position_ids.size(0), position_ids.size(1) - inv_freq_expanded = self.inv_freq[None, :, None].expand(batch_size, -1, seq_len) + # Reshape to [batch_size, num_heads, seq_len, head_dim] + query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + print(f"query_states: {query_states.shape}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") - # Expand position_ids to match the frequency tensor - position_ids_expanded = position_ids[:, None, :].float() + # Apply rotary positional embeddings if position_ids are provided + # or use position_embeddings + if position_embeddings is not None: + cos, sin = position_embeddings + else: + cos, sin = self.rotary_emb(value_states, position_ids) - # Compute cos and sin embeddings - freqs = torch.einsum("bnd,bnl->bnd", inv_freq_expanded, position_ids_expanded) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() + print(f"cos: {cos.shape} | sin: {sin.shape}") + # Expand cos and sin to match hidden_states' shape + cos = cos.unsqueeze(cos_sin_unsqueeze) + sin = sin.unsqueeze(cos_sin_unsqueeze) + print(f"cos: {cos.shape} | sin: {sin.shape}") - # Apply the scaling factor to cos and sin embeddings - cos = cos * self.scaling_factor - sin = sin * self.scaling_factor + # Apply rotary embeddings to queries and keys + query_states = (query_states * cos) + (rotate_half(query_states) * sin) + key_states = (key_states * cos) + (rotate_half(key_states) * sin) + print(f"query_states: {query_states.shape}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + # Forcing caching always enabled + if self.kv_cache is not None: + print(f"self.kv_cache.size {self.kv_cache.size}") + print(f"key_states.size(0) {key_states.size(2)}") + if self.kv_cache is None or self.kv_cache.size != key_states.size(2): + print(f"\n MAKE NEW KVCACHE batch_size={key_states.size(0)} max_seq_len={key_states.size(2)}") + self.kv_cache = ttm.KVCache( + batch_size=key_states.size(0), + max_seq_len=key_states.size(2), + num_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=hidden_states.dtype + ) + key_states, value_states = self.kv_cache.update(key_states, value_states) + print(f"kv_cache: {self.kv_cache.size}") + print(f"from kv_cache / key_states: {key_states.shape}") + print(f"from kv_cache / value_states: {value_states.shape}") + + # Repeat keys and values if needed + #if self.num_heads > self.num_kv_heads: + n_rep = self.num_heads // self.num_kv_heads + key_states = torch.repeat_interleave(key_states, n_rep, dim=1) + value_states = torch.repeat_interleave(value_states, n_rep, dim=1) + + print(f"query_states: {query_states.shape}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") + + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + print(f"causal_mask: {causal_mask.shape}") + + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + print(f"query_states: {query_states.shape}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") + + is_causal = True if causal_mask is None and seq_len > 1 else False + + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + print(f"attn_output: {attn_output.shape}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_len, -1) + + attn_output = self.o_proj(attn_output) + + print(f"attn_output: {attn_output.shape}") + + return attn_output, self.kv_cache diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 973b66da..d61cf5b4 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -9,27 +9,31 @@ from huggingface_hub import snapshot_download from safetensors.torch import load_file as load_safetensors from exo.inference.torch.models.llm_utils import load_model_config, select_next_token -from exo.inference.torch.models.llama3 import LlamaModel, KVCache +from exo.inference.torch.models.llama3 import LlamaModel from exo.inference.shard import Shard MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct" TEMP=0.7 -TOP_K=35 +TOP_K=25 TOP_P=0.9 -def test_generation(model, tokenizer, text, max_length=10): +def test_generation(model, tokenizer, text, max_length=10, config=None): """ Test the generation capabilities of the LlamaModel with sample text. """ # Tokenize input text prompt = tokenizer.apply_chat_template([ + { + "role": "system", + "content": "You are a helpful assistant." + }, { "role": "user", "content": text } ], tokenize=False, add_generation_prompt=True) - + print(f"prompt: {prompt}") inputs = tokenizer(prompt, return_tensors="pt") @@ -39,47 +43,48 @@ def test_generation(model, tokenizer, text, max_length=10): print(f"input_ids: {input_ids}") print(f"attention_mask: {attention_mask}") - # Initialize KVCache for caching - past_kv_cache = None - #past_kv_cache = KVCache( - # batch_size=input_ids.size(0), - # max_seq_len=model.max_position_embeddings, - # num_heads=model.num_heads, - # head_dim=model.head_dim, - # dtype=input_ids.dtype - #) - - #print(f"past_kv_cache: {past_kv_cache}") - # Start with initial input_ids generated_ids = input_ids.clone() # Generate tokens step-by-step + past_kvs = None + + print(f"{model}") + for _ in range(max_length): with torch.no_grad(): - logits, _, past_kv_cache = model( + pred_score, hstates, past_kvs = model( generated_ids, attention_mask=attention_mask, - past_kv_cache=past_kv_cache + past_kv_cache=past_kvs ) - # Select next token using logits - #next_token = select_next_token(logits, top_k=TOP_K, top_p=TOP_P, temperature=TEMP, use_max=False) - next_token = ttg.sample(logits[:, -1, :].clone().float(), temperature=TEMP, top_k=TOP_K).squeeze(-1) + print(f"pred_score: {pred_score.shape}") + print(f"hstates: {hstates.shape if hstates is not None else None}") + print(f"past_kvs: {past_kvs.size if past_kvs is not None else None}") + # Select next token using pred_score + #next_token = select_next_token(pred_score, top_k=TOP_K, top_p=TOP_P, temperature=TEMP, use_max=False) + next_token = ttg.sample(pred_score, temperature=TEMP, top_k=TOP_K)[:, -1, :] print(f"next_token: {next_token}") # Update generated_ids - generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) + generated_ids = torch.cat([generated_ids, next_token], dim=1) print(f"generated_ids: {generated_ids}") # Check for EOS token - if next_token.item() == tokenizer.eos_token_id: - break + print(f"next_token.item(): {next_token.item()}") + + if config: + print(config["eos_token_id"]) + if next_token.item() in config["eos_token_id"]: + break + else: + if next_token.item() == tokenizer.eos_token_id: + break # Decode generated text generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - print(f"\nPrompt: {text}") - print(f"\nGenerated Response: {generated_text}") + print(f"\n\n\n\nGenerated Response: {generated_text}") if __name__ == "__main__": print("\nTesting generation:") @@ -101,7 +106,7 @@ def test_generation(model, tokenizer, text, max_length=10): ) # Initialize tokenizer - tokenizer = AutoTokenizer.from_pretrained(cache_dir) + tokenizer = AutoTokenizer.from_pretrained(shard.model_id) # Initialize LlamaModel with config and tokenizer model = LlamaModel(config, shard) @@ -120,7 +125,7 @@ def test_generation(model, tokenizer, text, max_length=10): model.eval() # Set the model to evaluation mode # Sample text for testing - test_text = "What color is a red apple?" + test_text = "Hello" - test_generation(model, tokenizer, test_text) + test_generation(model, tokenizer, test_text, 5, config) From f1822e292196be72f858475e0fd4f72a4927515f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Wed, 30 Oct 2024 03:28:03 -0800 Subject: [PATCH 469/491] fixing kvcache for multiheadattention, fixing layers names for loading weights properly --- exo/inference/torch/models/llama3.py | 10 +++++----- exo/inference/torch/models/llm_utils.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 09c162e6..13958d06 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -35,7 +35,7 @@ def __init__( super(LlamaBlock, self).__init__() self.self_attn = self_attn self.mlp = mlp - self.input_layer_norm = RMSNorm(dim, eps=rms_norm_eps) + self.input_layernorm = RMSNorm(dim, eps=rms_norm_eps) self.post_attention_norm = RMSNorm(dim, eps=rms_norm_eps) def forward( @@ -64,8 +64,8 @@ def forward( residual = hidden_states # Apply RMSNorm to input - hidden_states = self.input_layer_norm(hidden_states) - print(f"self.input_layer_norm(hidden_states) {hidden_states.shape}") + hidden_states = self.input_layernorm(hidden_states) + print(f"self.input_layernorm(hidden_states) {hidden_states.shape}") #batch_size, seq_len, _ = hidden_states.shape #hidden_states = hidden_states.view(batch_size, seq_len, self.num_heads, self.head_dim).squeeze() @@ -158,7 +158,7 @@ def __init__(self, config: dict, shard: Shard): ) for _ in range(self.num_layers) ]) self.norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) - self.rotary_pos_emb = RotaryEmbedding( + self.rotary_emb = RotaryEmbedding( self.head_dim ) self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) @@ -233,7 +233,7 @@ def forward( print(f"position_ids: {position_ids.shape}") # Apply rotary positional embeddings - position_embeddings = self.rotary_pos_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) # Reshape back to (batch_size, seq_len, hidden_size) print(f"hidden_size: {self.hidden_size}") diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index f43c228b..459823ca 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -360,7 +360,7 @@ def forward( if self.kv_cache is not None: print(f"self.kv_cache.size {self.kv_cache.size}") print(f"key_states.size(0) {key_states.size(2)}") - if self.kv_cache is None or self.kv_cache.batch_size != key_states.size(0): + if self.kv_cache is None or self.kv_cache.size != key_states.size(2): print(f"\n MAKE NEW KVCACHE batch_size={key_states.size(0)} max_seq_len={key_states.size(2)}") self.kv_cache = ttm.KVCache( batch_size=key_states.size(0), From 38028c06183886e9caae7eddd8ed4a0e14c5a76f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 31 Oct 2024 15:47:53 -0800 Subject: [PATCH 470/491] doing work with position_id and causal mask --- exo/inference/torch/models/llama3.py | 33 ++++++++++------- .../torch/tests/test_llama3_model.py | 37 ++++++++++++++++++- 2 files changed, 56 insertions(+), 14 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 13958d06..045a790b 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -71,7 +71,7 @@ def forward( #hidden_states = hidden_states.view(batch_size, seq_len, self.num_heads, self.head_dim).squeeze() #print(f"hidden_states: {hidden_states.shape}") - # Apply MultiHeadAttention with KVCache + # Apply MultiHeadAttention with KVCache hidden_states, kv_cache = self.self_attn( hidden_states=hidden_states, position_ids=position_ids, @@ -132,7 +132,7 @@ def __init__(self, config: dict, shard: Shard): self.padding_idx = config.get("pad_token_id") # Model layers and methods, order matters - self.embed = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) + self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) self.layers = nn.ModuleList([ LlamaBlock( dim=self.hidden_size, @@ -191,7 +191,7 @@ def forward( batch_size, seq_len = input_ids.shape # Create initial embeddings - input_embeds = self.embed(input_ids) + input_embeds = self.embed_tokens(input_ids) ## Initialize or use the provided KVCache #if past_kv_cache is None: @@ -216,10 +216,17 @@ def forward( print(f"cache_position: {cache_position.shape}") if position_ids is None: - position_ids = cache_position.unsqueeze(0) + #position_ids = cache_position.unsqueeze(0) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) - print(f"input_embeds: {input_embeds.shape}") - hidden_states = input_embeds + # cache based input generation + if past_kv_cache is not None: + hidden_states = input_embeds[:, -cache_position.shape[0]:] + else: + hidden_states = input_embeds + + print(f"LM hidden_states: {hidden_states.shape}") # Reshape hidden_states to (batch_size, seq_len, num_heads, head_dim) batch_size, seq_len, _ = hidden_states.shape @@ -247,7 +254,7 @@ def forward( target_len = past_kv_cache.size + seq_len + 1 else: target_len = seq_len + 1 - causal_mask = create_4d_causal_attention_mask( + attention_mask = create_4d_causal_attention_mask( attention_mask=attention_mask, seq_len=seq_len, target_len=target_len, @@ -258,7 +265,6 @@ def forward( ) print(f"attention_mask: {attention_mask.shape}") - print(f"causal_mask: {causal_mask.shape}") # Forward pass through layers with KVCache for layer_idx in range(self.shard.start_layer, self.shard.end_layer): @@ -267,7 +273,7 @@ def forward( print(f"encoder_layer\n{encoder_layer}") layer_hidden_state, layer_kv_cache = self.layers[layer_idx]( hidden_states=hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, position_ids=position_ids, position_embeddings=position_embeddings ) @@ -277,15 +283,16 @@ def forward( print(f"layer_kv_cache: {layer_kv_cache.size}") - # Apply final layer normalization - hidden_states = self.norm(hidden_states) - # Compute prediction score from lm head if at end layer if self.shard.is_last_layer(): - pred_score = self.lm_head(hidden_states) + # Apply final layer normalization + hidden_states = self.norm(hidden_states) + pred_score = self.lm_head(hidden_states[:, -1:, :]) else: pred_score = None + print(f"end attention_mask: {attention_mask.shape}") + if pred_score is None: return pred_score, hidden_states, past_kv_cache else: diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index d61cf5b4..5a79f14d 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -17,6 +17,24 @@ TOP_K=25 TOP_P=0.9 +def check_weights(model, state_dict): + """ + Verifies that the weights from the state dictionary are properly loaded into the model. + """ + model_state_dict = model.state_dict() + for name, param in model_state_dict.items(): + if name in state_dict: + loaded_param = state_dict[name] + if param.shape != loaded_param.shape: + print(f"Shape mismatch for {name}: expected {param.shape}, got {loaded_param.shape}") + else: + print(f"{name}: loaded correctly") + else: + print(f"{name} not found in the state_dict") + + for name in state_dict: + if name not in model_state_dict: + print(f"Unexpected weight {name} found in state_dict") def test_generation(model, tokenizer, text, max_length=10, config=None): """ @@ -110,6 +128,7 @@ def test_generation(model, tokenizer, text, max_length=10, config=None): # Initialize LlamaModel with config and tokenizer model = LlamaModel(config, shard) + print(f"\nmodel: {model}") # Load weights from safetensors files in the cache directory safetensors_files = list(cache_dir.glob("*.safetensors")) @@ -120,8 +139,24 @@ def test_generation(model, tokenizer, text, max_length=10, config=None): for safetensor_file in safetensors_files: print(f"Loading weights from: {safetensor_file}") state_dict = load_safetensors(safetensor_file) - model.load_state_dict(state_dict, strict=False) + # remap to work with our model + remapped_state_dict = {} + for key, value in state_dict.items(): + # Remove the 'model.' prefix if it exists + print(f"remapping: {key}") + if key.startswith('model.'): + new_key = key[len('model.'):] # Remove 'model.' + else: + new_key = key + + remapped_state_dict[new_key] = value + + model.load_state_dict(remapped_state_dict, strict=False) + + check_weights(model, remapped_state_dict) + + #exit() model.eval() # Set the model to evaluation mode # Sample text for testing From 0fd1797d69f7b1be9e198e903090db0345eb0f9d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 31 Oct 2024 16:04:10 -0800 Subject: [PATCH 471/491] updating torch readme with current model in development --- exo/inference/torch/README.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/exo/inference/torch/README.md b/exo/inference/torch/README.md index 43b3782a..2ac5a743 100644 --- a/exo/inference/torch/README.md +++ b/exo/inference/torch/README.md @@ -51,3 +51,36 @@ GPU 4: NVIDIA Quadro P400 2GB GPU 5: NVIDIA Quadro P400 2GB ``` +## Current Model + +WIP pytorch llama model + +``` +# Llama-3.2-1B-Instruct # + +LlamaModel( + (embed): Embedding(128256, 2048) + (layers): ModuleList( + (0-15): 16 x LlamaBlock( + (self_attn): SDPAttention( + (q_proj): Linear(in_features=2048, out_features=2048, bias=False) + (k_proj): Linear(in_features=2048, out_features=512, bias=False) + (v_proj): Linear(in_features=2048, out_features=512, bias=False) + (o_proj): Linear(in_features=2048, out_features=2048, bias=False) + (rotary_emb): RotaryEmbedding() + ) + (mlp): MultiLayerPreceptron( + (gate_proj): Linear(in_features=2048, out_features=8192, bias=False) + (up_proj): Linear(in_features=2048, out_features=8192, bias=False) + (down_proj): Linear(in_features=8192, out_features=2048, bias=False) + (act_fn): SiLU() + ) + (input_layer_norm): RMSNorm() + (post_attention_norm): RMSNorm() + ) + ) + (norm): RMSNorm() + (rotary_pos_emb): RotaryEmbedding() + (lm_head): Linear(in_features=2048, out_features=128256, bias=False) +) +``` From 5aaffe6f2cc52560113cee2a6af5374617323616 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 2 Nov 2024 09:47:07 -0800 Subject: [PATCH 472/491] implemented using torchtune multiheadattention, added dot product attention but not implemented fully, added RMSNorm from modeling llama on HF, added weight renaming and loading along with handling no lm_head weight in safetensor where you then use embed weight as seen with gpt2, still not generating proper reponses further dev being done --- exo/inference/torch/models/llama3.py | 222 +++++------------ exo/inference/torch/models/llm_utils.py | 224 ++++++++++-------- .../torch/tests/test_llama3_model.py | 52 ++-- 3 files changed, 226 insertions(+), 272 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 045a790b..5f42e25b 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -7,18 +7,12 @@ import torch import torch.nn as nn -from torchtune.modules import ( - KVCache, - RMSNorm -) +import torchtune.modules as ttm from exo.inference.shard import Shard from exo.inference.torch.models.llm_utils import ( MultiLayerPreceptron, - #MultiHeadAttention, - SDPAttention, - RotaryEmbedding, - create_4d_causal_attention_mask + RMSNorm ) class LlamaBlock(nn.Module): @@ -27,39 +21,49 @@ class LlamaBlock(nn.Module): """ def __init__( self, - dim, + config, mlp, self_attn, rms_norm_eps=1e-6 ): super(LlamaBlock, self).__init__() + self.config = config self.self_attn = self_attn self.mlp = mlp - self.input_layernorm = RMSNorm(dim, eps=rms_norm_eps) - self.post_attention_norm = RMSNorm(dim, eps=rms_norm_eps) + self.input_layernorm = RMSNorm(self.config['hidden_size'], eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(self.config['hidden_size'], eps=rms_norm_eps) def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - kv_cache: Optional[KVCache] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor, position_ids: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Optional[KVCache]]: + ) -> torch.Tensor: """ Forward pass with integrated attention, resnet and key-value caching. Args: hidden_states (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). - kv_cache (Optional[KVCache]): KVCache object for managing past key-value states. - attention_mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, 1, 1, seq_len). position_ids (Optional[torch.Tensor]): Position IDs tensor of shape (batch_size, seq_len). Returns: Tuple[torch.Tensor, KVCache]: - Output tensor of shape (batch_size, seq_len, dim). - - Updated KVCache object. """ + if isinstance(self.self_attn, ttm.MultiHeadAttention): + if self.self_attn.kv_cache is None: + # setup cache + self.self_attn.setup_cache( + batch_size=hidden_states.size(0), + dtype=hidden_states.dtype, + max_seq_len=2048, #self.config['max_position_embeddings'] + ) + + # Reshape `attention_mask` to match the expected shape: [batch_size, seq_len, seq_len] + if attention_mask is not None: + attention_mask = attention_mask[:, None, :].expand(-1, hidden_states.size(1), -1).float() + print(f"reshaped attention_mask: {attention_mask.shape}") + # setting up resnet residual = hidden_states @@ -67,16 +71,10 @@ def forward( hidden_states = self.input_layernorm(hidden_states) print(f"self.input_layernorm(hidden_states) {hidden_states.shape}") - #batch_size, seq_len, _ = hidden_states.shape - #hidden_states = hidden_states.view(batch_size, seq_len, self.num_heads, self.head_dim).squeeze() - #print(f"hidden_states: {hidden_states.shape}") - - # Apply MultiHeadAttention with KVCache - hidden_states, kv_cache = self.self_attn( - hidden_states=hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - position_embeddings=position_embeddings + hidden_states = self.self_attn( + x=hidden_states, + #mask=attention_mask, + input_pos=position_ids ) # Residual connection @@ -85,19 +83,19 @@ def forward( print(f"hidden_states: {hidden_states.shape}") print(f"residual: {residual.shape}") # Post attention normalization - hidden_states = self.post_attention_norm(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) # Feed-forward network with MLP and residual connection hidden_states = self.mlp(hidden_states) hidden_states = hidden_states + residual - return hidden_states, kv_cache + return hidden_states class LlamaModel(nn.Module): """ LlamaModel is a pure PyTorch implementation of the LLaMA architecture """ - def __init__(self, config: dict, shard: Shard): + def __init__(self, config: dict, shard: Shard, is_causal=True): """ Initialize the LlamaModel. @@ -129,37 +127,42 @@ def __init__(self, config: dict, shard: Shard): self.rms_norm_eps = config['rms_norm_eps'] self.head_dim = config['head_dim'] self.attention_dropout = config.get('attention_dropout', 0.0) + self.attention_bias = config.get('attention_bias', False) self.padding_idx = config.get("pad_token_id") + self.has_lm_head_weight = False # Model layers and methods, order matters - self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) + self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) self.layers = nn.ModuleList([ LlamaBlock( - dim=self.hidden_size, + config=self.config, rms_norm_eps=self.rms_norm_eps, - self_attn=SDPAttention( - hidden_size=self.hidden_size, + self_attn=ttm.MultiHeadAttention( + embed_dim=self.hidden_size, num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.hidden_size // self.num_heads, - is_causal=True, - attention_dropout=self.attention_dropout, - rotary_emb=RotaryEmbedding( - self.head_dim - ), - attention_bias=config.get('attention_bias', False) + num_kv_heads=self.num_heads, + head_dim= self.hidden_size // self.num_heads, + q_proj=nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.attention_bias), + k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=self.attention_bias), + v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=self.attention_bias), + output_proj=nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.attention_bias), + max_seq_len=2048, #self.max_position_embeddings, + is_causal=is_causal, + attn_dropout=self.attention_dropout ), mlp=MultiLayerPreceptron( input_dim=self.hidden_size, hidden_dim=self.intermediate_size, activation=self.config.get("hidden_act", "silu"), use_bias=self.config.get("mlp_bias", False) - ), + ) ) for _ in range(self.num_layers) ]) - self.norm = RMSNorm(self.hidden_size, eps=self.rms_norm_eps) - self.rotary_emb = RotaryEmbedding( - self.head_dim + self.norm = RMSNorm(hidden_size=self.hidden_size, eps=self.rms_norm_eps) + self.rotary_emb = ttm.RotaryPositionalEmbeddings( + dim=self.hidden_size // self.num_heads, + max_seq_len=2048, #self.max_position_embeddings, + base=self.config.get('rope_theta', 10000) ) self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) @@ -168,132 +171,33 @@ def forward( input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.Tensor] = None, - cache_position: Optional[torch.Tensor] = None, - past_kv_cache: Optional[KVCache] = None, - ) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor]], Optional[KVCache]]: - """ - Forward pass with integrated position ID handling, attention mask, and optional KVCache. - - Args: - input_ids (torch.Tensor): Input token IDs of shape (batch_size, seq_len). - attention_mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, seq_len). - position_ids (Optional[torch.Tensor]): Position IDs. If None, they are calculated automatically. - cache_position (Optional[torch.LongTensor]): the positions of inputs in the sequence - past_kv_cache (Optional[KVCache]): Optional KVCache for efficient generation. - If provided, it stores past key-value states for faster autoregressive inference. - - Returns: - Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], KVCache]: - - pred_score (Optional[torch.Tensor]): Prediction scores from lm_head of model. - - hidden_states (Optional[torch.Tensor]): Hidden states from each layer - - past_kv_cache (KVCache): Updated KVCache object. - """ - batch_size, seq_len = input_ids.shape + ) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + _, seq_len = input_ids.shape - # Create initial embeddings input_embeds = self.embed_tokens(input_ids) - ## Initialize or use the provided KVCache - #if past_kv_cache is None: - # past_kv_cache = KVCache( - # batch_size=batch_size, - # max_seq_len=self.max_position_embeddings, - # num_heads=self.num_heads, - # head_dim=self.head_dim, - # dtype=input_embeds.dtype - # ) - - # Initialize position IDs if not provided - if cache_position is None: - past_seen_tokens = past_kv_cache.size if past_kv_cache is not None else 0 - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + input_embeds.shape[1], - device=input_ids.device - ) - #.unsqueeze(0).expand(batch_size, -1) - - print(f"cache_position: {cache_position.shape}") - if position_ids is None: - #position_ids = cache_position.unsqueeze(0) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids[:, -seq_len:] - # cache based input generation - if past_kv_cache is not None: - hidden_states = input_embeds[:, -cache_position.shape[0]:] - else: - hidden_states = input_embeds - - print(f"LM hidden_states: {hidden_states.shape}") - - # Reshape hidden_states to (batch_size, seq_len, num_heads, head_dim) - batch_size, seq_len, _ = hidden_states.shape - hidden_states = hidden_states.view(batch_size, seq_len, self.num_heads, self.head_dim) + print(f"LM input_embeds: {input_embeds.shape}") + print(f"LM attention_mask: {attention_mask.shape}") - # Reshape position_ids to match (batch_size, seq_len) - if position_ids.dim() != 2: - position_ids = position_ids.squeeze(0) + hidden_states = input_embeds - print(f"hidden_states: {hidden_states.shape}") - print(f"position_ids: {position_ids.shape}") - - # Apply rotary positional embeddings - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # Reshape back to (batch_size, seq_len, hidden_size) - print(f"hidden_size: {self.hidden_size}") - hidden_states = hidden_states.view(batch_size, seq_len, self.hidden_size) - print(f"hidden_states: {hidden_states.shape}") - - # create/update 4d causal mask - seq_len = input_embeds.shape[1] - - if past_kv_cache is not None: - target_len = past_kv_cache.size + seq_len + 1 - else: - target_len = seq_len + 1 - attention_mask = create_4d_causal_attention_mask( - attention_mask=attention_mask, - seq_len=seq_len, - target_len=target_len, - dtype=input_embeds.dtype, - device=input_embeds.device, - cache_pos=cache_position, - batch_size=input_embeds.size(0) - ) - - print(f"attention_mask: {attention_mask.shape}") - - # Forward pass through layers with KVCache for layer_idx in range(self.shard.start_layer, self.shard.end_layer): - print(f"forward layer #{layer_idx}") - encoder_layer = self.layers[layer_idx] - print(f"encoder_layer\n{encoder_layer}") - layer_hidden_state, layer_kv_cache = self.layers[layer_idx]( - hidden_states=hidden_states, + #print(f"forward layer #{layer_idx}") + #print(f"{self.layers[layer_idx]}") + hidden_states = self.layers[layer_idx]( + hidden_states=input_embeds, attention_mask=attention_mask, position_ids=position_ids, - position_embeddings=position_embeddings ) - hidden_states = layer_hidden_state - past_kv_cache = layer_kv_cache - - print(f"layer_kv_cache: {layer_kv_cache.size}") - - # Compute prediction score from lm head if at end layer if self.shard.is_last_layer(): - # Apply final layer normalization - hidden_states = self.norm(hidden_states) - pred_score = self.lm_head(hidden_states[:, -1:, :]) - else: - pred_score = None + pred_score = self.lm_head(self.norm(hidden_states)[:, -1:, :]) - print(f"end attention_mask: {attention_mask.shape}") + return pred_score, None - if pred_score is None: - return pred_score, hidden_states, past_kv_cache - else: - return pred_score, None, past_kv_cache + return None, hidden_states diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 459823ca..b4c0658a 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -11,8 +11,6 @@ import torchtune.modules as ttm import math -from transformers.models.mamba.modeling_mamba import causal_conv1d_update - from exo.helpers import DEBUG def load_model_config(model_config_path: Path) -> dict: @@ -90,60 +88,6 @@ def select_next_token( return next_token -class MultiLayerPreceptron(nn.Module): - def __init__( - self, - input_dim, - hidden_dim, - activation='gelu', - use_bias=False - ): - """ - General MLP (Multi-Layer Perceptron) module. - - Args: - input_dim (int): Dimensionality of the input. - hidden_dims (int): Hidden layer/intermediate dimensions. - output_dim (int): Dimensionality of the output. - activation (str): Activation function ('relu', 'gelu', 'tanh', 'sigmoid', etc.). - dropout (float): Dropout probability. - use_batchnorm (bool): Whether to use batch normalization. - """ - super(MultiLayerPreceptron, self).__init__() - - # Activation function mapping - activations = { - 'relu': nn.ReLU(), - 'gelu': nn.GELU(), - 'tanh': nn.Tanh(), - 'sigmoid': nn.Sigmoid(), - 'leaky_relu': nn.LeakyReLU(0.2), - 'silu': nn.SiLU() - } - - # Ensure valid activation - if activation not in activations: - raise ValueError(f"Invalid activation: {activation}. Choose from {list(activations.keys())}") - - # Construct MLP layers - self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) - self.up_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) - self.down_proj = nn.Linear(hidden_dim, input_dim, bias=use_bias) - self.act_fn = activations[activation] - - def forward(self, x) -> torch.Tensor: - """ - Forward pass for the MLP module. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor after the MLP transformations. - """ - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - def create_4d_causal_attention_mask( attention_mask: torch.Tensor, seq_len: int, @@ -269,6 +213,83 @@ def forward(self, x, position_ids) -> Tuple[torch.Tensor, torch.Tensor]: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class MultiLayerPreceptron(nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + activation='silu', + use_bias=False + ): + """ + General MLP (Multi-Layer Perceptron) module. + + Args: + input_dim (int): Dimensionality of the input. + hidden_dims (int): Hidden layer/intermediate dimensions. + output_dim (int): Dimensionality of the output. + activation (str): Activation function ('relu', 'gelu', 'tanh', 'sigmoid', etc.). + dropout (float): Dropout probability. + use_batchnorm (bool): Whether to use batch normalization. + """ + super(MultiLayerPreceptron, self).__init__() + + # Activation function mapping + activations = { + 'relu': nn.ReLU(), + 'gelu': nn.GELU(), + 'tanh': nn.Tanh(), + 'sigmoid': nn.Sigmoid(), + 'leaky_relu': nn.LeakyReLU(0.2), + 'silu': nn.SiLU() + } + + # Ensure valid activation + if activation not in activations: + raise ValueError(f"Invalid activation: {activation}. Choose from {list(activations.keys())}") + + # Construct MLP layers + self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) + self.up_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) + self.down_proj = nn.Linear(hidden_dim, input_dim, bias=use_bias) + self.act_fn = activations[activation] + + def forward(self, x) -> torch.Tensor: + """ + Forward pass for the MLP module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after the MLP transformations. + """ + + return self.down_proj( + self.act_fn( + self.gate_proj(x) + ) * self.up_proj(x) + ) + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + RMSNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + # ------------------ # Attention Methods # ------------------ @@ -289,7 +310,6 @@ def __init__( num_kv_heads, head_dim, rotary_emb, - kv_cache: Optional[ttm.KVCache] = None, attention_dropout=0.0, is_causal=True, attention_bias=False @@ -301,7 +321,6 @@ def __init__( self.head_dim = head_dim self.attention_dropout = attention_dropout self.is_causal = is_causal - self.kv_cache = kv_cache # nn layers self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) @@ -316,8 +335,9 @@ def forward( position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache: Optional[ttm.KVCache] = None, cos_sin_unsqueeze: int=1 - ) -> Tuple[torch.Tensor, ttm.KVCache]: + ) -> Tuple[torch.Tensor, Optional[ttm.KVCache]]: batch_size, seq_len, _ = hidden_states.size() # Project to queries, keys, and values @@ -357,22 +377,25 @@ def forward( print(f"value_states: {value_states.shape}") # Forcing caching always enabled - if self.kv_cache is not None: - print(f"self.kv_cache.size {self.kv_cache.size}") - print(f"key_states.size(0) {key_states.size(2)}") - if self.kv_cache is None or self.kv_cache.size != key_states.size(2): - print(f"\n MAKE NEW KVCACHE batch_size={key_states.size(0)} max_seq_len={key_states.size(2)}") - self.kv_cache = ttm.KVCache( - batch_size=key_states.size(0), - max_seq_len=key_states.size(2), - num_heads=self.num_kv_heads, - head_dim=self.head_dim, - dtype=hidden_states.dtype - ) - key_states, value_states = self.kv_cache.update(key_states, value_states) - print(f"kv_cache: {self.kv_cache.size}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") + if kv_cache is not None: + #print(f"kv_cache.size {kv_cache.size}") + + #print(f"key_states.size(2) {key_states.size(2)}") + + #if kv_cache.size != key_states.size(2): + # print(f"\n MAKE NEW KVCACHE batch_size={key_states.size(0)} max_seq_len={key_states.size(2)}") + # kv_cache = ttm.KVCache( + # batch_size=key_states.size(0), + # max_seq_len=key_states.size(2), + # num_heads=self.num_kv_heads, + # head_dim=self.head_dim, + # dtype=hidden_states.dtype + # ) + + key_states, value_states = kv_cache.update(key_states, value_states) + print(f"kv_cache: {kv_cache.size}") + print(f"key_states: {key_states.shape}") + print(f"value_states: {value_states.shape}") # Repeat keys and values if needed #if self.num_heads > self.num_kv_heads: @@ -417,7 +440,7 @@ def forward( attn_output = self.o_proj(attn_output) print(f"attn_output: {attn_output.shape}") - return attn_output, self.kv_cache + return attn_output, kv_cache class SDPAttention(nn.Module): """ @@ -434,7 +457,6 @@ def __init__( num_kv_heads, head_dim, rotary_emb, - kv_cache: Optional[ttm.KVCache] = None, attention_dropout=0.0, is_causal=True, attention_bias=False @@ -446,7 +468,6 @@ def __init__( self.head_dim = head_dim self.attention_dropout = attention_dropout self.is_causal = is_causal - self.kv_cache = kv_cache # nn layers self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) @@ -461,8 +482,9 @@ def forward( position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache: Optional[ttm.KVCache] = None, cos_sin_unsqueeze: int=1 - ) -> Tuple[torch.Tensor, ttm.KVCache]: + ) -> Tuple[torch.Tensor, Optional[ttm.KVCache]]: batch_size, seq_len, _ = hidden_states.size() # Project to queries, keys, and values @@ -502,22 +524,30 @@ def forward( print(f"value_states: {value_states.shape}") # Forcing caching always enabled - if self.kv_cache is not None: - print(f"self.kv_cache.size {self.kv_cache.size}") - print(f"key_states.size(0) {key_states.size(2)}") - if self.kv_cache is None or self.kv_cache.size != key_states.size(2): - print(f"\n MAKE NEW KVCACHE batch_size={key_states.size(0)} max_seq_len={key_states.size(2)}") - self.kv_cache = ttm.KVCache( - batch_size=key_states.size(0), - max_seq_len=key_states.size(2), - num_heads=self.num_kv_heads, - head_dim=self.head_dim, - dtype=hidden_states.dtype - ) - key_states, value_states = self.kv_cache.update(key_states, value_states) - print(f"kv_cache: {self.kv_cache.size}") - print(f"from kv_cache / key_states: {key_states.shape}") - print(f"from kv_cache / value_states: {value_states.shape}") + if kv_cache is not None: + #print(f"kv_cache.size {kv_cache.size}") + #print(f"key_states.size(0) {key_states.size(2)}") + + #if kv_cache.size != key_states.size(2): + # print(f"\n MAKE NEW KVCACHE batch_size={key_states.size(0)} max_seq_len={key_states.size(2)}") + # kv_cache = ttm.KVCache( + # batch_size=key_states.size(0), + # max_seq_len=key_states.size(2), + # num_heads=self.num_kv_heads, + # head_dim=self.head_dim, + # dtype=hidden_states.dtype + # ) + + key_states, value_states = kv_cache.update(key_states, value_states) + + # **Slice KVCache to match `query_states` length** + key_states = key_states[:, :, :seq_len, :] + value_states = value_states[:, :, :seq_len, :] + + # kv_cache.update(key_states, value_states) + print(f"kv_cache: {kv_cache.size}") + print(f"from kv_cache / key_states: {key_states.shape}") + print(f"from kv_cache / value_states: {value_states.shape}") # Repeat keys and values if needed #if self.num_heads > self.num_kv_heads: @@ -550,7 +580,7 @@ def forward( key_states, value_states, attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + dropout_p=0.0, is_causal=is_causal, ) @@ -563,5 +593,5 @@ def forward( print(f"attn_output: {attn_output.shape}") - return attn_output, self.kv_cache + return attn_output, kv_cache diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 5a79f14d..b5f07d00 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -1,6 +1,7 @@ """ Test of pytorch based llama3 model """ +import re from pathlib import Path import torch @@ -14,7 +15,7 @@ MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct" TEMP=0.7 -TOP_K=25 +TOP_K=35 TOP_P=0.9 def check_weights(model, state_dict): @@ -22,6 +23,7 @@ def check_weights(model, state_dict): Verifies that the weights from the state dictionary are properly loaded into the model. """ model_state_dict = model.state_dict() + print(f"model_state_dict: {model_state_dict.keys()}") for name, param in model_state_dict.items(): if name in state_dict: loaded_param = state_dict[name] @@ -29,8 +31,8 @@ def check_weights(model, state_dict): print(f"Shape mismatch for {name}: expected {param.shape}, got {loaded_param.shape}") else: print(f"{name}: loaded correctly") - else: - print(f"{name} not found in the state_dict") + #else: + # print(f"{name} not found in the state_dict") for name in state_dict: if name not in model_state_dict: @@ -65,44 +67,46 @@ def test_generation(model, tokenizer, text, max_length=10, config=None): generated_ids = input_ids.clone() # Generate tokens step-by-step - past_kvs = None - print(f"{model}") for _ in range(max_length): with torch.no_grad(): - pred_score, hstates, past_kvs = model( + pred_score, hstates = model( generated_ids, - attention_mask=attention_mask, - past_kv_cache=past_kvs + attention_mask=attention_mask ) + print("\n\n------------------------------------------------------") print(f"pred_score: {pred_score.shape}") print(f"hstates: {hstates.shape if hstates is not None else None}") - print(f"past_kvs: {past_kvs.size if past_kvs is not None else None}") # Select next token using pred_score - #next_token = select_next_token(pred_score, top_k=TOP_K, top_p=TOP_P, temperature=TEMP, use_max=False) - next_token = ttg.sample(pred_score, temperature=TEMP, top_k=TOP_K)[:, -1, :] + next_token = select_next_token(pred_score, top_k=TOP_K, top_p=TOP_P, temperature=TEMP, use_max=False) + #next_token = ttg.sample(pred_score, temperature=TEMP, top_k=TOP_K)[:, -1, :] print(f"next_token: {next_token}") # Update generated_ids generated_ids = torch.cat([generated_ids, next_token], dim=1) - print(f"generated_ids: {generated_ids}") + print(f"generated_ids: {generated_ids.shape}") + + # Update attention mask + #attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=attention_mask.device)], dim=1) + print(f"attention_mask: {attention_mask.shape}") # Check for EOS token print(f"next_token.item(): {next_token.item()}") if config: - print(config["eos_token_id"]) if next_token.item() in config["eos_token_id"]: break else: if next_token.item() == tokenizer.eos_token_id: break - # Decode generated text - generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - print(f"\n\n\n\nGenerated Response: {generated_text}") + # Decode generated text + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + print(f"\n\n\n\nGenerated Response: {generated_text}") + + print("\n\n------------------------------------------------------") if __name__ == "__main__": print("\nTesting generation:") @@ -142,6 +146,7 @@ def test_generation(model, tokenizer, text, max_length=10, config=None): # remap to work with our model remapped_state_dict = {} + tied_embed_weight = None for key, value in state_dict.items(): # Remove the 'model.' prefix if it exists print(f"remapping: {key}") @@ -150,8 +155,23 @@ def test_generation(model, tokenizer, text, max_length=10, config=None): else: new_key = key + # change o_proj to output_proj + re_o_proj = re.findall(r'layers.(\d+).(\w+).(o_proj).(\w+)', new_key) + if len(re_o_proj) != 0: + new_key = f"layers.{re_o_proj[0][0]}.{re_o_proj[0][1]}.output_proj.weight" + remapped_state_dict[new_key] = value + # saving embed for tied weights + if new_key == 'embed_tokens.weight': + tied_embed_weight = value + + if new_key == 'lm_head.weight': + model.has_lm_head_weight = True + + if not model.has_lm_head_weight: + remapped_state_dict['lm_head.weight'] = tied_embed_weight + model.load_state_dict(remapped_state_dict, strict=False) check_weights(model, remapped_state_dict) From b2b63c3c863fca640c75237775699a3856da1d3b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 3 Nov 2024 04:55:36 -0900 Subject: [PATCH 473/491] FINALLY A WORKING PYTORCH ONLY MODEL, working on logit gen, shard testing and then inference engine testing but we are almost there. HELL YEAAAAAAAAAAA --- exo/inference/torch/models/llama3.py | 238 ++++++++-------- exo/inference/torch/models/llm_utils.py | 200 +++++++++----- .../torch/tests/test_llama3_model.py | 257 +++++++++--------- 3 files changed, 381 insertions(+), 314 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 5f42e25b..2feb5e18 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -3,11 +3,13 @@ Written with pytorch using torchtune and other methods """ -from typing import Optional, Tuple +from typing import Tuple, List import torch import torch.nn as nn import torchtune.modules as ttm +import torchtune.generation as ttg +from torchtune.models.llama3_1 import Llama3ScaledRoPE from exo.inference.shard import Shard from exo.inference.torch.models.llm_utils import ( @@ -37,7 +39,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, - position_ids: Optional[torch.Tensor] = None + max_seq_len: int = 2048 ) -> torch.Tensor: """ Forward pass with integrated attention, resnet and key-value caching. @@ -56,148 +58,146 @@ def forward( self.self_attn.setup_cache( batch_size=hidden_states.size(0), dtype=hidden_states.dtype, - max_seq_len=2048, #self.config['max_position_embeddings'] + max_seq_len=max_seq_len, #self.config['max_position_embeddings'] ) - # Reshape `attention_mask` to match the expected shape: [batch_size, seq_len, seq_len] - if attention_mask is not None: - attention_mask = attention_mask[:, None, :].expand(-1, hidden_states.size(1), -1).float() - print(f"reshaped attention_mask: {attention_mask.shape}") - - # setting up resnet - residual = hidden_states - # Apply RMSNorm to input hidden_states = self.input_layernorm(hidden_states) print(f"self.input_layernorm(hidden_states) {hidden_states.shape}") + # get causal mask from attention mask + causal_mask = ttg.get_causal_mask_from_padding_mask( + attention_mask.bool(), + max_seq_len + ) + + print(f"causal_mask: {causal_mask.shape}") + + # get position_ids from attention mask + position_ids = ttg.get_position_ids_from_padding_mask( + attention_mask.bool() + ) + + print(f"position_ids: {position_ids.shape}") + hidden_states = self.self_attn( x=hidden_states, - #mask=attention_mask, - input_pos=position_ids + y=hidden_states, + mask=causal_mask, + #input_pos=position_ids ) # Residual connection - hidden_states = residual + hidden_states - residual = hidden_states print(f"hidden_states: {hidden_states.shape}") - print(f"residual: {residual.shape}") # Post attention normalization hidden_states = self.post_attention_layernorm(hidden_states) # Feed-forward network with MLP and residual connection hidden_states = self.mlp(hidden_states) - hidden_states = hidden_states + residual return hidden_states -class LlamaModel(nn.Module): +def LlamaModel( + config: dict, + shard: Shard, + is_causal: bool=True, + max_seq_len: int=4096 +): """ - LlamaModel is a pure PyTorch implementation of the LLaMA architecture + LlamaModel using torchtune """ - - def __init__(self, config: dict, shard: Shard, is_causal=True): - """ - Initialize the LlamaModel. - - Args: - config (dict): Configuration dictionary containing model parameters. - - hidden_size (int): Size of the hidden layers. - - num_hidden_layers (int): Number of transformer layers. - - num_attention_heads (int): Number of attention heads. - - intermediate_size (int): Size of the intermediate (feed-forward) layers. - - vocab_size (int): Vocabulary size for the embedding layer. - - max_position_embeddings (int): Maximum number of positional embeddings. - - rms_norm_eps (float): Epsilon for RMS normalization. - - head_dim (int): Dimension of each attention head. - - attention_dropout (float): Dropout rate for attention layers. - """ - super(LlamaModel, self).__init__() - - self.shard = shard - - # Load configurations from config - self.config = config - self.hidden_size = config['hidden_size'] - self.num_layers = config['num_hidden_layers'] - self.num_heads = config['num_attention_heads'] - self.num_kv_heads = config['num_key_value_heads'] - self.intermediate_size = config['intermediate_size'] - self.vocab_size = config['vocab_size'] - self.max_position_embeddings = config['max_position_embeddings'] - self.rms_norm_eps = config['rms_norm_eps'] - self.head_dim = config['head_dim'] - self.attention_dropout = config.get('attention_dropout', 0.0) - self.attention_bias = config.get('attention_bias', False) - self.padding_idx = config.get("pad_token_id") - self.has_lm_head_weight = False - - # Model layers and methods, order matters - self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([ - LlamaBlock( - config=self.config, - rms_norm_eps=self.rms_norm_eps, - self_attn=ttm.MultiHeadAttention( - embed_dim=self.hidden_size, - num_heads=self.num_heads, - num_kv_heads=self.num_heads, - head_dim= self.hidden_size // self.num_heads, - q_proj=nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.attention_bias), - k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=self.attention_bias), - v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=self.attention_bias), - output_proj=nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.attention_bias), - max_seq_len=2048, #self.max_position_embeddings, - is_causal=is_causal, - attn_dropout=self.attention_dropout - ), - mlp=MultiLayerPreceptron( - input_dim=self.hidden_size, - hidden_dim=self.intermediate_size, - activation=self.config.get("hidden_act", "silu"), - use_bias=self.config.get("mlp_bias", False) - ) - ) for _ in range(self.num_layers) - ]) - self.norm = RMSNorm(hidden_size=self.hidden_size, eps=self.rms_norm_eps) - self.rotary_emb = ttm.RotaryPositionalEmbeddings( - dim=self.hidden_size // self.num_heads, - max_seq_len=2048, #self.max_position_embeddings, - base=self.config.get('rope_theta', 10000) + print(shard) + + # Load configurations from config + rope_scaling = config.get("rope_scaling") + hidden_head_dim = config["hidden_size"] // config["num_attention_heads"] + + # Model layers and methods, order matters + embed_tokens = nn.Embedding( + config["vocab_size"], + config["hidden_size"] + ) + + layers = [] + for _ in range(shard.n_layers): + pos_embeddings = Llama3ScaledRoPE( + dim=hidden_head_dim, + max_seq_len=max_seq_len, + base=config.get('rope_theta', 10000), + scale_factor=rope_scaling['factor'] if rope_scaling else 32 ) - self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, - ) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - _, seq_len = input_ids.shape - input_embeds = self.embed_tokens(input_ids) - - if position_ids is None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids[:, -seq_len:] - - print(f"LM input_embeds: {input_embeds.shape}") - print(f"LM attention_mask: {attention_mask.shape}") + self_attn = ttm.MultiHeadAttention( + embed_dim=config["hidden_size"], + num_heads=config["num_attention_heads"], + num_kv_heads=config["num_key_value_heads"], + head_dim=hidden_head_dim, + q_proj=nn.Linear( + config["hidden_size"], + config["num_attention_heads"] * config["head_dim"], + bias=config.get('attention_bias', False) + ), + k_proj = nn.Linear( + config["hidden_size"], + config["num_key_value_heads"] * config["head_dim"], + bias=config.get('attention_bias', False) + ), + v_proj = nn.Linear( + config["hidden_size"], + config["num_key_value_heads"] * config["head_dim"], + bias=config.get('attention_bias', False) + ), + output_proj=nn.Linear( + config["hidden_size"], + config["hidden_size"], + bias=config.get('attention_bias', False) + ), + max_seq_len=max_seq_len, + is_causal=is_causal, + attn_dropout=config.get('attention_dropout', 0.0), + pos_embeddings=pos_embeddings + ) - hidden_states = input_embeds + mlp = MultiLayerPreceptron( + config["hidden_size"], + config['intermediate_size'], + 'silu' + ) - for layer_idx in range(self.shard.start_layer, self.shard.end_layer): - #print(f"forward layer #{layer_idx}") - #print(f"{self.layers[layer_idx]}") - hidden_states = self.layers[layer_idx]( - hidden_states=input_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - ) + layer = ttm.TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]), + mlp_norm=RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) + ) - if self.shard.is_last_layer(): - pred_score = self.lm_head(self.norm(hidden_states)[:, -1:, :]) + layers.append(layer) + + return ttm.TransformerDecoder( + tok_embeddings=embed_tokens, + layers=nn.ModuleList(layers), + max_seq_len=max_seq_len, + num_heads=config["num_attention_heads"], + head_dim=config["head_dim"], + norm=RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]), + output=nn.Linear(config["hidden_size"], config["vocab_size"]), + num_layers=shard.n_layers, + #output_hidden_states=list(range(shard.start_layer, shard.end_layer)) + ) + +class ShardedLlamaModel(nn.Module): + def __init__(self, config: dict, shard: Shard, is_causal=True): + super(ShardedLlamaModel, self).__init__() - return pred_score, None + self.shard = shard + self.config = config + self.model = LlamaModel(config, shard, is_causal) - return None, hidden_states + def generate( + self, + prompt: torch.Tensor + ): + """ + move login being done in test_llama3_model for generation to here + along with test sharding + """ + pass diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index b4c0658a..0946890e 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -1,9 +1,10 @@ """ Utility methods used by LLMs """ +import re import json from pathlib import Path -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch import torch.nn as nn @@ -11,7 +12,18 @@ import torchtune.modules as ttm import math +from safetensors.torch import load_file as load_safetensors + +from transformers import ( + LogitsProcessorList, + TopKLogitsWarper, + TopPLogitsWarper, + TemperatureLogitsWarper +) +from transformers.cache_utils import Cache, DynamicCache + from exo.helpers import DEBUG +from exo.inference.shard import Shard def load_model_config(model_config_path: Path) -> dict: """ @@ -28,65 +40,127 @@ def load_model_config(model_config_path: Path) -> dict: model_config = json.load(f) return model_config -def select_next_token( - logits, - top_k=0, - top_p=0.0, - temperature=1.0, - use_max=False, -): +def check_weights(model, state_dict): """ - Selects the next token from logits using top-k, top-p, and temperature scaling. - - Args: - logits (torch.Tensor): Logits or prediction scores tensor of shape (batch_size, vocab_size). - top_k (int): Number of top logits to consider for sampling. - top_p (float): Cumulative probability threshold for nucleus sampling. - temperature (float): Scaling factor for temperature. - use_max (bool): Whether to use argmax for next token selection. - debug (bool): If True, prints debugging information. - - Returns: - next_token (torch.Tensor): The next token selected (batch_size,). + Verifies that the weights from the state dictionary are properly loaded into the model. """ - # Get logits for the last token in the sequence - logits = logits[:, -1, :].clone().float() - - # Apply temperature scaling - if temperature != 1.0: - logits = logits / temperature - - # Apply top-k filtering - if top_k > 0: - top_k = min(top_k, logits.size(-1)) - min_topk = torch.topk(logits, top_k)[0][..., -1, None] - logits = logits.masked_fill(logits < min_topk, float("-inf")) + model_state_dict = model.state_dict() + for name, param in model_state_dict.items(): + if name in state_dict: + print(f"\nchecking {name}\n") + loaded_param = state_dict[name] + if param.shape != loaded_param.shape: + print(f"Shape mismatch for {name}: expected {param.shape}, got {loaded_param.shape}") + else: + print(f"{name}: loaded correctly") + + for name in state_dict: + if name not in model_state_dict: + print(f"Unexpected weight {name} found in state_dict") + +def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): + """ + Loads weights from huggingface and changes it to match torchtune naming structure + """ + # Load weights from safetensors files in the cache directory + safetensors_files = list(cache_dir.glob("*.safetensors")) + if not safetensors_files: + raise FileNotFoundError("No safetensors files found in the cache directory.") + + # Load weights from each found safetensors file + stitch_lmhead = True + for safetensor_file in safetensors_files: + state_dict = load_safetensors(safetensor_file) + + # remap to work with our model + remapped_state_dict = {} + tied_embed_weight = None + for key, value in state_dict.items(): + # load layer by shard + lnrgx = re.findall(r'model\.layers\.(\d+).*', key) + layer_num = int(lnrgx[0]) if len(lnrgx) > 0 else None + shard_layer_range = list(range(shard.start_layer, shard.end_layer)) + if layer_num in shard_layer_range: + # change input layer norm to sa_norm for torchtune + re_iln = re.findall( + rf'model.layers\.{layer_num}\.(input_layernorm)\.weight', key) + if len(re_iln) != 0: + key = f"model.layers.{layer_num}.sa_norm.weight" + + # change post attention layernorm to mlp_norm for torchtune + re_pal = re.findall( + rf'model.layers\.{layer_num}\.(post_attention_layernorm)\.weight', key) + if len(re_pal) != 0: + key = f"model.layers.{layer_num}.mlp_norm.weight" + + # change o_proj to output_proj + re_o_proj = re.findall(rf'model\.layers\.{layer_num}.(\w+)\.o_proj\.weight', key) + if len(re_o_proj) != 0: + key = f"model.layers.{layer_num}.{re_o_proj[0]}.output_proj.weight" + + # change self_attn to attn + re_attn = re.findall(rf'model\.layers\.{layer_num}.(\w+)\.(\w+)\.(\w+)', key) + if len(re_attn) != 0 and re_attn[0][0] == "self_attn": + key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}" + + # saving embed for tied weights + elif key == 'model.embed_tokens.weight': + tied_embed_weight = value + # change name for torchtune + key = 'model.tok_embeddings.weight' + + elif key == 'lm_head.weight': + stitch_lmhead = False + # change key for torchtune + key = 'model.output.weight' + + elif key == 'model.norm.weight': + key = 'model.norm.weight' + + remapped_state_dict[key] = value + + if stitch_lmhead: + remapped_state_dict['model.output.weight'] = tied_embed_weight + + model.load_state_dict(remapped_state_dict, strict=False) + + #if DEBUG >= 7: + print("\n--- checking weights ----\n") + check_weights(model, remapped_state_dict) + +def hf_logit_sample( + logits, + input_ids, + use_max: bool=False, + top_k: int=0, + top_p: float=0.9, + temp: float=1.0, +) -> torch.Tensor: + """ + Logit sampling using transformers + """ + logits_processor = LogitsProcessorList([ + TopKLogitsWarper(top_k), + TemperatureLogitsWarper(temp), + TopPLogitsWarper(top_p) + ]) - # Apply top-p (nucleus) filtering - if top_p > 0.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=False) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) - sorted_indices_to_remove = cumulative_probs <= (1 - top_p) - sorted_indices_to_remove[..., -1:] = 0 + # get a single cloned logit + logits = logits[:, -1, :].clone().float() - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - logits = logits.masked_fill(indices_to_remove, float('-inf')) + next_token_scores = logits_processor(input_ids, logits) - # Select next token if not use_max: - probs = F.softmax(logits, dim=-1) + probs = nn.functional.softmax(next_token_scores, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: - next_token = torch.argmax(logits, dim=-1, keepdim=True) + next_token = torch.argmax(next_token_scores, dim=-1) - next_token = next_token[:, None].squeeze(-1) - - # Debugging output if DEBUG >= 4: - print(f"Logits: {logits}") - print(f"Next token: {next_token}") + print(f"input_ids: {input_ids}") + print(f"next_token: {next_token}") - return next_token + return next_token[:, None].squeeze(-1) def create_4d_causal_attention_mask( attention_mask: torch.Tensor, @@ -98,7 +172,7 @@ def create_4d_causal_attention_mask( batch_size: int, ) -> torch.Tensor: """ - Creates a 4D causal attention mask from a 2D mask, with adjustments for static caching. + Creates a 4D causal attention mask from a 2D mask Args: attention_mask (torch.Tensor): @@ -142,17 +216,16 @@ def create_4d_causal_attention_mask( # Expand to 4D and batch size causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - # Create a padding mask based on the input attention_mask - mask_len = attention_mask.shape[-1] - causal_mask = causal_mask.clone() # Ensure contiguous memory for in-place operations - padding_mask = causal_mask[:, :, :, :mask_len] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 + # Create a padding mask based on the input attention_mask + mask_len = attention_mask.shape[-1] + causal_mask = causal_mask.clone() # Ensure contiguous memory for in-place operations + padding_mask = causal_mask[:, :, :, :mask_len] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 - # Apply padding to the causal mask - causal_mask[:, :, :, :mask_len] = causal_mask[:, :, :, :mask_len].masked_fill( - padding_mask, min_value - ) + # Apply padding to the causal mask + causal_mask[:, :, :, :mask_len] = causal_mask[:, :, :, :mask_len].masked_fill( + padding_mask, min_value + ) return causal_mask @@ -278,18 +351,15 @@ def __init__(self, hidden_size, eps=1e-6): """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps + self.eps = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return self.weight * hidden_states.to(input_dtype) - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - # ------------------ # Attention Methods # ------------------ diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index b5f07d00..32e8fa1d 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -1,112 +1,148 @@ """ Test of pytorch based llama3 model """ -import re from pathlib import Path import torch -import torchtune.generation as ttg from transformers import AutoTokenizer from huggingface_hub import snapshot_download -from safetensors.torch import load_file as load_safetensors -from exo.inference.torch.models.llm_utils import load_model_config, select_next_token -from exo.inference.torch.models.llama3 import LlamaModel + +import torchtune.generation as ttg +from torchtune.models import llama3 +from torchtune.data import Message + +from exo.inference.torch.models.llm_utils import ( + load_model_config, + hf_logit_sample, + load_model_weights_torchtune, + create_4d_causal_attention_mask +) +from exo.inference.torch.models.llama3 import ShardedLlamaModel from exo.inference.shard import Shard -MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct" -TEMP=0.7 + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +TEMP=0.6 TOP_K=35 TOP_P=0.9 +MAX_SEQ_LEN=2048 -def check_weights(model, state_dict): - """ - Verifies that the weights from the state dictionary are properly loaded into the model. - """ - model_state_dict = model.state_dict() - print(f"model_state_dict: {model_state_dict.keys()}") - for name, param in model_state_dict.items(): - if name in state_dict: - loaded_param = state_dict[name] - if param.shape != loaded_param.shape: - print(f"Shape mismatch for {name}: expected {param.shape}, got {loaded_param.shape}") - else: - print(f"{name}: loaded correctly") - #else: - # print(f"{name} not found in the state_dict") - - for name in state_dict: - if name not in model_state_dict: - print(f"Unexpected weight {name} found in state_dict") - -def test_generation(model, tokenizer, text, max_length=10, config=None): +def test_generation(text, max_length=10, config=None): """ Test the generation capabilities of the LlamaModel with sample text. """ # Tokenize input text - prompt = tokenizer.apply_chat_template([ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": text - } - ], tokenize=False, add_generation_prompt=True) - - print(f"prompt: {prompt}") + messages = [] + messages.extend( + [ + Message(role="user", content=text), + # Empty assistant message to kick-start generation + Message(role="assistant", content=""), + ] + ) - inputs = tokenizer(prompt, return_tensors="pt") - input_ids = inputs.get("input_ids") - attention_mask = inputs.get("attention_mask") + tokenizer_out = llama_tokenizer({"messages": messages}, inference=True) + print(f"tokenizer_out: {tokenizer_out}") + tokens = tokenizer_out["tokens"] + prompt = torch.tensor(tokens, dtype=torch.int) + + if prompt.ndim == 1: + prompt = prompt.view(1, -1) + + bsz, prompt_length = prompt.size() + total_response_length = prompt_length + MAX_SEQ_LEN + generated_tokens = prompt.clone() + resp_max_seq_len = ( + total_response_length + if not shard_model.model.caches_are_enabled() + else shard_model.model.decoder_max_cache_seq_len + ) - print(f"input_ids: {input_ids}") - print(f"attention_mask: {attention_mask}") + # masking for proper attention + padding_masks = prompt != llama_tokenizer.pad_id + if not padding_masks.all(): + padding_masks = torch.nn.functional.pad( + padding_masks, + (0, MAX_SEQ_LEN), + value=True + ) + + masks = ttg.get_causal_mask_from_padding_mask( + padding_masks, + target_seq_len=resp_max_seq_len + ) + + input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) + else: + masks = torch.tril( + torch.ones( + total_response_length, + resp_max_seq_len if resp_max_seq_len is not None else MAX_SEQ_LEN, + dtype=torch.bool, + device=prompt.device, + ) + ).unsqueeze(0) + + input_pos = torch.arange( + 0, total_response_length, device=prompt.device + ).unsqueeze(0) + + if shard_model.model.caches_are_enabled(): + curr_masks = masks[:, :prompt_length] + else: + curr_masks = masks[:, :prompt_length, :prompt_length] + + print(f"padding_masks: {padding_masks.shape}") + print(padding_masks.all()) + + next_token, gen_logits = ttg.generate_next_token( + shard_model.model, + input_pos=input_pos[:, :prompt_length].squeeze(), + x=prompt, + mask=curr_masks, + q=torch.empty( + ( + prompt.size(0), + shard_model.model.tok_embeddings.num_embeddings + ), device=prompt.device + ).exponential_(1, generator=None) + ) - # Start with initial input_ids - generated_ids = input_ids.clone() + print(f"next_token: {next_token}") - # Generate tokens step-by-step - print(f"{model}") + generated_tokens = torch.cat([generated_tokens, next_token], dim=-1) - for _ in range(max_length): - with torch.no_grad(): - pred_score, hstates = model( - generated_ids, - attention_mask=attention_mask - ) + print(f"generated_tokens: {generated_tokens}") - print("\n\n------------------------------------------------------") - print(f"pred_score: {pred_score.shape}") - print(f"hstates: {hstates.shape if hstates is not None else None}") - # Select next token using pred_score - next_token = select_next_token(pred_score, top_k=TOP_K, top_p=TOP_P, temperature=TEMP, use_max=False) - #next_token = ttg.sample(pred_score, temperature=TEMP, top_k=TOP_K)[:, -1, :] - print(f"next_token: {next_token}") + curr_pos = prompt_length - # Update generated_ids - generated_ids = torch.cat([generated_ids, next_token], dim=1) - print(f"generated_ids: {generated_ids.shape}") + # stop tokens logic + stop_tokens = None + stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) + stop_tokens = ( + torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype) + if stop_tokens + else None + ) + stop_token_mask = torch.ones( + (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device + ) - # Update attention mask - #attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=attention_mask.device)], dim=1) - print(f"attention_mask: {attention_mask.shape}") + # finish writing stop token logic using torchtune generation + # ref https://github.com/pytorch/torchtune/blob/main/torchtune/generation/_generation.py#L337 - # Check for EOS token - print(f"next_token.item(): {next_token.item()}") + for _ in range(max_length): - if config: - if next_token.item() in config["eos_token_id"]: - break + if shard_model.model.caches_are_enabled(): + curr_input_pos = input_pos[:, curr_pos] + curr_masks = masks[:, curr_pos, None, :] else: - if next_token.item() == tokenizer.eos_token_id: - break + tokens = generated_tokens.clone() + curr_input_pos = input_pos[:, : curr_pos + 1] + curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1] - # Decode generated text - generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - print(f"\n\n\n\nGenerated Response: {generated_text}") - - print("\n\n------------------------------------------------------") + generated_tokens = generated_tokens.tolist() + print(f"resp: {llama_tokenizer.decode(generated_tokens[0])}") if __name__ == "__main__": print("\nTesting generation:") @@ -123,64 +159,25 @@ def test_generation(model, tokenizer, text, max_length=10, config=None): shard = Shard( model_id=MODEL_NAME, start_layer=0, - end_layer=int(config["num_hidden_layers"]) - 1, + end_layer=int(config["num_hidden_layers"]), n_layers=int(config["num_hidden_layers"]) ) # Initialize tokenizer - tokenizer = AutoTokenizer.from_pretrained(shard.model_id) + llama_tokenizer_path = f"{cache_dir}/original/tokenizer.model" + llama_tokenizer = llama3.llama3_tokenizer(path=llama_tokenizer_path) + #tokenizer = AutoTokenizer.from_pretrained( + # MODEL_NAME, + # add_eos_token=True + #) # Initialize LlamaModel with config and tokenizer - model = LlamaModel(config, shard) - print(f"\nmodel: {model}") - - # Load weights from safetensors files in the cache directory - safetensors_files = list(cache_dir.glob("*.safetensors")) - if not safetensors_files: - raise FileNotFoundError("No safetensors files found in the cache directory.") - - # Load weights from each found safetensors file - for safetensor_file in safetensors_files: - print(f"Loading weights from: {safetensor_file}") - state_dict = load_safetensors(safetensor_file) - - # remap to work with our model - remapped_state_dict = {} - tied_embed_weight = None - for key, value in state_dict.items(): - # Remove the 'model.' prefix if it exists - print(f"remapping: {key}") - if key.startswith('model.'): - new_key = key[len('model.'):] # Remove 'model.' - else: - new_key = key - - # change o_proj to output_proj - re_o_proj = re.findall(r'layers.(\d+).(\w+).(o_proj).(\w+)', new_key) - if len(re_o_proj) != 0: - new_key = f"layers.{re_o_proj[0][0]}.{re_o_proj[0][1]}.output_proj.weight" - - remapped_state_dict[new_key] = value - - # saving embed for tied weights - if new_key == 'embed_tokens.weight': - tied_embed_weight = value - - if new_key == 'lm_head.weight': - model.has_lm_head_weight = True - - if not model.has_lm_head_weight: - remapped_state_dict['lm_head.weight'] = tied_embed_weight - - model.load_state_dict(remapped_state_dict, strict=False) - - check_weights(model, remapped_state_dict) - - #exit() - model.eval() # Set the model to evaluation mode + shard_model = ShardedLlamaModel(config, shard) + print(f"\nshard_model: {shard_model}") + load_model_weights_torchtune(cache_dir, shard, shard_model) # Sample text for testing test_text = "Hello" - test_generation(model, tokenizer, test_text, 5, config) + test_generation(test_text, 5, config) From f53ebd17646e95fcf5f8c2d0e807c26300b958fe Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 3 Nov 2024 10:12:59 -0900 Subject: [PATCH 474/491] cleaning up custom dot product attention but might be removed, building out next parts for distributed inference --- exo/inference/torch/models/llama3.py | 28 ++++++++++------ exo/inference/torch/models/llm_utils.py | 44 +++++++++++++------------ 2 files changed, 41 insertions(+), 31 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 2feb5e18..7f41d129 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -3,7 +3,7 @@ Written with pytorch using torchtune and other methods """ -from typing import Tuple, List +from typing import Optional import torch import torch.nn as nn @@ -105,8 +105,6 @@ def LlamaModel( """ LlamaModel using torchtune """ - print(shard) - # Load configurations from config rope_scaling = config.get("rope_scaling") hidden_head_dim = config["hidden_size"] // config["num_attention_heads"] @@ -185,19 +183,29 @@ def LlamaModel( ) class ShardedLlamaModel(nn.Module): - def __init__(self, config: dict, shard: Shard, is_causal=True): + def __init__(self, + config: dict, + shard: Shard, + device: torch.device=torch.device("cpu"), + hidden_states: Optional[torch.Tensor] = None, + is_causal=True + ): super(ShardedLlamaModel, self).__init__() self.shard = shard self.config = config self.model = LlamaModel(config, shard, is_causal) + self.device = device - def generate( - self, - prompt: torch.Tensor - ): + def generate(self, prompt: torch.Tensor): """ - move login being done in test_llama3_model for generation to here - along with test sharding + move logit generation being done in test_llama3_model for generation to here + along with sharding """ + self.model.output_hidden_states = list(range(shard.start_layer, shard.end_layer)) + + # pass hidden state to model until last layer + # can be done with model's encoder_input and encoder_mask + # on last layer can generate + pass diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 0946890e..df345968 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -24,6 +24,7 @@ from exo.helpers import DEBUG from exo.inference.shard import Shard +from exo.inference.torch.tests.test_llama3_model import MAX_SEQ_LEN def load_model_config(model_config_path: Path) -> dict: """ @@ -68,13 +69,13 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): raise FileNotFoundError("No safetensors files found in the cache directory.") # Load weights from each found safetensors file - stitch_lmhead = True + paried_lmhead = True for safetensor_file in safetensors_files: state_dict = load_safetensors(safetensor_file) # remap to work with our model remapped_state_dict = {} - tied_embed_weight = None + paried_embed_weight = None for key, value in state_dict.items(): # load layer by shard lnrgx = re.findall(r'model\.layers\.(\d+).*', key) @@ -103,14 +104,14 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): if len(re_attn) != 0 and re_attn[0][0] == "self_attn": key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}" - # saving embed for tied weights + # saving embed for paired weights elif key == 'model.embed_tokens.weight': - tied_embed_weight = value + paried_embed_weight = value # change name for torchtune key = 'model.tok_embeddings.weight' elif key == 'lm_head.weight': - stitch_lmhead = False + paried_lmhead = False # change key for torchtune key = 'model.output.weight' @@ -119,8 +120,8 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): remapped_state_dict[key] = value - if stitch_lmhead: - remapped_state_dict['model.output.weight'] = tied_embed_weight + if paried_lmhead: + remapped_state_dict['model.output.weight'] = paried_embed_weight model.load_state_dict(remapped_state_dict, strict=False) @@ -529,7 +530,8 @@ def __init__( rotary_emb, attention_dropout=0.0, is_causal=True, - attention_bias=False + attention_bias=False, + kv_max_seq_len=2048 ): super().__init__() self.hidden_size = hidden_size @@ -538,6 +540,7 @@ def __init__( self.head_dim = head_dim self.attention_dropout = attention_dropout self.is_causal = is_causal + self.kv_max_seq_len = kv_max_seq_len # nn layers self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) @@ -593,20 +596,19 @@ def forward( print(f"key_states: {key_states.shape}") print(f"value_states: {value_states.shape}") - # Forcing caching always enabled + # Caching if kv_cache is not None: - #print(f"kv_cache.size {kv_cache.size}") - #print(f"key_states.size(0) {key_states.size(2)}") - - #if kv_cache.size != key_states.size(2): - # print(f"\n MAKE NEW KVCACHE batch_size={key_states.size(0)} max_seq_len={key_states.size(2)}") - # kv_cache = ttm.KVCache( - # batch_size=key_states.size(0), - # max_seq_len=key_states.size(2), - # num_heads=self.num_kv_heads, - # head_dim=self.head_dim, - # dtype=hidden_states.dtype - # ) + if kv_cache.size >= self.max_seq_len: + # double the cache each time space is ran out + self.kv_max_seq_len = self.kv_max_seq_len + self.kv_max_seq_len + + kv_cache = ttm.KVCache( + batch_size=key_states.size(0), + max_seq_len=self.kv_max_seq_len, + num_heads=self.num_kv_heads, + head_dim=self.head_dim, + dtype=hidden_states.dtype + ) key_states, value_states = kv_cache.update(key_states, value_states) From e8db8eee975b5773f007e908ccc9a3541a99ae51 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 10 Nov 2024 13:29:52 -0900 Subject: [PATCH 475/491] first layer run fixes, variable layer length weight loading fixes, working on split modeling --- exo/inference/torch/models/llama3.py | 85 +++++++- exo/inference/torch/models/llm_utils.py | 34 +-- .../torch/tests/test_llama3_model.py | 195 +++++++++--------- 3 files changed, 195 insertions(+), 119 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 7f41d129..a8769edf 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -3,7 +3,7 @@ Written with pytorch using torchtune and other methods """ -from typing import Optional +from typing import Optional, Any, Tuple, List import torch import torch.nn as nn @@ -186,26 +186,91 @@ class ShardedLlamaModel(nn.Module): def __init__(self, config: dict, shard: Shard, + tokenizer: Any, device: torch.device=torch.device("cpu"), hidden_states: Optional[torch.Tensor] = None, is_causal=True ): super(ShardedLlamaModel, self).__init__() + self.tokenizer = tokenizer self.shard = shard self.config = config self.model = LlamaModel(config, shard, is_causal) self.device = device - def generate(self, prompt: torch.Tensor): + def generate( + self, + input_tensor: torch.Tensor, + max_seq_len: int=4096 + ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: """ - move logit generation being done in test_llama3_model for generation to here - along with sharding + Generate logits and/or hidden_states from llama model + + Args + input (torch.Tensor) - tokens if initial first layer input and hidden states after + max_seq_len (int) - Max sequence length of generation, default 4096 """ - self.model.output_hidden_states = list(range(shard.start_layer, shard.end_layer)) - - # pass hidden state to model until last layer - # can be done with model's encoder_input and encoder_mask - # on last layer can generate + self.model.output_hidden_states = list(range(self.shard.start_layer, self.shard.end_layer)) + + if self.shard.is_first_layer(): + tokens = input_tensor + + if tokens.ndim == 1: + tokens = tokens.view(1, -1) + + _, tokens_length = tokens.size() + total_response_length = tokens_length + max_seq_len + resp_max_seq_len = ( + total_response_length + if not self.model.caches_are_enabled() + else self.model.decoder_max_cache_seq_len + ) + + # masking for proper attention + padding_masks = tokens != self.tokenizer.pad_id + if not padding_masks.all(): + padding_masks = torch.nn.functional.pad( + padding_masks, + (0, max_seq_len), + value=True + ) + + masks = ttg.get_causal_mask_from_padding_mask( + padding_masks, + target_seq_len=resp_max_seq_len + ) + + input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) + else: + masks = torch.tril( + torch.ones( + total_response_length, + resp_max_seq_len if resp_max_seq_len is not None else max_seq_len, + dtype=torch.bool, + device=tokens.device, + ) + ).unsqueeze(0) + + input_pos = torch.arange( + 0, total_response_length, device=tokens.device + ).unsqueeze(0) + + if self.model.caches_are_enabled(): + curr_masks = masks[:, :tokens_length] + else: + curr_masks = masks[:, :tokens_length, :tokens_length] + + model_output = self.model( + tokens=tokens, + mask=curr_masks, + input_pos=input_pos[:, :tokens_length].squeeze() + ) + + model_logits = model_output[-1] + model_output.pop() # remove logits + model_hs = model_output # hidden states - pass + return model_hs, model_logits + else: + return None, None diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index df345968..8e4ec14d 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -24,7 +24,6 @@ from exo.helpers import DEBUG from exo.inference.shard import Shard -from exo.inference.torch.tests.test_llama3_model import MAX_SEQ_LEN def load_model_config(model_config_path: Path) -> dict: """ @@ -70,6 +69,7 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): # Load weights from each found safetensors file paried_lmhead = True + shard_layer_range = list(range(shard.start_layer, shard.end_layer)) for safetensor_file in safetensors_files: state_dict = load_safetensors(safetensor_file) @@ -80,45 +80,44 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): # load layer by shard lnrgx = re.findall(r'model\.layers\.(\d+).*', key) layer_num = int(lnrgx[0]) if len(lnrgx) > 0 else None - shard_layer_range = list(range(shard.start_layer, shard.end_layer)) if layer_num in shard_layer_range: # change input layer norm to sa_norm for torchtune re_iln = re.findall( rf'model.layers\.{layer_num}\.(input_layernorm)\.weight', key) if len(re_iln) != 0: - key = f"model.layers.{layer_num}.sa_norm.weight" + remapped_state_dict[f"model.layers.{layer_num}.sa_norm.weight"] = value # change post attention layernorm to mlp_norm for torchtune re_pal = re.findall( rf'model.layers\.{layer_num}\.(post_attention_layernorm)\.weight', key) if len(re_pal) != 0: - key = f"model.layers.{layer_num}.mlp_norm.weight" - - # change o_proj to output_proj - re_o_proj = re.findall(rf'model\.layers\.{layer_num}.(\w+)\.o_proj\.weight', key) - if len(re_o_proj) != 0: - key = f"model.layers.{layer_num}.{re_o_proj[0]}.output_proj.weight" + remapped_state_dict[f"model.layers.{layer_num}.mlp_norm.weight"] = value # change self_attn to attn + # along with changing o_proj to output_proj re_attn = re.findall(rf'model\.layers\.{layer_num}.(\w+)\.(\w+)\.(\w+)', key) if len(re_attn) != 0 and re_attn[0][0] == "self_attn": - key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}" + if re_attn[0][1] == "o_proj": + remapped_state_dict[f"model.layers.{layer_num}.attn.output_proj.weight"] = value + else: + remapped_state_dict[f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}"] = value # saving embed for paired weights elif key == 'model.embed_tokens.weight': paried_embed_weight = value # change name for torchtune - key = 'model.tok_embeddings.weight' + remapped_state_dict['model.tok_embeddings.weight'] = value elif key == 'lm_head.weight': paried_lmhead = False - # change key for torchtune - key = 'model.output.weight' - - elif key == 'model.norm.weight': - key = 'model.norm.weight' - remapped_state_dict[key] = value + # get everything else except layers, embed_tokens and lm_head + if ( + len(re.findall(r'model\.layers\..*', key)) == 0 + and key != "model.embed_tokens.weight" + and key != "lm_head.weight" + ): + remapped_state_dict[key] = value if paried_lmhead: remapped_state_dict['model.output.weight'] = paried_embed_weight @@ -127,6 +126,7 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): #if DEBUG >= 7: print("\n--- checking weights ----\n") + print(f"\nremapped_state_dict: {remapped_state_dict.keys()}\n") check_weights(model, remapped_state_dict) def hf_logit_sample( diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 32e8fa1d..71fcbfd8 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -11,14 +11,14 @@ from torchtune.models import llama3 from torchtune.data import Message + +from exo.inference.torch.models.llama3 import ShardedLlamaModel +from exo.inference.shard import Shard + from exo.inference.torch.models.llm_utils import ( load_model_config, - hf_logit_sample, load_model_weights_torchtune, - create_4d_causal_attention_mask ) -from exo.inference.torch.models.llama3 import ShardedLlamaModel -from exo.inference.shard import Shard MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -46,103 +46,114 @@ def test_generation(text, max_length=10, config=None): tokens = tokenizer_out["tokens"] prompt = torch.tensor(tokens, dtype=torch.int) - if prompt.ndim == 1: - prompt = prompt.view(1, -1) + hidden_states, logits = shard_model.generate(prompt) - bsz, prompt_length = prompt.size() - total_response_length = prompt_length + MAX_SEQ_LEN - generated_tokens = prompt.clone() - resp_max_seq_len = ( - total_response_length - if not shard_model.model.caches_are_enabled() - else shard_model.model.decoder_max_cache_seq_len - ) + if hidden_states is not None: + print(f"hidden_states: {hidden_states[0].shape}\n{hidden_states}") - # masking for proper attention - padding_masks = prompt != llama_tokenizer.pad_id - if not padding_masks.all(): - padding_masks = torch.nn.functional.pad( - padding_masks, - (0, MAX_SEQ_LEN), - value=True - ) - - masks = ttg.get_causal_mask_from_padding_mask( - padding_masks, - target_seq_len=resp_max_seq_len - ) - - input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) - else: - masks = torch.tril( - torch.ones( - total_response_length, - resp_max_seq_len if resp_max_seq_len is not None else MAX_SEQ_LEN, - dtype=torch.bool, - device=prompt.device, - ) - ).unsqueeze(0) - - input_pos = torch.arange( - 0, total_response_length, device=prompt.device - ).unsqueeze(0) - - if shard_model.model.caches_are_enabled(): - curr_masks = masks[:, :prompt_length] - else: - curr_masks = masks[:, :prompt_length, :prompt_length] - - print(f"padding_masks: {padding_masks.shape}") - print(padding_masks.all()) - - next_token, gen_logits = ttg.generate_next_token( - shard_model.model, - input_pos=input_pos[:, :prompt_length].squeeze(), - x=prompt, - mask=curr_masks, - q=torch.empty( - ( - prompt.size(0), - shard_model.model.tok_embeddings.num_embeddings - ), device=prompt.device - ).exponential_(1, generator=None) - ) + if logits is not None: + print(f"logits: {logits.shape}\n{logits}") + #if prompt.ndim == 1: + # prompt = prompt.view(1, -1) - print(f"next_token: {next_token}") + #bsz, prompt_length = prompt.size() + #total_response_length = prompt_length + MAX_SEQ_LEN + #generated_tokens = prompt.clone() + #resp_max_seq_len = ( + # total_response_length + # if not shard_model.model.caches_are_enabled() + # else shard_model.model.decoder_max_cache_seq_len + #) - generated_tokens = torch.cat([generated_tokens, next_token], dim=-1) + ## masking for proper attention + #padding_masks = prompt != llama_tokenizer.pad_id + #if not padding_masks.all(): + # padding_masks = torch.nn.functional.pad( + # padding_masks, + # (0, MAX_SEQ_LEN), + # value=True + # ) + + # masks = ttg.get_causal_mask_from_padding_mask( + # padding_masks, + # target_seq_len=resp_max_seq_len + # ) + + # input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) + #else: + # masks = torch.tril( + # torch.ones( + # total_response_length, + # resp_max_seq_len if resp_max_seq_len is not None else MAX_SEQ_LEN, + # dtype=torch.bool, + # device=prompt.device, + # ) + # ).unsqueeze(0) + + # input_pos = torch.arange( + # 0, total_response_length, device=prompt.device + # ).unsqueeze(0) + + #if shard_model.model.caches_are_enabled(): + # curr_masks = masks[:, :prompt_length] + #else: + # curr_masks = masks[:, :prompt_length, :prompt_length] + + #rand_sample = torch.empty( + # ( + # prompt.size(0), + # self.model.tok_embeddings.num_embeddings + # ), device=prompt.device + #).exponential_(1, generator=None) + + #print(f"padding_masks: {padding_masks.shape}") + #print(padding_masks.all()) + + ## this can be sepearted out for dist inference + ## see https://github.com/pytorch/torchtune/blob/bc4acc19ffab2366a14468c97294992dbb7c50d1/torchtune/generation/_generation.py#L66 + #next_token, gen_logits = ttg.generate_next_token( + # shard_model.model, + # input_pos=input_pos[:, :prompt_length].squeeze(), + # x=prompt, + # mask=curr_masks, + # q=rand_sample + #) - print(f"generated_tokens: {generated_tokens}") + #print(f"next_token: {next_token}") - curr_pos = prompt_length + #generated_tokens = torch.cat([generated_tokens, next_token], dim=-1) - # stop tokens logic - stop_tokens = None - stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) - stop_tokens = ( - torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype) - if stop_tokens - else None - ) - stop_token_mask = torch.ones( - (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device - ) + #print(f"generated_tokens: {generated_tokens}") + + #curr_pos = prompt_length + + ## stop tokens logic + #stop_tokens = None + #stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) + #stop_tokens = ( + # torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype) + # if stop_tokens + # else None + #) + #stop_token_mask = torch.ones( + # (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device + #) - # finish writing stop token logic using torchtune generation - # ref https://github.com/pytorch/torchtune/blob/main/torchtune/generation/_generation.py#L337 + ## finish writing stop token logic using torchtune generation + ## ref https://github.com/pytorch/torchtune/blob/main/torchtune/generation/_generation.py#L337 - for _ in range(max_length): + #for _ in range(max_length): - if shard_model.model.caches_are_enabled(): - curr_input_pos = input_pos[:, curr_pos] - curr_masks = masks[:, curr_pos, None, :] - else: - tokens = generated_tokens.clone() - curr_input_pos = input_pos[:, : curr_pos + 1] - curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1] + # if shard_model.model.caches_are_enabled(): + # curr_input_pos = input_pos[:, curr_pos] + # curr_masks = masks[:, curr_pos, None, :] + # else: + # tokens = generated_tokens.clone() + # curr_input_pos = input_pos[:, : curr_pos + 1] + # curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1] - generated_tokens = generated_tokens.tolist() - print(f"resp: {llama_tokenizer.decode(generated_tokens[0])}") + #generated_tokens = generated_tokens.tolist() + #print(f"resp: {llama_tokenizer.decode(generated_tokens[0])}") if __name__ == "__main__": print("\nTesting generation:") @@ -159,7 +170,7 @@ def test_generation(text, max_length=10, config=None): shard = Shard( model_id=MODEL_NAME, start_layer=0, - end_layer=int(config["num_hidden_layers"]), + end_layer=4,#int(config["num_hidden_layers"]), n_layers=int(config["num_hidden_layers"]) ) @@ -172,7 +183,7 @@ def test_generation(text, max_length=10, config=None): #) # Initialize LlamaModel with config and tokenizer - shard_model = ShardedLlamaModel(config, shard) + shard_model = ShardedLlamaModel(config, shard, llama_tokenizer) print(f"\nshard_model: {shard_model}") load_model_weights_torchtune(cache_dir, shard, shard_model) From 22bc6a78d2ee5c859533f2b60650bacb7047c5b0 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 10 Nov 2024 15:04:52 -0900 Subject: [PATCH 476/491] made it so weight for last output layer is only loaded when shard is last layer --- exo/inference/torch/models/llm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 8e4ec14d..9e139238 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -119,7 +119,7 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): ): remapped_state_dict[key] = value - if paried_lmhead: + if paried_lmhead and shard.is_last_layer(): remapped_state_dict['model.output.weight'] = paried_embed_weight model.load_state_dict(remapped_state_dict, strict=False) From 7f2abc3ad87a6cf8cb00c73067fff2866c8ae563 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Tue, 12 Nov 2024 06:03:56 -0900 Subject: [PATCH 477/491] working on sharding issue where hidden state is not working when being passed --- exo/inference/torch/models/llama3.py | 277 +++++++++++------- .../torch/tests/test_llama3_model.py | 83 +++++- 2 files changed, 241 insertions(+), 119 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index a8769edf..2536f57f 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -3,13 +3,14 @@ Written with pytorch using torchtune and other methods """ -from typing import Optional, Any, Tuple, List +from typing import Optional, Any, Tuple, List, Union, Callable import torch import torch.nn as nn import torchtune.modules as ttm import torchtune.generation as ttg from torchtune.models.llama3_1 import Llama3ScaledRoPE +from torchtune.modules.attention_utils import _MaskType from exo.inference.shard import Shard from exo.inference.torch.models.llm_utils import ( @@ -17,84 +18,101 @@ RMSNorm ) -class LlamaBlock(nn.Module): - """ - Encoder block class for the LLaMA model - """ +class ShardTransformerDecoder(ttm.TransformerDecoder): def __init__( self, - config, - mlp, - self_attn, - rms_norm_eps=1e-6 + *, + shard: Shard, + tok_embeddings: nn.Embedding, + layers: Union[nn.Module, List[nn.Module], nn.ModuleList], + max_seq_len: int, + num_heads: int, + head_dim: int, + norm: nn.Module, + output: Union[nn.Linear, Callable], + num_layers: Optional[int] = None, + output_hidden_states: Optional[List[int]] = None ): - super(LlamaBlock, self).__init__() - self.config = config - self.self_attn = self_attn - self.mlp = mlp - self.input_layernorm = RMSNorm(self.config['hidden_size'], eps=rms_norm_eps) - self.post_attention_layernorm = RMSNorm(self.config['hidden_size'], eps=rms_norm_eps) + super().__init__( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=norm, + output=output, + num_layers=num_layers, + output_hidden_states=output_hidden_states, + ) + + self.shard = shard def forward( self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - max_seq_len: int = 2048 - ) -> torch.Tensor: - """ - Forward pass with integrated attention, resnet and key-value caching. - - Args: - hidden_states (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). - position_ids (Optional[torch.Tensor]): Position IDs tensor of shape (batch_size, seq_len). + tokens: torch.Tensor, + *, + mask: Optional[_MaskType] = None, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + # for captured hidden states + hidden = [] + + # Determine the type of input and shape + print(f"tokens.ndim: {tokens.ndim}") + if tokens.ndim == 3: + h = tokens # Use directly as hidden states + else: + h = self.tok_embeddings(tokens) # Apply token tok_embeddings - Returns: - Tuple[torch.Tensor, KVCache]: - - Output tensor of shape (batch_size, seq_len, dim). - """ - if isinstance(self.self_attn, ttm.MultiHeadAttention): - if self.self_attn.kv_cache is None: - # setup cache - self.self_attn.setup_cache( - batch_size=hidden_states.size(0), - dtype=hidden_states.dtype, - max_seq_len=max_seq_len, #self.config['max_position_embeddings'] - ) + # capture tok hidden state, if needed + if 0 in self.output_hidden_states: + hidden.append(h) - # Apply RMSNorm to input - hidden_states = self.input_layernorm(hidden_states) - print(f"self.input_layernorm(hidden_states) {hidden_states.shape}") + seq_len = h.shape[1] - # get causal mask from attention mask - causal_mask = ttg.get_causal_mask_from_padding_mask( - attention_mask.bool(), - max_seq_len + self._validate_inputs( + seq_len, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, ) - print(f"causal_mask: {causal_mask.shape}") + # Initialize a list to capture hidden states if requested + hidden = [] + for i in range(self.shard.start_layer, self.shard.end_layer+1): + layer = self.layers[i] + + # Process through each transformer layer + h = layer( + h, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) - # get position_ids from attention mask - position_ids = ttg.get_position_ids_from_padding_mask( - attention_mask.bool() - ) + # capture wanted hidden states + if i in self.output_hidden_states: + hidden.append(h) - print(f"position_ids: {position_ids.shape}") + print(f"\n\n\nhidden layer H[{i}]\n{h}\n\n\n") - hidden_states = self.self_attn( - x=hidden_states, - y=hidden_states, - mask=causal_mask, - #input_pos=position_ids - ) + # Apply normalization + h = self.norm(h) - # Residual connection - print(f"hidden_states: {hidden_states.shape}") - # Post attention normalization - hidden_states = self.post_attention_layernorm(hidden_states) - # Feed-forward network with MLP and residual connection - hidden_states = self.mlp(hidden_states) + # Handle chunked output if needed + if self.num_output_chunks > 0: + output = self.chunked_output(h) + else: + output = self.output(h).float() - return hidden_states + # Return list if hidden states are requested + output = output if not hidden else [*hidden, output] + print(f"\n\noutput {output}\n\n") + return output def LlamaModel( config: dict, @@ -170,7 +188,7 @@ def LlamaModel( layers.append(layer) - return ttm.TransformerDecoder( + return ShardTransformerDecoder( tok_embeddings=embed_tokens, layers=nn.ModuleList(layers), max_seq_len=max_seq_len, @@ -179,7 +197,8 @@ def LlamaModel( norm=RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]), output=nn.Linear(config["hidden_size"], config["vocab_size"]), num_layers=shard.n_layers, - #output_hidden_states=list(range(shard.start_layer, shard.end_layer)) + #output_hidden_states=list(range(shard.start_layer, shard.end_layer)), + shard=shard ) class ShardedLlamaModel(nn.Module): @@ -201,76 +220,114 @@ def __init__(self, def generate( self, - input_tensor: torch.Tensor, + tokens: torch.Tensor, + hidden_state: Optional[torch.Tensor] = None, max_seq_len: int=4096 ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: """ Generate logits and/or hidden_states from llama model Args - input (torch.Tensor) - tokens if initial first layer input and hidden states after + tokens (torch.Tensor) - tokens from prompt tokenization + hidden_state (torch.Tensor, optional) - hidden state from last activated hidden layer, if any max_seq_len (int) - Max sequence length of generation, default 4096 """ - self.model.output_hidden_states = list(range(self.shard.start_layer, self.shard.end_layer)) - - if self.shard.is_first_layer(): - tokens = input_tensor + print(self.shard) + print(self.shard.is_last_layer()) + if not self.shard.is_last_layer(): + self.model.output_hidden_states = [self.shard.end_layer] + + if tokens.ndim == 1: + tokens = tokens.view(1, -1) + + _, tokens_length = tokens.size() + total_response_length = tokens_length + max_seq_len + resp_max_seq_len = ( + total_response_length + if not self.model.caches_are_enabled() + else self.model.decoder_max_cache_seq_len + ) - if tokens.ndim == 1: - tokens = tokens.view(1, -1) + # clone tokens + generated_tokens = tokens.clone() - _, tokens_length = tokens.size() - total_response_length = tokens_length + max_seq_len - resp_max_seq_len = ( - total_response_length - if not self.model.caches_are_enabled() - else self.model.decoder_max_cache_seq_len + # masking for proper attention + padding_masks = generated_tokens != self.tokenizer.pad_id + if not padding_masks.all(): + padding_masks = torch.nn.functional.pad( + padding_masks, + (0, max_seq_len), + value=True ) - # masking for proper attention - padding_masks = tokens != self.tokenizer.pad_id - if not padding_masks.all(): - padding_masks = torch.nn.functional.pad( - padding_masks, - (0, max_seq_len), - value=True - ) + masks = ttg.get_causal_mask_from_padding_mask( + padding_masks, + target_seq_len=resp_max_seq_len + ) - masks = ttg.get_causal_mask_from_padding_mask( - padding_masks, - target_seq_len=resp_max_seq_len + input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) + else: + masks = torch.tril( + torch.ones( + total_response_length, + resp_max_seq_len if resp_max_seq_len is not None else total_response_length, + dtype=torch.bool, + device=tokens.device, ) - - input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) - else: - masks = torch.tril( - torch.ones( - total_response_length, - resp_max_seq_len if resp_max_seq_len is not None else max_seq_len, - dtype=torch.bool, - device=tokens.device, - ) - ).unsqueeze(0) + ).unsqueeze(0) input_pos = torch.arange( - 0, total_response_length, device=tokens.device + 0, total_response_length, device=generated_tokens.device ).unsqueeze(0) - if self.model.caches_are_enabled(): - curr_masks = masks[:, :tokens_length] - else: - curr_masks = masks[:, :tokens_length, :tokens_length] + if self.model.caches_are_enabled(): + curr_masks = masks[:, :tokens_length] + else: + curr_masks = masks[:, :tokens_length, :tokens_length] + + if hidden_state is not None: + #_, hs_len, _ = hidden_state.size() + #total_hidden_length = hs_len + max_seq_len + #hs_max_seq_len = ( + # total_response_length + # if not self.model.caches_are_enabled() + # else self.model.decoder_max_cache_seq_len + #) + + #hs_mask = torch.tril( + # torch.ones( + # total_hidden_length, + # hs_max_seq_len if hs_max_seq_len is not None else max_seq_len, + # dtype=torch.bool, + # device=tokens.device, + # ) + #).unsqueeze(0) + + #if self.model.caches_are_enabled(): + #hs_curr_masks = hs_mask[:, :hs_len] + #else: + #hs_curr_masks = hs_mask[:, :hs_len, :hs_len] + model_output = self.model( + tokens=hidden_state, + mask=curr_masks, + input_pos=input_pos[:, :tokens_length].squeeze(), + ) + else: model_output = self.model( tokens=tokens, mask=curr_masks, input_pos=input_pos[:, :tokens_length].squeeze() ) + print(f"\nmodel_output: {model_output}") + + if isinstance(model_output, list): model_logits = model_output[-1] model_output.pop() # remove logits - model_hs = model_output # hidden states - - return model_hs, model_logits + model_hs = model_output[-1] # get last hidden state else: - return None, None + model_logits = model_output + model_hs = None + + return model_hs, model_logits diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 71fcbfd8..d742ff9c 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -27,7 +27,7 @@ TOP_P=0.9 MAX_SEQ_LEN=2048 -def test_generation(text, max_length=10, config=None): +def test_generation_1(shard_model, text): """ Test the generation capabilities of the LlamaModel with sample text. """ @@ -49,10 +49,12 @@ def test_generation(text, max_length=10, config=None): hidden_states, logits = shard_model.generate(prompt) if hidden_states is not None: - print(f"hidden_states: {hidden_states[0].shape}\n{hidden_states}") + print(f"hidden_states[{len(hidden_states)}]: {hidden_states}") if logits is not None: print(f"logits: {logits.shape}\n{logits}") + + return hidden_states, logits, prompt #if prompt.ndim == 1: # prompt = prompt.view(1, -1) @@ -155,6 +157,44 @@ def test_generation(text, max_length=10, config=None): #generated_tokens = generated_tokens.tolist() #print(f"resp: {llama_tokenizer.decode(generated_tokens[0])}") +def test_generation_2(shard_model, tokens, hidden_state): + print("Generate with the rest of layers") + hidden_states, logits = shard_model.generate( + tokens=tokens, + hidden_state=hidden_state + ) + + if hidden_states is not None: + print(f"hidden_states {hidden_states.shape}: {hidden_states}") + + if logits is not None: + print(f"logits: {logits.shape}\n{logits}") + + rand_sample = torch.empty( + ( + logits.size(0), + shard_model.model.tok_embeddings.num_embeddings + ), device=logits.device + ).exponential_(1, generator=None) + + logit = ttg.sample( + logits=logits[:, -1].clone(), + temperature=TEMP, + top_k=TOP_K, + q=rand_sample + ) + + print(f"logit: {logit}") + + generated_tokens = tokens.clone() + generated_tokens = torch.cat([generated_tokens, logit.squeeze(-1)], dim=-1).tolist() + + print(f"generated_tokens: {generated_tokens}") + + print(f"resp: {llama_tokenizer.decode(generated_tokens)}\n\n\n") + + return hidden_states, logits + if __name__ == "__main__": print("\nTesting generation:") # Get the path to the model files from the Hugging Face cache @@ -167,10 +207,20 @@ def test_generation(text, max_length=10, config=None): print(f"current config\n{config}") # Setup shard - shard = Shard( + s1_end = int(int(config["num_hidden_layers"])/2) + shard_1 = Shard( model_id=MODEL_NAME, start_layer=0, - end_layer=4,#int(config["num_hidden_layers"]), + end_layer=s1_end, + n_layers=int(config["num_hidden_layers"]) + ) + + s2_start = s1_end + 1 + s2_end = shard_1.n_layers - 1 + shard_2 = Shard( + model_id=MODEL_NAME, + start_layer=s2_start, + end_layer=s2_end, n_layers=int(config["num_hidden_layers"]) ) @@ -183,12 +233,27 @@ def test_generation(text, max_length=10, config=None): #) # Initialize LlamaModel with config and tokenizer - shard_model = ShardedLlamaModel(config, shard, llama_tokenizer) - print(f"\nshard_model: {shard_model}") - load_model_weights_torchtune(cache_dir, shard, shard_model) + shard_model_1 = ShardedLlamaModel(config, shard_1, llama_tokenizer) + print(f"\nshard_model_1: {shard_model_1}") + load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) # Sample text for testing - test_text = "Hello" + #prompt = "First letter in the word 'Red'" + prompt = "Hello" + shard_1_hs, shard_1_logits, shard_1_tokens = test_generation_1(shard_model_1, prompt) + + print(f"shard_1_hs:\n{shard_1_hs}") + print(f"shard_1_logits:\n{shard_1_logits}") + print(f"shard_1_tokens:\n{shard_1_tokens}") + + del shard_model_1.model + del shard_model_1 + + shard_model_2 = ShardedLlamaModel(config, shard_2, llama_tokenizer) + print(f"\nshard_model_2: {shard_model_2}") + load_model_weights_torchtune(cache_dir, shard_2, shard_model_2) + shard_2_hs, shard_2_logits = test_generation_2(shard_model_2, shard_1_tokens, shard_1_hs) - test_generation(test_text, 5, config) + print(f"shard_2_hs:\n{shard_2_hs}") + print(f"shard_2_logits:\n{shard_2_logits}") From bdf3240481349ba52a3f6313eed5f2c3a169e7b5 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Thu, 14 Nov 2024 23:56:51 -0900 Subject: [PATCH 478/491] fixing last hidden value handling --- exo/inference/torch/models/llama3.py | 47 ++++++++++--------- .../torch/tests/test_llama3_model.py | 14 ++++-- 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 2536f57f..c29979d9 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -56,9 +56,6 @@ def forward( encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: - # for captured hidden states - hidden = [] - # Determine the type of input and shape print(f"tokens.ndim: {tokens.ndim}") if tokens.ndim == 3: @@ -66,10 +63,6 @@ def forward( else: h = self.tok_embeddings(tokens) # Apply token tok_embeddings - # capture tok hidden state, if needed - if 0 in self.output_hidden_states: - hidden.append(h) - seq_len = h.shape[1] self._validate_inputs( @@ -81,9 +74,13 @@ def forward( ) # Initialize a list to capture hidden states if requested - hidden = [] + # for captured hidden states + hidden = None + for i in range(self.shard.start_layer, self.shard.end_layer+1): - layer = self.layers[i] + layer = self.layers[i] + + print(f"\nhidden layer in H[{i}]\n{h}\n") # Process through each transformer layer h = layer( @@ -94,12 +91,13 @@ def forward( input_pos=input_pos, ) - # capture wanted hidden states - if i in self.output_hidden_states: - hidden.append(h) + # for shard model just capture the last hs computed + if i == self.shard.end_layer: + hidden = h - print(f"\n\n\nhidden layer H[{i}]\n{h}\n\n\n") + print(f"\nhidden layer out H[{i}]->H[{i+1}]\n{h}\n") + print(f"last hidden: {hidden}") # Apply normalization h = self.norm(h) @@ -110,7 +108,7 @@ def forward( output = self.output(h).float() # Return list if hidden states are requested - output = output if not hidden else [*hidden, output] + output = [hidden, output] print(f"\n\noutput {output}\n\n") return output @@ -207,8 +205,8 @@ def __init__(self, shard: Shard, tokenizer: Any, device: torch.device=torch.device("cpu"), - hidden_states: Optional[torch.Tensor] = None, - is_causal=True + is_causal=True, + use_cache=False ): super(ShardedLlamaModel, self).__init__() @@ -217,6 +215,7 @@ def __init__(self, self.config = config self.model = LlamaModel(config, shard, is_causal) self.device = device + self.use_cache = use_cache def generate( self, @@ -234,13 +233,19 @@ def generate( """ print(self.shard) print(self.shard.is_last_layer()) - if not self.shard.is_last_layer(): - self.model.output_hidden_states = [self.shard.end_layer] if tokens.ndim == 1: tokens = tokens.view(1, -1) - _, tokens_length = tokens.size() + bsz, tokens_length = tokens.size() + + # setup cache + if not self.model.caches_are_enabled() and self.use_cache: + self.model.setup_caches(bsz, torch.float, decoder_max_seq_len=self.model.decoder_max_cache_seq_len) + + if not self.shard.is_last_layer(): + self.model.output_hidden_states = [self.shard.end_layer] + total_response_length = tokens_length + max_seq_len resp_max_seq_len = ( total_response_length @@ -323,9 +328,9 @@ def generate( print(f"\nmodel_output: {model_output}") if isinstance(model_output, list): - model_logits = model_output[-1] + model_logits = model_output[1] model_output.pop() # remove logits - model_hs = model_output[-1] # get last hidden state + model_hs = model_output[0] # get last hidden state else: model_logits = model_output model_hs = None diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index d742ff9c..d573cf73 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -2,7 +2,8 @@ Test of pytorch based llama3 model """ from pathlib import Path - +import gc +import time import torch from transformers import AutoTokenizer from huggingface_hub import snapshot_download @@ -246,14 +247,19 @@ def test_generation_2(shard_model, tokens, hidden_state): print(f"shard_1_logits:\n{shard_1_logits}") print(f"shard_1_tokens:\n{shard_1_tokens}") + gc.collect() + torch.cuda.empty_cache() + + if shard_model_1.model.caches_are_enabled(): + shard_model_1.model.reset_caches() + del shard_model_1.model del shard_model_1 + #time.sleep(10) + shard_model_2 = ShardedLlamaModel(config, shard_2, llama_tokenizer) print(f"\nshard_model_2: {shard_model_2}") load_model_weights_torchtune(cache_dir, shard_2, shard_model_2) shard_2_hs, shard_2_logits = test_generation_2(shard_model_2, shard_1_tokens, shard_1_hs) - print(f"shard_2_hs:\n{shard_2_hs}") - print(f"shard_2_logits:\n{shard_2_logits}") - From 227199f720bbe3b732836155ecfe02346de55e3c Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 15 Nov 2024 00:07:40 -0900 Subject: [PATCH 479/491] update test --- exo/inference/torch/tests/test_llama3_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index d573cf73..8b5799f8 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -192,7 +192,7 @@ def test_generation_2(shard_model, tokens, hidden_state): print(f"generated_tokens: {generated_tokens}") - print(f"resp: {llama_tokenizer.decode(generated_tokens)}\n\n\n") + print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(generated_tokens)}\n\n\n") return hidden_states, logits From 5af630268c31942502d7be10d256b31a7c852589 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 15 Nov 2024 00:09:14 -0900 Subject: [PATCH 480/491] update test --- exo/inference/torch/tests/test_llama3_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 8b5799f8..0ef2748e 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -240,7 +240,7 @@ def test_generation_2(shard_model, tokens, hidden_state): # Sample text for testing #prompt = "First letter in the word 'Red'" - prompt = "Hello" + prompt = "GM, say it back" shard_1_hs, shard_1_logits, shard_1_tokens = test_generation_1(shard_model_1, prompt) print(f"shard_1_hs:\n{shard_1_hs}") From d7e5aca57a7cecf17e84019f4ded0a3b7bf63143 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 15 Nov 2024 00:11:32 -0900 Subject: [PATCH 481/491] update test --- exo/inference/torch/tests/test_llama3_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 0ef2748e..22abb278 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -192,7 +192,7 @@ def test_generation_2(shard_model, tokens, hidden_state): print(f"generated_tokens: {generated_tokens}") - print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(generated_tokens)}\n\n\n") + print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(logit.squeeze(-1).tolist())}\n\n\n") return hidden_states, logits From 1874d2307f52cac6b3898c46b6aea034ae9b798d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 15 Nov 2024 00:14:18 -0900 Subject: [PATCH 482/491] update test, turn on caching --- exo/inference/torch/tests/test_llama3_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py index 22abb278..1b16d91f 100644 --- a/exo/inference/torch/tests/test_llama3_model.py +++ b/exo/inference/torch/tests/test_llama3_model.py @@ -234,7 +234,7 @@ def test_generation_2(shard_model, tokens, hidden_state): #) # Initialize LlamaModel with config and tokenizer - shard_model_1 = ShardedLlamaModel(config, shard_1, llama_tokenizer) + shard_model_1 = ShardedLlamaModel(config, shard_1, llama_tokenizer, use_cache=True) print(f"\nshard_model_1: {shard_model_1}") load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) @@ -258,7 +258,7 @@ def test_generation_2(shard_model, tokens, hidden_state): #time.sleep(10) - shard_model_2 = ShardedLlamaModel(config, shard_2, llama_tokenizer) + shard_model_2 = ShardedLlamaModel(config, shard_2, llama_tokenizer, use_cache=True) print(f"\nshard_model_2: {shard_model_2}") load_model_weights_torchtune(cache_dir, shard_2, shard_model_2) shard_2_hs, shard_2_logits = test_generation_2(shard_model_2, shard_1_tokens, shard_1_hs) From 3a0ad62226d8b18967581cd197999ae5c10f3193 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 15 Nov 2024 00:19:05 -0900 Subject: [PATCH 483/491] test safetensor load --- exo/inference/torch/models/llm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 9e139238..8e4ec14d 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -119,7 +119,7 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): ): remapped_state_dict[key] = value - if paried_lmhead and shard.is_last_layer(): + if paried_lmhead: remapped_state_dict['model.output.weight'] = paried_embed_weight model.load_state_dict(remapped_state_dict, strict=False) From 6098ae5324845d4520bc2f9158720dde8251face Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Fri, 15 Nov 2024 01:00:15 -0900 Subject: [PATCH 484/491] test hidden alignment --- exo/inference/torch/models/llama3.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index c29979d9..3280600e 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -60,6 +60,10 @@ def forward( print(f"tokens.ndim: {tokens.ndim}") if tokens.ndim == 3: h = tokens # Use directly as hidden states + + # check states alignment + align_check = self.layers[0].in_features == h.shape[-1] + print(f"align_check {align_check}") else: h = self.tok_embeddings(tokens) # Apply token tok_embeddings From fa1e70fdc93b73ccb46054d2986540574f03ac4f Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 17 Nov 2024 05:53:04 -0900 Subject: [PATCH 485/491] updates to torchtune model, fixing non-generation errors, created split and full test, separating huggingface and torchtune inference engines --- .../torch/{inference.py => hf_inference.py} | 5 +- exo/inference/torch/models/llama3.py | 266 +++---- exo/inference/torch/models/llm_utils.py | 711 ++++-------------- exo/inference/torch/pt_inference.py | 5 + .../torch/tests/test_inference_engine.py | 6 +- exo/inference/torch/tests/test_llama3_full.py | 121 +++ .../torch/tests/test_llama3_model.py | 265 ------- .../torch/tests/test_llama3_split.py | 131 ++++ 8 files changed, 524 insertions(+), 986 deletions(-) rename exo/inference/torch/{inference.py => hf_inference.py} (98%) create mode 100644 exo/inference/torch/pt_inference.py create mode 100644 exo/inference/torch/tests/test_llama3_full.py delete mode 100644 exo/inference/torch/tests/test_llama3_model.py create mode 100644 exo/inference/torch/tests/test_llama3_split.py diff --git a/exo/inference/torch/inference.py b/exo/inference/torch/hf_inference.py similarity index 98% rename from exo/inference/torch/inference.py rename to exo/inference/torch/hf_inference.py index 23bbe814..1b4f19e0 100644 --- a/exo/inference/torch/inference.py +++ b/exo/inference/torch/hf_inference.py @@ -25,9 +25,10 @@ TEMP = 0.6 TOP_P = 0.9 -class TorchDynamicShardInferenceEngine(InferenceEngine): +class HFDynamicShardInferenceEngine(InferenceEngine): """ - Torch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. + HuggingFace Dynamic Shard Inference Engine + Performing model inference with sharded Pytorch based HuggingFace models. """ def __init__(self, shard_downloader: HFShardDownloader): diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 3280600e..cb3456eb 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -3,6 +3,7 @@ Written with pytorch using torchtune and other methods """ + from typing import Optional, Any, Tuple, List, Union, Callable import torch @@ -15,10 +16,17 @@ from exo.inference.shard import Shard from exo.inference.torch.models.llm_utils import ( MultiLayerPreceptron, - RMSNorm + RMSNorm, + get_torch_dtype ) + class ShardTransformerDecoder(ttm.TransformerDecoder): + """ + ShardTransformerDecorder + Custom version of torchtune TransformerDecoder to allow for + sharding of models and passing of hidden layers between shards + """ def __init__( self, *, @@ -31,7 +39,7 @@ def __init__( norm: nn.Module, output: Union[nn.Linear, Callable], num_layers: Optional[int] = None, - output_hidden_states: Optional[List[int]] = None + output_hidden_states: Optional[List[int]] = None, ): super().__init__( tok_embeddings=tok_embeddings, @@ -57,34 +65,29 @@ def forward( input_pos: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: # Determine the type of input and shape - print(f"tokens.ndim: {tokens.ndim}") if tokens.ndim == 3: h = tokens # Use directly as hidden states - - # check states alignment - align_check = self.layers[0].in_features == h.shape[-1] - print(f"align_check {align_check}") else: h = self.tok_embeddings(tokens) # Apply token tok_embeddings - seq_len = h.shape[1] + seq_len = h.shape[1] - self._validate_inputs( - seq_len, - mask=mask, - encoder_input=encoder_input, - encoder_mask=encoder_mask, - input_pos=input_pos, - ) + self._validate_inputs( + seq_len, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) # Initialize a list to capture hidden states if requested # for captured hidden states - hidden = None + hidden = [] - for i in range(self.shard.start_layer, self.shard.end_layer+1): + for i in range(self.shard.start_layer, self.shard.end_layer + 1): layer = self.layers[i] - print(f"\nhidden layer in H[{i}]\n{h}\n") + print(f"\nhidden layer in H[{i}]\n{h}\nmask\n{mask}\ninput_pos\n{input_pos}\n{self.output_hidden_states}\n") # Process through each transformer layer h = layer( @@ -95,138 +98,141 @@ def forward( input_pos=input_pos, ) - # for shard model just capture the last hs computed - if i == self.shard.end_layer: - hidden = h + if i in self.output_hidden_states: + hidden.append(h) - print(f"\nhidden layer out H[{i}]->H[{i+1}]\n{h}\n") + print(f"\nhidden layer out H[{i}]->H[{i + 1}]\n{h}\n") - print(f"last hidden: {hidden}") # Apply normalization h = self.norm(h) # Handle chunked output if needed if self.num_output_chunks > 0: - output = self.chunked_output(h) + output = self.chunked_output(h) else: - output = self.output(h).float() + output = self.output(h).float() # Return list if hidden states are requested - output = [hidden, output] + output = [hidden[-1], output] if hidden else output print(f"\n\noutput {output}\n\n") return output -def LlamaModel( - config: dict, - shard: Shard, - is_causal: bool=True, - max_seq_len: int=4096 -): +def LlamaModel(config: dict, shard: Shard): """ LlamaModel using torchtune """ - # Load configurations from config - rope_scaling = config.get("rope_scaling") - hidden_head_dim = config["hidden_size"] // config["num_attention_heads"] - - # Model layers and methods, order matters - embed_tokens = nn.Embedding( - config["vocab_size"], - config["hidden_size"] + # rope scaling config + if config["rope_scaling"] is not None: + scale_factor = config["rope_scaling"].get("factor", 32) + + rope = Llama3ScaledRoPE( + dim=config["head_dim"], + max_seq_len=config["max_seq_len"], + base=config["rope_base"], + scale_factor=scale_factor, ) layers = [] for _ in range(shard.n_layers): - pos_embeddings = Llama3ScaledRoPE( - dim=hidden_head_dim, - max_seq_len=max_seq_len, - base=config.get('rope_theta', 10000), - scale_factor=rope_scaling['factor'] if rope_scaling else 32 - ) - self_attn = ttm.MultiHeadAttention( - embed_dim=config["hidden_size"], - num_heads=config["num_attention_heads"], - num_kv_heads=config["num_key_value_heads"], - head_dim=hidden_head_dim, + embed_dim=config["embed_dim"], + num_heads=config["num_heads"], + num_kv_heads=config["num_kv_heads"], + head_dim=config["head_dim"], q_proj=nn.Linear( - config["hidden_size"], - config["num_attention_heads"] * config["head_dim"], - bias=config.get('attention_bias', False) + config["embed_dim"], + config["num_heads"] * config["head_dim"], + bias=config["attn_bias"], ), - k_proj = nn.Linear( - config["hidden_size"], - config["num_key_value_heads"] * config["head_dim"], - bias=config.get('attention_bias', False) + k_proj=nn.Linear( + config["embed_dim"], + config["num_kv_heads"] * config["head_dim"], + bias=config["attn_bias"], ), - v_proj = nn.Linear( - config["hidden_size"], - config["num_key_value_heads"] * config["head_dim"], - bias=config.get('attention_bias', False) + v_proj=nn.Linear( + config["embed_dim"], + config["num_kv_heads"] * config["head_dim"], + bias=config["attn_bias"], ), output_proj=nn.Linear( - config["hidden_size"], - config["hidden_size"], - bias=config.get('attention_bias', False) + config["embed_dim"], + config["embed_dim"], + bias=config["attn_bias"], ), - max_seq_len=max_seq_len, - is_causal=is_causal, - attn_dropout=config.get('attention_dropout', 0.0), - pos_embeddings=pos_embeddings + max_seq_len=config["max_seq_len"], + attn_dropout=config["attn_dropout"], + pos_embeddings=rope, ) mlp = MultiLayerPreceptron( - config["hidden_size"], - config['intermediate_size'], - 'silu' + config["embed_dim"], + config["intermediate_dim"], + config["hidden_act"] ) layer = ttm.TransformerSelfAttentionLayer( attn=self_attn, mlp=mlp, - sa_norm=RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]), - mlp_norm=RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]) + sa_norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), + mlp_norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), ) layers.append(layer) + + layers = nn.ModuleList(layers) + tok_embeddings = nn.Embedding(config["vocab_size"], config["embed_dim"]) + # output_proj = ttm.TiedLinear(tok_embeddings) + output_proj = nn.Linear( + config["embed_dim"], + config["vocab_size"], + bias=config["attn_bias"], + ) return ShardTransformerDecoder( - tok_embeddings=embed_tokens, - layers=nn.ModuleList(layers), - max_seq_len=max_seq_len, - num_heads=config["num_attention_heads"], + tok_embeddings=tok_embeddings, + shard=shard, + layers=layers, + max_seq_len=config["max_seq_len"], + num_heads=config["num_heads"], head_dim=config["head_dim"], - norm=RMSNorm(config["hidden_size"], eps=config["rms_norm_eps"]), - output=nn.Linear(config["hidden_size"], config["vocab_size"]), - num_layers=shard.n_layers, - #output_hidden_states=list(range(shard.start_layer, shard.end_layer)), - shard=shard + norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), + output=output_proj, + num_layers=config["num_layers"] ) + # return ttm.TransformerDecoder( + # tok_embeddings=tok_embeddings, + # layers=layers, + # max_seq_len=config["max_seq_len"], + # num_heads=config["num_heads"], + # head_dim=config["head_dim"], + # norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), + # output=output_proj, + # num_layers=config["num_layers"], + # ) + + class ShardedLlamaModel(nn.Module): - def __init__(self, - config: dict, - shard: Shard, - tokenizer: Any, - device: torch.device=torch.device("cpu"), - is_causal=True, - use_cache=False - ): + def __init__( + self, + config: dict, + shard: Shard, + tokenizer: Any, + device: Optional[torch.device] = None, + max_seq_len: Optional[int] = None + ): super(ShardedLlamaModel, self).__init__() self.tokenizer = tokenizer self.shard = shard self.config = config - self.model = LlamaModel(config, shard, is_causal) - self.device = device - self.use_cache = use_cache + self.dtype = get_torch_dtype(self.config["torch_dtype"]) if "torch_dtype" in self.config else torch.float + self.device = device if device is not None else torch.device("cpu") + self.use_cache = self.config.get("use_cache", False) + self.model = LlamaModel(config, self.shard).to(dtype=self.dtype, device=self.device) + self.max_seq_len = max_seq_len if max_seq_len is not None else 4096 - def generate( - self, - tokens: torch.Tensor, - hidden_state: Optional[torch.Tensor] = None, - max_seq_len: int=4096 - ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: + def generate(self, tokens: torch.Tensor, hidden_state: Optional[torch.Tensor] = None) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: """ Generate logits and/or hidden_states from llama model @@ -245,17 +251,14 @@ def generate( # setup cache if not self.model.caches_are_enabled() and self.use_cache: - self.model.setup_caches(bsz, torch.float, decoder_max_seq_len=self.model.decoder_max_cache_seq_len) + with self.device: + self.model.setup_caches(bsz, self.dtype, decoder_max_seq_len=self.model.decoder_max_cache_seq_len) if not self.shard.is_last_layer(): self.model.output_hidden_states = [self.shard.end_layer] - total_response_length = tokens_length + max_seq_len - resp_max_seq_len = ( - total_response_length - if not self.model.caches_are_enabled() - else self.model.decoder_max_cache_seq_len - ) + total_response_length = tokens_length + self.max_seq_len + resp_max_seq_len = total_response_length if not self.model.caches_are_enabled() else self.model.decoder_max_cache_seq_len # clone tokens generated_tokens = tokens.clone() @@ -263,16 +266,9 @@ def generate( # masking for proper attention padding_masks = generated_tokens != self.tokenizer.pad_id if not padding_masks.all(): - padding_masks = torch.nn.functional.pad( - padding_masks, - (0, max_seq_len), - value=True - ) + padding_masks = torch.nn.functional.pad(padding_masks, (0, self.max_seq_len), value=True) - masks = ttg.get_causal_mask_from_padding_mask( - padding_masks, - target_seq_len=resp_max_seq_len - ) + masks = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=resp_max_seq_len) input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) else: @@ -285,56 +281,34 @@ def generate( ) ).unsqueeze(0) - input_pos = torch.arange( - 0, total_response_length, device=generated_tokens.device - ).unsqueeze(0) - + input_pos = torch.arange(0, total_response_length, device=generated_tokens.device).unsqueeze(0) + if self.model.caches_are_enabled(): curr_masks = masks[:, :tokens_length] else: curr_masks = masks[:, :tokens_length, :tokens_length] - if hidden_state is not None: - #_, hs_len, _ = hidden_state.size() - #total_hidden_length = hs_len + max_seq_len - #hs_max_seq_len = ( - # total_response_length - # if not self.model.caches_are_enabled() - # else self.model.decoder_max_cache_seq_len - #) - - #hs_mask = torch.tril( - # torch.ones( - # total_hidden_length, - # hs_max_seq_len if hs_max_seq_len is not None else max_seq_len, - # dtype=torch.bool, - # device=tokens.device, - # ) - #).unsqueeze(0) - - #if self.model.caches_are_enabled(): - #hs_curr_masks = hs_mask[:, :hs_len] - #else: - #hs_curr_masks = hs_mask[:, :hs_len, :hs_len] + input_pos = input_pos[:, :tokens_length].squeeze() + if hidden_state is not None: model_output = self.model( tokens=hidden_state, mask=curr_masks, - input_pos=input_pos[:, :tokens_length].squeeze(), + input_pos=input_pos, ) else: model_output = self.model( tokens=tokens, mask=curr_masks, - input_pos=input_pos[:, :tokens_length].squeeze() + input_pos=input_pos, ) print(f"\nmodel_output: {model_output}") if isinstance(model_output, list): model_logits = model_output[1] - model_output.pop() # remove logits - model_hs = model_output[0] # get last hidden state + model_output.pop() # remove logits + model_hs = model_output[0] # get last hidden state else: model_logits = model_output model_hs = None diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 8e4ec14d..f1a60e10 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -1,6 +1,7 @@ """ Utility methods used by LLMs """ + import re import json from pathlib import Path @@ -10,21 +11,30 @@ import torch.nn as nn import torch.nn.functional as F import torchtune.modules as ttm +from torchtune.models.convert_weights import hf_to_tune import math from safetensors.torch import load_file as load_safetensors -from transformers import ( - LogitsProcessorList, - TopKLogitsWarper, - TopPLogitsWarper, - TemperatureLogitsWarper -) +from transformers import LogitsProcessorList, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper from transformers.cache_utils import Cache, DynamicCache from exo.helpers import DEBUG from exo.inference.shard import Shard + +def get_torch_dtype(dtype_str: str) -> torch.dtype: + """ + Get dtype from setting in model's config.json + """ + if dtype_str == "bfloat16": + return torch.bfloat16 + elif dtype_str == "float16": + return torch.float16 + else: + return torch.float16 + + def load_model_config(model_config_path: Path) -> dict: """ Loads the config.json of the model @@ -37,9 +47,28 @@ def load_model_config(model_config_path: Path) -> dict: """ model_config = {} with open(model_config_path, "r") as f: - model_config = json.load(f) + base_config = json.load(f) + + model_config = { + "rope_scaling": base_config.get("rope_scaling"), + "embed_dim": base_config["hidden_size"], + "num_heads": base_config["num_attention_heads"], + "head_dim": base_config["hidden_size"] // base_config["num_attention_heads"], # Assuming embed_dim = hidden_size + "num_kv_heads": base_config["num_key_value_heads"], + "max_seq_len": base_config["max_position_embeddings"], + "intermediate_dim": base_config["intermediate_size"], + "attn_dropout": base_config.get("attention_dropout", 0.0), + "norm_eps": base_config["rms_norm_eps"], + "rope_base": base_config["rope_theta"], + "vocab_size": base_config["vocab_size"], + "num_layers": base_config["num_hidden_layers"], + "attn_bias": base_config.get("attention_bias", False), + "hidden_act": base_config.get("hidden_act", "silu") + } + return model_config + def check_weights(model, state_dict): """ Verifies that the weights from the state dictionary are properly loaded into the model. @@ -53,11 +82,12 @@ def check_weights(model, state_dict): print(f"Shape mismatch for {name}: expected {param.shape}, got {loaded_param.shape}") else: print(f"{name}: loaded correctly") - + for name in state_dict: if name not in model_state_dict: print(f"Unexpected weight {name} found in state_dict") + def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): """ Loads weights from huggingface and changes it to match torchtune naming structure @@ -70,231 +100,91 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): # Load weights from each found safetensors file paried_lmhead = True shard_layer_range = list(range(shard.start_layer, shard.end_layer)) + + full_state_dict = None for safetensor_file in safetensors_files: state_dict = load_safetensors(safetensor_file) - # remap to work with our model - remapped_state_dict = {} - paried_embed_weight = None - for key, value in state_dict.items(): - # load layer by shard - lnrgx = re.findall(r'model\.layers\.(\d+).*', key) - layer_num = int(lnrgx[0]) if len(lnrgx) > 0 else None - if layer_num in shard_layer_range: - # change input layer norm to sa_norm for torchtune - re_iln = re.findall( - rf'model.layers\.{layer_num}\.(input_layernorm)\.weight', key) - if len(re_iln) != 0: - remapped_state_dict[f"model.layers.{layer_num}.sa_norm.weight"] = value - - # change post attention layernorm to mlp_norm for torchtune - re_pal = re.findall( - rf'model.layers\.{layer_num}\.(post_attention_layernorm)\.weight', key) - if len(re_pal) != 0: - remapped_state_dict[f"model.layers.{layer_num}.mlp_norm.weight"] = value - - # change self_attn to attn - # along with changing o_proj to output_proj - re_attn = re.findall(rf'model\.layers\.{layer_num}.(\w+)\.(\w+)\.(\w+)', key) - if len(re_attn) != 0 and re_attn[0][0] == "self_attn": - if re_attn[0][1] == "o_proj": - remapped_state_dict[f"model.layers.{layer_num}.attn.output_proj.weight"] = value - else: - remapped_state_dict[f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}"] = value - - # saving embed for paired weights - elif key == 'model.embed_tokens.weight': - paried_embed_weight = value - # change name for torchtune - remapped_state_dict['model.tok_embeddings.weight'] = value - - elif key == 'lm_head.weight': - paried_lmhead = False - - # get everything else except layers, embed_tokens and lm_head - if ( - len(re.findall(r'model\.layers\..*', key)) == 0 - and key != "model.embed_tokens.weight" - and key != "lm_head.weight" - ): - remapped_state_dict[key] = value - - if paried_lmhead: - remapped_state_dict['model.output.weight'] = paried_embed_weight - - model.load_state_dict(remapped_state_dict, strict=False) - - #if DEBUG >= 7: - print("\n--- checking weights ----\n") - print(f"\nremapped_state_dict: {remapped_state_dict.keys()}\n") - check_weights(model, remapped_state_dict) - -def hf_logit_sample( - logits, - input_ids, - use_max: bool=False, - top_k: int=0, - top_p: float=0.9, - temp: float=1.0, -) -> torch.Tensor: - """ - Logit sampling using transformers - """ - logits_processor = LogitsProcessorList([ - TopKLogitsWarper(top_k), - TemperatureLogitsWarper(temp), - TopPLogitsWarper(top_p) - ]) - - # get a single cloned logit - logits = logits[:, -1, :].clone().float() - - next_token_scores = logits_processor(input_ids, logits) - - if not use_max: - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - else: - next_token = torch.argmax(next_token_scores, dim=-1) - - if DEBUG >= 4: - print(f"input_ids: {input_ids}") - print(f"next_token: {next_token}") - - return next_token[:, None].squeeze(-1) - -def create_4d_causal_attention_mask( - attention_mask: torch.Tensor, - seq_len: int, - target_len: int, - dtype: torch.dtype, - device: torch.device, - cache_pos: torch.Tensor, - batch_size: int, -) -> torch.Tensor: - """ - Creates a 4D causal attention mask from a 2D mask - - Args: - attention_mask (torch.Tensor): - A 2D tensor of shape (batch_size, key_value_length) or a 4D tensor of shape - (batch_size, 1, query_length, key_value_length). - seq_len (int): - Sequence length of the input being processed. - target_len (int): - Target length to generate the causal mask. - dtype (torch.dtype): - Data type for the causal mask. - device (torch.device): - Device to place the causal mask on. - cache_pos (torch.Tensor): - Cache position indices indicating the position of the input tokens in the sequence. - batch_size (int): - Number of samples in the batch. - - Returns: - torch.Tensor: - A 4D causal mask of shape (batch_size, 1, query_length, key_value_length). - """ - if attention_mask is not None and attention_mask.dim() == 4: - # If the mask is already 4D, return it directly - return attention_mask - - min_value = torch.finfo(dtype).min - - # Create a 2D causal mask of shape (seq_len, target_len) - causal_mask = torch.full( - (seq_len, target_len), fill_value=min_value, dtype=dtype, device=device - ) - - if seq_len != 1: - # Mask positions after the current position - causal_mask = torch.triu(causal_mask, diagonal=1) - - # Adjust causal mask for cache position - causal_mask *= (torch.arange(target_len, device=device) > cache_pos.view(-1, 1)) - - # Expand to 4D and batch size - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - - # Create a padding mask based on the input attention_mask - mask_len = attention_mask.shape[-1] - causal_mask = causal_mask.clone() # Ensure contiguous memory for in-place operations - padding_mask = causal_mask[:, :, :, :mask_len] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - - # Apply padding to the causal mask - causal_mask[:, :, :, :mask_len] = causal_mask[:, :, :, :mask_len].masked_fill( - padding_mask, min_value - ) - - return causal_mask - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - -class RotaryEmbedding(nn.Module): - """ - Rotary Position Embedding. - - This computes the inverse frequencies according to the original RoPE implementation. - There are other implementations that will be added. - Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py - """ - - def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.scaling_factor = scaling_factor - - # Initialize the inverse frequency for RoPE - inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Compute the rotary position embeddings (cos, sin) for the given input tensor. - - Args: - x (torch.Tensor): The input tensor of shape (batch_size, seq_len, num_heads, head_dim). - position_ids (torch.Tensor): The position indices for the sequence. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The cos and sin embeddings. - """ - # Expand inv_freq to match the batch size - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.size(0), -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - - # Compute cos and sin embeddings - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Apply the scaling factor to cos and sin embeddings - cos = cos * self.scaling_factor - sin = sin * self.scaling_factor + if full_state_dict is not None: + full_state_dict = full_state_dict | state_dict + else: + full_state_dict = state_dict + + # remap to work with our model + remapped_state_dict = {} + paried_embed_weight = None + for key, value in full_state_dict.items(): + # load layer by shard + lnrgx = re.findall(r"model\.layers\.(\d+).*", key) + layer_num = int(lnrgx[0]) if len(lnrgx) > 0 else None + if layer_num in shard_layer_range: + # change input layer norm to sa_norm for torchtune + re_iln = re.findall(rf"model.layers\.{layer_num}\.(input_layernorm)\.weight", key) + if len(re_iln) != 0: + new_key = f"model.layers.{layer_num}.sa_norm.weight" + # print(f"{key} == {new_key}") + remapped_state_dict[new_key] = value + + # change post attention layernorm to mlp_norm for torchtune + re_pal = re.findall(rf"model.layers\.{layer_num}\.(post_attention_layernorm)\.weight", key) + if len(re_pal) != 0: + new_key = f"model.layers.{layer_num}.mlp_norm.weight" + # print(f"{key} == {new_key}") + remapped_state_dict[new_key] = value + + # change self_attn to attn + # along with changing o_proj to output_proj + re_attn = re.findall(rf"model\.layers\.{layer_num}.(\w+)\.(\w+)\.(\w+)", key) + if len(re_attn) != 0 and re_attn[0][0] == "self_attn": + if re_attn[0][1] == "o_proj": + new_key = f"model.layers.{layer_num}.attn.output_proj.weight" + # print(f"{key} == {new_key}") + remapped_state_dict[new_key] = value + else: + new_key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}" + # print(f"{key} == {new_key}") + remapped_state_dict[new_key] = value + + # set mlp weights + re_mlp = re.findall(rf"model\.layers\.{layer_num}.mlp.(\w+)\.(\w+)", key) + if len(re_mlp) != 0: + new_key = f"model.layers.{layer_num}.mlp.{re_mlp[0][0]}.{re_mlp[0][1]}" + # print(f"load mlp {key}") + remapped_state_dict[new_key] = value + + # saving embed for paired weights + elif key == "model.embed_tokens.weight": + paried_embed_weight = value + # change name for torchtune + # print("model.embed_tokens.weight == model.tok_embeddings.weight") + remapped_state_dict["model.tok_embeddings.weight"] = value + + elif key == "lm_head.weight": + paried_lmhead = False + + # get everything else except layers, embed_tokens and lm_head + if len(re.findall(r"model\.layers\..*", key)) == 0 and key != "model.embed_tokens.weight" and key != "lm_head.weight": + # print(f"loading other weight: {key}") + remapped_state_dict[key] = value + + if paried_lmhead: + # print(f"model.output.weight: {paried_embed_weight}") + remapped_state_dict["model.output.weight"] = paried_embed_weight + + # print("\nRemapped state dict\n") + # for rsdk in remapped_state_dict.keys(): + # print(f"-- {rsdk}") + + model.load_state_dict(remapped_state_dict, strict=False) + + # if DEBUG >= 7: + # print("\n--- checking weights ----\n") + # print(f"\nremapped_state_dict: {remapped_state_dict.keys()}\n") + # check_weights(model, remapped_state_dict) - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class MultiLayerPreceptron(nn.Module): - def __init__( - self, - input_dim, - hidden_dim, - activation='silu', - use_bias=False - ): + def __init__(self, input_dim, hidden_dim, activation="silu", use_bias=False): """ General MLP (Multi-Layer Perceptron) module. @@ -303,24 +193,24 @@ def __init__( hidden_dims (int): Hidden layer/intermediate dimensions. output_dim (int): Dimensionality of the output. activation (str): Activation function ('relu', 'gelu', 'tanh', 'sigmoid', etc.). - dropout (float): Dropout probability. - use_batchnorm (bool): Whether to use batch normalization. + use_bias (bool): Use bias with linearization """ super(MultiLayerPreceptron, self).__init__() # Activation function mapping activations = { - 'relu': nn.ReLU(), - 'gelu': nn.GELU(), - 'tanh': nn.Tanh(), - 'sigmoid': nn.Sigmoid(), - 'leaky_relu': nn.LeakyReLU(0.2), - 'silu': nn.SiLU() + "relu": nn.ReLU(), + "gelu": nn.GELU(), + "tanh": nn.Tanh(), + "sigmoid": nn.Sigmoid(), + "leaky_relu": nn.LeakyReLU(0.2), + "silu": nn.SiLU() } # Ensure valid activation if activation not in activations: - raise ValueError(f"Invalid activation: {activation}. Choose from {list(activations.keys())}") + raise ValueError( + f"Invalid activation: {activation}. Choose from {list(activations.keys())}") # Construct MLP layers self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) @@ -329,341 +219,22 @@ def __init__( self.act_fn = activations[activation] def forward(self, x) -> torch.Tensor: - """ - Forward pass for the MLP module. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor after the MLP transformations. - """ + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return self.down_proj( - self.act_fn( - self.gate_proj(x) - ) * self.up_proj(x) - ) class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - RMSNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.eps = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - return self.weight * hidden_states.to(input_dtype) - -# ------------------ -# Attention Methods -# ------------------ - -class MultiHeadAttention(nn.Module): - """ - Multi-headed attention mechanism. - - Using the "attention is all you need" implementation. Other implementations will follow. - Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L277 - Ref: https://pytorch.org/torchtune/0.3/_modules/torchtune/modules/attention.html - """ - - def __init__( - self, - hidden_size, - num_heads, - num_kv_heads, - head_dim, - rotary_emb, - attention_dropout=0.0, - is_causal=True, - attention_bias=False - ): - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = head_dim - self.attention_dropout = attention_dropout - self.is_causal = is_causal - - # nn layers - self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=attention_bias) - self.rotary_emb = rotary_emb - - def forward( - self, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - kv_cache: Optional[ttm.KVCache] = None, - cos_sin_unsqueeze: int=1 - ) -> Tuple[torch.Tensor, Optional[ttm.KVCache]]: - batch_size, seq_len, _ = hidden_states.size() - - # Project to queries, keys, and values - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - print(f"query_states: {query_states.shape}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - # Reshape to [batch_size, num_heads, seq_len, head_dim] - query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - print(f"query_states: {query_states.shape}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - # Apply rotary positional embeddings if position_ids are provided - # or use position_embeddings - if position_embeddings is not None: - cos, sin = position_embeddings - else: - cos, sin = self.rotary_emb(value_states, position_ids) - - print(f"cos: {cos.shape} | sin: {sin.shape}") - # Expand cos and sin to match hidden_states' shape - cos = cos.unsqueeze(cos_sin_unsqueeze) - sin = sin.unsqueeze(cos_sin_unsqueeze) - print(f"cos: {cos.shape} | sin: {sin.shape}") - - # Apply rotary embeddings to queries and keys - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - print(f"query_states: {query_states.shape}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - # Forcing caching always enabled - if kv_cache is not None: - #print(f"kv_cache.size {kv_cache.size}") - - #print(f"key_states.size(2) {key_states.size(2)}") - - #if kv_cache.size != key_states.size(2): - # print(f"\n MAKE NEW KVCACHE batch_size={key_states.size(0)} max_seq_len={key_states.size(2)}") - # kv_cache = ttm.KVCache( - # batch_size=key_states.size(0), - # max_seq_len=key_states.size(2), - # num_heads=self.num_kv_heads, - # head_dim=self.head_dim, - # dtype=hidden_states.dtype - # ) - - key_states, value_states = kv_cache.update(key_states, value_states) - print(f"kv_cache: {kv_cache.size}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - # Repeat keys and values if needed - #if self.num_heads > self.num_kv_heads: - n_rep = self.num_heads // self.num_kv_heads - key_states = torch.repeat_interleave(key_states, n_rep, dim=1) - value_states = torch.repeat_interleave(value_states, n_rep, dim=1) - - print(f"query_states: {query_states.shape}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - # Compute attention scores - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - print(f"attn_weights: {attn_weights.shape}") - - # Apply attention mask, if provided - if attention_mask is not None: - print(f"attention_mask: {attention_mask.shape}") - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - print(f"causal_mask: {causal_mask.shape}") - attn_weights = attn_weights + causal_mask - print(f"attn_weights: {attn_weights.shape}") - - # Softmax normalization - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) - print(f"attn_weights: {attn_weights.shape}") - - # Compute attention output - attn_output = torch.matmul(attn_weights, value_states) - print(f"attn_output: {attn_output.shape}") - - # Transpose attention output - attn_output = attn_output.transpose(1,2).contiguous() - print(f"attn_output: {attn_output.shape}") - - # Reshape [batch_size, seq_len, -1] - attn_output = attn_output.reshape(batch_size, seq_len, -1) - print(f"attn_output after transpose: {attn_output.shape}") - - # Project back to hidden size - attn_output = self.o_proj(attn_output) - print(f"attn_output: {attn_output.shape}") - - return attn_output, kv_cache - -class SDPAttention(nn.Module): - """ - Scaled dot product attention mechanism. - - Using the scaled dot product attention method from pytorch - Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L524 - """ - - def __init__( - self, - hidden_size, - num_heads, - num_kv_heads, - head_dim, - rotary_emb, - attention_dropout=0.0, - is_causal=True, - attention_bias=False, - kv_max_seq_len=2048 - ): + def __init__(self, hidden_size, eps=1e-6): + """ + RMSNorm + designed for llama model but used for other models + """ super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = head_dim - self.attention_dropout = attention_dropout - self.is_causal = is_causal - self.kv_max_seq_len = kv_max_seq_len - - # nn layers - self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=attention_bias) - self.rotary_emb = rotary_emb - - def forward( - self, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - kv_cache: Optional[ttm.KVCache] = None, - cos_sin_unsqueeze: int=1 - ) -> Tuple[torch.Tensor, Optional[ttm.KVCache]]: - batch_size, seq_len, _ = hidden_states.size() - - # Project to queries, keys, and values - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - print(f"query_states: {query_states.shape}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - # Reshape to [batch_size, num_heads, seq_len, head_dim] - query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) - print(f"query_states: {query_states.shape}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - # Apply rotary positional embeddings if position_ids are provided - # or use position_embeddings - if position_embeddings is not None: - cos, sin = position_embeddings - else: - cos, sin = self.rotary_emb(value_states, position_ids) - - print(f"cos: {cos.shape} | sin: {sin.shape}") - # Expand cos and sin to match hidden_states' shape - cos = cos.unsqueeze(cos_sin_unsqueeze) - sin = sin.unsqueeze(cos_sin_unsqueeze) - print(f"cos: {cos.shape} | sin: {sin.shape}") - - # Apply rotary embeddings to queries and keys - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - print(f"query_states: {query_states.shape}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - # Caching - if kv_cache is not None: - if kv_cache.size >= self.max_seq_len: - # double the cache each time space is ran out - self.kv_max_seq_len = self.kv_max_seq_len + self.kv_max_seq_len - - kv_cache = ttm.KVCache( - batch_size=key_states.size(0), - max_seq_len=self.kv_max_seq_len, - num_heads=self.num_kv_heads, - head_dim=self.head_dim, - dtype=hidden_states.dtype - ) - - key_states, value_states = kv_cache.update(key_states, value_states) - - # **Slice KVCache to match `query_states` length** - key_states = key_states[:, :, :seq_len, :] - value_states = value_states[:, :, :seq_len, :] - - # kv_cache.update(key_states, value_states) - print(f"kv_cache: {kv_cache.size}") - print(f"from kv_cache / key_states: {key_states.shape}") - print(f"from kv_cache / value_states: {value_states.shape}") - - # Repeat keys and values if needed - #if self.num_heads > self.num_kv_heads: - n_rep = self.num_heads // self.num_kv_heads - key_states = torch.repeat_interleave(key_states, n_rep, dim=1) - value_states = torch.repeat_interleave(value_states, n_rep, dim=1) - - print(f"query_states: {query_states.shape}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - causal_mask = attention_mask - if causal_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - print(f"causal_mask: {causal_mask.shape}") - - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - print(f"query_states: {query_states.shape}") - print(f"key_states: {key_states.shape}") - print(f"value_states: {value_states.shape}") - - is_causal = True if causal_mask is None and seq_len > 1 else False - - attn_output = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=0.0, - is_causal=is_causal, - ) - - print(f"attn_output: {attn_output.shape}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, seq_len, -1) - - attn_output = self.o_proj(attn_output) - - print(f"attn_output: {attn_output.shape}") - - return attn_output, kv_cache - + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return self.weight * hidden_states.to(input_dtype) \ No newline at end of file diff --git a/exo/inference/torch/pt_inference.py b/exo/inference/torch/pt_inference.py new file mode 100644 index 00000000..7b8e7bba --- /dev/null +++ b/exo/inference/torch/pt_inference.py @@ -0,0 +1,5 @@ +""" +TorchDynamicShardInferenceEngine +Sharded inference engine using PyTorch based torchtune models +""" + diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py index 2b72b859..c7230c89 100644 --- a/exo/inference/torch/tests/test_inference_engine.py +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -5,7 +5,7 @@ import asyncio from exo.inference.shard import Shard -from exo.inference.torch.inference import TorchDynamicShardInferenceEngine +from exo.inference.torch.hf_inference import HFDynamicShardInferenceEngine from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.inference_engine import InferenceEngine @@ -119,8 +119,8 @@ async def test_inference_engine( try: print("\n\n -------- TEST Qwen/Qwen2-0.5B-Instruct -------- \n\n") asyncio.run(test_inference_engine( - TorchDynamicShardInferenceEngine(HFShardDownloader()), - TorchDynamicShardInferenceEngine(HFShardDownloader()), + HFDynamicShardInferenceEngine(HFShardDownloader()), + HFDynamicShardInferenceEngine(HFShardDownloader()), "Qwen/Qwen2-0.5B-Instruct", 36 )) diff --git a/exo/inference/torch/tests/test_llama3_full.py b/exo/inference/torch/tests/test_llama3_full.py new file mode 100644 index 00000000..a981db77 --- /dev/null +++ b/exo/inference/torch/tests/test_llama3_full.py @@ -0,0 +1,121 @@ +""" +Test of pytorch based llama3 models +full layer run +""" + +from pathlib import Path +import torch +from huggingface_hub import snapshot_download + +import torchtune.generation as ttg +from torchtune.models import llama3 +from torchtune.data import Message + + +from exo.inference.torch.models.llama3 import ShardedLlamaModel +from exo.inference.shard import Shard + +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_model_weights_torchtune, +) + + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +TEMP = 0.6 +TOP_K = 300 +MAX_GEN_TOKENS = 50 + +def main(model, prompt: str, device: torch.device=torch.device("cpu")): + # Tokenize input text + messages = [] + messages.extend([ + Message(role="system", content="You are a helpful and creative AI assistant."), + Message(role="user", content=prompt), + # Empty assistant message to kick-start generation + Message(role="assistant", content=""), + ]) + + tokenizer_out = llama_tokenizer({"messages": messages}, inference=True) + print(f"tokenizer_out: {tokenizer_out}") + tokens = torch.tensor(tokenizer_out["tokens"], dtype=torch.int, device=device) + + _, logits = model.generate(tokens=tokens) + + tokens = ttg.sample(logits=logits[:, -1].clone(), temperature=TEMP, top_k=TOP_K) + + print(f"tokens: {tokens}") + + generated_tokens = tokens.clone().tolist() + print(f"generated_tokens: {generated_tokens}") + print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(generated_tokens[0])}\n\n\n") + + +def normal_full(model, user_prompt: str, device: torch.device=torch.device("cpu")): + # Tokenize input text + messages = [] + messages.extend([ + Message(role="system", content="You are a helpful and creative AI assistant."), + Message(role="user", content=user_prompt), + # Empty assistant message to kick-start generation + Message(role="assistant", content=""), + ]) + + tokenizer_out = llama_tokenizer({"messages": messages}, inference=True) + prompt = torch.tensor(tokenizer_out["tokens"], dtype=torch.int, device=device) + print(f"tokens prompt: {prompt}") + print(f"pad_id: {llama_tokenizer.pad_id}") + + generated_tokens, _ = ttg.generate( + model=model.model, + prompt=prompt, + max_generated_tokens=MAX_GEN_TOKENS, + pad_id=llama_tokenizer.pad_id, + temperature=TEMP, + top_k=TOP_K, + stop_tokens=llama_tokenizer.stop_tokens, + ) + generated_tokens = generated_tokens[:, -MAX_GEN_TOKENS:].tolist() + + print(f"generated_tokens: {generated_tokens}") + + print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(generated_tokens[0])}\n\n\n") + + +if __name__ == "__main__": + # prompt = "hello" + prompt = "What is the capital of france?" + + # Get the path to the model files from the Hugging Face cache + cache_dir = Path(snapshot_download(MODEL_NAME)) + print(f"Cache directory: {cache_dir}") + + # Load model configuration + config = load_model_config(cache_dir / "config.json") + + print(f"current config\n{config}") + + # Setup shard + n_layers = int(config["num_layers"]) + shard_1 = Shard( + model_id=MODEL_NAME, + start_layer=0, + end_layer=n_layers-1, + n_layers=n_layers, + ) + + # Initialize tokenizer + llama_tokenizer_path = f"{cache_dir}/original/tokenizer.model" + llama_tokenizer = llama3.llama3_tokenizer(path=llama_tokenizer_path) + print(llama_tokenizer.stop_tokens) + + # Initialize LlamaModel with config and tokenizer + # device = torch.device("cuda") + device = None + shard_model_1 = ShardedLlamaModel(config, shard_1, llama_tokenizer, device=device) + print(f"\nshard_model_1: {shard_model_1}") + + load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) + + # main(shard_model_1, prompt, device) + normal_full(shard_model_1, prompt, device) diff --git a/exo/inference/torch/tests/test_llama3_model.py b/exo/inference/torch/tests/test_llama3_model.py deleted file mode 100644 index 1b16d91f..00000000 --- a/exo/inference/torch/tests/test_llama3_model.py +++ /dev/null @@ -1,265 +0,0 @@ -""" -Test of pytorch based llama3 model -""" -from pathlib import Path -import gc -import time -import torch -from transformers import AutoTokenizer -from huggingface_hub import snapshot_download - -import torchtune.generation as ttg -from torchtune.models import llama3 -from torchtune.data import Message - - -from exo.inference.torch.models.llama3 import ShardedLlamaModel -from exo.inference.shard import Shard - -from exo.inference.torch.models.llm_utils import ( - load_model_config, - load_model_weights_torchtune, -) - - -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" -TEMP=0.6 -TOP_K=35 -TOP_P=0.9 -MAX_SEQ_LEN=2048 - -def test_generation_1(shard_model, text): - """ - Test the generation capabilities of the LlamaModel with sample text. - """ - # Tokenize input text - messages = [] - messages.extend( - [ - Message(role="user", content=text), - # Empty assistant message to kick-start generation - Message(role="assistant", content=""), - ] - ) - - tokenizer_out = llama_tokenizer({"messages": messages}, inference=True) - print(f"tokenizer_out: {tokenizer_out}") - tokens = tokenizer_out["tokens"] - prompt = torch.tensor(tokens, dtype=torch.int) - - hidden_states, logits = shard_model.generate(prompt) - - if hidden_states is not None: - print(f"hidden_states[{len(hidden_states)}]: {hidden_states}") - - if logits is not None: - print(f"logits: {logits.shape}\n{logits}") - - return hidden_states, logits, prompt - #if prompt.ndim == 1: - # prompt = prompt.view(1, -1) - - #bsz, prompt_length = prompt.size() - #total_response_length = prompt_length + MAX_SEQ_LEN - #generated_tokens = prompt.clone() - #resp_max_seq_len = ( - # total_response_length - # if not shard_model.model.caches_are_enabled() - # else shard_model.model.decoder_max_cache_seq_len - #) - - ## masking for proper attention - #padding_masks = prompt != llama_tokenizer.pad_id - #if not padding_masks.all(): - # padding_masks = torch.nn.functional.pad( - # padding_masks, - # (0, MAX_SEQ_LEN), - # value=True - # ) - - # masks = ttg.get_causal_mask_from_padding_mask( - # padding_masks, - # target_seq_len=resp_max_seq_len - # ) - - # input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) - #else: - # masks = torch.tril( - # torch.ones( - # total_response_length, - # resp_max_seq_len if resp_max_seq_len is not None else MAX_SEQ_LEN, - # dtype=torch.bool, - # device=prompt.device, - # ) - # ).unsqueeze(0) - - # input_pos = torch.arange( - # 0, total_response_length, device=prompt.device - # ).unsqueeze(0) - - #if shard_model.model.caches_are_enabled(): - # curr_masks = masks[:, :prompt_length] - #else: - # curr_masks = masks[:, :prompt_length, :prompt_length] - - #rand_sample = torch.empty( - # ( - # prompt.size(0), - # self.model.tok_embeddings.num_embeddings - # ), device=prompt.device - #).exponential_(1, generator=None) - - #print(f"padding_masks: {padding_masks.shape}") - #print(padding_masks.all()) - - ## this can be sepearted out for dist inference - ## see https://github.com/pytorch/torchtune/blob/bc4acc19ffab2366a14468c97294992dbb7c50d1/torchtune/generation/_generation.py#L66 - #next_token, gen_logits = ttg.generate_next_token( - # shard_model.model, - # input_pos=input_pos[:, :prompt_length].squeeze(), - # x=prompt, - # mask=curr_masks, - # q=rand_sample - #) - - #print(f"next_token: {next_token}") - - #generated_tokens = torch.cat([generated_tokens, next_token], dim=-1) - - #print(f"generated_tokens: {generated_tokens}") - - #curr_pos = prompt_length - - ## stop tokens logic - #stop_tokens = None - #stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device) - #stop_tokens = ( - # torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype) - # if stop_tokens - # else None - #) - #stop_token_mask = torch.ones( - # (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device - #) - - ## finish writing stop token logic using torchtune generation - ## ref https://github.com/pytorch/torchtune/blob/main/torchtune/generation/_generation.py#L337 - - #for _ in range(max_length): - - # if shard_model.model.caches_are_enabled(): - # curr_input_pos = input_pos[:, curr_pos] - # curr_masks = masks[:, curr_pos, None, :] - # else: - # tokens = generated_tokens.clone() - # curr_input_pos = input_pos[:, : curr_pos + 1] - # curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1] - - #generated_tokens = generated_tokens.tolist() - #print(f"resp: {llama_tokenizer.decode(generated_tokens[0])}") - -def test_generation_2(shard_model, tokens, hidden_state): - print("Generate with the rest of layers") - hidden_states, logits = shard_model.generate( - tokens=tokens, - hidden_state=hidden_state - ) - - if hidden_states is not None: - print(f"hidden_states {hidden_states.shape}: {hidden_states}") - - if logits is not None: - print(f"logits: {logits.shape}\n{logits}") - - rand_sample = torch.empty( - ( - logits.size(0), - shard_model.model.tok_embeddings.num_embeddings - ), device=logits.device - ).exponential_(1, generator=None) - - logit = ttg.sample( - logits=logits[:, -1].clone(), - temperature=TEMP, - top_k=TOP_K, - q=rand_sample - ) - - print(f"logit: {logit}") - - generated_tokens = tokens.clone() - generated_tokens = torch.cat([generated_tokens, logit.squeeze(-1)], dim=-1).tolist() - - print(f"generated_tokens: {generated_tokens}") - - print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(logit.squeeze(-1).tolist())}\n\n\n") - - return hidden_states, logits - -if __name__ == "__main__": - print("\nTesting generation:") - # Get the path to the model files from the Hugging Face cache - cache_dir = Path(snapshot_download(MODEL_NAME)) - print(f"Cache directory: {cache_dir}") - - # Load model configuration - config = load_model_config(cache_dir / "config.json") - - print(f"current config\n{config}") - - # Setup shard - s1_end = int(int(config["num_hidden_layers"])/2) - shard_1 = Shard( - model_id=MODEL_NAME, - start_layer=0, - end_layer=s1_end, - n_layers=int(config["num_hidden_layers"]) - ) - - s2_start = s1_end + 1 - s2_end = shard_1.n_layers - 1 - shard_2 = Shard( - model_id=MODEL_NAME, - start_layer=s2_start, - end_layer=s2_end, - n_layers=int(config["num_hidden_layers"]) - ) - - # Initialize tokenizer - llama_tokenizer_path = f"{cache_dir}/original/tokenizer.model" - llama_tokenizer = llama3.llama3_tokenizer(path=llama_tokenizer_path) - #tokenizer = AutoTokenizer.from_pretrained( - # MODEL_NAME, - # add_eos_token=True - #) - - # Initialize LlamaModel with config and tokenizer - shard_model_1 = ShardedLlamaModel(config, shard_1, llama_tokenizer, use_cache=True) - print(f"\nshard_model_1: {shard_model_1}") - load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) - - # Sample text for testing - #prompt = "First letter in the word 'Red'" - prompt = "GM, say it back" - shard_1_hs, shard_1_logits, shard_1_tokens = test_generation_1(shard_model_1, prompt) - - print(f"shard_1_hs:\n{shard_1_hs}") - print(f"shard_1_logits:\n{shard_1_logits}") - print(f"shard_1_tokens:\n{shard_1_tokens}") - - gc.collect() - torch.cuda.empty_cache() - - if shard_model_1.model.caches_are_enabled(): - shard_model_1.model.reset_caches() - - del shard_model_1.model - del shard_model_1 - - #time.sleep(10) - - shard_model_2 = ShardedLlamaModel(config, shard_2, llama_tokenizer, use_cache=True) - print(f"\nshard_model_2: {shard_model_2}") - load_model_weights_torchtune(cache_dir, shard_2, shard_model_2) - shard_2_hs, shard_2_logits = test_generation_2(shard_model_2, shard_1_tokens, shard_1_hs) - diff --git a/exo/inference/torch/tests/test_llama3_split.py b/exo/inference/torch/tests/test_llama3_split.py new file mode 100644 index 00000000..7bc0fe7c --- /dev/null +++ b/exo/inference/torch/tests/test_llama3_split.py @@ -0,0 +1,131 @@ +""" +Test of pytorch based llama3 model +""" + +from pathlib import Path +import torch +from huggingface_hub import snapshot_download + +import torchtune.generation as ttg +from torchtune.models import llama3 +from torchtune.data import Message + + +from exo.inference.torch.models.llama3 import ShardedLlamaModel +from exo.inference.shard import Shard + +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_model_weights_torchtune, +) + + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +TEMP = 0.6 +TOP_K = 300 + +def test_generation_1(shard_model, prompt): + """ + Test the generation capabilities of the LlamaModel with sample text. + """ + # Tokenize input text + messages = [] + messages.extend([ + Message(role="system", content="You are a helpful and creative AI assistant."), + Message(role="user", content=prompt), + # Empty assistant message to kick-start generation + Message(role="assistant", content=""), + ]) + + print(f"last?: {shard_model.shard.is_last_layer()}") + tokenizer_out = llama_tokenizer({"messages": messages}, inference=True) + print(f"tokenizer_out: {tokenizer_out}") + tokens = torch.tensor(tokenizer_out["tokens"], dtype=torch.int) + + hidden_states, _ = shard_model.generate(tokens) + + if hidden_states is not None: + print(f"hidden_states[{len(hidden_states)}]: {hidden_states}") + + return hidden_states, tokens + + +def test_generation_2(shard_model, in_tokens, hidden_state): + print("Generate with the rest of layers") + hidden_states, logits = shard_model.generate( + tokens=in_tokens, + hidden_state=hidden_state + ) + + if hidden_states is not None: + print(f"hidden_states {hidden_states.shape}: {hidden_states}") + + if logits is not None: + print(f"logits: {logits.shape}\n{logits}") + + # rand_sample = torch.empty(( + # logits.size(0), + # shard_model.model.tok_embeddings.num_embeddings + # ), + # device=logits.device + # ).exponential_(1, generator=None) + + tokens = ttg.sample( + logits=logits[:, -1].clone(), + temperature=TEMP, + top_k=TOP_K, + # q=rand_sample + ) + + print(f"tokens: {tokens}") + + generated_tokens = tokens.clone() + generated_tokens = generated_tokens.tolist() + + print(f"generated_tokens: {generated_tokens}") + + print(f"\n\n[resp from model]\n\n{llama_tokenizer.decode(generated_tokens[0])}\n\n\n") + + +if __name__ == "__main__": + print("\nTesting generation:") + + prompt = "What is the capital of france? Say it in one word and nothing else, please." + + # Get the path to the model files from the Hugging Face cache + cache_dir = Path(snapshot_download(MODEL_NAME)) + + # Load model configuration + config = load_model_config(cache_dir / "config.json") + + # Setup shard + n_layers = int(config["num_layers"]) + s1_end = int(n_layers/2) + shard_1 = Shard( + model_id=MODEL_NAME, + start_layer=0, + end_layer=s1_end, + n_layers=n_layers + ) + + shard_2 = Shard( + model_id=MODEL_NAME, + start_layer=s1_end + 1, + end_layer=n_layers - 1, + n_layers=n_layers + ) + + # Initialize tokenizer + llama_tokenizer_path = f"{cache_dir}/original/tokenizer.model" + llama_tokenizer = llama3.llama3_tokenizer(path=llama_tokenizer_path) + + # Initialize LlamaModel with config and tokenizer + shard_model_1 = ShardedLlamaModel(config, shard_1, llama_tokenizer) + print(f"\nshard_model_1: {shard_model_1}") + load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) + shard_1_hs, shard_1_tokens = test_generation_1(shard_model_1, prompt) + + shard_model_2 = ShardedLlamaModel(config, shard_2, llama_tokenizer) + print(f"\nshard_model_2: {shard_model_2}") + load_model_weights_torchtune(cache_dir, shard_2, shard_model_2) + test_generation_2(shard_model_2, shard_1_tokens, shard_1_hs) From d958bf98d4a18556e8c330a31f21fb547ed20fef Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 17 Nov 2024 06:09:15 -0900 Subject: [PATCH 486/491] split model working, updates to safetensor loading letting shard control --- exo/inference/torch/models/llama3.py | 2 +- exo/inference/torch/models/llm_utils.py | 10 ++++------ exo/inference/torch/tests/test_llama3_full.py | 3 +-- exo/inference/torch/tests/test_llama3_split.py | 4 ++-- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index cb3456eb..dfa0aad9 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -230,7 +230,7 @@ def __init__( self.device = device if device is not None else torch.device("cpu") self.use_cache = self.config.get("use_cache", False) self.model = LlamaModel(config, self.shard).to(dtype=self.dtype, device=self.device) - self.max_seq_len = max_seq_len if max_seq_len is not None else 4096 + self.max_seq_len = self.config["max_seq_len"] def generate(self, tokens: torch.Tensor, hidden_state: Optional[torch.Tensor] = None) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: """ diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index f1a60e10..68c9ec8a 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -76,7 +76,7 @@ def check_weights(model, state_dict): model_state_dict = model.state_dict() for name, param in model_state_dict.items(): if name in state_dict: - print(f"\nchecking {name}\n") + # print(f"\nchecking {name}\n") loaded_param = state_dict[name] if param.shape != loaded_param.shape: print(f"Shape mismatch for {name}: expected {param.shape}, got {loaded_param.shape}") @@ -115,9 +115,7 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): paried_embed_weight = None for key, value in full_state_dict.items(): # load layer by shard - lnrgx = re.findall(r"model\.layers\.(\d+).*", key) - layer_num = int(lnrgx[0]) if len(lnrgx) > 0 else None - if layer_num in shard_layer_range: + for layer_num in range(shard.start_layer, shard.end_layer + 1): # change input layer norm to sa_norm for torchtune re_iln = re.findall(rf"model.layers\.{layer_num}\.(input_layernorm)\.weight", key) if len(re_iln) != 0: @@ -153,7 +151,7 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): remapped_state_dict[new_key] = value # saving embed for paired weights - elif key == "model.embed_tokens.weight": + if key == "model.embed_tokens.weight": paried_embed_weight = value # change name for torchtune # print("model.embed_tokens.weight == model.tok_embeddings.weight") @@ -180,7 +178,7 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): # if DEBUG >= 7: # print("\n--- checking weights ----\n") # print(f"\nremapped_state_dict: {remapped_state_dict.keys()}\n") - # check_weights(model, remapped_state_dict) + check_weights(model, remapped_state_dict) class MultiLayerPreceptron(nn.Module): diff --git a/exo/inference/torch/tests/test_llama3_full.py b/exo/inference/torch/tests/test_llama3_full.py index a981db77..dabdae21 100644 --- a/exo/inference/torch/tests/test_llama3_full.py +++ b/exo/inference/torch/tests/test_llama3_full.py @@ -23,7 +23,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" TEMP = 0.6 -TOP_K = 300 +TOP_K = 25 MAX_GEN_TOKENS = 50 def main(model, prompt: str, device: torch.device=torch.device("cpu")): @@ -107,7 +107,6 @@ def normal_full(model, user_prompt: str, device: torch.device=torch.device("cpu" # Initialize tokenizer llama_tokenizer_path = f"{cache_dir}/original/tokenizer.model" llama_tokenizer = llama3.llama3_tokenizer(path=llama_tokenizer_path) - print(llama_tokenizer.stop_tokens) # Initialize LlamaModel with config and tokenizer # device = torch.device("cuda") diff --git a/exo/inference/torch/tests/test_llama3_split.py b/exo/inference/torch/tests/test_llama3_split.py index 7bc0fe7c..0d5f69df 100644 --- a/exo/inference/torch/tests/test_llama3_split.py +++ b/exo/inference/torch/tests/test_llama3_split.py @@ -22,7 +22,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" TEMP = 0.6 -TOP_K = 300 +TOP_K = 25 def test_generation_1(shard_model, prompt): """ @@ -90,7 +90,7 @@ def test_generation_2(shard_model, in_tokens, hidden_state): if __name__ == "__main__": print("\nTesting generation:") - prompt = "What is the capital of france? Say it in one word and nothing else, please." + prompt = "Say 'Hello'" # Get the path to the model files from the Hugging Face cache cache_dir = Path(snapshot_download(MODEL_NAME)) From c8bdb0971c4b8836715903af1b86704ce4bd3c6b Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 17 Nov 2024 08:31:14 -0900 Subject: [PATCH 487/491] reduced model loading ram by loading only some layers in layer list, inference is still very high --- exo/inference/torch/models/llama3.py | 137 ++++++++++++------ exo/inference/torch/models/llm_utils.py | 13 +- exo/inference/torch/tests/test_llama3_full.py | 2 + .../torch/tests/test_llama3_split.py | 41 +++--- 4 files changed, 124 insertions(+), 69 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index dfa0aad9..9d2b9607 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -14,11 +14,7 @@ from torchtune.modules.attention_utils import _MaskType from exo.inference.shard import Shard -from exo.inference.torch.models.llm_utils import ( - MultiLayerPreceptron, - RMSNorm, - get_torch_dtype -) +from exo.inference.torch.models.llm_utils import MultiLayerPreceptron, RMSNorm, get_torch_dtype class ShardTransformerDecoder(ttm.TransformerDecoder): @@ -27,6 +23,7 @@ class ShardTransformerDecoder(ttm.TransformerDecoder): Custom version of torchtune TransformerDecoder to allow for sharding of models and passing of hidden layers between shards """ + def __init__( self, *, @@ -55,6 +52,44 @@ def __init__( self.shard = shard + def setup_caches( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, + ): + """ + modified version for shard + + assume just decoder layers + """ + if decoder_max_seq_len is not None: + self.decoder_max_cache_seq_len = decoder_max_seq_len + else: + self.decoder_max_cache_seq_len = self.max_seq_len + + for layer in self.layers: + if layer is not None: + layer.setup_caches( + batch_size, + dtype, + encoder_max_seq_len=self.encoder_max_cache_seq_len, + decoder_max_seq_len=self.decoder_max_cache_seq_len, + ) + + def caches_are_enabled(self) -> bool: + """ + modified version for shard + """ + if self.layers[0] is not None: + return self.layers[0].caches_are_enabled() + else: + for layer in self.layers: + if layer is not None: + return layer.caches_are_enabled() + def forward( self, tokens: torch.Tensor, @@ -90,18 +125,19 @@ def forward( print(f"\nhidden layer in H[{i}]\n{h}\nmask\n{mask}\ninput_pos\n{input_pos}\n{self.output_hidden_states}\n") # Process through each transformer layer - h = layer( - h, - mask=mask, - encoder_input=encoder_input, - encoder_mask=encoder_mask, - input_pos=input_pos, - ) + with torch.no_grad(): + h = layer( + h, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) - if i in self.output_hidden_states: - hidden.append(h) + if i in self.output_hidden_states: + hidden.append(h) - print(f"\nhidden layer out H[{i}]->H[{i + 1}]\n{h}\n") + print(f"\nhidden layer out H[{i}]->H[{i + 1}]\n{h}\n") # Apply normalization h = self.norm(h) @@ -117,6 +153,7 @@ def forward( print(f"\n\noutput {output}\n\n") return output + def LlamaModel(config: dict, shard: Shard): """ LlamaModel using torchtune @@ -132,8 +169,10 @@ def LlamaModel(config: dict, shard: Shard): scale_factor=scale_factor, ) - layers = [] - for _ in range(shard.n_layers): + # hack to align sharded weights with layers + # fill unused layer positions with None + layers = [None for _ in range(shard.n_layers)] + for i in range(shard.start_layer, shard.end_layer + 1): self_attn = ttm.MultiHeadAttention( embed_dim=config["embed_dim"], num_heads=config["num_heads"], @@ -164,11 +203,7 @@ def LlamaModel(config: dict, shard: Shard): pos_embeddings=rope, ) - mlp = MultiLayerPreceptron( - config["embed_dim"], - config["intermediate_dim"], - config["hidden_act"] - ) + mlp = MultiLayerPreceptron(config["embed_dim"], config["intermediate_dim"], config["hidden_act"]) layer = ttm.TransformerSelfAttentionLayer( attn=self_attn, @@ -177,16 +212,18 @@ def LlamaModel(config: dict, shard: Shard): mlp_norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), ) - layers.append(layer) - + layers[i] = layer + + for i in range(len(layers)): + print(f"layers[{i}]: {layers[i]}") layers = nn.ModuleList(layers) tok_embeddings = nn.Embedding(config["vocab_size"], config["embed_dim"]) - # output_proj = ttm.TiedLinear(tok_embeddings) - output_proj = nn.Linear( - config["embed_dim"], - config["vocab_size"], - bias=config["attn_bias"], - ) + output_proj = ttm.TiedLinear(tok_embeddings) + # output_proj = nn.Linear( + # config["embed_dim"], + # config["vocab_size"], + # bias=config["attn_bias"], + # ) return ShardTransformerDecoder( tok_embeddings=tok_embeddings, @@ -197,7 +234,7 @@ def LlamaModel(config: dict, shard: Shard): head_dim=config["head_dim"], norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), output=output_proj, - num_layers=config["num_layers"] + num_layers=config["num_layers"], ) # return ttm.TransformerDecoder( @@ -214,13 +251,14 @@ def LlamaModel(config: dict, shard: Shard): class ShardedLlamaModel(nn.Module): def __init__( - self, - config: dict, - shard: Shard, - tokenizer: Any, - device: Optional[torch.device] = None, - max_seq_len: Optional[int] = None - ): + self, + config: dict, + shard: Shard, + tokenizer: Any, + device: Optional[torch.device] = None, + max_new_tokens: Optional[int] = 10, + use_cache: Optional[bool] = False + ): super(ShardedLlamaModel, self).__init__() self.tokenizer = tokenizer @@ -228,11 +266,19 @@ def __init__( self.config = config self.dtype = get_torch_dtype(self.config["torch_dtype"]) if "torch_dtype" in self.config else torch.float self.device = device if device is not None else torch.device("cpu") - self.use_cache = self.config.get("use_cache", False) - self.model = LlamaModel(config, self.shard).to(dtype=self.dtype, device=self.device) + self.use_cache = self.config.get("use_cache", False) if not use_cache else use_cache + + + self.max_new_tokens = max_new_tokens self.max_seq_len = self.config["max_seq_len"] - def generate(self, tokens: torch.Tensor, hidden_state: Optional[torch.Tensor] = None) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: + self.model = LlamaModel(config, self.shard).to(dtype=self.dtype, device=self.device) + + def generate( + self, + tokens: torch.Tensor, + hidden_state: Optional[torch.Tensor] = None + ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: """ Generate logits and/or hidden_states from llama model @@ -241,6 +287,7 @@ def generate(self, tokens: torch.Tensor, hidden_state: Optional[torch.Tensor] = hidden_state (torch.Tensor, optional) - hidden state from last activated hidden layer, if any max_seq_len (int) - Max sequence length of generation, default 4096 """ + print(self.shard) print(self.shard.is_last_layer()) @@ -252,7 +299,11 @@ def generate(self, tokens: torch.Tensor, hidden_state: Optional[torch.Tensor] = # setup cache if not self.model.caches_are_enabled() and self.use_cache: with self.device: - self.model.setup_caches(bsz, self.dtype, decoder_max_seq_len=self.model.decoder_max_cache_seq_len) + self.model.setup_caches( + bsz, + self.dtype, + decoder_max_seq_len=tokens.numel() + self.max_new_tokens + ) if not self.shard.is_last_layer(): self.model.output_hidden_states = [self.shard.end_layer] @@ -282,7 +333,7 @@ def generate(self, tokens: torch.Tensor, hidden_state: Optional[torch.Tensor] = ).unsqueeze(0) input_pos = torch.arange(0, total_response_length, device=generated_tokens.device).unsqueeze(0) - + if self.model.caches_are_enabled(): curr_masks = masks[:, :tokens_length] else: diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py index 68c9ec8a..9edd779a 100644 --- a/exo/inference/torch/models/llm_utils.py +++ b/exo/inference/torch/models/llm_utils.py @@ -152,27 +152,28 @@ def load_model_weights_torchtune(cache_dir: Path, shard: Shard, model: Any): # saving embed for paired weights if key == "model.embed_tokens.weight": - paried_embed_weight = value + # paried_embed_weight = value # change name for torchtune # print("model.embed_tokens.weight == model.tok_embeddings.weight") remapped_state_dict["model.tok_embeddings.weight"] = value - elif key == "lm_head.weight": - paried_lmhead = False + # elif key == "lm_head.weight": + # paried_lmhead = False # get everything else except layers, embed_tokens and lm_head if len(re.findall(r"model\.layers\..*", key)) == 0 and key != "model.embed_tokens.weight" and key != "lm_head.weight": # print(f"loading other weight: {key}") remapped_state_dict[key] = value - if paried_lmhead: + # if paried_lmhead: # print(f"model.output.weight: {paried_embed_weight}") - remapped_state_dict["model.output.weight"] = paried_embed_weight + # remapped_state_dict["model.output.weight"] = paried_embed_weight # print("\nRemapped state dict\n") # for rsdk in remapped_state_dict.keys(): # print(f"-- {rsdk}") - + del state_dict + del full_state_dict model.load_state_dict(remapped_state_dict, strict=False) # if DEBUG >= 7: diff --git a/exo/inference/torch/tests/test_llama3_full.py b/exo/inference/torch/tests/test_llama3_full.py index dabdae21..f8b93160 100644 --- a/exo/inference/torch/tests/test_llama3_full.py +++ b/exo/inference/torch/tests/test_llama3_full.py @@ -66,6 +66,7 @@ def normal_full(model, user_prompt: str, device: torch.device=torch.device("cpu" print(f"tokens prompt: {prompt}") print(f"pad_id: {llama_tokenizer.pad_id}") + generated_tokens, _ = ttg.generate( model=model.model, prompt=prompt, @@ -75,6 +76,7 @@ def normal_full(model, user_prompt: str, device: torch.device=torch.device("cpu" top_k=TOP_K, stop_tokens=llama_tokenizer.stop_tokens, ) + generated_tokens = generated_tokens[:, -MAX_GEN_TOKENS:].tolist() print(f"generated_tokens: {generated_tokens}") diff --git a/exo/inference/torch/tests/test_llama3_split.py b/exo/inference/torch/tests/test_llama3_split.py index 0d5f69df..c7155e30 100644 --- a/exo/inference/torch/tests/test_llama3_split.py +++ b/exo/inference/torch/tests/test_llama3_split.py @@ -23,6 +23,8 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" TEMP = 0.6 TOP_K = 25 +MAX_NEW_TOKENS=10 + def test_generation_1(shard_model, prompt): """ @@ -52,10 +54,7 @@ def test_generation_1(shard_model, prompt): def test_generation_2(shard_model, in_tokens, hidden_state): print("Generate with the rest of layers") - hidden_states, logits = shard_model.generate( - tokens=in_tokens, - hidden_state=hidden_state - ) + hidden_states, logits = shard_model.generate(tokens=in_tokens, hidden_state=hidden_state) if hidden_states is not None: print(f"hidden_states {hidden_states.shape}: {hidden_states}") @@ -90,7 +89,7 @@ def test_generation_2(shard_model, in_tokens, hidden_state): if __name__ == "__main__": print("\nTesting generation:") - prompt = "Say 'Hello'" + prompt = "Hello, just say 'Hello' back nothing else" # Get the path to the model files from the Hugging Face cache cache_dir = Path(snapshot_download(MODEL_NAME)) @@ -100,32 +99,34 @@ def test_generation_2(shard_model, in_tokens, hidden_state): # Setup shard n_layers = int(config["num_layers"]) - s1_end = int(n_layers/2) - shard_1 = Shard( - model_id=MODEL_NAME, - start_layer=0, - end_layer=s1_end, - n_layers=n_layers - ) + s1_end = int(n_layers / 2) + shard_1 = Shard(model_id=MODEL_NAME, start_layer=0, end_layer=s1_end, n_layers=n_layers) - shard_2 = Shard( - model_id=MODEL_NAME, - start_layer=s1_end + 1, - end_layer=n_layers - 1, - n_layers=n_layers - ) + shard_2 = Shard(model_id=MODEL_NAME, start_layer=s1_end + 1, end_layer=n_layers - 1, n_layers=n_layers) # Initialize tokenizer llama_tokenizer_path = f"{cache_dir}/original/tokenizer.model" llama_tokenizer = llama3.llama3_tokenizer(path=llama_tokenizer_path) # Initialize LlamaModel with config and tokenizer - shard_model_1 = ShardedLlamaModel(config, shard_1, llama_tokenizer) + shard_model_1 = ShardedLlamaModel( + config, + shard_1, + llama_tokenizer, + None, + MAX_NEW_TOKENS + ) print(f"\nshard_model_1: {shard_model_1}") load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) shard_1_hs, shard_1_tokens = test_generation_1(shard_model_1, prompt) - shard_model_2 = ShardedLlamaModel(config, shard_2, llama_tokenizer) + shard_model_2 = ShardedLlamaModel( + config, + shard_2, + llama_tokenizer, + None, + MAX_NEW_TOKENS + ) print(f"\nshard_model_2: {shard_model_2}") load_model_weights_torchtune(cache_dir, shard_2, shard_model_2) test_generation_2(shard_model_2, shard_1_tokens, shard_1_hs) From 75817ebd9c36ee63ce59e8892f5dea1273291c1d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 17 Nov 2024 09:07:45 -0900 Subject: [PATCH 488/491] updating readme --- exo/inference/torch/README.md | 48 +++++++++++-------- exo/inference/torch/models/llama3.py | 11 +++-- exo/inference/torch/tests/test_llama3_full.py | 16 +++++-- .../torch/tests/test_llama3_split.py | 6 ++- 4 files changed, 49 insertions(+), 32 deletions(-) diff --git a/exo/inference/torch/README.md b/exo/inference/torch/README.md index 2ac5a743..9d4e757d 100644 --- a/exo/inference/torch/README.md +++ b/exo/inference/torch/README.md @@ -20,6 +20,9 @@ Working on removing transformers due to inference and VRAM usage [issues](https: ### 10/27/2024 Still working on llama3 model but wanted to note that a better KVCache needs to be investigated. +#### 11/17/2024 +Llama sharded model now working and next step is inference engine. Still testing on small llama 3.2 1B but will try larger models. + ## Tech Tested on @@ -58,29 +61,32 @@ WIP pytorch llama model ``` # Llama-3.2-1B-Instruct # -LlamaModel( - (embed): Embedding(128256, 2048) - (layers): ModuleList( - (0-15): 16 x LlamaBlock( - (self_attn): SDPAttention( - (q_proj): Linear(in_features=2048, out_features=2048, bias=False) - (k_proj): Linear(in_features=2048, out_features=512, bias=False) - (v_proj): Linear(in_features=2048, out_features=512, bias=False) - (o_proj): Linear(in_features=2048, out_features=2048, bias=False) - (rotary_emb): RotaryEmbedding() - ) - (mlp): MultiLayerPreceptron( - (gate_proj): Linear(in_features=2048, out_features=8192, bias=False) - (up_proj): Linear(in_features=2048, out_features=8192, bias=False) - (down_proj): Linear(in_features=8192, out_features=2048, bias=False) - (act_fn): SiLU() +ShardedLlamaModel( + (model): ShardTransformerDecoder( + (tok_embeddings): Embedding(128256, 2048) + (layers): ModuleList( + (0-15): 16 x TransformerSelfAttentionLayer( + (attn): MultiHeadAttention( + (q_proj): Linear(in_features=2048, out_features=2048, bias=False) + (k_proj): Linear(in_features=2048, out_features=512, bias=False) + (v_proj): Linear(in_features=2048, out_features=512, bias=False) + (output_proj): Linear(in_features=2048, out_features=2048, bias=False) + (pos_embeddings): Llama3ScaledRoPE() + ) + (mlp): MultiLayerPreceptron( + (gate_proj): Linear(in_features=2048, out_features=8192, bias=False) + (up_proj): Linear(in_features=2048, out_features=8192, bias=False) + (down_proj): Linear(in_features=8192, out_features=2048, bias=False) + (act_fn): SiLU() + ) + (sa_norm): RMSNorm() + (mlp_norm): RMSNorm() + (sa_scale): Identity() + (mlp_scale): Identity() ) - (input_layer_norm): RMSNorm() - (post_attention_norm): RMSNorm() ) + (norm): RMSNorm() ) - (norm): RMSNorm() - (rotary_pos_emb): RotaryEmbedding() - (lm_head): Linear(in_features=2048, out_features=128256, bias=False) ) + ``` diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 9d2b9607..7b0076cb 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -70,6 +70,7 @@ def setup_caches( else: self.decoder_max_cache_seq_len = self.max_seq_len + lic = 0 for layer in self.layers: if layer is not None: layer.setup_caches( @@ -78,6 +79,9 @@ def setup_caches( encoder_max_seq_len=self.encoder_max_cache_seq_len, decoder_max_seq_len=self.decoder_max_cache_seq_len, ) + + print(f"Setup cache for layer {lic}") + lic+=1 def caches_are_enabled(self) -> bool: """ @@ -266,7 +270,7 @@ def __init__( self.config = config self.dtype = get_torch_dtype(self.config["torch_dtype"]) if "torch_dtype" in self.config else torch.float self.device = device if device is not None else torch.device("cpu") - self.use_cache = self.config.get("use_cache", False) if not use_cache else use_cache + self.use_cache = use_cache if use_cache else self.config.get("use_cache", False) self.max_new_tokens = max_new_tokens @@ -287,10 +291,6 @@ def generate( hidden_state (torch.Tensor, optional) - hidden state from last activated hidden layer, if any max_seq_len (int) - Max sequence length of generation, default 4096 """ - - print(self.shard) - print(self.shard.is_last_layer()) - if tokens.ndim == 1: tokens = tokens.view(1, -1) @@ -299,6 +299,7 @@ def generate( # setup cache if not self.model.caches_are_enabled() and self.use_cache: with self.device: + print("\n\nSETTING UP CACHES\n\n") self.model.setup_caches( bsz, self.dtype, diff --git a/exo/inference/torch/tests/test_llama3_full.py b/exo/inference/torch/tests/test_llama3_full.py index f8b93160..3a2afcdb 100644 --- a/exo/inference/torch/tests/test_llama3_full.py +++ b/exo/inference/torch/tests/test_llama3_full.py @@ -24,7 +24,7 @@ MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" TEMP = 0.6 TOP_K = 25 -MAX_GEN_TOKENS = 50 +MAX_NEW_TOKENS = 10 def main(model, prompt: str, device: torch.device=torch.device("cpu")): # Tokenize input text @@ -70,14 +70,14 @@ def normal_full(model, user_prompt: str, device: torch.device=torch.device("cpu" generated_tokens, _ = ttg.generate( model=model.model, prompt=prompt, - max_generated_tokens=MAX_GEN_TOKENS, + max_generated_tokens=MAX_NEW_TOKENS, pad_id=llama_tokenizer.pad_id, temperature=TEMP, top_k=TOP_K, stop_tokens=llama_tokenizer.stop_tokens, ) - generated_tokens = generated_tokens[:, -MAX_GEN_TOKENS:].tolist() + generated_tokens = generated_tokens[:, -MAX_NEW_TOKENS:].tolist() print(f"generated_tokens: {generated_tokens}") @@ -113,8 +113,16 @@ def normal_full(model, user_prompt: str, device: torch.device=torch.device("cpu" # Initialize LlamaModel with config and tokenizer # device = torch.device("cuda") device = None - shard_model_1 = ShardedLlamaModel(config, shard_1, llama_tokenizer, device=device) + shard_model_1 = ShardedLlamaModel( + config, + shard_1, + llama_tokenizer, + device, + MAX_NEW_TOKENS, + use_cache=True + ) print(f"\nshard_model_1: {shard_model_1}") + exit() load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) diff --git a/exo/inference/torch/tests/test_llama3_split.py b/exo/inference/torch/tests/test_llama3_split.py index c7155e30..68272765 100644 --- a/exo/inference/torch/tests/test_llama3_split.py +++ b/exo/inference/torch/tests/test_llama3_split.py @@ -114,7 +114,8 @@ def test_generation_2(shard_model, in_tokens, hidden_state): shard_1, llama_tokenizer, None, - MAX_NEW_TOKENS + MAX_NEW_TOKENS, + use_cache=True ) print(f"\nshard_model_1: {shard_model_1}") load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) @@ -125,7 +126,8 @@ def test_generation_2(shard_model, in_tokens, hidden_state): shard_2, llama_tokenizer, None, - MAX_NEW_TOKENS + MAX_NEW_TOKENS, + use_cache=True ) print(f"\nshard_model_2: {shard_model_2}") load_model_weights_torchtune(cache_dir, shard_2, shard_model_2) From 73630d1ef92d64c843c40413e483cf1c8fb5a0ed Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Mon, 18 Nov 2024 11:50:58 -0900 Subject: [PATCH 489/491] building out torch inference engine --- exo/inference/torch/models/llama3.py | 8 +-- exo/inference/torch/pt_inference.py | 62 +++++++++++++++++++ exo/inference/torch/tests/test_llama3_full.py | 1 - 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 7b0076cb..5356218f 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -70,7 +70,6 @@ def setup_caches( else: self.decoder_max_cache_seq_len = self.max_seq_len - lic = 0 for layer in self.layers: if layer is not None: layer.setup_caches( @@ -79,9 +78,6 @@ def setup_caches( encoder_max_seq_len=self.encoder_max_cache_seq_len, decoder_max_seq_len=self.decoder_max_cache_seq_len, ) - - print(f"Setup cache for layer {lic}") - lic+=1 def caches_are_enabled(self) -> bool: """ @@ -287,9 +283,8 @@ def generate( Generate logits and/or hidden_states from llama model Args - tokens (torch.Tensor) - tokens from prompt tokenization + tokens (torch.Tensor) - tokens from prompt tokenization and generation hidden_state (torch.Tensor, optional) - hidden state from last activated hidden layer, if any - max_seq_len (int) - Max sequence length of generation, default 4096 """ if tokens.ndim == 1: tokens = tokens.view(1, -1) @@ -299,7 +294,6 @@ def generate( # setup cache if not self.model.caches_are_enabled() and self.use_cache: with self.device: - print("\n\nSETTING UP CACHES\n\n") self.model.setup_caches( bsz, self.dtype, diff --git a/exo/inference/torch/pt_inference.py b/exo/inference/torch/pt_inference.py index 7b8e7bba..e88ec060 100644 --- a/exo/inference/torch/pt_inference.py +++ b/exo/inference/torch/pt_inference.py @@ -2,4 +2,66 @@ TorchDynamicShardInferenceEngine Sharded inference engine using PyTorch based torchtune models """ +import os +import asyncio +import torch + +from torchtune.models import llama3 + +from exo.inference.inference_engine import InferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.shard import Shard +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_model_weights_torchtune, +) + +# supported models +from exo.inference.torch.models.llama3 import ShardedLlamaModel + +TEMP = 0.6 +TOP_K = 25 + +class TorchDynamicShardInferenceEngine(InferenceEngine): + def __init__(self, shard_downloader: HFShardDownloader, model_id: str="llama"): + self.shard = None + self.shard_downloader = shard_downloader + self.model_id = model_id + self.supported_models = ["llama"] + + # device settings + if os.environ.get("TORCH_DEVICE"): + self.device = torch.device(os.environ["TORCH_DEVICE"]) + elif torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + + async def ensure_shard(self, shard: Shard): + if self.shard == shard: + return + + # download model safetensors and shard + model_path = await self.shard_downloader.ensure_shard(shard) + model_config = load_model_config(model_path / "config.json") + + self.tokenizer = llama3.llama3_tokenizer( + path=f"{model_path}/original/tokenizer.model" + ) + + if self.model_id not in self.supported_models: + raise ValueError( + f"Model {self.model_id} not supported, only supported models are\n{self.supported_models}" + ) + + self.sharded_model = ShardedLlamaModel( + model_config, + shard, + self.tokenizer, + self.device, + None, + use_cache=True + ) diff --git a/exo/inference/torch/tests/test_llama3_full.py b/exo/inference/torch/tests/test_llama3_full.py index 3a2afcdb..7ffb4dce 100644 --- a/exo/inference/torch/tests/test_llama3_full.py +++ b/exo/inference/torch/tests/test_llama3_full.py @@ -122,7 +122,6 @@ def normal_full(model, user_prompt: str, device: torch.device=torch.device("cpu" use_cache=True ) print(f"\nshard_model_1: {shard_model_1}") - exit() load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) From ad993324f615a22941f0d764a0a87c8af54e53eb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 23 Nov 2024 10:44:05 -0900 Subject: [PATCH 490/491] creating torch inference engine, separated torch and hf torch engines, adding hf engine tests --- exo/inference/torch/hf_inference.py | 10 ++- exo/inference/torch/models/llama3.py | 61 +++++++++++++---- exo/inference/torch/pt_inference.py | 67 ++++++++++++++++++- ..._engine.py => test_hf_inference_engine.py} | 0 .../torch/tests/test_pt_inference_engine.py | 53 +++++++++++++++ 5 files changed, 169 insertions(+), 22 deletions(-) rename exo/inference/torch/tests/{test_inference_engine.py => test_hf_inference_engine.py} (100%) create mode 100644 exo/inference/torch/tests/test_pt_inference_engine.py diff --git a/exo/inference/torch/hf_inference.py b/exo/inference/torch/hf_inference.py index 1b4f19e0..4912a0a2 100644 --- a/exo/inference/torch/hf_inference.py +++ b/exo/inference/torch/hf_inference.py @@ -1,4 +1,7 @@ -# experimental, based off of tinygrad/inference.py +""" +HFDynamicShardInferenceEngine +Sharded inference engine using PyTorch based HuggingFace transformers +""" import asyncio import os import json @@ -26,11 +29,6 @@ TOP_P = 0.9 class HFDynamicShardInferenceEngine(InferenceEngine): - """ - HuggingFace Dynamic Shard Inference Engine - Performing model inference with sharded Pytorch based HuggingFace models. - """ - def __init__(self, shard_downloader: HFShardDownloader): """ Initialize the inference engine. diff --git a/exo/inference/torch/models/llama3.py b/exo/inference/torch/models/llama3.py index 5356218f..feef0baa 100644 --- a/exo/inference/torch/models/llama3.py +++ b/exo/inference/torch/models/llama3.py @@ -78,7 +78,7 @@ def setup_caches( encoder_max_seq_len=self.encoder_max_cache_seq_len, decoder_max_seq_len=self.decoder_max_cache_seq_len, ) - + def caches_are_enabled(self) -> bool: """ modified version for shard @@ -89,7 +89,7 @@ def caches_are_enabled(self) -> bool: for layer in self.layers: if layer is not None: return layer.caches_are_enabled() - + def forward( self, tokens: torch.Tensor, @@ -159,6 +159,7 @@ def LlamaModel(config: dict, shard: Shard): LlamaModel using torchtune """ # rope scaling config + scale_factor = 32 if config["rope_scaling"] is not None: scale_factor = config["rope_scaling"].get("factor", 32) @@ -214,8 +215,8 @@ def LlamaModel(config: dict, shard: Shard): layers[i] = layer - for i in range(len(layers)): - print(f"layers[{i}]: {layers[i]}") + #for i in range(len(layers)): + # print(f"layers[{i}]: {layers[i]}") layers = nn.ModuleList(layers) tok_embeddings = nn.Embedding(config["vocab_size"], config["embed_dim"]) output_proj = ttm.TiedLinear(tok_embeddings) @@ -251,12 +252,12 @@ def LlamaModel(config: dict, shard: Shard): class ShardedLlamaModel(nn.Module): def __init__( - self, - config: dict, - shard: Shard, - tokenizer: Any, + self, + config: dict, + shard: Shard, + tokenizer: Any, device: Optional[torch.device] = None, - max_new_tokens: Optional[int] = 10, + max_new_tokens: int = 2048, use_cache: Optional[bool] = False ): super(ShardedLlamaModel, self).__init__() @@ -266,19 +267,23 @@ def __init__( self.config = config self.dtype = get_torch_dtype(self.config["torch_dtype"]) if "torch_dtype" in self.config else torch.float self.device = device if device is not None else torch.device("cpu") - self.use_cache = use_cache if use_cache else self.config.get("use_cache", False) - - self.max_new_tokens = max_new_tokens self.max_seq_len = self.config["max_seq_len"] + if use_cache: + self.use_cache = use_cache + else: + self.config.get("use_cache", False) + self.model = LlamaModel(config, self.shard).to(dtype=self.dtype, device=self.device) + print(f"model loaded: {self.model}\n") + def generate( self, tokens: torch.Tensor, hidden_state: Optional[torch.Tensor] = None - ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], bool]: """ Generate logits and/or hidden_states from llama model @@ -292,6 +297,7 @@ def generate( bsz, tokens_length = tokens.size() # setup cache + print(self.model) if not self.model.caches_are_enabled() and self.use_cache: with self.device: self.model.setup_caches( @@ -351,6 +357,26 @@ def generate( print(f"\nmodel_output: {model_output}") + # stop token + stop_tokens = None + + stop_token_reached = torch.zeros( + bsz, + dtype=torch.bool, + device=tokens.device + ) + stop_tokens = ( + torch.tensor( + stop_tokens, + device=tokens.device, + dtype=tokens.dtype + ) + if stop_tokens + else None + ) + + finished = False + if isinstance(model_output, list): model_logits = model_output[1] model_output.pop() # remove logits @@ -359,4 +385,11 @@ def generate( model_logits = model_output model_hs = None - return model_hs, model_logits + if stop_tokens is not None: + stop_token_reached = ttg._generation.update_stop_tokens_tracker( + tokens, stop_tokens, stop_token_reached + ) + + finished = True if stop_token_reached.all() else False + + return model_hs, model_logits, finished diff --git a/exo/inference/torch/pt_inference.py b/exo/inference/torch/pt_inference.py index e88ec060..b5a1c8fd 100644 --- a/exo/inference/torch/pt_inference.py +++ b/exo/inference/torch/pt_inference.py @@ -3,7 +3,11 @@ Sharded inference engine using PyTorch based torchtune models """ import os +from typing import Optional, Tuple, Union, List +import functools +from concurrent.futures import ThreadPoolExecutor +import numpy as np import asyncio import torch @@ -12,6 +16,7 @@ from exo.inference.inference_engine import InferenceEngine from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.shard import Shard +from exo.helpers import DEBUG from exo.inference.torch.models.llm_utils import ( load_model_config, load_model_weights_torchtune, @@ -40,6 +45,57 @@ def __init__(self, shard_downloader: HFShardDownloader, model_id: str="llama"): else: self.device = torch.device("cpu") + async def infer_prompt( + self, + request_id: str, + shard: Shard, + prompt: str, + image_str: Optional[str] = None, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_prompt called") + print(f"prompt: {prompt}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + # ensure shard + await self.ensure_shard(shard) + + # tokenize + tokens = torch.tensor( + self.tokenizer.encode(prompt, add_bos=True, add_eos=True), + dtype=torch.int + ) + hidden_states = None + + # generate + loop = asyncio.get_running_loop() + with ThreadPoolExecutor() as pool: + hidden_states, logits, finished = await loop.run_in_executor( + pool, + functools.partial( + self.sharded_model.generate, + tokens=tokens + ) + ) + + if hidden_states is not None: + return hidden_states.numpy(force=True), "", finished + else: + return logits.numpy(force=True), "", finished + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + # ensure shard + await self.ensure_shard(shard) + + return np.empty((1,1)), "", False + async def ensure_shard(self, shard: Shard): if self.shard == shard: return @@ -58,10 +114,17 @@ async def ensure_shard(self, shard: Shard): ) self.sharded_model = ShardedLlamaModel( - model_config, - shard, + model_config, + shard, self.tokenizer, self.device, None, use_cache=True ) + + # load sharded weights + load_model_weights_torchtune( + model_path, + shard, + self.sharded_model + ) diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_hf_inference_engine.py similarity index 100% rename from exo/inference/torch/tests/test_inference_engine.py rename to exo/inference/torch/tests/test_hf_inference_engine.py diff --git a/exo/inference/torch/tests/test_pt_inference_engine.py b/exo/inference/torch/tests/test_pt_inference_engine.py new file mode 100644 index 00000000..e430989a --- /dev/null +++ b/exo/inference/torch/tests/test_pt_inference_engine.py @@ -0,0 +1,53 @@ +""" +Test inference engine and model sharding +""" +import time +import asyncio + +from exo.inference.shard import Shard +from exo.inference.torch.pt_inference import TorchDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.inference_engine import InferenceEngine + +import numpy as np + +async def test_inference_engine( + inference_engine_1: InferenceEngine, + inference_engine_2: InferenceEngine, + model_id: str, + n_layers: int): + + prompt = "In a single word only, what is the last name of the current president of the USA?" + + shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=0, + n_layers=n_layers + ) + + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + "A", + shard=shard, + prompt=prompt + ) + + print("\n------------resp_full---------------\n") + print(resp_full) + print("\n------------resp_full---------------\n") + + time.sleep(5) + +if __name__ == '__main__': + try: + print("\n\n -------- TEST meta-llama/Llama-3.2-1B-Instruct -------- \n\n") + asyncio.run(test_inference_engine( + TorchDynamicShardInferenceEngine(HFShardDownloader()), + TorchDynamicShardInferenceEngine(HFShardDownloader()), + "meta-llama/Llama-3.2-1B-Instruct", + 16 + )) + except Exception as err: + print(f"\n!!!! LLAMA TEST FAILED \n{err}\n") + + From 6ab6f1c504ce46fec812ff283078ac7fc971e78d Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sat, 23 Nov 2024 11:08:59 -0900 Subject: [PATCH 491/491] merge --- build/lib/exo/__init__.py | 1 + build/lib/exo/api/__init__.py | 1 + build/lib/exo/api/chatgpt_api.py | 358 +++++++++++ build/lib/exo/download/__init__.py | 0 build/lib/exo/download/download_progress.py | 61 ++ build/lib/exo/download/hf/__init__.py | 0 build/lib/exo/download/hf/hf_helpers.py | 403 ++++++++++++ .../lib/exo/download/hf/hf_shard_download.py | 77 +++ build/lib/exo/download/shard_download.py | 26 + build/lib/exo/helpers.py | 234 +++++++ build/lib/exo/inference/__init__.py | 0 .../exo/inference/debug_inference_engine.py | 59 ++ build/lib/exo/inference/inference_engine.py | 34 + build/lib/exo/inference/mlx/__init__.py | 0 .../lib/exo/inference/mlx/models/__init__.py | 0 build/lib/exo/inference/mlx/models/base.py | 9 + .../exo/inference/mlx/models/deepseek_v2.py | 127 ++++ build/lib/exo/inference/mlx/models/llama.py | 125 ++++ build/lib/exo/inference/mlx/models/llava.py | 585 ++++++++++++++++++ .../inference/mlx/sharded_inference_engine.py | 40 ++ build/lib/exo/inference/mlx/sharded_model.py | 86 +++ build/lib/exo/inference/mlx/sharded_utils.py | 207 +++++++ .../exo/inference/mlx/test_sharded_llama.py | 40 ++ .../exo/inference/mlx/test_sharded_llava.py | 64 ++ .../exo/inference/mlx/test_sharded_model.py | 52 ++ build/lib/exo/inference/pytorch/__init__.py | 0 build/lib/exo/inference/pytorch/helpers.py | 24 + build/lib/exo/inference/pytorch/inference.py | 211 +++++++ .../exo/inference/pytorch/model/__init__.py | 0 build/lib/exo/inference/pytorch/model/hf.py | 155 +++++ .../lib/exo/inference/pytorch/model/utils.py | 83 +++ .../pytorch/test_inference_engine.py | 141 +++++ build/lib/exo/inference/shard.py | 39 ++ .../exo/inference/test_inference_engine.py | 64 ++ build/lib/exo/inference/tokenizers.py | 45 ++ build/lib/exo/models.py | 44 ++ build/lib/exo/networking/__init__.py | 5 + build/lib/exo/networking/discovery.py | 17 + build/lib/exo/networking/grpc/__init__.py | 0 .../lib/exo/networking/grpc/grpc_discovery.py | 188 ++++++ .../exo/networking/grpc/grpc_peer_handle.py | 109 ++++ build/lib/exo/networking/grpc/grpc_server.py | 118 ++++ .../exo/networking/grpc/node_service_pb2.py | 61 ++ .../networking/grpc/node_service_pb2_grpc.py | 272 ++++++++ .../networking/grpc/test_grpc_discovery.py | 22 + build/lib/exo/networking/peer_handle.py | 48 ++ build/lib/exo/networking/server.py | 11 + build/lib/exo/orchestration/__init__.py | 4 + build/lib/exo/orchestration/node.py | 47 ++ build/lib/exo/orchestration/standard_node.py | 385 ++++++++++++ build/lib/exo/orchestration/test_node.py | 57 ++ build/lib/exo/stats/__init__.py | 0 build/lib/exo/stats/metrics.py | 29 + build/lib/exo/test_callbacks.py | 50 ++ build/lib/exo/topology/__init__.py | 0 build/lib/exo/topology/device_capabilities.py | 207 +++++++ .../lib/exo/topology/partitioning_strategy.py | 40 ++ ...g_memory_weighted_partitioning_strategy.py | 18 + .../exo/topology/test_device_capabilities.py | 91 +++ build/lib/exo/topology/test_map_partitions.py | 81 +++ ...g_memory_weighted_partitioning_strategy.py | 90 +++ build/lib/exo/topology/topology.py | 49 ++ build/lib/exo/viz/__init__.py | 0 build/lib/exo/viz/test_topology_viz.py | 129 ++++ build/lib/exo/viz/topology_viz.py | 307 +++++++++ 65 files changed, 5830 insertions(+) create mode 100644 build/lib/exo/__init__.py create mode 100644 build/lib/exo/api/__init__.py create mode 100644 build/lib/exo/api/chatgpt_api.py create mode 100644 build/lib/exo/download/__init__.py create mode 100644 build/lib/exo/download/download_progress.py create mode 100644 build/lib/exo/download/hf/__init__.py create mode 100644 build/lib/exo/download/hf/hf_helpers.py create mode 100644 build/lib/exo/download/hf/hf_shard_download.py create mode 100644 build/lib/exo/download/shard_download.py create mode 100644 build/lib/exo/helpers.py create mode 100644 build/lib/exo/inference/__init__.py create mode 100644 build/lib/exo/inference/debug_inference_engine.py create mode 100644 build/lib/exo/inference/inference_engine.py create mode 100644 build/lib/exo/inference/mlx/__init__.py create mode 100644 build/lib/exo/inference/mlx/models/__init__.py create mode 100644 build/lib/exo/inference/mlx/models/base.py create mode 100644 build/lib/exo/inference/mlx/models/deepseek_v2.py create mode 100644 build/lib/exo/inference/mlx/models/llama.py create mode 100644 build/lib/exo/inference/mlx/models/llava.py create mode 100644 build/lib/exo/inference/mlx/sharded_inference_engine.py create mode 100644 build/lib/exo/inference/mlx/sharded_model.py create mode 100644 build/lib/exo/inference/mlx/sharded_utils.py create mode 100644 build/lib/exo/inference/mlx/test_sharded_llama.py create mode 100644 build/lib/exo/inference/mlx/test_sharded_llava.py create mode 100644 build/lib/exo/inference/mlx/test_sharded_model.py create mode 100644 build/lib/exo/inference/pytorch/__init__.py create mode 100644 build/lib/exo/inference/pytorch/helpers.py create mode 100644 build/lib/exo/inference/pytorch/inference.py create mode 100644 build/lib/exo/inference/pytorch/model/__init__.py create mode 100644 build/lib/exo/inference/pytorch/model/hf.py create mode 100644 build/lib/exo/inference/pytorch/model/utils.py create mode 100644 build/lib/exo/inference/pytorch/test_inference_engine.py create mode 100644 build/lib/exo/inference/shard.py create mode 100644 build/lib/exo/inference/test_inference_engine.py create mode 100644 build/lib/exo/inference/tokenizers.py create mode 100644 build/lib/exo/models.py create mode 100644 build/lib/exo/networking/__init__.py create mode 100644 build/lib/exo/networking/discovery.py create mode 100644 build/lib/exo/networking/grpc/__init__.py create mode 100644 build/lib/exo/networking/grpc/grpc_discovery.py create mode 100644 build/lib/exo/networking/grpc/grpc_peer_handle.py create mode 100644 build/lib/exo/networking/grpc/grpc_server.py create mode 100644 build/lib/exo/networking/grpc/node_service_pb2.py create mode 100644 build/lib/exo/networking/grpc/node_service_pb2_grpc.py create mode 100644 build/lib/exo/networking/grpc/test_grpc_discovery.py create mode 100644 build/lib/exo/networking/peer_handle.py create mode 100644 build/lib/exo/networking/server.py create mode 100644 build/lib/exo/orchestration/__init__.py create mode 100644 build/lib/exo/orchestration/node.py create mode 100644 build/lib/exo/orchestration/standard_node.py create mode 100644 build/lib/exo/orchestration/test_node.py create mode 100644 build/lib/exo/stats/__init__.py create mode 100644 build/lib/exo/stats/metrics.py create mode 100644 build/lib/exo/test_callbacks.py create mode 100644 build/lib/exo/topology/__init__.py create mode 100644 build/lib/exo/topology/device_capabilities.py create mode 100644 build/lib/exo/topology/partitioning_strategy.py create mode 100644 build/lib/exo/topology/ring_memory_weighted_partitioning_strategy.py create mode 100644 build/lib/exo/topology/test_device_capabilities.py create mode 100644 build/lib/exo/topology/test_map_partitions.py create mode 100644 build/lib/exo/topology/test_ring_memory_weighted_partitioning_strategy.py create mode 100644 build/lib/exo/topology/topology.py create mode 100644 build/lib/exo/viz/__init__.py create mode 100644 build/lib/exo/viz/test_topology_viz.py create mode 100644 build/lib/exo/viz/topology_viz.py diff --git a/build/lib/exo/__init__.py b/build/lib/exo/__init__.py new file mode 100644 index 00000000..e802d331 --- /dev/null +++ b/build/lib/exo/__init__.py @@ -0,0 +1 @@ +from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION diff --git a/build/lib/exo/api/__init__.py b/build/lib/exo/api/__init__.py new file mode 100644 index 00000000..660e7507 --- /dev/null +++ b/build/lib/exo/api/__init__.py @@ -0,0 +1 @@ +from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI diff --git a/build/lib/exo/api/chatgpt_api.py b/build/lib/exo/api/chatgpt_api.py new file mode 100644 index 00000000..1abda85f --- /dev/null +++ b/build/lib/exo/api/chatgpt_api.py @@ -0,0 +1,358 @@ +import uuid +import time +import asyncio +import json +from pathlib import Path +from transformers import AutoTokenizer +from typing import List, Literal, Union, Dict +from aiohttp import web +import aiohttp_cors +import traceback +from exo import DEBUG, VERSION +from exo.helpers import PrefixDict +from exo.inference.shard import Shard +from exo.inference.tokenizers import resolve_tokenizer +from exo.orchestration import Node +from exo.models import model_base_shards +from typing import Callable + +class Message: + def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]): + self.role = role + self.content = content + + def to_dict(self): + return {"role": self.role, "content": self.content} + + +class ChatCompletionRequest: + def __init__(self, model: str, messages: List[Message], temperature: float): + self.model = model + self.messages = messages + self.temperature = temperature + + def to_dict(self): + return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature} + + +def generate_completion( + chat_request: ChatCompletionRequest, + tokenizer, + prompt: str, + request_id: str, + tokens: List[int], + stream: bool, + finish_reason: Union[Literal["length", "stop"], None], + object_type: Literal["chat.completion", "text_completion"], +) -> dict: + completion = { + "id": f"chatcmpl-{request_id}", + "object": object_type, + "created": int(time.time()), + "model": chat_request.model, + "system_fingerprint": f"exo_{VERSION}", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": tokenizer.decode(tokens)}, + "logprobs": None, + "finish_reason": finish_reason, + }], + } + + if not stream: + completion["usage"] = { + "prompt_tokens": len(tokenizer.encode(prompt)), + "completion_tokens": len(tokens), + "total_tokens": len(tokenizer.encode(prompt)) + len(tokens), + } + + choice = completion["choices"][0] + if object_type.startswith("chat.completion"): + key_name = "delta" if stream else "message" + choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)} + elif object_type == "text_completion": + choice["text"] = tokenizer.decode(tokens) + else: + ValueError(f"Unsupported response type: {object_type}") + + return completion + + +def remap_messages(messages: List[Message]) -> List[Message]: + remapped_messages = [] + last_image = None + for message in messages: + if not isinstance(message.content, list): + remapped_messages.append(message) + continue + + remapped_content = [] + for content in message.content: + if isinstance(content, dict): + if content.get("type") in ["image_url", "image"]: + image_url = content.get("image_url", {}).get("url") or content.get("image") + if image_url: + last_image = {"type": "image", "image": image_url} + remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"}) + else: + remapped_content.append(content) + else: + remapped_content.append(content) + remapped_messages.append(Message(role=message.role, content=remapped_content)) + + if last_image: + # Replace the last image placeholder with the actual image content + for message in reversed(remapped_messages): + for i, content in enumerate(message.content): + if isinstance(content, dict): + if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]": + message.content[i] = last_image + return remapped_messages + + return remapped_messages + + +def build_prompt(tokenizer, _messages: List[Message]): + if len(_messages) == 1: + user_msg = _messages[0] + + # get instruct sys message + sys_msg = Message(role="system", content="You are a helpful assistant.") + + # restructure for sys_msg to go first + _messages = [sys_msg, user_msg] + + messages = remap_messages(_messages) + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + if DEBUG >= 3: + print(f"prompt: {str(prompt)}") + for msg in messages: + print(f"chat role: {msg.role}\ncontent: {msg.content}") + + image_str = None + for message in messages: + if not isinstance(message.content, list): + continue + + for content in message.content: + # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41 + # follows the convention in https://platform.openai.com/docs/guides/vision + if isinstance(content, dict) and content.get("type", None) == "image": + image_str = content.get("image", None) + break + + return prompt, image_str + + +def parse_message(data: dict): + if "role" not in data or "content" not in data: + raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'") + return Message(data["role"], data["content"]) + + +def parse_chat_request(data: dict): + return ChatCompletionRequest( + data.get("model", "llama-3.1-8b"), + [parse_message(msg) for msg in data["messages"]], + data.get("temperature", 0.0), + ) + + +class PromptSession: + def __init__(self, request_id: str, timestamp: int, prompt: str): + self.request_id = request_id + self.timestamp = timestamp + self.prompt = prompt + + +class ChatGPTAPI: + def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None): + self.node = node + self.inference_engine_classname = inference_engine_classname + self.response_timeout_secs = response_timeout_secs + self.on_chat_completion_request = on_chat_completion_request + self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload + self.prompts: PrefixDict[str, PromptSession] = PrefixDict() + self.prev_token_lens: Dict[str, int] = {} + self.stream_tasks: Dict[str, asyncio.Task] = {} + cors = aiohttp_cors.setup(self.app) + cors_options = aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + allow_methods="*", + ) + cors.add(self.app.router.add_get("/models", self.handle_get_models), {"*": cors_options}) + cors.add(self.app.router.add_get("/v1/models", self.handle_get_models), {"*": cors_options}) + cors.add(self.app.router.add_post("/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}) + cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}) + cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options}) + cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options}) + + self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat" + self.app.router.add_get("/", self.handle_root) + self.app.router.add_static("/", self.static_dir, name="static") + + # Add middleware to log every request + self.app.middlewares.append(self.log_request) + + async def log_request(self, app, handler): + async def middleware(request): + if DEBUG >= 2: print(f"Received request: {request.method} {request.path}") + return await handler(request) + + return middleware + + async def handle_root(self, request): + return web.FileResponse(self.static_dir/"index.html") + + async def handle_get_models(self, request): + return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True } for model_name, _ in model_base_shards.items()]) + + async def handle_post_chat_token_encode(self, request): + data = await request.json() + shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname) + messages = [parse_message(msg) for msg in data.get("messages", [])] + tokenizer = await resolve_tokenizer(shard.model_id) + return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])}) + + async def handle_post_chat_completions(self, request): + data = await request.json() + if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}") + stream = data.get("stream", False) + chat_request = parse_chat_request(data) + if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead + chat_request.model = "llama-3.1-8b" + if not chat_request.model or chat_request.model not in model_base_shards: + if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b") + chat_request.model = "llama-3.1-8b" + shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None) + if not shard: + supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines] + return web.json_response( + {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"}, + status=400, + ) + + tokenizer = await resolve_tokenizer(shard.model_id) + if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}") + + prompt, image_str = build_prompt(tokenizer, chat_request.messages) + request_id = str(uuid.uuid4()) + if self.on_chat_completion_request: + try: + self.on_chat_completion_request(request_id, chat_request, prompt) + except Exception as e: + if DEBUG >= 2: traceback.print_exc() + # request_id = None + # match = self.prompts.find_longest_prefix(prompt) + # if match and len(prompt) > len(match[1].prompt): + # if DEBUG >= 2: + # print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}") + # request_id = match[1].request_id + # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) + # # remove the matching prefix from the prompt + # prompt = prompt[len(match[1].prompt):] + # else: + # request_id = str(uuid.uuid4()) + # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt)) + + callback_id = f"chatgpt-api-wait-response-{request_id}" + callback = self.node.on_token.register(callback_id) + + if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}") + try: + await self.node.process_prompt(shard, prompt, image_str, request_id=request_id) + except Exception as e: + if DEBUG >= 2: traceback.print_exc() + return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) + + try: + if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s") + + if stream: + response = web.StreamResponse( + status=200, + reason="OK", + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + }, + ) + await response.prepare(request) + + async def stream_result(request_id: str, tokens: List[int], is_finished: bool): + prev_last_tokens_len = self.prev_token_lens.get(request_id, 0) + self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens)) + new_tokens = tokens[prev_last_tokens_len:] + finish_reason = None + eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer, + AutoTokenizer) else getattr(tokenizer, "eos_token_id", None) + if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id: + new_tokens = new_tokens[:-1] + if is_finished: + finish_reason = "stop" + if is_finished and not finish_reason: + finish_reason = "length" + + completion = generate_completion( + chat_request, + tokenizer, + prompt, + request_id, + new_tokens, + stream, + finish_reason, + "chat.completion", + ) + if DEBUG >= 2: print(f"Streaming completion: {completion}") + try: + await response.write(f"data: {json.dumps(completion)}\n\n".encode()) + except Exception as e: + if DEBUG >= 2: print(f"Error streaming completion: {e}") + if DEBUG >= 2: traceback.print_exc() + + def on_result(_request_id: str, tokens: List[int], is_finished: bool): + self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished)) + + return _request_id == request_id and is_finished + + _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs) + if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete + if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.") + try: + await asyncio.wait_for(self.stream_tasks[request_id], timeout=30) + except asyncio.TimeoutError: + print("WARNING: Stream task timed out. This should not happen.") + await response.write_eof() + return response + else: + _, tokens, _ = await callback.wait( + lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, + timeout=self.response_timeout_secs, + ) + + finish_reason = "length" + eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id + if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}") + if tokens[-1] == eos_token_id: + tokens = tokens[:-1] + finish_reason = "stop" + + return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion")) + except asyncio.TimeoutError: + return web.json_response({"detail": "Response generation timed out"}, status=408) + finally: + deregistered_callback = self.node.on_token.deregister(callback_id) + if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}") + + async def run(self, host: str = "0.0.0.0", port: int = 8000): + runner = web.AppRunner(self.app) + await runner.setup() + site = web.TCPSite(runner, host, port) + await site.start() diff --git a/build/lib/exo/download/__init__.py b/build/lib/exo/download/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/download/download_progress.py b/build/lib/exo/download/download_progress.py new file mode 100644 index 00000000..779e5328 --- /dev/null +++ b/build/lib/exo/download/download_progress.py @@ -0,0 +1,61 @@ +from typing import Dict, Callable, Coroutine, Any, Literal +from dataclasses import dataclass +from datetime import timedelta + + +@dataclass +class RepoFileProgressEvent: + repo_id: str + repo_revision: str + file_path: str + downloaded: int + downloaded_this_session: int + total: int + speed: int + eta: timedelta + status: Literal["not_started", "in_progress", "complete"] + + def to_dict(self): + return { + "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session, + "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status + } + + @classmethod + def from_dict(cls, data): + if 'eta' in data: data['eta'] = timedelta(seconds=data['eta']) + return cls(**data) + + +@dataclass +class RepoProgressEvent: + repo_id: str + repo_revision: str + completed_files: int + total_files: int + downloaded_bytes: int + downloaded_bytes_this_session: int + total_bytes: int + overall_speed: int + overall_eta: timedelta + file_progress: Dict[str, RepoFileProgressEvent] + status: Literal["not_started", "in_progress", "complete"] + + def to_dict(self): + return { + "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes, + "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(), + "file_progress": {k: v.to_dict() + for k, v in self.file_progress.items()}, "status": self.status + } + + @classmethod + def from_dict(cls, data): + if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta']) + if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()} + + return cls(**data) + + +RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]] +RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]] diff --git a/build/lib/exo/download/hf/__init__.py b/build/lib/exo/download/hf/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/download/hf/hf_helpers.py b/build/lib/exo/download/hf/hf_helpers.py new file mode 100644 index 00000000..8fd96dc5 --- /dev/null +++ b/build/lib/exo/download/hf/hf_helpers.py @@ -0,0 +1,403 @@ +import asyncio +import aiohttp +import json +import os +from urllib.parse import urljoin +from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal +from datetime import datetime, timedelta +from fnmatch import fnmatch +from pathlib import Path +from typing import Generator, Iterable, TypeVar, TypedDict +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type +from exo.helpers import DEBUG +from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback +from exo.inference.shard import Shard +import aiofiles +from aiofiles import os as aios + +T = TypeVar("T") + +async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]: + refs_dir = get_repo_root(repo_id)/"refs" + refs_file = refs_dir/revision + if await aios.path.exists(refs_file): + async with aiofiles.open(refs_file, 'r') as f: + commit_hash = (await f.read()).strip() + snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash + return snapshot_dir + return None + + +def filter_repo_objects( + items: Iterable[T], + *, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + key: Optional[Callable[[T], str]] = None, +) -> Generator[T, None, None]: + if isinstance(allow_patterns, str): + allow_patterns = [allow_patterns] + if isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + if allow_patterns is not None: + allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns] + if ignore_patterns is not None: + ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns] + + if key is None: + + def _identity(item: T) -> str: + if isinstance(item, str): + return item + if isinstance(item, Path): + return str(item) + raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.") + + key = _identity + + for item in items: + path = key(item) + if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns): + continue + if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns): + continue + yield item + + +def _add_wildcard_to_directories(pattern: str) -> str: + if pattern[-1] == "/": + return pattern + "*" + return pattern + + +def get_hf_home() -> Path: + """Get the Hugging Face home directory.""" + return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface")) + + +async def get_hf_token(): + """Retrieve the Hugging Face token from the user's HF_HOME directory.""" + token_path = get_hf_home()/"token" + if await aios.path.exists(token_path): + async with aiofiles.open(token_path, 'r') as f: + return (await f.read()).strip() + return None + + +async def get_auth_headers(): + """Get authentication headers if a token is available.""" + token = await get_hf_token() + if token: + return {"Authorization": f"Bearer {token}"} + return {} + + +def get_repo_root(repo_id: str) -> Path: + """Get the root directory for a given repo ID in the Hugging Face cache.""" + sanitized_repo_id = repo_id.replace("/", "--") + return get_hf_home()/"hub"/f"models--{sanitized_repo_id}" + + +async def fetch_file_list(session, repo_id, revision, path=""): + api_url = f"https://huggingface.co/api/models/{repo_id}/tree/{revision}" + url = f"{api_url}/{path}" if path else api_url + + headers = await get_auth_headers() + async with session.get(url, headers=headers) as response: + if response.status == 200: + data = await response.json() + files = [] + for item in data: + if item["type"] == "file": + files.append({"path": item["path"], "size": item["size"]}) + elif item["type"] == "directory": + subfiles = await fetch_file_list(session, repo_id, revision, item["path"]) + files.extend(subfiles) + return files + else: + raise Exception(f"Failed to fetch file list: {response.status}") + + +@retry( + stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True +) +async def download_file( + session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True +): + base_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/" + url = urljoin(base_url, file_path) + local_path = os.path.join(save_directory, file_path) + + await aios.makedirs(os.path.dirname(local_path), exist_ok=True) + + # Check if file already exists and get its size + local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0 + + headers = await get_auth_headers() + if use_range_request: + headers["Range"] = f"bytes={local_file_size}-" + + async with session.get(url, headers=headers) as response: + total_size = int(response.headers.get('Content-Length', 0)) + downloaded_size = local_file_size + downloaded_this_session = 0 + mode = 'ab' if use_range_request else 'wb' + if downloaded_size == total_size: + if DEBUG >= 2: print(f"File already downloaded: {file_path}") + if progress_callback: + await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) + return + + if response.status == 200: + # File doesn't support range requests or we're not using them, start from beginning + mode = 'wb' + downloaded_size = 0 + elif response.status == 206: + # Partial content, resume download + content_range = response.headers.get('Content-Range', '') + try: + total_size = int(content_range.split('/')[-1]) + except ValueError: + if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") + return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) + elif response.status == 416: + # Range not satisfiable, get the actual file size + content_range = response.headers.get('Content-Range', '') + try: + total_size = int(content_range.split('/')[-1]) + if downloaded_size == total_size: + if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}") + if progress_callback: + await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) + return + except ValueError: + if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") + return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) + else: + raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}") + + if downloaded_size == total_size: + print(f"File already downloaded: {file_path}") + if progress_callback: + await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) + return + + DOWNLOAD_CHUNK_SIZE = 32768 + start_time = datetime.now() + async with aiofiles.open(local_path, mode) as f: + async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): + await f.write(chunk) + downloaded_size += len(chunk) + downloaded_this_session += len(chunk) + if progress_callback and total_size: + elapsed_time = (datetime.now() - start_time).total_seconds() + speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0 + remaining_size = total_size - downloaded_size + eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0) + status = "in_progress" if downloaded_size < total_size else "complete" + if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}") + await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status)) + if DEBUG >= 2: print(f"Downloaded: {file_path}") + + +async def download_repo_files( + repo_id: str, + revision: str = "main", + progress_callback: Optional[RepoProgressCallback] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + max_parallel_downloads: int = 4 +) -> Path: + repo_root = get_repo_root(repo_id) + refs_dir = repo_root/"refs" + snapshots_dir = repo_root/"snapshots" + cachedreqs_dir = repo_root/"cachedreqs" + + # Ensure directories exist + await aios.makedirs(refs_dir, exist_ok=True) + await aios.makedirs(snapshots_dir, exist_ok=True) + await aios.makedirs(cachedreqs_dir, exist_ok=True) + + # Check if we have a cached commit hash + refs_file = refs_dir/revision + if await aios.path.exists(refs_file): + async with aiofiles.open(refs_file, 'r') as f: + commit_hash = (await f.read()).strip() + if DEBUG >= 2: print(f"Commit hash is already hashed at {refs_file}: {commit_hash}") + else: + async with aiohttp.ClientSession() as session: + # Fetch the commit hash for the given revision + api_url = f"https://huggingface.co/api/models/{repo_id}/revision/{revision}" + headers = await get_auth_headers() + async with session.get(api_url, headers=headers) as response: + if response.status != 200: + raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}") + revision_info = await response.json() + commit_hash = revision_info['sha'] + + # Cache the commit hash + async with aiofiles.open(refs_file, 'w') as f: + await f.write(commit_hash) + + # Set up the snapshot directory + snapshot_dir = snapshots_dir/commit_hash + await aios.makedirs(snapshot_dir, exist_ok=True) + + # Set up the cached file list directory + cached_file_list_dir = cachedreqs_dir/commit_hash + await aios.makedirs(cached_file_list_dir, exist_ok=True) + cached_file_list_path = cached_file_list_dir/"fetch_file_list.json" + + async with aiohttp.ClientSession() as session: + # Check if we have a cached file list + if await aios.path.exists(cached_file_list_path): + async with aiofiles.open(cached_file_list_path, 'r') as f: + file_list = json.loads(await f.read()) + if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}") + else: + file_list = await fetch_file_list(session, repo_id, revision) + # Cache the file list + async with aiofiles.open(cached_file_list_path, 'w') as f: + await f.write(json.dumps(file_list)) + if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}") + + filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"])) + total_files = len(filtered_file_list) + total_bytes = sum(file["size"] for file in filtered_file_list) + file_progress: Dict[str, RepoFileProgressEvent] = { + file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") + for file in filtered_file_list + } + start_time = datetime.now() + + async def download_with_progress(file_info, progress_state): + local_path = snapshot_dir/file_info["path"] + if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]: + if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}") + progress_state['completed_files'] += 1 + progress_state['downloaded_bytes'] += file_info["size"] + file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete") + if progress_callback: + elapsed_time = (datetime.now() - start_time).total_seconds() + overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 + remaining_bytes = total_bytes - progress_state['downloaded_bytes'] + overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) + status = "in_progress" if progress_state['completed_files'] < total_files else "complete" + await progress_callback( + RepoProgressEvent( + repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, + overall_eta, file_progress, status + ) + ) + return + + async def file_progress_callback(event: RepoFileProgressEvent): + progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded + progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session + file_progress[event.file_path] = event + if progress_callback: + elapsed_time = (datetime.now() - start_time).total_seconds() + overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 + remaining_bytes = total_bytes - progress_state['downloaded_bytes'] + overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) + status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete" + await progress_callback( + RepoProgressEvent( + repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, + overall_eta, file_progress, status + ) + ) + + await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback) + progress_state['completed_files'] += 1 + file_progress[ + file_info["path"] + ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete") + if progress_callback: + elapsed_time = (datetime.now() - start_time).total_seconds() + overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 + remaining_bytes = total_bytes - progress_state['downloaded_bytes'] + overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) + status = "in_progress" if progress_state['completed_files'] < total_files else "complete" + await progress_callback( + RepoProgressEvent( + repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, + overall_eta, file_progress, status + ) + ) + + progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0} + + semaphore = asyncio.Semaphore(max_parallel_downloads) + + async def download_with_semaphore(file_info): + async with semaphore: + await download_with_progress(file_info, progress_state) + + tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list] + await asyncio.gather(*tasks) + + return snapshot_dir + + +async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]: + """ + Retrieve the weight map from the model.safetensors.index.json file. + + Args: + repo_id (str): The Hugging Face repository ID. + revision (str): The revision of the repository to use. + + Returns: + Optional[Dict[str, str]]: The weight map if it exists, otherwise None. + """ + + # Download the index file + await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json") + + # Check if the file exists + repo_root = get_repo_root(repo_id) + snapshot_dir = repo_root/"snapshots" + index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None) + + if index_file: + index_file_path = snapshot_dir/index_file + if await aios.path.exists(index_file_path): + async with aiofiles.open(index_file_path, 'r') as f: + index_data = json.loads(await f.read()) + return index_data.get("weight_map") + + return None + + +def extract_layer_num(tensor_name: str) -> Optional[int]: + # This is a simple example and might need to be adjusted based on the actual naming convention + parts = tensor_name.split('.') + for part in parts: + if part.isdigit(): + return int(part) + return None + + +def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: + default_patterns = [ + "*.json", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + ] + shard_specific_patterns = [] + if weight_map: + for tensor_name, filename in weight_map.items(): + layer_num = extract_layer_num(tensor_name) + if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer: + shard_specific_patterns.append(filename) + sorted_file_names = sorted(weight_map.values()) + if shard.is_first_layer(): + shard_specific_patterns.append(sorted_file_names[0]) + elif shard.is_last_layer(): + shard_specific_patterns.append(sorted_file_names[-1]) + else: + shard_specific_patterns = ["*.safetensors"] + return list(set(default_patterns + shard_specific_patterns)) # Remove duplicates diff --git a/build/lib/exo/download/hf/hf_shard_download.py b/build/lib/exo/download/hf/hf_shard_download.py new file mode 100644 index 00000000..eb562c3c --- /dev/null +++ b/build/lib/exo/download/hf/hf_shard_download.py @@ -0,0 +1,77 @@ +import asyncio +import traceback +from pathlib import Path +from typing import Dict, List, Tuple +from exo.inference.shard import Shard +from exo.download.shard_download import ShardDownloader +from exo.download.download_progress import RepoProgressEvent +from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root +from exo.helpers import AsyncCallbackSystem, DEBUG + + +class HFShardDownloader(ShardDownloader): + def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4): + self.quick_check = quick_check + self.max_parallel_downloads = max_parallel_downloads + self.active_downloads: Dict[Shard, asyncio.Task] = {} + self.completed_downloads: Dict[Shard, Path] = {} + self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]() + + async def ensure_shard(self, shard: Shard) -> Path: + if shard in self.completed_downloads: + return self.completed_downloads[shard] + if self.quick_check: + repo_root = get_repo_root(shard.model_id) + snapshots_dir = repo_root/"snapshots" + if snapshots_dir.exists(): + visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')] + if visible_dirs: + most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime) + return most_recent_dir + + # If a download on this shard is already in progress, keep that one + for active_shard in self.active_downloads: + if active_shard == shard: + if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.") + return await self.active_downloads[shard] + + # Cancel any downloads for this model_id on a different shard + existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id] + for active_shard in existing_active_shards: + if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})") + task = self.active_downloads[active_shard] + task.cancel() + try: + await task + except asyncio.CancelledError: + pass # This is expected when cancelling a task + except Exception as e: + if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}") + traceback.print_exc() + self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id} + + # Start new download + download_task = asyncio.create_task(self._download_shard(shard)) + self.active_downloads[shard] = download_task + try: + path = await download_task + self.completed_downloads[shard] = path + return path + finally: + # Ensure the task is removed even if an exception occurs + print(f"Removing download task for {shard}: {shard in self.active_downloads}") + if shard in self.active_downloads: + self.active_downloads.pop(shard) + + async def _download_shard(self, shard: Shard) -> Path: + async def wrapped_progress_callback(event: RepoProgressEvent): + self._on_progress.trigger_all(shard, event) + + weight_map = await get_weight_map(shard.model_id) + allow_patterns = get_allow_patterns(weight_map, shard) + + return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads) + + @property + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + return self._on_progress diff --git a/build/lib/exo/download/shard_download.py b/build/lib/exo/download/shard_download.py new file mode 100644 index 00000000..771fb868 --- /dev/null +++ b/build/lib/exo/download/shard_download.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple +from pathlib import Path +from exo.inference.shard import Shard +from exo.download.download_progress import RepoProgressEvent +from exo.helpers import AsyncCallbackSystem + + +class ShardDownloader(ABC): + @abstractmethod + async def ensure_shard(self, shard: Shard) -> Path: + """ + Ensures that the shard is downloaded. + Does not allow multiple overlapping downloads at once. + If you try to download a Shard which overlaps a Shard that is already being downloaded, + the download will be cancelled and a new download will start. + + Args: + shard (Shard): The shard to download. + """ + pass + + @property + @abstractmethod + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + pass diff --git a/build/lib/exo/helpers.py b/build/lib/exo/helpers.py new file mode 100644 index 00000000..d8a5c6cc --- /dev/null +++ b/build/lib/exo/helpers.py @@ -0,0 +1,234 @@ +import os +import asyncio +from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List +import socket +import random +import platform +import psutil +import uuid +import netifaces +from pathlib import Path + +DEBUG = int(os.getenv("DEBUG", default="0")) +DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0")) +VERSION = "0.0.1" + +exo_text = r""" + _____ _____ + / _ \ \/ / _ \ +| __/> < (_) | + \___/_/\_\___/ + """ + + +def get_system_info(): + if psutil.MACOS: + if platform.machine() == "arm64": + return "Apple Silicon Mac" + if platform.machine() in ["x86_64", "i386"]: + return "Intel Mac" + return "Unknown Mac architecture" + if psutil.LINUX: + return "Linux" + return "Non-Mac, non-Linux system" + +def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int: + used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports") + + def read_used_ports(): + if os.path.exists(used_ports_file): + with open(used_ports_file, "r") as f: + return [int(line.strip()) for line in f if line.strip().isdigit()] + return [] + + def write_used_port(port, used_ports): + with open(used_ports_file, "w") as f: + print(used_ports[-19:]) + for p in used_ports[-19:] + [port]: + f.write(f"{p}\n") + + used_ports = read_used_ports() + available_ports = set(range(min_port, max_port + 1)) - set(used_ports) + + while available_ports: + port = random.choice(list(available_ports)) + if DEBUG >= 2: print(f"Trying to find available port {port=}") + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((host, port)) + write_used_port(port, used_ports) + return port + except socket.error: + available_ports.remove(port) + + raise RuntimeError("No available ports in the specified range") + + +def print_exo(): + print(exo_text) + + +def print_yellow_exo(): + yellow = "\033[93m" # ANSI escape code for yellow + reset = "\033[0m" # ANSI escape code to reset color + print(f"{yellow}{exo_text}{reset}") + + +def terminal_link(uri, label=None): + if label is None: + label = uri + parameters = "" + + # OSC 8 ; params ; URI ST OSC 8 ;; ST + escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\" + + return escape_mask.format(parameters, uri, label) + + +T = TypeVar("T") +K = TypeVar("K") + + +class AsyncCallback(Generic[T]): + def __init__(self) -> None: + self.condition: asyncio.Condition = asyncio.Condition() + self.result: Optional[Tuple[T, ...]] = None + self.observers: list[Callable[..., None]] = [] + + async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]: + async with self.condition: + await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout) + assert self.result is not None # for type checking + return self.result + + def on_next(self, callback: Callable[..., None]) -> None: + self.observers.append(callback) + + def set(self, *args: T) -> None: + self.result = args + for observer in self.observers: + observer(*args) + asyncio.create_task(self.notify()) + + async def notify(self) -> None: + async with self.condition: + self.condition.notify_all() + + +class AsyncCallbackSystem(Generic[K, T]): + def __init__(self) -> None: + self.callbacks: Dict[K, AsyncCallback[T]] = {} + + def register(self, name: K) -> AsyncCallback[T]: + if name not in self.callbacks: + self.callbacks[name] = AsyncCallback[T]() + return self.callbacks[name] + + def deregister(self, name: K) -> None: + if name in self.callbacks: + del self.callbacks[name] + + def trigger(self, name: K, *args: T) -> None: + if name in self.callbacks: + self.callbacks[name].set(*args) + + def trigger_all(self, *args: T) -> None: + for callback in self.callbacks.values(): + callback.set(*args) + + +K = TypeVar('K', bound=str) +V = TypeVar('V') + + +class PrefixDict(Generic[K, V]): + def __init__(self): + self.items: Dict[K, V] = {} + + def add(self, key: K, value: V) -> None: + self.items[key] = value + + def find_prefix(self, argument: str) -> List[Tuple[K, V]]: + return [(key, value) for key, value in self.items.items() if argument.startswith(key)] + + def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]: + matches = self.find_prefix(argument) + if len(matches) == 0: + return None + + return max(matches, key=lambda x: len(x[0])) + + +def is_valid_uuid(val): + try: + uuid.UUID(str(val)) + return True + except ValueError: + return False + + +def get_or_create_node_id(): + NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__)))/".exo_node_id" + try: + if NODE_ID_FILE.is_file(): + with open(NODE_ID_FILE, "r") as f: + stored_id = f.read().strip() + if is_valid_uuid(stored_id): + if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}") + return stored_id + else: + if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.") + + new_id = str(uuid.uuid4()) + with open(NODE_ID_FILE, "w") as f: + f.write(new_id) + + if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}") + return new_id + except IOError as e: + if DEBUG >= 2: print(f"IO error creating node_id: {e}") + return str(uuid.uuid4()) + except Exception as e: + if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}") + return str(uuid.uuid4()) + + +def pretty_print_bytes(size_in_bytes: int) -> str: + if size_in_bytes < 1024: + return f"{size_in_bytes} B" + elif size_in_bytes < 1024**2: + return f"{size_in_bytes / 1024:.2f} KB" + elif size_in_bytes < 1024**3: + return f"{size_in_bytes / (1024 ** 2):.2f} MB" + elif size_in_bytes < 1024**4: + return f"{size_in_bytes / (1024 ** 3):.2f} GB" + else: + return f"{size_in_bytes / (1024 ** 4):.2f} TB" + + +def pretty_print_bytes_per_second(bytes_per_second: int) -> str: + if bytes_per_second < 1024: + return f"{bytes_per_second} B/s" + elif bytes_per_second < 1024**2: + return f"{bytes_per_second / 1024:.2f} KB/s" + elif bytes_per_second < 1024**3: + return f"{bytes_per_second / (1024 ** 2):.2f} MB/s" + elif bytes_per_second < 1024**4: + return f"{bytes_per_second / (1024 ** 3):.2f} GB/s" + else: + return f"{bytes_per_second / (1024 ** 4):.2f} TB/s" + + +def get_all_ip_addresses(): + try: + ip_addresses = [] + for interface in netifaces.interfaces(): + ifaddresses = netifaces.ifaddresses(interface) + if netifaces.AF_INET in ifaddresses: + for link in ifaddresses[netifaces.AF_INET]: + ip = link['addr'] + ip_addresses.append(ip) + return list(set(ip_addresses)) + except: + if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.") + return ["localhost"] diff --git a/build/lib/exo/inference/__init__.py b/build/lib/exo/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/debug_inference_engine.py b/build/lib/exo/inference/debug_inference_engine.py new file mode 100644 index 00000000..27bcb592 --- /dev/null +++ b/build/lib/exo/inference/debug_inference_engine.py @@ -0,0 +1,59 @@ +from exo.inference.inference_engine import InferenceEngine +from exo.inference.shard import Shard +from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine +import asyncio +import numpy as np + + +# An inference engine should work the same for any number of Shards, as long as the Shards are continuous. +async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str): + from exo.inference.tinygrad.inference import Tokenizer + from pathlib import Path + + _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model")) + + prompt = "In a single word only, what is the last name of the president of the United States? " + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt) + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + "A", + shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), + input_data=resp_full, + inference_state=inference_state_full, + ) + + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt) + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), + input_data=resp1, + inference_state=inference_state_1, + ) + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), + input_data=resp2, + inference_state=inference_state_2, + ) + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), + input_data=resp3, + inference_state=inference_state_3, + ) + + print(f"{resp2=}") + print(f"full: {_tokenizer.decode(resp_full)}") + print(f"next full: {_tokenizer.decode(next_resp_full)}") + print(f"resp2: {_tokenizer.decode(resp2)}") + print(f"{resp4=}") + print(f"resp4: {_tokenizer.decode(resp4)}") + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) + + +asyncio.run(test_inference_engine( + TinygradDynamicShardInferenceEngine(), + TinygradDynamicShardInferenceEngine(), + "llama3-8b-sfr", +)) diff --git a/build/lib/exo/inference/inference_engine.py b/build/lib/exo/inference/inference_engine.py new file mode 100644 index 00000000..2b98adbe --- /dev/null +++ b/build/lib/exo/inference/inference_engine.py @@ -0,0 +1,34 @@ +import numpy as np +import os + +from typing import Tuple, Optional +from abc import ABC, abstractmethod +from .shard import Shard + + +class InferenceEngine(ABC): + @abstractmethod + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + pass + + @abstractmethod + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]: + pass + + +def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'): + if inference_engine_name == "mlx": + from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine + + return MLXDynamicShardInferenceEngine(shard_downloader) + elif inference_engine_name == "tinygrad": + from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine + import tinygrad.helpers + tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) + + return TinygradDynamicShardInferenceEngine(shard_downloader) + elif inference_engine_name == "pytorch": + from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine + return PyTorchDynamicShardInferenceEngine(shard_downloader) + else: + raise ValueError(f"Inference engine {inference_engine_name} not supported") diff --git a/build/lib/exo/inference/mlx/__init__.py b/build/lib/exo/inference/mlx/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/mlx/models/__init__.py b/build/lib/exo/inference/mlx/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/mlx/models/base.py b/build/lib/exo/inference/mlx/models/base.py new file mode 100644 index 00000000..a1f1878c --- /dev/null +++ b/build/lib/exo/inference/mlx/models/base.py @@ -0,0 +1,9 @@ +from typing import Optional +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import KVCache + + +class IdentityBlock(nn.Module): + def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array: + return x diff --git a/build/lib/exo/inference/mlx/models/deepseek_v2.py b/build/lib/exo/inference/mlx/models/deepseek_v2.py new file mode 100644 index 00000000..9ea271ed --- /dev/null +++ b/build/lib/exo/inference/mlx/models/deepseek_v2.py @@ -0,0 +1,127 @@ +from dataclasses import dataclass, field +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.base import KVCache +from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer +from .base import IdentityBlock +from exo.inference.shard import Shard + + +@dataclass +class ModelArgs(ModelArgs): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + if isinstance(self.shard, Shard): + return + if not isinstance(self.shard, dict): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + self.shard = Shard(**self.shard) + + +class DeepseekV2Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.args = config + self.num_hidden_layers = config.num_hidden_layers + self.vocab_size = config.vocab_size + if self.args.shard.is_first_layer(): + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + self.layers = [] + for i in range(self.num_hidden_layers): + if self.args.shard.start_layer <= i <= self.args.shard.end_layer: + self.layers.append(DeepseekV2DecoderLayer(config, i)) + else: + self.layers.append(IdentityBlock()) + + if self.args.shard.is_last_layer(): + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + x: mx.array, + cache: Optional[KVCache] = None, + ) -> mx.array: + if self.args.shard.is_first_layer(): + h = self.embed_tokens(x) + else: + h = x + + mask = None + T = h.shape[1] + if T > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(T) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None]*len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + if self.args.shard.is_last_layer(): + h = self.norm(h) + return h + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.args = config + self.model_type = config.model_type + self.model = DeepseekV2Model(config) + if self.args.shard.is_last_layer(): + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[KVCache] = None, + ): + out = self.model(inputs, cache) + if self.args.shard.is_last_layer(): + return self.lm_head(out) + return out + + def sanitize(self, weights): + shard_state_dict = {} + + for key, value in weights.items(): + if key.startswith('model.layers.'): + layer_num = int(key.split('.')[2]) + if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: + shard_state_dict[key] = value + elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): + shard_state_dict[key] = value + elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')): + shard_state_dict[key] = value + + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict: + to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)] + shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) + + return shard_state_dict + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return ( + self.args.qk_nope_head_dim + self.args.qk_rope_head_dim, + self.args.v_head_dim, + ) + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/build/lib/exo/inference/mlx/models/llama.py b/build/lib/exo/inference/mlx/models/llama.py new file mode 100644 index 00000000..719d6a88 --- /dev/null +++ b/build/lib/exo/inference/mlx/models/llama.py @@ -0,0 +1,125 @@ +from dataclasses import dataclass, field + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.llama import TransformerBlock, ModelArgs + +from ...shard import Shard +from .base import IdentityBlock + + +@dataclass +class ModelArgs(ModelArgs): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + super().__post_init__() # Ensure parent initializations are respected + + if isinstance(self.shard, Shard): + return + if not isinstance(self.shard, dict): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + self.shard = Shard(**self.shard) + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + if self.args.shard.is_first_layer(): + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [] + for i in range(self.num_hidden_layers): + if self.args.shard.start_layer <= i <= self.args.shard.end_layer: + self.layers.append(TransformerBlock(args=args)) + else: + self.layers.append(IdentityBlock()) + if self.args.shard.is_last_layer(): + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + if self.args.shard.is_first_layer(): + h = self.embed_tokens(inputs) + else: + h = inputs + + mask = None + if h.shape[1] > 1: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None]*len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + if self.args.shard.is_last_layer(): + h = self.norm(h) + return h + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LlamaModel(args) + if self.args.shard.is_last_layer(): + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.shard.is_last_layer(): + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + shard_state_dict = {} + + for key, value in weights.items(): + if "self_attn.rotary_emb.inv_freq" in key: + continue + if key.startswith('model.layers.'): + layer_num = int(key.split('.')[2]) + if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: + shard_state_dict[key] = value + elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): + shard_state_dict[key] = value + elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'): + shard_state_dict[key] = value + elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'): + shard_state_dict[key] = value + elif self.args.shard.is_last_layer() and (key.startswith('model.norm')): + shard_state_dict[key] = value + + return shard_state_dict + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads) + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads diff --git a/build/lib/exo/inference/mlx/models/llava.py b/build/lib/exo/inference/mlx/models/llava.py new file mode 100644 index 00000000..b734b09b --- /dev/null +++ b/build/lib/exo/inference/mlx/models/llava.py @@ -0,0 +1,585 @@ +# Copyright © 2024 Apple Inc. + +import math +import inspect +from dataclasses import dataclass, field +from typing import Optional, Dict, Union + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import BaseModelArgs, KVCache +from exo.inference.shard import Shard +from .base import IdentityBlock +import numpy as np + + +@dataclass +class VisionConfig: + model_type: str + num_hidden_layers: int = 24 + hidden_size: int = 1024 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + layer_norm_eps: float = 1e-5 + + @classmethod + def from_dict(cls, params): + return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) + + +class VisionAttention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError("The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0") + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + scale = math.sqrt(1/queries.shape[-1]) + scores = (queries*scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + +class VisionMLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.activation_fn = nn.GELU(approx="fast") + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x + + +class VisionEncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = VisionAttention(config.hidden_size, config.num_attention_heads, bias=True) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = VisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y + + +class VisionEncoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + + +class VisionEmbeddings(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = mx.zeros((config.hidden_size,)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: + batch_size = x.shape[0] + patch_embeddings = self.patch_embedding(x) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) + embed_dim = patch_embeddings.shape[-1] + cls_embeddings = mx.broadcast_to(self.class_embedding, (batch_size, 1, embed_dim)) + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + embeddings += self.position_embedding.weight + return embeddings + + +class ClipVisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = VisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + x = self.embeddings(x) + x = self.pre_layrnorm(x) + + encoder_states = (x,) if output_hidden_states else None + + for l in self.encoder.layers: + x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) + + pooler_output = self.post_layernorm(x[:, 0, :]) + return pooler_output, x, encoder_states + + +class VisionModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + + self.model_type = config.model_type + if self.model_type != "clip_vision_model": + raise ValueError(f"Unsupported model type: {self.model_type}") + + self.vision_model = ClipVisionModel(config) + + def __call__(self, x: mx.array, output_hidden_states: Optional[bool] = None) -> mx.array: + return self.vision_model(x, output_hidden_states) + + def sanitize(self, weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + # Remove unused position_ids + continue + elif "patch_embedding.weight" in k: + # PyTorch conv2d weight tensors have shape: + # [out_channels, in_channels, kH, KW] + # MLX conv2d expects the weight be of shape: + # [out_channels, kH, KW, in_channels] + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights + + +@dataclass +class TextConfig: + model_type: str + hidden_size: int = 4096 + num_hidden_layers: int = 32 + intermediate_size: int = 11008 + num_attention_heads: int = 32 + head_dim: int = None + rms_norm_eps: float = 1e-6 + vocab_size: int = 32000 + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + @classmethod + def from_dict(cls, params): + return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters}) + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.head_dim is None: + self.head_dim = self.hidden_size // self.num_attention_heads + + if self.model_type is None: + self.model_type = "llama" + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + +class TextAttention(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + + dim = config.hidden_size + self.n_heads = n_heads = config.num_attention_heads + self.n_kv_heads = n_kv_heads = config.num_key_value_heads + + self.repeats = n_heads // n_kv_heads + + head_dim = config.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False) + self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False) + + rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1) + self.rope = nn.RoPE( + head_dim, + traditional=config.rope_traditional, + base=config.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class TextMLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: TextConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.self_attn = TextAttention(config) + self.mlp = TextMLP(config.hidden_size, config.intermediate_size) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class Llama(nn.Module): + def __init__(self, config: TextConfig, shard: Shard): + super().__init__() + self.config = config + self.shard = shard + self.vocab_size = config.vocab_size + self.model_type = config.model_type + self.num_hidden_layers = config.num_hidden_layers + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + assert self.vocab_size > 0 + if self.shard.is_first_layer(): + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [] + for i in range(self.num_hidden_layers): + if self.shard.start_layer <= i <= self.shard.end_layer: + self.layers.append(TransformerBlock(config=config)) + else: + self.layers.append(IdentityBlock()) + if self.shard.is_last_layer(): + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + ): + # for passing merged input embeddings + if inputs_embeds is None: + if self.shard.is_first_layer(): + h = self.embed_tokens(inputs) + else: + h = inputs + else: + h = inputs_embeds + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None]*len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + if self.shard.is_last_layer(): + h = self.norm(h) + return h + + +class LanguageModel(nn.Module): + def __init__(self, config: TextConfig, shard: Shard): + super().__init__() + self.model_type = config.model_type + if self.model_type != "llama": + raise ValueError(f"Model type {self.model_type} not supported. Currently only 'llama' is supported") + self.shard = shard + self.model = Llama(config, shard) + if self.shard.is_last_layer(): + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + inputs_embeds=None, + ): + out = self.model(inputs, cache, inputs_embeds) + if self.shard.is_last_layer(): + out = self.lm_head(out) + return out + + def sanitize(self, weights): + shard_state_dict = {} + for key, value in weights.items(): + if "self_attn.rotary_emb.inv_freq" in key: + continue + + if key.startswith('language_model.model.layers.'): + layer_num = int(key.split('.')[3]) + if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer: + continue + if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'): + continue + elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')): + continue + + shard_state_dict[key] = value + + return shard_state_dict + + +@dataclass +class LlaVAConfig(BaseModelArgs): + text_config: TextConfig + vision_config: VisionConfig = None + model_type: str = "llava" + ignore_index: int = -100 + image_token_index: int = 32000 + vision_feature_select_strategy: str = "default" + vision_feature_layer: int = -2 + vocab_size: int = 32000 + + @classmethod + def from_dict(cls, params): + updated_params = {} + class_params = inspect.signature(cls).parameters + for k, v in params.items(): + if k in class_params: + if k in ["text_config", "vision_config"]: + v = class_params[k].annotation.from_dict(v) + updated_params.update({k: v}) + + return cls(**updated_params) + + +@dataclass +class ModelArgs(LlaVAConfig): + shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) + + def __post_init__(self): + if isinstance(self.shard, dict): + self.shard = Shard(**self.shard) + + if not isinstance(self.shard, Shard): + raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") + + if not self.shard.is_first_layer(): + self.vision_config = None + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlaVAConfig): + super().__init__() + self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.gelu = nn.GELU() + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + + def __call__(self, x: mx.array) -> mx.array: + x = self.linear_1(x) + x = self.gelu(x) + x = self.linear_2(x) + return x + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.model_type = config.model_type + if config.vision_config: + self.vision_tower = VisionModel(config.vision_config) + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.vision_feature_layer = config.vision_feature_layer + self.vision_feature_select_strategy = config.vision_feature_select_strategy + self.language_model = LanguageModel(config.text_config, config.shard) + + def get_input_embeddings( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + ): + if pixel_values is None: + return self.language_model(input_ids) + + # Get the input embeddings from the language model + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + # Get the ouptut hidden states from the vision model + *_, hidden_states = self.vision_tower(pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True) + + # Select the hidden states from the desired layer + selected_image_feature = hidden_states[self.vision_feature_layer] + + if self.vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError("Unexpected feature selection strategy: " + f"{self.vision_feature_select_strategy}") + + # Pass image features through the multi-modal projector + image_features = self.multi_modal_projector(selected_image_feature) + + # Insert special image tokens in the input_ids + final_inputs_embeds = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids) + return final_inputs_embeds + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids): + image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape + + # Positions of tokens in input_ids, assuming batch size is 1 + image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + + if len(image_positions) != num_images: + raise ValueError(f"The number of image tokens ({len(image_positions)}) does not " + f" match the number of image inputs ({num_images}).") + + text_segments = [] + start_idx = 0 + + for position in image_positions: + text_segments.append(inputs_embeds[:, start_idx:position]) + start_idx = position + 1 + + image_embeddings = mx.split(image_features, image_features.shape[0]) + final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] + final_embeddings += [inputs_embeds[:, start_idx:]] + + # Create a final embedding of shape + # (1, num_image_patches*num_images + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=1) + + def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None): + input_embddings = None + if pixel_values is not None: + input_embddings = self.get_input_embeddings(input_ids, pixel_values) + logits = self.language_model(input_ids, cache=cache, inputs_embeds=input_embddings) + return logits + + def sanitize(self, weights): + if self.config.vision_config: + weights = self.vision_tower.sanitize(weights) + else: + weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))} + weights = self.language_model.sanitize(weights) + return weights + + @property + def layers(self): + return self.language_model.model.layers + + @property + def head_dim(self): + return (self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads) + + @property + def n_kv_heads(self): + return self.language_model.model.num_key_value_heads diff --git a/build/lib/exo/inference/mlx/sharded_inference_engine.py b/build/lib/exo/inference/mlx/sharded_inference_engine.py new file mode 100644 index 00000000..40cabfeb --- /dev/null +++ b/build/lib/exo/inference/mlx/sharded_inference_engine.py @@ -0,0 +1,40 @@ +import numpy as np +import mlx.core as mx +from ..inference_engine import InferenceEngine +from .sharded_model import StatefulShardedModel +from .sharded_utils import load_shard, get_image_from_str +from ..shard import Shard +from typing import Optional +from exo.download.shard_download import ShardDownloader + + +class MLXDynamicShardInferenceEngine(InferenceEngine): + def __init__(self, shard_downloader: ShardDownloader): + self.shard = None + self.shard_downloader = shard_downloader + + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): + await self.ensure_shard(shard) + if image_str: + image = await get_image_from_str(image_str) + inputs = self.tokenizer(prompt, image, return_tensors="np") + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, input_ids, pixel_values)) + else: + output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt)))) + return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id + + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool): + await self.ensure_shard(shard) + output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data))) + return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id + + async def ensure_shard(self, shard: Shard): + if self.shard == shard: + return + + model_path = await self.shard_downloader.ensure_shard(shard) + model_shard, self.tokenizer = await load_shard(model_path, shard) + self.stateful_sharded_model = StatefulShardedModel(shard, model_shard) + self.shard = shard diff --git a/build/lib/exo/inference/mlx/sharded_model.py b/build/lib/exo/inference/mlx/sharded_model.py new file mode 100644 index 00000000..c4570fbf --- /dev/null +++ b/build/lib/exo/inference/mlx/sharded_model.py @@ -0,0 +1,86 @@ +from typing import Dict, Generator, Optional, Tuple +from collections import OrderedDict + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import KVCache, RotatingKVCache +from mlx_lm.sample_utils import top_p_sampling + +from ..shard import Shard + + +class StatefulShardedModel: + def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2): + self.shard = shard + self.model = model + self.max_kv_size = max_kv_size + self.max_caches = max_caches + self.caches = OrderedDict() + + def step( + self, + request_id: str, + x, + pixel_values=None, + temp: float = 0.0, + top_p: float = 1.0, + logit_bias: Optional[Dict[int, float]] = None, + ) -> Generator[Tuple[mx.array, mx.array], None, None]: + def sample(logits: mx.array) -> Tuple[mx.array, float]: + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + logits[:, indices] += values + + if temp == 0: + token = mx.argmax(logits, axis=-1) + else: + if top_p > 0 and top_p < 1.0: + token = top_p_sampling(logits, top_p, temp) + else: + token = mx.random.categorical(logits*(1/temp)) + + return token + + y = x + + if request_id not in self.caches: + self.init_cache(request_id) + else: + self.caches.move_to_end(request_id) + + cache = self.caches[request_id] + + if pixel_values is None: + output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache) + else: + output = self.model(y, pixel_values=pixel_values, cache=cache) + + if self.shard.is_last_layer(): + logits = output[:, -1, :] + y = sample(logits) + return y + else: + return output + + def __call__( + self, + request_id: str, + x, + temp: float = 0.0, + top_p: float = 1.0, + logit_bias: Optional[Dict[int, float]] = None, + ) -> Generator[Tuple[mx.array, mx.array], None, None]: + return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias) + + def init_cache(self, request_id: str): + kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads) + if self.max_kv_size is not None: + cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads] + else: + cache = [KVCache(self.model.head_dim, n) for n in kv_heads] + + if len(self.caches) >= self.max_caches: + self.caches.popitem(last=False) + + self.caches[request_id] = cache diff --git a/build/lib/exo/inference/mlx/sharded_utils.py b/build/lib/exo/inference/mlx/sharded_utils.py new file mode 100644 index 00000000..7fa38eaa --- /dev/null +++ b/build/lib/exo/inference/mlx/sharded_utils.py @@ -0,0 +1,207 @@ +# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py + +import glob +import importlib +import json +import logging +import asyncio +import aiohttp +from functools import partial +from pathlib import Path +from typing import Optional, Tuple, Union, List, Callable +from PIL import Image +from io import BytesIO +import base64 + +import mlx.core as mx +import mlx.nn as nn +from transformers import AutoProcessor + +from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper +from mlx_lm.tuner.utils import apply_lora_layers + +from exo import DEBUG +from ..shard import Shard + + +class ModelNotFoundError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +MODEL_REMAPPING = { + "mistral": "llama", # mistral is compatible with llama + "phi-msft": "phixtral", +} + + +def _get_classes(config: dict): + """ + Retrieve the model and model args classes based on the configuration. + + Args: + config (dict): The model configuration. + + Returns: + A tuple containing the Model class and the ModelArgs class. + """ + model_type = config["model_type"] + model_type = MODEL_REMAPPING.get(model_type, model_type) + try: + arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}") + except ImportError: + msg = f"Model type {model_type} not supported." + logging.error(msg) + raise ValueError(msg) + + return arch.Model, arch.ModelArgs + + +def load_config(model_path: Path) -> dict: + try: + with open(model_path/"config.json", "r") as f: + config = json.load(f) + except FileNotFoundError: + logging.error(f"Config file not found in {model_path}") + raise + return config + + +def load_model_shard( + model_path: Path, + shard: Shard, + lazy: bool = False, + model_config: dict = {}, +) -> nn.Module: + """ + Load and initialize the model from a given path. + + Args: + model_path (Path): The path to load the model from. + lazy (bool): If False eval the model parameters to make sure they are + loaded in memory before returning, otherwise they will be loaded + when needed. Default: ``False`` + model_config(dict, optional): Configuration parameters for the model. + Defaults to an empty dictionary. + + Returns: + nn.Module: The loaded and initialized model. + + Raises: + FileNotFoundError: If the weight files (.safetensors) are not found. + ValueError: If the model class or args class are not found or cannot be instantiated. + """ + config = load_config(model_path) + config.update(model_config) + + # TODO hack + config["shard"] = { + "model_id": model_path.name, + "start_layer": shard.start_layer, + "end_layer": shard.end_layer, + "n_layers": shard.n_layers, + } + + weight_files = glob.glob(str(model_path/"model*.safetensors")) + + if not weight_files: + # Try weight for back-compat + weight_files = glob.glob(str(model_path/"weight*.safetensors")) + + if not weight_files: + logging.error(f"No safetensors found in {model_path}") + raise FileNotFoundError(f"No safetensors found in {model_path}") + + weights = {} + for wf in sorted(weight_files): + if DEBUG >= 8: + layer_nums = set() + for k in mx.load(wf): + if k.startswith("model.layers."): + layer_num = int(k.split(".")[2]) + layer_nums.add(layer_num) + if k.startswith("language_model.model.layers."): + layer_num = int(k.split(".")[3]) + layer_nums.add(layer_num) + print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},") + + weights.update(mx.load(wf)) + + model_class, model_args_class = _get_classes(config=config) + + model_args = model_args_class.from_dict(config) + model = model_class(model_args) + + if hasattr(model, "sanitize"): + weights = model.sanitize(weights) + + if (quantization := config.get("quantization", None)) is not None: + # Handle legacy models which may not have everything quantized + def class_predicate(p, m): + if not hasattr(m, "to_quantized"): + return False + return f"{p}.scales" in weights + + nn.quantize( + model, + **quantization, + class_predicate=class_predicate, + ) + + model.load_weights(list(weights.items()), strict=True) + + if not lazy: + mx.eval(model.parameters()) + + model.eval() + return model + + +async def load_shard( + model_path: str, + shard: Shard, + tokenizer_config={}, + model_config={}, + adapter_path: Optional[str] = None, + lazy: bool = False, +) -> Tuple[nn.Module, TokenizerWrapper]: + model = load_model_shard(model_path, shard, lazy, model_config) + if adapter_path is not None: + model = apply_lora_layers(model, adapter_path) + model.eval() + + # TODO: figure out a generic solution + if model.model_type == "llava": + processor = AutoProcessor.from_pretrained(model_path) + processor.eos_token_id = processor.tokenizer.eos_token_id + processor.encode = processor.tokenizer.encode + return model, processor + else: + tokenizer = load_tokenizer(model_path, tokenizer_config) + return model, tokenizer + + +async def get_image_from_str(_image_str: str): + image_str = _image_str.strip() + + if image_str.startswith("http"): + async with aiohttp.ClientSession() as session: + async with session.get(image_str, timeout=10) as response: + content = await response.read() + return Image.open(BytesIO(content)).convert("RGB") + elif image_str.startswith("data:image/"): + # Extract the image format and base64 data + format_prefix, base64_data = image_str.split(";base64,") + image_format = format_prefix.split("/")[1].lower() + if DEBUG >= 2: print(f"{image_str=} {image_format=}") + imgdata = base64.b64decode(base64_data) + img = Image.open(BytesIO(imgdata)) + + # Convert to RGB if not already + if img.mode != "RGB": + img = img.convert("RGB") + + return img + else: + raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.") diff --git a/build/lib/exo/inference/mlx/test_sharded_llama.py b/build/lib/exo/inference/mlx/test_sharded_llama.py new file mode 100644 index 00000000..1c48b936 --- /dev/null +++ b/build/lib/exo/inference/mlx/test_sharded_llama.py @@ -0,0 +1,40 @@ +import mlx.core as mx +from exo.inference.mlx.sharded_model import StatefulShardedModel +from exo.inference.mlx.sharded_utils import load_shard +from exo.inference.shard import Shard + +# 79, 80 for Llama-3-70B +shard_full = Shard("llama", 0, 31, 32) +shard1 = Shard("llama", 0, 12, 32) +shard2 = Shard("llama", 13, 31, 32) + +full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard_full) +model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1) +model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2) + +full = StatefulShardedModel(shard_full, full_model_shard) +m1 = StatefulShardedModel(shard1, model_shard1) +m2 = StatefulShardedModel(shard2, model_shard2) + +prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:" +prompt_tokens = mx.array(full_tokenizer.encode(prompt)) +max_tokens = 50 + +resp = prompt_tokens +full_generated_tokens = [] +for _ in range(max_tokens): + resp = full.step(resp) + full_generated_tokens.append(resp.item()) + +print("full response: ", full_tokenizer.decode(full_generated_tokens)) + +sharded_generated_tokens = [] +sharded_resp = prompt_tokens +for _ in range(max_tokens): + resp1 = m1.step(sharded_resp) + sharded_resp = m2.step(resp1) + sharded_generated_tokens.append(sharded_resp.item()) + +print("sharded response: ", tokenizer1.decode(sharded_generated_tokens)) + +assert tokenizer1.decode(full_generated_tokens) == tokenizer1.decode(sharded_generated_tokens) diff --git a/build/lib/exo/inference/mlx/test_sharded_llava.py b/build/lib/exo/inference/mlx/test_sharded_llava.py new file mode 100644 index 00000000..958a5acc --- /dev/null +++ b/build/lib/exo/inference/mlx/test_sharded_llava.py @@ -0,0 +1,64 @@ +import codecs +import asyncio +import requests +from PIL import Image +from io import BytesIO + +import mlx.core as mx +from mlx_lm.models.base import KVCache + +from exo.inference.mlx.sharded_model import StatefulShardedModel +from exo.inference.mlx.sharded_utils import load_shard +from exo.inference.shard import Shard + +shard_full = Shard("llava", 0, 31, 32) +shard1 = Shard("llava", 0, 12, 32) +shard2 = Shard("llava", 13, 31, 32) + +model_path = "llava-hf/llava-1.5-7b-hf" + +full_model_shard, full_processor = asyncio.run(load_shard(model_path, shard=shard_full)) +model_shard1, processor1 = asyncio.run(load_shard(model_path, shard=shard1)) +model_shard2, processor2 = asyncio.run(load_shard(model_path, shard=shard2)) + +full = StatefulShardedModel(shard_full, full_model_shard) +m1 = StatefulShardedModel(shard1, model_shard1) +m2 = StatefulShardedModel(shard2, model_shard2) + +PROMPT = "USER: \nWhat are these?\nASSISTANT:" +IMAGE_FILE = "http://images.cocodataset.org/val2017/000000039769.jpg" +response = requests.get(IMAGE_FILE) +img = Image.open(BytesIO(response.content)) +prompt = codecs.decode(PROMPT, "unicode_escape") +inputs = full_processor(prompt, img, return_tensors="np") +pixel_values = mx.array(inputs["pixel_values"]) +input_ids = mx.array(inputs["input_ids"]) + +print(prompt) +y = full.step("full", input_ids, pixel_values, temp=0) +full_generated_tokens = [y.item()] + +for _ in range(13): + y = full.step("full", y, temp=0) + full_generated_tokens.append(y.item()) + +full_response = full_processor.tokenizer.decode(full_generated_tokens) +print("full response:", full_response) + +inputs = processor1(prompt, img, return_tensors="np") +pixel_values = mx.array(inputs["pixel_values"]) +input_ids = mx.array(inputs["input_ids"]) + +y = m1.step("shard", input_ids, pixel_values, temp=0) +y = m2.step("shard", y, temp=0) +full_generated_tokens = [y.item()] + +for _ in range(13): + y = m1.step("shard", y, temp=0) + y = m2.step("shard", y, temp=0) + full_generated_tokens.append(y.item()) + +sharded_response = processor2.tokenizer.decode(full_generated_tokens) +print("sharded response:", sharded_response) + +assert full_response == sharded_response diff --git a/build/lib/exo/inference/mlx/test_sharded_model.py b/build/lib/exo/inference/mlx/test_sharded_model.py new file mode 100644 index 00000000..c9743d07 --- /dev/null +++ b/build/lib/exo/inference/mlx/test_sharded_model.py @@ -0,0 +1,52 @@ +from exo.inference.shard import Shard +import mlx.core as mx +import mlx.nn as nn +from typing import Optional +import numpy as np + + +class DummyModel(nn.Module): + def __init__(self, shard: Optional[Shard] = None): + self.shard = shard + self.layers = [ + nn.Linear(8, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 128), + nn.Linear(128, 8), + ] + + self.n_kv_heads = 4 + self.head_dim = 4 + + def __call__(self, x, cache=None): + if self.shard: + for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]: + x = layer(x) + if self.shard.is_last_layer(): + x = x.reshape((1, 2, 4)) + else: + for layer in self.layers: + x = layer(x) + x = x.reshape((1, 2, 4)) + + return x + + +model = DummyModel() +model.save_weights("./test_weights.npz") +n_layers = 5 +shard1 = Shard("test", 0, n_layers // 2, n_layers) +sharded_model1 = DummyModel(shard1) +shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers) +sharded_model2 = DummyModel(shard2) + +model.load_weights("./test_weights.npz") +sharded_model1.load_weights("./test_weights.npz") +sharded_model2.load_weights("./test_weights.npz") + +fullresp = model(mx.array([1, 2, 3, 4, 5, 6, 7, 8])) +resp1 = sharded_model1(mx.array([1, 2, 3, 4, 5, 6, 7, 8])) +resp2 = sharded_model2(resp1) + +assert np.all(np.array(fullresp) == np.array(resp2)) diff --git a/build/lib/exo/inference/pytorch/__init__.py b/build/lib/exo/inference/pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/pytorch/helpers.py b/build/lib/exo/inference/pytorch/helpers.py new file mode 100644 index 00000000..addea2db --- /dev/null +++ b/build/lib/exo/inference/pytorch/helpers.py @@ -0,0 +1,24 @@ +# Helper functions for pytorch inference +# Some code coming from tinygrad but written towards pytorch + +import asyncio +import aiohttp +from tqdm import tqdm +from pathlib import Path +from typing import List + +async def fetch_file_async(session, url: str, output_path: Path): + async with session.get(url) as response: + response.raise_for_status() + with open(output_path, 'wb') as f: + async for chunk in response.content.iter_chunked(8192): + f.write(chunk) + +async def download_files(urls: List[str], output_paths: List[Path]): + async with aiohttp.ClientSession() as session: + tasks = [] + for url, output_path in zip(urls, output_paths): + tasks.append(fetch_file_async(session, url, output_path)) + + for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Downloading files"): + await f diff --git a/build/lib/exo/inference/pytorch/inference.py b/build/lib/exo/inference/pytorch/inference.py new file mode 100644 index 00000000..ba834eb6 --- /dev/null +++ b/build/lib/exo/inference/pytorch/inference.py @@ -0,0 +1,211 @@ +# experimental, based off of tinygrad/inference.py +import numpy as np +import torch +import numpy as np +import json +from typing import Optional, Tuple +from exo.inference.shard import Shard +from exo.inference.inference_engine import InferenceEngine +from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel +from exo.api.chatgpt_api import resolve_tokenizer +from exo.helpers import DEBUG +from transformers import DynamicCache +from accelerate import disk_offload + +class PyTorchDynamicShardInferenceEngine(InferenceEngine): + """ + PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models. + """ + + def __init__(self, shard): + """ + Initialize the inference engine. + + Args: + debug (bool): If True, enables debug logging. Defaults to False. + """ + self.shard = shard + self.model = None + self.tokenizer = None + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + async def infer_prompt( + self, + request_id: str, + shard: Optional[Shard] = None, + prompt: str = "", + image_str: Optional[str] = None, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + + await self.ensure_shard(shard) + + # need to make this so inference_state is not a string + # cant use it with dynamic cache + + tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) + tokens = self.model.embed_tokens(tokens) + current_kvs = None + + if DEBUG >= 4: + print("infer_prompt called") + print(f"tokens: {tokens}\n") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + + # convert inference_state or cache from json to DynamicCache + past_kv = DynamicCache() + if inference_state != None: + cache_dict = json.loads(inference_state) + past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] + past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] + + output_data, current_kvs = self.model.forward( + tokens, + past_kv + ) + + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + + if DEBUG >= 4: + print(f"output_data: {output_data}\n") + print(f"output_data.size {output_data.size}\n") + + print(f"finished: {is_finished}") + print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") + print(f"output_data[-1] {output_data[-1]}") + + if output_data.size == 1: + print(f"size 1 output_data.item() {output_data.item()}") + print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] + } + + return ( + output_data, + json.dumps(cache_dict), + is_finished + ) + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + + await self.ensure_shard(shard) + + current_kvs = None + + if input_data.size == 1: + in_tensor = torch.from_numpy( + input_data, + ).unsqueeze(0).long().to(self.device) + else: + in_tensor = torch.from_numpy( + input_data + ).long().to(self.device) + + in_tensor = self.model.embed_tokens(in_tensor) + + if DEBUG >= 4: + print("infer_tensor called") + print(f"input_data: {input_data}") + print(f"input_data.size: {input_data.size}") + print(f"input_tensor: {in_tensor}\n") + print(f"shard: {self.shard}") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + + # convert inference_state or cache from json to DynamicCache + past_kv = DynamicCache() + if inference_state != None: + try: + cache_dict = json.loads(inference_state) + past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']] + past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']] + + if DEBUG >= 4: + print("Loaded past_kv from JSON") + print(f"past_kv: {past_kv}") + print(f"past_kv.key_cache len: {len(past_kv.key_cache)}") + print(f"past_kv.value_cache len: {len(past_kv.value_cache)}") + except json.JSONDecodeError: + print(f"ERROR DECODING INFERENCE STATE") + + output_data, current_kvs = self.model.forward( + in_tensor, + past_kv + ) + + is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id] + + if DEBUG >= 4: + print(f"in_tensor: {in_tensor}\n") + print(f"output_data: {output_data}\n") + print(f"output_data.size {output_data.size}\n") + print(f"finished: {is_finished}") + print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}") + print(f"output_data[-1] {output_data[-1]}") + + if output_data.size == 1: + print(f"size 1 output_data.item() {output_data.item()}") + print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}") + + + cache_dict = { + 'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache], + 'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache] + } + + return ( + output_data, + json.dumps(cache_dict), + is_finished + ) + + async def ensure_shard(self, shard: Optional[Shard]): + """ + Ensure the model shard is loaded and ready for inference. + + Args: + shard (Optional[Shard]): Shard information for the model. + """ + # if self.shard == shard: + # return + + if DEBUG >= 4: + print(f"Loading new shard: {shard}") + + if self.model: + if DEBUG >= 2: + print(f"\nCLEARING MODEL {shard.model_id}\n") + print(f"before allocated: {torch.cuda.memory_allocated()}") + print(f"before reserved: {torch.cuda.memory_reserved()}") + + # delete model and free up memory to reload + # self.model.cuda() + # disk_offload(model=self.model, offload_dir="./.offload") + import gc + + del self.model + gc.collect() + torch.cuda.empty_cache() + + if DEBUG >= 2: + print(f"after allocated: {torch.cuda.memory_allocated()}") + print(f"after reserved: {torch.cuda.memory_reserved()}") + + self.shard = shard + self.tokenizer = await resolve_tokenizer(shard.model_id) + self.model = ShardedHuggingFaceModel(shard, self.tokenizer) + + if DEBUG >= 4: + print(f"Shard loaded successfully: {shard}") \ No newline at end of file diff --git a/build/lib/exo/inference/pytorch/model/__init__.py b/build/lib/exo/inference/pytorch/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/inference/pytorch/model/hf.py b/build/lib/exo/inference/pytorch/model/hf.py new file mode 100644 index 00000000..aa2873c5 --- /dev/null +++ b/build/lib/exo/inference/pytorch/model/hf.py @@ -0,0 +1,155 @@ +import torch +import numpy as np +from transformers import AutoModelForCausalLM, DynamicCache, Cache +from exo.inference.shard import Shard +from exo.helpers import DEBUG +from typing import Tuple, Optional, Union, List +from exo.inference.pytorch.model.utils import sample_logits + +TOP_P = 0.75 #0.95 +TOP_K = 20 +TEMP = 0.8 + +class ShardedHuggingFaceModel(torch.nn.Module): + def __init__(self, shard: Shard, tokenizer: any): + super(ShardedHuggingFaceModel, self).__init__() + + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + self.shard = shard + self.tokenizer = tokenizer + + # Load the model + try: + self.llm_model = AutoModelForCausalLM.from_pretrained( + shard.model_id, + torch_dtype=torch.float32, + device_map="auto", + # offload_buffers=True + ) + + # disk_offload(model=self.llm_model, offload_dir="./.offload") + + self.base_model = self.llm_model.model + except Exception as err: + print(f"Error loading model: {err}") + raise + + if DEBUG >= 2: + print(f"\nShardedHuggingFaceModel init with shard {shard}") + print(f"self.llm_model: {self.llm_model}") + print(f"self.base_model: {self.base_model}") + + if DEBUG >= 2: + print(f"full_model.model layer: {len(self.base_model.layers)}") + + # Embeddings and final layer norm + # used for doing what forward LlamaModel does in transformers + self.norm = self.base_model.norm + self.lm_head = self.llm_model.lm_head + self.embed_tokens = self.base_model.embed_tokens + + def forward( + self, + input_ids: torch.tensor, + past_kvs: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + ) -> Tuple[np.ndarray, any]: + """ + Forward through layers using the base model + + Args: + input_ids: tensor input + past_kvs: past key value stores for cache + use_cache: use cache + + Returns: + hidden_states: numpy of states between layers + or logits: numpy of normalization and linearization of last hidden state + past_kvs: DynamicCache of past key values if use_cache is true + + Ref: + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/qwen2/modeling_qwen2.py#L804 + https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L887 + """ + if DEBUG >= 4: + print("forward called") + print(f"input_ids: {input_ids}\n") + print(f"layer_count: {self.shard.get_layer_count()}") + print(f"is_first_layer: {self.shard.is_first_layer()}") + print(f"is_last_layer: {self.shard.is_last_layer()}") + + past_kvs = DynamicCache.from_legacy_cache(past_kvs) + past_seen_tokens = past_kvs.get_seq_length() if past_kvs is not None else 0 + + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + input_ids.shape[1], + device=input_ids.device + ).to(self.device) + + position_ids = cache_position.unsqueeze(0).to(self.device) + + try: + position_embeddings = self.base_model.rotary_emb( + input_ids, + position_ids + ) + except Exception as err: + print(f"rotary_emb not found in base_model") + position_embeddings = None + + # progress through layers + for i in range(self.shard.start_layer, self.shard.end_layer + 1): + decoder_layer = self.base_model.layers[i] + + if DEBUG >= 4: + print("Going through layer") + print(f"{decoder_layer}") + print("input_ids") + print(f"{input_ids}") + + layer_outputs = decoder_layer( + input_ids, + position_ids=position_ids if not position_embeddings else None, + position_embeddings=position_embeddings, + past_key_value=past_kvs, + use_cache=True, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + next_kvs = layer_outputs[1] + + if DEBUG >= 3: + print(f"layer_outputs {layer_outputs}") + + if self.shard.is_last_layer(): + hs_norm = self.norm(hidden_states) + hs_lm_head = self.llm_model.lm_head(hs_norm).float() + + # Use the sampling function with default settings + with torch.no_grad(): + output_token = sample_logits( + hs_lm_head[:, -1, :], + TEMP, + TOP_P, + TOP_K + ).numpy(force=True).flatten() + + if DEBUG >= 2: + print(f"hs_norm: {hs_norm}") + print(f"hs_lm_head: {hs_lm_head}") + print(f"output_token: {output_token}") + + return (output_token, next_kvs) + + with torch.no_grad(): + out_hidden_states = hidden_states.numpy(force=True) + + return ( + out_hidden_states, + next_kvs + ) \ No newline at end of file diff --git a/build/lib/exo/inference/pytorch/model/utils.py b/build/lib/exo/inference/pytorch/model/utils.py new file mode 100644 index 00000000..df84b397 --- /dev/null +++ b/build/lib/exo/inference/pytorch/model/utils.py @@ -0,0 +1,83 @@ +import torch +from torch.nn import functional as F + +def top_p_sampling(scaled_logits: torch.Tensor, top_p: float) -> torch.Tensor: + """ + Apply top-p (nucleus) sampling to logits. + + Args: + scaled_logits (torch.Tensor): The scaled logits from the model's output. + top_p (float): The cumulative probability threshold for top-p filtering. + temp (float): Temperature parameter for softmax distribution reshaping. + + Returns: + torch.Tensor: Token selected based on the top-p criterion. + + Ref: + https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py#L67C1-L97C17 + """ + scaled_logits = torch.where(torch.isnan(scaled_logits), torch.zeros_like(scaled_logits), scaled_logits) + scaled_logits = torch.where(torch.isinf(scaled_logits), torch.full_like(scaled_logits, 1e6), scaled_logits) + + probs = torch.softmax(scaled_logits, dim=-1) + + sorted_probs, sorted_indices = torch.sort( + probs, + descending=True, + dim=-1 + ) + + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + mask = cumulative_probs > top_p + + top_probs = torch.where(mask, torch.zeros_like(sorted_probs), sorted_probs) + sum_probs = top_probs.sum(dim=-1, keepdim=True) + top_probs = torch.where(sum_probs > 0, top_probs / sum_probs, torch.ones_like(top_probs) / top_probs.size(-1)) + + if torch.isnan(top_probs).any() or torch.isinf(top_probs).any(): + print("Warning: Top probabilities contain NaN or Inf values after normalization") + top_probs = torch.where(torch.isnan(top_probs) | torch.isinf(top_probs), + 1.0 / top_probs.size(-1), + top_probs) + + sorted_token = torch.multinomial(top_probs, num_samples=1) + + token = sorted_indices.gather(-1, sorted_token) + + return token.squeeze(-1) + +def sample_logits(logits, temp, top_p, top_k): + """ + Sample tokens from logits using temperature, top-k, and top-p (nucleus) sampling. + + Args: + logits (torch.Tensor): The logits distribution to sample from. + temp (float): temp for scaling logits. + top_p (float): The cumulative probability threshold for nucleus sampling. + + Returns: + torch.Tensor: The selected token index. + """ + + # Ensure logits are float + logits = logits.float() + + # If temp is very low, just use argmax + if temp == 0: + return logits.argmax(dim=-1) + + scaled_logits = logits/temp + + # top k + if top_k > 0: + top_values, top_indices = torch.topk(scaled_logits, top_k, dim=-1) + scaled_logits = torch.zeros_like(logits).scatter_(-1, top_indices, top_values) + + # Top-p sampling + if 0 < top_p < 1.0: + return top_p_sampling(scaled_logits, top_p) + else: + # random distribution selection + probs = torch.softmax(scaled_logits, dim=-1) + rand_sample = torch.distributions.Categorical(probs) + return rand_sample.sample().squeeze() \ No newline at end of file diff --git a/build/lib/exo/inference/pytorch/test_inference_engine.py b/build/lib/exo/inference/pytorch/test_inference_engine.py new file mode 100644 index 00000000..bacf53bc --- /dev/null +++ b/build/lib/exo/inference/pytorch/test_inference_engine.py @@ -0,0 +1,141 @@ + +import asyncio +from exo.inference.shard import Shard +from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.inference_engine import InferenceEngine +from exo.inference.shard import Shard +from exo.helpers import DEBUG +import os +import numpy as np + +async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int): + # prompt = "Why is the sky blue?" + prompt = "In a single word only, what is the last name of the current president of the USA?" + + # shard = Shard( + # model_id=model_id, + # start_layer=0, + # end_layer=n_layers-1, + # n_layers=n_layers + # ) + + # resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + # "A", + # shard=shard, + # prompt=prompt + # ) + + # print(f"resp_full: {resp_full}") + + # next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + # "A", + # shard=shard, + # input_data=resp_full, + # inference_state=inference_state_full, + # ) + + # print(f"next_resp_full: {next_resp_full}") + + pp = int(n_layers/2) + + resp_shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=pp, + n_layers=n_layers + ) + + resp_shard2 = Shard( + model_id=model_id, + start_layer=pp + 1, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( + "B", + shard=resp_shard, + prompt=prompt + ) + + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp1, + inference_state=inference_state_1, + ) + + # resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + # "B", + # shard=resp_shard, + # input_data=resp2, + # inference_state=inference_state_2, + # ) + + # resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + # "B", + # shard=resp_shard2, + # input_data=resp3, + # inference_state=inference_state_3, + # ) + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) + +if __name__ == '__main__': + # try: + # print(f"\n\n -------- TEST QWEN2 -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Qwen/Qwen2-0.5B-Instruct", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "andrijdavid/Llama3-1B-Base", + # 3 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "meta-llama/Meta-Llama-3.1-8B", + # 32 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") + + # try: + # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Chickaboo/ChickaQ-Large", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") + + try: + print(f"\n\n --------- TEST ambrosfitz/TinyLlama-1.1B-Chat-yawp -------\n\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "ambrosfitz/TinyLlama-1.1B-Chat-yawp", + 22 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! ambrosfitz/TinyLlama-1.1B-Chat-yawp TEST FAILED \n{err}\n") + diff --git a/build/lib/exo/inference/shard.py b/build/lib/exo/inference/shard.py new file mode 100644 index 00000000..21b662f6 --- /dev/null +++ b/build/lib/exo/inference/shard.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class Shard: + model_id: str + start_layer: int + end_layer: int + n_layers: int + + def __hash__(self): + return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers)) + + def is_first_layer(self) -> bool: + return self.start_layer == 0 + + def is_last_layer(self) -> bool: + return self.end_layer == self.n_layers - 1 + + def get_layer_count(self) -> int: + return self.end_layer - self.start_layer + 1 + + def to_dict(self) -> dict: + return { + "model_id": self.model_id, + "start_layer": self.start_layer, + "end_layer": self.end_layer, + "n_layers": self.n_layers, + } + + def from_dict(data: dict) -> 'Shard': + return Shard(**data) + + def overlaps(self, other: 'Shard') -> bool: + return shards_overlap(self, other) + + +def shards_overlap(shard1: Shard, shard2: Shard) -> bool: + return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)) diff --git a/build/lib/exo/inference/test_inference_engine.py b/build/lib/exo/inference/test_inference_engine.py new file mode 100644 index 00000000..e57c608d --- /dev/null +++ b/build/lib/exo/inference/test_inference_engine.py @@ -0,0 +1,64 @@ +from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.inference_engine import InferenceEngine +from exo.inference.shard import Shard +from exo.helpers import DEBUG +import os +import asyncio +import numpy as np + + +# An inference engine should work the same for any number of Shards, as long as the Shards are continuous. +async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str): + prompt = "In a single word only, what is the last name of the current president of the USA?" + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt) + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + "A", + shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), + input_data=resp_full, + inference_state=inference_state_full, + ) + + pp = 15 + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt) + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32), + input_data=resp1, + inference_state=inference_state_1, + ) + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), + input_data=resp2, + inference_state=inference_state_2, + ) + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + "B", + shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32), + input_data=resp3, + inference_state=inference_state_3, + ) + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) + + +asyncio.run(test_inference_engine( + MLXDynamicShardInferenceEngine(HFShardDownloader()), + MLXDynamicShardInferenceEngine(HFShardDownloader()), + "mlx-community/Meta-Llama-3-8B-Instruct-4bit", +)) + +if os.getenv("RUN_TINYGRAD", default="0") == "1": + import tinygrad + import os + from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine + tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) + asyncio.run( + test_inference_engine( + TinygradDynamicShardInferenceEngine(HFShardDownloader()), + TinygradDynamicShardInferenceEngine(HFShardDownloader()), + "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", + ) + ) diff --git a/build/lib/exo/inference/tokenizers.py b/build/lib/exo/inference/tokenizers.py new file mode 100644 index 00000000..9accd943 --- /dev/null +++ b/build/lib/exo/inference/tokenizers.py @@ -0,0 +1,45 @@ +import traceback +from aiofiles import os as aios +from transformers import AutoTokenizer, AutoProcessor +from exo.download.hf.hf_helpers import get_local_snapshot_dir +from exo.helpers import DEBUG + +async def resolve_tokenizer(model_id: str): + local_path = await get_local_snapshot_dir(model_id) + if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}") + try: + if await aios.path.exists(local_path): + if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}") + return await _resolve_tokenizer(local_path) + except: + if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...") + if DEBUG >= 5: traceback.print_exc() + return await _resolve_tokenizer(model_id) + +async def _resolve_tokenizer(model_id_or_local_path: str): + try: + if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}") + if "Mistral-Large" in str(model_id_or_local_path): + use_fast = True + else: + use_fast = False + processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=use_fast) + if not hasattr(processor, 'eos_token_id'): + processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id + if not hasattr(processor, 'encode'): + processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode + if not hasattr(processor, 'decode'): + processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode + return processor + except Exception as e: + if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}") + if DEBUG >= 4: print(traceback.format_exc()) + + try: + if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}") + return AutoTokenizer.from_pretrained(model_id_or_local_path) + except Exception as e: + if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}") + if DEBUG >= 4: print(traceback.format_exc()) + + raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}") diff --git a/build/lib/exo/models.py b/build/lib/exo/models.py new file mode 100644 index 00000000..137b881c --- /dev/null +++ b/build/lib/exo/models.py @@ -0,0 +1,44 @@ +from exo.inference.shard import Shard + +model_base_shards = { + ### llama + "llama-3.1-8b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), + "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Meta-Llama-3.1-8B", start_layer=0, end_layer=0, n_layers=32), + }, + "llama-3.1-70b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80), + }, + "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),}, + "llama-3-8b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), + "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32), + }, + "llama-3-70b": { + "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), + "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80), + }, + "llama-3-2B-Base": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-2B-Base", start_layer=0, end_layer=0, n_layers=6), + }, + "llama-3-1B-Base": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), + }, + "TinyLlama-1.1B-Chat-yaw": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="ambrosfitz/TinyLlama-1.1B-Chat-yawp", start_layer=0, end_layer=0, n_layers=22), + }, + ### mistral + "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, + "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),}, + ### deepseek v2 + "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),}, + ### llava + "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),}, + ### qwen + "Qwen2-0.5B-Instruct": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), + }, + +} diff --git a/build/lib/exo/networking/__init__.py b/build/lib/exo/networking/__init__.py new file mode 100644 index 00000000..44a10a30 --- /dev/null +++ b/build/lib/exo/networking/__init__.py @@ -0,0 +1,5 @@ +from .discovery import Discovery +from .peer_handle import PeerHandle +from .server import Server + +__all__ = ["Discovery", "PeerHandle", "Server"] diff --git a/build/lib/exo/networking/discovery.py b/build/lib/exo/networking/discovery.py new file mode 100644 index 00000000..cdcbfabc --- /dev/null +++ b/build/lib/exo/networking/discovery.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from typing import List +from .peer_handle import PeerHandle + + +class Discovery(ABC): + @abstractmethod + async def start(self) -> None: + pass + + @abstractmethod + async def stop(self) -> None: + pass + + @abstractmethod + async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: + pass diff --git a/build/lib/exo/networking/grpc/__init__.py b/build/lib/exo/networking/grpc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/networking/grpc/grpc_discovery.py b/build/lib/exo/networking/grpc/grpc_discovery.py new file mode 100644 index 00000000..eb08a838 --- /dev/null +++ b/build/lib/exo/networking/grpc/grpc_discovery.py @@ -0,0 +1,188 @@ +import asyncio +import json +import socket +import time +from typing import List, Dict, Callable, Tuple, Coroutine +from ..discovery import Discovery +from ..peer_handle import PeerHandle +from .grpc_peer_handle import GRPCPeerHandle +from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES +from exo import DEBUG_DISCOVERY + + +class ListenProtocol(asyncio.DatagramProtocol): + def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]): + super().__init__() + self.on_message = on_message + self.loop = asyncio.get_event_loop() + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + asyncio.create_task(self.on_message(data, addr)) + + +class GRPCDiscovery(Discovery): + def __init__( + self, + node_id: str, + node_port: int, + listen_port: int, + broadcast_port: int = None, + broadcast_interval: int = 1, + device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES, + discovery_timeout: int = 30, + ): + self.node_id = node_id + self.node_port = node_port + self.device_capabilities = device_capabilities + self.listen_port = listen_port + self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port + self.broadcast_interval = broadcast_interval + self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {} + self.broadcast_task = None + self.listen_task = None + self.cleanup_task = None + self.discovery_timeout = discovery_timeout + + async def start(self): + self.device_capabilities = device_capabilities() + self.broadcast_task = asyncio.create_task(self.task_broadcast_presence()) + self.listen_task = asyncio.create_task(self.task_listen_for_peers()) + self.cleanup_task = asyncio.create_task(self.task_cleanup_peers()) + + async def stop(self): + if self.broadcast_task: + self.broadcast_task.cancel() + if self.listen_task: + self.listen_task.cancel() + if self.cleanup_task: + self.cleanup_task.cancel() + if self.broadcast_task or self.listen_task or self.cleanup_task: + await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True) + + async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: + if DEBUG_DISCOVERY >= 2: + print("Starting peer discovery process...") + + if wait_for_peers > 0: + while len(self.known_peers) == 0: + if DEBUG_DISCOVERY >= 2: + print("No peers discovered yet, retrying in 1 second...") + await asyncio.sleep(1) # Keep trying to find peers + if DEBUG_DISCOVERY >= 2: + print(f"Discovered first peer: {next(iter(self.known_peers.values()))}") + + grace_period = 5 # seconds + while True: + initial_peer_count = len(self.known_peers) + if DEBUG_DISCOVERY >= 2: + print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...") + if len(self.known_peers) == initial_peer_count: + if wait_for_peers > 0: + await asyncio.sleep(grace_period) + if DEBUG_DISCOVERY >= 2: + print(f"Waiting additional {wait_for_peers} seconds for more peers.") + wait_for_peers = 0 + else: + if DEBUG_DISCOVERY >= 2: + print("No new peers discovered in the last grace period. Ending discovery process.") + break # No new peers found in the grace period, we are done + + return [peer_handle for peer_handle, _, _ in self.known_peers.values()] + + async def task_broadcast_presence(self): + transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET) + sock = transport.get_extra_info("socket") + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + + message = json.dumps({ + "type": "discovery", + "node_id": self.node_id, + "grpc_port": self.node_port, + "device_capabilities": self.device_capabilities.to_dict(), + }).encode("utf-8") + + while True: + try: + if DEBUG_DISCOVERY >= 3: + print(f"Broadcast presence: {message}") + transport.sendto(message, ("", self.broadcast_port)) + await asyncio.sleep(self.broadcast_interval) + except Exception as e: + print(f"Error in broadcast presence: {e}") + import traceback + + print(traceback.format_exc()) + + async def on_listen_message(self, data, addr): + if not data: + return + + decoded_data = data.decode("utf-8", errors="ignore") + + # Check if the decoded data starts with a valid JSON character + if not (decoded_data.strip() and decoded_data.strip()[0] in "{["): + if DEBUG_DISCOVERY >= 2: + print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}") + return + + try: + decoder = json.JSONDecoder(strict=False) + message = decoder.decode(decoded_data) + except json.JSONDecodeError as e: + if DEBUG_DISCOVERY >= 2: + print(f"Error decoding JSON data from {addr}: {e}") + return + + if DEBUG_DISCOVERY >= 2: + print(f"received from peer {addr}: {message}") + + if message["type"] == "discovery" and message["node_id"] != self.node_id: + peer_id = message["node_id"] + peer_host = addr[0] + peer_port = message["grpc_port"] + device_capabilities = DeviceCapabilities(**message["device_capabilities"]) + if peer_id not in self.known_peers: + self.known_peers[peer_id] = ( + GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), + time.time(), + time.time(), + ) + if DEBUG_DISCOVERY >= 2: + print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}") + self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time()) + + async def task_listen_for_peers(self): + await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port)) + if DEBUG_DISCOVERY >= 2: + print("Started listen task") + + async def task_cleanup_peers(self): + while True: + try: + current_time = time.time() + peers_to_remove = [ + peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() + if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout + ] + if DEBUG_DISCOVERY >= 2: + print( + "Peer statuses:", + {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" + for peer_handle, connected_at, last_seen in self.known_peers.values()}, + ) + if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: + print(f"Cleaning up peers: {peers_to_remove}") + for peer_id in peers_to_remove: + if peer_id in self.known_peers: + del self.known_peers[peer_id] + if DEBUG_DISCOVERY >= 2: + print(f"Removed peer {peer_id} due to inactivity.") + await asyncio.sleep(self.broadcast_interval) + except Exception as e: + print(f"Error in cleanup peers: {e}") + import traceback + + print(traceback.format_exc()) diff --git a/build/lib/exo/networking/grpc/grpc_peer_handle.py b/build/lib/exo/networking/grpc/grpc_peer_handle.py new file mode 100644 index 00000000..0629dc77 --- /dev/null +++ b/build/lib/exo/networking/grpc/grpc_peer_handle.py @@ -0,0 +1,109 @@ +import grpc +import numpy as np +from typing import Optional, Tuple, List + +# These would be generated from the .proto file +from . import node_service_pb2 +from . import node_service_pb2_grpc + +from ..peer_handle import PeerHandle +from exo.inference.shard import Shard +from exo.topology.topology import Topology +from exo.topology.device_capabilities import DeviceCapabilities + + +class GRPCPeerHandle(PeerHandle): + def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities): + self._id = _id + self.address = address + self._device_capabilities = device_capabilities + self.channel = None + self.stub = None + + def id(self) -> str: + return self._id + + def device_capabilities(self) -> DeviceCapabilities: + return self._device_capabilities + + async def connect(self): + self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)]) + self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel) + + async def is_connected(self) -> bool: + return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY + + async def disconnect(self): + if self.channel: + await self.channel.close() + self.channel = None + self.stub = None + + async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + request = node_service_pb2.PromptRequest( + prompt=prompt, + image_str=image_str, + shard=node_service_pb2.Shard( + model_id=shard.model_id, + start_layer=shard.start_layer, + end_layer=shard.end_layer, + n_layers=shard.n_layers, + ), + request_id=request_id, + inference_state=inference_state, + ) + response = await self.stub.SendPrompt(request) + + if not response.tensor_data or not response.shape or not response.dtype: + return None + + return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape) + + async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + request = node_service_pb2.TensorRequest( + shard=node_service_pb2.Shard( + model_id=shard.model_id, + start_layer=shard.start_layer, + end_layer=shard.end_layer, + n_layers=shard.n_layers, + ), + tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)), + request_id=request_id, + inference_state=inference_state, + ) + response = await self.stub.SendTensor(request) + + if not response.tensor_data or not response.shape or not response.dtype: + return None + + return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape) + + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + request = node_service_pb2.GetInferenceResultRequest(request_id=request_id) + response = await self.stub.GetInferenceResult(request) + if response.tensor is None: + return None, response.is_finished + return ( + np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape), + response.is_finished, + ) + + async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: + request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth) + response = await self.stub.CollectTopology(request) + topology = Topology() + for node_id, capabilities in response.nodes.items(): + device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops) + topology.update_node(node_id, device_capabilities) + for node_id, peers in response.peer_graph.items(): + for peer_id in peers.peer_ids: + topology.add_edge(node_id, peer_id) + return topology + + async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished) + await self.stub.SendResult(request) + + async def send_opaque_status(self, request_id: str, status: str) -> None: + request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status) + await self.stub.SendOpaqueStatus(request) diff --git a/build/lib/exo/networking/grpc/grpc_server.py b/build/lib/exo/networking/grpc/grpc_server.py new file mode 100644 index 00000000..1481ef51 --- /dev/null +++ b/build/lib/exo/networking/grpc/grpc_server.py @@ -0,0 +1,118 @@ +import grpc +from concurrent import futures +import numpy as np +from asyncio import CancelledError + +from . import node_service_pb2 +from . import node_service_pb2_grpc +from exo import DEBUG +from exo.inference.shard import Shard +from exo.orchestration import Node + + +class GRPCServer(node_service_pb2_grpc.NodeServiceServicer): + def __init__(self, node: Node, host: str, port: int): + self.node = node + self.host = host + self.port = port + self.server = None + + async def start(self) -> None: + self.server = grpc.aio.server( + futures.ThreadPoolExecutor(max_workers=10), + options=[ + ("grpc.max_metadata_size", 32*1024*1024), + ("grpc.max_send_message_length", 128*1024*1024), + ("grpc.max_receive_message_length", 128*1024*1024), + ], + ) + node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server) + listen_addr = f"{self.host}:{self.port}" + self.server.add_insecure_port(listen_addr) + await self.server.start() + if DEBUG >= 1: print(f"Server started, listening on {listen_addr}") + + async def stop(self) -> None: + if self.server: + try: + await self.server.stop(grace=5) + await self.server.wait_for_termination() + except CancelledError: + pass + if DEBUG >= 1: print("Server stopped and all connections are closed") + + async def SendPrompt(self, request, context): + shard = Shard( + model_id=request.shard.model_id, + start_layer=request.shard.start_layer, + end_layer=request.shard.end_layer, + n_layers=request.shard.n_layers, + ) + prompt = request.prompt + image_str = request.image_str + request_id = request.request_id + result = await self.node.process_prompt(shard, prompt, image_str, request_id) + if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}") + tensor_data = result.tobytes() if result is not None else None + return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() + + async def SendTensor(self, request, context): + shard = Shard( + model_id=request.shard.model_id, + start_layer=request.shard.start_layer, + end_layer=request.shard.end_layer, + n_layers=request.shard.n_layers, + ) + tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape) + request_id = request.request_id + inference_state = request.inference_state + + result = await self.node.process_tensor(shard, tensor, request_id, inference_state) + if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}") + tensor_data = result.tobytes() if result is not None else None + return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() + + async def GetInferenceResult(self, request, context): + request_id = request.request_id + result = await self.node.get_inference_result(request_id) + if DEBUG >= 5: print(f"GetInferenceResult {request_id=}: {result}") + tensor_data = result[0].tobytes() if result[0] is not None else None + return ( + node_service_pb2.InferenceResult( + tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)), + is_finished=result[1], + ) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1]) + ) + + async def CollectTopology(self, request, context): + max_depth = request.max_depth + visited = set(request.visited) + topology = await self.node.collect_topology(visited, max_depth) + nodes = { + node_id: + node_service_pb2.DeviceCapabilities( + model=cap.model, + chip=cap.chip, + memory=cap.memory, + flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8), + ) + for node_id, cap in topology.nodes.items() + } + peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()} + if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}") + return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph) + + async def SendResult(self, request, context): + request_id = request.request_id + result = request.result + is_finished = request.is_finished + if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}") + self.node.on_token.trigger_all(request_id, result, is_finished) + return node_service_pb2.Empty() + + async def SendOpaqueStatus(self, request, context): + request_id = request.request_id + status = request.status + if DEBUG >= 5: print(f"Received SendOpaqueStatus request: {request_id=} {status=}") + self.node.on_opaque_status.trigger_all(request_id, status) + return node_service_pb2.Empty() diff --git a/build/lib/exo/networking/grpc/node_service_pb2.py b/build/lib/exo/networking/grpc/node_service_pb2.py new file mode 100644 index 00000000..cae2d080 --- /dev/null +++ b/build/lib/exo/networking/grpc/node_service_pb2.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: node_service.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xc3\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x16\n\timage_str\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x05 \x01(\tH\x02\x88\x01\x01\x42\x0c\n\n_image_strB\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0finference_state\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x07\n\x05\x45mpty2\xde\x03\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TOPOLOGY_NODESENTRY']._loaded_options = None + _globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001' + _globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001' + _globals['_SHARD']._serialized_start = 36 + _globals['_SHARD']._serialized_end = 119 + _globals['_PROMPTREQUEST']._serialized_start = 122 + _globals['_PROMPTREQUEST']._serialized_end = 317 + _globals['_TENSORREQUEST']._serialized_start = 320 + _globals['_TENSORREQUEST']._serialized_end = 499 + _globals['_GETINFERENCERESULTREQUEST']._serialized_start = 501 + _globals['_GETINFERENCERESULTREQUEST']._serialized_end = 548 + _globals['_INFERENCERESULT']._serialized_start = 550 + _globals['_INFERENCERESULT']._serialized_end = 642 + _globals['_TENSOR']._serialized_start = 644 + _globals['_TENSOR']._serialized_end = 703 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_start = 705 + _globals['_COLLECTTOPOLOGYREQUEST']._serialized_end = 765 + _globals['_TOPOLOGY']._serialized_start = 768 + _globals['_TOPOLOGY']._serialized_end = 1038 + _globals['_TOPOLOGY_NODESENTRY']._serialized_start = 889 + _globals['_TOPOLOGY_NODESENTRY']._serialized_end = 967 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start = 969 + _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end = 1038 + _globals['_PEERS']._serialized_start = 1040 + _globals['_PEERS']._serialized_end = 1065 + _globals['_DEVICEFLOPS']._serialized_start = 1067 + _globals['_DEVICEFLOPS']._serialized_end = 1122 + _globals['_DEVICECAPABILITIES']._serialized_start = 1124 + _globals['_DEVICECAPABILITIES']._serialized_end = 1231 + _globals['_SENDRESULTREQUEST']._serialized_start = 1233 + _globals['_SENDRESULTREQUEST']._serialized_end = 1309 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_start = 1311 + _globals['_SENDOPAQUESTATUSREQUEST']._serialized_end = 1372 + _globals['_EMPTY']._serialized_start = 1374 + _globals['_EMPTY']._serialized_end = 1381 + _globals['_NODESERVICE']._serialized_start = 1384 + _globals['_NODESERVICE']._serialized_end = 1862 +# @@protoc_insertion_point(module_scope) diff --git a/build/lib/exo/networking/grpc/node_service_pb2_grpc.py b/build/lib/exo/networking/grpc/node_service_pb2_grpc.py new file mode 100644 index 00000000..ea1d3c98 --- /dev/null +++ b/build/lib/exo/networking/grpc/node_service_pb2_grpc.py @@ -0,0 +1,272 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from . import node_service_pb2 as node__service__pb2 + +GRPC_GENERATED_VERSION = '1.64.1' +GRPC_VERSION = grpc.__version__ +EXPECTED_ERROR_RELEASE = '1.65.0' +SCHEDULED_RELEASE_DATE = 'June 25, 2024' +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + warnings.warn( + f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning + ) + + +class NodeServiceStub(object): + """Missing associated documentation comment in .proto file.""" + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendPrompt = channel.unary_unary( + '/node_service.NodeService/SendPrompt', + request_serializer=node__service__pb2.PromptRequest.SerializeToString, + response_deserializer=node__service__pb2.Tensor.FromString, + _registered_method=True + ) + self.SendTensor = channel.unary_unary( + '/node_service.NodeService/SendTensor', + request_serializer=node__service__pb2.TensorRequest.SerializeToString, + response_deserializer=node__service__pb2.Tensor.FromString, + _registered_method=True + ) + self.GetInferenceResult = channel.unary_unary( + '/node_service.NodeService/GetInferenceResult', + request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString, + response_deserializer=node__service__pb2.InferenceResult.FromString, + _registered_method=True + ) + self.CollectTopology = channel.unary_unary( + '/node_service.NodeService/CollectTopology', + request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString, + response_deserializer=node__service__pb2.Topology.FromString, + _registered_method=True + ) + self.SendResult = channel.unary_unary( + '/node_service.NodeService/SendResult', + request_serializer=node__service__pb2.SendResultRequest.SerializeToString, + response_deserializer=node__service__pb2.Empty.FromString, + _registered_method=True + ) + self.SendOpaqueStatus = channel.unary_unary( + '/node_service.NodeService/SendOpaqueStatus', + request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString, + response_deserializer=node__service__pb2.Empty.FromString, + _registered_method=True + ) + + +class NodeServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + def SendPrompt(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendTensor(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetInferenceResult(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CollectTopology(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendResult(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendOpaqueStatus(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_NodeServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendPrompt': + grpc.unary_unary_rpc_method_handler( + servicer.SendPrompt, + request_deserializer=node__service__pb2.PromptRequest.FromString, + response_serializer=node__service__pb2.Tensor.SerializeToString, + ), + 'SendTensor': + grpc.unary_unary_rpc_method_handler( + servicer.SendTensor, + request_deserializer=node__service__pb2.TensorRequest.FromString, + response_serializer=node__service__pb2.Tensor.SerializeToString, + ), + 'GetInferenceResult': + grpc.unary_unary_rpc_method_handler( + servicer.GetInferenceResult, + request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString, + response_serializer=node__service__pb2.InferenceResult.SerializeToString, + ), + 'CollectTopology': + grpc.unary_unary_rpc_method_handler( + servicer.CollectTopology, + request_deserializer=node__service__pb2.CollectTopologyRequest.FromString, + response_serializer=node__service__pb2.Topology.SerializeToString, + ), + 'SendResult': + grpc.unary_unary_rpc_method_handler( + servicer.SendResult, + request_deserializer=node__service__pb2.SendResultRequest.FromString, + response_serializer=node__service__pb2.Empty.SerializeToString, + ), + 'SendOpaqueStatus': + grpc.unary_unary_rpc_method_handler( + servicer.SendOpaqueStatus, + request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString, + response_serializer=node__service__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers) + + +# This class is part of an EXPERIMENTAL API. +class NodeService(object): + """Missing associated documentation comment in .proto file.""" + @staticmethod + def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/SendPrompt', + node__service__pb2.PromptRequest.SerializeToString, + node__service__pb2.Tensor.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/SendTensor', + node__service__pb2.TensorRequest.SerializeToString, + node__service__pb2.Tensor.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/GetInferenceResult', + node__service__pb2.GetInferenceResultRequest.SerializeToString, + node__service__pb2.InferenceResult.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/CollectTopology', + node__service__pb2.CollectTopologyRequest.SerializeToString, + node__service__pb2.Topology.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/SendResult', + node__service__pb2.SendResultRequest.SerializeToString, + node__service__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) + + @staticmethod + def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/node_service.NodeService/SendOpaqueStatus', + node__service__pb2.SendOpaqueStatusRequest.SerializeToString, + node__service__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True + ) diff --git a/build/lib/exo/networking/grpc/test_grpc_discovery.py b/build/lib/exo/networking/grpc/test_grpc_discovery.py new file mode 100644 index 00000000..13372bbb --- /dev/null +++ b/build/lib/exo/networking/grpc/test_grpc_discovery.py @@ -0,0 +1,22 @@ +import asyncio +import unittest +from .grpc_discovery import GRPCDiscovery + + +class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679) + self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678) + await self.node1.start() + await self.node2.start() + + async def asyncTearDown(self): + await self.node1.stop() + await self.node2.stop() + + async def test_discovery(self): + await asyncio.sleep(4) + + # Check discovered peers + print("Node1 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()])) + print("Node2 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()])) diff --git a/build/lib/exo/networking/peer_handle.py b/build/lib/exo/networking/peer_handle.py new file mode 100644 index 00000000..cf232d00 --- /dev/null +++ b/build/lib/exo/networking/peer_handle.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple, List +import numpy as np +from exo.inference.shard import Shard +from exo.topology.device_capabilities import DeviceCapabilities +from exo.topology.topology import Topology + + +class PeerHandle(ABC): + @abstractmethod + def id(self) -> str: + pass + + @abstractmethod + def device_capabilities(self) -> DeviceCapabilities: + pass + + @abstractmethod + async def connect(self) -> None: + pass + + @abstractmethod + async def is_connected(self) -> bool: + pass + + @abstractmethod + async def disconnect(self) -> None: + pass + + @abstractmethod + async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + pass + + @abstractmethod + async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]: + pass + + @abstractmethod + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + pass + + @abstractmethod + async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: + pass + + @abstractmethod + async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + pass diff --git a/build/lib/exo/networking/server.py b/build/lib/exo/networking/server.py new file mode 100644 index 00000000..8e7f9812 --- /dev/null +++ b/build/lib/exo/networking/server.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + + +class Server(ABC): + @abstractmethod + async def start(self) -> None: + pass + + @abstractmethod + async def stop(self) -> None: + pass diff --git a/build/lib/exo/orchestration/__init__.py b/build/lib/exo/orchestration/__init__.py new file mode 100644 index 00000000..478af537 --- /dev/null +++ b/build/lib/exo/orchestration/__init__.py @@ -0,0 +1,4 @@ +from .node import Node +from .standard_node import StandardNode + +__all__ = ["Node", "StandardNode"] diff --git a/build/lib/exo/orchestration/node.py b/build/lib/exo/orchestration/node.py new file mode 100644 index 00000000..60b72974 --- /dev/null +++ b/build/lib/exo/orchestration/node.py @@ -0,0 +1,47 @@ +from typing import Optional, Tuple, List +import numpy as np +from abc import ABC, abstractmethod +from exo.helpers import AsyncCallbackSystem +from exo.inference.shard import Shard +from exo.topology.topology import Topology + + +class Node(ABC): + @abstractmethod + async def start(self, wait_for_peers: int = 0) -> None: + pass + + @abstractmethod + async def stop(self) -> None: + pass + + @abstractmethod + async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + pass + + @abstractmethod + async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + pass + + @abstractmethod + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + pass + + @abstractmethod + async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology: + pass + + @property + @abstractmethod + def current_topology(self) -> Topology: + pass + + @property + @abstractmethod + def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]: + pass + + @property + @abstractmethod + def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]: + pass diff --git a/build/lib/exo/orchestration/standard_node.py b/build/lib/exo/orchestration/standard_node.py new file mode 100644 index 00000000..b968b659 --- /dev/null +++ b/build/lib/exo/orchestration/standard_node.py @@ -0,0 +1,385 @@ +import numpy as np +import json +import asyncio +import uuid +import time +import traceback +from typing import List, Dict, Optional, Tuple, Union +from exo.networking import Discovery, PeerHandle, Server +from exo.inference.inference_engine import InferenceEngine, Shard +from .node import Node +from exo.topology.topology import Topology +from exo.topology.device_capabilities import device_capabilities +from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards +from exo import DEBUG +from exo.helpers import AsyncCallbackSystem +from exo.viz.topology_viz import TopologyViz +from exo.download.hf.hf_helpers import RepoProgressEvent + + +class StandardNode(Node): + def __init__( + self, + _id: str, + server: Server, + inference_engine: InferenceEngine, + discovery: Discovery, + partitioning_strategy: PartitioningStrategy = None, + max_generate_tokens: int = 1024, + chatgpt_api_endpoints: List[str] = [], + web_chat_urls: List[str] = [], + disable_tui: Optional[bool] = False, + topology_viz: Optional[TopologyViz] = None, + ): + self.id = _id + self.inference_engine = inference_engine + self.server = server + self.discovery = discovery + self.partitioning_strategy = partitioning_strategy + self.peers: List[PeerHandle] = {} + self.topology: Topology = Topology() + self.device_capabilities = device_capabilities() + self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {} + self.max_generate_tokens = max_generate_tokens + self.topology_viz = topology_viz + self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]() + self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]() + self._on_opaque_status.register("node_status").on_next(self.on_node_status) + self.node_download_progress: Dict[str, RepoProgressEvent] = {} + + async def start(self, wait_for_peers: int = 0) -> None: + await self.server.start() + await self.discovery.start() + await self.update_peers(wait_for_peers) + await self.collect_topology() + if DEBUG >= 2: print(f"Collected topology: {self.topology}") + asyncio.create_task(self.periodic_topology_collection(5)) + + async def stop(self) -> None: + await self.discovery.stop() + await self.server.stop() + + def on_node_status(self, request_id, opaque_status): + try: + status_data = json.loads(opaque_status) + if status_data.get("type", "") == "node_status": + if status_data.get("status", "").startswith("start_"): + self.current_topology.active_node_id = status_data.get("node_id") + elif status_data.get("status", "").startswith("end_"): + if status_data.get("node_id") == self.current_topology.active_node_id: + self.current_topology.active_node_id = None + download_progress = None + if status_data.get("type", "") == "download_progress": + if DEBUG >= 5: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}") + download_progress = RepoProgressEvent.from_dict(status_data.get('progress')) + self.node_download_progress[status_data.get('node_id')] = download_progress + if self.topology_viz: + self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress) + except Exception as e: + if DEBUG >= 1: print(f"Error updating visualization: {e}") + if DEBUG >= 1: traceback.print_exc() + + async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + shard = self.get_current_shard(base_shard) + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "start_process_prompt", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "prompt": prompt, + "image_str": image_str, + "inference_state": inference_state, + "request_id": request_id, + }), + ) + ) + start_time = time.perf_counter_ns() + resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state) + end_time = time.perf_counter_ns() + elapsed_time_ns = end_time - start_time + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "end_process_prompt", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "prompt": prompt, + "image_str": image_str, + "inference_state": inference_state, + "request_id": request_id, + "elapsed_time_ns": elapsed_time_ns, + "result_size": resp.size if resp is not None else 0, + }), + ) + ) + return resp + + async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]: + if request_id is None: + request_id = str(uuid.uuid4()) + if request_id not in self.buffered_token_output: + self.buffered_token_output[request_id] = ([], False) + shard = self.get_current_shard(base_shard) + + if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}") + if shard.start_layer != 0: + if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}") + await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state) + return + + result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state) + is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens + if is_finished: + self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) + asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity + + if result.size == 1: + self.buffered_token_output[request_id][0].append(result.item()) + self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished) + + if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") + + if not is_finished: + asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state)) + + return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None + + async def process_tensor( + self, + base_shard: Shard, + tensor: np.ndarray, + request_id: Optional[str] = None, + inference_state: Optional[str] = None, + ) -> Optional[np.ndarray]: + shard = self.get_current_shard(base_shard) + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "start_process_tensor", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "tensor_size": tensor.size, + "tensor_shape": tensor.shape, + "request_id": request_id, + "inference_state": inference_state, + }), + ) + ) + start_time = time.perf_counter_ns() + resp = await self._process_tensor(shard, tensor, request_id, inference_state) + end_time = time.perf_counter_ns() + elapsed_time_ns = end_time - start_time + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "end_process_tensor", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "request_id": request_id, + "elapsed_time_ns": elapsed_time_ns, + "result_size": resp.size if resp is not None else 0, + }), + ) + ) + return resp + + async def _process_tensor( + self, + base_shard: Shard, + tensor: np.ndarray, + request_id: Optional[str] = None, + inference_state: Optional[str] = None, + ) -> Optional[np.ndarray]: + if request_id is None: + request_id = str(uuid.uuid4()) + if request_id not in self.buffered_token_output: + self.buffered_token_output[request_id] = ([], False) + shard = self.get_current_shard(base_shard) + + try: + if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}") + result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state) + is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens + if is_finished: + self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) + asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity + + if result.size == 1: # we got a new token out + self.buffered_token_output[request_id][0].append(result.item()) + self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished) + if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") + + if not is_finished: + asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state)) + + return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None + except Exception as e: + print(f"Error processing tensor for shard {shard}: {e}") + traceback.print_exc() + return None + + async def forward_to_next_shard( + self, + base_shard: Shard, + tensor_or_prompt: Union[np.ndarray, str], + request_id: str, + image_str: Optional[str] = None, + inference_state: Optional[str] = None, + ) -> None: + if not self.partitioning_strategy: + if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.") + return + shard = self.get_current_shard(base_shard) + + partitions = self.partitioning_strategy.partition(self.topology) + shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id) + current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None) + if DEBUG >= 1: print(f"Current partition index: {current_partition_index}") + if current_partition_index is not None: + next_partition_index = (current_partition_index+1) % len(partitions) + next_partition: Partition = partitions[next_partition_index] + next_shard = shards[next_partition_index] + if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}") + + if next_partition.node_id == self.id: + if isinstance(tensor_or_prompt, np.ndarray): + await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state) + else: + await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state) + return + + target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None) + if not target_peer: + raise ValueError(f"Peer for {next_partition} not found") + + if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}") + + if isinstance(tensor_or_prompt, np.ndarray): + await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state) + else: + await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state) + + def get_current_shard(self, base_shard: Shard) -> Shard: + partitions = self.partitioning_strategy.partition(self.topology) + shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id) + current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None) + if current_partition_index is None: + raise ValueError(f"No current partition found for node: {self.id}") + return shards[current_partition_index] + + async def update_peers(self, wait_for_peers: int = 0) -> None: + self.peers = await self.discovery.discover_peers(wait_for_peers) + for peer in self.peers: + is_connected = await peer.is_connected() + if DEBUG >= 2 and is_connected: + print(f"Already connected to {peer.id()}: {is_connected}") + if not is_connected: + if DEBUG >= 2: print(f"Connecting to {peer.id()}...") + await peer.connect() + if DEBUG >= 1: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})") + + async def periodic_topology_collection(self, interval: int): + while True: + await asyncio.sleep(interval) + try: + await self.update_peers() + await self.collect_topology() + except Exception as e: + print(f"Error collecting topology: {e}") + traceback.print_exc() + + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + if request_id not in self.buffered_token_output: + return None, False + return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1] + + async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology: + next_topology = Topology() + next_topology.update_node(self.id, self.device_capabilities) + + if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}") + + prev_visited = visited.copy() + # TODO: should we add our own peer id here? + visited.update(p.id() for p in self.peers) + + for peer in self.peers: + next_topology.update_node(peer.id(), peer.device_capabilities()) + next_topology.add_edge(self.id, peer.id()) + + if peer.id() in prev_visited: + continue + + if max_depth <= 0: + if DEBUG >= 2: print("Max depth reached. Skipping...") + continue + + try: + other_topology = await peer.collect_topology(visited, max_depth=max_depth - 1) + if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}") + self.topology.merge(other_topology) + except Exception as e: + print(f"Error collecting topology from {peer.id()}: {e}") + + next_topology.active_node_id = self.topology.active_node_id # this is not so clean. + self.topology = next_topology + if self.topology_viz: + self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id) + return next_topology + + @property + def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]: + return self._on_token + + @property + def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]: + return self._on_opaque_status + + def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None: + if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}") + self.on_token.trigger_all(request_id, tokens, is_finished) + + async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + async def send_result_to_peer(peer): + try: + await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0) + except asyncio.TimeoutError: + print(f"Timeout broadcasting result to {peer.id()}") + except Exception as e: + print(f"Error broadcasting result to {peer.id()}: {e}") + traceback.print_exc() + + await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True) + + async def broadcast_opaque_status(self, request_id: str, status: str) -> None: + if DEBUG >= 5: print(f"Broadcasting opaque status: {request_id=} {status=}") + + async def send_status_to_peer(peer): + try: + await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0) + except asyncio.TimeoutError: + print(f"Timeout sending opaque status to {peer.id()}") + except Exception as e: + print(f"Error sending opaque status to {peer.id()}: {e}") + traceback.print_exc() + + await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True) + # in the case of opaque status, we also want to receive our own opaque statuses + self.on_opaque_status.trigger_all(request_id, status) + + @property + def current_topology(self) -> Topology: + return self.topology diff --git a/build/lib/exo/orchestration/test_node.py b/build/lib/exo/orchestration/test_node.py new file mode 100644 index 00000000..230ef0cf --- /dev/null +++ b/build/lib/exo/orchestration/test_node.py @@ -0,0 +1,57 @@ +import unittest +from unittest.mock import Mock, AsyncMock +import numpy as np + +from .standard_node import StandardNode +from exo.networking.peer_handle import PeerHandle + + +class TestNode(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.mock_inference_engine = AsyncMock() + self.mock_server = AsyncMock() + self.mock_server.start = AsyncMock() + self.mock_server.stop = AsyncMock() + self.mock_discovery = AsyncMock() + self.mock_discovery.start = AsyncMock() + self.mock_discovery.stop = AsyncMock() + mock_peer1 = Mock(spec=PeerHandle) + mock_peer1.id.return_value = "peer1" + mock_peer2 = Mock(spec=PeerHandle) + mock_peer2.id.return_value = "peer2" + self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2]) + + self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery) + + async def asyncSetUp(self): + await self.node.start() + + async def asyncTearDown(self): + await self.node.stop() + + async def test_node_initialization(self): + self.assertEqual(self.node.node_id, "test_node") + self.assertEqual(self.node.host, "localhost") + self.assertEqual(self.node.port, 50051) + + async def test_node_start(self): + self.mock_server.start.assert_called_once_with("localhost", 50051) + + async def test_node_stop(self): + await self.node.stop() + self.mock_server.stop.assert_called_once() + + async def test_discover_and_connect_to_peers(self): + await self.node.discover_and_connect_to_peers() + self.assertEqual(len(self.node.peers), 2) + self.assertIn("peer1", map(lambda p: p.id(), self.node.peers)) + self.assertIn("peer2", map(lambda p: p.id(), self.node.peers)) + + async def test_process_tensor_calls_inference_engine(self): + mock_peer = Mock() + self.node.peers = [mock_peer] + + input_tensor = np.array([69, 1, 2]) + await self.node.process_tensor(input_tensor, None) + + self.node.inference_engine.process_shard.assert_called_once_with(input_tensor) diff --git a/build/lib/exo/stats/__init__.py b/build/lib/exo/stats/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/stats/metrics.py b/build/lib/exo/stats/metrics.py new file mode 100644 index 00000000..f29533ff --- /dev/null +++ b/build/lib/exo/stats/metrics.py @@ -0,0 +1,29 @@ +from exo.orchestration import Node +from prometheus_client import start_http_server, Counter, Histogram +import json + +# Create metrics to track time spent and requests made. +PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"]) +PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"]) +PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"]) + + +def start_metrics_server(node: Node, port: int): + start_http_server(port) + + def _on_opaque_status(request_id, opaque_status: str): + status_data = json.loads(opaque_status) + _type = status_data.get("type", "") + node_id = status_data.get("node_id", "") + if _type != "node_status": + return + status = status_data.get("status", "") + + if status == "end_process_prompt": + PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc() + elif status == "end_process_tensor": + elapsed_time_ns = status_data.get("elapsed_time_ns", 0) + PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc() + PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9) # Convert ns to seconds + + node.on_opaque_status.register("stats").on_next(_on_opaque_status) diff --git a/build/lib/exo/test_callbacks.py b/build/lib/exo/test_callbacks.py new file mode 100644 index 00000000..c10083d6 --- /dev/null +++ b/build/lib/exo/test_callbacks.py @@ -0,0 +1,50 @@ +import asyncio +from typing import Any, Callable +from exo.helpers import AsyncCallbackSystem, AsyncCallback + + +# Usage example +async def main() -> None: + callback_system = AsyncCallbackSystem[str, Any]() + + # Register callbacks + callback1 = callback_system.register("callback1") + callback2 = callback_system.register("callback2") + + def on_next_callback(name: str) -> Callable[..., None]: + def callback(*args: Any) -> None: + print(f"{name} received values: {args}") + + return callback + + callback1.on_next(on_next_callback("Callback1")) + callback2.on_next(on_next_callback("Callback2")) + + async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None: + try: + result = await callback.wait(condition, timeout=2) + print(f"{name} wait completed with result: {result}") + except asyncio.TimeoutError: + print(f"{name} wait timed out") + + # Trigger all callbacks at once + callback_system.trigger_all("Hello", 42, True) + + # Wait for all callbacks with different conditions + await asyncio.gather( + wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0), + wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True), + ) + + # Trigger individual callback + callback_system.trigger("callback2", "World", -10, False) + + # Demonstrate timeout + new_callback = callback_system.register("new_callback") + new_callback.on_next(on_next_callback("NewCallback")) + await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100) + + callback_system.trigger("callback2", "World", 200, False) + + +asyncio.run(main()) diff --git a/build/lib/exo/topology/__init__.py b/build/lib/exo/topology/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/topology/device_capabilities.py b/build/lib/exo/topology/device_capabilities.py new file mode 100644 index 00000000..51db53ef --- /dev/null +++ b/build/lib/exo/topology/device_capabilities.py @@ -0,0 +1,207 @@ +from exo import DEBUG +from dataclasses import dataclass, asdict +import subprocess +import psutil + +TFLOPS = 1.00 + + +@dataclass +class DeviceFlops: + # units of TFLOPS + fp32: float + fp16: float + int8: float + + def __str__(self): + return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS" + + def to_dict(self): + return asdict(self) + + +@dataclass +class DeviceCapabilities: + model: str + chip: str + memory: int + flops: DeviceFlops + + def __str__(self): + return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}" + + def __post_init__(self): + if isinstance(self.flops, dict): + self.flops = DeviceFlops(**self.flops) + + def to_dict(self): + return {"model": self.model, "chip": self.chip, "memory": self.memory, "flops": self.flops.to_dict()} + + +UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)) + +CHIP_FLOPS = { + # Source: https://www.cpu-monkey.com + # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative + ### M chips + "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS), + "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS), + "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS), + "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS), + "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS), + "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS), + "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS), + "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS), + "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS), + "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + ### A chips + "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS), + "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS), + "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS), + "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS), + "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS), + ### NVIDIA GPUs + # RTX 40 series + "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS), + "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS), + "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS), + "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS), + "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS), + "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS), + "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS), + "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS), + # RTX 30 series + "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS), + "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS), + "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS), + "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS), + "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS), + "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS), + "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS), + "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS), + "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS), + "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS), + # RTX 20 series + "NVIDIA GEFORCE RTX 2060": DeviceFlops(fp32=6.45*TFLOPS, fp16=12.9*TFLOPS, int8=25.8*TFLOPS), + "NVIDIA GEFORCE RTX 2060 SUPER": DeviceFlops(fp32=7.2*TFLOPS, fp16=14.4*TFLOPS, int8=28.8*TFLOPS), + "NVIDIA GEFORCE RTX 2070": DeviceFlops(fp32=7.46*TFLOPS, fp16=14.93*TFLOPS, int8=29.86*TFLOPS), + "NVIDIA GEFORCE RTX 2070 SUPER": DeviceFlops(fp32=9.06*TFLOPS, fp16=18.12*TFLOPS, int8=36.24*TFLOPS), + "NVIDIA GEFORCE RTX 2080": DeviceFlops(fp32=10.07*TFLOPS, fp16=20.14*TFLOPS, int8=40.28*TFLOPS), + "NVIDIA GEFORCE RTX 2080 SUPER": DeviceFlops(fp32=11.15*TFLOPS, fp16=22.30*TFLOPS, int8=44.60*TFLOPS), + "NVIDIA TITAN RTX": DeviceFlops(fp32=16.31*TFLOPS, fp16=32.62*TFLOPS, int8=65.24*TFLOPS), + # QUATRO RTX Ampere series + "NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS), + "NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS), + "NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS), + "NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS), + "NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS), + # Common Server GPUs + "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS), + "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA T1000 8GB": DeviceFlops(fp32=2.5 * TFLOPS, fp16=5.0 * TFLOPS, int8=10.0 * TFLOPS), + "Quadro M2000": DeviceFlops(fp32=0.5 * TFLOPS, fp16=1.0 * TFLOPS, int8=2.0 * TFLOPS), + "Quadro P400": DeviceFlops(fp32=0.641 * TFLOPS, fp16=1.282 * TFLOPS, int8=2.564 * TFLOPS), + # ... add more devices if needed ... + ### AMD GPUs + # RX 6000 series + "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS), + "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS), + "AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS), + "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS), + "AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS), + "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS), + "AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS), + "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS), + "AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS), + # RX 7000 series + "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS), + "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS), + "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS), + "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS), + "AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS), + "AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS), + ### Qualcomm embedded chips: TODO +} +CHIP_FLOPS.update({f"LAPTOP GPU {key}": value for key, value in CHIP_FLOPS.items()}) +CHIP_FLOPS.update({f"Laptop GPU {key}": value for key, value in CHIP_FLOPS.items()}) +CHIP_FLOPS.update({f"{key} LAPTOP GPU": value for key, value in CHIP_FLOPS.items()}) +CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items()}) + + +def device_capabilities() -> DeviceCapabilities: + if psutil.MACOS: + return mac_device_capabilities() + elif psutil.LINUX: + return linux_device_capabilities() + else: + return DeviceCapabilities( + model="Unknown Device", + chip="Unknown Chip", + memory=psutil.virtual_memory().total // 2**20, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ) + + +def mac_device_capabilities() -> DeviceCapabilities: + # Fetch the model of the Mac using system_profiler + model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8") + model_line = next((line for line in model.split("\n") if "Model Name" in line), None) + model_id = model_line.split(": ")[1] if model_line else "Unknown Model" + chip_line = next((line for line in model.split("\n") if "Chip" in line), None) + chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip" + memory_line = next((line for line in model.split("\n") if "Memory" in line), None) + memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory" + memory_units = memory_str.split() + memory_value = int(memory_units[0]) + if memory_units[1] == "GB": + memory = memory_value*1024 + else: + memory = memory_value + + # Assuming static values for other attributes for demonstration + return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0))) + + +def linux_device_capabilities() -> DeviceCapabilities: + import psutil + from tinygrad import Device + + if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}") + if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU": + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + gpu_name = pynvml.nvmlDeviceGetName(handle).upper() + gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + + if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}") + + return DeviceCapabilities( + model=f"Linux Box ({gpu_name})", + chip=gpu_name, + memory=gpu_memory_info.total // 2**20, + flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)), + ) + elif Device.DEFAULT == "AMD": + # TODO AMD support + return DeviceCapabilities( + model="Linux Box (AMD)", + chip="Unknown AMD", + memory=psutil.virtual_memory().total // 2**20, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ) + else: + return DeviceCapabilities( + model=f"Linux Box (Device: {Device.DEFAULT})", + chip=f"Unknown Chip (Device: {Device.DEFAULT})", + memory=psutil.virtual_memory().total // 2**20, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ) diff --git a/build/lib/exo/topology/partitioning_strategy.py b/build/lib/exo/topology/partitioning_strategy.py new file mode 100644 index 00000000..29c3dc6a --- /dev/null +++ b/build/lib/exo/topology/partitioning_strategy.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import List +from dataclasses import dataclass +from .topology import Topology +from exo.inference.shard import Shard + + +# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1 +@dataclass +class Partition: + node_id: str + start: float + end: float + + +class PartitioningStrategy(ABC): + @abstractmethod + def partition(self, topology: Topology) -> List[Partition]: + pass + + +def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]: + shards = [] + for i, partition in enumerate(partitions): + start_layer = int(partition.start*num_layers) + end_layer = int(partition.end*num_layers) - 1 + + # Ensure the last partition covers up to num_layers - 1 + if i == len(partitions) - 1: + end_layer = num_layers - 1 + + # Ensure no empty shards + if start_layer <= end_layer: + shards.append(Shard(model_id, start_layer, end_layer, num_layers)) + + # Ensure full coverage + if shards and shards[-1].end_layer < num_layers - 1: + shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers) + + return shards diff --git a/build/lib/exo/topology/ring_memory_weighted_partitioning_strategy.py b/build/lib/exo/topology/ring_memory_weighted_partitioning_strategy.py new file mode 100644 index 00000000..6550aeb1 --- /dev/null +++ b/build/lib/exo/topology/ring_memory_weighted_partitioning_strategy.py @@ -0,0 +1,18 @@ +from typing import List +from .partitioning_strategy import PartitioningStrategy +from .topology import Topology +from .partitioning_strategy import Partition + + +class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy): + def partition(self, topology: Topology) -> List[Partition]: + nodes = list(topology.all_nodes()) + nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True) + total_memory = sum(node[1].memory for node in nodes) + partitions = [] + start = 0 + for node in nodes: + end = round(start + (node[1].memory/total_memory), 5) + partitions.append(Partition(node[0], start, end)) + start = end + return partitions diff --git a/build/lib/exo/topology/test_device_capabilities.py b/build/lib/exo/topology/test_device_capabilities.py new file mode 100644 index 00000000..5f8b4c3a --- /dev/null +++ b/build/lib/exo/topology/test_device_capabilities.py @@ -0,0 +1,91 @@ +import unittest +from unittest.mock import patch +from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS + + +class TestMacDeviceCapabilities(unittest.TestCase): + @patch("subprocess.check_output") + def test_mac_device_capabilities_pro(self, mock_check_output): + # Mock the subprocess output + mock_check_output.return_value = b""" +Hardware: + +Hardware Overview: + +Model Name: MacBook Pro +Model Identifier: Mac15,9 +Model Number: Z1CM000EFB/A +Chip: Apple M3 Max +Total Number of Cores: 16 (12 performance and 4 efficiency) +Memory: 128 GB +System Firmware Version: 10000.000.0 +OS Loader Version: 10000.000.0 +Serial Number (system): XXXXXXXXXX +Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX +Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX +Activation Lock Status: Enabled +""" + + # Call the function + result = mac_device_capabilities() + + # Check the results + self.assertIsInstance(result, DeviceCapabilities) + self.assertEqual(result.model, "MacBook Pro") + self.assertEqual(result.chip, "Apple M3 Max") + self.assertEqual(result.memory, 131072) # 16 GB in MB + self.assertEqual( + str(result), + "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS", + ) + + @patch("subprocess.check_output") + def test_mac_device_capabilities_air(self, mock_check_output): + # Mock the subprocess output + mock_check_output.return_value = b""" +Hardware: + +Hardware Overview: + +Model Name: MacBook Air +Model Identifier: Mac14,2 +Model Number: MLY33B/A +Chip: Apple M2 +Total Number of Cores: 8 (4 performance and 4 efficiency) +Memory: 8 GB +System Firmware Version: 10000.00.0 +OS Loader Version: 10000.00.0 +Serial Number (system): XXXXXXXXXX +Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX +Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX +Activation Lock Status: Disabled +""" + + # Call the function + result = mac_device_capabilities() + + # Check the results + self.assertIsInstance(result, DeviceCapabilities) + self.assertEqual(result.model, "MacBook Air") + self.assertEqual(result.chip, "Apple M2") + self.assertEqual(result.memory, 8192) # 8 GB in MB + + @unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB") + def test_mac_device_capabilities_real(self): + # Call the function without mocking + result = mac_device_capabilities() + + # Check the results + self.assertIsInstance(result, DeviceCapabilities) + self.assertEqual(result.model, "MacBook Pro") + self.assertEqual(result.chip, "Apple M3 Max") + self.assertEqual(result.memory, 131072) # 128 GB in MB + self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS)) + self.assertEqual( + str(result), + "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/lib/exo/topology/test_map_partitions.py b/build/lib/exo/topology/test_map_partitions.py new file mode 100644 index 00000000..5254915e --- /dev/null +++ b/build/lib/exo/topology/test_map_partitions.py @@ -0,0 +1,81 @@ +import unittest +from typing import List +from exo.topology.partitioning_strategy import Partition, map_partitions_to_shards +from exo.inference.shard import Shard + + +class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): + def test_map_partitions_to_shards(self): + partitions = [ + Partition("node1", 0.0, 0.42857), + Partition("node2", 0.42857, 0.71428), + Partition("node3", 0.71428, 0.99999), + ] + shards = map_partitions_to_shards(partitions, 32, "model") + self.assertEqual( + shards, + [ + Shard("model", 0, 12, 32), + Shard("model", 13, 21, 32), + Shard("model", 22, 31, 32), + ], + ) + + partitions = [ + Partition("node1", 0.0, 0.1), + Partition("node2", 0.1, 0.2), + Partition("node3", 0.2, 1.0), + ] + shards = map_partitions_to_shards(partitions, 32, "model") + self.assertEqual( + shards, + [ + Shard("model", 0, 2, 32), + Shard("model", 3, 5, 32), + Shard("model", 6, 31, 32), + ], + ) + + partitions = [ + Partition("node1", 0.0, 1.0), + ] + shards = map_partitions_to_shards(partitions, 32, "model") + self.assertEqual( + shards, + [ + Shard("model", 0, 31, 32), + ], + ) + + partitions = [] + shards = map_partitions_to_shards(partitions, 32, "model") + self.assertEqual(shards, []) + + def test_broken_map_partitions_to_shards(self): + # this was an old broken implementation that sometimes had rounding errors! + def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str): + shards = [] + for i, partition in enumerate(partitions): + start_layer = int(partition.start*num_layers) + end_layer = int(partition.end*num_layers) - 1 + shards.append(Shard(model_id, start_layer, end_layer, num_layers)) + return shards + + partitions = [ + Partition("node1", 0.0, 0.42857), + Partition("node2", 0.42857, 0.71428), + Partition("node3", 0.71428, 0.99999), + ] + shards = _broken_map_partitions_to_shards(partitions, 32, "model") + self.assertEqual( + shards, + [ + Shard("model", 0, 12, 32), + Shard("model", 13, 21, 32), + Shard("model", 22, 30, 32), + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/lib/exo/topology/test_ring_memory_weighted_partitioning_strategy.py b/build/lib/exo/topology/test_ring_memory_weighted_partitioning_strategy.py new file mode 100644 index 00000000..fd466f36 --- /dev/null +++ b/build/lib/exo/topology/test_ring_memory_weighted_partitioning_strategy.py @@ -0,0 +1,90 @@ +import unittest +from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy +from exo.topology.topology import Topology +from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops +from exo.topology.partitioning_strategy import Partition + + +class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): + def test_partition(self): + # triangle + # node1 -> node2 -> node3 -> node1 + topology = Topology() + topology.update_node( + "node1", + DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)), + ) + topology.update_node( + "node2", + DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)), + ) + topology.update_node( + "node3", + DeviceCapabilities(model="test3", chip="test3", memory=6000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)), + ) + topology.add_edge("node1", "node2") + topology.add_edge("node2", "node3") + topology.add_edge("node3", "node1") + topology.add_edge("node1", "node3") + + strategy = RingMemoryWeightedPartitioningStrategy() + partitions = strategy.partition(topology) + + self.assertEqual(len(partitions), 3) + self.assertEqual( + partitions, + [ + Partition("node3", 0.0, 0.6), + Partition("node1", 0.6, 0.9), + Partition("node2", 0.9, 1.0), + ], + ) + + def test_partition_rounding(self): + # triangle + # node1 -> node2 -> node3 -> node1 + topology = Topology() + topology.update_node( + "node1", + DeviceCapabilities( + model="MacBook Pro", + chip="test1", + memory=128*1024*1024*1024, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ), + ) + topology.update_node( + "node2", + DeviceCapabilities( + model="Mac Studio", + chip="test2", + memory=192*1024*1024*1024, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ), + ) + topology.update_node( + "node3", + DeviceCapabilities( + model="MacBook Pro", + chip="test3", + memory=128*1024*1024*1024, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ), + ) + + strategy = RingMemoryWeightedPartitioningStrategy() + partitions = strategy.partition(topology) + + self.assertEqual(len(partitions), 3) + self.assertEqual( + partitions, + [ + Partition("node3", 0.0, 0.42857), + Partition("node1", 0.6, 0.9), + Partition("node2", 0.9, 1.0), + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/lib/exo/topology/topology.py b/build/lib/exo/topology/topology.py new file mode 100644 index 00000000..46b512e5 --- /dev/null +++ b/build/lib/exo/topology/topology.py @@ -0,0 +1,49 @@ +from .device_capabilities import DeviceCapabilities +from typing import Dict, Set, Optional + + +class Topology: + def __init__(self): + self.nodes: Dict[str, DeviceCapabilities] = {} # Maps node IDs to DeviceCapabilities + self.peer_graph: Dict[str, Set[str]] = {} # Adjacency list representing the graph + self.active_node_id: Optional[str] = None + + def update_node(self, node_id: str, device_capabilities: DeviceCapabilities): + self.nodes[node_id] = device_capabilities + + def get_node(self, node_id: str) -> DeviceCapabilities: + return self.nodes.get(node_id) + + def all_nodes(self): + return self.nodes.items() + + def add_edge(self, node1_id: str, node2_id: str): + if node1_id not in self.peer_graph: + self.peer_graph[node1_id] = set() + if node2_id not in self.peer_graph: + self.peer_graph[node2_id] = set() + self.peer_graph[node1_id].add(node2_id) + self.peer_graph[node2_id].add(node1_id) + + def get_neighbors(self, node_id: str) -> Set[str]: + return self.peer_graph.get(node_id, set()) + + def all_edges(self): + edges = [] + for node, neighbors in self.peer_graph.items(): + for neighbor in neighbors: + if (neighbor, node) not in edges: # Avoid duplicate edges + edges.append((node, neighbor)) + return edges + + def merge(self, other: "Topology"): + for node_id, capabilities in other.nodes.items(): + self.update_node(node_id, capabilities) + for node_id, neighbors in other.peer_graph.items(): + for neighbor in neighbors: + self.add_edge(node_id, neighbor) + + def __str__(self): + nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items()) + edges_str = ", ".join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items()) + return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})" diff --git a/build/lib/exo/viz/__init__.py b/build/lib/exo/viz/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/build/lib/exo/viz/test_topology_viz.py b/build/lib/exo/viz/test_topology_viz.py new file mode 100644 index 00000000..e57de1ae --- /dev/null +++ b/build/lib/exo/viz/test_topology_viz.py @@ -0,0 +1,129 @@ +import asyncio +import unittest +from datetime import timedelta +from exo.viz.topology_viz import TopologyViz +from exo.topology.topology import Topology +from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops +from exo.topology.partitioning_strategy import Partition +from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent + + +def create_hf_repo_progress_event( + completed_files: int = 5, + total_files: int = 10, + downloaded_bytes: int = 500000000, + downloaded_bytes_this_session: int = 250000000, + total_bytes: int = 1000000000, + overall_speed: int = 5000000, + overall_eta: timedelta = timedelta(seconds=100), + file_progress: dict = None, + status: str = "in_progress" +) -> RepoProgressEvent: + if file_progress is None: + file_progress = { + "file1.bin": + RepoFileProgressEvent( + repo_id="repo_id", + repo_revision="repo_revision", + file_path="file1.bin", + downloaded=100000000, + downloaded_this_session=50000000, + total=200000000, + speed=1000000, + eta=timedelta(seconds=100), + status="in_progress" + ), "file2.bin": + RepoFileProgressEvent( + repo_id="repo_id", + repo_revision="repo_revision", + file_path="file2.bin", + downloaded=200000000, + downloaded_this_session=100000000, + total=200000000, + speed=2000000, + eta=timedelta(seconds=0), + status="complete" + ) + } + + return RepoProgressEvent( + repo_id="repo_id", + repo_revision="repo_revision", + completed_files=completed_files, + total_files=total_files, + downloaded_bytes=downloaded_bytes, + downloaded_bytes_this_session=downloaded_bytes_this_session, + total_bytes=total_bytes, + overall_speed=overall_speed, + overall_eta=overall_eta, + file_progress=file_progress, + status=status + ) + + +class TestNodeViz(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.topology = Topology() + self.topology.update_node( + "node1", + DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)), + ) + self.topology.update_node( + "node2", + DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)), + ) + self.topology.update_node( + "node3", + DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)), + ) + self.topology.update_node( + "node4", + DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)), + ) + + self.top_viz = TopologyViz() + await asyncio.sleep(2) # Simulate running for a short time + + async def test_layout_generation(self): + # self.top_viz._generate_layout() + self.top_viz.refresh() + import time + + time.sleep(2) + self.top_viz.update_visualization( + self.topology, + [ + Partition("node1", 0, 0.2), + Partition("node4", 0.2, 0.4), + Partition("node2", 0.4, 0.8), + Partition("node3", 0.8, 0.9), + ], + "node1", + { + "node1": create_hf_repo_progress_event(), + "node2": create_hf_repo_progress_event(), + "node3": create_hf_repo_progress_event(), + "node4": create_hf_repo_progress_event(), + }, + ) + time.sleep(2) + self.topology.active_node_id = "node3" + self.top_viz.update_visualization( + self.topology, + [ + Partition("node1", 0, 0.3), + Partition("node5", 0.3, 0.5), + Partition("node2", 0.5, 0.7), + Partition("node4", 0.7, 0.9), + ], + "node5", + { + "node1": create_hf_repo_progress_event(), + "node5": create_hf_repo_progress_event(), + }, + ) + time.sleep(2) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/lib/exo/viz/topology_viz.py b/build/lib/exo/viz/topology_viz.py new file mode 100644 index 00000000..3664f378 --- /dev/null +++ b/build/lib/exo/viz/topology_viz.py @@ -0,0 +1,307 @@ +import math +from collections import OrderedDict +from typing import List, Optional, Tuple, Dict +from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second +from exo.topology.topology import Topology +from exo.topology.partitioning_strategy import Partition +from exo.download.hf.hf_helpers import RepoProgressEvent +from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES +from rich.console import Console, Group +from rich.text import Text +from rich.live import Live +from rich.style import Style +from rich.table import Table +from rich.layout import Layout +from rich.syntax import Syntax +from rich.panel import Panel +from rich.markdown import Markdown + + +class TopologyViz: + def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []): + self.chatgpt_api_endpoints = chatgpt_api_endpoints + self.web_chat_urls = web_chat_urls + self.topology = Topology() + self.partitions: List[Partition] = [] + self.node_id = None + self.node_download_progress: Dict[str, RepoProgressEvent] = {} + self.requests: OrderedDict[str, Tuple[str, str]] = {} + + self.console = Console() + self.layout = Layout() + self.layout.split(Layout(name="main"), Layout(name="prompt_output", size=15), Layout(name="download", size=25)) + self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow") + self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green") + self.download_panel = Panel("", title="Download Progress", border_style="cyan") + self.layout["main"].update(self.main_panel) + self.layout["prompt_output"].update(self.prompt_output_panel) + self.layout["download"].update(self.download_panel) + + # Initially hide the prompt_output panel + self.layout["prompt_output"].visible = False + self.live_panel = Live(self.layout, auto_refresh=False, console=self.console) + self.live_panel.start() + + def update_visualization(self, topology: Topology, partitions: List[Partition], node_id: Optional[str] = None, node_download_progress: Dict[str, RepoProgressEvent] = {}): + self.topology = topology + self.partitions = partitions + self.node_id = node_id + if node_download_progress: + self.node_download_progress = node_download_progress + self.refresh() + + def update_prompt(self, request_id: str, prompt: Optional[str] = None): + if request_id in self.requests: + self.requests[request_id] = [prompt, self.requests[request_id][1]] + else: + self.requests[request_id] = [prompt, ""] + self.refresh() + + def update_prompt_output(self, request_id: str, output: Optional[str] = None): + if request_id in self.requests: + self.requests[request_id] = [self.requests[request_id][0], output] + else: + self.requests[request_id] = ["", output] + self.refresh() + + def refresh(self): + self.main_panel.renderable = self._generate_main_layout() + # Update the panel title with the number of nodes and partitions + node_count = len(self.topology.nodes) + self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})" + + # Update and show/hide prompt and output panel + if any(r[0] or r[1] for r in self.requests.values()): + self.prompt_output_panel = self._generate_prompt_output_layout() + self.layout["prompt_output"].update(self.prompt_output_panel) + self.layout["prompt_output"].visible = True + else: + self.layout["prompt_output"].visible = False + + # Only show download_panel if there are in-progress downloads + if any(progress.status == "in_progress" for progress in self.node_download_progress.values()): + self.download_panel.renderable = self._generate_download_layout() + self.layout["download"].visible = True + else: + self.layout["download"].visible = False + + self.live_panel.update(self.layout, refresh=True) + + def _generate_prompt_output_layout(self) -> Panel: + content = [] + requests = list(self.requests.values())[-3:] # Get the 3 most recent requests + max_width = self.console.width - 6 # Full width minus padding and icon + max_lines = 13 # Maximum number of lines for the entire panel content + + for (prompt, output) in reversed(requests): + prompt_icon, output_icon = "💬️", "🤖" + + # Process prompt + prompt_lines = prompt.split('\n') + if len(prompt_lines) > max_lines // 2: + prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...'] + prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue") + prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white") + + # Process output + output_lines = output.split('\n') + remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing + if len(output_lines) > remaining_lines: + output_lines = output_lines[:remaining_lines - 1] + ['...'] + output_text = Text(f"\n{output_icon} ", style="bold bright_magenta") + output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white") + + content.append(prompt_text) + content.append(output_text) + content.append(Text()) # Empty line between entries + + return Panel( + Group(*content), + title="", + border_style="cyan", + height=15, # Increased height to accommodate multiple lines + expand=True # Allow the panel to expand to full width + ) + + def _generate_main_layout(self) -> str: + # Calculate visualization parameters + num_partitions = len(self.partitions) + radius_x = 30 + radius_y = 12 + center_x, center_y = 50, 24 # Increased center_y to add more space + + # Generate visualization + visualization = [[" " for _ in range(100)] for _ in range(48)] # Increased height to 48 + + # Add exo_text at the top in bright yellow + exo_lines = exo_text.split("\n") + yellow_style = Style(color="bright_yellow") + max_line_length = max(len(line) for line in exo_lines) + for i, line in enumerate(exo_lines): + centered_line = line.center(max_line_length) + start_x = (100-max_line_length) // 2 + 15 + colored_line = Text(centered_line, style=yellow_style) + for j, char in enumerate(str(colored_line)): + if 0 <= start_x + j < 100 and i < len(visualization): + visualization[i][start_x + j] = char + + # Display chatgpt_api_endpoints and web_chat_urls + info_lines = [] + if len(self.web_chat_urls) > 0: + info_lines.append(f"Web Chat URL (tinychat): {' '.join(self.web_chat_urls[:1])}") + if len(self.chatgpt_api_endpoints) > 0: + info_lines.append(f"ChatGPT API endpoint: {' '.join(self.chatgpt_api_endpoints[:1])}") + + info_start_y = len(exo_lines) + 1 + for i, line in enumerate(info_lines): + start_x = (100 - len(line)) // 2 + 15 + for j, char in enumerate(line): + if 0 <= start_x + j < 100 and info_start_y + i < 48: + visualization[info_start_y + i][start_x + j] = char + + # Calculate total FLOPS and position on the bar + total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions) + bar_pos = (math.tanh(total_flops/20 - 2) + 1)/2 + + # Add GPU poor/rich bar + bar_width = 30 + bar_start_x = (100-bar_width) // 2 + bar_y = info_start_y + len(info_lines) + 1 + + # Create a gradient bar using emojis + gradient_bar = Text() + emojis = ["🟥", "🟧", "🟨", "🟩"] + for i in range(bar_width): + emoji_index = min(int(i/(bar_width/len(emojis))), len(emojis) - 1) + gradient_bar.append(emojis[emoji_index]) + + # Add the gradient bar to the visualization + visualization[bar_y][bar_start_x - 1] = "[" + visualization[bar_y][bar_start_x + bar_width] = "]" + for i, segment in enumerate(str(gradient_bar)): + visualization[bar_y][bar_start_x + i] = segment + + # Add labels + visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor" + visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = "GPU rich" + + # Add position indicator and FLOPS value + pos_x = bar_start_x + int(bar_pos*bar_width) + flops_str = f"{total_flops:.2f} TFLOPS" + visualization[bar_y - 1][pos_x] = "▼" + visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str + visualization[bar_y + 2][pos_x] = "▲" + + # Add an extra empty line for spacing + bar_y += 4 + + for i, partition in enumerate(self.partitions): + device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES) + + angle = 2*math.pi*i/num_partitions + x = int(center_x + radius_x*math.cos(angle)) + y = int(center_y + radius_y*math.sin(angle)) + + # Place node with different color for active node and this node + if partition.node_id == self.topology.active_node_id: + visualization[y][x] = "🔴" + elif partition.node_id == self.node_id: + visualization[y][x] = "🟢" + else: + visualization[y][x] = "🔵" + + # Place node info (model, memory, TFLOPS, partition) on three lines + node_info = [ + f"{device_capabilities.model} {device_capabilities.memory // 1024}GB", + f"{device_capabilities.flops.fp16}TFLOPS", + f"[{partition.start:.2f}-{partition.end:.2f}]", + ] + + # Calculate info position based on angle + info_distance_x = radius_x + 6 + info_distance_y = radius_y + 3 + info_x = int(center_x + info_distance_x*math.cos(angle)) + info_y = int(center_y + info_distance_y*math.sin(angle)) + + # Adjust text position to avoid overwriting the node icon and prevent cutoff + if info_x < x: + info_x = max(0, x - len(max(node_info, key=len)) - 1) + elif info_x > x: + info_x = min(99 - len(max(node_info, key=len)), info_x) + + # Adjust for top and bottom nodes + if 5*math.pi/4 < angle < 7*math.pi/4: + info_x += 4 + elif math.pi/4 < angle < 3*math.pi/4: + info_x += 3 + info_y -= 2 + + for j, line in enumerate(node_info): + for k, char in enumerate(line): + if 0 <= info_y + j < 48 and 0 <= info_x + k < 100: + if info_y + j != y or info_x + k != x: + visualization[info_y + j][info_x + k] = char + + # Draw line to next node + next_i = (i+1) % num_partitions + next_angle = 2*math.pi*next_i/num_partitions + next_x = int(center_x + radius_x*math.cos(next_angle)) + next_y = int(center_y + radius_y*math.sin(next_angle)) + + # Simple line drawing + steps = max(abs(next_x - x), abs(next_y - y)) + for step in range(1, steps): + line_x = int(x + (next_x-x)*step/steps) + line_y = int(y + (next_y-y)*step/steps) + if 0 <= line_y < 48 and 0 <= line_x < 100: + visualization[line_y][line_x] = "-" + + # Convert to string + return "\n".join("".join(str(char) for char in row) for row in visualization) + + def _generate_download_layout(self) -> Table: + summary = Table(show_header=False, box=None, padding=(0, 1), expand=True) + summary.add_column("Info", style="cyan", no_wrap=True, ratio=50) + summary.add_column("Progress", style="cyan", no_wrap=True, ratio=40) + summary.add_column("Percentage", style="cyan", no_wrap=True, ratio=10) + + # Current node download progress + if self.node_id in self.node_download_progress: + download_progress = self.node_download_progress[self.node_id] + title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):" + summary.add_row(Text(title, style="bold")) + progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})" + summary.add_row(progress_info) + + eta_info = f"{download_progress.overall_eta}" + summary.add_row(eta_info) + + summary.add_row("") # Empty row for spacing + + for file_path, file_progress in download_progress.file_progress.items(): + if file_progress.status != "complete": + progress = int(file_progress.downloaded/file_progress.total*30) + bar = f"[{'=' * progress}{' ' * (30 - progress)}]" + percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%" + summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage) + + summary.add_row("") # Empty row for spacing + + # Other nodes download progress summary + summary.add_row(Text("Other Nodes Download Progress:", style="bold")) + for node_id, progress in self.node_download_progress.items(): + if node_id != self.node_id: + device = self.topology.nodes.get(node_id) + partition = next((p for p in self.partitions if p.node_id == node_id), None) + partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else "" + percentage = progress.downloaded_bytes/progress.total_bytes*100 if progress.total_bytes > 0 else 0 + speed = pretty_print_bytes_per_second(progress.overall_speed) + device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}" + progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})" + progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]" + percentage_str = f"{percentage:.1f}%" + eta_str = f"{progress.overall_eta}" + summary.add_row(device_info, progress_info, percentage_str) + summary.add_row("", progress_bar, eta_str) + + return summary