Skip to content

Commit 54ab20b

Browse files
mikekgfbmalfet
authored andcommitted
Handle MPS with for export and generate+compile (#716)
* Handle compile for export and generate * typo * typo * typo
1 parent c3c383b commit 54ab20b

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ For more details on quantization and what settings to use for your use
216216
case visit our [Quanitization documentation](docs/quantization.md) or
217217
run `python3 torchchat.py export`
218218

219-
[end default]:
219+
[end default]: end
220220

221221
### Deploy and run on iOS
222222

Diff for: cli.py

+8
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,14 @@ def arg_init(args):
330330
args.quantize.get("executor", {}).get("accelerator", args.device)
331331
)
332332

333+
if "mps" in args.device:
334+
if args.compile or args.compile_prefill:
335+
print(
336+
"Warning: compilation is not available with device MPS, ignoring option to engage compilation"
337+
)
338+
args.compile = False
339+
args.compile_prefill = False
340+
333341
if hasattr(args, "seed") and args.seed:
334342
torch.manual_seed(args.seed)
335343
return args

Diff for: export.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def main(args):
4848
output_pte_path = args.output_pte_path
4949
output_dso_path = args.output_dso_path
5050

51+
if output_dso_path and "mps" in builder_args.device:
52+
print("Warning! Device MPS not supported for export. Exporting for device CPU.")
53+
builder_args.device = "cpu"
54+
5155
# TODO: clean this up
5256
# This mess is because ET does not support _weight_int4pack_mm right now
5357
if not builder_args.gguf_path:
@@ -85,7 +89,6 @@ def main(args):
8589
with torch.no_grad():
8690
if output_pte_path:
8791
output_pte_path = str(os.path.abspath(output_pte_path))
88-
print(f">{output_pte_path}<")
8992
if executorch_export_available:
9093
print(f"Exporting model using ExecuTorch to {output_pte_path}")
9194
export_model_et(

0 commit comments

Comments
 (0)