Skip to content

Commit d8e6d13

Browse files
Update
[ghstack-poisoned]
1 parent 5d71c9b commit d8e6d13

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

backends/apple/metal/metal_backend.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,34 +108,62 @@ def preprocess(
108108
options: dict[str, typing.Any] = {
109109
# Do not link against the full PyTorch/libtorch library
110110
"aot_inductor.link_libtorch": False,
111-
# Package model constants and other generated files directly in the shared object (.so) file
112-
"aot_inductor.package_constants_in_so": True,
111+
# Separate weight constants from the .so file
112+
"aot_inductor.package": True,
113+
"aot_inductor.package_constants_in_so": False,
114+
# Store weight constants on disk in a binary blob
115+
"aot_inductor.package_constants_on_disk_format": "binary_blob",
113116
# Enable maximum automatic tuning for optimal performance
114117
"max_autotune": True,
115118
# "aot_inductor.debug_compile": True,
116119
# "aot_inductor.force_mmap_weights": False,
117120
}
118121

119122
with collect_unsupported_fallback_kernels():
120-
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
123+
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
121124
if len(missing_fallback_kernels) > 0:
122125
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
123126
raise RuntimeError(
124127
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
125128
"Please add them to the AOTI backend."
126129
)
127130

131+
# Extract the .so and .blob paths from the returned list
132+
so_path = None
133+
blob_path = None
134+
for path in paths:
135+
if path.endswith(".wrapper.so"):
136+
so_path = path
137+
elif path.endswith(".wrapper_weights.blob"):
138+
blob_path = path
139+
140+
if so_path is None or blob_path is None:
141+
raise RuntimeError(
142+
f"Could not find required files in compiled paths, got {paths}"
143+
)
144+
128145
# pyre-ignorep[6]: Incompatible parameter type
129146
with open(so_path, "rb") as f:
130147
so_data = f.read()
131148

132149
named_data_store = NamedDataStore()
133150
method_name = MetalBackend.method_name_from_compile_specs(compile_specs)
151+
152+
# Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file.
153+
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)
154+
155+
# Add weights blob to named data store
156+
with open(blob_path, "rb") as f:
157+
blob_data = f.read()
158+
134159
named_data_store.add_named_data(
135-
method_name + "_so_blob", so_data, 1, "aoti_metal_blob"
160+
method_name + "_weights_blob", blob_data, 1, "aoti_metal_blob"
136161
)
137162

138-
# Clean up the generated so file; it has been packaged into the NamdeDataStore
163+
# Clean up the weights blob file
164+
os.remove(blob_path)
165+
166+
# Clean up the generated so file; it has been packaged into the NamedDataStore
139167
# pyre-ignorep[6]: Incompatible parameter type
140168
os.remove(so_path)
141169

backends/apple/metal/runtime/metal_backend.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,15 @@ class ET_EXPERIMENTAL MetalBackend final
106106
Debug,
107107
"MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerRun");
108108

109+
LOAD_SYMBOL(
110+
handle,
111+
update_constants_from_blob,
112+
AOTInductorModelUpdateConstantsFromBlob,
113+
so_handle);
114+
ET_LOG(
115+
Debug,
116+
"MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelUpdateConstantsFromBlob");
117+
109118
ET_LOG(
110119
Debug,
111120
"MetalBackend::load_function_pointers_into_handle - All symbols loaded successfully");
@@ -203,6 +212,9 @@ class ET_EXPERIMENTAL MetalBackend final
203212
outfile.close();
204213
ET_LOG(Info, "MetalBackend::init - File closed successfully");
205214

215+
// Free the buffer immediately after writing to disk
216+
aoti_metal_buffer->Free();
217+
206218
// Load the ELF using dlopen
207219
void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
208220
ET_CHECK_OR_RETURN_ERROR(
@@ -234,6 +246,20 @@ class ET_EXPERIMENTAL MetalBackend final
234246

235247
handle->container_handle = container_handle;
236248

249+
// Look into named data map for constant data
250+
std::string weights_blob_key =
251+
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
252+
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
253+
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
254+
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
255+
const void* weights_blob = buffer_res->data();
256+
// Feed the weights blob into the container. Under the hood it's copying
257+
// weights, so we should free the buffer immediately.
258+
ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob(
259+
handle->container_handle, static_cast<const uint8_t*>(weights_blob)));
260+
buffer_res->Free();
261+
}
262+
237263
ET_LOG(Info, "MetalBackend::init - Initialization completed successfully");
238264
return (DelegateHandle*)handle; // Return the handle post-processing
239265
}

0 commit comments

Comments
 (0)