Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle MPS with for export and generate+compile #716

Merged
merged 4 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ For more details on quantization and what settings to use for your use
case visit our [Quanitization documentation](docs/quantization.md) or
run `python3 torchchat.py export`

[end default]:
[end default]: end

### Deploy and run on iOS

Expand Down
8 changes: 8 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,14 @@ def arg_init(args):
args.quantize.get("executor", {}).get("accelerator", args.device)
)

if "mps" in args.device:
if args.compile or args.compile_prefill:
print(
"Warning: compilation is not available with device MPS, ignoring option to engage compilation"
)
args.compile = False
args.compile_prefill = False

if hasattr(args, "seed") and args.seed:
torch.manual_seed(args.seed)
return args
5 changes: 4 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def main(args):
output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path

if output_dso_path and "mps" in builder_args.device:
print("Warning! Device MPS not supported for export. Exporting for device CPU.")
builder_args.device = "cpu"

# TODO: clean this up
# This mess is because ET does not support _weight_int4pack_mm right now
if not builder_args.gguf_path:
Expand Down Expand Up @@ -85,7 +89,6 @@ def main(args):
with torch.no_grad():
if output_pte_path:
output_pte_path = str(os.path.abspath(output_pte_path))
print(f">{output_pte_path}<")
if executorch_export_available:
print(f"Exporting model using ExecuTorch to {output_pte_path}")
export_model_et(
Expand Down
Loading