-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(1/n) Support 2D Parallelism #19846
(1/n) Support 2D Parallelism #19846
Conversation
tests/tests_fabric/strategies/test_model_parallel_integration.py
Outdated
Show resolved
Hide resolved
⚡ Required checks status: All passing 🟢Groups summary🟢 pytorch_lightning: Tests workflowThese checks are required after the changes to 🟢 pytorch_lightning: Azure GPU
These checks are required after the changes to 🟢 pytorch_lightning: Benchmarks
These checks are required after the changes to 🟢 fabric: Docs
These checks are required after the changes to 🟢 pytorch_lightning: Docs
These checks are required after the changes to 🟢 lightning_fabric: CPU workflowThese checks are required after the changes to 🟢 lightning_fabric: Azure GPU
These checks are required after the changes to 🟢 mypy
These checks are required after the changes to 🟢 installThese checks are required after the changes to Thank you for your contribution! 💜
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19846 +/- ##
=========================================
- Coverage 84% 59% -25%
=========================================
Files 424 420 -4
Lines 34702 34802 +100
=========================================
- Hits 29097 20437 -8660
- Misses 5605 14365 +8760 |
TModel = TypeVar("TModel", bound=Module) | ||
|
||
|
||
class ModelParallelStrategy(ParallelStrategy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find this name confusing. Data-parallelism is not model parallelism yet this class supports both. Can we think of something else?
What about ManualParallelStrategy
? It was the name proposed a million years ago: #11922
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some discussion internally to consider another name, will rename the strategy in a follow up if we converge to a final decision.
fabric.launch() | ||
assert fabric.strategy.device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel") | ||
assert fabric.strategy.device_mesh.size(0) == 1 | ||
assert fabric.strategy.device_mesh.size(1) == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CI runs with 2 devices. This could introduce a deadlock if PyTorch adds a collective call in the device mesh. Could we change it to 2 just in case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For 2D I need at least 2*2=4 devices.
What do you mean this could add a deadlock. Where? I'm calling it on all ranks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't notice that you had min_cuda_gpus=4
.
Be aware then that this will not run on our CI (gets skipped) because all our agents run with only 2 visible CUDA devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I'm aware. I was torn between doing this vs. simulating it on CPU. DTensor works on CPU, but I'd still need processes spawned for proper e2e testing and we currently don't have the combination of standalone=True and min_gpus=x in the CI. Also I wasn't convinced enough that CPU testing would be representative enough. Maybe we need to revisit this later because it could be easy to miss updates to the tests.
tests/tests_fabric/strategies/test_model_parallel_integration.py
Outdated
Show resolved
Hide resolved
if dp_mesh.size() > 1: | ||
assert dp_mesh.ndim == 1 # Hybrid-sharding not supported | ||
|
||
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this meant to be encapsulated by fabric?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the user modifies the model directly here, probably not. In any case not for this iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth leaving a comment then because as a user I would be confused about the impact of doing this and setting precision="something"
in Fabric
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I added a comment. We will need to do something about precision soon, let's brainstorm this for a follow up
tests/tests_fabric/strategies/test_model_parallel_integration.py
Outdated
Show resolved
Hide resolved
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
fabric.launch() | ||
assert fabric.strategy.device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel") | ||
assert fabric.strategy.device_mesh.size(0) == 1 | ||
assert fabric.strategy.device_mesh.size(1) == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't notice that you had min_cuda_gpus=4
.
Be aware then that this will not run on our CI (gets skipped) because all our agents run with only 2 visible CUDA devices.
What does this PR do?
Adds a new
ModelParallelStrategy
that enables user-defined model parallelism.The emphasis here must be on user-defined. The strategy does not do anything to the model except set up the device mesh for the user to consume. It is the user's responsibility to correctly parse the device mesh and set up the parallelization in their model. This means applying TP, FSDP, activation checkpointing, etc.
See
examples/fabric/tensor_parallel
for a full example.Future PRs will add documentation pages.
What's not supported yet?
parallelize_fn
will be replaced by a hook in the LightningModule.data_parallel_size
accept a tuple.cc @Borda @carmocca @justusschock @awaelchli