Skip to content

Commit

Permalink
[aoti] Add cpp packaging for aoti + loading in python
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Jul 17, 2024
1 parent ee681bf commit 0cf4e99
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 38 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__/
# C extensions
*.so

.vscode
.model-artifacts/
.venv
.torchchat
Expand All @@ -18,3 +19,7 @@ runner-aoti/cmake-out/*

# pte files
*.pte

checkpoints/
exportedModels/
cmake-out/
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ model. The first command performs the actual export, the second
command loads the exported model into the Python interface to enable
users to test the exported model.

```
```bash
# Compile
python3 torchchat.py export llama3 --output-dso-path exportedModels/llama3.so
python3 torchchat.py export llama3 --output-aoti-package-path exportedModels/llama3_artifacts --device cpu

# Execute the exported model using Python

python3 torchchat.py generate llama3 --dso-path exportedModels/llama3.so --prompt "Hello my name is"
python3 torchchat.py generate llama3 --aoti-package-path exportedModels/llama3_artifacts --prompt "Hello my name is" --device cpu
```

NOTE: If your machine has cuda add this flag for performance
Expand All @@ -172,9 +172,14 @@ To build the runner binary on your Mac or Linux:
scripts/build_native.sh aoti
```

To compile the AOTI generated artifacts into a `.so`:
```bash
make -C exportedModels/llama3_artifacts
```

Execute
```bash
cmake-out/aoti_run exportedModels/llama3.so -z `python3 torchchat.py where llama3`/tokenizer.model -l 3 -i "Once upon a time"
cmake-out/aoti_run exportedModels/llama3_artifacts/llama3_artifacts.so -z `python3 torchchat.py where llama3`/tokenizer.model -l 3 -i "Once upon a time" -d cpu
```

## Mobile Execution
Expand Down
60 changes: 48 additions & 12 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BuilderArgs:
gguf_path: Optional[Union[Path, str]] = None
gguf_kwargs: Optional[Dict[str, Any]] = None
dso_path: Optional[Union[Path, str]] = None
aoti_package_path: Optional[Union[Path, str]] = None
pte_path: Optional[Union[Path, str]] = None
device: Optional[str] = None
precision: torch.dtype = torch.float32
Expand All @@ -54,28 +55,29 @@ def __post_init__(self):
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
or (self.gguf_path and self.gguf_path.is_file())
or (self.dso_path and Path(self.dso_path).is_file())
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
or (self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError(
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
)

if self.dso_path and self.pte_path:
raise RuntimeError("specify either DSO path or PTE path, but not both")
if self.pte_path and self.aoti_package_path:
raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one")

if self.checkpoint_path and (self.dso_path or self.pte_path):
if self.checkpoint_path and (self.pte_path or self.aoti_package_path):
print(
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
"Warning: checkpoint path ignored because an exported AOTI or PTE path specified"
)
if self.checkpoint_dir and (self.dso_path or self.pte_path):
if self.checkpoint_dir and (self.pte_path or self.aoti_package_path):
print(
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
"Warning: checkpoint dir ignored because an exported AOTI or PTE path specified"
)
if self.gguf_path and (self.dso_path or self.pte_path):
if self.gguf_path and (self.pte_path or self.aoti_package_path):
print(
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
"Warning: GGUF path ignored because an exported AOTI or PTE path specified"
)
if not (self.dso_path) and not (self.pte_path):
if not (self.dso_path) and not (self.aoti_package_path):
self.prefill_possible = True

@classmethod
Expand Down Expand Up @@ -111,6 +113,7 @@ def from_args(cls, args): # -> BuilderArgs:
checkpoint_path,
checkpoint_dir,
args.dso_path,
args.aoti_package_path,
args.pte_path,
args.gguf_path,
]:
Expand Down Expand Up @@ -145,10 +148,11 @@ def from_args(cls, args): # -> BuilderArgs:
gguf_path=args.gguf_path,
gguf_kwargs=None,
dso_path=args.dso_path,
aoti_package_path=args.aoti_package_path,
pte_path=args.pte_path,
device=args.device,
precision=dtype,
setup_caches=(args.output_dso_path or args.output_pte_path),
setup_caches=(args.output_dso_path or args.output_pte_path or args.output_aoti_package_path),
use_distributed=args.distributed,
is_chat_model=is_chat_model,
)
Expand All @@ -161,6 +165,7 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
speculative_builder_args.gguf_path = None
speculative_builder_args.dso_path = None
speculative_builder_args.aoti_package_path = None
speculative_builder_args.pte_path = None
return speculative_builder_args

Expand Down Expand Up @@ -432,11 +437,12 @@ def _initialize_model(
):
print("Loading model...")

if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
if builder_args.gguf_path and (builder_args.dso_path or builder_args.aoti_package_path or builder_args.pte_path):
print("Setting gguf_kwargs for generate.")
is_dso = builder_args.dso_path is not None
is_aoti_package = builder_args.aoti_package_path is not None
is_pte = builder_args.pte_path is not None
assert not (is_dso and is_pte)
assert not (is_dso and is_aoti_package and is_pte)
assert builder_args.gguf_kwargs is None
# TODO: make GGUF load independent of backend
# currently not working because AVX int_mm broken
Expand Down Expand Up @@ -470,6 +476,36 @@ def _initialize_model(
)
except:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")

elif builder_args.aoti_package_path:
if not is_cuda_or_cpu_device(builder_args.device):
print(
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
)
builder_args.device = "cpu"

# assert (
# quantize is None or quantize == "{ }"
# ), "quantize not valid for exported PT2 model. Specify quantization during export."

with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args, only_config=True)
device_sync(device=builder_args.device)

try:
# Replace model forward with the AOT-compiled forward
# This is a hacky way to quickly demo AOTI's capability.
# model is still a Python object, and any mutation to its
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
from torch._inductor.package import load_package
model.forward = load_package(
str(builder_args.aoti_package_path.absolute()), builder_args.device
)
except:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")

elif builder_args.pte_path:
if not is_cpu_device(builder_args.device):
print(
Expand Down
22 changes: 13 additions & 9 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,46 @@ def unpack_packed_weights(

active_builder_args_dso = None
active_builder_args_pte = None
active_builder_args_aoti_package = None


def set_backend(dso, pte):
def set_backend(dso, pte, aoti_package):
global active_builder_args_dso
global active_builder_args_pte
active_builder_args_dso = dso
active_builder_args_aoti_package = aoti_package
active_builder_args_pte = pte


def use_aoti_backend() -> bool:
global active_builder_args_dso
global active_builder_args_aoti_package
global active_builder_args_pte

# eager == aoti, which is when backend has not been explicitly set
if (not active_builder_args_dso) and not (active_builder_args_pte):
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
return True

if active_builder_args_pte and active_builder_args_dso:
if active_builder_args_pte and active_builder_args_aoti_package:
raise RuntimeError(
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
)

return bool(active_builder_args_dso)
return bool(active_builder_args_dso) or bool(active_builder_args_aoti_package)


def use_et_backend() -> bool:
global active_builder_args_dso
global active_builder_args_aoti_package
global active_builder_args_pte

# eager == aoti, which is when backend has not been explicitly set
if not (active_builder_args_pte or active_builder_args_dso):
return False
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
return True

if active_builder_args_pte and active_builder_args_dso:
if active_builder_args_pte and active_builder_args_aoti_package:
raise RuntimeError(
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
)

return bool(active_builder_args_pte)
Expand Down
12 changes: 12 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ def _add_export_output_path_args(parser) -> None:
default=None,
help="Output to the specified AOT Inductor .dso model file",
)
output_path_parser.add_argument(
"--output-aoti-package-path",
type=str,
default=None,
help="Output directory for AOTInductor compiled artifacts",
)


# Add CLI Args representing user provided exported model files
Expand All @@ -174,6 +180,12 @@ def _add_exported_input_path_args(parser) -> None:
default=None,
help="Use the specified AOT Inductor .dso model file",
)
exported_model_path_parser.add_argument(
"--aoti-package-path",
type=Path,
default=None,
help="Use the specified directory containing AOT Inductor compiled files",
)
exported_model_path_parser.add_argument(
"--pte-path",
type=Path,
Expand Down
4 changes: 3 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def main(args) -> None:

if compile:
assert not (
builder_args.dso_path or builder_args.pte_path
builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path
), "cannot compile exported model"
global model_forward
model_forward = torch.compile(
Expand All @@ -260,6 +260,8 @@ def main(args) -> None:
)
if builder_args.dso_path:
print(f"For model {builder_args.dso_path}")
if builder_args.aoti_package_path:
print(f"For model {builder_args.aoti_package_path}")
elif builder_args.pte_path:
print(f"For model {builder_args.pte_path}")
elif builder_args.checkpoint_path:
Expand Down
18 changes: 14 additions & 4 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ def main(args):

print(f"Using device={builder_args.device}")
set_precision(builder_args.precision)
set_backend(dso=args.output_dso_path, pte=args.output_pte_path)
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path)

builder_args.dso_path = None
builder_args.pte_path = None
builder_args.aoti_package_path = None
builder_args.setup_caches = True

output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path
output_aoti_package_path = args.output_aoti_package_path

if output_pte_path and builder_args.device != "cpu":
print(
Expand Down Expand Up @@ -74,6 +76,7 @@ def main(args):
)
model_to_pte = model
model_to_dso = model
model_to_aoti_package = model
else:
if output_pte_path:
_set_gguf_kwargs(builder_args, is_et=True, context="export")
Expand All @@ -83,12 +86,13 @@ def main(args):
)
_unset_gguf_kwargs(builder_args)

if output_dso_path:
if output_dso_path or output_aoti_package_path:
_set_gguf_kwargs(builder_args, is_et=False, context="export")
model_to_dso = _initialize_model(
model_to_aoti_package = _initialize_model(
builder_args,
quantize,
)
model_to_dso = model_to_aoti_package
_unset_gguf_kwargs(builder_args)

with torch.no_grad():
Expand All @@ -104,10 +108,16 @@ def main(args):
"Export with executorch requested but ExecuTorch could not be loaded"
)
print(executorch_exception)

if output_dso_path:
output_dso_path = str(os.path.abspath(output_dso_path))
print(f"Exporting model using AOT Inductor to {output_dso_path}")
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args, False)

if output_aoti_package_path:
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
export_model_aoti(model_to_aoti_package, builder_args.device, output_aoti_package_path, args, True)


if __name__ == "__main__":
Expand Down
17 changes: 12 additions & 5 deletions export_aoti.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import torch
import torch.nn as nn
from torch.export import Dim
import torch._inductor.config

default_device = "cpu"


def export_model(model: nn.Module, device, output_path, args=None):
def export_model(model: nn.Module, device, output_path, args=None, package=True):
max_seq_length = 350

input = (
Expand All @@ -25,11 +26,17 @@ def export_model(model: nn.Module, device, output_path, args=None):
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}

model.to(device)
so = torch._export.aot_compile(

options = {"aot_inductor.output_path": output_path}
# TODO: workaround until we update torch version
if "aot_inductor.package" in torch._inductor.config._config:
options["aot_inductor.package"] = package

path = torch._export.aot_compile(
model,
args=input,
options={"aot_inductor.output_path": output_path},
options=options,
dynamic_shapes=dynamic_shapes,
)
print(f"The generated DSO model can be found at: {so}")
return so
print(f"The AOTInductor compiled files can be found at: {path}")
return path
6 changes: 3 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def validate_build(
reason = "model compilation for prefill"
if self.compile:
reason = "model compilation"
if builder_args.dso_path:
model_type = "DSO"
if builder_args.aoti_package_path:
model_type = "PT2"
if builder_args.pte_path:
model_type = "PTE"
if model_type and reason:
Expand All @@ -103,7 +103,7 @@ def validate_build(
@classmethod
def from_args(cls, args):
sequential_prefill = (
args.sequential_prefill or bool(args.dso_path) or bool(args.pte_path)
args.sequential_prefill or bool(args.aoti_package_path) or bool(args.pte_path)
)

return cls(
Expand Down

0 comments on commit 0cf4e99

Please sign in to comment.