diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index a5b23dfe3..1e04800ab 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -56,6 +56,7 @@ class BuilderArgs: gguf_kwargs: Optional[Dict[str, Any]] = None dso_path: Optional[Union[Path, str]] = None aoti_package_path: Optional[Union[Path, str]] = None + snapshot_path: Optional[Union[Path, str]] = None pte_path: Optional[Union[Path, str]] = None device: Optional[str] = None precision: torch.dtype = torch.float32 @@ -87,6 +88,7 @@ def __post_init__(self): or (self.dso_path and Path(self.dso_path).is_file()) or (self.aoti_package_path and Path(self.aoti_package_path).is_file()) or (self.pte_path and Path(self.pte_path).is_file()) + or (self.snapshot_path and Path(self.snapshot_path).is_file()) ): raise RuntimeError( "need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path" @@ -142,6 +144,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": dso_path = getattr(args, "dso_path", None) pte_path = getattr(args, "pte_path", None) aoti_package_path = getattr(args, "aoti_package_path", None) + snapshot_path = getattr(args, "snapshot_path", None) is_chat_model = False if args.is_chat_model: @@ -169,6 +172,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": output_pte_path = getattr(args, "output_pte_path", None) output_aoti_package_path = getattr(args, "output_aoti_package_path", None) output_dso_path = getattr(args, "output_dso_path", None) + output_snapshot_path = getattr(args, "output_snapshot_path", None) if output_pte_path and args.dtype.startswith("fast"): if args.dtype == "fast": # As per Kimish, float32 should be faster on ET XNNPACK @@ -206,6 +210,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": dso_path=dso_path, aoti_package_path=aoti_package_path, pte_path=pte_path, + snapshot_path=snapshot_path, device=args.device, precision=dtype, setup_caches=( @@ -631,6 +636,34 @@ def do_nothing(max_batch_size, max_seq_length): model = PTEModel(config, builder_args.pte_path) except Exception: raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") + elif builder_args.snapshot_path: + # Resolve ModelArgs for constructing the PTEModel + # If a manual params_path is provided, use that + if builder_args.params_path: + config: ModelArgs = ModelArgs.from_params(builder_args.params_path) + else: + # TODO: Instead of loading the whole model, refactor to call a + # helper that generate just model.config + with measure_time("Time to load model: {time:.02f} seconds"): + model = _load_model(builder_args) + device_sync(device=builder_args.device) + config = model.config + model = None + try: + model = torch.load(builder_args.snapshot_path, weights_only=False) + except Exception: + raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}") + # _active_backend() does not allow DSO & AOTI to be true. + # Choose either. + from torchchat.utils.build_utils import set_backend + set_backend (dso=True, pte=False, aoti_package=False) + if (model.config != config): + raise RuntimeError("loaded model architecture mismatch") + ## + ## import all libraries with custom kernels ans custom operators + ## that quantize may be pulling in + ## + elif builder_args.distributed: pp_degree = builder_args.pp tp_degree = builder_args.tp diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 70f404635..1d531c709 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -207,6 +207,12 @@ def _add_export_output_path_args(parser) -> None: default=None, help="Output to the specified AOT Inductor .dso model file", ) + exclusive_parser.add_argument( + "--output-snapshot-path", + type=str, + default=None, + help="Output to the specified PyTorch model and sha256 file", + ) exclusive_parser.add_argument( "--output-aoti-package-path", type=str, @@ -254,7 +260,13 @@ def _add_exported_input_path_args(parser) -> None: default=None, help="Use the specified ExecuTorch .pte model file", ) - + exclusive_parser.add_argument( + "--snapshot-path", + type=Path, + default=None, + help="Use the specified torchchat snaphot .tc model file", + ) + # Add CLI Args related to JIT downloading of model artifacts def _add_jit_downloading_args(parser) -> None: diff --git a/torchchat/export.py b/torchchat/export.py index 829bd47db..e7cb32309 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -28,6 +28,31 @@ default_device = "cpu" +""" +Export Snapshot +""" + + +def export_snapshot( + model: nn.Module, + device: Optional[str] = None, + output_path: str = "model-snapshot.tc", +) -> str: + """ + Export the model as snapshot. + + Args: + model: The model to be exported. + device: The device to run the model on. + output_path: The path to save the exported model. + Returns: + The path to the exported model. + """ + assert output_path.endswith(".tc"), "use .tc extension for snapshots" + torch.save(model, output_path) + return output_path + + """ Export for Server """ @@ -72,6 +97,7 @@ def export_for_server( "aot_inductor.package": package, "aot_inductor.metadata": metadata or {}, } + if not package: options = {"aot_inductor.output_path": output_path} @@ -373,6 +399,7 @@ def main(args): output_pte_path = args.output_pte_path output_dso_path = args.output_dso_path + output_snapshot_path = args.output_snapshot_path output_aoti_package_path = args.output_aoti_package_path if output_pte_path and builder_args.device != "cpu": @@ -380,7 +407,7 @@ def main(args): f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={builder_args.device} setting." ) builder_args.device = "cpu" - elif "mps" in builder_args.device: + elif (output_pte_path or output_dso_path or output_aoti_package_path) and "mps" in builder_args.device: print("Warning! Device MPS not supported for export. Exporting for device CPU.") builder_args.device = "cpu" @@ -417,6 +444,7 @@ def main(args): model_to_pte = model model_to_dso = model model_to_aoti_package = model + model_to_snapshot = model else: if output_pte_path: _set_gguf_kwargs(builder_args, is_et=True, context="export") @@ -436,6 +464,15 @@ def main(args): model_to_dso = model_to_aoti_package _unset_gguf_kwargs(builder_args) + if output_snapshot_path: + _set_gguf_kwargs(builder_args, is_et=False, context="export") + model_to_snapshot = _initialize_model( + builder_args, + quantize, + support_tensor_subclass=False, + ) + _unset_gguf_kwargs(builder_args) + with torch.no_grad(): if output_pte_path: output_pte_path = str(os.path.abspath(output_pte_path)) @@ -483,3 +520,13 @@ def main(args): package=True, metadata=metadata, ) + + if output_snapshot_path: + output_snapshot_path = str(os.path.abspath(output_snapshot_path)) + print(f"Exporting model using Snapshot to {output_snapshot_path}") + export_snapshot( + model_to_snapshot, + builder_args.device, + output_snapshot_path, + ) +