Skip to content

Commit

Permalink
Handle MPS with for export and generate+compile (#716)
Browse files Browse the repository at this point in the history
* Handle compile for export and generate

* typo

* typo

* typo
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent b2d199b commit d22de66
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ For more details on quantization and what settings to use for your use
case visit our [Quanitization documentation](docs/quantization.md) or
run `python3 torchchat.py export`

[end default]:
[end default]: end

### Deploy and run on iOS

Expand Down
8 changes: 8 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,14 @@ def arg_init(args):
args.quantize.get("executor", {}).get("accelerator", args.device)
)

if "mps" in args.device:
if args.compile or args.compile_prefill:
print(
"Warning: compilation is not available with device MPS, ignoring option to engage compilation"
)
args.compile = False
args.compile_prefill = False

if hasattr(args, "seed") and args.seed:
torch.manual_seed(args.seed)
return args
5 changes: 4 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def main(args):
output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path

if output_dso_path and "mps" in builder_args.device:
print("Warning! Device MPS not supported for export. Exporting for device CPU.")
builder_args.device = "cpu"

# TODO: clean this up
# This mess is because ET does not support _weight_int4pack_mm right now
if not builder_args.gguf_path:
Expand Down Expand Up @@ -85,7 +89,6 @@ def main(args):
with torch.no_grad():
if output_pte_path:
output_pte_path = str(os.path.abspath(output_pte_path))
print(f">{output_pte_path}<")
if executorch_export_available:
print(f"Exporting model using ExecuTorch to {output_pte_path}")
export_model_et(
Expand Down

0 comments on commit d22de66

Please sign in to comment.