1
1
# experimental, based off of tinygrad/inference.py
2
+ import asyncio
2
3
import os
3
4
import re
4
5
import numpy as np
5
6
import torch
6
7
import json
8
+ import functools
9
+ from concurrent .futures import ThreadPoolExecutor
7
10
8
- from typing import Optional , Tuple
11
+ from typing import Optional , Tuple , Union , List
9
12
from exo .inference .shard import Shard
10
13
from exo .inference .inference_engine import InferenceEngine
11
14
from exo .inference .pytorch .model .hf import ShardedHuggingFaceModel
12
15
from exo .inference .tokenizers import resolve_tokenizer
13
16
from exo .helpers import DEBUG
14
17
from exo .download .hf .hf_shard_download import HFShardDownloader
15
18
16
- from transformers import AutoTokenizer
19
+ from transformers import AutoTokenizer , Cache
17
20
18
21
# llama
19
22
from transformers .models .llama .modeling_llama import LlamaModel
@@ -39,8 +42,6 @@ def __init__(self, shard_downloader: HFShardDownloader):
39
42
"""
40
43
self .shard = None
41
44
self .shard_downloader = shard_downloader
42
- self .stateful_sharded_model = None
43
- self .tokenizer = None
44
45
45
46
# the whole history with new logits need to
46
47
# be passed to the model to reach the end token
@@ -59,15 +60,15 @@ def __init__(self, shard_downloader: HFShardDownloader):
59
60
if torch .cuda .is_available ():
60
61
self .device = torch .device ("cuda" )
61
62
self .torch_dtype = torch .float32
62
- elif torch .backends .mps .is_available ():
63
+ elif torch .backends .mps .is_available () and torch . backends . mps . is_built () :
63
64
self .device = torch .device ("mps" )
64
65
self .torch_dtype = torch .float32
65
66
else :
66
67
self .device = torch .device ("cpu" )
67
68
self .torch_dtype = torch .float16
68
69
69
- # setup unfinished sequence
70
- self . unfinished_sequences = torch .ones ( 1 , dtype = torch .long , device = self . device )
70
+ # setup threadding
71
+ torch .set_num_threads ( torch .get_num_threads () )
71
72
72
73
def infer_caching (
73
74
self ,
@@ -98,6 +99,44 @@ def infer_caching(
98
99
99
100
return (past_iids , cached_iids )
100
101
102
+ async def async_forward (
103
+ self ,
104
+ input_ids : Optional [torch .Tensor ] = None ,
105
+ hidden_states : Optional [torch .Tensor ] = None ,
106
+ attention_mask : Optional [torch .Tensor ] = None
107
+ ) -> Tuple [Optional [torch .Tensor ], Optional [Union [Cache , List [torch .FloatTensor ]]], Optional [torch .Tensor ]]:
108
+
109
+ loop = asyncio .get_running_loop ()
110
+
111
+ forward_partial = functools .partial (
112
+ self .stateful_sharded_model .forward ,
113
+ input_ids = input_ids ,
114
+ hidden_states = hidden_states ,
115
+ attention_mask = attention_mask
116
+ )
117
+
118
+ with ThreadPoolExecutor () as pool :
119
+ result = await loop .run_in_executor (pool , forward_partial )
120
+
121
+ return result
122
+
123
+ async def async_logit_sample (
124
+ self ,
125
+ logits : torch .Tensor
126
+ ) -> torch .Tensor :
127
+
128
+ loop = asyncio .get_running_loop ()
129
+
130
+ sample_partial = functools .partial (
131
+ self .stateful_sharded_model .logits_sample ,
132
+ logits = logits
133
+ )
134
+
135
+ with ThreadPoolExecutor () as pool :
136
+ result = await loop .run_in_executor (pool , sample_partial )
137
+
138
+ return result
139
+
101
140
async def infer_prompt (
102
141
self ,
103
142
request_id : str ,
@@ -129,7 +168,7 @@ async def infer_prompt(
129
168
if DEBUG >= 4 :
130
169
print (f"past_input_ids: { self .past_input_ids } \n " )
131
170
132
- shard_hidden_states , shard_past_kvs , shard_logits = self .stateful_sharded_model . forward (
171
+ shard_hidden_states , shard_past_kvs , shard_logits = await self .async_forward (
133
172
input_ids = self .past_input_ids ,
134
173
attention_mask = input_attention_mask
135
174
)
@@ -141,7 +180,7 @@ async def infer_prompt(
141
180
142
181
next_token = None
143
182
if shard_logits is not None :
144
- next_token = self .stateful_sharded_model . logits_sample (shard_logits )
183
+ next_token = await self .async_logit_sample (shard_logits )
145
184
self .past_input_ids = torch .cat ([input_ids , next_token [:, None ].squeeze (- 1 )], dim = - 1 )
146
185
input_ids = next_token
147
186
@@ -206,24 +245,27 @@ async def infer_tensor(
206
245
print (f"hidden_state: { hidden_states } " )
207
246
print (f"inference_state: { inference_state } " )
208
247
209
- shard_hidden_states , shard_past_kvs , shard_logits = self .stateful_sharded_model . forward (
248
+ shard_hidden_states , shard_past_kvs , shard_logits = await self .async_forward (
210
249
input_ids = self .past_input_ids ,
211
250
hidden_states = hidden_states
212
251
)
213
252
214
253
next_token = None
215
254
if shard_logits is not None :
216
- next_token = self .stateful_sharded_model . logits_sample (shard_logits )
255
+ next_token = await self .async_logit_sample (shard_logits )
217
256
input_ids = next_token
218
257
219
258
#cache
259
+ next_cached_logits = None
220
260
if next_token is not None :
221
261
if self .past_input_ids is not None :
222
262
next_cached_logits = torch .cat ([self .past_input_ids , next_token ], dim = - 1 ).to (self .device )
223
263
elif past_iids is not None :
224
264
next_cached_logits = torch .cat ([past_iids , next_token ], dim = - 1 ).to (self .device )
225
265
226
- cached_iids = {"input_ids" : next_cached_logits .tolist ()}
266
+ cached_iids = {
267
+ "input_ids" : next_cached_logits .tolist () if next_cached_logits is not None else []
268
+ }
227
269
228
270
is_finished = False
229
271
if next_token is not None :
0 commit comments