Skip to content

Commit

Permalink
Merge pull request #20 from risingsunomi/pr139-dev-oct24
Browse files Browse the repository at this point in the history
adding threadpooling to forward and logit sampling
  • Loading branch information
risingsunomi authored Oct 9, 2024
2 parents 296dff6 + 9d24779 commit fe6ae45
Showing 1 changed file with 54 additions and 12 deletions.
66 changes: 54 additions & 12 deletions exo/inference/pytorch/inference.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# experimental, based off of tinygrad/inference.py
import asyncio
import os
import re
import numpy as np
import torch
import json
import functools
from concurrent.futures import ThreadPoolExecutor

from typing import Optional, Tuple
from typing import Optional, Tuple, Union, List
from exo.inference.shard import Shard
from exo.inference.inference_engine import InferenceEngine
from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel
from exo.inference.tokenizers import resolve_tokenizer
from exo.helpers import DEBUG
from exo.download.hf.hf_shard_download import HFShardDownloader

from transformers import AutoTokenizer
from transformers import AutoTokenizer, Cache

# llama
from transformers.models.llama.modeling_llama import LlamaModel
Expand All @@ -39,8 +42,6 @@ def __init__(self, shard_downloader: HFShardDownloader):
"""
self.shard = None
self.shard_downloader = shard_downloader
self.stateful_sharded_model = None
self.tokenizer = None

# the whole history with new logits need to
# be passed to the model to reach the end token
Expand All @@ -59,15 +60,15 @@ def __init__(self, shard_downloader: HFShardDownloader):
if torch.cuda.is_available():
self.device = torch.device("cuda")
self.torch_dtype = torch.float32
elif torch.backends.mps.is_available():
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
self.device = torch.device("mps")
self.torch_dtype = torch.float32
else:
self.device = torch.device("cpu")
self.torch_dtype = torch.float16

# setup unfinished sequence
self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device)
# setup threadding
torch.set_num_threads(torch.get_num_threads())

def infer_caching(
self,
Expand Down Expand Up @@ -98,6 +99,44 @@ def infer_caching(

return (past_iids, cached_iids)

async def async_forward(
self,
input_ids: Optional[torch.Tensor] = None,
hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None
) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]:

loop = asyncio.get_running_loop()

forward_partial = functools.partial(
self.stateful_sharded_model.forward,
input_ids=input_ids,
hidden_states=hidden_states,
attention_mask=attention_mask
)

with ThreadPoolExecutor() as pool:
result = await loop.run_in_executor(pool, forward_partial)

return result

async def async_logit_sample(
self,
logits: torch.Tensor
) -> torch.Tensor:

loop = asyncio.get_running_loop()

sample_partial = functools.partial(
self.stateful_sharded_model.logits_sample,
logits=logits
)

with ThreadPoolExecutor() as pool:
result = await loop.run_in_executor(pool, sample_partial)

return result

async def infer_prompt(
self,
request_id: str,
Expand Down Expand Up @@ -129,7 +168,7 @@ async def infer_prompt(
if DEBUG >= 4:
print(f"past_input_ids: {self.past_input_ids}\n")

shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward(
shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward(
input_ids=self.past_input_ids,
attention_mask=input_attention_mask
)
Expand All @@ -141,7 +180,7 @@ async def infer_prompt(

next_token = None
if shard_logits is not None:
next_token = self.stateful_sharded_model.logits_sample(shard_logits)
next_token = await self.async_logit_sample(shard_logits)
self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1)
input_ids = next_token

Expand Down Expand Up @@ -206,24 +245,27 @@ async def infer_tensor(
print(f"hidden_state: {hidden_states}")
print(f"inference_state: {inference_state}")

shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward(
shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward(
input_ids=self.past_input_ids,
hidden_states=hidden_states
)

next_token = None
if shard_logits is not None:
next_token = self.stateful_sharded_model.logits_sample(shard_logits)
next_token = await self.async_logit_sample(shard_logits)
input_ids = next_token

#cache
next_cached_logits = None
if next_token is not None:
if self.past_input_ids is not None:
next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device)
elif past_iids is not None:
next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device)

cached_iids = {"input_ids": next_cached_logits.tolist()}
cached_iids = {
"input_ids": next_cached_logits.tolist() if next_cached_logits is not None else []
}

is_finished = False
if next_token is not None:
Expand Down

0 comments on commit fe6ae45

Please sign in to comment.