Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate distributed state dict API #2138

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Dec 10, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Migrate distributed state dict APIs from torch.distributed.

Changelog

What are the changes made in this PR?

Switch to distributed state dict APIs from torch.distributed.

  • load_from_full_model_state_dict <- set_model_state_dict
  • gather_cpu_state_dict <- get_model_state_dict
  • load_from_full_optimizer_state_dict <- set_optimizer_state_dict
  • get_full_optimizer_state_dict <- get_optimizer_state_dict

To align the inputs, add model input to get_full_optimizer_state_dict and load_from_full_optimizer_state_dict.
Change the sharded_sd input for gather_cpu_state_dict to model.

TODO:
nf4tensor are kept the same, remain as future work

Test plan

pytest tests/torchtune/training/test_distributed.py
pytest tests -m integration_test

We compare the running with the previous API and the new API, loss are the same in initial loading and resume from checkpoint.

We also draw the memory traces, results show that the new API won't cost mote memory peak comapred with the current ones.

Copy link

pytorch-bot bot commented Dec 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2138

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3d0d26f with merge base 002b17c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 10, 2024
@joecummings joecummings added the distributed Anything related to distributed env (multi-GPU, multi-node) label Dec 10, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 3.38983% with 57 lines in your changes missing coverage. Please review.

Project coverage is 65.26%. Comparing base (f2bd4bc) to head (8b575be).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_distributed.py 3.50% 55 Missing ⚠️
tests/torchtune/training/test_distributed.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            main    #2138       +/-   ##
==========================================
+ Coverage   9.33%   65.26%   +55.93%     
==========================================
  Files        289      334       +45     
  Lines      16959    19192     +2233     
==========================================
+ Hits        1583    12526    +10943     
+ Misses     15376     6666     -8710     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mori360 mori360 changed the title Mitigate distributed state dict API Migrate distributed state dict API Dec 18, 2024
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Dec 19, 2024
…ept 2 device type and optimize memory (#142845)

For destributed state dict api [migration](pytorch/torchtune#2138), make the changes here:
1. `load_from_full_model_state_dict` at TorchTune calls `set_model_state_dict` with the options on whether to have cpu_offload. Add cpu_offload at _load_model_state_dict to process to cpu if config is True
2. Change the device check as lora_finetune might hace 2 device types, accept that to be valid.
3. Some changes to optimize the memory performance:
3.1 use `.detach().clone()` instead of view directly
3.2 if local_state is not meta, copy `full_tensor[slices]` to `ret.to_local()`
4. add relative unit tests

Memory performance calling from TorchTune with llama2/7B_full:
1. cpu_offload = True
<img width="555" alt="Screenshot 2024-12-18 at 1 36 47 PM" src="https://github.com/user-attachments/assets/429261f5-1107-4592-b295-de3944a2614b" />

2. cpu_offload = False
<img width="555" alt="Screenshot 2024-12-18 at 1 36 52 PM" src="https://github.com/user-attachments/assets/40bf281a-236a-4218-826b-b1192a10c806" />

Pull Request resolved: #142845
Approved by: https://github.com/fegin
sharded_param = full_tensor.new_zeros(chunk.size())
sharded_param[: chunk.size(0)].copy_(chunk)

# TODO: change to from_local API (need to add view support for NF4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we get view support for NF4?

cc @andrewor14

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review, we currently skip the NF4 tensor part and plan to support NF4 in the next quarter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's already view support for NF4Tensor? What's the error you're getting?

also cc @drisspg @weifengpy

@mori360 mori360 marked this pull request as ready for review December 20, 2024 23:57
@mori360 mori360 requested a review from joecummings December 20, 2024 23:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. distributed Anything related to distributed env (multi-GPU, multi-node)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants