Skip to content

Commit

Permalink
Update AOTI package
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Sep 13, 2024
1 parent 6fae164 commit 0146b38
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 63 deletions.
21 changes: 10 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy an
[skip default]: end

### Server
This mode exposes a REST API for interacting with a model.
This mode exposes a REST API for interacting with a model.
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.

To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
Expand Down Expand Up @@ -255,13 +255,14 @@ Use the "Max Response Tokens" slider to limit the maximum number of tokens gener
## Desktop/Server Execution

### AOTI (AOT Inductor)
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a [DSO](https://en.wikipedia.org/wiki/Shared_library) model (represented by a file with extension `.so`)
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a zipped PT2 file containing all the artifacts generated by AOTInductor, and a [.so](https://en.wikipedia.org/wiki/Shared_library) file with the runnable contents
that is then loaded for inference. This can be done with both Python and C++ enviroments.

The following example exports and executes the Llama3.1 8B Instruct
model. The first command compiles and performs the actual export.
```
python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.so

```bash
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts.pt2
```

> [!NOTE]
Expand All @@ -273,12 +274,11 @@ case visit our [customization guide](docs/model_customization.md).

### Run in a Python Enviroment

To run in a python enviroment, use the generate subcommand like before, but include the dso file.
To run in a python enviroment, use the generate subcommand like before, but include the pt2 file.

```bash
python3 torchchat.py generate llama3.1 --aoti-package-path exportedModels/llama3_1_artifacts.pt2 --prompt "Hello my name is"
```
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
```
**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`.


### Run using our C++ Runner
Expand All @@ -288,11 +288,10 @@ To run in a C++ enviroment, we need to build the runner binary.
torchchat/utils/scripts/build_native.sh aoti
```

Then run the compiled executable, with the exported DSO from earlier.
Then run the compiled executable, with the pt2.
```bash
cmake-out/aoti_run exportedModels/llama3.1.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
```
**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`.

## Mobile Execution

Expand Down
6 changes: 3 additions & 3 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ fi
# NOTE: If a newly-fetched version of the executorch repo changes the value of
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
# package versions.
PYTORCH_NIGHTLY_VERSION=dev20240814
PYTORCH_NIGHTLY_VERSION=dev20240913

# Nightly version for torchvision
VISION_NIGHTLY_VERSION=dev20240814
VISION_NIGHTLY_VERSION=dev20240913

# Nightly version for torchtune
TUNE_NIGHTLY_VERSION=dev20240910
Expand All @@ -74,7 +74,7 @@ fi

# pip packages needed by exir.
REQUIREMENTS_TO_INSTALL=(
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
)
Expand Down
19 changes: 4 additions & 15 deletions runner/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ LICENSE file in the root directory of this source tree.
#endif

#ifdef __AOTI_MODEL__
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
#ifdef USE_CUDA
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
#endif
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
torch::Device aoti_device(torch::kCPU);

#else // __ET_MODEL__
Expand Down Expand Up @@ -93,7 +90,7 @@ typedef struct {
RunState state; // buffers for the "wave" of activations in the forward pass

#ifdef __AOTI_MODEL__
torch::inductor::AOTIModelContainerRunner* runner;
torch::inductor::AOTIModelPackageLoader* runner;
#else // __ET_MODEL__
Module* runner;
#endif
Expand Down Expand Up @@ -143,16 +140,8 @@ void build_transformer(
malloc_run_state(&t->state, &t->config);

#ifdef __AOTI_MODEL__
#ifdef USE_CUDA
if (aoti_device.type() == torch::kCUDA) {
t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path);
aoti_device = torch::Device(torch::kCUDA);
} else {
#else
{
#endif
t->runner = new torch::inductor::AOTIModelContainerRunnerCpu(model_path);
}
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA);
#else //__ET_MODEL__
t->runner = new Module(
/* path to PTE model */ model_path,
Expand Down
65 changes: 53 additions & 12 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class BuilderArgs:
gguf_path: Optional[Union[Path, str]] = None
gguf_kwargs: Optional[Dict[str, Any]] = None
dso_path: Optional[Union[Path, str]] = None
aoti_package_path: Optional[Union[Path, str]] = None
pte_path: Optional[Union[Path, str]] = None
device: Optional[str] = None
precision: torch.dtype = torch.float32
Expand All @@ -69,28 +70,29 @@ def __post_init__(self):
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
or (self.gguf_path and self.gguf_path.is_file())
or (self.dso_path and Path(self.dso_path).is_file())
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
or (self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError(
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
)

if self.dso_path and self.pte_path:
raise RuntimeError("specify either DSO path or PTE path, but not both")
if self.aoti_package_path and self.pte_path:
raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one")

if self.checkpoint_path and (self.dso_path or self.pte_path):
if self.checkpoint_path and (self.aoti_package_path or self.pte_path):
print(
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
"Warning: checkpoint path ignored because an exported AOTI or PTE path specified"
)
if self.checkpoint_dir and (self.dso_path or self.pte_path):
if self.checkpoint_dir and (self.aoti_package_path or self.pte_path):
print(
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
"Warning: checkpoint dir ignored because an exported AOTI or PTE path specified"
)
if self.gguf_path and (self.dso_path or self.pte_path):
if self.gguf_path and (self.aoti_package_path or self.pte_path):
print(
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
"Warning: GGUF path ignored because an exported AOTI or PTE path specified"
)
if not (self.dso_path) and not (self.pte_path):
if not (self.aoti_package_path) and not (self.pte_path):
self.prefill_possible = True

@classmethod
Expand Down Expand Up @@ -120,6 +122,7 @@ def from_args(cls, args): # -> BuilderArgs:

dso_path = getattr(args, "dso_path", None)
pte_path = getattr(args, "pte_path", None)
aoti_package_path = getattr(args, "aoti_package_path", None)

is_chat_model = False
if args.is_chat_model:
Expand All @@ -130,6 +133,7 @@ def from_args(cls, args): # -> BuilderArgs:
checkpoint_dir,
dso_path,
pte_path,
aoti_package_path,
args.gguf_path,
]:
if path is not None:
Expand All @@ -145,6 +149,7 @@ def from_args(cls, args): # -> BuilderArgs:


output_pte_path = getattr(args, "output_pte_path", None)
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
output_dso_path = getattr(args, "output_dso_path", None)
if output_pte_path and args.dtype.startswith("fast"):
if args.dtype == "fast":
Expand All @@ -166,10 +171,11 @@ def from_args(cls, args): # -> BuilderArgs:
gguf_path=args.gguf_path,
gguf_kwargs=None,
dso_path=dso_path,
aoti_package_path=aoti_package_path,
pte_path=pte_path,
device=args.device,
precision=dtype,
setup_caches=(output_dso_path or output_pte_path),
setup_caches=(output_dso_path or output_pte_path or output_aoti_package_path),
use_distributed=args.distributed,
is_chat_model=is_chat_model,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
Expand All @@ -184,6 +190,7 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
speculative_builder_args.gguf_path = None
speculative_builder_args.dso_path = None
speculative_builder_args.aoti_package_path = None
speculative_builder_args.pte_path = None
return speculative_builder_args

Expand Down Expand Up @@ -463,11 +470,12 @@ def _initialize_model(
):
print("Loading model...")

if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
if builder_args.gguf_path and (builder_args.dso_path or builder_args.aoti_package_path or builder_args.pte_path):
print("Setting gguf_kwargs for generate.")
is_dso = builder_args.dso_path is not None
is_aoti_package = builder_args.aoti_package_path is not None
is_pte = builder_args.pte_path is not None
assert not (is_dso and is_pte)
assert not (is_dso and is_aoti_package and is_pte)
assert builder_args.gguf_kwargs is None
# TODO: make GGUF load independent of backend
# currently not working because AVX int_mm broken
Expand Down Expand Up @@ -501,6 +509,39 @@ def _initialize_model(
)
except:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")

elif builder_args.aoti_package_path:
if not is_cuda_or_cpu_device(builder_args.device):
print(
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
)
builder_args.device = "cpu"

# assert (
# quantize is None or quantize == "{ }"
# ), "quantize not valid for exported PT2 model. Specify quantization during export."

with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args, only_config=True)
device_sync(device=builder_args.device)

try:
# Replace model forward with the AOT-compiled forward
# This is a hacky way to quickly demo AOTI's capability.
# model is still a Python object, and any mutation to its
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
from torch._inductor.package import load_package
aoti_compiled_model = load_package(
str(builder_args.aoti_package_path.absolute())
)
model.forward = aoti_compiled_model
metadata = aoti_compiled_model.get_metadata()
builder_args.device = metadata["AOTI_DEVICE_KEY"]
except:
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")

elif builder_args.pte_path:
if not is_cpu_device(builder_args.device):
print(
Expand Down
12 changes: 12 additions & 0 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ def _add_export_output_path_args(parser) -> None:
default=None,
help="Output to the specified AOT Inductor .dso model file",
)
exclusive_parser.add_argument(
"--output-aoti-package-path",
type=str,
default=None,
help="Output directory for AOTInductor compiled artifacts",
)


def _add_export_args(parser) -> None:
Expand Down Expand Up @@ -220,6 +226,12 @@ def _add_exported_input_path_args(parser) -> None:
default=None,
help="Use the specified AOT Inductor .dso model file",
)
exclusive_parser.add_argument(
"--aoti-package-path",
type=Path,
default=None,
help="Use the specified directory containing AOT Inductor compiled files",
)
exclusive_parser.add_argument(
"--pte-path",
type=Path,
Expand Down
Loading

0 comments on commit 0146b38

Please sign in to comment.