Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Package exo fixes #473

Merged
merged 28 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import aiohttp_cors
import traceback
import os
import signal
import sys
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
Expand Down Expand Up @@ -188,7 +189,7 @@ async def handle_quit(self, request):
response = web.json_response({"detail": "Quit signal received"}, status=200)
await response.prepare(request)
await response.write_eof()
await shutdown(signal.SIGINT, asyncio.get_event_loop())
await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node)

async def timeout_middleware(self, app, handler):
async def middleware(request):
Expand Down
15 changes: 7 additions & 8 deletions exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
from exo.inference.shard import Shard
import aiofiles
from aiofiles import os as aios

T = TypeVar("T")

Expand Down Expand Up @@ -109,15 +108,15 @@ async def move_models_to_hf(seed_dir: Union[str, Path]):
"""Move model in resources folder of app to .cache/huggingface/hub"""
source_dir = Path(seed_dir)
dest_dir = get_hf_home()/"hub"
await aios.makedirs(dest_dir, exist_ok=True)
async for path in source_dir.iterdir():
AlexCheema marked this conversation as resolved.
Show resolved Hide resolved
if path.is_dir() and path.startswith("models--"):
await aios.makedirs(dest_dir, exist_ok=True)
async for path in await aios.listdir(source_dir):
if await path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir / path.name
if dest_path.exists():
if DEBUG>=1: print(f"skipping moving {dest_path}. File already exists")
else:
try:
await aios.rename(str(path), str(dest_path))

except Exception as e:
print(e)


async def fetch_file_list(session, repo_id, revision, path=""):
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
Expand Down
5 changes: 2 additions & 3 deletions exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def get_all_ip_addresses():
return ["localhost"]


async def shutdown(signal, loop):
async def shutdown(signal, loop, node):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
print("Thank you for using exo.")
Expand All @@ -246,8 +246,7 @@ async def shutdown(signal, loop):
[task.cancel() for task in server_tasks]
print(f"Cancelling {len(server_tasks)} outstanding tasks")
await asyncio.gather(*server_tasks, return_exceptions=True)
await server.stop()
loop.stop()
await node.server.stop()
AlexCheema marked this conversation as resolved.
Show resolved Hide resolved


def is_frozen():
Expand Down
2 changes: 1 addition & 1 deletion exo/inference/mlx/sharded_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def load_shard(
processor.encode = processor.tokenizer.encode
return model, processor
else:
tokenizer = await resolve_tokenizer(model_path)
AlexCheema marked this conversation as resolved.
Show resolved Hide resolved
tokenizer = load_tokenizer(model_path, tokenizer_config)
return model, tokenizer


Expand Down
12 changes: 9 additions & 3 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
finally:
node.on_token.deregister(callback_id)

def clean_path(path):
"""Clean and resolve path"""
if path.startswith("Optional("):
path = path.strip('Optional("').rstrip('")')
return os.path.expanduser(path)

async def main():
loop = asyncio.get_running_loop()
Expand All @@ -211,13 +216,14 @@ async def main():

if not args.models_seed_dir is None:
try:
await move_models_to_hf(args.models_seed_dir)
models_seed_dir = clean_path(args.models_seed_dir)
await move_models_to_hf(models_seed_dir)
except Exception as e:
print(f"Error moving models to .cache/huggingface: {e}")

# Use a more direct approach to handle signals
def handle_exit():
asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node))
AlexCheema marked this conversation as resolved.
Show resolved Hide resolved

if platform.system() != "Windows":
for s in [signal.SIGINT, signal.SIGTERM]:
Expand All @@ -244,7 +250,7 @@ def run():
except KeyboardInterrupt:
print("Received keyboard interrupt. Shutting down...")
finally:
loop.run_until_complete(shutdown(signal.SIGTERM, loop))
loop.run_until_complete(shutdown(signal.SIGTERM, loop, node))
AlexCheema marked this conversation as resolved.
Show resolved Hide resolved
loop.close()


Expand Down
9 changes: 5 additions & 4 deletions scripts/build_exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import os
import pkgutil
import shutil

def run():
site_packages = site.getsitepackages()[0]
Expand All @@ -14,8 +15,8 @@ def run():
"--follow-imports",
"--standalone",
"--output-filename=exo",
"--onefile",
"--python-flag=no_site"
"--python-flag=no_site",
"--onefile"
]

if sys.platform == "darwin":
Expand All @@ -24,8 +25,6 @@ def run():
"--macos-app-mode=gui",
"--macos-app-version=0.0.1",
"--macos-signed-app-name=com.exolabs.exo",
"--macos-sign-identity=auto",
"--macos-sign-notarization",
AlexCheema marked this conversation as resolved.
Show resolved Hide resolved
"--include-distribution-meta=mlx",
"--include-module=mlx._reprlib_fix",
"--include-module=mlx._os_warning",
Expand Down Expand Up @@ -53,6 +52,8 @@ def run():
try:
subprocess.run(command, check=True)
print("Build completed!")
os.makedirs('./dist/main.dist/transformers/models', exist_ok=True)
shutil.copytree(f"{site_packages}/transformers/models", "dist/main.dist/transformers/models", dirs_exist_ok=True)
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"rich==13.7.1",
"tenacity==9.0.0",
"tqdm==4.66.4",
"transformers==4.46.3",
AlexCheema marked this conversation as resolved.
Show resolved Hide resolved
"transformers==4.46.3" if (sys.version_info.major==3 and sys.version_info.minor>12) else "transformers==4.43.3" ,
AlexCheema marked this conversation as resolved.
Show resolved Hide resolved
"uuid==1.30",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
]
Expand Down