From d8e6d136a953f183551ac8d9b82ca8d15cdac366 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 21 Oct 2025 21:15:30 -0700 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- backends/apple/metal/metal_backend.py | 38 ++++++++++++++++--- .../apple/metal/runtime/metal_backend.cpp | 26 +++++++++++++ 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 13a3534004b..ffc1e6e6d7d 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -108,8 +108,11 @@ def preprocess( options: dict[str, typing.Any] = { # Do not link against the full PyTorch/libtorch library "aot_inductor.link_libtorch": False, - # Package model constants and other generated files directly in the shared object (.so) file - "aot_inductor.package_constants_in_so": True, + # Separate weight constants from the .so file + "aot_inductor.package": True, + "aot_inductor.package_constants_in_so": False, + # Store weight constants on disk in a binary blob + "aot_inductor.package_constants_on_disk_format": "binary_blob", # Enable maximum automatic tuning for optimal performance "max_autotune": True, # "aot_inductor.debug_compile": True, @@ -117,7 +120,7 @@ def preprocess( } with collect_unsupported_fallback_kernels(): - so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] + paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] if len(missing_fallback_kernels) > 0: formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) raise RuntimeError( @@ -125,17 +128,42 @@ def preprocess( "Please add them to the AOTI backend." ) + # Extract the .so and .blob paths from the returned list + so_path = None + blob_path = None + for path in paths: + if path.endswith(".wrapper.so"): + so_path = path + elif path.endswith(".wrapper_weights.blob"): + blob_path = path + + if so_path is None or blob_path is None: + raise RuntimeError( + f"Could not find required files in compiled paths, got {paths}" + ) + # pyre-ignorep[6]: Incompatible parameter type with open(so_path, "rb") as f: so_data = f.read() named_data_store = NamedDataStore() method_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file. + named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) + + # Add weights blob to named data store + with open(blob_path, "rb") as f: + blob_data = f.read() + named_data_store.add_named_data( - method_name + "_so_blob", so_data, 1, "aoti_metal_blob" + method_name + "_weights_blob", blob_data, 1, "aoti_metal_blob" ) - # Clean up the generated so file; it has been packaged into the NamdeDataStore + # Clean up the weights blob file + os.remove(blob_path) + + # Clean up the generated so file; it has been packaged into the NamedDataStore # pyre-ignorep[6]: Incompatible parameter type os.remove(so_path) diff --git a/backends/apple/metal/runtime/metal_backend.cpp b/backends/apple/metal/runtime/metal_backend.cpp index 97b273d428f..f79a2a67b6f 100644 --- a/backends/apple/metal/runtime/metal_backend.cpp +++ b/backends/apple/metal/runtime/metal_backend.cpp @@ -106,6 +106,15 @@ class ET_EXPERIMENTAL MetalBackend final Debug, "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelContainerRun"); + LOAD_SYMBOL( + handle, + update_constants_from_blob, + AOTInductorModelUpdateConstantsFromBlob, + so_handle); + ET_LOG( + Debug, + "MetalBackend::load_function_pointers_into_handle - Loaded AOTInductorModelUpdateConstantsFromBlob"); + ET_LOG( Debug, "MetalBackend::load_function_pointers_into_handle - All symbols loaded successfully"); @@ -203,6 +212,9 @@ class ET_EXPERIMENTAL MetalBackend final outfile.close(); ET_LOG(Info, "MetalBackend::init - File closed successfully"); + // Free the buffer immediately after writing to disk + aoti_metal_buffer->Free(); + // Load the ELF using dlopen void* so_handle = dlopen(so_path.c_str(), RTLD_LAZY | RTLD_LOCAL); ET_CHECK_OR_RETURN_ERROR( @@ -234,6 +246,20 @@ class ET_EXPERIMENTAL MetalBackend final handle->container_handle = container_handle; + // Look into named data map for constant data + std::string weights_blob_key = + method_name.empty() ? "weights_blob" : method_name + "_weights_blob"; + auto buffer_res = named_data_map->get_data(weights_blob_key.c_str()); + if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) { + ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str()); + const void* weights_blob = buffer_res->data(); + // Feed the weights blob into the container. Under the hood it's copying + // weights, so we should free the buffer immediately. + ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob( + handle->container_handle, static_cast(weights_blob))); + buffer_res->Free(); + } + ET_LOG(Info, "MetalBackend::init - Initialization completed successfully"); return (DelegateHandle*)handle; // Return the handle post-processing } From 00ba52242cdd965c2b263862be67dd72f32a292a Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 21 Oct 2025 21:15:34 -0700 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 1 + .../apple/metal/runtime/shims/et_metal.mm | 15 ++ .../apple/metal/runtime/shims/et_metal_ops.mm | 16 +- backends/apple/metal/runtime/shims/memory.cpp | 230 ++++++++++++------ backends/apple/metal/runtime/shims/memory.h | 2 +- 5 files changed, 181 insertions(+), 83 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index 75f79e5139c..0e012d18c8f 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -354,6 +354,7 @@ extern "C" { // Memory management functions for Metal void* metal_allocate_buffer(long bytes); +void metal_deallocate_buffer(void* ptr); bool metal_is_device_pointer(void* ptr); int metal_copy_memory( void* dst, diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index fdca0a28cf3..cae8f96c0d2 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -86,6 +86,21 @@ void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) { } } +void metal_deallocate_buffer(void* ptr) { + @autoreleasepool { + auto it = ptr_to_mtl_buffer.find(ptr); + if (it != ptr_to_mtl_buffer.end()) { + id buffer = it->second; + [buffer release]; + ptr_to_mtl_buffer.erase(it); + ET_LOG(Debug, "Deallocated Metal buffer for pointer %p", ptr); + ptr = nullptr; + } else { + ET_LOG(Error, "Failed to find Metal buffer for pointer %p", ptr); + } + } +} + void metal_cleanup_resources() { if (!ptr_to_mtl_buffer.empty()) { @autoreleasepool { diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 0aa90650a1d..94bc3219306 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -736,9 +736,12 @@ AOTITorchError aoti_torch_mps_convolution( throw std::runtime_error("Tensor size mismatch"); } - // Store the tensor handle - mark that we own the memory since we manually allocated it with malloc + // Store the tensor handle - mark that we own the memory since we manually allocated it *ret0 = output_tensor_handle; - is_tensor_own_memory[et_tensor] = true; // We allocated the GPU memory + // Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2 + // The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it + extern std::unordered_map memory_to_n_tensor; + memory_to_n_tensor[tensor_data] = 1; ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel); @@ -1327,10 +1330,11 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( } // Mark that we own the memory for these tensors - auto* out_et_tensor = reinterpret_cast(out_tensor_handle); - auto* attn_et_tensor = reinterpret_cast(attn_tensor_handle); - is_tensor_own_memory[out_et_tensor] = true; - is_tensor_own_memory[attn_et_tensor] = true; + // Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2 + // The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it + extern std::unordered_map memory_to_n_tensor; + memory_to_n_tensor[out_contents_ptr] = 1; + memory_to_n_tensor[attn_contents_ptr] = 1; // Set output tensor handles *ret0 = out_tensor_handle; diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index 83250f308bb..50f8d4e1cd0 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -31,7 +31,12 @@ using namespace executorch::backends::aoti; // Global storage for tensors and their metadata std::unordered_set> tensors; -std::unordered_map is_tensor_own_memory; + +// Reference counting for memory addresses +// Maps memory address to number of tensors using it +// Special value: NOT_OWN (-1) means tensor never owns the memory +constexpr int32_t NOT_OWN = -1; +std::unordered_map memory_to_n_tensor; extern "C" { @@ -110,7 +115,18 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2( // Store the tensor so it doesn't get destroyed tensors.insert(tensor); *ret_new_tensor = tensor.get(); - is_tensor_own_memory[tensor.get()] = false; + + // Check if this memory address is already being tracked + auto memory_it = memory_to_n_tensor.find(adjusted_data); + ET_CHECK_OR_RETURN_ERROR( + memory_it == memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is already being tracked by another tensor", + adjusted_data); + + // Mark this memory as NOT_OWN since tensor created from blob never owns + // memory + memory_to_n_tensor[adjusted_data] = NOT_OWN; ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: successfull"); return Error::Ok; @@ -192,7 +208,9 @@ AOTITorchError aoti_torch_empty_strided( // Store the tensor so it doesn't get destroyed tensors.insert(tensor); *ret_new_tensor = tensor.get(); - is_tensor_own_memory[tensor.get()] = true; + + // This tensor owns the memory it allocated, set reference count to 1 + memory_to_n_tensor[ptr] = 1; ET_LOG(Debug, "aoti_torch_empty_strided: successfull"); return Error::Ok; @@ -200,51 +218,81 @@ AOTITorchError aoti_torch_empty_strided( AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) { ET_LOG(Debug, "aoti_torch_delete_tensor_object: entered"); - // Find tensor in the set + + // Handle null tensor pointer + if (tensor == nullptr) { + ET_LOG(Debug, "aoti_torch_delete_tensor_object: null tensor"); + return Error::Ok; + } + + // Check if tensor exists in our tracking + bool found_in_tensors = false; for (auto it = tensors.begin(); it != tensors.end(); ++it) { if (it->get() == tensor) { - auto tensor_ptr = *it; + found_in_tensors = true; + break; + } + } - // Check ownership before cleaning up - auto ownership_it = is_tensor_own_memory.find(tensor); - bool owns_memory = (ownership_it != is_tensor_own_memory.end()) - ? ownership_it->second - : false; + // If tensor not found in our tracking, it's invalid + ET_CHECK_OR_RETURN_ERROR( + found_in_tensors, InvalidArgument, "Didn't find tensor %p", tensor); - // Clean up ownership metadata - is_tensor_own_memory.erase(tensor); + // Find and delete the tensor + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + // Get the tensor before erasing + auto tensor_ptr = *it; + void* data_ptr = tensor_ptr->mutable_data_ptr(); - if (owns_memory) { - // et tensor owns the memory; need to free it manually - void* data_ptr = tensor_ptr->mutable_data_ptr(); + // Find the reference count for this memory address + auto memory_it = memory_to_n_tensor.find(data_ptr); + if (memory_it != memory_to_n_tensor.end()) { + int32_t ref_count = memory_it->second; - // Check if it's Metal GPU memory - if (metal_is_device_pointer(data_ptr)) { - // This is Metal GPU memory - the Metal helper will handle cleanup - // Metal buffers are automatically managed by ARC when the buffer is - // released + if (ref_count == NOT_OWN) { + // Tensor never owned the memory, skip freeing + // Just remove tensor from tracking tensors.erase(it); - ET_LOG( - Debug, - "aoti_torch_delete_tensor_object: successfull (Metal GPU memory)"); + ET_LOG(Debug, "aoti_torch_delete_tensor_object: tensor doesn't own memory, skipping free"); return Error::Ok; + } else if (ref_count == 1) { + // Only current tensor using this memory, free it + // Check if it's Metal GPU memory + if (metal_is_device_pointer(data_ptr)) { + metal_deallocate_buffer(data_ptr); + } else { + // This is CPU memory - free immediately + free(data_ptr); + data_ptr = nullptr; + ET_LOG(Debug, "aoti_torch_delete_tensor_object: freeing CPU memory"); + } + + // Remove from memory tracking + memory_to_n_tensor.erase(memory_it); + } else if (ref_count > 1) { + // Other tensors still using this memory, just decrement count + memory_to_n_tensor[data_ptr] = ref_count - 1; + ET_LOG(Debug, "aoti_torch_delete_tensor_object: decremented ref count from %d to %d", ref_count, ref_count - 1); } - - // This is CPU memory - free immediately - free(data_ptr); + } else { + ET_CHECK_OR_RETURN_ERROR( + false, + Internal, + "Internal error: memory not found during deletion"); } - // else: Don't free memory since the tensor doesn't own it - // Remove from set (this will call the destructor if it's the last + // Remove tensor from set (this will call the destructor if it's the last // reference) tensors.erase(it); - ET_LOG( - Debug, "aoti_torch_delete_tensor_object: successfull (CPU memory)"); + ET_LOG(Debug, "aoti_torch_delete_tensor_object: successfull"); return Error::Ok; } } - ET_LOG(Error, "Didn't find tensor %p", tensor); - return Error::InvalidArgument; + + // This should never be reached since we found it above + ET_CHECK_OR_RETURN_ERROR( + false, Internal, "Internal error: tensor not found after validation"); } AOTITorchError aoti_torch_copy_( @@ -375,6 +423,24 @@ AOTITorchError aoti_torch__reinterpret_tensor( InvalidArgument, "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null"); + // Check if storage_offset is not 0 - return error if not + ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset)); + + // Get the device info from the source tensor to perform device_index + // validation + int32_t device_type = 0; + int32_t device_index = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type)); + + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index)); + + // Ensure device_index is always 0 + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + // Get the dtype from the source tensor int32_t dtype = 0; ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(self, &dtype)); @@ -382,54 +448,52 @@ AOTITorchError aoti_torch__reinterpret_tensor( // Validate dtype using SupportedDTypes ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); - int32_t device_type = 0; - ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type)); + // Get the original data pointer from the source tensor + void* data_ptr = self->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); - int32_t device_index = 0; - ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index)); + // Check if the given memory is in the map, if not return error + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Convert sizes using utility function from utils.h + std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // Convert strides using utility function from utils.h + std::vector strides = + convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor view that reinterprets the same memory with different + // shape/strides This creates a view, not a copy - the data pointer is shared + std::shared_ptr tensor = executorch::extension::from_blob( + data_ptr, // Reuse the same memory from source tensor + sizes, // New sizes with explicit SizesType + strides, // New strides with explicit StridesType + dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting + ); - // Get the base data pointer from the source tensor - void* base_data_ptr = self->mutable_data_ptr(); ET_CHECK_OR_RETURN_ERROR( - base_data_ptr != nullptr, + tensor != nullptr, InvalidArgument, - "Source tensor has null data pointer"); + "Failed to create reinterpreted tensor view"); - // Calculate new tensor size in elements for logging - int64_t new_numel = 1; - for (int64_t i = 0; i < ndim; i++) { - new_numel *= sizes_ptr[i]; - } + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); - ET_LOG( - Debug, - "aoti_torch__reinterpret_tensor: base_data_ptr=%p, new_numel=%lld, storage_offset=%lld", - base_data_ptr, - new_numel, - storage_offset); - - // Create a new tensor view that shares the same underlying storage - // This is the correct way to implement reinterpret_tensor - as a view, not a - // copy - AOTITorchError create_err = aoti_torch_create_tensor_from_blob_v2( - base_data_ptr, // Same underlying data pointer - ndim, // New dimensions - sizes_ptr, // New sizes - strides_ptr, // New strides - storage_offset, // Storage offset (will be handled properly now) - dtype, - device_type, - device_index, - ret_new_tensor, - 0, // layout (default) - nullptr, // opaque_metadata - 0 // opaque_metadata_size - ); + *ret_new_tensor = tensor.get(); - if (create_err != Error::Ok) { - ET_LOG(Error, "failed to create reinterpreted tensor view"); - return create_err; - } + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull"); return Error::Ok; @@ -437,13 +501,27 @@ AOTITorchError aoti_torch__reinterpret_tensor( // Cleanup function for clearing global state void cleanup_memory() { - is_tensor_own_memory.clear(); - if (!tensors.empty()) { - ET_LOG(Error, "Warning: tensors not empty during cleanup"); + // Use aoti_torch_delete_tensor_object to properly delete each tensor + // Note: We need to collect tensor pointers first since deletion modifies the + // set + std::vector tensor_ptrs; + tensor_ptrs.reserve(tensors.size()); + for (const auto& tensor_shared : tensors) { + tensor_ptrs.push_back(tensor_shared.get()); + } + + // Now delete each tensor - this will modify the global tensors set + for (Tensor* tensor_ptr : tensor_ptrs) { + aoti_torch_delete_tensor_object(tensor_ptr); } + // tensors set should now be empty, but ensure it's cleared + tensors.clear(); + // Clean up Metal resources metal_cleanup_resources(); + + ET_LOG(Info, "Cleared all tensors and Metal resources"); } } // extern "C" diff --git a/backends/apple/metal/runtime/shims/memory.h b/backends/apple/metal/runtime/shims/memory.h index 47fb6352b50..5f48fd921c6 100644 --- a/backends/apple/metal/runtime/shims/memory.h +++ b/backends/apple/metal/runtime/shims/memory.h @@ -22,7 +22,7 @@ namespace metal { extern "C" { // Global storage declarations -extern std::unordered_map is_tensor_own_memory; +extern std::unordered_map memory_to_n_tensor; extern std::unordered_set> tensors; // Memory-related operations From 136f908a2e6396c6964f98f2f3843c2e7ba22171 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 21 Oct 2025 21:15:39 -0700 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- .../apple/metal/runtime/shims/et_metal.mm | 14 +++-- .../apple/metal/runtime/shims/et_metal_ops.mm | 57 +++++++++++++------ 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index cae8f96c0d2..2ba058de40a 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -680,12 +680,16 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev // Commit methods void ETMetalStream::commit() { - if (enableCommitAndContinue_ && commandBuffer_) { - // Use commit-and-continue for better performance - commitAndContinue(); - } else { - flush(); + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commit: No command buffer to commit"); + return; } + + [commandBuffer_ commit]; + ET_LOG(Debug, "ETMetalStream::commit: Committed buffer %p", commandBuffer_); + + [commandBuffer_ release]; + commandBuffer_ = nil; } void ETMetalStream::commitAndWait() { diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 94bc3219306..7e1fa66ac7c 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -270,7 +270,7 @@ AOTITorchError aoti_torch_mps_mm_out( @try { // Use stream helper to encode and synchronize correctly - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); } @catch (NSException *exception) { ET_LOG(Error, "aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]); @@ -279,6 +279,14 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully"); + // Release MPSGraph to prevent memory leak + [mpsGraph release]; + mpsGraph = nil; + + [selfData release]; + [mat2Data release]; + [outputData release]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Executed successfully"); return Error::Ok; @@ -616,14 +624,16 @@ AOTITorchError aoti_torch_mps_convolution( feeds[inputPlaceholder] = inputData; feeds[weightPlaceholder] = weightData; + MPSGraphTensorData* biasData = nil; + // Add bias data to feeds if provided if (bias_tensor && biasPlaceholder) { id bias_buffer = get_mtl_buffer(bias_tensor, "aoti_torch_mps_convolution", "bias"); NSArray* biasShape = @[@(C_out)]; - MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer - shape:biasShape - dataType:mps_dtype]; + biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer + shape:biasShape + dataType:mps_dtype]; feeds[biasPlaceholder] = biasData; ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds"); @@ -650,7 +660,7 @@ AOTITorchError aoti_torch_mps_convolution( @try { // Use stream helper to encode and synchronize correctly - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); } @catch (NSException *exception) { ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]); @@ -743,6 +753,15 @@ AOTITorchError aoti_torch_mps_convolution( extern std::unordered_map memory_to_n_tensor; memory_to_n_tensor[tensor_data] = 1; + // Release MPSGraph to prevent memory leak + [mpsGraph release]; + mpsGraph = nil; + + [inputData release]; + [weightData release]; + if (biasData) [biasData release]; + [outputData release]; + ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel); ET_LOG(Debug, "aoti_torch_mps_convolution: Executed successfully"); @@ -992,14 +1011,6 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MPSGraph scaledDotProductAttention"); @try { - // Check if scaledDotProductAttentionWithQueryTensor is available - MPSGraph* testGraph = [MPSGraph new]; - if (![testGraph respondsToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)]) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API not available on this system"); - throw std::runtime_error("scaledDotProductAttentionWithQueryTensor API not available on this system"); - } - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API is available"); - // Create MPSGraph for scaled dot product attention MPSGraph* mpsGraph = [MPSGraph new]; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance"); @@ -1246,6 +1257,8 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( feeds[valuePlaceholder] = valueData; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added input tensors to feeds"); + MPSGraphTensorData* maskData = nil; + // Add explicit mask data to feeds if provided if (explicitMaskPlaceholder && attn_mask && *attn_mask) { auto* mask_tensor = reinterpret_cast(*attn_mask); @@ -1257,9 +1270,9 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( [maskShapeArray addObject:@(mask_tensor->sizes()[i])]; } - MPSGraphTensorData* maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer - shape:maskShapeArray - dataType:mps_dtype]; + maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer + shape:maskShapeArray + dataType:mps_dtype]; feeds[explicitMaskPlaceholder] = maskData; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added explicit mask tensor to feeds"); } @@ -1275,9 +1288,19 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( // Execute via shared stream and keep results on GPU ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executing MPSGraph using stream"); - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph execution completed successfully"); + // Release MPSGraph to prevent memory leak + [mpsGraph release]; + mpsGraph = nil; + + [queryData release]; + [keyData release]; + [valueData release]; + if (maskData) [maskData release]; + [outputData release]; + } @catch (NSException *exception) { ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: NSException caught: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]); From 5bd34ee8ed29147ae2171071eb1f00109144e4cd Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 21 Oct 2025 21:19:21 -0700 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/memory.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index 50f8d4e1cd0..b5d2d3161ae 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -254,7 +254,9 @@ AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) { // Tensor never owned the memory, skip freeing // Just remove tensor from tracking tensors.erase(it); - ET_LOG(Debug, "aoti_torch_delete_tensor_object: tensor doesn't own memory, skipping free"); + ET_LOG( + Debug, + "aoti_torch_delete_tensor_object: tensor doesn't own memory, skipping free"); return Error::Ok; } else if (ref_count == 1) { // Only current tensor using this memory, free it @@ -265,7 +267,8 @@ AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) { // This is CPU memory - free immediately free(data_ptr); data_ptr = nullptr; - ET_LOG(Debug, "aoti_torch_delete_tensor_object: freeing CPU memory"); + ET_LOG( + Debug, "aoti_torch_delete_tensor_object: freeing CPU memory"); } // Remove from memory tracking @@ -273,7 +276,11 @@ AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) { } else if (ref_count > 1) { // Other tensors still using this memory, just decrement count memory_to_n_tensor[data_ptr] = ref_count - 1; - ET_LOG(Debug, "aoti_torch_delete_tensor_object: decremented ref count from %d to %d", ref_count, ref_count - 1); + ET_LOG( + Debug, + "aoti_torch_delete_tensor_object: decremented ref count from %d to %d", + ref_count, + ref_count - 1); } } else { ET_CHECK_OR_RETURN_ERROR( From fa097b8ca4aaf0b455e201c33a8f7b6177f8eaed Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 21 Oct 2025 22:25:53 -0700 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- .../apple/metal/runtime/shims/et_metal_ops.mm | 416 ++++++++++++------ 1 file changed, 279 insertions(+), 137 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 7e1fa66ac7c..b150c68fe6d 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -32,8 +32,69 @@ // Declare the global mapping from et_metal.mm extern std::unordered_map> ptr_to_mtl_buffer; +// ======================= +// MPSGraph Caching Infrastructure +// ======================= + namespace { +// Cache key structure for different operations +struct GraphCacheKey { + std::string op_name; + std::vector shape_params; + int32_t dtype; + bool transpose_flag; + + bool operator==(const GraphCacheKey& other) const { + return op_name == other.op_name && + shape_params == other.shape_params && + dtype == other.dtype && + transpose_flag == other.transpose_flag; + } +}; + +// Hash function for GraphCacheKey +struct GraphCacheKeyHash { + std::size_t operator()(const GraphCacheKey& key) const { + std::size_t hash = std::hash{}(key.op_name); + for (auto val : key.shape_params) { + hash ^= std::hash{}(val) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + hash ^= std::hash{}(key.dtype) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + hash ^= std::hash{}(key.transpose_flag) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + return hash; + } +}; + +// Struct to store both the compiled graph and its tensors for reuse +struct CachedGraph { + MPSGraph* graph; + MPSGraphTensor* input1; + MPSGraphTensor* input2; + MPSGraphTensor* input3; // Optional (e.g., bias, mask) + MPSGraphTensor* output; +}; + +// Global cache for compiled MPSGraphs +// These graphs are never released - they're reused across calls +static std::unordered_map graph_cache; + +// Statistics for monitoring cache effectiveness +struct CacheStats { + size_t hits = 0; + size_t misses = 0; + + void logStats() { + if ((hits + misses) % 100 == 0 && (hits + misses) > 0) { + double hit_rate = 100.0 * hits / (hits + misses); + ET_LOG(Debug, "MPSGraph cache stats: %zu hits, %zu misses (%.1f%% hit rate)", + hits, misses, hit_rate); + } + } +}; + +static CacheStats cache_stats; + // Helper function to get Metal buffer from the global mapping static id get_mtl_buffer(Tensor* tensor, const char* op_name, const char* tensor_name) { void* data_ptr = tensor->mutable_data_ptr(); @@ -61,7 +122,7 @@ return it->second; } -} // namespace +} // anonymous namespace extern "C" { @@ -180,13 +241,8 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: dtype=%d, element_size=%zu", dtype, element_size); ET_LOG(Debug, "aoti_torch_mps_mm_out: M=%lld, K=%lld, N=%lld", M, K, N); - // Create MPSGraph for matrix multiplication - MPSGraph* mpsGraph = [MPSGraph new]; - ET_LOG(Debug, "aoti_torch_mps_mm_out: Created MPSGraph instance"); - - // Define tensor shapes for placeholders + // Define tensor shapes for placeholders (needed for both cache hit and miss) NSArray* selfShape = @[@(M), @(K)]; - NSArray* outShape = @[@(M), @(N)]; // For mat2, we need to handle both contiguous and transposed cases // If mat2 is transposed, its physical layout in memory is [N, K] (column-major) @@ -202,43 +258,91 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 physical shape (contiguous): [%d,%d]", (int)K, (int)N); } - ET_LOG(Debug, "aoti_torch_mps_mm_out: Creating placeholders with shapes self:[%d,%d] mat2:[%d,%d]", - (int)M, (int)K, - mat2_is_transposed ? (int)N : (int)K, - mat2_is_transposed ? (int)K : (int)N); + // Create cache key for this matrix multiplication + GraphCacheKey cache_key; + cache_key.op_name = "mm"; + cache_key.shape_params = {M, K, N}; + cache_key.dtype = dtype; + cache_key.transpose_flag = mat2_is_transposed; + + // Check if we have a cached graph + MPSGraph* mpsGraph = nullptr; + MPSGraphTensor* mmOutput = nil; + MPSGraphTensor* selfPlaceholder = nil; + MPSGraphTensor* mat2Placeholder = nil; + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + // Cache hit - reuse compiled graph and tensor references + CachedGraph& cached = cache_it->second; + mpsGraph = cached.graph; + selfPlaceholder = cached.input1; + mat2Placeholder = cached.input2; + mmOutput = cached.output; + + cache_stats.hits++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using cached MPSGraph (cache hit, %zu total hits)", cache_stats.hits); - // Create placeholders for input tensors - MPSGraphTensor* selfPlaceholder = [mpsGraph placeholderWithShape:selfShape - dataType:mps_dtype - name:@"self"]; - MPSGraphTensor* mat2Placeholder = [mpsGraph placeholderWithShape:mat2PhysicalShape - dataType:mps_dtype - name:@"mat2_physical"]; + } else { + // Cache miss - create and compile new graph + mpsGraph = [MPSGraph new]; + cache_stats.misses++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created new MPSGraph instance (cache miss, %zu total misses)", cache_stats.misses); + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Creating placeholders with shapes self:[%d,%d] mat2:[%d,%d]", + (int)M, (int)K, + mat2_is_transposed ? (int)N : (int)K, + mat2_is_transposed ? (int)K : (int)N); + + // Create placeholders for input tensors + selfPlaceholder = [mpsGraph placeholderWithShape:selfShape + dataType:mps_dtype + name:@"self"]; + mat2Placeholder = [mpsGraph placeholderWithShape:mat2PhysicalShape + dataType:mps_dtype + name:@"mat2_physical"]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created input placeholders"); + + // If mat2 is transposed, apply transpose operation in the graph to get the logical shape + MPSGraphTensor* mat2Logical; + if (mat2_is_transposed) { + // Transpose from physical [N, K] to logical [K, N] + // MPSGraph transposeTensor swaps the last two dimensions for 2D tensors + mat2Logical = [mpsGraph transposeTensor:mat2Placeholder + dimension:-2 + withDimension:-1 + name:@"mat2_transposed"]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Applied transpose operation to mat2 in graph"); + } else { + // No transpose needed, use placeholder directly + mat2Logical = mat2Placeholder; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using mat2 placeholder directly (no transpose needed)"); + } - ET_LOG(Debug, "aoti_torch_mps_mm_out: Created input placeholders"); + // Perform matrix multiplication using MPSGraph with the logical mat2 tensor + mmOutput = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder + secondaryTensor:mat2Logical + name:@"matrix_multiplication"]; - // If mat2 is transposed, apply transpose operation in the graph to get the logical shape - MPSGraphTensor* mat2Logical; - if (mat2_is_transposed) { - // Transpose from physical [N, K] to logical [K, N] - // MPSGraph transposeTensor swaps the last two dimensions for 2D tensors - mat2Logical = [mpsGraph transposeTensor:mat2Placeholder - dimension:-2 - withDimension:-1 - name:@"mat2_transposed"]; - ET_LOG(Debug, "aoti_torch_mps_mm_out: Applied transpose operation to mat2 in graph"); - } else { - // No transpose needed, use placeholder directly - mat2Logical = mat2Placeholder; - ET_LOG(Debug, "aoti_torch_mps_mm_out: Using mat2 placeholder directly (no transpose needed)"); - } + ET_LOG(Debug, "aoti_torch_mps_mm_out: Successfully created matrix multiplication tensor"); + + // Cache the compiled graph and tensor references for reuse + CachedGraph cached_graph; + cached_graph.graph = mpsGraph; + cached_graph.input1 = selfPlaceholder; + cached_graph.input2 = mat2Placeholder; + cached_graph.input3 = nil; + cached_graph.output = mmOutput; + graph_cache[cache_key] = cached_graph; - // Perform matrix multiplication using MPSGraph with the logical mat2 tensor - MPSGraphTensor* mmOutput = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder - secondaryTensor:mat2Logical - name:@"matrix_multiplication"]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Cached compiled MPSGraph for future reuse"); + } // End of cache miss/hit block - ET_LOG(Debug, "aoti_torch_mps_mm_out: Successfully created matrix multiplication tensor"); + // Define output shape + NSArray* outShape = @[@(M), @(N)]; // Create feeds dictionary for graph execution NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; @@ -279,10 +383,6 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully"); - // Release MPSGraph to prevent memory leak - [mpsGraph release]; - mpsGraph = nil; - [selfData release]; [mat2Data release]; [outputData release]; @@ -502,106 +602,150 @@ AOTITorchError aoti_torch_mps_convolution( ET_LOG(Debug, "aoti_torch_mps_convolution: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); - // Create MPSGraph for convolution - MPSGraph* mpsGraph = [MPSGraph new]; - ET_LOG(Debug, "aoti_torch_mps_convolution: Created MPSGraph instance"); - - // Define tensor shapes for placeholders (always 4D NCHW for MPSGraph) + // Define tensor shapes for placeholders (needed for both cache hit and miss) NSArray* inputShape = @[@(N), @(C_in), @(H_in), @(W_in)]; NSArray* weightShape = @[@(C_out), @(C_in), @(kernel_h), @(kernel_w)]; - ET_LOG(Debug, "aoti_torch_mps_convolution: Creating placeholders with shapes input:[%d,%d,%d,%d] weight:[%d,%d,%d,%d]", - (int)N, (int)C_in, (int)H_in, (int)W_in, - (int)C_out, (int)C_in, (int)kernel_h, (int)kernel_w); - - // Create placeholders for input tensors - MPSGraphTensor* inputPlaceholder = [mpsGraph placeholderWithShape:inputShape - dataType:mps_dtype - name:@"input"]; - MPSGraphTensor* weightPlaceholder = [mpsGraph placeholderWithShape:weightShape - dataType:mps_dtype - name:@"weight"]; - - ET_LOG(Debug, "aoti_torch_mps_convolution: Created input and weight placeholders"); - - // Create convolution descriptor - MPSGraphConvolution2DOpDescriptor* convDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w - strideInY:stride_h - dilationRateInX:dil_w - dilationRateInY:dil_h - groups:groups - paddingLeft:pad_w - paddingRight:pad_w - paddingTop:pad_h - paddingBottom:pad_h - paddingStyle:MPSGraphPaddingStyleExplicit - dataLayout:MPSGraphTensorNamedDataLayoutNCHW - weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; - - ET_LOG(Debug, "aoti_torch_mps_convolution: Created convolution descriptor with stride=[%lld,%lld], padding=[%lld,%lld], dilation=[%lld,%lld], groups=%lld", - stride_w, stride_h, pad_w, pad_h, dil_w, dil_h, groups); - - // Perform convolution using MPSGraph + // Create cache key for this convolution + GraphCacheKey cache_key; + cache_key.op_name = "conv"; + cache_key.shape_params = {N, C_in, H_in, W_in, C_out, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dil_h, dil_w, groups}; + cache_key.dtype = dtype; + cache_key.transpose_flag = (transposed != 0); + + // Check if we have a cached graph + MPSGraph* mpsGraph = nullptr; MPSGraphTensor* convOutput = nil; - if (transposed) { - ET_LOG(Debug, "aoti_torch_mps_convolution: Using transposed convolution"); - // For transposed convolution, we need to handle output padding - int64_t output_pad_h = output_padding && output_padding_len_ > 0 ? output_padding[0] : 0; - int64_t output_pad_w = output_padding && output_padding_len_ > 1 ? output_padding[1] : 0; - - // For transposed convolution, we need to adjust the padding calculation - // In transposed convolution, the effective padding is typically negative - // and we use output_padding to control the final output size - int64_t transposed_pad_h = pad_h - output_pad_h; - int64_t transposed_pad_w = pad_w - output_pad_w; - - // Create transposed convolution descriptor with adjusted padding - MPSGraphConvolution2DOpDescriptor* transposedConvDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w - strideInY:stride_h - dilationRateInX:dil_w - dilationRateInY:dil_h - groups:groups - paddingLeft:transposed_pad_w - paddingRight:transposed_pad_w - paddingTop:transposed_pad_h - paddingBottom:transposed_pad_h - paddingStyle:MPSGraphPaddingStyleExplicit - dataLayout:MPSGraphTensorNamedDataLayoutNCHW - weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; - - convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder - weightsTensor:weightPlaceholder - descriptor:transposedConvDesc - name:@"transposed_convolution"]; + MPSGraphTensor* finalOutput = nil; + MPSGraphTensor* inputPlaceholder = nil; + MPSGraphTensor* weightPlaceholder = nil; + MPSGraphTensor* biasPlaceholder = nil; + bool has_bias = (bias_tensor != nullptr); + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + // Cache hit - reuse compiled graph and tensor references + CachedGraph& cached = cache_it->second; + mpsGraph = cached.graph; + inputPlaceholder = cached.input1; + weightPlaceholder = cached.input2; + biasPlaceholder = cached.input3; // May be nil if no bias + finalOutput = cached.output; + + cache_stats.hits++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_convolution: Using cached MPSGraph (cache hit, %zu total hits)", cache_stats.hits); + } else { - ET_LOG(Debug, "aoti_torch_mps_convolution: Using regular convolution"); - convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder - weightsTensor:weightPlaceholder - descriptor:convDesc - name:@"convolution"]; - } + // Cache miss - create and compile new graph + mpsGraph = [MPSGraph new]; + cache_stats.misses++; + cache_stats.logStats(); + ET_LOG(Debug, "aoti_torch_mps_convolution: Created new MPSGraph instance (cache miss, %zu total misses)", cache_stats.misses); + + ET_LOG(Debug, "aoti_torch_mps_convolution: Creating placeholders with shapes input:[%d,%d,%d,%d] weight:[%d,%d,%d,%d]", + (int)N, (int)C_in, (int)H_in, (int)W_in, + (int)C_out, (int)C_in, (int)kernel_h, (int)kernel_w); + + // Create placeholders for input tensors + inputPlaceholder = [mpsGraph placeholderWithShape:inputShape + dataType:mps_dtype + name:@"input"]; + weightPlaceholder = [mpsGraph placeholderWithShape:weightShape + dataType:mps_dtype + name:@"weight"]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created input and weight placeholders"); + + // Create convolution descriptor + MPSGraphConvolution2DOpDescriptor* convDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:pad_w + paddingRight:pad_w + paddingTop:pad_h + paddingBottom:pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created convolution descriptor with stride=[%lld,%lld], padding=[%lld,%lld], dilation=[%lld,%lld], groups=%lld", + stride_w, stride_h, pad_w, pad_h, dil_w, dil_h, groups); + + // Perform convolution using MPSGraph + if (transposed) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using transposed convolution"); + // For transposed convolution, we need to handle output padding + int64_t output_pad_h = output_padding && output_padding_len_ > 0 ? output_padding[0] : 0; + int64_t output_pad_w = output_padding && output_padding_len_ > 1 ? output_padding[1] : 0; + + // For transposed convolution, we need to adjust the padding calculation + // In transposed convolution, the effective padding is typically negative + // and we use output_padding to control the final output size + int64_t transposed_pad_h = pad_h - output_pad_h; + int64_t transposed_pad_w = pad_w - output_pad_w; + + // Create transposed convolution descriptor with adjusted padding + MPSGraphConvolution2DOpDescriptor* transposedConvDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:transposed_pad_w + paddingRight:transposed_pad_w + paddingTop:transposed_pad_h + paddingBottom:transposed_pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:transposedConvDesc + name:@"transposed_convolution"]; + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using regular convolution"); + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:convDesc + name:@"convolution"]; + } - ET_LOG(Debug, "aoti_torch_mps_convolution: Successfully created convolution tensor"); + ET_LOG(Debug, "aoti_torch_mps_convolution: Successfully created convolution tensor"); - // Handle bias if provided - MPSGraphTensor* finalOutput = convOutput; - MPSGraphTensor* biasPlaceholder = nil; - if (bias_tensor) { - ET_LOG(Debug, "aoti_torch_mps_convolution: Adding bias to convolution output"); + // Handle bias if provided + if (bias_tensor) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Adding bias to convolution output"); - // Create bias placeholder - NSArray* biasShape = @[@(C_out)]; - biasPlaceholder = [mpsGraph placeholderWithShape:biasShape - dataType:mps_dtype - name:@"bias"]; + // Create bias placeholder + NSArray* biasShape = @[@(C_out)]; + biasPlaceholder = [mpsGraph placeholderWithShape:biasShape + dataType:mps_dtype + name:@"bias"]; - // Add bias to convolution output - finalOutput = [mpsGraph additionWithPrimaryTensor:convOutput - secondaryTensor:biasPlaceholder - name:@"add_bias"]; + // Add bias to convolution output + finalOutput = [mpsGraph additionWithPrimaryTensor:convOutput + secondaryTensor:biasPlaceholder + name:@"add_bias"]; - ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias placeholder to graph"); - } + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias placeholder to graph"); + } else { + finalOutput = convOutput; + } + + // Cache the compiled graph and tensor references for reuse + CachedGraph cached_graph; + cached_graph.graph = mpsGraph; + cached_graph.input1 = inputPlaceholder; + cached_graph.input2 = weightPlaceholder; + cached_graph.input3 = biasPlaceholder; // May be nil if no bias + cached_graph.output = finalOutput; + graph_cache[cache_key] = cached_graph; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Cached compiled MPSGraph for future reuse"); + } // End of cache miss block // Create feeds dictionary for graph execution NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; @@ -748,15 +892,12 @@ AOTITorchError aoti_torch_mps_convolution( // Store the tensor handle - mark that we own the memory since we manually allocated it *ret0 = output_tensor_handle; + // Mark that we own the memory for these tensors // Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2 // The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it extern std::unordered_map memory_to_n_tensor; memory_to_n_tensor[tensor_data] = 1; - // Release MPSGraph to prevent memory leak - [mpsGraph release]; - mpsGraph = nil; - [inputData release]; [weightData release]; if (biasData) [biasData release]; @@ -1012,6 +1153,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( @try { // Create MPSGraph for scaled dot product attention + // TODO: Implement caching for attention operation similar to mm and convolution MPSGraph* mpsGraph = [MPSGraph new]; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance");