1
1
# experimental, based off of tinygrad/inference.py
2
-
3
2
import numpy as np
4
3
import torch
5
4
import numpy as np
6
5
import json
7
- from typing import Optional , Callable , Tuple
6
+ from typing import Optional , Tuple
8
7
from exo .inference .shard import Shard
9
8
from exo .inference .inference_engine import InferenceEngine
10
9
from exo .inference .pytorch .model .hf import ShardedHuggingFaceModel
11
10
from exo .api .chatgpt_api import resolve_tokenizer
12
11
from exo .helpers import DEBUG
13
12
from transformers import DynamicCache
13
+ from accelerate import disk_offload
14
14
15
15
class PyTorchDynamicShardInferenceEngine (InferenceEngine ):
16
16
"""
17
17
PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models.
18
18
"""
19
19
20
- def __init__ (self ):
20
+ def __init__ (self , shard ):
21
21
"""
22
22
Initialize the inference engine.
23
23
24
24
Args:
25
25
debug (bool): If True, enables debug logging. Defaults to False.
26
26
"""
27
- self .shard = None
27
+ self .shard = shard
28
28
self .model = None
29
29
self .tokenizer = None
30
30
self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -37,41 +37,57 @@ async def infer_prompt(
37
37
image_str : Optional [str ] = None ,
38
38
inference_state : Optional [str ] = None
39
39
) -> Tuple [np .ndarray , str , bool ]:
40
- if DEBUG >= 2 :
41
- print ("infer_prompt called" )
42
-
40
+
43
41
await self .ensure_shard (shard )
44
42
45
43
# need to make this so inference_state is not a string
46
44
# cant use it with dynamic cache
47
45
48
- tokens = self .tokenizer .encode (prompt , return_tensors = "pt" )
46
+ tokens = self .tokenizer .encode (prompt , return_tensors = "pt" ).to (self .device )
47
+ tokens = self .model .embed_tokens (tokens )
48
+ current_kvs = None
49
49
50
- if DEBUG >= 2 :
50
+ if DEBUG >= 4 :
51
+ print ("infer_prompt called" )
51
52
print (f"tokens: { tokens } \n " )
52
-
53
- output_data = self .model .forward_layers (
54
- tokens
53
+ print (f"layer_count: { self .shard .get_layer_count ()} " )
54
+ print (f"is_first_layer: { self .shard .is_first_layer ()} " )
55
+ print (f"is_last_layer: { self .shard .is_last_layer ()} " )
56
+
57
+ # convert inference_state or cache from json to DynamicCache
58
+ past_kv = DynamicCache ()
59
+ if inference_state != None :
60
+ cache_dict = json .loads (inference_state )
61
+ past_kv .key_cache = [torch .tensor (data ).to (self .device ) for data in cache_dict ['key_cache' ]]
62
+ past_kv .value_cache = [torch .tensor (data ).to (self .device ) for data in cache_dict ['value_cache' ]]
63
+
64
+ output_data , current_kvs = self .model .forward (
65
+ tokens ,
66
+ past_kv
55
67
)
56
68
57
69
is_finished = output_data .size == 1 and output_data .item () in [self .tokenizer .eos_token_id ]
58
70
59
- if is_finished :
60
- print (f"token from llm decode: { self .tokenizer .decode (output_data )} " )
61
-
62
-
63
- if DEBUG >= 2 :
71
+ if DEBUG >= 4 :
64
72
print (f"output_data: { output_data } \n " )
65
73
print (f"output_data.size { output_data .size } \n " )
66
- print ( f"output_data.item() { output_data . item () } " )
74
+
67
75
print (f"finished: { is_finished } " )
68
76
print (f"self.tokenizer.eos_token_id { self .tokenizer .eos_token_id } " )
69
77
print (f"output_data[-1] { output_data [- 1 ]} " )
70
- print (f"output_data.item() in [self.tokenizer.eos_token_id]: { output_data .item () in [self .tokenizer .eos_token_id ]} " )
78
+
79
+ if output_data .size == 1 :
80
+ print (f"size 1 output_data.item() { output_data .item ()} " )
81
+ print (f"output_data.item() in [self.tokenizer.eos_token_id]: { output_data .item () in [self .tokenizer .eos_token_id ]} " )
82
+
83
+ cache_dict = {
84
+ 'key_cache' : [tensor .tolist () for tensor in current_kvs .key_cache ],
85
+ 'value_cache' : [tensor .tolist () for tensor in current_kvs .value_cache ]
86
+ }
71
87
72
88
return (
73
89
output_data ,
74
- "" ,
90
+ json . dumps ( cache_dict ) ,
75
91
is_finished
76
92
)
77
93
@@ -80,39 +96,78 @@ async def infer_tensor(
80
96
request_id : str ,
81
97
shard : Shard ,
82
98
input_data : np .ndarray ,
83
- inference_state : Optional [str ] = None ) -> Tuple [np .ndarray , str , bool ]:
99
+ inference_state : Optional [str ] = None
100
+ ) -> Tuple [np .ndarray , str , bool ]:
84
101
85
- in_tensor = torch .tensor (input_data )
86
-
87
- # Ensure input_data is 2D: [batch_size, seq_len]
88
- if in_tensor .dim () == 1 :
89
- in_tensor = in_tensor .unsqueeze (0 ) # Add a batch dimension: [1, seq_len]
102
+ await self .ensure_shard (shard )
90
103
91
- if DEBUG >= 2 :
92
- print ("infer_tensor called" )
93
- print (f"input_data: { input_data } \n " )
94
- print (f"in_tensor: { in_tensor } \n " )
104
+ current_kvs = None
95
105
96
- await self .ensure_shard (shard )
106
+ if input_data .size == 1 :
107
+ in_tensor = torch .from_numpy (
108
+ input_data ,
109
+ ).unsqueeze (0 ).long ().to (self .device )
110
+ else :
111
+ in_tensor = torch .from_numpy (
112
+ input_data
113
+ ).long ().to (self .device )
114
+
115
+ in_tensor = self .model .embed_tokens (in_tensor )
97
116
98
- output_data = self .model .forward_layers (
99
- in_tensor
117
+ if DEBUG >= 4 :
118
+ print ("infer_tensor called" )
119
+ print (f"input_data: { input_data } " )
120
+ print (f"input_data.size: { input_data .size } " )
121
+ print (f"input_tensor: { in_tensor } \n " )
122
+ print (f"shard: { self .shard } " )
123
+ print (f"layer_count: { self .shard .get_layer_count ()} " )
124
+ print (f"is_first_layer: { self .shard .is_first_layer ()} " )
125
+ print (f"is_last_layer: { self .shard .is_last_layer ()} " )
126
+
127
+ # convert inference_state or cache from json to DynamicCache
128
+ past_kv = DynamicCache ()
129
+ if inference_state != None :
130
+ try :
131
+ cache_dict = json .loads (inference_state )
132
+ past_kv .key_cache = [torch .tensor (data ).to (self .device ) for data in cache_dict ['key_cache' ]]
133
+ past_kv .value_cache = [torch .tensor (data ).to (self .device ) for data in cache_dict ['value_cache' ]]
134
+
135
+ if DEBUG >= 4 :
136
+ print ("Loaded past_kv from JSON" )
137
+ print (f"past_kv: { past_kv } " )
138
+ print (f"past_kv.key_cache len: { len (past_kv .key_cache )} " )
139
+ print (f"past_kv.value_cache len: { len (past_kv .value_cache )} " )
140
+ except json .JSONDecodeError :
141
+ print (f"ERROR DECODING INFERENCE STATE" )
142
+
143
+ output_data , current_kvs = self .model .forward (
144
+ in_tensor ,
145
+ past_kv
100
146
)
101
147
102
148
is_finished = output_data .size == 1 and output_data .item () in [self .tokenizer .eos_token_id ]
103
149
104
- if DEBUG >= 2 :
150
+ if DEBUG >= 4 :
151
+ print (f"in_tensor: { in_tensor } \n " )
105
152
print (f"output_data: { output_data } \n " )
106
153
print (f"output_data.size { output_data .size } \n " )
107
- print (f"output_data.item() { output_data .item ()} " )
108
154
print (f"finished: { is_finished } " )
109
155
print (f"self.tokenizer.eos_token_id { self .tokenizer .eos_token_id } " )
110
156
print (f"output_data[-1] { output_data [- 1 ]} " )
111
- print (f"output_data.item() in [self.tokenizer.eos_token_id]: { output_data .item () in [self .tokenizer .eos_token_id ]} " )
157
+
158
+ if output_data .size == 1 :
159
+ print (f"size 1 output_data.item() { output_data .item ()} " )
160
+ print (f"output_data.item() in [self.tokenizer.eos_token_id]: { output_data .item () in [self .tokenizer .eos_token_id ]} " )
161
+
162
+
163
+ cache_dict = {
164
+ 'key_cache' : [tensor .tolist () for tensor in current_kvs .key_cache ],
165
+ 'value_cache' : [tensor .tolist () for tensor in current_kvs .value_cache ]
166
+ }
112
167
113
168
return (
114
169
output_data ,
115
- "" ,
170
+ json . dumps ( cache_dict ) ,
116
171
is_finished
117
172
)
118
173
@@ -126,12 +181,12 @@ async def ensure_shard(self, shard: Optional[Shard]):
126
181
if self .shard == shard :
127
182
return
128
183
129
- if DEBUG >= 2 :
184
+ if DEBUG >= 4 :
130
185
print (f"Loading new shard: { shard } " )
131
186
132
- self .model = ShardedHuggingFaceModel (shard )
133
- self .tokenizer = await resolve_tokenizer (shard .model_id )
134
187
self .shard = shard
188
+ self .tokenizer = await resolve_tokenizer (shard .model_id )
189
+ self .model = ShardedHuggingFaceModel (shard , self .tokenizer )
135
190
136
- if DEBUG >= 2 :
191
+ if DEBUG >= 4 :
137
192
print (f"Shard loaded successfully: { shard } " )
0 commit comments