-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix tensor parallelism for float8 training with rowwise scaling #1718
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1718
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, left a couple minor comments
test/float8/test_dtensor.py
Outdated
@@ -196,14 +213,25 @@ def _test_fp8_mlp_tensor_parallelism_base( | |||
sp_model = copy.deepcopy(toy_model) | |||
sp_model = convert_to_float8_training(sp_model, config=config) | |||
|
|||
# for tensorwise scaling, enable float8 all_gather | |||
# for rowwise scaling, keep high precision all_gather |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we expand this comment to explain the reasoning behind this (why fp8 all gather for tensorwise and HP all gather for rowwise)?
prepare_input = prepare_input_cls( | ||
input_layouts=Shard(1), | ||
desired_input_layouts=Replicate(), | ||
fwd_config_submodule_fqn="w2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is saying we use the forward config from the FFN w2 linear layer to perform the fp8 conversion on the inputs? If so, why specifically w2?
@@ -169,7 +169,9 @@ def backward(ctx, grad_output): | |||
# workaround from https://github.com/pytorch/pytorch/issues/141881 | |||
# to avoid saving float8 weight from forward to backward when | |||
# FSDP is on | |||
weight_hp_t = weight_hp_t + (grad_output_reshaped[0, 0] * 0) | |||
g_reshaped = grad_output.reshape(-1, grad_output.shape[-1]) * 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: since this workaround is now different than the one in the github issue referenced in the comment, it would be helpful to update the comment explaining how this modified workaround fixes the interaction with TP.
Summary: 1. add a test for toy model + TP + float8 rowwise scaling training 2. fix underlying issues to make the test pass: a. add fast path for tensor view where the new shape is the same as old shape, for rowwise scaled float8 (this is needed for DTensor) b. modify the fake grad dependency workaround to work when grad is a DTensor Test Plan: 1. ./test/float8/test_everything.sh (one transient failure: https://www.internalfb.com/phabricator/paste/view/P1733103301) 2. verified that float8 rowwise scaling behaves sanely in torchtitan on LLaMa 3 8B on 8 H100s, with tp 2: ``` // requires pytorch/torchtitan#808 // baseline - bfloat16 + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:41:16,175 - root - INFO - step: 40 loss: 7.4240 memory: 35.56GiB(37.43%) tps: 1,669 mfu: 9.77% // float8 baseline - float8 tensorwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:44:07,806 - root - INFO - step: 40 loss: 7.4993 memory: 35.57GiB(37.44%) tps: 2,141 mfu: 12.54% // float8 rowwise without zero fake dep (for sanity) + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:47:51,400 - root - INFO - step: 40 loss: 7.3472 memory: 35.55GiB(37.42%) tps: 1,858 mfu: 10.88% // float8 rowwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:51:20,864 - root - INFO - step: 40 loss: 9.4211 memory: 35.55GiB(37.42%) tps: 1,820 mfu: 10.66% ``` Reviewers: Subscribers: Tasks: Tags:
2ccf537
to
f4adfb0
Compare
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: