Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update. #5

Merged
merged 26 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e7470b1
bugfix in llm setup
dtnewman Oct 15, 2024
fa24f46
Merge pull request #27 from dtnewman/main
risingsunomi Oct 15, 2024
61ee67c
add nemotron-70b and nemotron-70b-bf16
AlexCheema Oct 16, 2024
feae3ae
Merge pull request #356 from exo-explore/nemotron
AlexCheema Oct 16, 2024
1e4524b
add nemotron-70b and nemotron-70b-bf16 to tinychat
AlexCheema Oct 16, 2024
5c69f3f
Merge remote-tracking branch 'origin/main' into HEAD
AlexCheema Oct 16, 2024
f5a1cef
handle range not satisfiable edge case
AlexCheema Oct 16, 2024
751bd1c
updating to use automodelforcausallm instead of autoconfig
risingsunomi Oct 16, 2024
7d866d8
removing meta model
risingsunomi Oct 16, 2024
253237b
updating split model test
risingsunomi Oct 16, 2024
e46ffa4
updating split model test
risingsunomi Oct 16, 2024
476b6ba
automodel fix
risingsunomi Oct 16, 2024
f7e02e9
fixing split model test
risingsunomi Oct 16, 2024
bd6322f
pytorch offload buffers error
risingsunomi Oct 17, 2024
c51bd91
device_map any issue with split model
risingsunomi Oct 17, 2024
4a2aef4
updating split model test
risingsunomi Oct 17, 2024
79f0763
fixing split model issue
risingsunomi Oct 17, 2024
cbbc9cf
fixing node issues
risingsunomi Oct 17, 2024
58cebab
fixing node issues
risingsunomi Oct 17, 2024
7f9b1bb
fixing node issues
risingsunomi Oct 17, 2024
c3adec5
fixing node issues
risingsunomi Oct 17, 2024
c8e6acc
fixing node issues
risingsunomi Oct 17, 2024
df028e2
fixing node issues, range issue
risingsunomi Oct 17, 2024
e5a1939
fixing node issues, range issue
risingsunomi Oct 17, 2024
d03a85c
Merge branch 'main' into pr139-dev-oct24
risingsunomi Oct 17, 2024
69a8955
Merge pull request #28 from risingsunomi/pr139-dev-oct24
risingsunomi Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,7 @@ Unlike other distributed inference frameworks, exo does not use a master-worker

Exo supports different [partitioning strategies](exo/topology/partitioning_strategy.py) to split up a model across devices. The default partitioning strategy is [ring memory weighted partitioning](exo/topology/ring_memory_weighted_partitioning_strategy.py). This runs an inference in a ring where each device runs a number of model layers proportional to the memory of the device.

<p>
<picture>
<img alt="ring topology" src="docs/ring-topology.png" width="30%" height="30%">
</picture>
</p>

!["A screenshot of exo running 5 nodes](docs/exo-screenshot.png)

## Installation

Expand Down
Binary file added docs/exo-screenshot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ async def download_file(
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
if DEBUG >= 2: print(f"Range not satisfiable {file_path=} {total_size=} {downloaded_size=}")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
Expand Down
7 changes: 6 additions & 1 deletion exo/inference/torch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def infer_caching(
cached_iids = {"input_ids": past_iids.tolist()}

if DEBUG >= 4:
print(f"cached_iids len: {len(cached_iids)}")
print(f"cached_iids: {cached_iids}")

return (past_iids, cached_iids)
Expand Down Expand Up @@ -126,7 +127,11 @@ async def async_forward(
attention_mask=attention_mask
))

return result
if DEBUG >=4:
print("async_forward")
print(f"result: {result}")

return result[0], result[1], result[2]

async def async_logit_sample(
self,
Expand Down
39 changes: 26 additions & 13 deletions exo/inference/torch/model/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from exo.inference.torch.utils import extract_layers

from transformers import (
AutoConfig,
AutoModelForCausalLM,
DynamicCache,
Cache,
Expand Down Expand Up @@ -63,6 +62,7 @@ def __init__(
self.position_ids = None
self.causal_mask = None
self.local_model_path = local_model_path
self.is_sharded_model = False

# setup logit processors
self.logits_processor = LogitsProcessorList([
Expand All @@ -82,25 +82,30 @@ def __init__(
# setup pytorch and transformer llm
try:
if weight_map:
self.llm_model_config = self.load_sharded_model(
print("loading shard model")
self.llm_model = self.load_sharded_model(
shard,
weight_map,
offload_buffers=self.offload_buffers
)

self.is_sharded_model = True

# clear out edited safetensor json
# this is needed because shard downloader just
# appends and not redownloads the file
os.remove(self.model_safetensors_path)

self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device)
self.model = self.llm_model.model.to(self.device)
else:
self.llm_model_config = AutoConfig.from_pretrained(
print("loading full model")
self.llm_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.local_model_path,
torch_dtype=self.dtype,
device_map=self.device_map,
offload_buffers=self.offload_buffers
)

self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device)
offload_buffers=offload_buffers
).to(self.device)

self.model = self.llm_model.model.to(self.device)
except Exception as err:
Expand All @@ -112,7 +117,7 @@ def load_sharded_model(
shard: Shard,
weight_map: dict,
offload_buffers: bool
) -> AutoConfig:
) -> AutoModelForCausalLM:
"""
Loads sharded version of model where only needed
weights are loaded for necessary layers
Expand Down Expand Up @@ -154,13 +159,18 @@ def load_sharded_model(
shard_num_hidden_layers = shard.end_layer - shard.start_layer
if DEBUG >= 4:
print(f"config with {shard_num_hidden_layers} layers")
return AutoConfig.from_pretrained(

llm_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.local_model_path,
device_map=self.device_map,
torch_dtype=self.dtype,
offload_buffers=offload_buffers,
local_files_only=True,
num_hidden_layers=shard_num_hidden_layers
)

return llm_model.to(self.device)

except Exception as err:
print(f"err: {err}")
raise
Expand Down Expand Up @@ -255,11 +265,14 @@ def forward(
self.cache_position = model_inputs["cache_position"]
self.past_key_values = model_inputs["past_key_values"]

if DEBUG >= 4:
print(f"model_inputs: {model_inputs}")
if DEBUG >= 4:
print(f"model_inputs: {model_inputs}")

# run through decoder layers
layer_amt = range(self.shard.end_layer - self.shard.start_layer)
if self.is_sharded_model:
layer_amt = range(self.shard.end_layer - self.shard.start_layer)
else:
layer_amt = range(self.shard.start_layer, self.shard.end_layer)

if DEBUG >= 4:
print(f"hidden_states: {self.hidden_states}")
Expand Down Expand Up @@ -304,7 +317,7 @@ def forward(
# shard is last layer says true at the start and not detecting last layer correctly
if self.shard.is_last_layer():
self.hidden_states = self.model.norm(self.hidden_states)
if use_legacy_cache:
if use_legacy_cache and self.next_decoder_cache is not None:
self.past_key_values = self.next_decoder_cache.to_legacy_cache()
else:
self.past_key_values = self.next_decoder_cache
Expand Down
53 changes: 35 additions & 18 deletions exo/inference/torch/tests/test_split_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
from exo.inference.shard import Shard
from exo.inference.torch.utils import print_cuda_vram_stats

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_model(
repo_id: str,
shard: Shard,
model_path: Path,
weight_map: Optional[dict],
device: Optional[str] = "cuda"
device: Optional[torch.device] = torch.device("cpu")
) -> Optional[AutoModelForCausalLM]:
"""
load model by layer and safetensors
Expand All @@ -34,6 +33,24 @@ def load_model(
print("load_model called")
model_st_snapshot = model_path/"model.safetensors.index.json"

if os.environ.get("TORCH_DEVICE"):
device = torch.device(os.environ["TORCH_DEVICE"])
elif torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")

torch.set_default_device(device)

# setup cude dtype
dtype = torch.get_default_dtype()

# setup device_map
if os.environ.get("TORCH_DEVICE_MAP"):
device_map = os.environ["TORCH_DEVICE_MAP"]
else:
device_map = str(device)

if weight_map:
layer_weight_map = {}
non_layer_weights = []
Expand Down Expand Up @@ -89,18 +106,18 @@ def load_model(
# setup the weight range for init_weights
shard_num_hidden_layers = shard.end_layer - shard.start_layer
print(f"Setting up LLM config with {shard_num_hidden_layers} hidden layers")
llm_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=model_path,
device_map="cuda",
offload_buffers=True,
local_files_only=True,
num_hidden_layers=shard_num_hidden_layers
)

# load model with layer edits
# or whole model if no weight_map
print(f"Loading sharded AutoModelForCausalLM from {model_path}")
shard_model = AutoModelForCausalLM.from_config(llm_config).to(device)
shard_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path,
device_map=device_map,
torch_dtype=dtype,
offload_buffers=True,
local_files_only=True,
num_hidden_layers=shard_num_hidden_layers
).to(device)

print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
Expand Down Expand Up @@ -137,8 +154,6 @@ def load_model(
print(f"Prompt: {prompt}\n")
print(f"Response: {response}\n")

print_ram_stats()

# have to clear out edited model safetensors mst_json
os.remove(model_st_snapshot)

Expand Down Expand Up @@ -167,13 +182,15 @@ async def test_split_model(
weight_map = await get_weight_map(model_id)

load_model(
model_id,
shard,
model_path,
weight_map
)

if __name__ == "__main__":
n_layers = int(os.environ["N_LAYERS"]) if os.environ.get("N_LAYERS") else 32
start_layer = int(os.environ["START_LAYER"]) if os.environ.get("START_LAYER") else 0
end_layer = int(os.environ["END_LAYER"]) if os.environ.get("END_LAYER") else int(n_layers/2)
#Qwen/Qwen2.5-3B
#try:
# print("\n-------- Test Qwen/Qwen2.5-3B-Instruct ----------\n")
Expand All @@ -191,9 +208,9 @@ async def test_split_model(
print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n")
asyncio.run(test_split_model(
"unsloth/Meta-Llama-3.1-8B-Instruct",
0,
6,
32
start_layer,
end_layer,
n_layers
))
except Exception as err:
print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n")
print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.1-8B-Instruct TEST FAILED \n{err}\n")
4 changes: 0 additions & 4 deletions exo/inference/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def extract_layers(

non_layer_weights = sorted(non_layer_weights, key=lambda x: x[1])

print(non_layer_weights)
print(f"first: {shard.is_first_layer()}")
print(f"last: {shard.is_last_layer()}")

if shard.is_first_layer():
# this assumes at max only one first weight non-layer for model
first_weight = non_layer_weights[0]
Expand Down
7 changes: 7 additions & 0 deletions exo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,11 @@
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2-0.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=24),
"TorchDynamicShardInferenceEngine": Shard(model_id="Qwen/Qwen2-0.5B-Instruct", start_layer=0, end_layer=0, n_layers=24),
},
### nemotron
"nemotron-70b": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),
},
"nemotron-70b-bf16": {
"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),
},
}
23 changes: 21 additions & 2 deletions exo/tinychat/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,32 @@ main {
border-right: 2px solid var(--secondary-color);
box-shadow: 10px 10px 20px 2px var(--secondary-color-transparent);
}
.download-progress{
margin-bottom: 20em;
.download-progress {
margin-bottom: 12em;
overflow-y: auto;
}
.message > pre {
white-space: pre-wrap;
}

.progress-bar-container {
width: 100%;
background-color: #e0e0e0;
border-radius: 4px;
margin: 10px 0;
}
.progress-bar {
height: 20px;
border-radius: 4px;
transition: width 0.5s ease-in-out;
}
.progress-bar.complete {
background-color: #4CAF50;
}
.progress-bar.in-progress {
background-color: #2196F3;
}

.toast {
width: 100%; /* Take up the full width of the page */
background-color: #fc2a2a; /* Dark background color */
Expand Down
36 changes: 27 additions & 9 deletions exo/tinychat/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
<option value="llama-3.1-405b">Llama 3.1 405B</option>
<option value="llama-3-8b">Llama 3 8B</option>
<option value="llama-3-70b">Llama 3 70B</option>
<option value="nemotron-70b">Nemotron 70B</option>
<option value="nemotron-70b-bf16">Nemotron 70B (BF16)</option>
<option value="mistral-nemo">Mistral Nemo</option>
<option value="mistral-large">Mistral Large</option>
<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
Expand Down Expand Up @@ -153,16 +155,32 @@ <h3 x-text="new Date(_state.time).toLocaleString()"></h3>
</div>

<!-- Download Progress Section -->
<template x-if="downloadProgress">
<div class="download-progress message message-role-assistant">
<h2>Download Progress</h2>
<div class="download-progress-node">
<p><strong>Model:</strong> <span x-text="downloadProgress.repo_id + '@' + downloadProgress.repo_revision"></span></p>
<p><strong>Progress:</strong> <span x-text="`${downloadProgress.downloaded_bytes_display} / ${downloadProgress.total_bytes_display} (${downloadProgress.percentage}%)`"></span></p>
<p><strong>Speed:</strong> <span x-text="downloadProgress.overall_speed_display || 'N/A'"></span></p>
<p><strong>ETA:</strong> <span x-text="downloadProgress.overall_eta_display || 'N/A'"></span></p>
<template x-if="downloadProgress && downloadProgress.length > 0">
<div class="download-progress message message-role-assistant">
<h2>Download Progress</h2>
<br>
<template x-for="(progress, index) in downloadProgress" :key="index">
<div class="download-progress-node">
<br>
<h3 x-text="`Download ${index + 1}`"></h3>
<p><strong>Model:</strong> <span x-text="progress.repo_id + '@' + progress.repo_revision"></span></p>
<p><strong>Status:</strong> <span x-text="progress.status"></span></p>
<div class="progress-bar-container">
<div class="progress-bar"
:class="progress.isComplete ? 'complete' : 'in-progress'"
:style="`width: ${progress.percentage}%;`">
</div>
</div>
<template x-if="!progress.isComplete">
<div>
<p><strong>Progress:</strong> <span x-text="`${progress.downloaded_bytes_display} / ${progress.total_bytes_display} (${progress.percentage}%)`"></span></p>
<p><strong>Speed:</strong> <span x-text="progress.overall_speed_display || 'N/A'"></span></p>
<p><strong>ETA:</strong> <span x-text="progress.overall_eta_display || 'N/A'"></span></p>
</div>
</template>
</div>
</template>
</div>
</div>
</template>


Expand Down
Loading