diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 3a7c85937..a7f7bbba2 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -533,16 +533,16 @@ def arg_init(args): # Localized import to minimize expensive imports from torchchat.utils.build_utils import get_device_str - if args.device is None or args.device == "fast": + if args.device is None: args.device = get_device_str( args.quantize.get("executor", {}).get("accelerator", default_device) ) else: + args.device = get_device_str(args.device) executor_handler = args.quantize.get("executor", None) - if executor_handler: - if executor_handler["accelerator"] != args.device: - print('overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}') - executor_handler["accelerator"] = args.device + if executor_handler and executor_handler["accelerator"] != args.device: + print(f'overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}') + executor_handler["accelerator"] = args.device if "mps" in args.device: if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):