Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions backends/apple/metal/metal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,34 +108,62 @@ def preprocess(
options: dict[str, typing.Any] = {
# Do not link against the full PyTorch/libtorch library
"aot_inductor.link_libtorch": False,
# Package model constants and other generated files directly in the shared object (.so) file
"aot_inductor.package_constants_in_so": True,
# Separate weight constants from the .so file
"aot_inductor.package": True,
"aot_inductor.package_constants_in_so": False,
# Store weight constants on disk in a binary blob
"aot_inductor.package_constants_on_disk_format": "binary_blob",
# Enable maximum automatic tuning for optimal performance
"max_autotune": True,
# "aot_inductor.debug_compile": True,
# "aot_inductor.force_mmap_weights": False,
}

with collect_unsupported_fallback_kernels():
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
if len(missing_fallback_kernels) > 0:
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
raise RuntimeError(
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
"Please add them to the AOTI backend."
)

# Extract the .so and .blob paths from the returned list
so_path = None
blob_path = None
for path in paths:
if path.endswith(".wrapper.so"):
so_path = path
elif path.endswith(".wrapper_weights.blob"):
blob_path = path

if so_path is None or blob_path is None:
raise RuntimeError(
f"Could not find required files in compiled paths, got {paths}"
)

# pyre-ignorep[6]: Incompatible parameter type
with open(so_path, "rb") as f:
so_data = f.read()

named_data_store = NamedDataStore()
method_name = MetalBackend.method_name_from_compile_specs(compile_specs)

# Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file.
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)

# Add weights blob to named data store
with open(blob_path, "rb") as f:
blob_data = f.read()

named_data_store.add_named_data(
method_name + "_so_blob", so_data, 1, "aoti_metal_blob"
method_name + "_weights_blob", blob_data, 1, "aoti_metal_blob"
)

# Clean up the generated so file; it has been packaged into the NamdeDataStore
# Clean up the weights blob file
os.remove(blob_path)

# Clean up the generated so file; it has been packaged into the NamedDataStore
# pyre-ignorep[6]: Incompatible parameter type
os.remove(so_path)

Expand Down
26 changes: 26 additions & 0 deletions backends/apple/metal/runtime/metal_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ class ET_EXPERIMENTAL MetalBackend final
Debug,
"MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerRun");

LOAD_SYMBOL(
handle,
update_constants_from_blob,
AOTInductorModelUpdateConstantsFromBlob,
so_handle);
ET_LOG(
Debug,
"MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelUpdateConstantsFromBlob");

ET_LOG(
Debug,
"MetalBackend::load_function_pointers_into_handle - All symbols loaded successfully");
Expand Down Expand Up @@ -203,6 +212,9 @@ class ET_EXPERIMENTAL MetalBackend final
outfile.close();
ET_LOG(Info, "MetalBackend::init - File closed successfully");

// Free the buffer immediately after writing to disk
aoti_metal_buffer->Free();

// Load the ELF using dlopen
void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
ET_CHECK_OR_RETURN_ERROR(
Expand Down Expand Up @@ -234,6 +246,20 @@ class ET_EXPERIMENTAL MetalBackend final

handle->container_handle = container_handle;

// Look into named data map for constant data
std::string weights_blob_key =
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
const void* weights_blob = buffer_res->data();
// Feed the weights blob into the container. Under the hood it's copying
// weights, so we should free the buffer immediately.
ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob(
handle->container_handle, static_cast<const uint8_t*>(weights_blob)));
buffer_res->Free();
}

ET_LOG(Info, "MetalBackend::init - Initialization completed successfully");
return (DelegateHandle*)handle; // Return the handle post-processing
}
Expand Down
1 change: 1 addition & 0 deletions backends/apple/metal/runtime/shims/et_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ extern "C" {

// Memory management functions for Metal
void* metal_allocate_buffer(long bytes);
void metal_deallocate_buffer(void* ptr);
bool metal_is_device_pointer(void* ptr);
int metal_copy_memory(
void* dst,
Expand Down
15 changes: 15 additions & 0 deletions backends/apple/metal/runtime/shims/et_metal.mm
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
}
}

void metal_deallocate_buffer(void* ptr) {
@autoreleasepool {
auto it = ptr_to_mtl_buffer.find(ptr);
if (it != ptr_to_mtl_buffer.end()) {
id<MTLBuffer> buffer = it->second;
[buffer release];
ptr_to_mtl_buffer.erase(it);
ET_LOG(Debug, "Deallocated Metal buffer for pointer %p", ptr);
ptr = nullptr;
} else {
ET_LOG(Error, "Failed to find Metal buffer for pointer %p", ptr);
}
}
}

void metal_cleanup_resources() {
if (!ptr_to_mtl_buffer.empty()) {
@autoreleasepool {
Expand Down
16 changes: 10 additions & 6 deletions backends/apple/metal/runtime/shims/et_metal_ops.mm
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,12 @@ AOTITorchError aoti_torch_mps_convolution(
throw std::runtime_error("Tensor size mismatch");
}

// Store the tensor handle - mark that we own the memory since we manually allocated it with malloc
// Store the tensor handle - mark that we own the memory since we manually allocated it
*ret0 = output_tensor_handle;
is_tensor_own_memory[et_tensor] = true; // We allocated the GPU memory
// Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2
// The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
memory_to_n_tensor[tensor_data] = 1;

ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel);

Expand Down Expand Up @@ -1327,10 +1330,11 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
}

// Mark that we own the memory for these tensors
auto* out_et_tensor = reinterpret_cast<Tensor*>(out_tensor_handle);
auto* attn_et_tensor = reinterpret_cast<Tensor*>(attn_tensor_handle);
is_tensor_own_memory[out_et_tensor] = true;
is_tensor_own_memory[attn_et_tensor] = true;
// Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2
// The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
memory_to_n_tensor[out_contents_ptr] = 1;
memory_to_n_tensor[attn_contents_ptr] = 1;

// Set output tensor handles
*ret0 = out_tensor_handle;
Expand Down
Loading
Loading