Skip to content

Commit 6a6ba04

Browse files
Update
[ghstack-poisoned]
2 parents 750badf + 2667a0c commit 6a6ba04

File tree

3 files changed

+172
-33
lines changed

3 files changed

+172
-33
lines changed

backends/apple/metal/runtime/shims/et_metal.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,35 @@ enum class SyncType {
7777
// =======================
7878
// ETMetalShaderLibrary - ExecuTorch Metal shader library management
7979
// =======================
80+
81+
/**
82+
* @class ETMetalShaderLibrary
83+
* @brief Manages Metal shader library compilation and kernel function retrieval.
84+
*
85+
* This class provides a high-level interface for compiling Metal shading language
86+
* source code into a Metal library and creating compute pipeline states for
87+
* kernel functions. It handles the creation and caching of Metal compute pipeline
88+
* states and functions, which should be reused across multiple kernel dispatches.
89+
*
90+
* The class automatically compiles the provided shader source code upon construction
91+
* and maintains an internal cache of compute pipeline states for different kernel
92+
* functions to avoid redundant compilation.
93+
*
94+
* Example usage:
95+
* @code
96+
* std::string shaderSource = R"(
97+
* #include <metal_stdlib>
98+
* using namespace metal;
99+
* kernel void my_kernel(device float* data [[buffer(0)]],
100+
* uint tid [[thread_position_in_grid]]) {
101+
* data[tid] = data[tid] * 2.0;
102+
* }
103+
* )";
104+
*
105+
* ETMetalShaderLibrary library(shaderSource);
106+
* auto kernelFunction = library.getKernelFunction("my_kernel");
107+
* @endcode
108+
*/
80109
class ETMetalShaderLibrary {
81110
public:
82111
ETMetalShaderLibrary(const std::string& source);
@@ -103,6 +132,45 @@ class ETMetalShaderLibrary {
103132
// =======================
104133
// ETMetalKernelFunction - ExecuTorch Metal kernel function execution
105134
// =======================
135+
136+
/**
137+
* @class ETMetalKernelFunction
138+
* @brief Represents a Metal compute kernel function ready for execution.
139+
*
140+
* This class encapsulates a Metal compute pipeline state and function, providing
141+
* a high-level interface for setting kernel arguments and dispatching compute
142+
* work to the GPU. It handles the encoding of compute commands and manages the
143+
* interaction with Metal's compute command encoder.
144+
*
145+
* The class supports different dispatch patterns:
146+
* - Single-dimension dispatch for linear workloads
147+
* - Multi-dimensional dispatch for grid-based workloads
148+
* - Custom thread group sizes for performance optimization
149+
*
150+
* Kernel arguments can be set using tensors (which will be mapped to Metal buffers)
151+
* or scalar values. The class handles the encoding of these arguments
152+
* into the compute command encoder.
153+
*
154+
* Example usage:
155+
* @code
156+
* // Get kernel function from library
157+
* auto kernelFunction = library.getKernelFunction("vector_add");
158+
*
159+
* // Start encoding commands
160+
* kernelFunction->startEncoding();
161+
*
162+
* // Set tensor arguments
163+
* kernelFunction->setArg(0, inputTensorA);
164+
* kernelFunction->setArg(1, inputTensorB);
165+
* kernelFunction->setArg(2, outputTensor);
166+
*
167+
* // Set scalar argument
168+
* kernelFunction->setArg(3, static_cast<int64_t>(numElements));
169+
*
170+
* // Dispatch for linear workload
171+
* kernelFunction->dispatchSingle(numElements);
172+
* @endcode
173+
*/
106174
class ETMetalKernelFunction {
107175
public:
108176
ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func);
@@ -132,6 +200,45 @@ class ETMetalKernelFunction {
132200
// =======================
133201
// ETMetalStream - Metal command buffer and synchronization management
134202
// =======================
203+
204+
/**
205+
* @class ETMetalStream
206+
* @brief Manages Metal compute command streams and provides GPU synchronization.
207+
*
208+
* This class serves as the central management hub for Metal GPU operations, providing
209+
* a stream-based abstraction similar to CUDA streams. It handles command buffer lifecycle,
210+
* compute command encoder management, and various synchronization patterns required for
211+
* efficient GPU computation.
212+
*
213+
* Key features:
214+
* - Lazy command buffer and encoder creation for optimal resource usage
215+
* - Thread-safe operations using serial dispatch queues
216+
* - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, COMMIT_AND_CONTINUE)
217+
* - Kernel coalescing to batch multiple operations efficiently
218+
* - MPSGraph integration for high-level neural network operations
219+
* - Memory operations (copy, fill) with GPU acceleration via blit encoders
220+
*
221+
* The stream follows PyTorch's MPS stream design patterns, providing similar semantics
222+
* for command buffer management and synchronization.
223+
*
224+
* Example usage:
225+
* @code
226+
* // Get current stream (typically the default stream)
227+
* ETMetalStream* stream = getCurrentMetalStream();
228+
*
229+
* // Execute kernel operations (handled automatically)
230+
* auto kernelFunction = library.getKernelFunction("my_kernel");
231+
* kernelFunction->startEncoding();
232+
* kernelFunction->setArg(0, inputTensor);
233+
* kernelFunction->dispatchSingle(numElements);
234+
*
235+
* // Synchronize to ensure completion
236+
* stream->synchronize(SyncType::COMMIT_AND_WAIT);
237+
*
238+
* // Copy between GPU buffers using blit encoder
239+
* stream->copy(srcBuffer, dstBuffer, numBytes, 0, 0, SyncType::COMMIT);
240+
* @endcode
241+
*/
135242
class ETMetalStream {
136243
public:
137244
ETMetalStream();

backends/apple/metal/runtime/shims/et_metal.mm

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,26 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
743743

744744
void ETMetalStream::copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer, size_t length,
745745
size_t srcOffset, size_t dstOffset, SyncType syncType) {
746+
747+
if (length == 0) {
748+
return;
749+
750+
// Check that offsets are within buffer bounds before copying
751+
if (!srcBuffer || !dstBuffer) {
752+
ET_LOG(Error, "ETMetalStream::copy: Source or destination buffer is nil");
753+
return;
754+
}
755+
NSUInteger srcBufferLength = [srcBuffer length];
756+
NSUInteger dstBufferLength = [dstBuffer length];
757+
if (srcOffset + length > srcBufferLength) {
758+
ET_LOG(Error, "ETMetalStream::copy: Source offset (%zu) + length (%zu) exceeds source buffer size (%zu)", srcOffset, length, srcBufferLength);
759+
return;
760+
}
761+
if (dstOffset + length > dstBufferLength) {
762+
ET_LOG(Error, "ETMetalStream::copy: Destination offset (%zu) + length (%zu) exceeds destination buffer size (%zu)", dstOffset, length, dstBufferLength);
763+
return;
764+
}
765+
746766
dispatch_sync(serialQueue_, ^{
747767
@autoreleasepool {
748768
endKernelCoalescing();
@@ -792,8 +812,6 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
792812
targetOperations:nil
793813
resultsDictionary:results
794814
executionDescriptor:nil];
795-
796-
//synchronize(syncType);
797815
}
798816
});
799817
}

backends/apple/metal/runtime/shims/shim_mps.mm

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -97,31 +97,33 @@ AOTITorchError aoti_torch_mps_get_kernel_function(
9797
return Error::InvalidArgument;
9898
}
9999

100-
try {
101-
auto* library = reinterpret_cast<ETMetalShaderLibrary*>(library_handle);
102-
auto function_shared_ptr = library->getKernelFunction(std::string(kernel_name));
103-
if (!function_shared_ptr) {
104-
ET_LOG(Error, "aoti_torch_mps_get_kernel_function: Failed to get kernel function '%s'", kernel_name);
105-
return Error::Internal;
106-
}
100+
@autoreleasepool {
101+
try {
102+
auto* library = reinterpret_cast<ETMetalShaderLibrary*>(library_handle);
103+
auto function_shared_ptr = library->getKernelFunction(std::string(kernel_name));
104+
if (!function_shared_ptr) {
105+
ET_LOG(Error, "aoti_torch_mps_get_kernel_function: Failed to get kernel function '%s'", kernel_name);
106+
return Error::Internal;
107+
}
107108

108-
auto* raw_function = function_shared_ptr.get();
109+
auto* raw_function = function_shared_ptr.get();
109110

110-
// Store the shared_ptr to keep the object alive
111-
storeFunctionHandle(raw_function, function_shared_ptr);
111+
// Store the shared_ptr to keep the object alive
112+
storeFunctionHandle(raw_function, function_shared_ptr);
112113

113-
// Return raw pointer to match existing API
114-
*function_handle = reinterpret_cast<AOTIMetalKernelFunctionHandle>(raw_function);
114+
// Return raw pointer to match existing API
115+
*function_handle = reinterpret_cast<AOTIMetalKernelFunctionHandle>(raw_function);
115116

116-
ET_LOG(Debug, "aoti_torch_mps_get_kernel_function: Got kernel function '%s' -> %p", kernel_name, raw_function);
117-
return Error::Ok;
117+
ET_LOG(Debug, "aoti_torch_mps_get_kernel_function: Got kernel function '%s' -> %p", kernel_name, raw_function);
118+
return Error::Ok;
118119

119-
} catch (const std::exception& e) {
120-
ET_LOG(Error, "aoti_torch_mps_get_kernel_function exception: %s", e.what());
121-
return Error::Internal;
122-
} catch (...) {
123-
ET_LOG(Error, "aoti_torch_mps_get_kernel_function: unknown exception");
124-
return Error::Internal;
120+
} catch (const std::exception& e) {
121+
ET_LOG(Error, "aoti_torch_mps_get_kernel_function exception: %s", e.what());
122+
return Error::Internal;
123+
} catch (...) {
124+
ET_LOG(Error, "aoti_torch_mps_get_kernel_function: unknown exception");
125+
return Error::Internal;
126+
}
125127
}
126128
}
127129

@@ -133,19 +135,21 @@ AOTITorchError aoti_torch_mps_start_encoding(
133135
return Error::InvalidArgument;
134136
}
135137

136-
try {
137-
auto* function = reinterpret_cast<ETMetalKernelFunction*>(func);
138-
function->startEncoding();
138+
@autoreleasepool {
139+
try {
140+
auto* function = reinterpret_cast<ETMetalKernelFunction*>(func);
141+
function->startEncoding();
139142

140-
ET_LOG(Debug, "aoti_torch_mps_start_encoding: Started encoding for function %p", function);
141-
return Error::Ok;
143+
ET_LOG(Debug, "aoti_torch_mps_start_encoding: Started encoding for function %p", function);
144+
return Error::Ok;
142145

143-
} catch (const std::exception& e) {
144-
ET_LOG(Error, "aoti_torch_mps_start_encoding exception: %s", e.what());
145-
return Error::Internal;
146-
} catch (...) {
147-
ET_LOG(Error, "aoti_torch_mps_start_encoding: unknown exception");
148-
return Error::Internal;
146+
} catch (const std::exception& e) {
147+
ET_LOG(Error, "aoti_torch_mps_start_encoding exception: %s", e.what());
148+
return Error::Internal;
149+
} catch (...) {
150+
ET_LOG(Error, "aoti_torch_mps_start_encoding: unknown exception");
151+
return Error::Internal;
152+
}
149153
}
150154
}
151155

@@ -268,6 +272,11 @@ AOTITorchError aoti_torch_mps_dispatch_array(
268272
return Error::InvalidArgument;
269273
}
270274

275+
if (!length) {
276+
ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: null length pointer");
277+
return Error::InvalidArgument;
278+
}
279+
271280
try {
272281
auto* function = reinterpret_cast<ETMetalKernelFunction*>(func);
273282
function->dispatchArray(length, length_size);
@@ -296,6 +305,11 @@ AOTITorchError aoti_torch_mps_dispatch_array_with_group_size(
296305
return Error::InvalidArgument;
297306
}
298307

308+
if (!length) {
309+
ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: null length pointer");
310+
return Error::InvalidArgument;
311+
}
312+
299313
try {
300314
auto* function = reinterpret_cast<ETMetalKernelFunction*>(func);
301315
function->dispatchArrayWithGroupSize(length, length_size, group_size, group_size_size);

0 commit comments

Comments
 (0)