forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix bug with float8 training + FSDP2 + TP (pytorch#1327)
Summary: The combination of float8 training + FSDP2 + TP recently broke, fixing: 1. add a test case so we have this covered in CI. 2. fix the test case, by ensuring we check for `Float8Tensor` properly when it is wrapped in `DTensor`. Note 1: most of the code in `distributed_utils.py` was dead code from before we switched to DTensor, so I deleted it in this PR. Note 2: we already have extensive testing for FSDP2 and TP/SP in separate files. I chose to create a new file for testing those two features together to keep complexity and test runtime manageable. Note 3: we really should make these distributed test cases run in CI, right now it's still local testing only Note 4: there are a couple of future follow-ups which would be interesting: - in FSDP2 with float8-all-gather, perhaps we should return DTensor(Float8Tensor) instead of Float8Tensor, to stay consistent with how FSDP2 wraps weights without float8-all-gather - in DTensor, it would be nice if `isinstance(t, Float8Tensor)` returned True if `t` is a DTensor wrapping a Float8Tensor - food for thought for composability. Having this would enable us to simplify some of the float8 modeling code. Test Plan: ``` // tests added in this PR ./test/float8/test_dtensor.sh // all tests ./test/float8/test_everything.sh // torchtitan command fails before this PR and passes after with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --training.tensor_parallel_degree 2 ``` Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
1 parent
8acd8e7
commit 05ed2e6
Showing
8 changed files
with
177 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
Test numerics of manually defined float16 TP vs float8 TP of toy models | ||
Note: for now, this does not run in CI. | ||
TODO(future): make this run in CI | ||
""" | ||
|
||
import copy | ||
import os | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 | ||
|
||
if not TORCH_VERSION_AT_LEAST_2_5: | ||
pytest.skip("Unsupported PyTorch version", allow_module_level=True) | ||
|
||
from torch.distributed._composable.fsdp import fully_shard | ||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh | ||
from torch.distributed.tensor.parallel import parallelize_module | ||
from tqdm import tqdm | ||
|
||
from torchao.float8 import Float8LinearConfig | ||
from torchao.float8.float8_linear_utils import convert_to_float8_training | ||
from torchao.float8.float8_tensor_parallel import ( | ||
Float8ColwiseParallel, | ||
Float8RowwiseParallel, | ||
) | ||
from torchao.testing.float8.dtensor_utils import ToyModel | ||
|
||
|
||
def setup_distributed(): | ||
world_size = int(os.environ.get("WORLD_SIZE", -1)) | ||
|
||
# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html | ||
device_mesh = init_device_mesh( | ||
"cuda", | ||
(world_size // 2, 2), | ||
mesh_dim_names=("dp", "tp"), | ||
) | ||
# seed must be the same in all processes | ||
torch.manual_seed(1) | ||
return device_mesh | ||
|
||
|
||
def _test_fp8_mlp_tensor_parallelism_base( | ||
mesh: DeviceMesh, size=16, compile: bool = False | ||
): | ||
device = mesh.device_type | ||
|
||
config = Float8LinearConfig( | ||
emulate=True, | ||
enable_fsdp_float8_all_gather=True, | ||
) | ||
|
||
toy_model = ToyModel().to(device) | ||
|
||
tp_model = copy.deepcopy(toy_model) | ||
tp_model = convert_to_float8_training(tp_model, config=config) | ||
|
||
# apply TP | ||
tp_model = parallelize_module( | ||
tp_model, | ||
mesh["tp"], | ||
{ | ||
"ffn.w1": Float8ColwiseParallel(), | ||
"ffn.w2": Float8ColwiseParallel(), | ||
"ffn.out_proj": Float8RowwiseParallel(), | ||
}, | ||
) | ||
|
||
if compile: | ||
tp_model = torch.compile(tp_model) | ||
|
||
# apply FSDP | ||
fsdp_config = {"mesh": mesh["dp"]} | ||
tp_model = fully_shard(tp_model, **fsdp_config) | ||
|
||
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) | ||
x_fp32_tp_input = x_fp32.clone() | ||
|
||
tp_out = tp_model(x_fp32_tp_input) | ||
tp_out.sum().backward() | ||
torch.cuda.synchronize() | ||
|
||
# TODO(future PR): test numerics, and add more cases | ||
|
||
|
||
def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): | ||
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) | ||
|
||
|
||
def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): | ||
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
# float8 only works on CUDA H100 so we only test cuda and we follow | ||
# other test files to not use TestCase but instead just add the test | ||
# cases in the main func. | ||
device_mesh = setup_distributed() | ||
|
||
tests = [ | ||
_test_fp8_mlp_tensor_parallelism_eager, | ||
_test_fp8_mlp_tensor_parallelism_compile, | ||
] | ||
|
||
for test in tqdm(tests, desc="Running tests"): | ||
try: | ||
test(device_mesh) | ||
except Exception as e: | ||
print(f"Test {test.__name__} failed with error: {e}") | ||
raise e | ||
|
||
torch.distributed.destroy_process_group() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.