@@ -77,6 +77,35 @@ enum class SyncType {
7777// =======================
7878// ETMetalShaderLibrary - ExecuTorch Metal shader library management
7979// =======================
80+
81+ /* *
82+ * @class ETMetalShaderLibrary
83+ * @brief Manages Metal shader library compilation and kernel function retrieval.
84+ *
85+ * This class provides a high-level interface for compiling Metal shading language
86+ * source code into a Metal library and creating compute pipeline states for
87+ * kernel functions. It handles the creation and caching of Metal compute pipeline
88+ * states and functions, which should be reused across multiple kernel dispatches.
89+ *
90+ * The class automatically compiles the provided shader source code upon construction
91+ * and maintains an internal cache of compute pipeline states for different kernel
92+ * functions to avoid redundant compilation.
93+ *
94+ * Example usage:
95+ * @code
96+ * std::string shaderSource = R"(
97+ * #include <metal_stdlib>
98+ * using namespace metal;
99+ * kernel void my_kernel(device float* data [[buffer(0)]],
100+ * uint tid [[thread_position_in_grid]]) {
101+ * data[tid] = data[tid] * 2.0;
102+ * }
103+ * )";
104+ *
105+ * ETMetalShaderLibrary library(shaderSource);
106+ * auto kernelFunction = library.getKernelFunction("my_kernel");
107+ * @endcode
108+ */
80109class ETMetalShaderLibrary {
81110 public:
82111 ETMetalShaderLibrary (const std::string& source);
@@ -103,6 +132,45 @@ class ETMetalShaderLibrary {
103132// =======================
104133// ETMetalKernelFunction - ExecuTorch Metal kernel function execution
105134// =======================
135+
136+ /* *
137+ * @class ETMetalKernelFunction
138+ * @brief Represents a Metal compute kernel function ready for execution.
139+ *
140+ * This class encapsulates a Metal compute pipeline state and function, providing
141+ * a high-level interface for setting kernel arguments and dispatching compute
142+ * work to the GPU. It handles the encoding of compute commands and manages the
143+ * interaction with Metal's compute command encoder.
144+ *
145+ * The class supports different dispatch patterns:
146+ * - Single-dimension dispatch for linear workloads
147+ * - Multi-dimensional dispatch for grid-based workloads
148+ * - Custom thread group sizes for performance optimization
149+ *
150+ * Kernel arguments can be set using tensors (which will be mapped to Metal buffers)
151+ * or scalar values. The class handles the encoding of these arguments
152+ * into the compute command encoder.
153+ *
154+ * Example usage:
155+ * @code
156+ * // Get kernel function from library
157+ * auto kernelFunction = library.getKernelFunction("vector_add");
158+ *
159+ * // Start encoding commands
160+ * kernelFunction->startEncoding();
161+ *
162+ * // Set tensor arguments
163+ * kernelFunction->setArg(0, inputTensorA);
164+ * kernelFunction->setArg(1, inputTensorB);
165+ * kernelFunction->setArg(2, outputTensor);
166+ *
167+ * // Set scalar argument
168+ * kernelFunction->setArg(3, static_cast<int64_t>(numElements));
169+ *
170+ * // Dispatch for linear workload
171+ * kernelFunction->dispatchSingle(numElements);
172+ * @endcode
173+ */
106174class ETMetalKernelFunction {
107175 public:
108176 ETMetalKernelFunction (MTLComputePipelineState_t cps, MTLFunction_t func);
@@ -132,6 +200,45 @@ class ETMetalKernelFunction {
132200// =======================
133201// ETMetalStream - Metal command buffer and synchronization management
134202// =======================
203+
204+ /* *
205+ * @class ETMetalStream
206+ * @brief Manages Metal compute command streams and provides GPU synchronization.
207+ *
208+ * This class serves as the central management hub for Metal GPU operations, providing
209+ * a stream-based abstraction similar to CUDA streams. It handles command buffer lifecycle,
210+ * compute command encoder management, and various synchronization patterns required for
211+ * efficient GPU computation.
212+ *
213+ * Key features:
214+ * - Lazy command buffer and encoder creation for optimal resource usage
215+ * - Thread-safe operations using serial dispatch queues
216+ * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, COMMIT_AND_CONTINUE)
217+ * - Kernel coalescing to batch multiple operations efficiently
218+ * - MPSGraph integration for high-level neural network operations
219+ * - Memory operations (copy, fill) with GPU acceleration via blit encoders
220+ *
221+ * The stream follows PyTorch's MPS stream design patterns, providing similar semantics
222+ * for command buffer management and synchronization.
223+ *
224+ * Example usage:
225+ * @code
226+ * // Get current stream (typically the default stream)
227+ * ETMetalStream* stream = getCurrentMetalStream();
228+ *
229+ * // Execute kernel operations (handled automatically)
230+ * auto kernelFunction = library.getKernelFunction("my_kernel");
231+ * kernelFunction->startEncoding();
232+ * kernelFunction->setArg(0, inputTensor);
233+ * kernelFunction->dispatchSingle(numElements);
234+ *
235+ * // Synchronize to ensure completion
236+ * stream->synchronize(SyncType::COMMIT_AND_WAIT);
237+ *
238+ * // Copy between GPU buffers using blit encoder
239+ * stream->copy(srcBuffer, dstBuffer, numBytes, 0, 0, SyncType::COMMIT);
240+ * @endcode
241+ */
135242class ETMetalStream {
136243 public:
137244 ETMetalStream ();
0 commit comments