Skip to content

Commit 46667b6

Browse files
authored
Merge pull request #8 from risingsunomi/pr139-dev
Pr139 dev
2 parents ea41845 + ed5bea7 commit 46667b6

18 files changed

+509
-207
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,6 @@ cython_debug/
170170
#.idea/
171171

172172
**/*.xcodeproj/*
173+
174+
# PyTorch interface
175+
.offload

exo/api/chatgpt_api.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,27 @@ def remap_messages(messages: List[Message]) -> List[Message]:
113113

114114

115115
def build_prompt(tokenizer, _messages: List[Message]):
116+
if len(_messages) == 1:
117+
user_msg = _messages[0]
118+
119+
# get instruct sys message
120+
sys_msg = Message(role="system", content="You are a helpful assistant.")
121+
122+
# restructure for sys_msg to go first
123+
_messages = [sys_msg, user_msg]
124+
116125
messages = remap_messages(_messages)
117-
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
126+
prompt = tokenizer.apply_chat_template(
127+
messages,
128+
tokenize=False,
129+
add_generation_prompt=True
130+
)
131+
132+
if DEBUG >= 3:
133+
print(f"prompt: {str(prompt)}")
134+
for msg in messages:
135+
print(f"chat role: {msg.role}\ncontent: {msg.content}")
136+
118137
image_str = None
119138
for message in messages:
120139
if not isinstance(message.content, list):

exo/download/__init__.py

Whitespace-only changes.

exo/download/hf/__init__.py

Whitespace-only changes.

exo/inference/inference_engine.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class InferenceEngine(ABC):
1010
@abstractmethod
11-
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
11+
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
1212
pass
1313

1414
@abstractmethod
@@ -27,5 +27,8 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
2727
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
2828

2929
return TinygradDynamicShardInferenceEngine(shard_downloader)
30+
elif inference_engine_name == "pytorch":
31+
from exo.inference.pytorch.inference import PyTorchDynamicShardInferenceEngine
32+
return PyTorchDynamicShardInferenceEngine(shard_downloader)
3033
else:
3134
raise ValueError(f"Inference engine {inference_engine_name} not supported")

exo/inference/pytorch/README.md

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# PyTorch & HuggingFace inference engine
2+
Experimental, still under development
3+
4+
5+
## Install
6+
Install needed py modules, make sure to be using CUDA 12.4 for the PyTorch install
7+
8+
```console
9+
$ pip install torch --index-url https://download.pytorch.org/whl/cu124
10+
$ pip install transformers accelerate
11+
```
12+
13+
After installing accelerate you get hit with a dependency error, for now ignore until we can fix this as exo works fine with 1.26.4
14+
15+
```console
16+
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
17+
exo 0.0.1 requires numpy==2.0.0, but you have numpy 1.26.4 which is incompatible.
18+
```
19+
20+
## Low VRAM Notes
21+
22+
- When trying to do disk_offload getting the error "Cannot copy out of meta tensor; no data!", looking up the error it is tied to (low vram)[https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13087#issuecomment-2080272004]
23+
24+
## Multiple GPU in 1 Notes
25+
### Running multiple GPUs on 1 machine
26+
- Getting error "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_CUDA_cat)"

exo/inference/pytorch/__init__.py

Whitespace-only changes.

exo/inference/pytorch/inference.py

+96-41
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
# experimental, based off of tinygrad/inference.py
2-
32
import numpy as np
43
import torch
54
import numpy as np
65
import json
7-
from typing import Optional, Callable, Tuple
6+
from typing import Optional, Tuple
87
from exo.inference.shard import Shard
98
from exo.inference.inference_engine import InferenceEngine
109
from exo.inference.pytorch.model.hf import ShardedHuggingFaceModel
1110
from exo.api.chatgpt_api import resolve_tokenizer
1211
from exo.helpers import DEBUG
1312
from transformers import DynamicCache
13+
from accelerate import disk_offload
1414

1515
class PyTorchDynamicShardInferenceEngine(InferenceEngine):
1616
"""
1717
PyTorch Dynamic Shard Inference Engine for performing model inference with sharded models.
1818
"""
1919

20-
def __init__(self):
20+
def __init__(self, shard):
2121
"""
2222
Initialize the inference engine.
2323
2424
Args:
2525
debug (bool): If True, enables debug logging. Defaults to False.
2626
"""
27-
self.shard = None
27+
self.shard = shard
2828
self.model = None
2929
self.tokenizer = None
3030
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -37,41 +37,57 @@ async def infer_prompt(
3737
image_str: Optional[str] = None,
3838
inference_state: Optional[str] = None
3939
) -> Tuple[np.ndarray, str, bool]:
40-
if DEBUG >= 2:
41-
print("infer_prompt called")
42-
40+
4341
await self.ensure_shard(shard)
4442

4543
# need to make this so inference_state is not a string
4644
# cant use it with dynamic cache
4745

48-
tokens = self.tokenizer.encode(prompt, return_tensors="pt")
46+
tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
47+
tokens = self.model.embed_tokens(tokens)
48+
current_kvs = None
4949

50-
if DEBUG >= 2:
50+
if DEBUG >= 4:
51+
print("infer_prompt called")
5152
print(f"tokens: {tokens}\n")
52-
53-
output_data = self.model.forward_layers(
54-
tokens
53+
print(f"layer_count: {self.shard.get_layer_count()}")
54+
print(f"is_first_layer: {self.shard.is_first_layer()}")
55+
print(f"is_last_layer: {self.shard.is_last_layer()}")
56+
57+
# convert inference_state or cache from json to DynamicCache
58+
past_kv = DynamicCache()
59+
if inference_state != None:
60+
cache_dict = json.loads(inference_state)
61+
past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']]
62+
past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']]
63+
64+
output_data, current_kvs = self.model.forward(
65+
tokens,
66+
past_kv
5567
)
5668

5769
is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id]
5870

59-
if is_finished:
60-
print(f"token from llm decode: {self.tokenizer.decode(output_data)}")
61-
62-
63-
if DEBUG >= 2:
71+
if DEBUG >= 4:
6472
print(f"output_data: {output_data}\n")
6573
print(f"output_data.size {output_data.size}\n")
66-
print(f"output_data.item() {output_data.item()}")
74+
6775
print(f"finished: {is_finished}")
6876
print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}")
6977
print(f"output_data[-1] {output_data[-1]}")
70-
print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}")
78+
79+
if output_data.size == 1:
80+
print(f"size 1 output_data.item() {output_data.item()}")
81+
print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}")
82+
83+
cache_dict = {
84+
'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache],
85+
'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache]
86+
}
7187

7288
return (
7389
output_data,
74-
"",
90+
json.dumps(cache_dict),
7591
is_finished
7692
)
7793

@@ -80,39 +96,78 @@ async def infer_tensor(
8096
request_id: str,
8197
shard: Shard,
8298
input_data: np.ndarray,
83-
inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
99+
inference_state: Optional[str] = None
100+
) -> Tuple[np.ndarray, str, bool]:
84101

85-
in_tensor = torch.tensor(input_data)
86-
87-
# Ensure input_data is 2D: [batch_size, seq_len]
88-
if in_tensor.dim() == 1:
89-
in_tensor = in_tensor.unsqueeze(0) # Add a batch dimension: [1, seq_len]
102+
await self.ensure_shard(shard)
90103

91-
if DEBUG >= 2:
92-
print("infer_tensor called")
93-
print(f"input_data: {input_data}\n")
94-
print(f"in_tensor: {in_tensor}\n")
104+
current_kvs = None
95105

96-
await self.ensure_shard(shard)
106+
if input_data.size == 1:
107+
in_tensor = torch.from_numpy(
108+
input_data,
109+
).unsqueeze(0).long().to(self.device)
110+
else:
111+
in_tensor = torch.from_numpy(
112+
input_data
113+
).long().to(self.device)
114+
115+
in_tensor = self.model.embed_tokens(in_tensor)
97116

98-
output_data = self.model.forward_layers(
99-
in_tensor
117+
if DEBUG >= 4:
118+
print("infer_tensor called")
119+
print(f"input_data: {input_data}")
120+
print(f"input_data.size: {input_data.size}")
121+
print(f"input_tensor: {in_tensor}\n")
122+
print(f"shard: {self.shard}")
123+
print(f"layer_count: {self.shard.get_layer_count()}")
124+
print(f"is_first_layer: {self.shard.is_first_layer()}")
125+
print(f"is_last_layer: {self.shard.is_last_layer()}")
126+
127+
# convert inference_state or cache from json to DynamicCache
128+
past_kv = DynamicCache()
129+
if inference_state != None:
130+
try:
131+
cache_dict = json.loads(inference_state)
132+
past_kv.key_cache = [torch.tensor(data).to(self.device) for data in cache_dict['key_cache']]
133+
past_kv.value_cache = [torch.tensor(data).to(self.device) for data in cache_dict['value_cache']]
134+
135+
if DEBUG >= 4:
136+
print("Loaded past_kv from JSON")
137+
print(f"past_kv: {past_kv}")
138+
print(f"past_kv.key_cache len: {len(past_kv.key_cache)}")
139+
print(f"past_kv.value_cache len: {len(past_kv.value_cache)}")
140+
except json.JSONDecodeError:
141+
print(f"ERROR DECODING INFERENCE STATE")
142+
143+
output_data, current_kvs = self.model.forward(
144+
in_tensor,
145+
past_kv
100146
)
101147

102148
is_finished = output_data.size == 1 and output_data.item() in [self.tokenizer.eos_token_id]
103149

104-
if DEBUG >= 2:
150+
if DEBUG >= 4:
151+
print(f"in_tensor: {in_tensor}\n")
105152
print(f"output_data: {output_data}\n")
106153
print(f"output_data.size {output_data.size}\n")
107-
print(f"output_data.item() {output_data.item()}")
108154
print(f"finished: {is_finished}")
109155
print(f"self.tokenizer.eos_token_id {self.tokenizer.eos_token_id}")
110156
print(f"output_data[-1] {output_data[-1]}")
111-
print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}")
157+
158+
if output_data.size == 1:
159+
print(f"size 1 output_data.item() {output_data.item()}")
160+
print(f"output_data.item() in [self.tokenizer.eos_token_id]: {output_data.item() in [self.tokenizer.eos_token_id]}")
161+
162+
163+
cache_dict = {
164+
'key_cache': [tensor.tolist() for tensor in current_kvs.key_cache],
165+
'value_cache': [tensor.tolist() for tensor in current_kvs.value_cache]
166+
}
112167

113168
return (
114169
output_data,
115-
"",
170+
json.dumps(cache_dict),
116171
is_finished
117172
)
118173

@@ -126,12 +181,12 @@ async def ensure_shard(self, shard: Optional[Shard]):
126181
if self.shard == shard:
127182
return
128183

129-
if DEBUG >= 2:
184+
if DEBUG >= 4:
130185
print(f"Loading new shard: {shard}")
131186

132-
self.model = ShardedHuggingFaceModel(shard)
133-
self.tokenizer = await resolve_tokenizer(shard.model_id)
134187
self.shard = shard
188+
self.tokenizer = await resolve_tokenizer(shard.model_id)
189+
self.model = ShardedHuggingFaceModel(shard, self.tokenizer)
135190

136-
if DEBUG >= 2:
191+
if DEBUG >= 4:
137192
print(f"Shard loaded successfully: {shard}")

exo/inference/pytorch/model/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)