-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix bug with float8 training + FSDP2 + TP #1327
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1327
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit f8184c0 with merge base 7489c7d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
f97f5ac
to
859c70f
Compare
Summary: The combination of float8 training + FSDP2 + TP recently broke, fixing: 1. add a test case so we have this covered in CI. 2. fix the test case, by ensuring we check for `Float8Tensor` properly when it is wrapped in `DTensor`. Note 1: most of the code in `distributed_utils.py` was dead code from before we switched to DTensor, so I deleted it in this PR. Note 2: we already have extensive testing for FSDP2 and TP/SP in separate files. I chose to create a new file for testing those two features together to keep complexity and test runtime manageable. Note 3: we really should make these distributed test cases run in CI, right now it's still local testing only Note 4: there are a couple of future follow-ups which would be interesting: - in FSDP2 with float8-all-gather, perhaps we should return DTensor(Float8Tensor) instead of Float8Tensor, to stay consistent with how FSDP2 wraps weights without float8-all-gather - in DTensor, it would be nice if `isinstance(t, Float8Tensor)` returned True if `t` is a DTensor wrapping a Float8Tensor - food for thought for composability. Having this would enable us to simplify some of the float8 modeling code. Test Plan: ``` // tests added in this PR ./test/float8/test_dtensor.sh // all tests ./test/float8/test_everything.sh // torchtitan command fails before this PR and passes after with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --training.tensor_parallel_degree 2 ``` Reviewers: Subscribers: Tasks: Tags:
859c70f
to
f8184c0
Compare
|
||
|
||
def _gather_along_first_dim(input_: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of this was dead code, deleting
Summary: The combination of float8 training + FSDP2 + TP recently broke, fixing: 1. add a test case so we have this covered in CI. 2. fix the test case, by ensuring we check for `Float8Tensor` properly when it is wrapped in `DTensor`. Note 1: most of the code in `distributed_utils.py` was dead code from before we switched to DTensor, so I deleted it in this PR. Note 2: we already have extensive testing for FSDP2 and TP/SP in separate files. I chose to create a new file for testing those two features together to keep complexity and test runtime manageable. Note 3: we really should make these distributed test cases run in CI, right now it's still local testing only Note 4: there are a couple of future follow-ups which would be interesting: - in FSDP2 with float8-all-gather, perhaps we should return DTensor(Float8Tensor) instead of Float8Tensor, to stay consistent with how FSDP2 wraps weights without float8-all-gather - in DTensor, it would be nice if `isinstance(t, Float8Tensor)` returned True if `t` is a DTensor wrapping a Float8Tensor - food for thought for composability. Having this would enable us to simplify some of the float8 modeling code. Test Plan: ``` // tests added in this PR ./test/float8/test_dtensor.sh // all tests ./test/float8/test_everything.sh // torchtitan command fails before this PR and passes after with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --training.tensor_parallel_degree 2 ``` Reviewers: Subscribers: Tasks: Tags:
* add pp_dim, distributed, num_gpus, num_nodes as cmd line args * add tp_dim * add elastic_launch * working, can now launch from cli * Remove numpy < 2.0 pin to align with pytorch (pytorch#1301) Fix pytorch#1296 Align with https://github.com/pytorch/pytorch/blame/main/requirements.txt#L5 * Update torchtune pin to 0.4.0-dev20241010 (pytorch#1300) Co-authored-by: vmpuri <[email protected]> * Unbreak gguf util CI job by fixing numpy version (pytorch#1307) Setting numpy version to be the range required by gguf: https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/pyproject.toml * Remove apparently-unused import torchvision in model.py (pytorch#1305) Co-authored-by: vmpuri <[email protected]> * remove global var for tokenizer type + patch tokenizer to allow list of sequences * make pp tp visible in interface * Add llama 3.1 to dist_run.py * [WIP] Move dist inf into its own generator * Add initial generator interface to dist inference * Added generate method and placeholder scheduler * use prompt parameter for dist generation * Enforce tp>=2 * Build tokenizer from TokenizerArgs * Disable torchchat format + constrain possible models for distributed * disable calling dist_run.py directly for now * Restore original dist_run.py for now * disable _maybe_parallelize_model again * Reenable arg.model_name in dist_run.py * Use singleton logger instead of print in generate * Address PR comments; try/expect in launch_dist_inference; added comments --------- Co-authored-by: lessw2020 <[email protected]> Co-authored-by: Mengwei Liu <[email protected]> Co-authored-by: vmpuri <[email protected]> Co-authored-by: vmpuri <[email protected]> Co-authored-by: Scott Wolchok <[email protected]>
Summary:
Fixes #1313
The combination of float8 training + FSDP2 + TP recently broke, fixing:
Float8Tensor
properly when it is wrapped inDTensor
.Note 1: most of the code in
distributed_utils.py
was dead code from before we switched to DTensor, so I deleted it in this PR.Note 2: we already have extensive testing for FSDP2 and TP/SP in separate files. I chose to create a new file for testing those two features together to keep complexity and test runtime manageable.
Note 3: we really should make these distributed test cases run in CI, right now it's still local testing only
Note 4: there are a couple of future follow-ups which would be interesting:
isinstance(t, Float8Tensor)
returned True ift
is a DTensor wrapping a Float8Tensor - food for thought for composability. Having this would enable us to simplify some of the float8 modeling code.Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: