-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Introducing Parameter Sharding and Torch backend for Tensor Parallelism #21724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Introducing Parameter Sharding and Torch backend for Tensor Parallelism #21724
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances Keras's capabilities for large-scale model training by introducing a robust framework for Tensor Parallelism. It provides a new PyTorch-specific distributed backend with collective communication primitives and a flexible parameter sharding mechanism. This allows Keras models to efficiently distribute their parameters across multiple devices, paving the way for more advanced distributed training strategies within the Keras ecosystem. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a foundational framework for Tensor Parallelism in Keras and adds a Torch backend for distributed communication. The changes are substantial and add significant new capabilities. My review focuses on the correctness, generality, and test coverage of this new framework. I've identified some critical issues, such as backend-specific implementations in what should be a backend-agnostic framework, and tests that don't cover the new Torch implementation. There are also opportunities to improve code quality by removing hardcoded logic and reducing code duplication. Addressing these points will help ensure the new Tensor Parallelism framework is robust and maintainable.
keras/src/distribution/tensor_parallel/parameter_sharding_test.py
Outdated
Show resolved
Hide resolved
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21724 +/- ##
==========================================
- Coverage 82.58% 82.40% -0.19%
==========================================
Files 572 577 +5
Lines 58187 59004 +817
Branches 9116 9243 +127
==========================================
+ Hits 48055 48621 +566
- Misses 7808 8029 +221
- Partials 2324 2354 +30
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This pull request introduces a foundational framework for Tensor Parallelism in Keras, Parameter_sharding.py, enabling the training of large-scale models by sharding their parameters across multiple devices. This is a significant step towards supporting advanced distributed training strategies directly within the Keras ecosystem.
The core of this contribution is a new, backend-agnostic parameter sharding framework and the necessary distributed communication primitives for the PyTorch backend.
Key Changes
PyTorch Distributed Backend
A new distributed_backend.py module has been added for the PyTorch backend.
It implements essential collective communication operations (all_reduce, all_gather, broadcast, scatter) using the torch.distributed package.
Provides helper functions for gradient computation (compute_gradients) and device management, aligning its interface with other Keras backends.
Parameter Sharding Framework
Introduces a powerful parameter sharding API under keras/src/distribution/tensor_parallel/.
ParameterShardingStrategy: A new class that manages the logic for splitting model weights based on user-defined rules specified in a ConfigKeras object.
ShardedWeight: A wrapper class for sharded keras.Variable objects, allowing them to be seamlessly integrated into the model.
make_parameter_sharded_model: A factory function that takes a standard Keras model and returns a sharded version, automatically handling the weight splitting and model wrapping. The wrapped ParameterShardedModel injects communication ops (e.g., all-reduce) into the forward pass to ensure correct computations.
Example usage: https://colab.research.google.com/drive/1UAINIcstDuO0aeA9lxCF5LaIj5ne5X5z?resourcekey=0-pPF4COO19KRoqS5cpWNILA&usp=sharing
This is the 3rd (out of 4) PR for AutoSharding Keras.