Skip to content

Commit fe6ae45

Browse files
authored
Merge pull request #20 from risingsunomi/pr139-dev-oct24
adding threadpooling to forward and logit sampling
2 parents 296dff6 + 9d24779 commit fe6ae45

File tree

1 file changed

+54
-12
lines changed

1 file changed

+54
-12
lines changed

exo/inference/pytorch/inference.py

+54-12
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
# experimental, based off of tinygrad/inference.py
2+
import asyncio
23
import os
34
import re
45
import numpy as np
56
import torch
67
import json
8+
import functools
9+
from concurrent.futures import ThreadPoolExecutor
710

8-
from typing import Optional, Tuple
11+
from typing import Optional, Tuple, Union, List
912
from exo.inference.shard import Shard
1013
from exo.inference.inference_engine import InferenceEngine
1114
from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel
1215
from exo.inference.tokenizers import resolve_tokenizer
1316
from exo.helpers import DEBUG
1417
from exo.download.hf.hf_shard_download import HFShardDownloader
1518

16-
from transformers import AutoTokenizer
19+
from transformers import AutoTokenizer, Cache
1720

1821
# llama
1922
from transformers.models.llama.modeling_llama import LlamaModel
@@ -39,8 +42,6 @@ def __init__(self, shard_downloader: HFShardDownloader):
3942
"""
4043
self.shard = None
4144
self.shard_downloader = shard_downloader
42-
self.stateful_sharded_model = None
43-
self.tokenizer = None
4445

4546
# the whole history with new logits need to
4647
# be passed to the model to reach the end token
@@ -59,15 +60,15 @@ def __init__(self, shard_downloader: HFShardDownloader):
5960
if torch.cuda.is_available():
6061
self.device = torch.device("cuda")
6162
self.torch_dtype = torch.float32
62-
elif torch.backends.mps.is_available():
63+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
6364
self.device = torch.device("mps")
6465
self.torch_dtype = torch.float32
6566
else:
6667
self.device = torch.device("cpu")
6768
self.torch_dtype = torch.float16
6869

69-
# setup unfinished sequence
70-
self.unfinished_sequences = torch.ones(1, dtype=torch.long, device=self.device)
70+
# setup threadding
71+
torch.set_num_threads(torch.get_num_threads())
7172

7273
def infer_caching(
7374
self,
@@ -98,6 +99,44 @@ def infer_caching(
9899

99100
return (past_iids, cached_iids)
100101

102+
async def async_forward(
103+
self,
104+
input_ids: Optional[torch.Tensor] = None,
105+
hidden_states: Optional[torch.Tensor] = None,
106+
attention_mask: Optional[torch.Tensor] = None
107+
) -> Tuple[Optional[torch.Tensor], Optional[Union[Cache, List[torch.FloatTensor]]], Optional[torch.Tensor]]:
108+
109+
loop = asyncio.get_running_loop()
110+
111+
forward_partial = functools.partial(
112+
self.stateful_sharded_model.forward,
113+
input_ids=input_ids,
114+
hidden_states=hidden_states,
115+
attention_mask=attention_mask
116+
)
117+
118+
with ThreadPoolExecutor() as pool:
119+
result = await loop.run_in_executor(pool, forward_partial)
120+
121+
return result
122+
123+
async def async_logit_sample(
124+
self,
125+
logits: torch.Tensor
126+
) -> torch.Tensor:
127+
128+
loop = asyncio.get_running_loop()
129+
130+
sample_partial = functools.partial(
131+
self.stateful_sharded_model.logits_sample,
132+
logits=logits
133+
)
134+
135+
with ThreadPoolExecutor() as pool:
136+
result = await loop.run_in_executor(pool, sample_partial)
137+
138+
return result
139+
101140
async def infer_prompt(
102141
self,
103142
request_id: str,
@@ -129,7 +168,7 @@ async def infer_prompt(
129168
if DEBUG >= 4:
130169
print(f"past_input_ids: {self.past_input_ids}\n")
131170

132-
shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward(
171+
shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward(
133172
input_ids=self.past_input_ids,
134173
attention_mask=input_attention_mask
135174
)
@@ -141,7 +180,7 @@ async def infer_prompt(
141180

142181
next_token = None
143182
if shard_logits is not None:
144-
next_token = self.stateful_sharded_model.logits_sample(shard_logits)
183+
next_token = await self.async_logit_sample(shard_logits)
145184
self.past_input_ids = torch.cat([input_ids, next_token[:, None].squeeze(-1)], dim=-1)
146185
input_ids = next_token
147186

@@ -206,24 +245,27 @@ async def infer_tensor(
206245
print(f"hidden_state: {hidden_states}")
207246
print(f"inference_state: {inference_state}")
208247

209-
shard_hidden_states, shard_past_kvs, shard_logits = self.stateful_sharded_model.forward(
248+
shard_hidden_states, shard_past_kvs, shard_logits = await self.async_forward(
210249
input_ids=self.past_input_ids,
211250
hidden_states=hidden_states
212251
)
213252

214253
next_token = None
215254
if shard_logits is not None:
216-
next_token = self.stateful_sharded_model.logits_sample(shard_logits)
255+
next_token = await self.async_logit_sample(shard_logits)
217256
input_ids = next_token
218257

219258
#cache
259+
next_cached_logits = None
220260
if next_token is not None:
221261
if self.past_input_ids is not None:
222262
next_cached_logits = torch.cat([self.past_input_ids, next_token], dim=-1).to(self.device)
223263
elif past_iids is not None:
224264
next_cached_logits = torch.cat([past_iids, next_token], dim=-1).to(self.device)
225265

226-
cached_iids = {"input_ids": next_cached_logits.tolist()}
266+
cached_iids = {
267+
"input_ids": next_cached_logits.tolist() if next_cached_logits is not None else []
268+
}
227269

228270
is_finished = False
229271
if next_token is not None:

0 commit comments

Comments
 (0)