From 57e14e8cbf6e3d2dbae632a70b8b53c4ec14efcb Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 11:08:19 -0800 Subject: [PATCH 1/2] adding needed libs to setup.py, fixing 4 space to 2 space issue, adding in hf downloader to inference engine, testing --- .gitignore | 3 + exo/inference/pytorch/inference.py | 533 +++++++++--------- exo/inference/pytorch/model/hf.py | 30 +- .../pytorch/tests/test_inference_engine.py | 283 ++++------ .../pytorch/tests/test_split_model.py | 25 +- exo/models.py | 12 +- exo/tinychat/index.html | 1 + setup.py | 2 + 8 files changed, 438 insertions(+), 451 deletions(-) diff --git a/.gitignore b/.gitignore index f5609f311..33907f700 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,6 @@ cython_debug/ # PyTorch interface .offload + +# neovim/vim settings +.vimrc diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 94cea1004..2f87c1b10 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -1,4 +1,5 @@ # experimental, based off of tinygrad/inference.py +import os import numpy as np import torch import json @@ -9,9 +10,9 @@ from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel from exo.api.chatgpt_api import resolve_tokenizer from exo.helpers import DEBUG -from exo.download.shard_download import ShardDownloader +from exo.download.hf.hf_shard_download import HFShardDownloader -# model value options +# model value options TOP_K = 20 TEMP = 0.6 TOP_P = 0.9 @@ -19,267 +20,273 @@ MAX_TIME = 60.0 class PyTorchDynamicShardInferenceEngine(InferenceEngine): + """ + PyTorch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. + """ + + def __init__(self, shard_downloader: HFShardDownloader): + """ + Initialize the inference engine. + + Args: + debug (bool): If True, enables debug logging. Defaults to False. """ - PyTorch Dynamic Shard Inference Engine for performing model inference with sharded Pytorch/HF based models. + self.shard = None + self.shard_downloader = shard_downloader + self.stateful_sharded_model = None + self.tokenizer = None + + # the whole history with new logits need to + # be passed to the model to reach the end token + # even with caching + self.past_input_ids = None + + # setup cuda device + if os.environ.get("PYTORCH_DEVICE"): + pytorch_device = os.environ["PYTOCH_DEVICE"] + if pytorch_device not in ["cuda", "mps", "cpu"]: + pytorch_device = "cpu" + + self.device = pytorch_device + self.torch_dtype = torch.float32 if pytorch_device != "cpu" else torch.float16 + + if torch.cuda.is_available(): + self.device = torch.device("cuda") + self.torch_dtype = torch.float32 + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + self.torch_dtype = torch.float32 + else: + self.device = torch.device("cpu") + self.torch_dtype = torch.float16 + + # setup unfinished sequence + self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device) + + def infer_caching( + self, + inference_state: Optional[str] = None + ) -> Tuple[Optional[torch.tensor], Optional[dict]]: """ + inference caching from inference_state json + """ + # setup cache and cached input_ids + past_iids = None + cached_iids = None + if inference_state is not None: + try: + infer_state = json.loads(inference_state) + except ValueError: + infer_state = None + + if infer_state is not None: + cached_iids = infer_state["cached_iids"] + if cached_iids is not None: + past_iids = None + if len(cached_iids) > 0: + past_iids = torch.tensor(cached_iids["input_ids"]).to(self.device) + cached_iids = {"input_ids": past_iids.tolist()} + + if DEBUG >= 4: + print(f"cached_iids: {cached_iids}") + + return (past_iids, cached_iids) + + async def infer_prompt( + self, + request_id: Optional[str] = None, + shard: Optional[Shard] = None, + prompt: Optional[str] = "", + image_str: Optional[str] = None, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_prompt called") + print(f"prompt: {prompt}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + + await self.ensure_shard(shard) + + # setup prompt input + messages = [{"role": "user", "content": prompt}] + txt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + inputs = self.tokenizer([txt], return_tensors="pt") + input_ids = inputs.input_ids.to(self.device) + input_attention_mask = inputs.attention_mask.to(self.device) + batch_size, seq_length = input_ids.shape[:2] + + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) + + if past_iids is not None: + self.past_input_ids = past_iids, + else: + self.past_input_ids = input_ids + + if DEBUG >= 4: + print(f"past_input_ids: {self.past_input_ids}\n") + + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + input_ids=self.past_input_ids, + attention_mask=input_attention_mask + ) + + if DEBUG >= 4: + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + next_token = None + if shard_logits is not None: + next_token = self.stateful_sharded_model.logits_sample(shard_logits) + self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) + input_ids = next_token + + if self.past_input_ids is not None: + cached_iids = {"input_ids": self.past_input_ids.tolist()} + + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id + + if DEBUG >= 4: + print(f"\ninput_ids: {input_ids}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), + is_finished + ) + + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[str] = None + ) -> Tuple[np.ndarray, str, bool]: + if DEBUG >= 4: + print("infer_tensor called") + print(f"input_data: {input_data}") + print(f"shard: {shard}") + print(f"inference_state: {inference_state}") + + await self.ensure_shard(shard) + + input_ids = torch.tensor(input_data).to(self.device) + + # get cache from inference_state + past_iids, cached_iids = self.infer_caching(inference_state) + + # detect if hidden_states or not + hidden_states = None + self.past_input_ids = None + if input_ids.size()[-1] > 1: + hidden_states = input_ids + else: + if past_iids is not None: + self.past_input_ids = past_iids + else: + self.past_input_ids = input_ids + + if DEBUG >= 4: + print(f"input_ids: {input_ids}") + print(f"inference_state: {inference_state}") + + shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( + input_ids=self.past_input_ids, + hidden_states=hidden_states + ) + + hidden_dict = None + if shard_hidden_states is not None: + hidden_dict = {"hidden_states": shard_hidden_states.tolist()} + + next_token = None + if shard_logits is not None: + next_token = self.stateful_sharded_model.logits_sample(shard_logits) + input_ids = next_token + + #cache + if next_token is not None: + if self.past_input_ids is not None: + next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) + elif past_iids is not None: + next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) + + cached_iids = {"input_ids": next_cached_logits.tolist()} + + is_finished = False + if next_token is not None: + is_finished = next_token.item() == self.tokenizer.eos_token_id + + if is_finished: + # clear cache + cached_iids = {"input_ids": []} + + if DEBUG >= 4: + print(f"\ninput_ids: {input_ids}") + print(f"\nshard_hidden_states: {shard_hidden_states}\n") + print(f"\nshard_past_kvs {shard_past_kvs}\n") + print(f"\nshard_logits: {shard_logits}") + + return_values = ( + input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), + json.dumps({"cached_iids": cached_iids}), + is_finished + ) + + if DEBUG >= 4: + print(f"return_values: {return_values}") + + return return_values + + + async def ensure_shard(self, shard: Optional[Shard]): + """ + Ensure the model shard is loaded and ready for inference. - def __init__(self, shard_downloader: ShardDownloader): - """ - Initialize the inference engine. - - Args: - debug (bool): If True, enables debug logging. Defaults to False. - """ - self.shard = None - self.shard_downloader = shard_downloader - self.stateful_sharded_model = None - self.tokenizer = None - - # the whole history with new logits need to - # be passed to the model to reach the end token - # even with caching - self.past_input_ids = None - - # setup cuda device - if torch.cuda.is_available(): - self.device = torch.device("cuda") - self.torch_dtype = torch.float32 - elif torch.backends.mps.is_available(): - self.device = torch.device("mps") - self.torch_dtype = torch.float32 - else: - self.device = torch.device("cpu") - self.torch_dtype = torch.float16 - - # setup unfinished sequence - self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device) - - def infer_caching( - self, - inference_state: Optional[str] = None - ) -> Tuple[Optional[torch.tensor], Optional[dict]]: - """ - inference caching from inference_state json - """ - # setup cache and cached input_ids - past_iids = None - cached_iids = None - if inference_state is not None: - try: - infer_state = json.loads(inference_state) - except ValueError: - infer_state = None - - if infer_state is not None: - cached_iids = infer_state["cached_iids"] - if cached_iids is not None: - past_iids = None - if len(cached_iids) > 0: - past_iids = torch.tensor(cached_iids["input_ids"]).to(self.device) - cached_iids = {"input_ids": past_iids.tolist()} - - if DEBUG >= 4: - print(f"cached_iids: {cached_iids}") - - return (past_iids, cached_iids) - - - async def infer_prompt( - self, - request_id: str, - shard: Optional[Shard] = None, - prompt: str = "", - image_str: Optional[str] = None, - inference_state: Optional[str] = None - ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 4: - print("infer_prompt called") - print(f"prompt: {prompt}") - print(f"shard: {shard}") - print(f"inference_state: {inference_state}") - - await self.ensure_shard(shard) - - # setup prompt input - messages = [{"role": "user", "content": prompt}] - txt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - - inputs = self.tokenizer([txt], return_tensors="pt") - input_ids = inputs.input_ids.to(self.device) - input_attention_mask = inputs.attention_mask.to(self.device) - batch_size, seq_length = input_ids.shape[:2] - - # get cache from inference_state - past_iids, cached_iids = self.infer_caching(inference_state) - - if past_iids is not None: - self.past_input_ids = past_iids, - else: - self.past_input_ids = input_ids - - if DEBUG >= 4: - print(f"past_input_ids: {self.past_input_ids}\n") - - shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( - input_ids=self.past_input_ids, - attention_mask=input_attention_mask - ) - - if DEBUG >= 4: - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - - next_token = None - if shard_logits is not None: - next_token = self.stateful_sharded_model.logits_sample(shard_logits) - self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1) - input_ids = next_token - - if self.past_input_ids is not None: - cached_iids = {"input_ids": self.past_input_ids.tolist()} - - is_finished = False - if next_token is not None: - is_finished = next_token.item() == self.tokenizer.eos_token_id - - if DEBUG >= 4: - print(f"\ninput_ids: {input_ids}") - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - - return_values = ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps({"cached_iids": cached_iids}), - is_finished - ) - - if DEBUG >= 4: - print(f"return_values: {return_values}") - - return return_values - - async def infer_tensor( - self, - request_id: str, - shard: Shard, - input_data: np.ndarray, - inference_state: Optional[str] = None - ) -> Tuple[np.ndarray, str, bool]: - if DEBUG >= 4: - print("infer_tensor called") - print(f"input_data: {input_data}") - print(f"shard: {shard}") - print(f"inference_state: {inference_state}") - - await self.ensure_shard(shard) - - input_ids = torch.tensor(input_data).to(self.device) - - # get cache from inference_state - past_iids, cached_iids = self.infer_caching(inference_state) - - # detect if hidden_states or not - hidden_states = None - self.past_input_ids = None - if input_ids.size()[-1] > 1: - hidden_states = input_ids - else: - if past_iids is not None: - self.past_input_ids = past_iids - else: - self.past_input_ids = input_ids - - if DEBUG >= 4: - print(f"input_ids: {input_ids}") - print(f"inference_state: {inference_state}") - - shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward( - input_ids=self.past_input_ids, - hidden_states=hidden_states - ) - - hidden_dict = None - if shard_hidden_states is not None: - hidden_dict = {"hidden_states": shard_hidden_states.tolist()} - - next_token = None - if shard_logits is not None: - next_token = self.stateful_sharded_model.logits_sample(shard_logits) - input_ids = next_token - - #cache - if next_token is not None: - if self.past_input_ids is not None: - next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device) - elif past_iids is not None: - next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device) - - cached_iids = {"input_ids": next_cached_logits.tolist()} - - is_finished = False - if next_token is not None: - is_finished = next_token.item() == self.tokenizer.eos_token_id - - if is_finished: - # clear cache - cached_iids = {"input_ids": []} - - if DEBUG >= 4: - print(f"\ninput_ids: {input_ids}") - print(f"\nshard_hidden_states: {shard_hidden_states}\n") - print(f"\nshard_past_kvs {shard_past_kvs}\n") - print(f"\nshard_logits: {shard_logits}") - - return_values = ( - input_ids.numpy(force=True) if shard_logits is not None else shard_hidden_states.numpy(force=True), - json.dumps({"cached_iids": cached_iids}), - is_finished - ) - - if DEBUG >= 4: - print(f"return_values: {return_values}") - - return return_values - - - async def ensure_shard(self, shard: Optional[Shard]): - """ - Ensure the model shard is loaded and ready for inference. - - Args: - shard (Optional[Shard]): Shard information for the model. - """ - if self.shard == shard: - return - - if DEBUG >= 4: - print(f"Loading new shard: {shard}") - - # -- TO DO -- - # Build in shard downloader but requires pulling - # apart how TrainedModel loads weight in its __init__ - # function in the transformer library - # model_path = await self.shard_downloader.ensure_shard(shard) - - self.tokenizer = await resolve_tokenizer(shard.model_id) - self.stateful_sharded_model = ShardedHuggingFaceModel( - shard=shard, - device=self.device, - dtype=self.torch_dtype, - top_k=TOP_K, - temp=TEMP, - top_p=TOP_P, - max_length=MAX_LENGTH, - max_time=MAX_TIME - ) - - self.shard = shard - - if DEBUG >= 4: - print(f"Shard loaded successfully: {shard}") + Args: + shard (Optional[Shard]): Shard information for the model. + """ + if self.shard == shard: + return + + if DEBUG >= 4: + print(f"Loading new shard: {shard}") + + model_path = await self.shard_downloader.ensure_shard(shard) + if DEBUG >= 4: + print(f"model_path: {model_path}") + + self.tokenizer = await resolve_tokenizer(shard.model_id) + self.stateful_sharded_model = ShardedHuggingFaceModel( + shard=shard, + local_model_path=model_path, + device=self.device, + dtype=self.torch_dtype, + top_k=TOP_K, + temp=TEMP, + top_p=TOP_P, + max_length=MAX_LENGTH, + max_time=MAX_TIME + ) + + self.shard = shard + + if DEBUG >= 4: + print(f"Shard loaded successfully: {shard}") diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 1b617d7cd..38cd85c20 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -23,7 +23,8 @@ class ShardedHuggingFaceModel: def __init__( self, shard: Shard, - device, + local_model_path, + device, dtype, top_k: int = 25, temp: float = 0.7, @@ -31,19 +32,20 @@ def __init__( max_length: int = 50, max_time: float = 10.0 ): - # class vars + # class vars self.shard = shard - self.hidden_states = None + self.hidden_states = None self.input_ids = None self.inputs_embeds = None self.attention_mask = None - self.position_embeddings = None - self.past_key_values = None - self.cache_position = None - self.position_ids = None + self.position_embeddings = None + self.past_key_values = None + self.cache_position = None + self.position_ids = None self.causal_mask = None + self.local_model_path = local_model_path - # setup logit processors + # setup logit processors self.logits_processor = LogitsProcessorList([ TopKLogitsWarper(top_k), TemperatureLogitsWarper(temp), @@ -56,13 +58,13 @@ def __init__( # setup pytorch and transformer llm try: self.llm_model = AutoModelForCausalLM.from_pretrained( - shard.model_id, + pretrained_model_name_or_path=self.local_model_path, torch_dtype=self.torch_dtype, device_map="auto", offload_buffers=True ) - self.model = self.llm_model.model + self.model = self.llm_model.model except Exception as err: print(f"error loading and splitting model: {err}") raise @@ -70,7 +72,6 @@ def __init__( def forward( self, - shard: Optional[Shard] = None, input_ids: Optional[torch.tensor] = None, hidden_states: Optional[torch.tensor] = None, attention_mask: Optional[torch.tensor] = None, @@ -93,7 +94,7 @@ def forward( infer_tensor: bool optional, lets forward know to handle tensors Returns: - Tuple of + Tuple of - hidden_states: tensor optional - past_key_values: Cache or list[tensor] optional - logits: tensor Optional @@ -199,9 +200,8 @@ def forward( print(f"hidden_states: {self.hidden_states}") print(f"next_decoder_cache: {self.next_decoder_cache}") - # handle last layer to get logits - # shard is last layer says true at the start and not detecting last layer correctly + # shard is last layer says true at the start and not detecting last layer correctly if self.shard.is_last_layer(): self.hidden_states = self.model.norm(self.hidden_states) if use_legacy_cache: @@ -209,7 +209,7 @@ def forward( else: self.past_key_values = self.next_decoder_cache - # lm_head + # lm_head logits = self.llm_model.lm_head(self.hidden_states).to(self.device) if DEBUG >= 4: diff --git a/exo/inference/pytorch/tests/test_inference_engine.py b/exo/inference/pytorch/tests/test_inference_engine.py index 7e64c137a..854d9b9c9 100644 --- a/exo/inference/pytorch/tests/test_inference_engine.py +++ b/exo/inference/pytorch/tests/test_inference_engine.py @@ -11,164 +11,131 @@ import time async def test_inference_engine( - inference_engine_1: InferenceEngine, - inference_engine_2: InferenceEngine, - model_id: str, - n_layers: int): - - # prompt = "Why is the sky blue?" - prompt = "In a single word only, what is the last name of the current president of the USA?" - - shard = Shard( - model_id=model_id, - start_layer=0, - end_layer=n_layers-1, - n_layers=n_layers - ) - - resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( - "A", - shard=shard, - prompt=prompt - ) - - print("\n------------resp_full---------------\n") - print(resp_full) - print("\n------------resp_full---------------\n") - - time.sleep(5) - - next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( - "A", - shard=shard, - input_data=resp_full, - inference_state=inference_state_full, - ) - - print("\n------------next_resp_full---------------\n") - print(next_resp_full) - print("\n------------next_resp_full---------------\n") - - time.sleep(5) - - pp = int(n_layers/2) - - resp_shard = Shard( - model_id=model_id, - start_layer=0, - end_layer=pp, - n_layers=n_layers - ) - - resp_shard2 = Shard( - model_id=model_id, - start_layer=pp + 1, - end_layer=n_layers-1, - n_layers=n_layers - ) - - resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( - "B", - shard=resp_shard, - prompt=prompt - ) - - print("\n------------resp1---------------\n") - print(resp1) - print("\n------------resp1---------------\n") - - time.sleep(5) - - - resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( - "B", - shard=resp_shard2, - input_data=resp1, - inference_state=inference_state_1, - ) - - print("\n------------resp2---------------\n") - print(resp2) - print("\n------------resp2---------------\n") - - resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( - "B", - shard=resp_shard, - input_data=resp2, - inference_state=inference_state_2, - ) - - print("\n------------resp3---------------\n") - print(resp3) - print("\n------------resp3---------------\n") - - resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( - "B", - shard=resp_shard2, - input_data=resp3, - inference_state=inference_state_3, - ) - - print("\n------------resp4---------------\n") - print(resp4) - print("\n------------resp4---------------\n") - - assert np.array_equal(resp_full, resp2) - assert np.array_equal(next_resp_full, resp4) + inference_engine_1: InferenceEngine, + inference_engine_2: InferenceEngine, + model_id: str, + n_layers: int): + + # prompt = "Why is the sky blue?" + prompt = "In a single word only, what is the last name of the current president of the USA?" + + shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt( + "A", + shard=shard, + prompt=prompt + ) + + print("\n------------resp_full---------------\n") + print(resp_full) + print("\n------------resp_full---------------\n") + + time.sleep(5) + + next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor( + "A", + shard=shard, + input_data=resp_full, + inference_state=inference_state_full, + ) + + print("\n------------next_resp_full---------------\n") + print(next_resp_full) + print("\n------------next_resp_full---------------\n") + + time.sleep(5) + + pp = int(n_layers/2) + + resp_shard = Shard( + model_id=model_id, + start_layer=0, + end_layer=pp, + n_layers=n_layers + ) + + resp_shard2 = Shard( + model_id=model_id, + start_layer=pp + 1, + end_layer=n_layers-1, + n_layers=n_layers + ) + + resp1, inference_state_1, _ = await inference_engine_1.infer_prompt( + "B", + shard=resp_shard, + prompt=prompt + ) + + print("\n------------resp1---------------\n") + print(resp1) + print("\n------------resp1---------------\n") + + time.sleep(5) + + + resp2, inference_state_2, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp1, + inference_state=inference_state_1, + ) + + print("\n------------resp2---------------\n") + print(resp2) + print("\n------------resp2---------------\n") + + resp3, inference_state_3, _ = await inference_engine_1.infer_tensor( + "B", + shard=resp_shard, + input_data=resp2, + inference_state=inference_state_2, + ) + + print("\n------------resp3---------------\n") + print(resp3) + print("\n------------resp3---------------\n") + + resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor( + "B", + shard=resp_shard2, + input_data=resp3, + inference_state=inference_state_3, + ) + + print("\n------------resp4---------------\n") + print(resp4) + print("\n------------resp4---------------\n") + + assert np.array_equal(resp_full, resp2) + assert np.array_equal(next_resp_full, resp4) if __name__ == '__main__': - try: - print(f"\n\n -------- TEST QWEN2 -------- \n\n") - asyncio.run(test_inference_engine( - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - "Qwen/Qwen2-0.5B-Instruct", - 24 - )) - except Exception as err: - print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") - - # try: - # print(f"\n\n -------- TEST LLAMA3-1B-Base -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "andrijdavid/Llama3-1B-Base", - # 3 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! LLAMA3-1B-Base TEST FAILED \n{err}\n") - - # try: - # print(f"\n\n -------- TEST META LLAMA 3.1 8B -------- \n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "meta-llama/Meta-Llama-3.1-8B", - # 32 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! META LLAMA 3.1 8B TEST FAILED \n{err}\n") - - # try: - # print(f"\n\n ------- TEST Chickaboo/ChickaQ-Large -----\n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "Chickaboo/ChickaQ-Large", - # 24 - # )) - # except Exception as err: - # print(f"\n\n !!!!!!!!!!! Chickaboo/ChickaQ-Large TEST FAILED \n{err}\n") - - #try: - # print(f"\n\n --------- TEST TinyLlama/TinyLlama_v1.1 -------\n\n") - # asyncio.run(test_inference_engine( - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), - # "TinyLlama/TinyLlama_v1.1", - # 22 - # )) - #except Exception as err: - # print(f"\n\n !!!!!!!!!!! TinyLlama/TinyLlama_v1.1 TEST FAILED \n{err}\n") + # try: + # print("\n\n -------- TEST QWEN2 -------- \n\n") + # asyncio.run(test_inference_engine( + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + # "Qwen/Qwen2-0.5B-Instruct", + # 24 + # )) + # except Exception as err: + # print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n") + + try: + print("\n-------- Test meta-llama/Llama-3.2-1B-Instruct ----------\n") + asyncio.run(test_inference_engine( + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + PyTorchDynamicShardInferenceEngine(HFShardDownloader()), + "meta-llama/Llama-3.2-1B-Instruct", + 24 + )) + except Exception as err: + print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n") diff --git a/exo/inference/pytorch/tests/test_split_model.py b/exo/inference/pytorch/tests/test_split_model.py index 827bdec2e..157a215d1 100644 --- a/exo/inference/pytorch/tests/test_split_model.py +++ b/exo/inference/pytorch/tests/test_split_model.py @@ -3,14 +3,11 @@ import asyncio import gc from transformers import ( - AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache, Cache, LogitsProcessorList, - #MinLengthLogitsProcessor, - LogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, @@ -286,8 +283,8 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): stopping_critera = StoppingCriteriaList( [ - MaxLengthCriteria(max_length=50), - MaxTimeCriteria(max_time=10.0), + MaxLengthCriteria(max_length=255), + MaxTimeCriteria(max_time=100.0), ] ) @@ -355,9 +352,21 @@ async def model_half_split_test(prompt: str, model_id: str, layers: int): # ) #) - print("\n-------- Test Qwen/Qwen2-0.5B-Instruct ----------\n") - model_id = "Qwen/Qwen2-0.5B-Instruct" - model_layers = 24 + #print("\n-------- Test Qwen/Qwen2-0.5B-Instruct ----------\n") + #model_id = "Qwen/Qwen2-0.5B-Instruct" + #model_layers = 24 + + #asyncio.run( + # model_half_split_test( + # prompt=prompt, + # model_id=model_id, + # layers=model_layers + # ) + #) + + print("\n-------- Test meta-llama/Llama-3.2-1B-Instruct ----------\n") + model_id = "meta-llama/Llama-3.2-1B-Instruct" + model_layers = 32 asyncio.run( model_half_split_test( diff --git a/exo/models.py b/exo/models.py index 67ea81c41..6f69960ea 100644 --- a/exo/models.py +++ b/exo/models.py @@ -36,8 +36,8 @@ "llama-3-1B-Base": { "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), }, - "TinyLlama-1.1B-Chat-yaw": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="ambrosfitz/TinyLlama-1.1B-Chat-yawp", start_layer=0, end_layer=0, n_layers=22), + "meta-llama/Llama-3.2-1B-Instruct": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=24), }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, @@ -47,11 +47,6 @@ "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),}, ### llava "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),}, - ### qwen - "Qwen2-0.5B-Instruct": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), - }, - ### qwen "qwen-2.5-coder-1.5b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), @@ -74,4 +69,7 @@ "qwen-2.5-math-72b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), }, + "Qwen2-0.5B-Instruct": { + "PyTorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24), + }, } diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index 9cad69d58..c00d2b0a6 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -38,6 +38,7 @@ + diff --git a/setup.py b/setup.py index 75d570e9f..8401167be 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,8 @@ "transformers==4.43.3", "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad", + "torch==2.4.0+cu124", + "accelerate=0.33.0" ] # Add macOS-specific packages if on Darwin (macOS) From 9fe3ec63dd26b78d9c27e3bcb17f72a79c7ee977 Mon Sep 17 00:00:00 2001 From: risingsunomi Date: Sun, 6 Oct 2024 11:51:33 -0800 Subject: [PATCH 2/2] cleaning up code, added pytorch engine to llama 3.2 1b model shard in models.py, removed old 3.2 1b model shard, moving to test server for more vram --- exo/inference/pytorch/inference.py | 21 ++++++++++++--------- exo/inference/pytorch/model/hf.py | 2 +- exo/models.py | 4 +--- exo/tinychat/index.html | 1 - setup.py | 2 +- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/exo/inference/pytorch/inference.py b/exo/inference/pytorch/inference.py index 2f87c1b10..8264aae83 100644 --- a/exo/inference/pytorch/inference.py +++ b/exo/inference/pytorch/inference.py @@ -12,6 +12,9 @@ from exo.helpers import DEBUG from exo.download.hf.hf_shard_download import HFShardDownloader +# llama +from transformers.models.llama.modeling_llama import LlamaModel + # model value options TOP_K = 20 TEMP = 0.6 @@ -52,7 +55,7 @@ def __init__(self, shard_downloader: HFShardDownloader): if torch.cuda.is_available(): self.device = torch.device("cuda") - self.torch_dtype = torch.float32 + self.torch_dtype = torch.float16 elif torch.backends.mps.is_available(): self.device = torch.device("mps") self.torch_dtype = torch.float32 @@ -105,10 +108,10 @@ async def infer_prompt( print(f"prompt: {prompt}") print(f"shard: {shard}") print(f"inference_state: {inference_state}") - + await self.ensure_shard(shard) - - # setup prompt input + + # setup prompt input messages = [{"role": "user", "content": prompt}] txt = self.tokenizer.apply_chat_template( messages, @@ -174,9 +177,9 @@ async def infer_prompt( async def infer_tensor( self, - request_id: str, - shard: Shard, - input_data: np.ndarray, + request_id: str, + shard: Shard, + input_data: np.ndarray, inference_state: Optional[str] = None ) -> Tuple[np.ndarray, str, bool]: if DEBUG >= 4: @@ -192,13 +195,13 @@ async def infer_tensor( # get cache from inference_state past_iids, cached_iids = self.infer_caching(inference_state) - # detect if hidden_states or not + # detect if hidden_states or not hidden_states = None self.past_input_ids = None if input_ids.size()[-1] > 1: hidden_states = input_ids else: - if past_iids is not None: + if past_iids is not None: self.past_input_ids = past_iids else: self.past_input_ids = input_ids diff --git a/exo/inference/pytorch/model/hf.py b/exo/inference/pytorch/model/hf.py index 38cd85c20..57a1590b0 100644 --- a/exo/inference/pytorch/model/hf.py +++ b/exo/inference/pytorch/model/hf.py @@ -16,7 +16,7 @@ TemperatureLogitsWarper ) -# llama +# llama from transformers.models.llama.modeling_llama import LlamaModel class ShardedHuggingFaceModel: diff --git a/exo/models.py b/exo/models.py index 6f69960ea..2f1e7d10a 100644 --- a/exo/models.py +++ b/exo/models.py @@ -4,6 +4,7 @@ ### llama "llama-3.2-1b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16), + "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=16), }, "llama-3.2-3b": { "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28), @@ -36,9 +37,6 @@ "llama-3-1B-Base": { "PyTorchDynamicShardInferenceEngine": Shard(model_id="andrijdavid/Llama3-1B-Base", start_layer=0, end_layer=0, n_layers=3), }, - "meta-llama/Llama-3.2-1B-Instruct": { - "PyTorchDynamicShardInferenceEngine": Shard(model_id="meta-llama/Llama-3.2-1B-Instruct", start_layer=0, end_layer=0, n_layers=24), - }, ### mistral "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),}, "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),}, diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index c00d2b0a6..9cad69d58 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -38,7 +38,6 @@ - diff --git a/setup.py b/setup.py index 8401167be..b23485a7f 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ "uuid==1.30", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad", "torch==2.4.0+cu124", - "accelerate=0.33.0" + "accelerate" ] # Add macOS-specific packages if on Darwin (macOS)