Skip to content

Commit 6636c48

Browse files
larryliu0820JacobSzwejbka
authored andcommitted
[aoti-et] Store weights outside of .so (#15180)
This pull request introduces significant changes to how model constants (weights) are packaged and loaded in the CUDA backend, improving modularity and flexibility. It also updates the runner APIs to support more flexible loading modes and bumps PyTorch commit pins and nightly version references. **Packaging and Loading Model Constants:** * Model constants are now separated from the `.so` file and stored as a binary blob on disk, rather than being packaged directly into the shared object. The preprocessing logic in `cuda_backend.py` is updated to handle the new file outputs and manage cleanup. [[1]](diffhunk://#diff-5b5ea2257772b3aba04b2534f5ea1429a0c631bfd25a7ef531f526e76c471d7aL149-R153) [[2]](diffhunk://#diff-5b5ea2257772b3aba04b2534f5ea1429a0c631bfd25a7ef531f526e76c471d7aL165-R211) * The C++ CUDA backend (`cuda_backend.cpp`) now loads the new weights blob from the named data map and feeds it into the model container using a newly added API function. The buffer is freed immediately after use for better resource management. [[1]](diffhunk://#diff-a4b17eccf1aa933837671c5184e02bc815d934a362344bb2b17b789cdfaa5375R153-R154) [[2]](diffhunk://#diff-a4b17eccf1aa933837671c5184e02bc815d934a362344bb2b17b789cdfaa5375R183-R195) **API and Infrastructure Updates:** * A new function pointer type, `AOTInductorModelUpdateConstantsFromBlobFunc`, is added to the delegate handle structure in `aoti_delegate_handle.h` to support updating model constants from a binary blob. [[1]](diffhunk://#diff-0598c198d53bf756f6013186ea3155f15ddef247aa863e83ef30f27991b3a0a7R74-R78) [[2]](diffhunk://#diff-0598c198d53bf756f6013186ea3155f15ddef247aa863e83ef30f27991b3a0a7R95) * The CUDA backend now loads this new symbol from the shared object at runtime. **Runner API Improvements:** * The multimodal runner API is updated to accept a `Module::LoadMode` parameter, allowing for more flexible loading options such as memory mapping. This change is propagated through helper functions and their headers. [[1]](diffhunk://#diff-0ac16dbe4eaefa08e21fbda582fe2cd2b482f43aaedfc1bf2f31becf5e7bb843L322-R322) [[2]](diffhunk://#diff-005ac94c6b217e02d652aafc206d36b2ec1190af36aa0a632fd406975dfc2600L271-R272) [[3]](diffhunk://#diff-005ac94c6b217e02d652aafc206d36b2ec1190af36aa0a632fd406975dfc2600L281-R284) [[4]](diffhunk://#diff-ac7a381a7828a6f1a543d2beab4cf503c2d3547ab86821c8e1777df9305108aaL143-R144) **Dependency Updates:** * The PyTorch commit pin is updated in `.ci/docker/ci_commit_pins/pytorch.txt` and the nightly version is bumped in `torch_pin.py` for compatibility with the new packaging logic. [[1]](diffhunk://#diff-e873e85ae7aa52ebeadb13a27cf83eff1891b1011e27f94ec040eb8407893c5eL1-R1) [[2]](diffhunk://#diff-9665391232bd21d4ee0a293cbc7f76d99db902ab1e6e045a59f9a132325babc9L2-R2)
1 parent 04fc6ff commit 6636c48

File tree

17 files changed

+193
-82
lines changed

17 files changed

+193
-82
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
53a2908a10f414a2f85caa06703a26a40e873869
1+
e6f766c7d750d40603eee3f66c5915bac606b3ea

.ci/scripts/utils.sh

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,44 @@ install_pip_dependencies() {
4444
popd || return
4545
}
4646

47+
dedupe_macos_loader_path_rpaths() {
48+
if [[ "$(uname)" != "Darwin" ]]; then
49+
return
50+
fi
51+
52+
local torch_lib_dir
53+
pushd ..
54+
torch_lib_dir=$(python -c "import importlib.util; print(importlib.util.find_spec('torch').submodule_search_locations[0])")/lib
55+
popd
56+
57+
if [[ -z "${torch_lib_dir}" || ! -d "${torch_lib_dir}" ]]; then
58+
return
59+
fi
60+
61+
local torch_libs=(
62+
"libtorch_cpu.dylib"
63+
"libtorch.dylib"
64+
"libc10.dylib"
65+
)
66+
67+
for lib_name in "${torch_libs[@]}"; do
68+
local lib_path="${torch_lib_dir}/${lib_name}"
69+
if [[ ! -f "${lib_path}" ]]; then
70+
continue
71+
fi
72+
73+
local removed=0
74+
# Repeatedly remove the @loader_path rpath entries until none remain.
75+
while install_name_tool -delete_rpath @loader_path "${lib_path}" 2>/dev/null; do
76+
removed=1
77+
done
78+
79+
if [[ "${removed}" == "1" ]]; then
80+
install_name_tool -add_rpath @loader_path "${lib_path}" || true
81+
fi
82+
done
83+
}
84+
4785
install_domains() {
4886
echo "Install torchvision and torchaudio"
4987
pip install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}"
@@ -101,6 +139,7 @@ install_pytorch_and_domains() {
101139
echo "Use cached wheel at ${cached_torch_wheel}"
102140
fi
103141

142+
dedupe_macos_loader_path_rpaths
104143
# Grab the pinned audio and vision commits from PyTorch
105144
TORCHAUDIO_VERSION=$(cat .github/ci_commit_pins/audio.txt)
106145
export TORCHAUDIO_VERSION

.github/workflows/pull.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ jobs:
351351
352352
# reinstall executorch
353353
bash ./install_executorch.sh --minimal
354+
pip list
354355
355356
# run python unittest
356357
python -m unittest examples.models.moshi.mimi.test_mimi

CMakeLists.txt

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -99,28 +99,6 @@ announce_configured_options(CCACHE_PROGRAM)
9999

100100
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
101101

102-
# Setup RPATH. See
103-
# https://gitlab.kitware.com/cmake/community/-/wikis/doc/cmake/RPATH-handling
104-
# Use separate rpaths during build and install phases
105-
set(CMAKE_SKIP_BUILD_RPATH OFF)
106-
# Don't use the install-rpath during the build phase
107-
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
108-
# Automatically add all linked folders that are NOT in the build directory to
109-
# the rpath (per library?)
110-
#
111-
# TODO: Doesn't work for us right now because we are not installing .so's into
112-
# the correct locations. For example we have libcustom_ops_aot_lib.so depending
113-
# on _portable_lib.so, which was eventually put under
114-
# <site-packages>/executorch/extension/pybindings/ but this rpath is not
115-
# automatically added because at build time it seems `portable_lib` is being
116-
# built under the same directory, so no extra rpath is being added. To properly
117-
# fix this we need to install `portable_lib` into the correct path.
118-
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH ON)
119-
# ------------------------------ OPTIONS -------------------------------------
120-
# WARNING: Please don't add example specific options in this CMakeLists.txt.
121-
# Instead please use `find_package(executorch REQUIRED)` in the example
122-
# directory and add a new executable in the example `CMakeLists.txt`.
123-
124102
if(NOT EXECUTORCH_ENABLE_LOGGING)
125103
# Avoid pulling in the logging strings, which can be large. Note that this
126104
# will set the compiler flag for all targets in this directory, and for all
@@ -901,12 +879,13 @@ if(EXECUTORCH_BUILD_PYBIND)
901879

902880
# Set RPATH to find PyTorch libraries relative to the installation location
903881
# This goes from executorch/extension/pybindings up to site-packages, then to
904-
# torch/lib
882+
# torch/lib. Don't do this to APPLE, as it will error out on the following
883+
# error:
884+
#
905885
if(APPLE)
906-
set_target_properties(
907-
portable_lib PROPERTIES BUILD_RPATH "@loader_path/../../../torch/lib"
908-
INSTALL_RPATH "@loader_path/../../../torch/lib"
909-
)
886+
# Skip setting @loader_path for APPLE, since it causes error like ld:
887+
# duplicate LC_RPATH '@loader_path' in '<site-packages>/torch/lib/
888+
# libtorch_cpu.dylib'
910889
else()
911890
set_target_properties(
912891
portable_lib PROPERTIES BUILD_RPATH "$ORIGIN/../../../torch/lib"

backends/aoti/aoti_delegate_handle.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
7171
AOTInductorModelContainerHandle container_handle,
7272
size_t* num_constants);
7373

74+
// Update the model container with the constant tensors
75+
using AOTInductorModelUpdateConstantsFromBlobFunc = AOTIRuntimeError (*)(
76+
AOTInductorModelContainerHandle container_handle,
77+
const uint8_t* weight_blob_ptr);
78+
7479
} // extern "C"
7580

7681
// AOTI Delegate Handle structure
@@ -87,6 +92,7 @@ struct AOTIDelegateHandle {
8792
AOTInductorModelContainerGetNumInputsFunc get_num_inputs;
8893
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
8994
AOTInductorModelContainerRunFunc run;
95+
AOTInductorModelUpdateConstantsFromBlobFunc update_constants_from_blob;
9096
};
9197

9298
} // namespace aoti

backends/cuda/cuda_backend.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,11 @@ def preprocess(
146146
"aot_inductor.embed_kernel_binary": True,
147147
# Do not link against the full PyTorch/libtorch library
148148
"aot_inductor.link_libtorch": False,
149-
# Package model constants and other generated files directly in the shared object (.so) file
150-
"aot_inductor.package_constants_in_so": True,
149+
# Separate weight constants from the .so file
150+
"aot_inductor.package": True,
151+
"aot_inductor.package_constants_in_so": False,
152+
# Store weight constants on disk in a binary blob
153+
"aot_inductor.package_constants_on_disk_format": "binary_blob",
151154
# Enable maximum automatic tuning for optimal performance
152155
"max_autotune": True,
153156
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch
@@ -162,25 +165,49 @@ def preprocess(
162165
]
163166
), torch.no_grad():
164167
# torch._logging.set_logs(post_grad_graphs=True)
165-
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
168+
# Here we should expect 1 so file and 1 weight blob in the same directory.
169+
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
166170
if len(missing_fallback_kernels) > 0:
167171
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
168172
raise RuntimeError(
169173
f"Method {CudaBackend.method_name_from_compile_specs(compile_specs)} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
170174
"Please add them to the AOTI backend."
171175
)
172176

177+
# Extract the .so and .blob paths from the returned list
178+
so_path = None
179+
blob_path = None
180+
for path in paths:
181+
if path.endswith(".wrapper.so"):
182+
so_path = path
183+
elif path.endswith(".wrapper_weights.blob"):
184+
blob_path = path
185+
186+
if so_path is None or blob_path is None:
187+
raise RuntimeError(
188+
f"Could not find required files in compiled paths, got {paths}"
189+
)
190+
173191
# pyre-ignorep[6]: Incompatible parameter type
174192
with open(so_path, "rb") as f:
175193
so_data = f.read()
176194

177195
named_data_store = NamedDataStore()
178196
method_name = CudaBackend.method_name_from_compile_specs(compile_specs)
197+
198+
# Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file.
199+
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)
200+
201+
# Add weights blob to named data store
202+
with open(blob_path, "rb") as f:
203+
blob_data = f.read()
179204
named_data_store.add_named_data(
180-
method_name + "_so_blob", so_data, 1, "aoti_cuda_blob"
205+
method_name + "_weights_blob", blob_data, 1, "aoti_cuda_blob"
181206
)
207+
# Clean up the weights blob file
208+
os.remove(blob_path)
182209

183-
# Clean up the generated so file; it has been packaged into the NamdeDataStore
210+
# Clean up the generated so file; it has been packaged into the NamedDataStore
184211
# pyre-ignorep[6]: Incompatible parameter type
185212
os.remove(so_path)
186213

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,6 @@
2727

2828
namespace executorch::backends::cuda {
2929

30-
#define LOAD_SYMBOL(handle, member, name, so_handle) \
31-
do { \
32-
auto symbol_res = get_function(so_handle, #name); \
33-
if (!symbol_res.ok()) { \
34-
return symbol_res.error(); \
35-
} \
36-
handle->member = reinterpret_cast<name##Func>(symbol_res.get()); \
37-
} while (0)
38-
3930
using namespace std;
4031
using namespace aoti;
4132

@@ -61,29 +52,37 @@ class ET_EXPERIMENTAL CudaBackend final
6152
Error load_function_pointers_into_handle(
6253
void* so_handle,
6354
AOTIDelegateHandle* handle) const {
64-
LOAD_SYMBOL(
65-
handle,
66-
create_with_device,
67-
AOTInductorModelContainerCreateWithDevice,
68-
so_handle);
55+
#define LOAD_SYMBOL(member, name) \
56+
do { \
57+
auto symbol_res = get_function(so_handle, #name); \
58+
if (!symbol_res.ok()) { \
59+
return symbol_res.error(); \
60+
} \
61+
handle->member = reinterpret_cast<name##Func>(symbol_res.get()); \
62+
} while (0)
63+
64+
LOAD_SYMBOL(create_with_device, AOTInductorModelContainerCreateWithDevice);
6965

70-
LOAD_SYMBOL(
71-
handle, delete_container, AOTInductorModelContainerDelete, so_handle);
66+
LOAD_SYMBOL(delete_container, AOTInductorModelContainerDelete);
7267

73-
LOAD_SYMBOL(
74-
handle,
75-
get_num_inputs,
76-
AOTInductorModelContainerGetNumInputs,
77-
so_handle);
68+
LOAD_SYMBOL(get_num_inputs, AOTInductorModelContainerGetNumInputs);
7869

79-
LOAD_SYMBOL(
80-
handle,
81-
get_num_outputs,
82-
AOTInductorModelContainerGetNumOutputs,
83-
so_handle);
70+
LOAD_SYMBOL(get_num_outputs, AOTInductorModelContainerGetNumOutputs);
8471

85-
LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle);
72+
LOAD_SYMBOL(run, AOTInductorModelContainerRun);
73+
#undef LOAD_SYMBOL
8674

75+
auto symbol_res =
76+
get_function(so_handle, "AOTInductorModelUpdateConstantsFromBlob");
77+
if (symbol_res.ok()) {
78+
handle->update_constants_from_blob =
79+
reinterpret_cast<AOTInductorModelUpdateConstantsFromBlobFunc>(
80+
symbol_res.get());
81+
} else {
82+
ET_LOG(
83+
Info,
84+
"Failed to load AOTInductorModelUpdateConstantsFromBlob. This .so is probably compiled on an old version of torch (<2.9.0)");
85+
}
8786
return Error::Ok;
8887
}
8988

@@ -112,13 +111,13 @@ class ET_EXPERIMENTAL CudaBackend final
112111
method_name.empty() ? "so_blob" : method_name + "_so_blob";
113112

114113
const NamedDataMap* named_data_map = context.get_named_data_map();
115-
auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str());
114+
auto aoti_dso_buffer = named_data_map->get_data(so_blob_key.c_str());
116115
ET_CHECK_OR_RETURN_ERROR(
117-
aoti_cuda_buffer.ok(),
116+
aoti_dso_buffer.ok(),
118117
Internal,
119118
"Failed to get data for key %s: 0x%x",
120119
so_blob_key.c_str(),
121-
static_cast<uint32_t>(aoti_cuda_buffer.error()));
120+
static_cast<uint32_t>(aoti_dso_buffer.error()));
122121

123122
// Generate dynamic temporary file path
124123
filesystem::path temp_dir = filesystem::temp_directory_path();
@@ -132,19 +131,21 @@ class ET_EXPERIMENTAL CudaBackend final
132131
ET_LOG(
133132
Info,
134133
"Writing %zu bytes to %s",
135-
aoti_cuda_buffer->size(),
134+
aoti_dso_buffer->size(),
136135
so_path.c_str());
137136

138137
outfile.write(
139-
static_cast<const char*>(aoti_cuda_buffer->data()),
140-
aoti_cuda_buffer->size());
138+
static_cast<const char*>(aoti_dso_buffer->data()),
139+
aoti_dso_buffer->size());
141140

142141
ET_CHECK_OR_RETURN_ERROR(
143142
outfile, AccessFailed, "Failed to write to file %s", so_path.c_str());
144143

145144
// Finish writing the file to disk
146145
outfile.close();
147146

147+
// Free the buffer immediately after writing to disk
148+
aoti_dso_buffer->Free();
148149
// Load the lib
149150
Result<void*> lib_handle_res = load_library(so_path);
150151
if (!lib_handle_res.ok()) {
@@ -172,6 +173,19 @@ class ET_EXPERIMENTAL CudaBackend final
172173

173174
handle->container_handle = container_handle;
174175

176+
// Look into named data map for constant data
177+
std::string weights_blob_key =
178+
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
179+
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
180+
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
181+
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
182+
const void* weights_blob = buffer_res->data();
183+
// Feed the weights blob into the container. Under the hood it's copying
184+
// weights, so we should free the buffer immediately.
185+
ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob(
186+
handle->container_handle, static_cast<const uint8_t*>(weights_blob)));
187+
buffer_res->Free();
188+
}
175189
// Create a CUDA stream for asynchronous execution
176190
cudaStream_t cuda_stream;
177191
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));

examples/models/moshi/mimi/install_requirements.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
set -x
99

1010
conda install -c conda-forge "ffmpeg<8" -y
11-
pip install torchcodec==0.7.0.dev20250929 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
12-
pip install moshi==0.2.4
13-
pip install bitsandbytes soundfile
11+
pip install torchcodec==0.7.0.dev20251012 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
12+
pip install moshi==0.2.11
13+
pip install bitsandbytes soundfile einops
1414
# Run llama2/install requirements for torchao deps
1515
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
1616
bash "$SCRIPT_DIR"/../../llama/install_requirements.sh

examples/models/moshi/mimi/test_mimi.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ def forward(self, x):
189189
x = self.mimi_model.upsample(x)
190190
(emb,) = self.mimi_model.decoder_transformer(x)
191191
emb.transpose(1, 2)
192-
with self.mimi_model._context_for_encoder_decoder:
193-
out = self.mimi_model.decoder(emb)
192+
out = self.mimi_model.decoder(emb)
194193
return out
195194

196195
emb_input = torch.rand(1, 1, 512, device="cpu")

examples/models/voxtral/multimodal.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ int32_t main(int32_t argc, char** argv) {
319319
// Create multimodal runner
320320
std::unique_ptr<::executorch::extension::llm::MultimodalRunner> runner =
321321
::executorch::extension::llm::create_multimodal_runner(
322-
model_path, std::move(tokenizer), data_path);
322+
model_path, std::move(tokenizer), data_path, Module::LoadMode::Mmap);
323323
if (runner == nullptr) {
324324
ET_LOG(Error, "Failed to create multimodal runner");
325325
return 1;

0 commit comments

Comments
 (0)