Skip to content

Commit

Permalink
[AOTI] Remove the original model weights in Python deployment (#1337)
Browse files Browse the repository at this point in the history
* [AOTI] Remove the original model weights in Python deployment

Summary: Fixes #1302. Because AOTI-compiled model contains a copy of model weights, we need to release the corresponding eager model weights in the Python deployment path.

* Revert "[AOTI] Remove the original model weights in Python deployment"

This reverts commit 962ec0d.

* Refactor the code

* Add setup_cache for aoti_package_path

---------

Co-authored-by: Jack-Khuu <[email protected]>
  • Loading branch information
desertfire and Jack-Khuu authored Nov 6, 2024
1 parent 54455a3 commit 4a7dab8
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,15 @@ def _load_model(builder_args: BuilderArgs) -> Model:
model = _load_model_default(builder_args)
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)

if builder_args.dso_path or builder_args.aoti_package_path:
# AOTI-compoiled model will load its own weights.
# Release weights here to avoid OOM
import gc
if hasattr(model, "model"):
model.model = None
gc.collect()
torch.cuda.empty_cache()

model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()

Expand Down Expand Up @@ -584,6 +593,12 @@ def _initialize_model(
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
# Using cpp runner to run AOTI compiled model is recommended.

def do_nothing(max_batch_size, max_seq_length):
pass
model.setup_caches = do_nothing

model.forward = torch._export.aot_load(
str(builder_args.dso_path.absolute()), builder_args.device
)
Expand Down Expand Up @@ -617,6 +632,11 @@ def _initialize_model(
aoti_compiled_model = load_package(
str(builder_args.aoti_package_path.absolute())
)

def do_nothing(max_batch_size, max_seq_length):
pass
model.setup_caches = do_nothing

model.forward = aoti_compiled_model
metadata = aoti_compiled_model.get_metadata()
builder_args.device = metadata["AOTI_DEVICE_KEY"]
Expand Down

0 comments on commit 4a7dab8

Please sign in to comment.