From 4a7dab8cfb7111aa2323ad840cda68d65b81e86f Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 6 Nov 2024 11:41:57 -0500 Subject: [PATCH] [AOTI] Remove the original model weights in Python deployment (#1337) * [AOTI] Remove the original model weights in Python deployment Summary: Fixes https://github.com/pytorch/torchchat/issues/1302. Because AOTI-compiled model contains a copy of model weights, we need to release the corresponding eager model weights in the Python deployment path. * Revert "[AOTI] Remove the original model weights in Python deployment" This reverts commit 962ec0d913bbf6c30496560f12b4726445dce7da. * Refactor the code * Add setup_cache for aoti_package_path --------- Co-authored-by: Jack-Khuu --- torchchat/cli/builder.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index a7a22a1e8..fb2bfb299 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -536,6 +536,15 @@ def _load_model(builder_args: BuilderArgs) -> Model: model = _load_model_default(builder_args) # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) + if builder_args.dso_path or builder_args.aoti_package_path: + # AOTI-compoiled model will load its own weights. + # Release weights here to avoid OOM + import gc + if hasattr(model, "model"): + model.model = None + gc.collect() + torch.cuda.empty_cache() + model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() @@ -584,6 +593,12 @@ def _initialize_model( # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. + # Using cpp runner to run AOTI compiled model is recommended. + + def do_nothing(max_batch_size, max_seq_length): + pass + model.setup_caches = do_nothing + model.forward = torch._export.aot_load( str(builder_args.dso_path.absolute()), builder_args.device ) @@ -617,6 +632,11 @@ def _initialize_model( aoti_compiled_model = load_package( str(builder_args.aoti_package_path.absolute()) ) + + def do_nothing(max_batch_size, max_seq_length): + pass + model.setup_caches = do_nothing + model.forward = aoti_compiled_model metadata = aoti_compiled_model.get_metadata() builder_args.device = metadata["AOTI_DEVICE_KEY"]