Skip to content

Commit 84d2232

Browse files
angelayiJack-Khuu
authored and
vmpuri
committed
Update aoti calls to utilize new export and packaging APIs (#1455)
Co-authored-by: Jack-Khuu <[email protected]>
1 parent a64b9e3 commit 84d2232

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

Diff for: torchchat/cli/builder.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,8 @@ def do_nothing(max_batch_size, max_seq_length):
589589
# attributes will NOT be seen on by AOTI-compiled forward
590590
# function, e.g. calling model.setup_cache will NOT touch
591591
# AOTI compiled and maintained model buffers such as kv_cache.
592-
from torch._inductor.package import load_package
593592

594-
aoti_compiled_model = load_package(
593+
aoti_compiled_model = torch._inductor.aoti_load_package(
595594
str(builder_args.aoti_package_path.absolute())
596595
)
597596

Diff for: torchchat/export.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,20 @@ def export_for_server(
7575
if not package:
7676
options = {"aot_inductor.output_path": output_path}
7777

78-
path = torch._export.aot_compile(
78+
ep = torch.export.export(
7979
model,
8080
example_inputs,
8181
dynamic_shapes=dynamic_shapes,
82-
options=options,
8382
)
8483

8584
if package:
86-
from torch._inductor.package import package_aoti
87-
88-
path = package_aoti(output_path, path)
85+
path = torch._inductor.aoti_compile_and_package(
86+
ep, package_path=output_path, inductor_configs=options
87+
)
88+
else:
89+
path = torch._inductor.aot_compile(
90+
ep.module(), example_inputs, options=options
91+
)
8992

9093
print(f"The generated packaged model can be found at: {path}")
9194
return path

0 commit comments

Comments
 (0)