diff --git a/.ci/scripts/validate.sh b/.ci/scripts/validate.sh index 1f7e889d3..ace9ef18d 100644 --- a/.ci/scripts/validate.sh +++ b/.ci/scripts/validate.sh @@ -133,51 +133,51 @@ function generate_aoti_model_output() { echo "******************************************" echo "************** non-quantized *************" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path "${MODEL_DIR}/${MODEL_NAME}.so" --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path "$MODEL_DIR/${MODEL_NAME}.so" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path "${MODEL_DIR}/${MODEL_NAME}.pt2" --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --aoti-package-path "$MODEL_DIR/${MODEL_NAME}.pt2" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "******************************************" echo "******* Emb: channel-wise quantized ******" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "******************************************" echo "******** Emb: group-wise quantized *******" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "***********************************************" echo "******* Emb: 4bit channel-wise quantized ******" echo "***********************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 0, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 0, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "***********************************************" echo "******** Emb: 4bit group-wise quantized *******" echo "***********************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" if [ "${EXCLUDE_INT8_QUANT:-false}" == false ]; then echo "******************************************" echo "******* INT8 channel-wise quantized ******" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" echo "******************************************" echo "******** INT8 group-wise quantized *******" echo "******************************************" - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" fi echo "******************************************" @@ -185,8 +185,8 @@ function generate_aoti_model_output() { echo "******************************************" if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then # For CUDA, only bfloat16 makes sense for int4 mm kernel - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" fi done @@ -285,8 +285,8 @@ function eval_model_sanity_check() { echo "******** INT4 group-wise quantized (AOTI) *******" echo "*************************************************" if [ "$DTYPE" != "float16" ]; then - python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --dynamic-shapes --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore torchchat.py eval --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1 + python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --dynamic-shapes --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore torchchat.py eval --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1 cat "$MODEL_DIR/output_eval_aoti" fi; fi; diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 14b8c0712..ee7270a5d 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -378,8 +378,8 @@ jobs: echo "::group::Run inference with quantize file" if [ $(uname -s) == Darwin ]; then - python3 torchchat.py export --output-dso-path /tmp/model.so --quantize torchchat/quant_config/cuda.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth" - python3 torchchat.py generate --dso-path /tmp/model.so --checkpoint "./checkpoints/${REPO_NAME}/model.pth"~ + python3 torchchat.py export --output-aoti-package-path /tmp/model.pt2 --quantize torchchat/quant_config/cuda.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth" + python3 torchchat.py generate --aoti-package-path /tmp/model.pt2 --checkpoint "./checkpoints/${REPO_NAME}/model.pth"~ fi echo "::endgroup::" @@ -1023,8 +1023,8 @@ jobs: for dtype in fp32 fp16 bf16 fast fast16; do echo "Running export + runner with dtype=$dtype" - python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --dtype $dtype --output-dso-path /tmp/model.so - ./cmake-out/aoti_run /tmp/model.so -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}" + python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --dtype $dtype --output-aoti-package-path /tmp/model.pt2 + ./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}" done echo "Tests complete." @@ -1118,8 +1118,8 @@ jobs: python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' ./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}" echo "Export and run AOTI (C++ runner)" - python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' - ./cmake-out/aoti_run ./model.so -z ./tokenizer.model -t 0 -i "${PRMT}" + python torchchat.py export stories110M --output-aoti-package-path ./model.pt2 --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' + ./cmake-out/aoti_run ./model.pt2 -z ./tokenizer.model -t 0 -i "${PRMT}" echo "Generate AOTI" - python torchchat.py generate stories110M --dso-path ./model.so --prompt "${PRMT}" + python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}" echo "Tests complete." diff --git a/.github/workflows/runner-cuda-dtype.yml b/.github/workflows/runner-cuda-dtype.yml index a79c262c3..b83b9904b 100644 --- a/.github/workflows/runner-cuda-dtype.yml +++ b/.github/workflows/runner-cuda-dtype.yml @@ -56,9 +56,9 @@ jobs: for DTYPE in bfloat16; do python torchchat.py generate --dtype ${DTYPE} --checkpoint-path ${MODEL_DIR}/stories15M.pt --temperature 0 --prompt "${PROMPT}" --device cuda - python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-dso-path /tmp/model.so + python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-aoti-package-path /tmp/model.pt2 - ./cmake-out/aoti_run /tmp/model.so -d CUDA -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}" + ./cmake-out/aoti_run /tmp/model.pt2 -d CUDA -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}" done diff --git a/README.md b/README.md index 7f972f81c..40fa84d85 100644 --- a/README.md +++ b/README.md @@ -293,13 +293,18 @@ Use the "Max Response Tokens" slider to limit the maximum number of tokens gener ## Desktop/Server Execution ### AOTI (AOT Inductor) -[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a [DSO](https://en.wikipedia.org/wiki/Shared_library) model (represented by a file with extension `.so`) -that is then loaded for inference. This can be done with both Python and C++ environments. +[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution +for faster inference. The process creates a zipped PT2 file containing all the +artifacts generated by AOTInductor, and a +[.so](https://en.wikipedia.org/wiki/Shared_library) file with the runnable +contents that is then loaded for inference. This can be done with both Python +and C++ enviroments. The following example exports and executes the Llama3.1 8B Instruct model. The first command compiles and performs the actual export. -``` -python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.so + +```bash +python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts.pt2 ``` > [!NOTE] @@ -311,12 +316,11 @@ case visit our [customization guide](docs/model_customization.md). ### Run in a Python Environment -To run in a python environment, use the generate subcommand like before, but include the dso file. +To run in a python enviroment, use the generate subcommand like before, but include the pt2 file. +```bash +python3 torchchat.py generate llama3.1 --aoti-package-path exportedModels/llama3_1_artifacts.pt2 --prompt "Hello my name is" ``` -python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is" -``` -**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`. ### Run using our C++ Runner @@ -326,11 +330,10 @@ To run in a C++ enviroment, we need to build the runner binary. torchchat/utils/scripts/build_native.sh aoti ``` -Then run the compiled executable, with the exported DSO from earlier. +Then run the compiled executable, with the pt2. ```bash -cmake-out/aoti_run exportedModels/llama3.1.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time" +cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time" ``` -**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`. ## Mobile Execution diff --git a/runner/run.cpp b/runner/run.cpp index e161c029e..abfbb4584 100644 --- a/runner/run.cpp +++ b/runner/run.cpp @@ -31,10 +31,7 @@ LICENSE file in the root directory of this source tree. #endif #ifdef __AOTI_MODEL__ -#include -#ifdef USE_CUDA -#include -#endif +#include torch::Device aoti_device(torch::kCPU); #else // __ET_MODEL__ @@ -94,7 +91,7 @@ typedef struct { RunState state; // buffers for the "wave" of activations in the forward pass #ifdef __AOTI_MODEL__ - torch::inductor::AOTIModelContainerRunner* runner; + torch::inductor::AOTIModelPackageLoader* runner; #else // __ET_MODEL__ Module* runner; #endif @@ -144,16 +141,8 @@ void build_transformer( malloc_run_state(&t->state, &t->config); #ifdef __AOTI_MODEL__ -#ifdef USE_CUDA - if (aoti_device.type() == torch::kCUDA) { - t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path); - aoti_device = torch::Device(torch::kCUDA); - } else { -#else - { -#endif - t->runner = new torch::inductor::AOTIModelContainerRunnerCpu(model_path); - } + t->runner = new torch::inductor::AOTIModelPackageLoader(model_path); + aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA); #else //__ET_MODEL__ t->runner = new Module( /* path to PTE model */ model_path, diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 511cf1f35..17bc219f8 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -53,6 +53,7 @@ class BuilderArgs: gguf_path: Optional[Union[Path, str]] = None gguf_kwargs: Optional[Dict[str, Any]] = None dso_path: Optional[Union[Path, str]] = None + aoti_package_path: Optional[Union[Path, str]] = None pte_path: Optional[Union[Path, str]] = None device: Optional[str] = None precision: torch.dtype = torch.float32 @@ -75,16 +76,19 @@ def __post_init__(self): or (self.checkpoint_dir and self.checkpoint_dir.is_dir()) or (self.gguf_path and self.gguf_path.is_file()) or (self.dso_path and Path(self.dso_path).is_file()) + or (self.aoti_package_path and Path(self.aoti_package_path).is_file()) or (self.pte_path and Path(self.pte_path).is_file()) ): raise RuntimeError( "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" ) - if self.dso_path and self.pte_path: - raise RuntimeError("specify either DSO path or PTE path, but not both") + if self.aoti_package_path and self.pte_path: + raise RuntimeError( + "specify either AOTI Package path or PTE path, but not more than one" + ) - if self.dso_path or self.pte_path: + if self.dso_path or self.pte_path or self.aoti_package_path: ignored_params = [ (self.checkpoint_path, "checkpoint path"), (self.checkpoint_dir, "checkpoint dir"), @@ -125,6 +129,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": dso_path = getattr(args, "dso_path", None) pte_path = getattr(args, "pte_path", None) + aoti_package_path = getattr(args, "aoti_package_path", None) is_chat_model = False if args.is_chat_model: @@ -135,6 +140,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": checkpoint_dir, dso_path, pte_path, + aoti_package_path, args.gguf_path, ]: if path is not None: @@ -149,6 +155,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": is_chat_model = True output_pte_path = getattr(args, "output_pte_path", None) + output_aoti_package_path = getattr(args, "output_aoti_package_path", None) output_dso_path = getattr(args, "output_dso_path", None) if output_pte_path and args.dtype.startswith("fast"): if args.dtype == "fast": @@ -174,10 +181,13 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": gguf_path=args.gguf_path, gguf_kwargs=None, dso_path=dso_path, + aoti_package_path=aoti_package_path, pte_path=pte_path, device=args.device, precision=dtype, - setup_caches=(output_dso_path or output_pte_path), + setup_caches=( + output_dso_path or output_pte_path or output_aoti_package_path + ), distributed=distributed, pp=pp, tp=tp, @@ -195,6 +205,7 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args.checkpoint_path = args.draft_checkpoint_path speculative_builder_args.gguf_path = None speculative_builder_args.dso_path = None + speculative_builder_args.aoti_package_path = None speculative_builder_args.pte_path = None return speculative_builder_args @@ -511,11 +522,14 @@ def _initialize_model( support_tensor_subclass: bool = True, ) -> Model: print("Loading model...") - if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path): + if builder_args.gguf_path and ( + builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path + ): print("Setting gguf_kwargs for generate.") is_dso = builder_args.dso_path is not None + is_aoti_package = builder_args.aoti_package_path is not None is_pte = builder_args.pte_path is not None - assert not (is_dso and is_pte) + assert not (is_dso and is_aoti_package and is_pte) assert builder_args.gguf_kwargs is None # TODO: make GGUF load independent of backend # currently not working because AVX int_mm broken @@ -549,6 +563,42 @@ def _initialize_model( ) except: raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}") + + elif builder_args.aoti_package_path: + if not is_cuda_or_cpu_device(builder_args.device): + print( + f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead" + ) + builder_args.device = "cpu" + + # assert ( + # quantize is None or quantize == "{ }" + # ), "quantize not valid for exported PT2 model. Specify quantization during export." + + with measure_time("Time to load model: {time:.02f} seconds"): + model = _load_model(builder_args) + device_sync(device=builder_args.device) + + try: + # Replace model forward with the AOT-compiled forward + # This is a hacky way to quickly demo AOTI's capability. + # model is still a Python object, and any mutation to its + # attributes will NOT be seen on by AOTI-compiled forward + # function, e.g. calling model.setup_cache will NOT touch + # AOTI compiled and maintained model buffers such as kv_cache. + from torch._inductor.package import load_package + + aoti_compiled_model = load_package( + str(builder_args.aoti_package_path.absolute()) + ) + model.forward = aoti_compiled_model + metadata = aoti_compiled_model.get_metadata() + builder_args.device = metadata["AOTI_DEVICE_KEY"] + except: + raise RuntimeError( + f"Failed to load AOTI compiled {builder_args.aoti_package_path}" + ) + elif builder_args.pte_path: if not is_cpu_device(builder_args.device): print( diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index bc41d56ec..740f344a8 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -191,6 +191,12 @@ def _add_export_output_path_args(parser) -> None: default=None, help="Output to the specified AOT Inductor .dso model file", ) + exclusive_parser.add_argument( + "--output-aoti-package-path", + type=str, + default=None, + help="Output directory for AOTInductor compiled artifacts", + ) def _add_export_args(parser) -> None: @@ -220,6 +226,12 @@ def _add_exported_input_path_args(parser) -> None: default=None, help="Use the specified AOT Inductor .dso model file", ) + exclusive_parser.add_argument( + "--aoti-package-path", + type=Path, + default=None, + help="Use the specified directory containing AOT Inductor compiled files", + ) exclusive_parser.add_argument( "--pte-path", type=Path, diff --git a/torchchat/export.py b/torchchat/export.py index 626d4fae3..7a7923119 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -11,6 +11,7 @@ import torch.nn as nn from torch.export import Dim +import torch._inductor from torchchat.cli.builder import ( _initialize_model, @@ -35,8 +36,9 @@ def export_for_server( model: nn.Module, device: Optional[str] = "cpu", - output_path: str = "model.dso", + output_path: str = "model.pt2", dynamic_shapes: bool = False, + package: bool = True, ) -> str: """ Export the model using AOT Compile to get a .dso for server use cases. @@ -49,7 +51,7 @@ def export_for_server( The path to the exported model. """ if dynamic_shapes: - input = ( + example_inputs = ( torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) @@ -58,21 +60,31 @@ def export_for_server( # Specify that the first dimension of each input is that batch size dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}} else: - input = ( + example_inputs = ( torch.tensor([[1]], dtype=torch.int, device=device), torch.tensor([0], dtype=torch.int, device=device), ) dynamic_shapes = None with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): - so = torch._export.aot_compile( + metadata = {} # TODO: put more metadata here + options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata} + if not package: + options = {"aot_inductor.output_path": output_path} + + path = torch._export.aot_compile( model, - args=input, - options={"aot_inductor.output_path": output_path}, + example_inputs, dynamic_shapes=dynamic_shapes, + options=options, ) - print(f"The generated DSO model can be found at: {so}") - return so + + if package: + from torch._inductor.package import package_aoti + path = package_aoti(output_path, path) + + print(f"The generated packaged model can be found at: {path}") + return path """ @@ -338,14 +350,16 @@ def main(args): print(f"Using device={builder_args.device}") set_precision(builder_args.precision) - set_backend(dso=args.output_dso_path, pte=args.output_pte_path) + set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path) builder_args.dso_path = None builder_args.pte_path = None + builder_args.aoti_package_path = None builder_args.setup_caches = True output_pte_path = args.output_pte_path output_dso_path = args.output_dso_path + output_aoti_package_path = args.output_aoti_package_path if output_pte_path and builder_args.device != "cpu": print( @@ -383,10 +397,11 @@ def main(args): quantize, tokenizer, max_seq_length=builder_args.max_seq_length, - support_tensor_subclass=output_dso_path is None, + support_tensor_subclass=output_dso_path is None and output_aoti_package_path is None, ) model_to_pte = model model_to_dso = model + model_to_aoti_package = model else: if output_pte_path: _set_gguf_kwargs(builder_args, is_et=True, context="export") @@ -396,13 +411,14 @@ def main(args): ) _unset_gguf_kwargs(builder_args) - if output_dso_path: + if output_dso_path or output_aoti_package_path: _set_gguf_kwargs(builder_args, is_et=False, context="export") - model_to_dso = _initialize_model( + model_to_aoti_package = _initialize_model( builder_args, quantize, support_tensor_subclass=False, ) + model_to_dso = model_to_aoti_package _unset_gguf_kwargs(builder_args) with torch.no_grad(): @@ -419,9 +435,22 @@ def main(args): if output_dso_path: output_dso_path = str(os.path.abspath(output_dso_path)) print(f"Exporting model using AOT Inductor to {output_dso_path}") + print("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead.") export_for_server( model_to_dso, builder_args.device, output_dso_path, builder_args.dynamic_shapes, + package=False, + ) + + if output_aoti_package_path: + output_aoti_package_path = str(os.path.abspath(output_aoti_package_path)) + print(f"Exporting model using AOT Inductor to {output_aoti_package_path}") + export_for_server( + model_to_aoti_package, + builder_args.device, + output_aoti_package_path, + builder_args.dynamic_shapes, + package=True, ) diff --git a/torchchat/generate.py b/torchchat/generate.py index 886b0d7ab..be6a2e819 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -163,6 +163,8 @@ def validate_build( reason = "model compilation for prefill" if self.compile: reason = "model compilation" + if builder_args.aoti_package_path: + model_type = "PT2" if builder_args.dso_path: model_type = "DSO" if builder_args.pte_path: @@ -176,7 +178,10 @@ def validate_build( def from_args(cls, args): dso_path = getattr(args, "dso_path", None) pte_path = getattr(args, "pte_path", None) - sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path) + aoti_package_path = getattr(args, "aoti_package_path", None) + sequential_prefill = ( + args.sequential_prefill or bool(aoti_package_path) or bool(pte_path) or bool(dso_path) + ) # Validate that all image prompts exist before expensive model load if image_prompts := getattr(args, "image_prompts", None): diff --git a/torchchat/usages/eval.py b/torchchat/usages/eval.py index 5993c3781..b708e5840 100644 --- a/torchchat/usages/eval.py +++ b/torchchat/usages/eval.py @@ -260,7 +260,7 @@ def main(args) -> None: if compile: assert not ( - builder_args.dso_path or builder_args.pte_path + builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path ), "cannot compile exported model" model_forward = torch.compile( model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True @@ -288,6 +288,8 @@ def main(args) -> None: ) if builder_args.dso_path: print(f"For model {builder_args.dso_path}") + elif builder_args.aoti_package_path: + print(f"For model {builder_args.aoti_package_path}") elif builder_args.pte_path: print(f"For model {builder_args.pte_path}") elif builder_args.checkpoint_path: diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index fd30f87d5..1b649ffbc 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -6,9 +6,10 @@ from __future__ import annotations -from enum import Enum import logging import os + +from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, List, Tuple @@ -70,34 +71,37 @@ def unpack_packed_weights( active_builder_args_dso = None active_builder_args_pte = None +active_builder_args_aoti_package = None -def set_backend(dso, pte): +def set_backend(dso, pte, aoti_package): global active_builder_args_dso global active_builder_args_pte active_builder_args_dso = dso + active_builder_args_aoti_package = aoti_package active_builder_args_pte = pte class _Backend(Enum): - AOTI = 0, + AOTI = (0,) EXECUTORCH = 1 def _active_backend() -> _Backend: global active_builder_args_dso + global active_builder_args_aoti_package global active_builder_args_pte # eager == aoti, which is when backend has not been explicitly set - if (not active_builder_args_dso) and not (active_builder_args_pte): - return _Backend.AOTI + if (not active_builder_args_pte) and (not active_builder_args_aoti_package): + return True - if active_builder_args_pte and active_builder_args_dso: + if active_builder_args_pte and active_builder_args_aoti_package: raise RuntimeError( - "code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!" + "code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!" ) - return _Backend.AOTI if active_builder_args_dso else _Backend.EXECUTORCH + return _Backend.AOTI if active_builder_args_pte else _Backend.EXECUTORCH def use_aoti_backend() -> bool: @@ -115,22 +119,24 @@ def use_et_backend() -> bool: def set_precision(dtype): - """set_precision() is a torchchat-internal API that records the dtype we're building the model for. -The precision is recorded for future queries by get_precision(), so that when building a model, -or performing optimizations, we can query the type the user is building the model for. -This is an informational value that can be used when we want to know what type to build for (e.g., a kv cache). -Changing the `precision` does not change the precision of the model. -""" - + """set_precision() is a torchchat-internal API that records the dtype we're building the model for. + The precision is recorded for future queries by get_precision(), so that when building a model, + or performing optimizations, we can query the type the user is building the model for. + This is an informational value that can be used when we want to know what type to build for (e.g., a kv cache). + Changing the `precision` does not change the precision of the model. + """ + global precision - assert precision is None, "only set precision once to avoid inconsistent answers during different phases of model build and export" + assert ( + precision is None + ), "only set precision once to avoid inconsistent answers during different phases of model build and export" precision = dtype def get_precision(): """get_precision() is a torchchat-internal API that returns the dtype we're building the model for, as specified by the `--dtype` CLI option+, -or the precision quantizer. -""" + or the precision quantizer. + """ global precision # if (and only if) precision has not been set, update it to the default value torch.float32 if precision is None: @@ -224,7 +230,7 @@ def canonical_path(path): def state_dict_device(d, device="cpu") -> Dict: - return {key : weight.to(device=device) for (key, weight) in d.items()} + return {key: weight.to(device=device) for (key, weight) in d.items()} #########################################################################