Skip to content

Commit

Permalink
Add export --output-snapshot-path snap.tc, and `--snapshot-path sna…
Browse files Browse the repository at this point in the history
…p.tc` (#1465)

* support model snapshots to save quantized models

* import set backend

---------

Co-authored-by: Michael Gschwind <[email protected]>
  • Loading branch information
mikekgfb and mike94043 authored Jan 31, 2025
1 parent 4356b4c commit 7cbf2a3
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 2 deletions.
33 changes: 33 additions & 0 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class BuilderArgs:
gguf_kwargs: Optional[Dict[str, Any]] = None
dso_path: Optional[Union[Path, str]] = None
aoti_package_path: Optional[Union[Path, str]] = None
snapshot_path: Optional[Union[Path, str]] = None
pte_path: Optional[Union[Path, str]] = None
device: Optional[str] = None
precision: torch.dtype = torch.float32
Expand Down Expand Up @@ -87,6 +88,7 @@ def __post_init__(self):
or (self.dso_path and Path(self.dso_path).is_file())
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
or (self.pte_path and Path(self.pte_path).is_file())
or (self.snapshot_path and Path(self.snapshot_path).is_file())
):
raise RuntimeError(
"need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
Expand Down Expand Up @@ -142,6 +144,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
dso_path = getattr(args, "dso_path", None)
pte_path = getattr(args, "pte_path", None)
aoti_package_path = getattr(args, "aoti_package_path", None)
snapshot_path = getattr(args, "snapshot_path", None)

is_chat_model = False
if args.is_chat_model:
Expand Down Expand Up @@ -169,6 +172,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
output_pte_path = getattr(args, "output_pte_path", None)
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
output_dso_path = getattr(args, "output_dso_path", None)
output_snapshot_path = getattr(args, "output_snapshot_path", None)
if output_pte_path and args.dtype.startswith("fast"):
if args.dtype == "fast":
# As per Kimish, float32 should be faster on ET XNNPACK
Expand Down Expand Up @@ -206,6 +210,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
dso_path=dso_path,
aoti_package_path=aoti_package_path,
pte_path=pte_path,
snapshot_path=snapshot_path,
device=args.device,
precision=dtype,
setup_caches=(
Expand Down Expand Up @@ -631,6 +636,34 @@ def do_nothing(max_batch_size, max_seq_length):
model = PTEModel(config, builder_args.pte_path)
except Exception:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
elif builder_args.snapshot_path:
# Resolve ModelArgs for constructing the PTEModel
# If a manual params_path is provided, use that
if builder_args.params_path:
config: ModelArgs = ModelArgs.from_params(builder_args.params_path)
else:
# TODO: Instead of loading the whole model, refactor to call a
# helper that generate just model.config
with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args)
device_sync(device=builder_args.device)
config = model.config
model = None
try:
model = torch.load(builder_args.snapshot_path, weights_only=False)
except Exception:
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
# _active_backend() does not allow DSO & AOTI to be true.
# Choose either.
from torchchat.utils.build_utils import set_backend
set_backend (dso=True, pte=False, aoti_package=False)
if (model.config != config):
raise RuntimeError("loaded model architecture mismatch")
##
## import all libraries with custom kernels ans custom operators
## that quantize may be pulling in
##

elif builder_args.distributed:
pp_degree = builder_args.pp
tp_degree = builder_args.tp
Expand Down
14 changes: 13 additions & 1 deletion torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ def _add_export_output_path_args(parser) -> None:
default=None,
help="Output to the specified AOT Inductor .dso model file",
)
exclusive_parser.add_argument(
"--output-snapshot-path",
type=str,
default=None,
help="Output to the specified PyTorch model and sha256 file",
)
exclusive_parser.add_argument(
"--output-aoti-package-path",
type=str,
Expand Down Expand Up @@ -254,7 +260,13 @@ def _add_exported_input_path_args(parser) -> None:
default=None,
help="Use the specified ExecuTorch .pte model file",
)

exclusive_parser.add_argument(
"--snapshot-path",
type=Path,
default=None,
help="Use the specified torchchat snaphot .tc model file",
)


# Add CLI Args related to JIT downloading of model artifacts
def _add_jit_downloading_args(parser) -> None:
Expand Down
49 changes: 48 additions & 1 deletion torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,31 @@
default_device = "cpu"


"""
Export Snapshot
"""


def export_snapshot(
model: nn.Module,
device: Optional[str] = None,
output_path: str = "model-snapshot.tc",
) -> str:
"""
Export the model as snapshot.
Args:
model: The model to be exported.
device: The device to run the model on.
output_path: The path to save the exported model.
Returns:
The path to the exported model.
"""
assert output_path.endswith(".tc"), "use .tc extension for snapshots"
torch.save(model, output_path)
return output_path


"""
Export for Server
"""
Expand Down Expand Up @@ -72,6 +97,7 @@ def export_for_server(
"aot_inductor.package": package,
"aot_inductor.metadata": metadata or {},
}

if not package:
options = {"aot_inductor.output_path": output_path}

Expand Down Expand Up @@ -373,14 +399,15 @@ def main(args):

output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path
output_snapshot_path = args.output_snapshot_path
output_aoti_package_path = args.output_aoti_package_path

if output_pte_path and builder_args.device != "cpu":
print(
f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={builder_args.device} setting."
)
builder_args.device = "cpu"
elif "mps" in builder_args.device:
elif (output_pte_path or output_dso_path or output_aoti_package_path) and "mps" in builder_args.device:
print("Warning! Device MPS not supported for export. Exporting for device CPU.")
builder_args.device = "cpu"

Expand Down Expand Up @@ -417,6 +444,7 @@ def main(args):
model_to_pte = model
model_to_dso = model
model_to_aoti_package = model
model_to_snapshot = model
else:
if output_pte_path:
_set_gguf_kwargs(builder_args, is_et=True, context="export")
Expand All @@ -436,6 +464,15 @@ def main(args):
model_to_dso = model_to_aoti_package
_unset_gguf_kwargs(builder_args)

if output_snapshot_path:
_set_gguf_kwargs(builder_args, is_et=False, context="export")
model_to_snapshot = _initialize_model(
builder_args,
quantize,
support_tensor_subclass=False,
)
_unset_gguf_kwargs(builder_args)

with torch.no_grad():
if output_pte_path:
output_pte_path = str(os.path.abspath(output_pte_path))
Expand Down Expand Up @@ -483,3 +520,13 @@ def main(args):
package=True,
metadata=metadata,
)

if output_snapshot_path:
output_snapshot_path = str(os.path.abspath(output_snapshot_path))
print(f"Exporting model using Snapshot to {output_snapshot_path}")
export_snapshot(
model_to_snapshot,
builder_args.device,
output_snapshot_path,
)

0 comments on commit 7cbf2a3

Please sign in to comment.