Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug with float8 training + FSDP2 + TP #1327

Merged
merged 1 commit into from
Nov 22, 2024

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Nov 22, 2024

Summary:

Fixes #1313

The combination of float8 training + FSDP2 + TP recently broke, fixing:

  1. add a test case so we have this covered in local testing scripts.
  2. fix the test case, by ensuring we check for Float8Tensor properly when it is wrapped in DTensor.

Note 1: most of the code in distributed_utils.py was dead code from before we switched to DTensor, so I deleted it in this PR.
Note 2: we already have extensive testing for FSDP2 and TP/SP in separate files. I chose to create a new file for testing those two features together to keep complexity and test runtime manageable.
Note 3: we really should make these distributed test cases run in CI, right now it's still local testing only
Note 4: there are a couple of future follow-ups which would be interesting:

  • in FSDP2 with float8-all-gather, perhaps we should return DTensor(Float8Tensor) instead of Float8Tensor, to stay consistent with how FSDP2 wraps weights without float8-all-gather
  • in DTensor, it would be nice if isinstance(t, Float8Tensor) returned True if t is a DTensor wrapping a Float8Tensor - food for thought for composability. Having this would enable us to simplify some of the float8 modeling code.

Test Plan:

// tests added in this PR
./test/float8/test_dtensor.sh

// all tests
./test/float8/test_everything.sh

// torchtitan command fails before this PR and passes after
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --training.tensor_parallel_degree 2

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Nov 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1327

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit f8184c0 with merge base 7489c7d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 22, 2024
@vkuzo vkuzo added the topic: bug fix Use this tag for PRs that fix bugs label Nov 22, 2024
@vkuzo vkuzo force-pushed the 20241122_fix_float8_tp_fsdp_bug branch from f97f5ac to 859c70f Compare November 22, 2024 17:50
Summary:

The combination of float8 training + FSDP2 + TP recently broke, fixing:

1. add a test case so we have this covered in CI.
2. fix the test case, by ensuring we check for `Float8Tensor` properly
   when it is wrapped in `DTensor`.

Note 1: most of the code in `distributed_utils.py` was dead code from
before we switched to DTensor, so I deleted it in this PR.
Note 2: we already have extensive testing for FSDP2 and TP/SP in
separate files. I chose to create a new file for testing those two
features together to keep complexity and test runtime manageable.
Note 3: we really should make these distributed test cases run in CI,
right now it's still local testing only
Note 4: there are a couple of future follow-ups which would be
interesting:
- in FSDP2 with float8-all-gather, perhaps we should return
  DTensor(Float8Tensor) instead of Float8Tensor, to stay consistent with
  how FSDP2 wraps weights without float8-all-gather
- in DTensor, it would be nice if `isinstance(t, Float8Tensor)` returned
  True if `t` is a DTensor wrapping a Float8Tensor - food for thought
  for composability. Having this would enable us to simplify some of
  the float8 modeling code.

Test Plan:

```
// tests added in this PR
./test/float8/test_dtensor.sh

// all tests
./test/float8/test_everything.sh

// torchtitan command fails before this PR and passes after
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --training.tensor_parallel_degree 2
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20241122_fix_float8_tp_fsdp_bug branch from 859c70f to f8184c0 Compare November 22, 2024 17:53


def _gather_along_first_dim(input_: torch.Tensor):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this was dead code, deleting

@vkuzo vkuzo merged commit 8f73e84 into main Nov 22, 2024
18 checks passed
sunjiweiswift pushed a commit to sunjiweiswift/ao that referenced this pull request Nov 25, 2024
Summary:

The combination of float8 training + FSDP2 + TP recently broke, fixing:

1. add a test case so we have this covered in CI.
2. fix the test case, by ensuring we check for `Float8Tensor` properly
   when it is wrapped in `DTensor`.

Note 1: most of the code in `distributed_utils.py` was dead code from
before we switched to DTensor, so I deleted it in this PR.
Note 2: we already have extensive testing for FSDP2 and TP/SP in
separate files. I chose to create a new file for testing those two
features together to keep complexity and test runtime manageable.
Note 3: we really should make these distributed test cases run in CI,
right now it's still local testing only
Note 4: there are a couple of future follow-ups which would be
interesting:
- in FSDP2 with float8-all-gather, perhaps we should return
  DTensor(Float8Tensor) instead of Float8Tensor, to stay consistent with
  how FSDP2 wraps weights without float8-all-gather
- in DTensor, it would be nice if `isinstance(t, Float8Tensor)` returned
  True if `t` is a DTensor wrapping a Float8Tensor - food for thought
  for composability. Having this would enable us to simplify some of
  the float8 modeling code.

Test Plan:

```
// tests added in this PR
./test/float8/test_dtensor.sh

// all tests
./test/float8/test_everything.sh

// torchtitan command fails before this PR and passes after
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --training.tensor_parallel_degree 2
```

Reviewers:

Subscribers:

Tasks:

Tags:
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* add pp_dim, distributed, num_gpus, num_nodes as cmd line args

* add tp_dim

* add elastic_launch

* working, can now launch from cli

* Remove numpy < 2.0 pin to align with pytorch (pytorch#1301)

Fix pytorch#1296

Align with https://github.com/pytorch/pytorch/blame/main/requirements.txt#L5

* Update torchtune pin to 0.4.0-dev20241010 (pytorch#1300)

Co-authored-by: vmpuri <[email protected]>

* Unbreak gguf util CI job by fixing numpy version (pytorch#1307)

Setting numpy version to be the range required by gguf: https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/pyproject.toml

* Remove apparently-unused import torchvision in model.py (pytorch#1305)

Co-authored-by: vmpuri <[email protected]>

* remove global var for tokenizer type + patch tokenizer to allow list of sequences

* make pp tp visible in interface

* Add llama 3.1 to dist_run.py

* [WIP] Move dist inf into its own generator

* Add initial generator interface to dist inference

* Added generate method and placeholder scheduler

* use prompt parameter for dist generation

* Enforce tp>=2

* Build tokenizer from TokenizerArgs

* Disable torchchat format + constrain possible models for distributed

* disable calling dist_run.py directly for now

* Restore original dist_run.py for now

* disable _maybe_parallelize_model again

* Reenable arg.model_name in dist_run.py

* Use singleton logger instead of print in generate

* Address PR comments; try/expect in launch_dist_inference; added comments

---------

Co-authored-by: lessw2020 <[email protected]>
Co-authored-by: Mengwei Liu <[email protected]>
Co-authored-by: vmpuri <[email protected]>
Co-authored-by: vmpuri <[email protected]>
Co-authored-by: Scott Wolchok <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

attempting to run aten.abs.default, this is not supported with latest torchtitan + torchao
4 participants