Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(1/n) Support 2D Parallelism #19846

Merged
merged 16 commits into from
May 7, 2024
Merged

(1/n) Support 2D Parallelism #19846

merged 16 commits into from
May 7, 2024

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented May 4, 2024

What does this PR do?

Adds a new ModelParallelStrategy that enables user-defined model parallelism.

def parallelize_my_model(model, device_mesh):
	# User-defined function that applies the desired parallelizations specific to the model
    # (TP, FSDP2, activation checkpointing, ...)
    ...


strategy = ModelParallelStrategy(
    parallelize_fn=parallelize_my_model,
    # Define the size of the 2D parallelism
    # Set to "auto" to apply TP intra-node and DP inter-node
    data_parallel_size=2,
    tensor_parallel_size=2,
)

fabric = L.Fabric(..., strategy=strategy)
fabric.launch()

# 1. Initializes the device mesh
# 2. Runs `parallelize_fn` here
# 3. Calls `.to_empty()` if model is on meta-device
# 4. Calls `.reset_parameters()` on submodules
model = fabric.setup(model)

The emphasis here must be on user-defined. The strategy does not do anything to the model except set up the device mesh for the user to consume. It is the user's responsibility to correctly parse the device mesh and set up the parallelization in their model. This means applying TP, FSDP, activation checkpointing, etc.

See examples/fabric/tensor_parallel for a full example.

Future PRs will add documentation pages.

What's not supported yet?

  • Not all checkpoint features are implemented. Only supports saving and loading distributed checkpoints right now. Future PRs will add the missing code.
  • Mixed precision with grad scaler not supported (16-mixed). Only bf16-mixed and bf16-true is supported. Future PRs will add the missing code.
  • Trainer: Future PRs will implement the same strategy for the PL Trainer, where the parallelize_fn will be replaced by a hook in the LightningModule.
  • HSDP (requires additional dimension of the device mesh). Could be exposed by an additional optional argument or making data_parallel_size accept a tuple.

cc @Borda @carmocca @justusschock @awaelchli

@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels May 4, 2024
@awaelchli awaelchli added feature Is an improvement or enhancement strategy and removed fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels May 4, 2024
@awaelchli awaelchli added this to the 2.3 milestone May 4, 2024
@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels May 4, 2024
@awaelchli awaelchli marked this pull request as ready for review May 4, 2024 16:22
Copy link
Contributor

github-actions bot commented May 4, 2024

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 2.0, oldest) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.10, 2.1) success
pl-cpu (macOS-11, lightning, 3.10, 2.2) success
pl-cpu (macOS-14, lightning, 3.10, 2.3) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.2) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.3) success
pl-cpu (windows-2022, lightning, 3.8, 2.0, oldest) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.10, 2.1) success
pl-cpu (windows-2022, lightning, 3.10, 2.2) success
pl-cpu (windows-2022, lightning, 3.10, 2.3) success
pl-cpu (macOS-11, pytorch, 3.8, 2.0) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 2.0) success
pl-cpu (windows-2022, pytorch, 3.8, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.0) success
pl-cpu (macOS-12, pytorch, 3.11, 2.1) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1) success
pl-cpu (windows-2022, pytorch, 3.11, 2.0) success
pl-cpu (windows-2022, pytorch, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/__init__.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/fabric/strategies/model_parallel.py, src/lightning/fabric/utilities/init.py, src/lightning/pytorch/strategies/fsdp.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) (testing Lightning | latest) success
pytorch-lightning (GPUs) (testing PyTorch | latest) success

These checks are required after the changes to src/lightning/pytorch/strategies/fsdp.py, src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/__init__.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/fabric/strategies/model_parallel.py, src/lightning/fabric/utilities/init.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/__init__.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/fabric/strategies/model_parallel.py, src/lightning/fabric/utilities/init.py, src/lightning/pytorch/strategies/fsdp.py.

🟢 fabric: Docs
Check ID Status
docs-make (fabric, doctest) success
docs-make (fabric, html) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/__init__.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/fabric/strategies/model_parallel.py, src/lightning/fabric/utilities/init.py.

🟢 pytorch_lightning: Docs
Check ID Status
docs-make (pytorch, doctest) success
docs-make (pytorch, html) success

These checks are required after the changes to src/lightning/pytorch/strategies/fsdp.py, docs/source-pytorch/conf.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 2.0, oldest) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.11, 2.1) success
fabric-cpu (macOS-11, lightning, 3.11, 2.2) success
fabric-cpu (macOS-14, lightning, 3.10, 2.3) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2) success
fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3) success
fabric-cpu (windows-2022, lightning, 3.8, 2.0, oldest) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.11, 2.1) success
fabric-cpu (windows-2022, lightning, 3.11, 2.2) success
fabric-cpu (windows-2022, lightning, 3.11, 2.3) success
fabric-cpu (macOS-11, fabric, 3.8, 2.0) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 2.0) success
fabric-cpu (windows-2022, fabric, 3.8, 2.0) success
fabric-cpu (macOS-12, fabric, 3.11, 2.0) success
fabric-cpu (macOS-12, fabric, 3.11, 2.1) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1) success
fabric-cpu (windows-2022, fabric, 3.11, 2.0) success
fabric-cpu (windows-2022, fabric, 3.11, 2.1) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/__init__.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/fabric/strategies/model_parallel.py, src/lightning/fabric/utilities/init.py, tests/tests_fabric/strategies/test_fsdp.py, tests/tests_fabric/strategies/test_fsdp_integration.py, tests/tests_fabric/strategies/test_model_parallel.py, tests/tests_fabric/strategies/test_model_parallel_integration.py, tests/tests_fabric/utilities/test_init.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) (testing Fabric | latest) success
lightning-fabric (GPUs) (testing Lightning | latest) success

These checks are required after the changes to examples/fabric/tensor_parallel/data.py, examples/fabric/tensor_parallel/model.py, examples/fabric/tensor_parallel/parallelism.py, examples/fabric/tensor_parallel/train.py, src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/__init__.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/fabric/strategies/model_parallel.py, src/lightning/fabric/utilities/init.py, tests/tests_fabric/strategies/test_fsdp.py, tests/tests_fabric/strategies/test_fsdp_integration.py, tests/tests_fabric/strategies/test_model_parallel.py, tests/tests_fabric/strategies/test_model_parallel_integration.py, tests/tests_fabric/utilities/test_init.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/__init__.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/fabric/strategies/model_parallel.py, src/lightning/fabric/utilities/init.py, src/lightning/pytorch/strategies/fsdp.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/__init__.py, src/lightning/fabric/strategies/fsdp.py, src/lightning/fabric/strategies/model_parallel.py, src/lightning/fabric/utilities/init.py, src/lightning/pytorch/strategies/fsdp.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

Copy link

codecov bot commented May 4, 2024

Codecov Report

Attention: Patch coverage is 90.69767% with 20 lines in your changes are missing coverage. Please review.

Project coverage is 59%. Comparing base (0f12271) to head (0d9afe8).

Additional details and impacted files
@@            Coverage Diff            @@
##           master   #19846     +/-   ##
=========================================
- Coverage      84%      59%    -25%     
=========================================
  Files         424      420      -4     
  Lines       34702    34802    +100     
=========================================
- Hits        29097    20437   -8660     
- Misses       5605    14365   +8760     

@github-actions github-actions bot added the docs Documentation related label May 4, 2024
TModel = TypeVar("TModel", bound=Module)


class ModelParallelStrategy(ParallelStrategy):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this name confusing. Data-parallelism is not model parallelism yet this class supports both. Can we think of something else?

What about ManualParallelStrategy? It was the name proposed a million years ago: #11922

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some discussion internally to consider another name, will rename the strategy in a follow up if we converge to a final decision.

src/lightning/fabric/strategies/model_parallel.py Outdated Show resolved Hide resolved
src/lightning/fabric/strategies/model_parallel.py Outdated Show resolved Hide resolved
fabric.launch()
assert fabric.strategy.device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel")
assert fabric.strategy.device_mesh.size(0) == 1
assert fabric.strategy.device_mesh.size(1) == 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI runs with 2 devices. This could introduce a deadlock if PyTorch adds a collective call in the device mesh. Could we change it to 2 just in case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 2D I need at least 2*2=4 devices.
What do you mean this could add a deadlock. Where? I'm calling it on all ranks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice that you had min_cuda_gpus=4.

Be aware then that this will not run on our CI (gets skipped) because all our agents run with only 2 visible CUDA devices.

https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=201341&view=logs&j=3f274fac-2e11-54ca-487e-194c91f3ae9f&t=8e4ceb7c-ceed-5ee0-b6fc-c9023e41cb74&l=1306

Copy link
Contributor Author

@awaelchli awaelchli May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I'm aware. I was torn between doing this vs. simulating it on CPU. DTensor works on CPU, but I'd still need processes spawned for proper e2e testing and we currently don't have the combination of standalone=True and min_gpus=x in the CI. Also I wasn't convinced enough that CPU testing would be representative enough. Maybe we need to revisit this later because it could be easy to miss updates to the tests.

if dp_mesh.size() > 1:
assert dp_mesh.ndim == 1 # Hybrid-sharding not supported

mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this meant to be encapsulated by fabric?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the user modifies the model directly here, probably not. In any case not for this iteration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth leaving a comment then because as a user I would be confused about the impact of doing this and setting precision="something" in Fabric

Copy link
Contributor Author

@awaelchli awaelchli May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I added a comment. We will need to do something about precision soon, let's brainstorm this for a follow up

@awaelchli awaelchli requested a review from carmocca May 7, 2024 07:38
fabric.launch()
assert fabric.strategy.device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel")
assert fabric.strategy.device_mesh.size(0) == 1
assert fabric.strategy.device_mesh.size(1) == 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice that you had min_cuda_gpus=4.

Be aware then that this will not run on our CI (gets skipped) because all our agents run with only 2 visible CUDA devices.

https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=201341&view=logs&j=3f274fac-2e11-54ca-487e-194c91f3ae9f&t=8e4ceb7c-ceed-5ee0-b6fc-c9023e41cb74&l=1306

@mergify mergify bot added the ready PRs ready to be merged label May 7, 2024
@awaelchli awaelchli merged commit 0c8a193 into master May 7, 2024
117 of 118 checks passed
@awaelchli awaelchli deleted the examples/tp branch May 7, 2024 21:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs Documentation related fabric lightning.fabric.Fabric feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package ready PRs ready to be merged strategy
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants