Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Metal Inference Engine for exo #361

Open
wants to merge 43 commits into
base: main
Choose a base branch
from

Conversation

sambhavnoobcoder
Copy link

@sambhavnoobcoder sambhavnoobcoder commented Oct 18, 2024

This PR aims to address the issue #238

What I Did

  1. Created a new folder @metal with the following files:

    • inference.py
    • metal_kernel_compiler.py
    • metal_model_shard.py
    • swift_code_generator.py
    • utils.py
  2. Implemented MetalDynamicShardInferenceEngine class in inference.py, which is designed to use Metal for GPU acceleration on Apple Silicon Macs.

  3. Created MetalKernelCompiler in metal_kernel_compiler.py to compile TinyGrad kernels to Metal shaders.

  4. Defined MetalModelShard and MetalKernelMetadata in metal_model_shard.py to represent Metal-specific model shards and kernel metadata.

  5. Implemented SwiftCodeGenerator in swift_code_generator.py to generate Swift wrapper code for Metal kernels.

  6. Added utility classes and functions in utils.py, including Linearizer for optimizing the abstract syntax tree (AST) of kernels.

  7. Updated models.py to include Metal-specific model configurations:

  8. Modified main.py to support the new Metal inference engine:

What I Couldn't Do / Current Errors

  1. Integration Error: The MetalDynamicShardInferenceEngine is not fully integrated with the existing codebase. I'm encountering an attribute error:

    'MetalDynamicShardInferenceEngine' object has no attribute '_numpy_to_metal_buffer'
    

    This error suggests that I need to implement the _numpy_to_metal_buffer method in the MetalDynamicShardInferenceEngine class.

  2. Compilation Issues: I haven't been able to successfully compile the Metal shaders or integrate them with the Swift runtime. This requires further investigation and possibly bridging between Python and Swift/Metal.

  3. Inconsistent capturing of kernels : At times the code is simply unable to capture the kernels of the model . i am unsure why that is happening .

What I Need Help With

  1. Implementing the _numpy_to_metal_buffer method in MetalDynamicShardInferenceEngine. This method should convert numpy arrays to Metal buffers.

  2. Guidance on how to properly compile Metal shaders and integrate them with the Swift runtime from Python.

  3. Assistance with bridging the gap between the Python code and the Metal/Swift implementation.

  4. Review of the Metal kernel compilation process to ensure it's correctly translating TinyGrad operations to Metal operations.

  5. Advice on how to properly initialize and manage Metal devices and command queues from within the Python environment.

Next Steps

  1. Implement the missing _numpy_to_metal_buffer method.
  2. Set up a proper bridge between Python and Swift/Metal, possibly using a library like pyobjc.
  3. Implement proper error handling and logging for Metal-specific operations.
  4. Create unit tests for the Metal implementation.
  5. Benchmark the Metal implementation against existing CPU and GPU implementations.

Questions for Reviewers

  1. Is the current approach of translating TinyGrad kernels to Metal shaders the most efficient way, or should we consider a different approach?
  2. How should we handle memory management between Python and Metal?
  3. Are there any specific Metal optimizations we should be considering for large language models?
  4. How can we ensure thread safety when dealing with Metal resources from Python?

I would greatly appreciate any guidance or suggestions on how to proceed with this implementation. Thank you for your time and assistance!

- Implemented `MetalDynamicShardInferenceEngine` for dynamic shard inference.
- Added `MetalModelShard` and `MetalKernelMetadata` for handling model sharding and kernel metadata.
- Included `SwiftCodeGenerator` for generating Swift code compatible with Metal shader kernels.
- added libraries and dependencies
- Introduced `MetalKernelMetadata` dataclass to store metadata for Metal kernel operations.
- Fields include kernel name, input/output shapes, work group size, global size, buffer sizes, and operation sequence.
- This class will help manage and organize kernel execution details in Metal-based inference engines.
- Added `MetalModelShard` dataclass to manage sharded model components for Metal-based inference.
- Contains `kernel_metadata` for storing metadata of individual kernels.
- Includes `weights` for holding model weights in a dictionary format.
- Added `config` to store additional configuration data required for the model shard.
- setup libraries and dependencies
- Implemented `SwiftCodeGenerator` with the `generate_swift_wrapper` method, which generates Swift code to interface with Metal.
- The generated Swift code includes an `MLXMetalEngine` class with methods for kernel compilation, buffer creation, and execution.
- `MLXMetalEngine` handles Metal device setup, command queue creation, kernel pipeline state management, and kernel execution.
- Added error handling with a custom `MLXMetalError` enum for common Metal-related errors.
…ator

- Implemented `_generate_kernel_metadata_initialization` to generate Swift initialization code for kernel metadata.
- This method creates the initialization for each kernel's metadata, including input shapes, output shape, work group size, global size, and buffer sizes.
- The generated code is used to populate the `kernelMetadata` dictionary in the Swift wrapper.
- Implemented `_generate_metal_kernel_source` to generate Metal shader source code.
- This method concatenates the `metal_code` from each kernel's metadata to form the complete Metal source code.
- The generated source is used for kernel compilation in the Swift wrapper.
- Refactor and commit library.
- Introduced `UOp` dataclass to store information about a unified operation.
- The `op` field stores the operation name, while `dtype` represents the data type associated with the operation.
- `args` holds a list of arguments required for the operation.
- This class helps structure operations in a concise and standardized format.
- Introduced `Kernel` dataclass with fields to store kernel `name` and a list of unified operations (`uops`).
- The `uops` field contains a list of `UOp` objects, representing the sequence of operations to be executed within the kernel.
- This class is designed to encapsulate the structure and behavior of individual computation kernels.
- Introduced `ASTNode` class with fields to store the operation (`op`), data type (`dtype`), and a list of arguments (`args`).
- The class models individual nodes in an abstract syntax tree (AST), representing operations in a computational graph.
- Each node holds essential information about the operation and its associated arguments.
- Introduced `Linearizer` class with fields to store the linearizer `name`, an abstract syntax tree (`ast`), and options (`opts`).
- The class initializes an empty list of unified operations (`uops`), which will store the linearized form of the AST nodes.
- The `Linearizer` class is responsible for converting an abstract syntax tree into a sequence of unified operations (`UOp`).
- Added a method to simplify expressions where all arguments are constants.
- Implemented optimizations for add, mul, sub, and div operations.
- Constant folding applied by evaluating constant expressions during AST traversal.
- Implemented `evaluate_constant_expression` to handle constant folding.
- Supported arithmetic operations: addition, multiplication, subtraction, and division.
- Raised ValueError for unsupported operations to ensure robustness in constant folding.
- Introduced `linearize` method to convert optimized AST into micro-operations.
- Utilized `get_optimized_ast` for constant folding before linearization.
- Added a second optimization pass over uOps for further performance improvements.
- Returns a `Kernel` object containing the name and optimized uOps.
- Added `ast_to_uop` method to translate AST nodes into a list of micro-operations.
- Handled constant loading, arithmetic operations (add, mul, sub, div), and assignments.
- Ensured non-constant values are loaded into uOps before operations.
- Included a fallback mechanism to handle unrecognized AST operations.
- Added `optimize_uops` method to eliminate unnecessary uOps.
- Performed backward pass to identify used variables in the uOps.
- Implemented forward pass to retain only essential operations, removing dead code.
- Handled both constant loading and variable storage efficiently.
- Introduced `KernelOperation` dataclass to encapsulate the type, inputs, output, and attributes of kernel operations.
- Added default empty dictionary for attributes to ensure flexibility in operation handling.
- Ensured proper initialization of attributes in the `__post_init__` method.
- refactored out extra code
- Commit library & Dependencies
- Introduced `MetalKernelCompiler` to handle kernel storage and metadata management.
- Added `kernels` dictionary to store compiled kernel code by name.
- Maintained `kernel_order` list to track the execution order of kernels.
- Included `kernel_metadata` dictionary to store metadata related to each kernel using `MetalKernelMetadata`.
- Added `_generate_buffer_bindings` method to generate buffer bindings for Metal kernels.
- Dynamically created bindings for input buffers based on the number of inputs.
- Included binding for the output buffer at the correct index.
- Ensured proper formatting for Metal kernel syntax with comma-separated bindings.
- Implemented `_convert_shape_to_metal` to convert a list of dimensions into Metal-compatible `uint3` format.
- Ensured proper formatting for the Metal kernel by joining shape dimensions with commas.
- Added `_generate_index_calculation` to compute global indices for 1D, 2D, and 3D grids in Metal kernels.
- Handled indexing based on grid size and thread group IDs (`gid`) for different dimensional cases.
- Ensured efficient index calculation for higher-dimensional operations in Metal compute shaders.
- Implemented `_convert_dtype_to_metal` to map Python data types to Metal types.
- Created a dictionary to handle conversions for various float and integer types.
- Provided a default return value of "float" for unsupported data types to ensure compatibility.
- Added `_convert_op_to_metal` to translate TinyGrad operations into Metal-compatible `KernelOperation` instances.
- Supported various operation types, including basic arithmetic, activation functions, memory operations, matrix operations, and reduction operations.
- Utilized lambda functions to create `KernelOperation` instances dynamically based on operation attributes.
- Handled unknown operations gracefully by returning a `KernelOperation` with an "unknown" type.
- Implemented `_generate_metal_op_code` to create Metal shader code for various kernel operations.
- Defined templates for arithmetic, activation functions, memory load/store, matrix multiplication, and reduction operations.
- Included handling for unknown operations by providing a fallback template.
- Utilized Python's string formatting to inject operation-specific details into the generated Metal code.
- Added `compile_kernel_to_metal` to convert a TinyGrad kernel into Metal shader code.
- Analyzed kernel operations using a `Linearizer` to prepare for Metal conversion.
- Created `MetalKernelMetadata` to encapsulate kernel details, including input/output shapes, work group size, and operation sequence.
- Generated Metal code with appropriate buffer bindings, global index calculations, temporary variable declarations, and kernel computations.
- Ensured proper handling of global indices to avoid out-of-bounds access in kernel execution.
- Implemented `_generate_temp_declarations` to create temporary variable declarations based on kernel operations.
- Collected unique variable names from operation outputs and inputs.
- Excluded input and output variables from temporary declarations to avoid naming conflicts.
- Formatted the declarations for inclusion in Metal shader code, ensuring proper syntax and spacing.
- Added `_generate_computation_code` to create the core computation logic for Metal kernels.
- Utilized the previously defined `_generate_metal_op_code` method to convert each `KernelOperation` into Metal code.
- Ensured that all generated computation statements are formatted correctly for inclusion in the Metal shader.
- Dependencies commit
- Introduced `MetalDynamicShardInferenceEngine` to manage dynamic inference using Metal.
- Initialized with a shard downloader for fetching model shards.
- Set up a single-threaded executor for handling inference tasks.
- Included components for Metal kernel compilation (`MetalKernelCompiler`), Swift code generation (`SwiftCodeGenerator`), and placeholders for the Metal engine and tokenizer.
- Prepared the class structure for further implementation of dynamic inference capabilities.
- Implemented `initialize_metal_engine` to set up the Metal engine using kernels from the provided model shard.
- Compiled each kernel to Metal code using `MetalKernelCompiler` and stored both the compiled code and metadata.
- Generated Swift wrapper code for the compiled kernels using `SwiftCodeGenerator`.
- Initialized the Metal engine via a Swift bridge, ensuring compatibility with the generated code.
- Prepared the engine for efficient execution of dynamic inference tasks using Metal shaders.
- Added `infer_prompt` method to handle inference requests for both text-only and image-text inputs.
- Ensured the appropriate model shard is available before processing.
- For image-text inputs, extracted image data and prepared inputs for the vision encoder using an executor for concurrent processing.
- Integrated the vision encoder and text decoder to process inputs and generate output data.
- For text-only inputs, tokenized the prompt using an executor for parallel execution.
- Determined if the inference is finished based on the output data, specifically checking for the end-of-sequence token.
- Returned the output data, an empty string, and a flag indicating whether the inference is complete.
- Added `infer_tensor` method to handle inference requests with raw tensor data.
- Ensured the appropriate model shard is available before processing.
- Converted the input NumPy array to a Metal buffer for compatibility with Metal kernels.
- Executed the inference kernels using the prepared input buffer and stored the output data.
- Determined if the inference is finished by checking the output data for the end-of-sequence token.
- Returned the output data, an empty string, and a flag indicating whether the inference is complete.
- Implemented `ensure_shard` to load a model shard if it is not already loaded.
- Checked if the requested shard matches the currently loaded shard to avoid unnecessary loading.
- Utilized the shard downloader to retrieve the model path for the specified shard.
- Loaded shard weights and configuration using an executor to run the loading process asynchronously.
- Initialized the Metal engine with the newly loaded shard and updated the shard reference.
- Added `_load_metal_shard` to handle the loading of model weights and configuration from a specified path.
- Loaded weights and configuration using the `load_shard` function.
- Converted NumPy weight arrays to Metal buffers for compatibility with Metal kernels.
- Returned a new `MetalModelShard` instance containing the kernel metadata, converted weights, and configuration.
- Implemented `_run_inference_kernels` to handle the sequential execution of inference kernels using Metal.
- Retrieved the kernel sequence for inference from the shard configuration.
- Initialized the inference by using the provided input buffer as the starting input.
- For each kernel in the sequence, passed the current output and associated weights (if available) to the Metal engine.
- Updated the current output after each kernel execution to continue the chain of computations.
- Returned in the correct format
- added parameter for including the metal inference engine in the models
- change to offer metal as an option for inference engine as well 
- current implementation is a placeholder , running the model every time , aiming to modify this as a proper option in the final PR .
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant