-
Notifications
You must be signed in to change notification settings - Fork 465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate distributed state dict API #2138
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2138
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3d0d26f with merge base 002b17c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2138 +/- ##
==========================================
+ Coverage 9.33% 65.26% +55.93%
==========================================
Files 289 334 +45
Lines 16959 19192 +2233
==========================================
+ Hits 1583 12526 +10943
+ Misses 15376 6666 -8710 ☔ View full report in Codecov by Sentry. |
…ept 2 device type and optimize memory (#142845) For destributed state dict api [migration](pytorch/torchtune#2138), make the changes here: 1. `load_from_full_model_state_dict` at TorchTune calls `set_model_state_dict` with the options on whether to have cpu_offload. Add cpu_offload at _load_model_state_dict to process to cpu if config is True 2. Change the device check as lora_finetune might hace 2 device types, accept that to be valid. 3. Some changes to optimize the memory performance: 3.1 use `.detach().clone()` instead of view directly 3.2 if local_state is not meta, copy `full_tensor[slices]` to `ret.to_local()` 4. add relative unit tests Memory performance calling from TorchTune with llama2/7B_full: 1. cpu_offload = True <img width="555" alt="Screenshot 2024-12-18 at 1 36 47 PM" src="https://github.com/user-attachments/assets/429261f5-1107-4592-b295-de3944a2614b" /> 2. cpu_offload = False <img width="555" alt="Screenshot 2024-12-18 at 1 36 52 PM" src="https://github.com/user-attachments/assets/40bf281a-236a-4218-826b-b1192a10c806" /> Pull Request resolved: #142845 Approved by: https://github.com/fegin
sharded_param = full_tensor.new_zeros(chunk.size()) | ||
sharded_param[: chunk.size(0)].copy_(chunk) | ||
|
||
# TODO: change to from_local API (need to add view support for NF4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can we get view support for NF4?
cc @andrewor14
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the review, we currently skip the NF4 tensor part and plan to support NF4 in the next quarter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there's already view support for NF4Tensor? What's the error you're getting?
also cc @drisspg @weifengpy
Context
What is the purpose of this PR? Is it to
Migrate distributed state dict APIs from torch.distributed.
Changelog
What are the changes made in this PR?
Switch to distributed state dict APIs from torch.distributed.
load_from_full_model_state_dict
<-set_model_state_dict
gather_cpu_state_dict
<-get_model_state_dict
load_from_full_optimizer_state_dict
<-set_optimizer_state_dict
get_full_optimizer_state_dict
<-get_optimizer_state_dict
To align the inputs, add model input to
get_full_optimizer_state_dict
andload_from_full_optimizer_state_dict
.Change the sharded_sd input for
gather_cpu_state_dict
to model.TODO:
nf4tensor are kept the same, remain as future work
Test plan
pytest tests/torchtune/training/test_distributed.py
pytest tests -m integration_test
We compare the running with the previous API and the new API, loss are the same in initial loading and resume from checkpoint.
We also draw the memory traces, results show that the new API won't cost mote memory peak comapred with the current ones.