Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show downloaded models, improve error handling, ability to delete models, side bar with more detail, button to go back to chat history #456

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c7dd312
adding logic to check which models are downloaded
cadenmackenzie Nov 13, 2024
de09e2a
reusing helper function to get cached directory
cadenmackenzie Nov 13, 2024
7d7bdd8
removing uneccesary console logs and fixing order of variables in ind…
cadenmackenzie Nov 13, 2024
fb32a85
removing error separtation so I can put in different PR
cadenmackenzie Nov 13, 2024
59f5b6d
adding back in set error message
cadenmackenzie Nov 13, 2024
25d67f5
cleaning up logging in index.js
cadenmackenzie Nov 13, 2024
95ce665
removing unneccesary css
cadenmackenzie Nov 13, 2024
3eb726c
removing sorting of models by name
cadenmackenzie Nov 13, 2024
cbeb1b3
fix safari issue
dtnewman Nov 14, 2024
372d873
Merge pull request #1 from dtnewman/dn/downloadModelsV2
cadenmackenzie Nov 14, 2024
d9aabd7
working versions
cadenmackenzie Nov 14, 2024
dfcf513
removing is_model_downloaded method and changing how downloaded varia…
cadenmackenzie Nov 14, 2024
972074e
reducing redundent checks
cadenmackenzie Nov 14, 2024
dd38924
removing checking of percentage for models that are not found locally
cadenmackenzie Nov 14, 2024
bd2985a
Merge pull request #2 from cadenmackenzie/downloadedModelsV2Revisions
cadenmackenzie Nov 14, 2024
649157d
creating HFShardDownloader with quick_check true so it doesnt start d…
cadenmackenzie Nov 17, 2024
c923ef6
modifying how its being displayed becuase now calculating overall per…
cadenmackenzie Nov 18, 2024
c61f40c
adding helper funciton to check file download. also modifying downloa…
cadenmackenzie Nov 18, 2024
dec79ac
modify get_shard_download_status to use helper function
cadenmackenzie Nov 18, 2024
4c6fda7
modifying helper fucntion checking size to follow redirect for .safet…
cadenmackenzie Nov 18, 2024
3ac8687
adding redirect for all requests
cadenmackenzie Nov 18, 2024
3256051
comment
cadenmackenzie Nov 18, 2024
db610f5
removing traceback
cadenmackenzie Nov 18, 2024
6a7de04
removing path update
cadenmackenzie Nov 18, 2024
fad0591
Merge pull request #4 from cadenmackenzie/hf_helperRefactor
cadenmackenzie Nov 18, 2024
b77362b
moving os import
cadenmackenzie Nov 18, 2024
695ab34
removing import get_hf_home
cadenmackenzie Nov 18, 2024
8135437
fixing formatting
cadenmackenzie Nov 19, 2024
91276cc
fixing formatting
cadenmackenzie Nov 19, 2024
8ee6cc3
yapf formatting
cadenmackenzie Nov 19, 2024
0d50167
yapf in download_file
cadenmackenzie Nov 19, 2024
2cdd55d
Merge branch 'main' into downloadedModelsV2
cadenmackenzie Nov 21, 2024
1ca11ea
defining optional
cadenmackenzie Nov 21, 2024
7a8c722
Merge pull request #5 from cadenmackenzie/main
cadenmackenzie Nov 21, 2024
7e6c69f
remvoing console log
cadenmackenzie Nov 21, 2024
31ce70f
working with side bar to choose model, show download percentage, sele…
cadenmackenzie Nov 21, 2024
fb3baf5
adding amount that has been downloaded if model is not fully downloaded
cadenmackenzie Nov 21, 2024
619df1d
adding functionality to delete the models if there is part of the mod…
cadenmackenzie Nov 22, 2024
a9838a8
formatting handle_delete_model
cadenmackenzie Nov 22, 2024
bc905cd
formatting deleteModel
cadenmackenzie Nov 22, 2024
39139c1
fixiing required engines definition
cadenmackenzie Nov 22, 2024
c469d53
Merge pull request #6 from cadenmackenzie/modelSideBarV2
cadenmackenzie Nov 24, 2024
e16170c
backend endpoint now uses SSE to send each model as its loaded. also …
cadenmackenzie Nov 24, 2024
db45ed6
Merge pull request #7 from cadenmackenzie/downloadedModelsV2_dynamica…
cadenmackenzie Nov 24, 2024
bc83d1f
Merge pull request #8 from cadenmackenzie/main
cadenmackenzie Nov 24, 2024
ded80b0
ensuring requests do not stack up by moving polling to while loop wit…
cadenmackenzie Nov 26, 2024
e99a739
adding a fetch to get initail model object to show models before goin…
cadenmackenzie Nov 26, 2024
445ba7a
Merge pull request #9 from cadenmackenzie/downloadedModelsV2_showingM…
cadenmackenzie Nov 26, 2024
5968a93
Merge pull request #10 from cadenmackenzie/main
cadenmackenzie Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions exo/api/chatgpt_api.py
AlexCheema marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from typing import Callable
import os
from exo.download.hf.hf_helpers import get_hf_home


class Message:
Expand Down Expand Up @@ -200,25 +202,93 @@ async def middleware(request):
async def handle_root(self, request):
return web.FileResponse(self.static_dir/"index.html")

def is_model_downloaded(self, model_name):
if DEBUG >= 2:
print(f"\nChecking if model {model_name} is downloaded:")

cache_dir = get_hf_home() / "hub"
repo = get_repo(model_name, self.inference_engine_classname)

if DEBUG >= 2:
print(f" Cache dir: {cache_dir}")
print(f" Repo: {repo}")
print(f" Engine: {self.inference_engine_classname}")

if not repo:
return False

# Convert repo path (e.g. "mlx-community/Llama-3.2-1B-Instruct-4bit")
# to directory format (e.g. "models--mlx-community--Llama-3.2-1B-Instruct-4bit")
repo_parts = repo.split('/')
formatted_path = f"models--{repo_parts[0]}--{repo_parts[1]}"
repo_path = cache_dir / formatted_path / "snapshots"

if DEBUG >= 2:
print(f" Looking in: {repo_path}")

if repo_path.exists():
# Look for the most recent snapshot directory
snapshots = list(repo_path.glob("*"))
if snapshots:
latest_snapshot = max(snapshots, key=lambda p: p.stat().st_mtime)

# Check for model files and their index files
model_files = (
list(latest_snapshot.glob("model.safetensors")) +
list(latest_snapshot.glob("model.safetensors.index.json")) +
list(latest_snapshot.glob("*.mlx"))
)

if DEBUG >= 2:
print(f" Latest snapshot: {latest_snapshot}")
print(f" Found files: {model_files}")

# Model is considered downloaded if we find either:
# 1. model.safetensors file
# 2. model.safetensors.index.json file (for sharded models)
# 3. *.mlx file
return len(model_files) > 0

if DEBUG >= 2:
print(" No valid model files found")
return False

async def handle_model_support(self, request):
return web.json_response({
"model pool": {
model_name: pretty_name.get(model_name, model_name)
for model_name in [
model_id for model_id, model_info in model_cards.items()
if all(map(
lambda engine: engine in model_info["repo"],
list(dict.fromkeys([
inference_engine_classes.get(engine_name, None)
for engine_list in self.node.topology_inference_engines_pool
for engine_name in engine_list
if engine_name is not None
] + [self.inference_engine_classname]))
))
]
}
})

try:
model_pool = {}

for model_name, pretty in pretty_name.items():
if model_name in model_cards:
model_info = model_cards[model_name]

# Get required engines
required_engines = list(dict.fromkeys([
inference_engine_classes.get(engine_name, None)
for engine_list in self.node.topology_inference_engines_pool
for engine_name in engine_list
if engine_name is not None
] + [self.inference_engine_classname]))

# Check if model supports required engines
if all(map(lambda engine: engine in model_info["repo"], required_engines)):
is_downloaded = self.is_model_downloaded(model_name)
if DEBUG >= 2:
print(f"Model {model_name} download status: {is_downloaded}")

model_pool[model_name] = {
"name": pretty,
"downloaded": is_downloaded
}

return web.json_response({"model pool": model_pool})
except Exception as e:
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response(
{"detail": f"Server error: {str(e)}"},
status=500
)

async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])

Expand Down
10 changes: 5 additions & 5 deletions exo/tinychat/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
<body>
<main x-data="state" x-init="console.log(endpoint)">
<!-- Error Toast -->
<div x-show="errorMessage" x-transition.opacity class="toast">
<div x-show="errorMessage !== null" x-transition.opacity class="toast">
<div class="toast-header">
<span class="toast-error-message" x-text="errorMessage.basic"></span>
<span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
<div class="toast-header-buttons">
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
class="toast-expand-button"
x-show="errorMessage.stack">
x-show="errorMessage?.stack">
<span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
</button>
<button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
Expand All @@ -41,11 +41,11 @@
</div>
</div>
<div class="toast-content" x-show="errorExpanded" x-transition>
<span x-text="errorMessage.stack"></span>
<span x-text="errorMessage?.stack || ''"></span>
</div>
</div>
<div class="model-selector">
<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" class='model-select'>
</select>
</div>
<div @popstate.window="
Expand Down
114 changes: 35 additions & 79 deletions exo/tinychat/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ document.addEventListener("alpine:init", () => {
home: 0,
generating: false,
endpoint: `${window.location.origin}/v1`,

// Initialize error message structure
errorMessage: null,
errorExpanded: false,
errorTimeout: null,
Expand All @@ -38,6 +40,9 @@ document.addEventListener("alpine:init", () => {

// Start polling for download progress
this.startDownloadProgressPolling();

// Call populateSelector immediately after initialization
this.populateSelector();
},

removeHistory(cstate) {
Expand Down Expand Up @@ -77,50 +82,25 @@ document.addEventListener("alpine:init", () => {
async populateSelector() {
try {
const response = await fetch(`${window.location.origin}/modelpool`);
const responseText = await response.text(); // Get raw response text first

if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}

// Try to parse the response text
let responseJson;
try {
responseJson = JSON.parse(responseText);
} catch (parseError) {
console.error('Failed to parse JSON:', parseError);
throw new Error(`Invalid JSON response: ${responseText}`);
}

const sel = document.querySelector(".model-select");
if (!sel) {
throw new Error("Could not find model selector element");
const errorText = await response.text();
throw new Error(`HTTP error! status: ${response.status}\n${errorText}`);
}

// Clear the current options and add new ones
const data = await response.json();
const sel = document.querySelector('.model-select');
sel.innerHTML = '';

const modelDict = responseJson["model pool"];
if (!modelDict) {
throw new Error("Response missing 'model pool' property");
}

Object.entries(modelDict).forEach(([key, value]) => {
// Use the model pool entries in their original order
Object.entries(data["model pool"]).forEach(([key, value]) => {
const opt = document.createElement("option");
opt.value = key;
opt.textContent = value;
opt.textContent = `${value.name}${value.downloaded ? ' (downloaded)' : ''}`;
sel.appendChild(opt);
});

// Set initial value to the first model
const firstKey = Object.keys(modelDict)[0];
if (firstKey) {
sel.value = firstKey;
this.cstate.selectedModel = firstKey;
}
} catch (error) {
console.error("Error populating model selector:", error);
this.errorMessage = `Failed to load models: ${error.message}`;
this.setError(error);
}
},

Expand Down Expand Up @@ -169,29 +149,7 @@ document.addEventListener("alpine:init", () => {
this.processMessage(value);
} catch (error) {
console.error('error', error);
const errorDetails = {
message: error.message || 'Unknown error',
stack: error.stack,
name: error.name || 'Error'
};

this.errorMessage = {
basic: `${errorDetails.name}: ${errorDetails.message}`,
stack: errorDetails.stack
};

// Clear any existing timeout
if (this.errorTimeout) {
clearTimeout(this.errorTimeout);
}

// Only set the timeout if the error details aren't expanded
if (!this.errorExpanded) {
this.errorTimeout = setTimeout(() => {
this.errorMessage = null;
this.errorExpanded = false;
}, 30 * 1000);
}
this.setError(error);
this.generating = false;
}
},
Expand Down Expand Up @@ -309,29 +267,7 @@ document.addEventListener("alpine:init", () => {
}
} catch (error) {
console.error('error', error);
const errorDetails = {
message: error.message || 'Unknown error',
stack: error.stack,
name: error.name || 'Error'
};

this.errorMessage = {
basic: `${errorDetails.name}: ${errorDetails.message}`,
stack: errorDetails.stack
};

// Clear any existing timeout
if (this.errorTimeout) {
clearTimeout(this.errorTimeout);
}

// Only set the timeout if the error details aren't expanded
if (!this.errorExpanded) {
this.errorTimeout = setTimeout(() => {
this.errorMessage = null;
this.errorExpanded = false;
}, 30 * 1000);
}
this.setError(error);
} finally {
this.generating = false;
}
Expand Down Expand Up @@ -467,6 +403,26 @@ document.addEventListener("alpine:init", () => {
this.fetchDownloadProgress();
}, 1000); // Poll every second
},

// Add a helper method to set errors consistently
setError(error) {
this.errorMessage = {
basic: error.message || "An unknown error occurred",
stack: error.stack || ""
};
this.errorExpanded = false;

if (this.errorTimeout) {
clearTimeout(this.errorTimeout);
}

if (!this.errorExpanded) {
this.errorTimeout = setTimeout(() => {
this.errorMessage = null;
this.errorExpanded = false;
}, 30 * 1000);
}
},
}));
});

Expand Down