Skip to content

Commit

Permalink
Merge pull request #5 from risingsunomi/main
Browse files Browse the repository at this point in the history
Update.
  • Loading branch information
lipere123 authored Oct 18, 2024
2 parents 1e78ff9 + 69a8955 commit 402a3a8
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 36 deletions.
2 changes: 2 additions & 0 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ async def download_file(
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
if DEBUG >= 2: print(f"Range not satisfiable {file_path=} {total_size=} {downloaded_size=}")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
Expand Down
7 changes: 6 additions & 1 deletion exo/inference/torch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def infer_caching(
cached_iids = {"input_ids": past_iids.tolist()}

if DEBUG >= 4:
print(f"cached_iids len: {len(cached_iids)}")
print(f"cached_iids: {cached_iids}")

return (past_iids, cached_iids)
Expand Down Expand Up @@ -126,7 +127,11 @@ async def async_forward(
attention_mask=attention_mask
))

return result
if DEBUG >=4:
print("async_forward")
print(f"result: {result}")

return result[0], result[1], result[2]

async def async_logit_sample(
self,
Expand Down
39 changes: 26 additions & 13 deletions exo/inference/torch/model/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from exo.inference.torch.utils import extract_layers

from transformers import (
AutoConfig,
AutoModelForCausalLM,
DynamicCache,
Cache,
Expand Down Expand Up @@ -63,6 +62,7 @@ def __init__(
self.position_ids = None
self.causal_mask = None
self.local_model_path = local_model_path
self.is_sharded_model = False

# setup logit processors
self.logits_processor = LogitsProcessorList([
Expand All @@ -82,25 +82,30 @@ def __init__(
# setup pytorch and transformer llm
try:
if weight_map:
self.llm_model_config = self.load_sharded_model(
print("loading shard model")
self.llm_model = self.load_sharded_model(
shard,
weight_map,
offload_buffers=self.offload_buffers
)

self.is_sharded_model = True

# clear out edited safetensor json
# this is needed because shard downloader just
# appends and not redownloads the file
os.remove(self.model_safetensors_path)

self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device)
self.model = self.llm_model.model.to(self.device)
else:
self.llm_model_config = AutoConfig.from_pretrained(
print("loading full model")
self.llm_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.local_model_path,
torch_dtype=self.dtype,
device_map=self.device_map,
offload_buffers=self.offload_buffers
)

self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device)
offload_buffers=offload_buffers
).to(self.device)

self.model = self.llm_model.model.to(self.device)
except Exception as err:
Expand All @@ -112,7 +117,7 @@ def load_sharded_model(
shard: Shard,
weight_map: dict,
offload_buffers: bool
) -> AutoConfig:
) -> AutoModelForCausalLM:
"""
Loads sharded version of model where only needed
weights are loaded for necessary layers
Expand Down Expand Up @@ -154,13 +159,18 @@ def load_sharded_model(
shard_num_hidden_layers = shard.end_layer - shard.start_layer
if DEBUG >= 4:
print(f"config with {shard_num_hidden_layers} layers")
return AutoConfig.from_pretrained(

llm_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.local_model_path,
device_map=self.device_map,
torch_dtype=self.dtype,
offload_buffers=offload_buffers,
local_files_only=True,
num_hidden_layers=shard_num_hidden_layers
)

return llm_model.to(self.device)

except Exception as err:
print(f"err: {err}")
raise
Expand Down Expand Up @@ -255,11 +265,14 @@ def forward(
self.cache_position = model_inputs["cache_position"]
self.past_key_values = model_inputs["past_key_values"]

if DEBUG >= 4:
print(f"model_inputs: {model_inputs}")
if DEBUG >= 4:
print(f"model_inputs: {model_inputs}")

# run through decoder layers
layer_amt = range(self.shard.end_layer - self.shard.start_layer)
if self.is_sharded_model:
layer_amt = range(self.shard.end_layer - self.shard.start_layer)
else:
layer_amt = range(self.shard.start_layer, self.shard.end_layer)

if DEBUG >= 4:
print(f"hidden_states: {self.hidden_states}")
Expand Down Expand Up @@ -304,7 +317,7 @@ def forward(
# shard is last layer says true at the start and not detecting last layer correctly
if self.shard.is_last_layer():
self.hidden_states = self.model.norm(self.hidden_states)
if use_legacy_cache:
if use_legacy_cache and self.next_decoder_cache is not None:
self.past_key_values = self.next_decoder_cache.to_legacy_cache()
else:
self.past_key_values = self.next_decoder_cache
Expand Down
53 changes: 35 additions & 18 deletions exo/inference/torch/tests/test_split_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
from exo.inference.shard import Shard
from exo.inference.torch.utils import print_cuda_vram_stats

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_model(
repo_id: str,
shard: Shard,
model_path: Path,
weight_map: Optional[dict],
device: Optional[str] = "cuda"
device: Optional[torch.device] = torch.device("cpu")
) -> Optional[AutoModelForCausalLM]:
"""
load model by layer and safetensors
Expand All @@ -34,6 +33,24 @@ def load_model(
print("load_model called")
model_st_snapshot = model_path/"model.safetensors.index.json"

if os.environ.get("TORCH_DEVICE"):
device = torch.device(os.environ["TORCH_DEVICE"])
elif torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")

torch.set_default_device(device)

# setup cude dtype
dtype = torch.get_default_dtype()

# setup device_map
if os.environ.get("TORCH_DEVICE_MAP"):
device_map = os.environ["TORCH_DEVICE_MAP"]
else:
device_map = str(device)

if weight_map:
layer_weight_map = {}
non_layer_weights = []
Expand Down Expand Up @@ -89,18 +106,18 @@ def load_model(
# setup the weight range for init_weights
shard_num_hidden_layers = shard.end_layer - shard.start_layer
print(f"Setting up LLM config with {shard_num_hidden_layers} hidden layers")
llm_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=model_path,
device_map="cuda",
offload_buffers=True,
local_files_only=True,
num_hidden_layers=shard_num_hidden_layers
)

# load model with layer edits
# or whole model if no weight_map
print(f"Loading sharded AutoModelForCausalLM from {model_path}")
shard_model = AutoModelForCausalLM.from_config(llm_config).to(device)
shard_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path,
device_map=device_map,
torch_dtype=dtype,
offload_buffers=True,
local_files_only=True,
num_hidden_layers=shard_num_hidden_layers
).to(device)

print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -137,8 +154,6 @@ def load_model(
print(f"Prompt: {prompt}\n")
print(f"Response: {response}\n")

print_ram_stats()

# have to clear out edited model safetensors mst_json
os.remove(model_st_snapshot)

Expand Down Expand Up @@ -167,13 +182,15 @@ async def test_split_model(
weight_map = await get_weight_map(model_id)

load_model(
model_id,
shard,
model_path,
weight_map
)

if __name__ == "__main__":
n_layers = int(os.environ["N_LAYERS"]) if os.environ.get("N_LAYERS") else 32
start_layer = int(os.environ["START_LAYER"]) if os.environ.get("START_LAYER") else 0
end_layer = int(os.environ["END_LAYER"]) if os.environ.get("END_LAYER") else int(n_layers/2)
#Qwen/Qwen2.5-3B
#try:
# print("\n-------- Test Qwen/Qwen2.5-3B-Instruct ----------\n")
Expand All @@ -191,9 +208,9 @@ async def test_split_model(
print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n")
asyncio.run(test_split_model(
"unsloth/Meta-Llama-3.1-8B-Instruct",
0,
6,
32
start_layer,
end_layer,
n_layers
))
except Exception as err:
print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n")
print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.1-8B-Instruct TEST FAILED \n{err}\n")
4 changes: 0 additions & 4 deletions exo/inference/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def extract_layers(

non_layer_weights = sorted(non_layer_weights, key=lambda x: x[1])

print(non_layer_weights)
print(f"first: {shard.is_first_layer()}")
print(f"last: {shard.is_last_layer()}")

if shard.is_first_layer():
# this assumes at max only one first weight non-layer for model
first_weight = non_layer_weights[0]
Expand Down
7 changes: 7 additions & 0 deletions exo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,11 @@
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2-0.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=24),
"TorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24),
},
### nemotron
"nemotron-70b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),
},
"nemotron-70b-bf16": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),
},
}
2 changes: 2 additions & 0 deletions exo/tinychat/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
<option value="llama-3.1-405b">Llama 3.1 405B</option>
<option value="llama-3-8b">Llama 3 8B</option>
<option value="llama-3-70b">Llama 3 70B</option>
<option value="nemotron-70b">Nemotron 70B</option>
<option value="nemotron-70b-bf16">Nemotron 70B (BF16)</option>
<option value="mistral-nemo">Mistral Nemo</option>
<option value="mistral-large">Mistral Large</option>
<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
Expand Down

0 comments on commit 402a3a8

Please sign in to comment.