-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add Autoconfig, Coordinated_Optimizer and Sharding keras implementations for Tensor Parallel Autosharding #21707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Autoconfig, Coordinated_Optimizer and Sharding keras implementations for Tensor Parallel Autosharding #21707
Conversation
…hsuhana/keras into Tensor_parallel_keras_2
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances Keras's capabilities for large-scale model training by introducing foundational support for tensor parallelism autosharding. It provides mechanisms to automatically determine how model layers should be split across multiple devices and a specialized optimizer to manage the distributed training process, including sharding optimizer states and synchronizing gradients. This enables users to train models that exceed the memory capacity of a single accelerator, making distributed training more accessible and efficient within the Keras ecosystem. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant new functionality for tensor parallelism autosharding in Keras, including modules for automatic configuration and a coordinated optimizer. The implementation is well-structured, with new logic for analyzing models, generating sharding plans, and synchronizing gradients. However, I've identified a few issues that need attention. There is a critical bug in the CoordinatedOptimizer
where a method for applying gradients with sharded states is called but not defined. I also found a couple of high-severity issues related to incorrect logic for matching optimizer states and gathering sharded parameters, which could lead to runtime errors or incorrect behavior. Additionally, there are some medium-severity issues regarding code clarity, such as unused parameters. The accompanying tests are a good start but do not cover the code path with the critical bug.
keras/src/distribution/tensor_parallel/coordinated_optimizer.py
Outdated
Show resolved
Hide resolved
keras/src/distribution/tensor_parallel/coordinated_optimizer.py
Outdated
Show resolved
Hide resolved
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21707 +/- ##
==========================================
- Coverage 82.59% 82.49% -0.11%
==========================================
Files 572 576 +4
Lines 58322 58923 +601
Branches 9130 9236 +106
==========================================
+ Hits 48173 48606 +433
- Misses 7818 7965 +147
- Partials 2331 2352 +21
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
if id(current_layer) in processed_layers: | ||
return | ||
processed_layers.add(id(current_layer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per my comment below about not needing a recursion, this is not needed
processed_layers.add(id(current_layer)) | ||
|
||
name = current_layer.name | ||
full_name = f"{prefix}.{name}" if prefix else name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because you will never really recurse, the prefix won't work.
self._variable_to_slot_name = {} | ||
opt_name = self.base_optimizer.name | ||
|
||
normalized_params = sorted( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a hard time following what the code in this method is doing. I think it's try to re-pair the optimizer variables with the corresponding model variables.
I think it would be easier to capture that information in BaseOptimizer.add_variable_from_reference
. Today we have _get_variable_index
, but we need something more specific.
Also what do you call a slot?
numpy_grad = ops.convert_to_numpy(gradients[0]) | ||
synced_numpy = all_reduce_fn(numpy_grad, op="mean") | ||
synced_tensor = ops.convert_to_tensor(synced_numpy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy_grad = ops.convert_to_numpy(gradients[0])
This is going to copy the gradient tensor from GPU/TPU to CPU memory. At that point, there is no sharding anymore because it's going to gather all shards from all GPUs into one single recombined tensor on CPU.
synced_numpy = all_reduce_fn(numpy_grad, op="mean")
There is no all_reduce
needed here, this is a CPU NumPy array, you might as well do np.mean
.
However, what JAX will do is copy the full tensor (unsharded) from CPU to device 0 and perform the mean on that one device.
synced_tensor = ops.convert_to_tensor(synced_numpy)
synced_numpy
is already a JAX array per my comment above, so this is a no-op.
So overall, I'm not sure what the intent is, but it look like this is not doing what you think it's doing. In particular, it's moving gradient back and forth to CPU, which will reduce the throughput.
stacked_grads = keras.ops.stack( | ||
[ops.convert_to_tensor(g) for g in gradients], axis=0 | ||
) | ||
mean_grad = ops.mean(stacked_grads, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would there be more than 1 gradient in this case?
mean_grad = ops.mean(stacked_grads, axis=0) | ||
return [mean_grad for _ in range(len(gradients))] | ||
|
||
def get_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh weird, I wonder why get_weights
is missing in BaseOptimizer
.
self._initialize_sharded_states() | ||
|
||
|
||
class TensorParallelOptimizer(optimizers.Optimizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears that TensorParallelOptimizer
is mostly a wrapper around CoordinatedOptimizer
.
Any reason to have both separate?
This PR introduces support for tensor parallelism autosharding in Keras, enabling users to shard large model layers across multiple devices. This is a crucial feature for training models that are too large to fit into the memory of a single accelerator.
The implementation is centered around two new components:
autoconfig.py: This module contains the logic to analyze a Keras model, identify sharding candidates (e.g., Dense, EinsumDense layers), and generate a sharding plan.
coordinated_optimizer.py: This is an optimizer wrapper that consumes the sharding plan. During training, it intercepts gradients for sharded variables and performs a collective AllReduce to ensure weight updates are correctly synchronized across all devices.
Example usage: https://colab.research.google.com/drive/1UAINIcstDuO0aeA9lxCF5LaIj5ne5X5z?resourcekey=0-pPF4COO19KRoqS5cpWNILA&usp=sharing
This is the 2nd (out of 4) PR for AutoSharding Keras.