Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wraps sharded model for proper access to it state_dict in FSDP strategy #16558

Merged
merged 46 commits into from
Apr 17, 2023

Conversation

SpirinEgor
Copy link
Contributor

What does this PR do?

Fixes #16526 by following previously deleted DDPFullyShardedStrategy

Does your PR introduce any breaking changes?

No, it doesn't.

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Jan 30, 2023
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this!

This is only an issue in master, not 1.9, correct? In that case, this doesn't need a CHANGELOG entry.

Can you write a test?

src/pytorch_lightning/strategies/fsdp.py Outdated Show resolved Hide resolved
@SpirinEgor
Copy link
Contributor Author

SpirinEgor commented Jan 30, 2023

This is only an issue in master, not 1.9, correct?

No, I faced this in 1.9 in the fsdp_native strategy. The master renames it to fsdp without any fixes from ddp_fully_sharded that are removed.

@carmocca carmocca added this to the v1.9.x milestone Jan 30, 2023
@carmocca carmocca added bug Something isn't working strategy: fsdp Fully Sharded Data Parallel labels Jan 30, 2023
@SpirinEgor
Copy link
Contributor Author

SpirinEgor commented Jan 31, 2023

@carmocca I wrote tests for this. And I notice that the problem is actually occurred only for layers that is not wrapped, e.g. small layers. So, I parametrized tests with different wrapping policies and slightly changed BoringModel for correct asserts.

My build has one failed test, but it seems not my fault. What should I do?

And do you have any suggestions about modifying _LightningModuleWrapperBase?

@SpirinEgor SpirinEgor requested a review from carmocca January 31, 2023 14:39
@mergify mergify bot added the has conflicts label Feb 1, 2023
@mergify mergify bot removed the has conflicts label Feb 2, 2023
@mergify mergify bot removed the has conflicts label Apr 13, 2023
@github-actions github-actions bot removed ci Continuous Integration app (removed) Generic label for Lightning App package fabric lightning.fabric.Fabric labels Apr 13, 2023
@SpirinEgor SpirinEgor requested a review from awaelchli April 14, 2023 15:35
@SpirinEgor
Copy link
Contributor Author

@awaelchli I implemented what we discussed. For now, FSDP always aggregate full state dict on zero rank. Cpu offload depends on CPUOffload from initialized strategy. I'm not sure if this is okay, but this was enough to pass tests 😅 Setting cpu_offload=True in FullStateDictConfig lead to errors on CI, but not in my environment (details).

@SpirinEgor
Copy link
Contributor Author

I rethought this logic with offloading to the CPU. It's not good to reuse this variable as it's intended for a completely different purpose. We need to figure out why this doesn't work on CI. Because it works on my setup (2xA100).

@awaelchli
Copy link
Contributor

I'm looking into it!

@awaelchli awaelchli added the community This PR is from the community label Apr 17, 2023
@Borda Borda merged commit bb4e495 into Lightning-AI:master Apr 17, 2023
@mergify mergify bot added the ready PRs ready to be merged label Apr 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working community This PR is from the community pl Generic label for PyTorch Lightning package ready PRs ready to be merged strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Proper extraction of state_dict for fsdp strategy
4 participants